跟踪通常采用特征提取、目标信息集成和包围盒估计的多级流程。为了简化这一流程,并统一特征提取和目标信息集成的过程,我们提出了一个紧凑的跟踪框架,称为MixFormer。我们的核心设计是利用注意力操作的灵活性,提出了一种混合注意模块(MAM),用于特征提取和目标信息集成。该同步建模方案可以提取特定目标的判别性特征,并在目标和搜索区域之间进行广泛的通信。在MAM的基础上,我们通过叠加多个MAM,逐步嵌入patch,并在其上放置一个定位头来构建MixFormer跟踪框架。此外,为了在在线跟踪过程中处理多个目标模板,我们设计了一种非对称注意方案来降低计算成本,并提出了一种有效的分数预测模块来选择高质量的模板。我们的MixFormer在5个跟踪基准上设置了新的最先进的性能,包括LaSOT、TrackingNet、VOT2020、GOT-10k和UAV123。特别是mixer - l在LaSOT上的NP分数为79.9%,在TrackingNet上的NP分数为88.9%,在VOT2020上的EAO分数为0.555。我们还进行了深入的消融研究,以证明同时特征提取和信息整合的有效性。

https://zhuanlan.zhihu.com/p/485189978?utm_psn=1823159977482780672

https://arxiv.org/pdf/2203.11082

以往跟踪框架存在的问题:(1)需要多个组件构成;(2)CNN-based方法缺乏全局建模能力;(3)Transformer-based方法仍依赖CNN提取特征,并在高层特征上进行注意力建模。为了克服上述问题,作者把特征提取和信息基础进行统一。首先,使特征提取更具体到跟踪目标,提取更多目标判别性特征。其次,将目标信息更广泛地集成到搜索区域中,从而更好的捕获它们之间的相关性。最后,获得了一个更加紧凑和优雅的跟踪框架,无需显式的集成模块

image.png

image.png

MAM模块相关代码

class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)

class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)

class Mlp(nn.Module):
    def __init__(self,
                 in_features,
                 hidden_features=None,
                 out_features=None,
                 act_layer=nn.GELU,
                 drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
class Attention(nn.Module):  
    def __init__(self,  
                 dim_in,  
                 dim_out,  
                 num_heads,  
                 qkv_bias=False,  
                 attn_drop=0.,  
                 proj_drop=0.,  
                 method='dw_bn',  
                 kernel_size=3,  
                 stride_kv=1,  
                 stride_q=1,  
                 padding_kv=1,  
                 padding_q=1,  
                 with_cls_token=True,  
                 freeze_bn=False,  
                 **kwargs  
                 ):  
        """  
        初始化 Attention 模块。  

        :param dim_in: 输入通道维度。  
        :param dim_out: 输出通道维度。  
        :param num_heads: 注意力头的数量。  
        :param qkv_bias: 是否在 q, k, v 的线性投影中使用偏置。  
        :param attn_drop: 注意力后的 dropout 概率。  
        :param proj_drop: 投影后的 dropout 概率。  
        :param method: 投影的卷积方式,比如 'dw_bn'。  
        :param kernel_size: 卷积核大小。  
        :param stride_kv: 键和值的卷积步幅。  
        :param stride_q: 查询的卷积步幅。  
        :param padding_kv: 键和值的填充。  
        :param padding_q: 查询的填充。  
        :param with_cls_token: 是否包含 class token。  
        :param freeze_bn: 是否冻结批归一化层。  
        """  
        super().__init__()  
        self.stride_kv = stride_kv  
        self.stride_q = stride_q  
        self.dim = dim_out  
        self.num_heads = num_heads  
        self.scale = dim_out ** -0.5  # 缩放因子,用于缩放点积  
        self.with_cls_token = with_cls_token  

        # 定义批归一化方式,冻结或不冻结  
        if freeze_bn:  
            conv_proj_post_norm = FrozenBatchNorm2d  # 如果冻结则使用特定冻结 BN 实现  
        else:  
            conv_proj_post_norm = nn.BatchNorm2d  # 默认使用常规 BN  

        # 创建用于查询、键和值的卷积投影  
        self.conv_proj_q = self._build_projection(dim_in, dim_out, kernel_size, padding_q, stride_q,  
                                                  'linear' if method == 'avg' else method, conv_proj_post_norm)  
        self.conv_proj_k = self._build_projection(dim_in, dim_out, kernel_size, padding_kv, stride_kv, method,  
                                                  conv_proj_post_norm)  
        self.conv_proj_v = self._build_projection(dim_in, dim_out, kernel_size, padding_kv, stride_kv, method,  
                                                  conv_proj_post_norm)  

        # 查询、键、值的线性投影  
        self.proj_q = nn.Linear(dim_in, dim_out, bias=qkv_bias)  
        self.proj_k = nn.Linear(dim_in, dim_out, bias=qkv_bias)  
        self.proj_v = nn.Linear(dim_in, dim_out, bias=qkv_bias)  

        # 注意力和投影后的 Dropout 操作  
        self.attn_drop = nn.Dropout(attn_drop)  
        self.proj = nn.Linear(dim_out, dim_out)  
        self.proj_drop = nn.Dropout(proj_drop)  

    def _build_projection(self,  
                          dim_in,  
                          dim_out,  
                          kernel_size,  
                          padding,  
                          stride,  
                          method,  
                          norm):  
        """  
        创建卷积投影。  

        :param dim_in: 输入通道维度。  
        :param dim_out: 输出通道维度。  
        :param kernel_size: 卷积核大小。  
        :param padding: 填充。  
        :param stride: 步幅。  
        :param method: 卷积类型 ('dw_bn', 'avg', 'linear')。  
        :param norm: 归一化方法。  
        :return: 投影操作序列。  
        """  
        if method == 'dw_bn':  
            # 深度可分离卷积加批归一化  
            proj = nn.Sequential(OrderedDict([  
                ('conv', nn.Conv2d(dim_in, dim_in, kernel_size=kernel_size, padding=padding, stride=stride,  
                                   bias=False, groups=dim_in)),  
                ('bn', norm(dim_in)),  
                ('rearrage', Rearrange('b c h w -> b (h w) c')),  # 转换张量形状  
            ]))  
        elif method == 'avg':  
            # 平均池化(模拟卷积)  
            proj = nn.Sequential(OrderedDict([  
                ('avg', nn.AvgPool2d(kernel_size=kernel_size, padding=padding, stride=stride, ceil_mode=True)),  
                ('rearrage', Rearrange('b c h w -> b (h w) c')),  # 转换张量形状  
            ]))  
        elif method == 'linear':  
            # 线性层跳过卷积,直接处理(需在调用时处理)  
            proj = None  
        else:  
            raise ValueError('Unknown method ({})'.format(method))  
        return proj  

    def forward_conv(self, x, t_h, t_w, s_h, s_w):  
        """  
        进行卷积前向传播,用于查询、键和值的特征图提取。  

        :param x: 输入张量,通常由 template 和 search 组合而成。  
        :param t_h: 模板高度。  
        :param t_w: 模板宽度。  
        :param s_h: 搜索图高度。  
        :param s_w: 搜索图宽度。  
        :return: q, k, v - 查询、键和值。  
        """  
        # 将输入 x 分割成 template ,online_template 和 search,按照其各自的大小  
        template, online_template, search = torch.split(x, [t_h * t_w, t_h * t_w, s_h * s_w], dim=1)  

        # 重新排列尺寸以适合卷积操作  
        template = rearrange(template, 'b (h w) c -> b c h w', h=t_h, w=t_w).contiguous()  
        online_template = rearrange(online_template, 'b (h w) c -> b c h w', h=t_h, w=t_w).contiguous()  
        search = rearrange(search, 'b (h w) c -> b c h w', h=s_h, w=s_w).contiguous()  

        # 根据投影是否存在来进行相应的卷积或线性投影操作  
        if self.conv_proj_q is not None:  
            t_q = self.conv_proj_q(template)  
            ot_q = self.conv_proj_q(online_template)  
            s_q = self.conv_proj_q(search)  
            q = torch.cat([t_q, ot_q, s_q], dim=1)  
        else:  
            t_q = rearrange(template, 'b c h w -> b (h w) c').contiguous()  
            ot_q = rearrange(online_template, 'b c h w -> b (h w) c').contiguous()  
            s_q = rearrange(search, 'b c h w -> b (h w) c').contiguous()  
            q = torch.cat([t_q, ot_q, s_q], dim=1)  

        if self.conv_proj_k is not None:  
            t_k = self.conv_proj_k(template)  
            ot_k = self.conv_proj_k(online_template)  
            s_k = self.conv_proj_k(search)  
            k = torch.cat([t_k, ot_k, s_k], dim=1)  
        else:  
            t_k = rearrange(template, 'b c h w -> b (h w) c').contiguous()  
            ot_k = rearrange(online_template, 'b c h w -> b (h w) c').contiguous()  
            s_k = rearrange(search, 'b c h w -> b (h w) c').contiguous()  
            k = torch.cat([t_k, ot_k, s_k], dim=1)  

        if self.conv_proj_v is not None:  
            t_v = self.conv_proj_v(template)  
            ot_v = self.conv_proj_v(online_template)  
            s_v = self.conv_proj_v(search)  
            v = torch.cat([t_v, ot_v, s_v], dim=1)  
        else:  
            t_v = rearrange(template, 'b c h w -> b (h w) c').contiguous()  
            ot_v = rearrange(online_template, 'b c h w -> b (h w) c').contiguous()  
            s_v = rearrange(search, 'b c h w -> b (h w) c').contiguous()  
            v = torch.cat([t_v, ot_v, s_v], dim=1)  

        return q, k, v  

    def forward_conv_test(self, x, s_h, s_w):  
        """  
        测试时的卷积操作。  

        :param x: 搜索区域张量。  
        :param s_h: 搜索图高度。  
        :param s_w: 搜索图宽度。  
        :return: q, k, v - 查询、键和值。  
        """  
        search = x  
        search = rearrange(search, 'b (h w) c -> b c h w', h=s_h, w=s_w).contiguous()  

        if self.conv_proj_q is not None:  
            q = self.conv_proj_q(search)  
        else:  
            q = rearrange(search, 'b c h w -> b (h w) c').contiguous()  

        if self.conv_proj_k is not None:  
            k = self.conv_proj_k(search)  
        else:  
            k = rearrange(search, 'b c h w -> b (h w) c').contiguous()  
        k = torch.cat([self.t_k, self.ot_k, k], dim=1)  

        if self.conv_proj_v is not None:  
            v = self.conv_proj_v(search)  
        else:  
            v = rearrange(search, 'b c h w -> b (h w) c').contiguous()  
        v = torch.cat([self.t_v, self.ot_v, v], dim=1)  

        return q, k, v  

    def forward(self, x, t_h, t_w, s_h, s_w):  
        """  
        非对称混合注意力的前向传播。  

        :param x: 输入张量。  
        :param t_h: 模板高度。  
        :param t_w: 模板宽度。  
        :param s_h: 搜索图高度。  
        :param s_w: 搜索图宽度。  
        :return: 经过 attention操作的输出。  
        """  
        if (  
            self.conv_proj_q is not None  
            or self.conv_proj_k is not None  
            or self.conv_proj_v is not None  
        ):  
            q, k, v = self.forward_conv(x, t_h, t_w, s_h, s_w)  

        # 将 q, k, v 调整为多头形式  
        q = rearrange(self.proj_q(q), 'b t (h d) -> b h t d', h=self.num_heads).contiguous()  
        k = rearrange(self.proj_k(k), 'b t (h d) -> b h t d', h=self.num_heads).contiguous()  
        v = rearrange(self.proj_v(v), 'b t (h d) -> b h t d', h=self.num_heads).contiguous()  

        ### 注意: k/v 压缩,q_size 的 1/4(conv_stride=2)  
        q_mt, q_s = torch.split(q, [t_h * t_w * 2, s_h * s_w], dim=2)  
        k_mt, k_s = torch.split(k, [((t_h + 1) // 2) ** 2 * 2, s_h * s_w // 4], dim=2)  
        v_mt, v_s = torch.split(v, [((t_h + 1) // 2) ** 2 * 2, s_h * s_w // 4], dim=2)  

        # 模板注意力计算  
        attn_score = torch.einsum('bhlk,bhtk->bhlt', [q_mt, k_mt]) * self.scale  
        attn = F.softmax(attn_score, dim=-1)  
        attn = self.attn_drop(attn)  
        x_mt = torch.einsum('bhlt,bhtv->bhlv', [attn, v_mt])  
        x_mt = rearrange(x_mt, 'b h t d -> b t (h d)')  

        # 搜索区域注意力计算  
        attn_score = torch.einsum('bhlk,bhtk->bhlt', [q_s, k]) * self.scale  
        attn = F.softmax(attn_score, dim=-1)  
        attn = self.attn_drop(attn)  
        x_s = torch.einsum('bhlt,bhtv->bhlv', [attn, v])  
        x_s = rearrange(x_s, 'b h t d -> b t (h d)')  

        x = torch.cat([x_mt, x_s], dim=1)  

        # 经过线性投影和投影后 dropout  
        x = self.proj(x)  
        x = self.proj_drop(x)  

        return x  

    def forward_test(self, x, s_h, s_w):  
        """   
        测试阶段使用的前向传播,用于处理仅有搜索区域的场景。  

        :param x: 搜索区域张量。  
        :param s_h: 搜索图高度。  
        :param s_w: 搜索图宽度。  
        :return: 经过 attention 运算后的输出。  
        """  
        if (  
                self.conv_proj_q is not None  
                or self.conv_proj_k is not None  
                or self.conv_proj_v is not None  
        ):  
            q_s, k, v = self.forward_conv_test(x, s_h, s_w)  

        q_s = rearrange(self.proj_q(q_s), 'b t (h d) -> b h t d', h=self.num_heads).contiguous()  
        k = rearrange(self.proj_k(k), 'b t (h d) -> b h t d', h=self.num_heads).contiguous()  
        v = rearrange(self.proj_v(v), 'b t (h d) -> b h t d', h=self.num_heads).contiguous()  

        attn_score = torch.einsum('bhlk,bhtk->bhlt', [q_s, k]) * self.scale  
        attn = F.softmax(attn_score, dim=-1)  
        attn = self.attn_drop(attn)  
        x_s = torch.einsum('bhlt,bhtv->bhlv', [attn, v])  
        x_s = rearrange(x_s, 'b h t d -> b t (h d)').contiguous()  

        x = x_s  

        x = self.proj(x)  
        x = self.proj_drop(x)  

        return x  

    def set_online(self, x, t_h, t_w):  
        """  
        在线模式设置,用于将模板和在线模板组合为一个查询、键和值。  

        :param x: 输入张量,包括模板和在线模板。  
        :param t_h: 模板高度。  
        :param t_w: 模板宽度。  
        :return: 设定后的特征。  
        """  
        template = x[:, :t_h * t_w]  # 1, 1024, c  
        online_template = x[:, t_h * t_w:]  # 1, b*1024, c  
        template = rearrange(template, 'b (h w) c -> b c h w', h=t_h, w=t_w).contiguous()  
        online_template = rearrange(online_template.squeeze(0), '(b h w) c -> b c h w', h=t_h, w=t_w).contiguous()  # b, c, 32, 32  

        if self.conv_proj_q is not None:  
            t_q = self.conv_proj_q(template)  
            ot_q = self.conv_proj_q(online_template).flatten(end_dim=1).unsqueeze(0)  
        else:  
            t_q = rearrange(template, 'b c h w -> b (h w) c').contiguous()  
            ot_q = rearrange(online_template, 'b c h w -> (b h w) c').contiguous().unsqueeze(0)  
        q = torch.cat([t_q, ot_q], dim=1)  

        if self.conv_proj_k is not None:  
            self.t_k = self.conv_proj_k(template)  
            self.ot_k = self.conv_proj_k(online_template).flatten(end_dim=1).unsqueeze(0)  
        else:  
            self.t_k = rearrange(template, 'b c h w -> b (h w) c').contiguous()  
            self.ot_k = rearrange(online_template, 'b c h w -> (b h w) c').contiguous().unsqueeze(0)  
        k = torch.cat([self.t_k, self.ot_k], dim=1)  

        if self.conv_proj_v is not None:  
            self.t_v = self.conv_proj_v(template)  
            self.ot_v = self.conv_proj_v(online_template).flatten(end_dim=1).unsqueeze(0)  
        else:  
            self.t_v = rearrange(template, 'b c h w -> b (h w) c').contiguous()  
            self.ot_v = rearrange(online_template, 'b c h w -> (b h w) c').contiguous().unsqueeze(0)  
        v = torch.cat([self.t_v, self.ot_v], dim=1)  

        # 将 q, k, v 重新排列为多头形态以进行注意力计算  
        q = rearrange(self.proj_q(q), 'b t (h d) -> b h t d', h=self.num_heads).contiguous()  
        k = rearrange(self.proj_k(k), 'b t (h d) -> b h t d', h=self.num_heads).contiguous()  
        v = rearrange(self.proj_v(v), 'b t (h d) -> b h t d', h=self.num_heads).contiguous()  

        attn_score = torch.einsum('bhlk,bhtk->bhlt', [q, k]) * self.scale  
        attn = F.softmax(attn_score, dim=-1)  
        attn = self.attn_drop(attn)  
        x = torch.einsum('bhlt,bhtv->bhlv', [attn, v])  
        x = rearrange(x, 'b h t d -> b t (h d)').contiguous()  

        x = self.proj(x)  
        x = self.proj_drop(x)  

        return x
class Block(nn.Module):  
    """  
    Transformer Block Module  

    This class defines a single block of a Transformer model, which includes  
    attention mechanisms and a feed-forward neural network (MLP).  
    """  

    def __init__(self,  
                 dim_in,  
                 dim_out,  
                 num_heads,  
                 mlp_ratio=4.,  
                 qkv_bias=False,  
                 drop=0.,  
                 attn_drop=0.,  
                 drop_path=0.,  
                 act_layer=nn.GELU,  
                 norm_layer=nn.LayerNorm,  
                 freeze_bn=False,  
                 **kwargs):  
        """  
        Initialize the Transformer Block.  

        Parameters:  
        - dim_in: int, the input dimension size.  
        - dim_out: int, the output dimension size.  
        - num_heads: int, number of attention heads.  
        - mlp_ratio: float, ratio to determine the hidden layer size in MLP.  
        - qkv_bias: bool, whether to use bias in query, key, value projections.  
        - drop: float, dropout rate for MLP.  
        - attn_drop: float, dropout rate for attention.  
        - drop_path: float, drop path rate for stochastic depth.  
        - act_layer: activation layer, default is GELU.  
        - norm_layer: normalization layer, default is LayerNorm.  
        - freeze_bn: bool, whether to freeze batch normalization layers.  
        - kwargs: additional keyword arguments.  
        """  
        super().__init__()  

        # Whether to use class token  
        self.with_cls_token = kwargs['with_cls_token']  

        # First normalization layer  
        self.norm1 = norm_layer(dim_in)  

        # Attention layer  
        self.attn = Attention(  
            dim_in, dim_out, num_heads, qkv_bias, attn_drop, drop, freeze_bn=freeze_bn,  
            **kwargs  
        )  

        # Drop path for stochastic depth regularization  
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()  

        # Second normalization layer  
        self.norm2 = norm_layer(dim_out)  

        # Define the hidden dimension size for the MLP  
        dim_mlp_hidden = int(dim_out * mlp_ratio)  

        # MLP layer  
        self.mlp = Mlp(  
            in_features=dim_out,  
            hidden_features=dim_mlp_hidden,  
            act_layer=act_layer,  
            drop=drop  
        )  

    def forward(self, x, t_h, t_w, s_h, s_w):  
        """  
        Forward pass for training.  

        Parameters:  
        - x: input tensor.  
        - t_h, t_w: template height and width.  
        - s_h, s_w: search area height and width.  

        Returns:  
        - Output tensor after processing through the block.  
        """  
        res = x  # Save the input for residual connection  

        x = self.norm1(x)  # Apply first normalization  
        attn = self.attn(x, t_h, t_w, s_h, s_w)  # Compute attention  
        x = res + self.drop_path(attn)  # Add residual connection and apply drop path  
        x = x + self.drop_path(self.mlp(self.norm2(x)))  # Apply MLP with residual connection  

        return x  

    def forward_test(self, x, s_h, s_w):  
        """  
        Forward pass for testing.  

        Parameters:  
        - x: input tensor.  
        - s_h, s_w: search area height and width.  

        Returns:  
        - Output tensor after processing through the block.  
        """  
        res = x  # Save the input for residual connection  

        x = self.norm1(x)  # Apply first normalization  
        attn = self.attn.forward_test(x, s_h, s_w)  # Compute attention in test mode  
        x = res + self.drop_path(attn)  # Add residual connection and apply drop path  
        x = x + self.drop_path(self.mlp(self.norm2(x)))  # Apply MLP with residual connection  

        return x  

    def set_online(self, x, t_h, t_w):  
        """  
        Set the block for online mode.  

        Parameters:  
        - x: input tensor.  
        - t_h, t_w: template height and width.  

        Returns:  
        - Output tensor after processing through the block.  
        """  
        res = x  # Save the input for residual connection  
        x = self.norm1(x)  # Apply first normalization  
        attn = self.attn.set_online(x, t_h, t_w)  # Compute attention in online mode  
        x = res + self.drop_path(attn)  # Add residual connection and apply drop path  
        x = x + self.drop_path(self.mlp(self.norm2(x)))  # Apply MLP with residual connection  

        return x

class ConvEmbed(nn.Module):  
    """ 将图像转换为卷积嵌入  

    ConvEmbed 类用于将输入图像转换为卷积嵌入向量。这在视觉 Transformer 模型中是常见的,将二维图像块映射到嵌入空间。  
    """  

    def __init__(self,  
                 patch_size=7,  
                 in_chans=3,  
                 embed_dim=64,  
                 stride=4,  
                 padding=2,  
                 norm_layer=None):  
        """  
        初始化 ConvEmbed 类。  

        参数:  
        - patch_size: int or tuple, 指定每个卷积过滤器的大小。默认为 7。  
        - in_chans: int, 输入的通道数(例如,RGB 图像有 3 个通道)。默认为 3。  
        - embed_dim: int, 输出嵌入特征的维度。默认为 64。  
        - stride: int, 卷积操作的步幅。默认为 4。  
        - padding: int, 卷积操作的填充大小。默认为 2。  
        - norm_layer: callable or None, 用于标准化嵌入特征的层,可以是 nn.LayerNorm 类或其他自定义归一化层。默认为 None。  
        """  
        super().__init__()  

        # 当 patch_size 为整数时,将其转换为 (patch_size, patch_size) 的元组形式  
        patch_size = to_2tuple(patch_size)  
        self.patch_size = patch_size  # 储存块大小  

        # 定义卷积层,将输入通道数映射到指定的嵌入维度  
        self.proj = nn.Conv2d(  
            in_chans,        # 输入通道数  
            embed_dim,       # 输出通道数,即嵌入维度  
            kernel_size=patch_size,  # 卷积核大小  
            stride=stride,   # 卷积步幅  
            padding=padding  # 填充大小  
        )  

        # 如果提供了标准化层,则进行初始化,否则不使用标准化  
        self.norm = norm_layer(embed_dim) if norm_layer else None  

    def forward(self, x):  
        """  
        前向传播。  

        参数:  
        - x: 输入张量,形状通常为 (batch_size, channels, height, width)  

        返回:  
        - 转换后的嵌入张量,通常是与输入相同的形状。  
        """  
        # 通过卷积层获取嵌入特征  
        x = self.proj(x)  

        # 获取卷积后的张量的形状,B 是批量大小,C 是通道数,H 和 W 是高和宽  
        B, C, H, W = x.shape  

        # 重新排列张量,并将 2D 图像(h, w)展平为 1D 序列(h w)  
        x = rearrange(x, 'b c h w -> b (h w) c').contiguous()  

        # 如果定义了 norm_layer,则对嵌入特征进行标准化  
        if self.norm:  
            x = self.norm(x)  

        # 将展平的张量转换回其原始形状,在应用标准化后  
        x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W).contiguous()  

        return x  # 返回转换后的嵌入张量

class VisionTransformer(nn.Module):
    """ Vision Transformer with support for patch or hybrid CNN input stage
    """
    def __init__(self,
                 patch_size=16,
                 patch_stride=16,
                 patch_padding=0,
                 in_chans=3,
                 embed_dim=768,
                 depth=12,
                 num_heads=12,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm,
                 init='trunc_norm',
                 freeze_bn=False,
                 **kwargs):
        super().__init__()
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models

        self.rearrage = None

        self.patch_embed = ConvEmbed(
            # img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            stride=patch_stride,
            padding=patch_padding,
            embed_dim=embed_dim,
            norm_layer=norm_layer
        )

        with_cls_token = kwargs['with_cls_token']
        if with_cls_token:
            self.cls_token = nn.Parameter(
                torch.zeros(1, 1, embed_dim)
            )
        else:
            self.cls_token = None

        self.pos_drop = nn.Dropout(p=drop_rate)
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule

        blocks = []
        for j in range(depth):
            blocks.append(
                Block(
                    dim_in=embed_dim,
                    dim_out=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[j],
                    act_layer=act_layer,
                    norm_layer=norm_layer,
                    freeze_bn=freeze_bn,
                    **kwargs
                )
            )
        self.blocks = nn.ModuleList(blocks)

        if self.cls_token is not None:
            trunc_normal_(self.cls_token, std=.02)

        if init == 'xavier':
            self.apply(self._init_weights_xavier)
        else:
            self.apply(self._init_weights_trunc_normal)

    def _init_weights_trunc_normal(self, m):
        if isinstance(m, nn.Linear):
            logging.info('=> init weight of Linear from trunc norm')
            trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                logging.info('=> init bias of Linear to zeros')
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def _init_weights_xavier(self, m):
        if isinstance(m, nn.Linear):
            logging.info('=> init weight of Linear from xavier uniform')
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                logging.info('=> init bias of Linear to zeros')
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, template, online_template, search):
        """
        :param template: (batch, c, 128, 128)
        :param search: (batch, c, 320, 320)
        :return:
        """
        # x = self.patch_embed(x)
        # B, C, H, W = x.size()
        template = self.patch_embed(template)
        online_template = self.patch_embed(online_template)
        t_B, t_C, t_H, t_W = template.size()
        search = self.patch_embed(search)
        s_B, s_C, s_H, s_W = search.size()

        template = rearrange(template, 'b c h w -> b (h w) c').contiguous()
        online_template = rearrange(online_template, 'b c h w -> b (h w) c').contiguous()
        search = rearrange(search, 'b c h w -> b (h w) c').contiguous()
        x = torch.cat([template, online_template, search], dim=1)

        x = self.pos_drop(x)

        for i, blk in enumerate(self.blocks):
            x = blk(x, t_H, t_W, s_H, s_W)

        # if self.cls_token is not None:
        #     cls_tokens, x = torch.split(x, [1, H*W], 1)
        template, online_template, search = torch.split(x, [t_H*t_W, t_H*t_W, s_H*s_W], dim=1)
        template = rearrange(template, 'b (h w) c -> b c h w', h=t_H, w=t_W).contiguous()
        online_template = rearrange(online_template, 'b (h w) c -> b c h w', h=t_H, w=t_W).contiguous()
        search = rearrange(search, 'b (h w) c -> b c h w', h=s_H, w=s_W).contiguous()

        return template, online_template, search

    def forward_test(self, search):
        # x = self.patch_embed(x)
        # B, C, H, W = x.size()
        search = self.patch_embed(search)
        s_B, s_C, s_H, s_W = search.size()

        search = rearrange(search, 'b c h w -> b (h w) c').contiguous()
        x = search
        # x = torch.cat([template, search], dim=1)

        x = self.pos_drop(x)

        for i, blk in enumerate(self.blocks):
            x = blk.forward_test(x, s_H, s_W)

        # if self.cls_token is not None:
        #     cls_tokens, x = torch.split(x, [1, H*W], 1)
        # template, search = torch.split(x, [t_H*t_W, s_H*s_W], dim=1)
        search = x
        search = rearrange(search, 'b (h w) c -> b c h w', h=s_H, w=s_W)

        return search

    def set_online(self, template, online_template):
        template = self.patch_embed(template)
        online_template = self.patch_embed(online_template)
        t_B, t_C, t_H, t_W = template.size()

        template = rearrange(template, 'b c h w -> b (h w) c').contiguous()
        online_template = rearrange(online_template, 'b c h w -> (b h w) c').unsqueeze(0).contiguous()
        # 1, 1024, c
        # 1, b*1024, c
        # print(template.shape, online_template.shape)
        x = torch.cat([template, online_template], dim=1)

        x = self.pos_drop(x)

        for i, blk in enumerate(self.blocks):
            x = blk.set_online(x, t_H, t_W)

        # if self.cls_token is not None:
        #     cls_tokens, x = torch.split(x, [1, H*W], 1)
        template = x[:, :t_H*t_W]
        online_template = x[:, t_H*t_W:]
        template = rearrange(template, 'b (h w) c -> b c h w', h=t_H, w=t_W)
        online_template = rearrange(online_template.squeeze(0), '(b h w) c -> b c h w', h=t_H, w=t_W)

        return template, online_template        

image.png

BackBone

class ConvolutionalVisionTransformer(nn.Module):
    def __init__(self,
                 in_chans=3,
                 # num_classes=1000,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm,
                 init='trunc_norm',
                 spec=None):
        super().__init__()
        # self.num_classes = num_classes

        self.num_stages = spec['NUM_STAGES']
        for i in range(self.num_stages):
            kwargs = {
                'patch_size': spec['PATCH_SIZE'][i],
                'patch_stride': spec['PATCH_STRIDE'][i],
                'patch_padding': spec['PATCH_PADDING'][i],
                'embed_dim': spec['DIM_EMBED'][i],
                'depth': spec['DEPTH'][i],
                'num_heads': spec['NUM_HEADS'][i],
                'mlp_ratio': spec['MLP_RATIO'][i],
                'qkv_bias': spec['QKV_BIAS'][i],
                'drop_rate': spec['DROP_RATE'][i],
                'attn_drop_rate': spec['ATTN_DROP_RATE'][i],
                'drop_path_rate': spec['DROP_PATH_RATE'][i],
                'with_cls_token': spec['CLS_TOKEN'][i],
                'method': spec['QKV_PROJ_METHOD'][i],
                'kernel_size': spec['KERNEL_QKV'][i],
                'padding_q': spec['PADDING_Q'][i],
                'padding_kv': spec['PADDING_KV'][i],
                'stride_kv': spec['STRIDE_KV'][i],
                'stride_q': spec['STRIDE_Q'][i],
                'freeze_bn': spec['FREEZE_BN'],
            }

            stage = VisionTransformer(
                in_chans=in_chans,
                init=init,
                act_layer=act_layer,
                norm_layer=norm_layer,
                **kwargs
            )
            setattr(self, f'stage{i}', stage)

            in_chans = spec['DIM_EMBED'][i]

        dim_embed = spec['DIM_EMBED'][-1]
        self.norm = norm_layer(dim_embed)
        self.cls_token = spec['CLS_TOKEN'][-1]

        # Classifier head
        self.head = nn.Linear(dim_embed, 1000)
        trunc_normal_(self.head.weight, std=0.02)

    def forward(self, template, online_template, search):
        """
        :param template: (b, 3, 128, 128)
        :param search: (b, 3, 320, 320)
        :return:
        """
        # template = template + self.template_emb
        # search = search + self.search_emb
        for i in range(self.num_stages):
            template, online_template, search = getattr(self, f'stage{i}')(template, online_template, search)

        return template, search

    def forward_test(self, search):
        for i in range(self.num_stages):
            search = getattr(self, f'stage{i}').forward_test(search)
        return search

    def set_online(self, template, online_template):
        for i in range(self.num_stages):
            template, online_template = getattr(self, f'stage{i}').set_online(template, online_template)

def get_mixformer_model(config, **kwargs):
    msvit_spec = config.MODEL.BACKBONE
    msvit = ConvolutionalVisionTransformer(
        in_chans=3,
        act_layer=QuickGELU,
        norm_layer=partial(LayerNorm, eps=1e-5),
        init=getattr(msvit_spec, 'INIT', 'trunc_norm'),
        spec=msvit_spec
    )

    if config.MODEL.BACKBONE.PRETRAINED:
        try:
            ckpt_path = config.MODEL.BACKBONE.PRETRAINED_PATH
            ckpt = torch.load(ckpt_path, map_location='cpu')
            missing_keys, unexpected_keys = msvit.load_state_dict(ckpt, strict=False)
            if is_main_process():
                print("Load pretrained backbone checkpoint from:", ckpt_path)
                print("missing keys:", missing_keys)
                print("unexpected keys:", unexpected_keys)
                print("Loading pretrained CVT done.")
        except:
            print("Warning: Pretrained CVT weights are not loaded")

    return msvit

Head

head的部分提供了MLP Corner Pyramid_Corner三种类型的Head

def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1,
         freeze_bn=False):
    if freeze_bn:
        return nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
                      padding=padding, dilation=dilation, bias=True),
            FrozenBatchNorm2d(out_planes),
            nn.ReLU(inplace=True))
    else:
        return nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
                      padding=padding, dilation=dilation, bias=True),
            nn.BatchNorm2d(out_planes),
            nn.ReLU(inplace=True))

class Corner_Predictor(nn.Module):  
    """ Corner Predictor module"""  

    def __init__(self, inplanes=64, channel=256, feat_sz=20, stride=16, freeze_bn=False):  
        super(Corner_Predictor, self).__init__()  
        # 特征图尺寸  
        self.feat_sz = feat_sz  
        # 步幅,用于将特征图坐标映射回图像坐标  
        self.stride = stride  
        # 图像尺寸,通过特征图尺寸和步幅计算得到  
        self.img_sz = self.feat_sz * self.stride  
        
        '''top-left corner'''  
        # 定义用于预测左上角的卷积层  
        self.conv1_tl = conv(inplanes, channel, freeze_bn=freeze_bn)  
        # 通道减少一半  
        self.conv2_tl = conv(channel, channel // 2, freeze_bn=freeze_bn)  
        # 通道再减少一半  
        self.conv3_tl = conv(channel // 2, channel // 4, freeze_bn=freeze_bn)  
        # 通道再减少一半  
        self.conv4_tl = conv(channel // 4, channel // 8, freeze_bn=freeze_bn)  
        # 输出一维,是坐标的概率图  
        self.conv5_tl = nn.Conv2d(channel // 8, 1, kernel_size=1)  

        '''bottom-right corner'''  
        # 以下为预测右下角的卷积层,与左上角类似  
        self.conv1_br = conv(inplanes, channel, freeze_bn=freeze_bn)  
        self.conv2_br = conv(channel, channel // 2, freeze_bn=freeze_bn)  
        self.conv3_br = conv(channel // 2, channel // 4, freeze_bn=freeze_bn)  
        self.conv4_br = conv(channel // 4, channel // 8, freeze_bn=freeze_bn)  
        self.conv5_br = nn.Conv2d(channel // 8, 1, kernel_size=1)  

        '''about coordinates and indexs'''  
        with torch.no_grad():  
            # 生成从0到feat_sz的索引,这些索引用来计算特征图坐标  
            self.indice = torch.arange(0, self.feat_sz).view(-1, 1) * self.stride  
            # 生成坐标网格  
            self.coord_x = self.indice.repeat((self.feat_sz, 1)) \\
                .view((self.feat_sz * self.feat_sz,)).float().cuda()  
            self.coord_y = self.indice.repeat((1, self.feat_sz)) \\
                .view((self.feat_sz * self.feat_sz,)).float().cuda()  

    def forward(self, x, return_dist=False, softmax=True):  
        """ Forward pass with input x. """  
        # 获取左上角和右下角的得分图  
        score_map_tl, score_map_br = self.get_score_map(x)  
        if return_dist:  
            # 如果需要返回概率分布,计算软argmax并返回  
            coorx_tl, coory_tl, prob_vec_tl = self.soft_argmax(score_map_tl, return_dist=True, softmax=softmax)  
            coorx_br, coory_br, prob_vec_br = self.soft_argmax(score_map_br, return_dist=True, softmax=softmax)  
            return torch.stack((coorx_tl, coory_tl, coorx_br, coory_br), dim=1) / self.img_sz, prob_vec_tl, prob_vec_br  
        else:  
            # 直接返回坐标  
            coorx_tl, coory_tl = self.soft_argmax(score_map_tl)  
            coorx_br, coory_br = self.soft_argmax(score_map_br)  
            return torch.stack((coorx_tl, coory_tl, coorx_br, coory_br), dim=1) / self.img_sz  

    def get_score_map(self, x):  
        # top-left branch  
        # 依次通过卷积层进行特征提取  
        x_tl1 = self.conv1_tl(x)  
        x_tl2 = self.conv2_tl(x_tl1)  
        x_tl3 = self.conv3_tl(x_tl2)  
        x_tl4 = self.conv4_tl(x_tl3)  
        score_map_tl = self.conv5_tl(x_tl4)  

        # bottom-right branch  
        # 右下角分支,进行与左上角分支类似的特征提取  
        x_br1 = self.conv1_br(x)  
        x_br2 = self.conv2_br(x_br1)  
        x_br3 = self.conv3_br(x_br2)  
        x_br4 = self.conv4_br(x_br3)  
        score_map_br = self.conv5_br(x_br4)  
        return score_map_tl, score_map_br  

    def soft_argmax(self, score_map, return_dist=False, softmax=True):  
        """ get soft-argmax coordinate for a given heatmap """  
        # 展开特征图为向量形式  
        score_vec = score_map.view((-1, self.feat_sz * self.feat_sz))  
        # 计算softmax以获得概率分布  
        prob_vec = nn.functional.softmax(score_vec, dim=1)  
        # 计算加权平均坐标,通过概率和坐标网格  
        exp_x = torch.sum((self.coord_x * prob_vec), dim=1)  
        exp_y = torch.sum((self.coord_y * prob_vec), dim=1)  
        if return_dist:  
            # 返回分布向量和坐标  
            if softmax:  
                return exp_x, exp_y, prob_vec  
            else:  
                return exp_x, exp_y, score_vec  
        else:  
            # 仅返回平均坐标  
            return exp_x, exp_y

class Pyramid_Corner_Predictor(nn.Module):  
    """ Corner Predictor module"""  

    def __init__(self, inplanes=64, channel=256, feat_sz=20, stride=16, freeze_bn=False):  
        super(Pyramid_Corner_Predictor, self).__init__()  
        # 初始化特征图尺寸  
        self.feat_sz = feat_sz  
        # 定义步幅,用于将特征图坐标转换到图像坐标  
        self.stride = stride  
        # 总图像尺寸,通过特征图尺寸和步幅计算得到  
        self.img_sz = self.feat_sz * self.stride  
        
        '''top-left corner'''  
        # 用于预测左上角的卷积层  
        self.conv1_tl = conv(inplanes, channel, freeze_bn=freeze_bn)  
        self.conv2_tl = conv(channel, channel // 2, freeze_bn=freeze_bn)  
        self.conv3_tl = conv(channel // 2, channel // 4, freeze_bn=freeze_bn)  
        self.conv4_tl = conv(channel // 4, channel // 8, freeze_bn=freeze_bn)  
        self.conv5_tl = nn.Conv2d(channel // 8, 1, kernel_size=1)  

        # 收缩特征图并调整通道的层  
        self.adjust1_tl = conv(inplanes, channel // 2, freeze_bn=freeze_bn)  
        self.adjust2_tl = conv(inplanes, channel // 4, freeze_bn=freeze_bn)  

        # 一系列卷积层调整特征图,对应不同的特征尺寸  
        self.adjust3_tl = nn.Sequential(conv(channel // 2, channel // 4, freeze_bn=freeze_bn),  
                                        conv(channel // 4, channel // 8, freeze_bn=freeze_bn),  
                                        conv(channel // 8, 1, freeze_bn=freeze_bn))  
        self.adjust4_tl = nn.Sequential(conv(channel // 4, channel // 8, freeze_bn=freeze_bn),  
                                        conv(channel // 8, 1, freeze_bn=freeze_bn))  

        '''bottom-right corner'''  
        # 对于右下角的卷积层,同理如左上角  
        self.conv1_br = conv(inplanes, channel, freeze_bn=freeze_bn)  
        self.conv2_br = conv(channel, channel // 2, freeze_bn=freeze_bn)  
        self.conv3_br = conv(channel // 2, channel // 4, freeze_bn=freeze_bn)  
        self.conv4_br = conv(channel // 4, channel // 8, freeze_bn=freeze_bn)  
        self.conv5_br = nn.Conv2d(channel // 8, 1, kernel_size=1)  

        self.adjust1_br = conv(inplanes, channel // 2, freeze_bn=freeze_bn)  
        self.adjust2_br = conv(inplanes, channel // 4, freeze_bn=freeze_bn)  

        self.adjust3_br = nn.Sequential(conv(channel // 2, channel // 4, freeze_bn=freeze_bn),  
                                        conv(channel // 4, channel // 8, freeze_bn=freeze_bn),  
                                        conv(channel // 8, 1, freeze_bn=freeze_bn))  
        self.adjust4_br = nn.Sequential(conv(channel // 4, channel // 8, freeze_bn=freeze_bn),  
                                        conv(channel // 8, 1, freeze_bn=freeze_bn))  

        '''about coordinates and indexs'''  
        with torch.no_grad():  
            # 使用步幅生成索引,用于特征图的坐标转换  
            self.indice = torch.arange(0, self.feat_sz).view(-1, 1) * self.stride  
            # 生成网格的x坐标  
            self.coord_x = self.indice.repeat((self.feat_sz, 1)) \\
                .view((self.feat_sz * self.feat_sz,)).float().cuda()  
            # 生成网格的y坐标  
            self.coord_y = self.indice.repeat((1, self.feat_sz)) \\
                .view((self.feat_sz * self.feat_sz,)).float().cuda()  

    def forward(self, x, return_dist=False, softmax=True):  
        """ Forward pass with input x. """  
        # 计算左上和右下分支的得分图  
        score_map_tl, score_map_br = self.get_score_map(x)  
        if return_dist:  
            # 若需要返回分布向量,使用soft argmax  
            coorx_tl, coory_tl, prob_vec_tl = self.soft_argmax(score_map_tl, return_dist=True, softmax=softmax)  
            coorx_br, coory_br, prob_vec_br = self.soft_argmax(score_map_br, return_dist=True, softmax=softmax)  
            return torch.stack((coorx_tl, coory_tl, coorx_br, coory_br), dim=1) / self.img_sz, prob_vec_tl, prob_vec_br  
        else:  
            # 若不返回分布则仅返回坐标  
            coorx_tl, coory_tl = self.soft_argmax(score_map_tl)  
            coorx_br, coory_br = self.soft_argmax(score_map_br)  
            return torch.stack((coorx_tl, coory_tl, coorx_br, coory_br), dim=1) / self.img_sz  

    def get_score_map(self, x):  
        x_init = x  
        # top-left branch  
        x_tl1 = self.conv1_tl(x)  
        x_tl2 = self.conv2_tl(x_tl1)  

        #up-1  
        # 上采样和调整特征图  
        x_init_up1 = F.interpolate(self.adjust1_tl(x_init), scale_factor=2)  
        x_up1 = F.interpolate(x_tl2, scale_factor=2)  
        x_up1 = x_init_up1 + x_up1  

        x_tl3 = self.conv3_tl(x_up1)  

        #up-2  
        x_init_up2 = F.interpolate(self.adjust2_tl(x_init), scale_factor=4)  
        x_up2 = F.interpolate(x_tl3, scale_factor=2)  
        x_up2 = x_init_up2 + x_up2  

        x_tl4 = self.conv4_tl(x_up2)  
        # 聚合多层次(金字塔概念)的特征图以得到最终的得分图  
        score_map_tl = self.conv5_tl(x_tl4) + F.interpolate(self.adjust3_tl(x_tl2), scale_factor=4) + F.interpolate(self.adjust4_tl(x_tl3), scale_factor=2)  

        # bottom-right branch  
        # 右下角分支的特征提取,同理如左上角  
        x_br1 = self.conv1_br(x)  
        x_br2 = self.conv2_br(x_br1)  

        # up-1  
        x_init_up1 = F.interpolate(self.adjust1_br(x_init), scale_factor=2)  
        x_up1 = F.interpolate(x_br2, scale_factor=2)  
        x_up1 = x_init_up1 + x_up1  

        x_br3 = self.conv3_br(x_up1)  

        # up-2  
        x_init_up2 = F.interpolate(self.adjust2_br(x_init), scale_factor=4)  
        x_up2 = F.interpolate(x_br3, scale_factor=2)  
        x_up2 = x_init_up2 + x_up2  

        x_br4 = self.conv4_br(x_up2)  
        score_map_br = self.conv5_br(x_br4) + F.interpolate(self.adjust3_br(x_br2), scale_factor=4) + F.interpolate(self.adjust4_br(x_br3), scale_factor=2)  

        return score_map_tl, score_map_br  

    def soft_argmax(self, score_map, return_dist=False, softmax=True):  
        """ get soft-argmax coordinate for a given heatmap """  
        # 将特征图展平为向量  
        score_vec = score_map.view((-1, self.feat_sz * self.feat_sz))  # (batch, feat_sz * feat_sz)  
        # 对信息进行softmax处理以获得概率分布  
        prob_vec = nn.functional.softmax(score_vec, dim=1)  
        # 计算概率加权的x和y坐标  
        exp_x = torch.sum((self.coord_x * prob_vec), dim=1)  
        exp_y = torch.sum((self.coord_y * prob_vec), dim=1)  
        if return_dist:  
            # 返回期望坐标和概率分布  
            if softmax:  
                return exp_x, exp_y, prob_vec  
            else:  
                return exp_x, exp_y, score_vec  
        else:  
            # 仅返回期望坐标  
            return exp_x, exp_y

class MLP(nn.Module):
    """ Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, BN=False):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        if BN:
            self.layers = nn.ModuleList(nn.Sequential(nn.Linear(n, k), nn.BatchNorm1d(k))
                                        for n, k in zip([input_dim] + h, h + [output_dim]))
        else:
            self.layers = nn.ModuleList(nn.Linear(n, k)
                                        for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x

在线模板更新

image.png