未验证 提交 a27df36d 编写于 作者: W Wenyu 提交者: GitHub

fix arange to static for inference (#7279)

上级 e639b354
...@@ -184,7 +184,10 @@ class TransformerEncoder(nn.Layer): ...@@ -184,7 +184,10 @@ class TransformerEncoder(nn.Layer):
@register @register
@serializable @serializable
class CustomCSPPAN(nn.Layer): class CustomCSPPAN(nn.Layer):
__shared__ = ['norm_type', 'data_format', 'width_mult', 'depth_mult', 'trt'] __shared__ = [
'norm_type', 'data_format', 'width_mult', 'depth_mult', 'trt',
'eval_size'
]
def __init__(self, def __init__(self,
in_channels=[256, 512, 1024], in_channels=[256, 512, 1024],
...@@ -212,7 +215,8 @@ class CustomCSPPAN(nn.Layer): ...@@ -212,7 +215,8 @@ class CustomCSPPAN(nn.Layer):
attn_dropout=None, attn_dropout=None,
act_dropout=None, act_dropout=None,
normalize_before=False, normalize_before=False,
use_trans=False): use_trans=False,
eval_size=None):
super(CustomCSPPAN, self).__init__() super(CustomCSPPAN, self).__init__()
out_channels = [max(round(c * width_mult), 1) for c in out_channels] out_channels = [max(round(c * width_mult), 1) for c in out_channels]
...@@ -223,19 +227,29 @@ class CustomCSPPAN(nn.Layer): ...@@ -223,19 +227,29 @@ class CustomCSPPAN(nn.Layer):
self.num_blocks = len(in_channels) self.num_blocks = len(in_channels)
self.data_format = data_format self.data_format = data_format
self._out_channels = out_channels self._out_channels = out_channels
self.hidden_dim = in_channels[-1] self.hidden_dim = in_channels[-1]
in_channels = in_channels[::-1] in_channels = in_channels[::-1]
self.nhead = nhead
self.num_layers = num_layers
self.use_trans = use_trans self.use_trans = use_trans
self.eval_size = eval_size
if use_trans: if use_trans:
if eval_size is not None:
self.pos_embed = self.build_2d_sincos_position_embedding(
eval_size[1] // 32,
eval_size[0] // 32,
embed_dim=self.hidden_dim)
else:
self.pos_embed = None
encoder_layer = TransformerEncoderLayer( encoder_layer = TransformerEncoderLayer(
self.hidden_dim, nhead, dim_feedforward, dropout, activation, self.hidden_dim, nhead, dim_feedforward, dropout, activation,
attn_dropout, act_dropout, normalize_before) attn_dropout, act_dropout, normalize_before)
encoder_norm = nn.LayerNorm( encoder_norm = nn.LayerNorm(
self.hidden_dim) if normalize_before else None self.hidden_dim) if normalize_before else None
self.encoder = TransformerEncoder(encoder_layer, self.num_layers, self.encoder = TransformerEncoder(encoder_layer, num_layers,
encoder_norm) encoder_norm)
fpn_stages = [] fpn_stages = []
fpn_routes = [] fpn_routes = []
for i, (ch_in, ch_out) in enumerate(zip(in_channels, out_channels)): for i, (ch_in, ch_out) in enumerate(zip(in_channels, out_channels)):
...@@ -340,8 +354,12 @@ class CustomCSPPAN(nn.Layer): ...@@ -340,8 +354,12 @@ class CustomCSPPAN(nn.Layer):
# flatten [B, C, H, W] to [B, HxW, C] # flatten [B, C, H, W] to [B, HxW, C]
src_flatten = last_feat.flatten(2).transpose([0, 2, 1]) src_flatten = last_feat.flatten(2).transpose([0, 2, 1])
pos_embed = self.build_2d_sincos_position_embedding( if self.eval_size is not None:
w=w, h=h, embed_dim=self.hidden_dim) pos_embed = self.pos_embed
else:
pos_embed = self.build_2d_sincos_position_embedding(
w=w, h=h, embed_dim=self.hidden_dim)
memory = self.encoder(src_flatten, pos_embed=pos_embed) memory = self.encoder(src_flatten, pos_embed=pos_embed)
last_feat_encode = memory.transpose([0, 2, 1]).reshape([n, c, h, w]) last_feat_encode = memory.transpose([0, 2, 1]).reshape([n, c, h, w])
blocks[-1] = last_feat_encode blocks[-1] = last_feat_encode
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册