未验证 提交 b10ef7d9 编写于 作者: W wjm 提交者: GitHub

[cherry-pick] fix recursive call of DLA (#6771) (#6786)

上级 2382374e
...@@ -164,11 +164,11 @@ class IDAUp(nn.Layer): ...@@ -164,11 +164,11 @@ class IDAUp(nn.Layer):
for i in range(start_level + 1, end_level): for i in range(start_level + 1, end_level):
upsample = getattr(self, 'up_' + str(i - start_level)) upsample = getattr(self, 'up_' + str(i - start_level))
project = getattr(self, 'proj_' + str(i - start_level)) project = getattr(self, 'proj_' + str(i - start_level))
inputs[i] = project(inputs[i]) inputs[i] = project(inputs[i])
inputs[i] = upsample(inputs[i]) inputs[i] = upsample(inputs[i])
node = getattr(self, 'node_' + str(i - start_level)) node = getattr(self, 'node_' + str(i - start_level))
inputs[i] = node(paddle.add(inputs[i], inputs[i - 1])) inputs[i] = node(paddle.add(inputs[i], inputs[i - 1]))
return inputs
class DLAUp(nn.Layer): class DLAUp(nn.Layer):
...@@ -197,8 +197,8 @@ class DLAUp(nn.Layer): ...@@ -197,8 +197,8 @@ class DLAUp(nn.Layer):
out = [inputs[-1]] # start with 32 out = [inputs[-1]] # start with 32
for i in range(len(inputs) - self.start_level - 1): for i in range(len(inputs) - self.start_level - 1):
ida = getattr(self, 'ida_{}'.format(i)) ida = getattr(self, 'ida_{}'.format(i))
ida(inputs, len(inputs) - i - 2, len(inputs)) outputs = ida(inputs, len(inputs) - i - 2, len(inputs))
out.insert(0, inputs[-1]) out.insert(0, outputs[-1])
return out return out
...@@ -259,7 +259,9 @@ class CenterNetDLAFPN(nn.Layer): ...@@ -259,7 +259,9 @@ class CenterNetDLAFPN(nn.Layer):
def forward(self, body_feats): def forward(self, body_feats):
dla_up_feats = self.dla_up(body_feats) inputs = [body_feats[i] for i in range(len(body_feats))]
dla_up_feats = self.dla_up(inputs)
ida_up_feats = [] ida_up_feats = []
for i in range(self.last_level - self.first_level): for i in range(self.last_level - self.first_level):
...@@ -271,7 +273,11 @@ class CenterNetDLAFPN(nn.Layer): ...@@ -271,7 +273,11 @@ class CenterNetDLAFPN(nn.Layer):
if self.with_sge: if self.with_sge:
feat = self.sge_attention(feat) feat = self.sge_attention(feat)
if self.down_ratio != 4: if self.down_ratio != 4:
feat = F.interpolate(feat, scale_factor=self.down_ratio // 4, mode="bilinear", align_corners=True) feat = F.interpolate(
feat,
scale_factor=self.down_ratio // 4,
mode="bilinear",
align_corners=True)
return feat return feat
@property @property
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册