未验证 提交 333370d9 编写于 作者: W Wenyu 提交者: GitHub

fix bugs (#6240)

上级 1353aa5d
......@@ -32,6 +32,7 @@ from . import esnet
from . import cspresnet
from . import csp_darknet
from . import convnext
from . import vision_transformer
from .vgg import *
from .resnet import *
......@@ -52,4 +53,5 @@ from .hardnet import *
from .esnet import *
from .cspresnet import *
from .csp_darknet import *
from .convnext import *
\ No newline at end of file
from .convnext import *
from .vision_transformer import *
\ No newline at end of file
......@@ -347,6 +347,10 @@ class VisionTransformer(nn.Layer):
self.embed_dim = embed_dim
self.with_fpn = with_fpn
self.use_checkpoint = use_checkpoint
self.use_sincos_pos_emb = use_sincos_pos_emb
self.use_rel_pos_bias = use_rel_pos_bias
self.final_norm = final_norm
if use_checkpoint:
print('please set: FLAGS_allocator_strategy=naive_best_fit')
self.patch_embed = PatchEmbed(
......@@ -369,8 +373,8 @@ class VisionTransformer(nn.Layer):
std=.02))
elif use_sincos_pos_emb:
pos_embed = self.build_2d_sincos_position_embedding(embed_dim)
self.pos_embed = pos_embed
self.pos_embed = pos_embed
self.pos_embed = self.create_parameter(shape=pos_embed.shape)
self.pos_embed.set_value(pos_embed.numpy())
self.pos_embed.stop_gradient = True
......@@ -383,15 +387,9 @@ class VisionTransformer(nn.Layer):
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(
window_size=self.patch_embed.patch_shape, num_heads=num_heads)
elif self.use_sincos_pos_emb:
self.pos_embed = self.build_2d_sincos_position_embedding(embed_dim)
else:
self.rel_pos_bias = None
self.use_rel_pos_bias = use_rel_pos_bias
dpr = np.linspace(0, drop_path_rate, depth)
self.blocks = nn.LayerList([
......@@ -411,17 +409,15 @@ class VisionTransformer(nn.Layer):
epsilon=epsilon) for i in range(depth)
])
self.final_norm = final_norm
######### del by xy
#if self.final_norm:
# self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
self.pretrained = pretrained
self.init_weight()
assert len(out_indices) <= 4, ''
self.out_indices = out_indices
self.out_channels = [embed_dim for _ in range(len(out_indices))]
self.out_strides = [4, 8, 16, 32][-len(out_indices):]
self.out_strides = [4, 8, 16, 32][-len(out_indices):] if with_fpn else [
8 for _ in range(len(out_indices))
]
self.norm = Identity()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册