diff --git a/ppcls/modeling/architectures/resnest.py b/ppcls/modeling/architectures/resnest.py index 769b763787f7468f7392a4ef2e1c7479c5523c76..0875c06a9b948f6597f616a522d34c508bbdb992 100644 --- a/ppcls/modeling/architectures/resnest.py +++ b/ppcls/modeling/architectures/resnest.py @@ -165,14 +165,15 @@ class SplatConv(nn.Layer): atten = self.conv3(gap) atten = self.rsoftmax(atten) - atten = paddle.reshape(x=atten, shape=[-1, atten.shape[1], 1, 1]) if self.radix > 1: attens = paddle.split(atten, num_or_sections=self.radix, axis=1) - y = paddle.add_n( - [split * att for (att, split) in zip(attens, splited)]) + y = paddle.add_n([ + paddle.multiply(split, att) + for (att, split) in zip(attens, splited) + ]) else: - y = x * atten + y = paddle.multiply(x, atten) return y