diff --git a/ppdet/modeling/reid/resnet.py b/ppdet/modeling/reid/resnet.py index 968fe9774f116c846cd372c7086dc9671d135b7c..2e2a85558d69cecb307df1f1098ec0bdd70a93e2 100644 --- a/ppdet/modeling/reid/resnet.py +++ b/ppdet/modeling/reid/resnet.py @@ -55,12 +55,14 @@ class ConvBNLayer(nn.Layer): bias_attr=False, data_format=data_format) - self._batch_norm = nn.BatchNorm( - num_filters, act=act, data_layout=data_format) + self._batch_norm = nn.BatchNorm2D(num_filters) + self.act = act def forward(self, inputs): y = self._conv(inputs) y = self._batch_norm(y) + if self.act: + y = getattr(F, self.act)(y) return y