提交 bba9e0df 编写于 作者: W weishengyu

add vgg theseus funcs

上级 8c1515b5
......@@ -12,12 +12,10 @@ class Identity(nn.Layer):
class TheseusLayer(nn.Layer):
def __init__(self, *args, return_patterns=None, **kwargs):
def __init__(self, *args, **kwargs):
super(TheseusLayer, self).__init__()
self.res_dict = None
self.register_forward_post_hook(self._disconnect_res_dict_hook)
if return_patterns is not None:
self._update_res(return_patterns)
def forward(self, *input, res_dict=None, **kwargs):
if res_dict is not None:
......@@ -51,7 +49,7 @@ class TheseusLayer(nn.Layer):
if self.res_dict is not None:
self.res_dict[layer.full_name()] = output
def _disconnect_res_dict_hook(self, input, output):
def _disconnect_res_dict_hook(self, *args, **kwargs):
self.res_dict = None
def replace_sub(self, layer_name_pattern, replace_function, recursive=True):
......
......@@ -45,7 +45,7 @@ NET_CONFIG = {
class ConvBlock(TheseusLayer):
def __init__(self, input_channels, output_channels, groups):
def __init__(self, input_channels, output_channels, groups, return_patterns=None):
super().__init__()
self.groups = groups
......@@ -83,6 +83,7 @@ class ConvBlock(TheseusLayer):
self.max_pool = MaxPool2D(kernel_size=2, stride=2, padding=0)
self.relu = nn.ReLU()
self._update_res(return_patterns)
def forward(self, inputs):
x = self.conv1(inputs)
......@@ -111,16 +112,16 @@ class VGGNet(TheseusLayer):
model: nn.Layer. Specific VGG model depends on args.
"""
def __init__(self, config, stop_grad_layers=0, class_num=1000):
def __init__(self, config, stop_grad_layers=0, class_num=1000, return_patterns=None):
super().__init__()
self.stop_grad_layers = stop_grad_layers
self.conv_block_1 = ConvBlock(3, 64, config[0])
self.conv_block_2 = ConvBlock(64, 128, config[1])
self.conv_block_3 = ConvBlock(128, 256, config[2])
self.conv_block_4 = ConvBlock(256, 512, config[3])
self.conv_block_5 = ConvBlock(512, 512, config[4])
self.conv_block_1 = ConvBlock(3, 64, config[0], return_patterns)
self.conv_block_2 = ConvBlock(64, 128, config[1], return_patterns)
self.conv_block_3 = ConvBlock(128, 256, config[2], return_patterns)
self.conv_block_4 = ConvBlock(256, 512, config[3], return_patterns)
self.conv_block_5 = ConvBlock(512, 512, config[4], return_patterns)
self.relu = nn.ReLU()
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
......@@ -137,8 +138,9 @@ class VGGNet(TheseusLayer):
self.fc1 = Linear(7 * 7 * 512, 4096)
self.fc2 = Linear(4096, 4096)
self.fc3 = Linear(4096, class_num)
self._update_res(return_patterns)
def forward(self, inputs):
def forward(self, inputs, res_dict=None):
x = self.conv_block_1(inputs)
x = self.conv_block_2(x)
x = self.conv_block_3(x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册