未验证 提交 8d16077e 编写于 作者: L littletomatodonkey 提交者: GitHub

fix resnest (#491)

上级 dc124ecb
...@@ -165,14 +165,15 @@ class SplatConv(nn.Layer): ...@@ -165,14 +165,15 @@ class SplatConv(nn.Layer):
atten = self.conv3(gap) atten = self.conv3(gap)
atten = self.rsoftmax(atten) atten = self.rsoftmax(atten)
atten = paddle.reshape(x=atten, shape=[-1, atten.shape[1], 1, 1])
if self.radix > 1: if self.radix > 1:
attens = paddle.split(atten, num_or_sections=self.radix, axis=1) attens = paddle.split(atten, num_or_sections=self.radix, axis=1)
y = paddle.add_n( y = paddle.add_n([
[split * att for (att, split) in zip(attens, splited)]) paddle.multiply(split, att)
for (att, split) in zip(attens, splited)
])
else: else:
y = x * atten y = paddle.multiply(x, atten)
return y return y
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册