diff --git a/ppcls/arch/backbone/model_zoo/pvt_v2.py b/ppcls/arch/backbone/model_zoo/pvt_v2.py index 94175754fe40ae1a4f2cc40e4d291b71a8d17fe8..f435e87564cfdb6ccf32cd37bd0e4b5318484179 100644 --- a/ppcls/arch/backbone/model_zoo/pvt_v2.py +++ b/ppcls/arch/backbone/model_zoo/pvt_v2.py @@ -147,7 +147,8 @@ class Attention(nn.Layer): ]).transpose([2, 0, 3, 1, 4]) else: x_ = x.transpose([0, 2, 1]).reshape([B, C, H, W]) - x_ = self.sr(self.pool(x_)).reshape([B, C, -1]).transpose( + x_ = self.sr(self.pool(x_)) + x_ = x_.reshape([B, C, x_.shape[2] * x_.shape[3]]).transpose( [0, 2, 1]) x_ = self.norm(x_) x_ = self.act(x_)