跟踪通常采用特征提取、目标信息集成和包围盒估计的多级流程。为了简化这一流程,并统一特征提取和目标信息集成的过程,我们提出了一个紧凑的跟踪框架,称为MixFormer。我们的核心设计是利用注意力操作的灵活性,提出了一种混合注意模块(MAM),用于特征提取和目标信息集成。该同步建模方案可以提取特定目标的判别性特征,并在目标和搜索区域之间进行广泛的通信。在MAM的基础上,我们通过叠加多个MAM,逐步嵌入patch,并在其上放置一个定位头来构建MixFormer跟踪框架。此外,为了在在线跟踪过程中处理多个目标模板,我们设计了一种非对称注意方案来降低计算成本,并提出了一种有效的分数预测模块来选择高质量的模板。我们的MixFormer在5个跟踪基准上设置了新的最先进的性能,包括LaSOT、TrackingNet、VOT2020、GOT-10k和UAV123。特别是mixer - l在LaSOT上的NP分数为79.9%,在TrackingNet上的NP分数为88.9%,在VOT2020上的EAO分数为0.555。我们还进行了深入的消融研究,以证明同时特征提取和信息整合的有效性。
https://zhuanlan.zhihu.com/p/485189978?utm_psn=1823159977482780672
https://arxiv.org/pdf/2203.11082
以往跟踪框架存在的问题:(1)需要多个组件构成;(2)CNN-based方法缺乏全局建模能力;(3)Transformer-based方法仍依赖CNN提取特征,并在高层特征上进行注意力建模。为了克服上述问题,作者把特征提取和信息基础进行统一。首先,使特征提取更具体到跟踪目标,提取更多目标判别性特征。其次,将目标信息更广泛地集成到搜索区域中,从而更好的捕获它们之间的相关性。最后,获得了一个更加紧凑和优雅的跟踪框架,无需显式的集成模块
MAM模块相关代码
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class Mlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self,
dim_in,
dim_out,
num_heads,
qkv_bias=False,
attn_drop=0.,
proj_drop=0.,
method='dw_bn',
kernel_size=3,
stride_kv=1,
stride_q=1,
padding_kv=1,
padding_q=1,
with_cls_token=True,
freeze_bn=False,
**kwargs
):
"""
初始化 Attention 模块。
:param dim_in: 输入通道维度。
:param dim_out: 输出通道维度。
:param num_heads: 注意力头的数量。
:param qkv_bias: 是否在 q, k, v 的线性投影中使用偏置。
:param attn_drop: 注意力后的 dropout 概率。
:param proj_drop: 投影后的 dropout 概率。
:param method: 投影的卷积方式,比如 'dw_bn'。
:param kernel_size: 卷积核大小。
:param stride_kv: 键和值的卷积步幅。
:param stride_q: 查询的卷积步幅。
:param padding_kv: 键和值的填充。
:param padding_q: 查询的填充。
:param with_cls_token: 是否包含 class token。
:param freeze_bn: 是否冻结批归一化层。
"""
super().__init__()
self.stride_kv = stride_kv
self.stride_q = stride_q
self.dim = dim_out
self.num_heads = num_heads
self.scale = dim_out ** -0.5 # 缩放因子,用于缩放点积
self.with_cls_token = with_cls_token
# 定义批归一化方式,冻结或不冻结
if freeze_bn:
conv_proj_post_norm = FrozenBatchNorm2d # 如果冻结则使用特定冻结 BN 实现
else:
conv_proj_post_norm = nn.BatchNorm2d # 默认使用常规 BN
# 创建用于查询、键和值的卷积投影
self.conv_proj_q = self._build_projection(dim_in, dim_out, kernel_size, padding_q, stride_q,
'linear' if method == 'avg' else method, conv_proj_post_norm)
self.conv_proj_k = self._build_projection(dim_in, dim_out, kernel_size, padding_kv, stride_kv, method,
conv_proj_post_norm)
self.conv_proj_v = self._build_projection(dim_in, dim_out, kernel_size, padding_kv, stride_kv, method,
conv_proj_post_norm)
# 查询、键、值的线性投影
self.proj_q = nn.Linear(dim_in, dim_out, bias=qkv_bias)
self.proj_k = nn.Linear(dim_in, dim_out, bias=qkv_bias)
self.proj_v = nn.Linear(dim_in, dim_out, bias=qkv_bias)
# 注意力和投影后的 Dropout 操作
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim_out, dim_out)
self.proj_drop = nn.Dropout(proj_drop)
def _build_projection(self,
dim_in,
dim_out,
kernel_size,
padding,
stride,
method,
norm):
"""
创建卷积投影。
:param dim_in: 输入通道维度。
:param dim_out: 输出通道维度。
:param kernel_size: 卷积核大小。
:param padding: 填充。
:param stride: 步幅。
:param method: 卷积类型 ('dw_bn', 'avg', 'linear')。
:param norm: 归一化方法。
:return: 投影操作序列。
"""
if method == 'dw_bn':
# 深度可分离卷积加批归一化
proj = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(dim_in, dim_in, kernel_size=kernel_size, padding=padding, stride=stride,
bias=False, groups=dim_in)),
('bn', norm(dim_in)),
('rearrage', Rearrange('b c h w -> b (h w) c')), # 转换张量形状
]))
elif method == 'avg':
# 平均池化(模拟卷积)
proj = nn.Sequential(OrderedDict([
('avg', nn.AvgPool2d(kernel_size=kernel_size, padding=padding, stride=stride, ceil_mode=True)),
('rearrage', Rearrange('b c h w -> b (h w) c')), # 转换张量形状
]))
elif method == 'linear':
# 线性层跳过卷积,直接处理(需在调用时处理)
proj = None
else:
raise ValueError('Unknown method ({})'.format(method))
return proj
def forward_conv(self, x, t_h, t_w, s_h, s_w):
"""
进行卷积前向传播,用于查询、键和值的特征图提取。
:param x: 输入张量,通常由 template 和 search 组合而成。
:param t_h: 模板高度。
:param t_w: 模板宽度。
:param s_h: 搜索图高度。
:param s_w: 搜索图宽度。
:return: q, k, v - 查询、键和值。
"""
# 将输入 x 分割成 template ,online_template 和 search,按照其各自的大小
template, online_template, search = torch.split(x, [t_h * t_w, t_h * t_w, s_h * s_w], dim=1)
# 重新排列尺寸以适合卷积操作
template = rearrange(template, 'b (h w) c -> b c h w', h=t_h, w=t_w).contiguous()
online_template = rearrange(online_template, 'b (h w) c -> b c h w', h=t_h, w=t_w).contiguous()
search = rearrange(search, 'b (h w) c -> b c h w', h=s_h, w=s_w).contiguous()
# 根据投影是否存在来进行相应的卷积或线性投影操作
if self.conv_proj_q is not None:
t_q = self.conv_proj_q(template)
ot_q = self.conv_proj_q(online_template)
s_q = self.conv_proj_q(search)
q = torch.cat([t_q, ot_q, s_q], dim=1)
else:
t_q = rearrange(template, 'b c h w -> b (h w) c').contiguous()
ot_q = rearrange(online_template, 'b c h w -> b (h w) c').contiguous()
s_q = rearrange(search, 'b c h w -> b (h w) c').contiguous()
q = torch.cat([t_q, ot_q, s_q], dim=1)
if self.conv_proj_k is not None:
t_k = self.conv_proj_k(template)
ot_k = self.conv_proj_k(online_template)
s_k = self.conv_proj_k(search)
k = torch.cat([t_k, ot_k, s_k], dim=1)
else:
t_k = rearrange(template, 'b c h w -> b (h w) c').contiguous()
ot_k = rearrange(online_template, 'b c h w -> b (h w) c').contiguous()
s_k = rearrange(search, 'b c h w -> b (h w) c').contiguous()
k = torch.cat([t_k, ot_k, s_k], dim=1)
if self.conv_proj_v is not None:
t_v = self.conv_proj_v(template)
ot_v = self.conv_proj_v(online_template)
s_v = self.conv_proj_v(search)
v = torch.cat([t_v, ot_v, s_v], dim=1)
else:
t_v = rearrange(template, 'b c h w -> b (h w) c').contiguous()
ot_v = rearrange(online_template, 'b c h w -> b (h w) c').contiguous()
s_v = rearrange(search, 'b c h w -> b (h w) c').contiguous()
v = torch.cat([t_v, ot_v, s_v], dim=1)
return q, k, v
def forward_conv_test(self, x, s_h, s_w):
"""
测试时的卷积操作。
:param x: 搜索区域张量。
:param s_h: 搜索图高度。
:param s_w: 搜索图宽度。
:return: q, k, v - 查询、键和值。
"""
search = x
search = rearrange(search, 'b (h w) c -> b c h w', h=s_h, w=s_w).contiguous()
if self.conv_proj_q is not None:
q = self.conv_proj_q(search)
else:
q = rearrange(search, 'b c h w -> b (h w) c').contiguous()
if self.conv_proj_k is not None:
k = self.conv_proj_k(search)
else:
k = rearrange(search, 'b c h w -> b (h w) c').contiguous()
k = torch.cat([self.t_k, self.ot_k, k], dim=1)
if self.conv_proj_v is not None:
v = self.conv_proj_v(search)
else:
v = rearrange(search, 'b c h w -> b (h w) c').contiguous()
v = torch.cat([self.t_v, self.ot_v, v], dim=1)
return q, k, v
def forward(self, x, t_h, t_w, s_h, s_w):
"""
非对称混合注意力的前向传播。
:param x: 输入张量。
:param t_h: 模板高度。
:param t_w: 模板宽度。
:param s_h: 搜索图高度。
:param s_w: 搜索图宽度。
:return: 经过 attention操作的输出。
"""
if (
self.conv_proj_q is not None
or self.conv_proj_k is not None
or self.conv_proj_v is not None
):
q, k, v = self.forward_conv(x, t_h, t_w, s_h, s_w)
# 将 q, k, v 调整为多头形式
q = rearrange(self.proj_q(q), 'b t (h d) -> b h t d', h=self.num_heads).contiguous()
k = rearrange(self.proj_k(k), 'b t (h d) -> b h t d', h=self.num_heads).contiguous()
v = rearrange(self.proj_v(v), 'b t (h d) -> b h t d', h=self.num_heads).contiguous()
### 注意: k/v 压缩,q_size 的 1/4(conv_stride=2)
q_mt, q_s = torch.split(q, [t_h * t_w * 2, s_h * s_w], dim=2)
k_mt, k_s = torch.split(k, [((t_h + 1) // 2) ** 2 * 2, s_h * s_w // 4], dim=2)
v_mt, v_s = torch.split(v, [((t_h + 1) // 2) ** 2 * 2, s_h * s_w // 4], dim=2)
# 模板注意力计算
attn_score = torch.einsum('bhlk,bhtk->bhlt', [q_mt, k_mt]) * self.scale
attn = F.softmax(attn_score, dim=-1)
attn = self.attn_drop(attn)
x_mt = torch.einsum('bhlt,bhtv->bhlv', [attn, v_mt])
x_mt = rearrange(x_mt, 'b h t d -> b t (h d)')
# 搜索区域注意力计算
attn_score = torch.einsum('bhlk,bhtk->bhlt', [q_s, k]) * self.scale
attn = F.softmax(attn_score, dim=-1)
attn = self.attn_drop(attn)
x_s = torch.einsum('bhlt,bhtv->bhlv', [attn, v])
x_s = rearrange(x_s, 'b h t d -> b t (h d)')
x = torch.cat([x_mt, x_s], dim=1)
# 经过线性投影和投影后 dropout
x = self.proj(x)
x = self.proj_drop(x)
return x
def forward_test(self, x, s_h, s_w):
"""
测试阶段使用的前向传播,用于处理仅有搜索区域的场景。
:param x: 搜索区域张量。
:param s_h: 搜索图高度。
:param s_w: 搜索图宽度。
:return: 经过 attention 运算后的输出。
"""
if (
self.conv_proj_q is not None
or self.conv_proj_k is not None
or self.conv_proj_v is not None
):
q_s, k, v = self.forward_conv_test(x, s_h, s_w)
q_s = rearrange(self.proj_q(q_s), 'b t (h d) -> b h t d', h=self.num_heads).contiguous()
k = rearrange(self.proj_k(k), 'b t (h d) -> b h t d', h=self.num_heads).contiguous()
v = rearrange(self.proj_v(v), 'b t (h d) -> b h t d', h=self.num_heads).contiguous()
attn_score = torch.einsum('bhlk,bhtk->bhlt', [q_s, k]) * self.scale
attn = F.softmax(attn_score, dim=-1)
attn = self.attn_drop(attn)
x_s = torch.einsum('bhlt,bhtv->bhlv', [attn, v])
x_s = rearrange(x_s, 'b h t d -> b t (h d)').contiguous()
x = x_s
x = self.proj(x)
x = self.proj_drop(x)
return x
def set_online(self, x, t_h, t_w):
"""
在线模式设置,用于将模板和在线模板组合为一个查询、键和值。
:param x: 输入张量,包括模板和在线模板。
:param t_h: 模板高度。
:param t_w: 模板宽度。
:return: 设定后的特征。
"""
template = x[:, :t_h * t_w] # 1, 1024, c
online_template = x[:, t_h * t_w:] # 1, b*1024, c
template = rearrange(template, 'b (h w) c -> b c h w', h=t_h, w=t_w).contiguous()
online_template = rearrange(online_template.squeeze(0), '(b h w) c -> b c h w', h=t_h, w=t_w).contiguous() # b, c, 32, 32
if self.conv_proj_q is not None:
t_q = self.conv_proj_q(template)
ot_q = self.conv_proj_q(online_template).flatten(end_dim=1).unsqueeze(0)
else:
t_q = rearrange(template, 'b c h w -> b (h w) c').contiguous()
ot_q = rearrange(online_template, 'b c h w -> (b h w) c').contiguous().unsqueeze(0)
q = torch.cat([t_q, ot_q], dim=1)
if self.conv_proj_k is not None:
self.t_k = self.conv_proj_k(template)
self.ot_k = self.conv_proj_k(online_template).flatten(end_dim=1).unsqueeze(0)
else:
self.t_k = rearrange(template, 'b c h w -> b (h w) c').contiguous()
self.ot_k = rearrange(online_template, 'b c h w -> (b h w) c').contiguous().unsqueeze(0)
k = torch.cat([self.t_k, self.ot_k], dim=1)
if self.conv_proj_v is not None:
self.t_v = self.conv_proj_v(template)
self.ot_v = self.conv_proj_v(online_template).flatten(end_dim=1).unsqueeze(0)
else:
self.t_v = rearrange(template, 'b c h w -> b (h w) c').contiguous()
self.ot_v = rearrange(online_template, 'b c h w -> (b h w) c').contiguous().unsqueeze(0)
v = torch.cat([self.t_v, self.ot_v], dim=1)
# 将 q, k, v 重新排列为多头形态以进行注意力计算
q = rearrange(self.proj_q(q), 'b t (h d) -> b h t d', h=self.num_heads).contiguous()
k = rearrange(self.proj_k(k), 'b t (h d) -> b h t d', h=self.num_heads).contiguous()
v = rearrange(self.proj_v(v), 'b t (h d) -> b h t d', h=self.num_heads).contiguous()
attn_score = torch.einsum('bhlk,bhtk->bhlt', [q, k]) * self.scale
attn = F.softmax(attn_score, dim=-1)
attn = self.attn_drop(attn)
x = torch.einsum('bhlt,bhtv->bhlv', [attn, v])
x = rearrange(x, 'b h t d -> b t (h d)').contiguous()
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
"""
Transformer Block Module
This class defines a single block of a Transformer model, which includes
attention mechanisms and a feed-forward neural network (MLP).
"""
def __init__(self,
dim_in,
dim_out,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
freeze_bn=False,
**kwargs):
"""
Initialize the Transformer Block.
Parameters:
- dim_in: int, the input dimension size.
- dim_out: int, the output dimension size.
- num_heads: int, number of attention heads.
- mlp_ratio: float, ratio to determine the hidden layer size in MLP.
- qkv_bias: bool, whether to use bias in query, key, value projections.
- drop: float, dropout rate for MLP.
- attn_drop: float, dropout rate for attention.
- drop_path: float, drop path rate for stochastic depth.
- act_layer: activation layer, default is GELU.
- norm_layer: normalization layer, default is LayerNorm.
- freeze_bn: bool, whether to freeze batch normalization layers.
- kwargs: additional keyword arguments.
"""
super().__init__()
# Whether to use class token
self.with_cls_token = kwargs['with_cls_token']
# First normalization layer
self.norm1 = norm_layer(dim_in)
# Attention layer
self.attn = Attention(
dim_in, dim_out, num_heads, qkv_bias, attn_drop, drop, freeze_bn=freeze_bn,
**kwargs
)
# Drop path for stochastic depth regularization
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
# Second normalization layer
self.norm2 = norm_layer(dim_out)
# Define the hidden dimension size for the MLP
dim_mlp_hidden = int(dim_out * mlp_ratio)
# MLP layer
self.mlp = Mlp(
in_features=dim_out,
hidden_features=dim_mlp_hidden,
act_layer=act_layer,
drop=drop
)
def forward(self, x, t_h, t_w, s_h, s_w):
"""
Forward pass for training.
Parameters:
- x: input tensor.
- t_h, t_w: template height and width.
- s_h, s_w: search area height and width.
Returns:
- Output tensor after processing through the block.
"""
res = x # Save the input for residual connection
x = self.norm1(x) # Apply first normalization
attn = self.attn(x, t_h, t_w, s_h, s_w) # Compute attention
x = res + self.drop_path(attn) # Add residual connection and apply drop path
x = x + self.drop_path(self.mlp(self.norm2(x))) # Apply MLP with residual connection
return x
def forward_test(self, x, s_h, s_w):
"""
Forward pass for testing.
Parameters:
- x: input tensor.
- s_h, s_w: search area height and width.
Returns:
- Output tensor after processing through the block.
"""
res = x # Save the input for residual connection
x = self.norm1(x) # Apply first normalization
attn = self.attn.forward_test(x, s_h, s_w) # Compute attention in test mode
x = res + self.drop_path(attn) # Add residual connection and apply drop path
x = x + self.drop_path(self.mlp(self.norm2(x))) # Apply MLP with residual connection
return x
def set_online(self, x, t_h, t_w):
"""
Set the block for online mode.
Parameters:
- x: input tensor.
- t_h, t_w: template height and width.
Returns:
- Output tensor after processing through the block.
"""
res = x # Save the input for residual connection
x = self.norm1(x) # Apply first normalization
attn = self.attn.set_online(x, t_h, t_w) # Compute attention in online mode
x = res + self.drop_path(attn) # Add residual connection and apply drop path
x = x + self.drop_path(self.mlp(self.norm2(x))) # Apply MLP with residual connection
return x
class ConvEmbed(nn.Module):
""" 将图像转换为卷积嵌入
ConvEmbed 类用于将输入图像转换为卷积嵌入向量。这在视觉 Transformer 模型中是常见的,将二维图像块映射到嵌入空间。
"""
def __init__(self,
patch_size=7,
in_chans=3,
embed_dim=64,
stride=4,
padding=2,
norm_layer=None):
"""
初始化 ConvEmbed 类。
参数:
- patch_size: int or tuple, 指定每个卷积过滤器的大小。默认为 7。
- in_chans: int, 输入的通道数(例如,RGB 图像有 3 个通道)。默认为 3。
- embed_dim: int, 输出嵌入特征的维度。默认为 64。
- stride: int, 卷积操作的步幅。默认为 4。
- padding: int, 卷积操作的填充大小。默认为 2。
- norm_layer: callable or None, 用于标准化嵌入特征的层,可以是 nn.LayerNorm 类或其他自定义归一化层。默认为 None。
"""
super().__init__()
# 当 patch_size 为整数时,将其转换为 (patch_size, patch_size) 的元组形式
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size # 储存块大小
# 定义卷积层,将输入通道数映射到指定的嵌入维度
self.proj = nn.Conv2d(
in_chans, # 输入通道数
embed_dim, # 输出通道数,即嵌入维度
kernel_size=patch_size, # 卷积核大小
stride=stride, # 卷积步幅
padding=padding # 填充大小
)
# 如果提供了标准化层,则进行初始化,否则不使用标准化
self.norm = norm_layer(embed_dim) if norm_layer else None
def forward(self, x):
"""
前向传播。
参数:
- x: 输入张量,形状通常为 (batch_size, channels, height, width)
返回:
- 转换后的嵌入张量,通常是与输入相同的形状。
"""
# 通过卷积层获取嵌入特征
x = self.proj(x)
# 获取卷积后的张量的形状,B 是批量大小,C 是通道数,H 和 W 是高和宽
B, C, H, W = x.shape
# 重新排列张量,并将 2D 图像(h, w)展平为 1D 序列(h w)
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
# 如果定义了 norm_layer,则对嵌入特征进行标准化
if self.norm:
x = self.norm(x)
# 将展平的张量转换回其原始形状,在应用标准化后
x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W).contiguous()
return x # 返回转换后的嵌入张量
class VisionTransformer(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self,
patch_size=16,
patch_stride=16,
patch_padding=0,
in_chans=3,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
init='trunc_norm',
freeze_bn=False,
**kwargs):
super().__init__()
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.rearrage = None
self.patch_embed = ConvEmbed(
# img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
stride=patch_stride,
padding=patch_padding,
embed_dim=embed_dim,
norm_layer=norm_layer
)
with_cls_token = kwargs['with_cls_token']
if with_cls_token:
self.cls_token = nn.Parameter(
torch.zeros(1, 1, embed_dim)
)
else:
self.cls_token = None
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
blocks = []
for j in range(depth):
blocks.append(
Block(
dim_in=embed_dim,
dim_out=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[j],
act_layer=act_layer,
norm_layer=norm_layer,
freeze_bn=freeze_bn,
**kwargs
)
)
self.blocks = nn.ModuleList(blocks)
if self.cls_token is not None:
trunc_normal_(self.cls_token, std=.02)
if init == 'xavier':
self.apply(self._init_weights_xavier)
else:
self.apply(self._init_weights_trunc_normal)
def _init_weights_trunc_normal(self, m):
if isinstance(m, nn.Linear):
logging.info('=> init weight of Linear from trunc norm')
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
logging.info('=> init bias of Linear to zeros')
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _init_weights_xavier(self, m):
if isinstance(m, nn.Linear):
logging.info('=> init weight of Linear from xavier uniform')
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
logging.info('=> init bias of Linear to zeros')
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, template, online_template, search):
"""
:param template: (batch, c, 128, 128)
:param search: (batch, c, 320, 320)
:return:
"""
# x = self.patch_embed(x)
# B, C, H, W = x.size()
template = self.patch_embed(template)
online_template = self.patch_embed(online_template)
t_B, t_C, t_H, t_W = template.size()
search = self.patch_embed(search)
s_B, s_C, s_H, s_W = search.size()
template = rearrange(template, 'b c h w -> b (h w) c').contiguous()
online_template = rearrange(online_template, 'b c h w -> b (h w) c').contiguous()
search = rearrange(search, 'b c h w -> b (h w) c').contiguous()
x = torch.cat([template, online_template, search], dim=1)
x = self.pos_drop(x)
for i, blk in enumerate(self.blocks):
x = blk(x, t_H, t_W, s_H, s_W)
# if self.cls_token is not None:
# cls_tokens, x = torch.split(x, [1, H*W], 1)
template, online_template, search = torch.split(x, [t_H*t_W, t_H*t_W, s_H*s_W], dim=1)
template = rearrange(template, 'b (h w) c -> b c h w', h=t_H, w=t_W).contiguous()
online_template = rearrange(online_template, 'b (h w) c -> b c h w', h=t_H, w=t_W).contiguous()
search = rearrange(search, 'b (h w) c -> b c h w', h=s_H, w=s_W).contiguous()
return template, online_template, search
def forward_test(self, search):
# x = self.patch_embed(x)
# B, C, H, W = x.size()
search = self.patch_embed(search)
s_B, s_C, s_H, s_W = search.size()
search = rearrange(search, 'b c h w -> b (h w) c').contiguous()
x = search
# x = torch.cat([template, search], dim=1)
x = self.pos_drop(x)
for i, blk in enumerate(self.blocks):
x = blk.forward_test(x, s_H, s_W)
# if self.cls_token is not None:
# cls_tokens, x = torch.split(x, [1, H*W], 1)
# template, search = torch.split(x, [t_H*t_W, s_H*s_W], dim=1)
search = x
search = rearrange(search, 'b (h w) c -> b c h w', h=s_H, w=s_W)
return search
def set_online(self, template, online_template):
template = self.patch_embed(template)
online_template = self.patch_embed(online_template)
t_B, t_C, t_H, t_W = template.size()
template = rearrange(template, 'b c h w -> b (h w) c').contiguous()
online_template = rearrange(online_template, 'b c h w -> (b h w) c').unsqueeze(0).contiguous()
# 1, 1024, c
# 1, b*1024, c
# print(template.shape, online_template.shape)
x = torch.cat([template, online_template], dim=1)
x = self.pos_drop(x)
for i, blk in enumerate(self.blocks):
x = blk.set_online(x, t_H, t_W)
# if self.cls_token is not None:
# cls_tokens, x = torch.split(x, [1, H*W], 1)
template = x[:, :t_H*t_W]
online_template = x[:, t_H*t_W:]
template = rearrange(template, 'b (h w) c -> b c h w', h=t_H, w=t_W)
online_template = rearrange(online_template.squeeze(0), '(b h w) c -> b c h w', h=t_H, w=t_W)
return template, online_template
class ConvolutionalVisionTransformer(nn.Module):
def __init__(self,
in_chans=3,
# num_classes=1000,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
init='trunc_norm',
spec=None):
super().__init__()
# self.num_classes = num_classes
self.num_stages = spec['NUM_STAGES']
for i in range(self.num_stages):
kwargs = {
'patch_size': spec['PATCH_SIZE'][i],
'patch_stride': spec['PATCH_STRIDE'][i],
'patch_padding': spec['PATCH_PADDING'][i],
'embed_dim': spec['DIM_EMBED'][i],
'depth': spec['DEPTH'][i],
'num_heads': spec['NUM_HEADS'][i],
'mlp_ratio': spec['MLP_RATIO'][i],
'qkv_bias': spec['QKV_BIAS'][i],
'drop_rate': spec['DROP_RATE'][i],
'attn_drop_rate': spec['ATTN_DROP_RATE'][i],
'drop_path_rate': spec['DROP_PATH_RATE'][i],
'with_cls_token': spec['CLS_TOKEN'][i],
'method': spec['QKV_PROJ_METHOD'][i],
'kernel_size': spec['KERNEL_QKV'][i],
'padding_q': spec['PADDING_Q'][i],
'padding_kv': spec['PADDING_KV'][i],
'stride_kv': spec['STRIDE_KV'][i],
'stride_q': spec['STRIDE_Q'][i],
'freeze_bn': spec['FREEZE_BN'],
}
stage = VisionTransformer(
in_chans=in_chans,
init=init,
act_layer=act_layer,
norm_layer=norm_layer,
**kwargs
)
setattr(self, f'stage{i}', stage)
in_chans = spec['DIM_EMBED'][i]
dim_embed = spec['DIM_EMBED'][-1]
self.norm = norm_layer(dim_embed)
self.cls_token = spec['CLS_TOKEN'][-1]
# Classifier head
self.head = nn.Linear(dim_embed, 1000)
trunc_normal_(self.head.weight, std=0.02)
def forward(self, template, online_template, search):
"""
:param template: (b, 3, 128, 128)
:param search: (b, 3, 320, 320)
:return:
"""
# template = template + self.template_emb
# search = search + self.search_emb
for i in range(self.num_stages):
template, online_template, search = getattr(self, f'stage{i}')(template, online_template, search)
return template, search
def forward_test(self, search):
for i in range(self.num_stages):
search = getattr(self, f'stage{i}').forward_test(search)
return search
def set_online(self, template, online_template):
for i in range(self.num_stages):
template, online_template = getattr(self, f'stage{i}').set_online(template, online_template)
def get_mixformer_model(config, **kwargs):
msvit_spec = config.MODEL.BACKBONE
msvit = ConvolutionalVisionTransformer(
in_chans=3,
act_layer=QuickGELU,
norm_layer=partial(LayerNorm, eps=1e-5),
init=getattr(msvit_spec, 'INIT', 'trunc_norm'),
spec=msvit_spec
)
if config.MODEL.BACKBONE.PRETRAINED:
try:
ckpt_path = config.MODEL.BACKBONE.PRETRAINED_PATH
ckpt = torch.load(ckpt_path, map_location='cpu')
missing_keys, unexpected_keys = msvit.load_state_dict(ckpt, strict=False)
if is_main_process():
print("Load pretrained backbone checkpoint from:", ckpt_path)
print("missing keys:", missing_keys)
print("unexpected keys:", unexpected_keys)
print("Loading pretrained CVT done.")
except:
print("Warning: Pretrained CVT weights are not loaded")
return msvit
head的部分提供了MLP Corner Pyramid_Corner三种类型的Head
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1,
freeze_bn=False):
if freeze_bn:
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
FrozenBatchNorm2d(out_planes),
nn.ReLU(inplace=True))
else:
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.BatchNorm2d(out_planes),
nn.ReLU(inplace=True))
class Corner_Predictor(nn.Module):
""" Corner Predictor module"""
def __init__(self, inplanes=64, channel=256, feat_sz=20, stride=16, freeze_bn=False):
super(Corner_Predictor, self).__init__()
# 特征图尺寸
self.feat_sz = feat_sz
# 步幅,用于将特征图坐标映射回图像坐标
self.stride = stride
# 图像尺寸,通过特征图尺寸和步幅计算得到
self.img_sz = self.feat_sz * self.stride
'''top-left corner'''
# 定义用于预测左上角的卷积层
self.conv1_tl = conv(inplanes, channel, freeze_bn=freeze_bn)
# 通道减少一半
self.conv2_tl = conv(channel, channel // 2, freeze_bn=freeze_bn)
# 通道再减少一半
self.conv3_tl = conv(channel // 2, channel // 4, freeze_bn=freeze_bn)
# 通道再减少一半
self.conv4_tl = conv(channel // 4, channel // 8, freeze_bn=freeze_bn)
# 输出一维,是坐标的概率图
self.conv5_tl = nn.Conv2d(channel // 8, 1, kernel_size=1)
'''bottom-right corner'''
# 以下为预测右下角的卷积层,与左上角类似
self.conv1_br = conv(inplanes, channel, freeze_bn=freeze_bn)
self.conv2_br = conv(channel, channel // 2, freeze_bn=freeze_bn)
self.conv3_br = conv(channel // 2, channel // 4, freeze_bn=freeze_bn)
self.conv4_br = conv(channel // 4, channel // 8, freeze_bn=freeze_bn)
self.conv5_br = nn.Conv2d(channel // 8, 1, kernel_size=1)
'''about coordinates and indexs'''
with torch.no_grad():
# 生成从0到feat_sz的索引,这些索引用来计算特征图坐标
self.indice = torch.arange(0, self.feat_sz).view(-1, 1) * self.stride
# 生成坐标网格
self.coord_x = self.indice.repeat((self.feat_sz, 1)) \\
.view((self.feat_sz * self.feat_sz,)).float().cuda()
self.coord_y = self.indice.repeat((1, self.feat_sz)) \\
.view((self.feat_sz * self.feat_sz,)).float().cuda()
def forward(self, x, return_dist=False, softmax=True):
""" Forward pass with input x. """
# 获取左上角和右下角的得分图
score_map_tl, score_map_br = self.get_score_map(x)
if return_dist:
# 如果需要返回概率分布,计算软argmax并返回
coorx_tl, coory_tl, prob_vec_tl = self.soft_argmax(score_map_tl, return_dist=True, softmax=softmax)
coorx_br, coory_br, prob_vec_br = self.soft_argmax(score_map_br, return_dist=True, softmax=softmax)
return torch.stack((coorx_tl, coory_tl, coorx_br, coory_br), dim=1) / self.img_sz, prob_vec_tl, prob_vec_br
else:
# 直接返回坐标
coorx_tl, coory_tl = self.soft_argmax(score_map_tl)
coorx_br, coory_br = self.soft_argmax(score_map_br)
return torch.stack((coorx_tl, coory_tl, coorx_br, coory_br), dim=1) / self.img_sz
def get_score_map(self, x):
# top-left branch
# 依次通过卷积层进行特征提取
x_tl1 = self.conv1_tl(x)
x_tl2 = self.conv2_tl(x_tl1)
x_tl3 = self.conv3_tl(x_tl2)
x_tl4 = self.conv4_tl(x_tl3)
score_map_tl = self.conv5_tl(x_tl4)
# bottom-right branch
# 右下角分支,进行与左上角分支类似的特征提取
x_br1 = self.conv1_br(x)
x_br2 = self.conv2_br(x_br1)
x_br3 = self.conv3_br(x_br2)
x_br4 = self.conv4_br(x_br3)
score_map_br = self.conv5_br(x_br4)
return score_map_tl, score_map_br
def soft_argmax(self, score_map, return_dist=False, softmax=True):
""" get soft-argmax coordinate for a given heatmap """
# 展开特征图为向量形式
score_vec = score_map.view((-1, self.feat_sz * self.feat_sz))
# 计算softmax以获得概率分布
prob_vec = nn.functional.softmax(score_vec, dim=1)
# 计算加权平均坐标,通过概率和坐标网格
exp_x = torch.sum((self.coord_x * prob_vec), dim=1)
exp_y = torch.sum((self.coord_y * prob_vec), dim=1)
if return_dist:
# 返回分布向量和坐标
if softmax:
return exp_x, exp_y, prob_vec
else:
return exp_x, exp_y, score_vec
else:
# 仅返回平均坐标
return exp_x, exp_y
class Pyramid_Corner_Predictor(nn.Module):
""" Corner Predictor module"""
def __init__(self, inplanes=64, channel=256, feat_sz=20, stride=16, freeze_bn=False):
super(Pyramid_Corner_Predictor, self).__init__()
# 初始化特征图尺寸
self.feat_sz = feat_sz
# 定义步幅,用于将特征图坐标转换到图像坐标
self.stride = stride
# 总图像尺寸,通过特征图尺寸和步幅计算得到
self.img_sz = self.feat_sz * self.stride
'''top-left corner'''
# 用于预测左上角的卷积层
self.conv1_tl = conv(inplanes, channel, freeze_bn=freeze_bn)
self.conv2_tl = conv(channel, channel // 2, freeze_bn=freeze_bn)
self.conv3_tl = conv(channel // 2, channel // 4, freeze_bn=freeze_bn)
self.conv4_tl = conv(channel // 4, channel // 8, freeze_bn=freeze_bn)
self.conv5_tl = nn.Conv2d(channel // 8, 1, kernel_size=1)
# 收缩特征图并调整通道的层
self.adjust1_tl = conv(inplanes, channel // 2, freeze_bn=freeze_bn)
self.adjust2_tl = conv(inplanes, channel // 4, freeze_bn=freeze_bn)
# 一系列卷积层调整特征图,对应不同的特征尺寸
self.adjust3_tl = nn.Sequential(conv(channel // 2, channel // 4, freeze_bn=freeze_bn),
conv(channel // 4, channel // 8, freeze_bn=freeze_bn),
conv(channel // 8, 1, freeze_bn=freeze_bn))
self.adjust4_tl = nn.Sequential(conv(channel // 4, channel // 8, freeze_bn=freeze_bn),
conv(channel // 8, 1, freeze_bn=freeze_bn))
'''bottom-right corner'''
# 对于右下角的卷积层,同理如左上角
self.conv1_br = conv(inplanes, channel, freeze_bn=freeze_bn)
self.conv2_br = conv(channel, channel // 2, freeze_bn=freeze_bn)
self.conv3_br = conv(channel // 2, channel // 4, freeze_bn=freeze_bn)
self.conv4_br = conv(channel // 4, channel // 8, freeze_bn=freeze_bn)
self.conv5_br = nn.Conv2d(channel // 8, 1, kernel_size=1)
self.adjust1_br = conv(inplanes, channel // 2, freeze_bn=freeze_bn)
self.adjust2_br = conv(inplanes, channel // 4, freeze_bn=freeze_bn)
self.adjust3_br = nn.Sequential(conv(channel // 2, channel // 4, freeze_bn=freeze_bn),
conv(channel // 4, channel // 8, freeze_bn=freeze_bn),
conv(channel // 8, 1, freeze_bn=freeze_bn))
self.adjust4_br = nn.Sequential(conv(channel // 4, channel // 8, freeze_bn=freeze_bn),
conv(channel // 8, 1, freeze_bn=freeze_bn))
'''about coordinates and indexs'''
with torch.no_grad():
# 使用步幅生成索引,用于特征图的坐标转换
self.indice = torch.arange(0, self.feat_sz).view(-1, 1) * self.stride
# 生成网格的x坐标
self.coord_x = self.indice.repeat((self.feat_sz, 1)) \\
.view((self.feat_sz * self.feat_sz,)).float().cuda()
# 生成网格的y坐标
self.coord_y = self.indice.repeat((1, self.feat_sz)) \\
.view((self.feat_sz * self.feat_sz,)).float().cuda()
def forward(self, x, return_dist=False, softmax=True):
""" Forward pass with input x. """
# 计算左上和右下分支的得分图
score_map_tl, score_map_br = self.get_score_map(x)
if return_dist:
# 若需要返回分布向量,使用soft argmax
coorx_tl, coory_tl, prob_vec_tl = self.soft_argmax(score_map_tl, return_dist=True, softmax=softmax)
coorx_br, coory_br, prob_vec_br = self.soft_argmax(score_map_br, return_dist=True, softmax=softmax)
return torch.stack((coorx_tl, coory_tl, coorx_br, coory_br), dim=1) / self.img_sz, prob_vec_tl, prob_vec_br
else:
# 若不返回分布则仅返回坐标
coorx_tl, coory_tl = self.soft_argmax(score_map_tl)
coorx_br, coory_br = self.soft_argmax(score_map_br)
return torch.stack((coorx_tl, coory_tl, coorx_br, coory_br), dim=1) / self.img_sz
def get_score_map(self, x):
x_init = x
# top-left branch
x_tl1 = self.conv1_tl(x)
x_tl2 = self.conv2_tl(x_tl1)
#up-1
# 上采样和调整特征图
x_init_up1 = F.interpolate(self.adjust1_tl(x_init), scale_factor=2)
x_up1 = F.interpolate(x_tl2, scale_factor=2)
x_up1 = x_init_up1 + x_up1
x_tl3 = self.conv3_tl(x_up1)
#up-2
x_init_up2 = F.interpolate(self.adjust2_tl(x_init), scale_factor=4)
x_up2 = F.interpolate(x_tl3, scale_factor=2)
x_up2 = x_init_up2 + x_up2
x_tl4 = self.conv4_tl(x_up2)
# 聚合多层次(金字塔概念)的特征图以得到最终的得分图
score_map_tl = self.conv5_tl(x_tl4) + F.interpolate(self.adjust3_tl(x_tl2), scale_factor=4) + F.interpolate(self.adjust4_tl(x_tl3), scale_factor=2)
# bottom-right branch
# 右下角分支的特征提取,同理如左上角
x_br1 = self.conv1_br(x)
x_br2 = self.conv2_br(x_br1)
# up-1
x_init_up1 = F.interpolate(self.adjust1_br(x_init), scale_factor=2)
x_up1 = F.interpolate(x_br2, scale_factor=2)
x_up1 = x_init_up1 + x_up1
x_br3 = self.conv3_br(x_up1)
# up-2
x_init_up2 = F.interpolate(self.adjust2_br(x_init), scale_factor=4)
x_up2 = F.interpolate(x_br3, scale_factor=2)
x_up2 = x_init_up2 + x_up2
x_br4 = self.conv4_br(x_up2)
score_map_br = self.conv5_br(x_br4) + F.interpolate(self.adjust3_br(x_br2), scale_factor=4) + F.interpolate(self.adjust4_br(x_br3), scale_factor=2)
return score_map_tl, score_map_br
def soft_argmax(self, score_map, return_dist=False, softmax=True):
""" get soft-argmax coordinate for a given heatmap """
# 将特征图展平为向量
score_vec = score_map.view((-1, self.feat_sz * self.feat_sz)) # (batch, feat_sz * feat_sz)
# 对信息进行softmax处理以获得概率分布
prob_vec = nn.functional.softmax(score_vec, dim=1)
# 计算概率加权的x和y坐标
exp_x = torch.sum((self.coord_x * prob_vec), dim=1)
exp_y = torch.sum((self.coord_y * prob_vec), dim=1)
if return_dist:
# 返回期望坐标和概率分布
if softmax:
return exp_x, exp_y, prob_vec
else:
return exp_x, exp_y, score_vec
else:
# 仅返回期望坐标
return exp_x, exp_y
class MLP(nn.Module):
""" Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, BN=False):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
if BN:
self.layers = nn.ModuleList(nn.Sequential(nn.Linear(n, k), nn.BatchNorm1d(k))
for n, k in zip([input_dim] + h, h + [output_dim]))
else:
self.layers = nn.ModuleList(nn.Linear(n, k)
for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x