未验证 提交 1f34af72 编写于 作者: W Wenyu 提交者: GitHub

Freeze reset blocks (#3441)

* modified freeze at in resnet

* freeze_at should less than num_stage
上级 6063d163
...@@ -560,6 +560,15 @@ class ResNet(nn.Layer): ...@@ -560,6 +560,15 @@ class ResNet(nn.Layer):
self.res_layers.append(res_layer) self.res_layers.append(res_layer)
self.ch_in = self._out_channels[i] self.ch_in = self._out_channels[i]
if freeze_at >= 0:
self._freeze_parameters(self.conv1)
for i in range(min(freeze_at + 1, num_stages)):
self._freeze_parameters(self.res_layers[i])
def _freeze_parameters(self, m):
for p in m.parameters():
p.stop_gradient = True
@property @property
def out_shape(self): def out_shape(self):
return [ return [
...@@ -575,8 +584,6 @@ class ResNet(nn.Layer): ...@@ -575,8 +584,6 @@ class ResNet(nn.Layer):
outs = [] outs = []
for idx, stage in enumerate(self.res_layers): for idx, stage in enumerate(self.res_layers):
x = stage(x) x = stage(x)
if idx == self.freeze_at:
x.stop_gradient = True
if idx in self.return_idx: if idx in self.return_idx:
outs.append(x) outs.append(x)
return outs return outs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册