diff --git a/ppdet/modeling/backbones/resnet.py b/ppdet/modeling/backbones/resnet.py index 6be2fc6e16cfa94695aad596ffa2ecfd0705f7b8..c0f2a6b6c9195e14d08fbedfbf0eb657f86a1681 100755 --- a/ppdet/modeling/backbones/resnet.py +++ b/ppdet/modeling/backbones/resnet.py @@ -560,6 +560,15 @@ class ResNet(nn.Layer): self.res_layers.append(res_layer) self.ch_in = self._out_channels[i] + if freeze_at >= 0: + self._freeze_parameters(self.conv1) + for i in range(min(freeze_at + 1, num_stages)): + self._freeze_parameters(self.res_layers[i]) + + def _freeze_parameters(self, m): + for p in m.parameters(): + p.stop_gradient = True + @property def out_shape(self): return [ @@ -575,8 +584,6 @@ class ResNet(nn.Layer): outs = [] for idx, stage in enumerate(self.res_layers): x = stage(x) - if idx == self.freeze_at: - x.stop_gradient = True if idx in self.return_idx: outs.append(x) return outs