未验证 提交 223a0d2e 编写于 作者: L littletomatodonkey 提交者: GitHub

fix vgg stop grad (#558)

* fix vgg stop grad

* beautify code
上级 bbc6649d
......@@ -68,10 +68,11 @@ class ConvBlock(nn.Layer):
class VGGNet(nn.Layer):
def __init__(self, layers=11, class_dim=1000):
def __init__(self, layers=11, stop_grad_layers=0, class_dim=1000):
super(VGGNet, self).__init__()
self.layers = layers
self.stop_grad_layers = stop_grad_layers
self.vgg_configure = {
11: [1, 1, 2, 2, 2],
13: [2, 2, 2, 2, 2],
......@@ -89,6 +90,14 @@ class VGGNet(nn.Layer):
self._conv_block_4 = ConvBlock(256, 512, self.groups[3], name="conv4_")
self._conv_block_5 = ConvBlock(512, 512, self.groups[4], name="conv5_")
for idx, block in enumerate([
self._conv_block_1, self._conv_block_2, self._conv_block_3,
self._conv_block_4, self._conv_block_5
]):
if self.stop_grad_layers >= idx + 1:
for param in block.parameters():
param.trainable = False
self._drop = Dropout(p=0.5, mode="downscale_in_infer")
self._fc1 = Linear(
7 * 7 * 512,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册