提交 bba9e0df 编写于 作者: W weishengyu

add vgg theseus funcs

上级 8c1515b5
...@@ -12,12 +12,10 @@ class Identity(nn.Layer): ...@@ -12,12 +12,10 @@ class Identity(nn.Layer):
class TheseusLayer(nn.Layer): class TheseusLayer(nn.Layer):
def __init__(self, *args, return_patterns=None, **kwargs): def __init__(self, *args, **kwargs):
super(TheseusLayer, self).__init__() super(TheseusLayer, self).__init__()
self.res_dict = None self.res_dict = None
self.register_forward_post_hook(self._disconnect_res_dict_hook) self.register_forward_post_hook(self._disconnect_res_dict_hook)
if return_patterns is not None:
self._update_res(return_patterns)
def forward(self, *input, res_dict=None, **kwargs): def forward(self, *input, res_dict=None, **kwargs):
if res_dict is not None: if res_dict is not None:
...@@ -51,7 +49,7 @@ class TheseusLayer(nn.Layer): ...@@ -51,7 +49,7 @@ class TheseusLayer(nn.Layer):
if self.res_dict is not None: if self.res_dict is not None:
self.res_dict[layer.full_name()] = output self.res_dict[layer.full_name()] = output
def _disconnect_res_dict_hook(self, input, output): def _disconnect_res_dict_hook(self, *args, **kwargs):
self.res_dict = None self.res_dict = None
def replace_sub(self, layer_name_pattern, replace_function, recursive=True): def replace_sub(self, layer_name_pattern, replace_function, recursive=True):
......
...@@ -45,7 +45,7 @@ NET_CONFIG = { ...@@ -45,7 +45,7 @@ NET_CONFIG = {
class ConvBlock(TheseusLayer): class ConvBlock(TheseusLayer):
def __init__(self, input_channels, output_channels, groups): def __init__(self, input_channels, output_channels, groups, return_patterns=None):
super().__init__() super().__init__()
self.groups = groups self.groups = groups
...@@ -83,6 +83,7 @@ class ConvBlock(TheseusLayer): ...@@ -83,6 +83,7 @@ class ConvBlock(TheseusLayer):
self.max_pool = MaxPool2D(kernel_size=2, stride=2, padding=0) self.max_pool = MaxPool2D(kernel_size=2, stride=2, padding=0)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self._update_res(return_patterns)
def forward(self, inputs): def forward(self, inputs):
x = self.conv1(inputs) x = self.conv1(inputs)
...@@ -111,16 +112,16 @@ class VGGNet(TheseusLayer): ...@@ -111,16 +112,16 @@ class VGGNet(TheseusLayer):
model: nn.Layer. Specific VGG model depends on args. model: nn.Layer. Specific VGG model depends on args.
""" """
def __init__(self, config, stop_grad_layers=0, class_num=1000): def __init__(self, config, stop_grad_layers=0, class_num=1000, return_patterns=None):
super().__init__() super().__init__()
self.stop_grad_layers = stop_grad_layers self.stop_grad_layers = stop_grad_layers
self.conv_block_1 = ConvBlock(3, 64, config[0]) self.conv_block_1 = ConvBlock(3, 64, config[0], return_patterns)
self.conv_block_2 = ConvBlock(64, 128, config[1]) self.conv_block_2 = ConvBlock(64, 128, config[1], return_patterns)
self.conv_block_3 = ConvBlock(128, 256, config[2]) self.conv_block_3 = ConvBlock(128, 256, config[2], return_patterns)
self.conv_block_4 = ConvBlock(256, 512, config[3]) self.conv_block_4 = ConvBlock(256, 512, config[3], return_patterns)
self.conv_block_5 = ConvBlock(512, 512, config[4]) self.conv_block_5 = ConvBlock(512, 512, config[4], return_patterns)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1) self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
...@@ -137,8 +138,9 @@ class VGGNet(TheseusLayer): ...@@ -137,8 +138,9 @@ class VGGNet(TheseusLayer):
self.fc1 = Linear(7 * 7 * 512, 4096) self.fc1 = Linear(7 * 7 * 512, 4096)
self.fc2 = Linear(4096, 4096) self.fc2 = Linear(4096, 4096)
self.fc3 = Linear(4096, class_num) self.fc3 = Linear(4096, class_num)
self._update_res(return_patterns)
def forward(self, inputs): def forward(self, inputs, res_dict=None):
x = self.conv_block_1(inputs) x = self.conv_block_1(inputs)
x = self.conv_block_2(x) x = self.conv_block_2(x)
x = self.conv_block_3(x) x = self.conv_block_3(x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册