未验证 提交 333370d9 编写于 作者: W Wenyu 提交者: GitHub

fix bugs (#6240)

上级 1353aa5d
...@@ -32,6 +32,7 @@ from . import esnet ...@@ -32,6 +32,7 @@ from . import esnet
from . import cspresnet from . import cspresnet
from . import csp_darknet from . import csp_darknet
from . import convnext from . import convnext
from . import vision_transformer
from .vgg import * from .vgg import *
from .resnet import * from .resnet import *
...@@ -52,4 +53,5 @@ from .hardnet import * ...@@ -52,4 +53,5 @@ from .hardnet import *
from .esnet import * from .esnet import *
from .cspresnet import * from .cspresnet import *
from .csp_darknet import * from .csp_darknet import *
from .convnext import * from .convnext import *
\ No newline at end of file from .vision_transformer import *
\ No newline at end of file
...@@ -347,6 +347,10 @@ class VisionTransformer(nn.Layer): ...@@ -347,6 +347,10 @@ class VisionTransformer(nn.Layer):
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.with_fpn = with_fpn self.with_fpn = with_fpn
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.use_sincos_pos_emb = use_sincos_pos_emb
self.use_rel_pos_bias = use_rel_pos_bias
self.final_norm = final_norm
if use_checkpoint: if use_checkpoint:
print('please set: FLAGS_allocator_strategy=naive_best_fit') print('please set: FLAGS_allocator_strategy=naive_best_fit')
self.patch_embed = PatchEmbed( self.patch_embed = PatchEmbed(
...@@ -369,8 +373,8 @@ class VisionTransformer(nn.Layer): ...@@ -369,8 +373,8 @@ class VisionTransformer(nn.Layer):
std=.02)) std=.02))
elif use_sincos_pos_emb: elif use_sincos_pos_emb:
pos_embed = self.build_2d_sincos_position_embedding(embed_dim) pos_embed = self.build_2d_sincos_position_embedding(embed_dim)
self.pos_embed = pos_embed
self.pos_embed = pos_embed
self.pos_embed = self.create_parameter(shape=pos_embed.shape) self.pos_embed = self.create_parameter(shape=pos_embed.shape)
self.pos_embed.set_value(pos_embed.numpy()) self.pos_embed.set_value(pos_embed.numpy())
self.pos_embed.stop_gradient = True self.pos_embed.stop_gradient = True
...@@ -383,15 +387,9 @@ class VisionTransformer(nn.Layer): ...@@ -383,15 +387,9 @@ class VisionTransformer(nn.Layer):
if use_shared_rel_pos_bias: if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias( self.rel_pos_bias = RelativePositionBias(
window_size=self.patch_embed.patch_shape, num_heads=num_heads) window_size=self.patch_embed.patch_shape, num_heads=num_heads)
elif self.use_sincos_pos_emb:
self.pos_embed = self.build_2d_sincos_position_embedding(embed_dim)
else: else:
self.rel_pos_bias = None self.rel_pos_bias = None
self.use_rel_pos_bias = use_rel_pos_bias
dpr = np.linspace(0, drop_path_rate, depth) dpr = np.linspace(0, drop_path_rate, depth)
self.blocks = nn.LayerList([ self.blocks = nn.LayerList([
...@@ -411,17 +409,15 @@ class VisionTransformer(nn.Layer): ...@@ -411,17 +409,15 @@ class VisionTransformer(nn.Layer):
epsilon=epsilon) for i in range(depth) epsilon=epsilon) for i in range(depth)
]) ])
self.final_norm = final_norm
######### del by xy
#if self.final_norm:
# self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
self.pretrained = pretrained self.pretrained = pretrained
self.init_weight() self.init_weight()
assert len(out_indices) <= 4, '' assert len(out_indices) <= 4, ''
self.out_indices = out_indices self.out_indices = out_indices
self.out_channels = [embed_dim for _ in range(len(out_indices))] self.out_channels = [embed_dim for _ in range(len(out_indices))]
self.out_strides = [4, 8, 16, 32][-len(out_indices):] self.out_strides = [4, 8, 16, 32][-len(out_indices):] if with_fpn else [
8 for _ in range(len(out_indices))
]
self.norm = Identity() self.norm = Identity()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册