提交 9790cc51 编写于 作者: W weishengyu

add return_dict to trainer

上级 bba9e0df
...@@ -44,6 +44,8 @@ class TheseusLayer(nn.Layer): ...@@ -44,6 +44,8 @@ class TheseusLayer(nn.Layer):
if return_layers is not None and re.match(return_pattern, layer_name): if return_layers is not None and re.match(return_pattern, layer_name):
self._sub_layers[layer_i].register_forward_post_hook( self._sub_layers[layer_i].register_forward_post_hook(
self._save_sub_res_hook) self._save_sub_res_hook)
if isinstance(self._sub_layers[layer_i], TheseusLayer):
self._sub_layers[layer_i]._update_res(return_layers)
def _save_sub_res_hook(self, layer, input, output): def _save_sub_res_hook(self, layer, input, output):
if self.res_dict is not None: if self.res_dict is not None:
......
...@@ -45,7 +45,7 @@ NET_CONFIG = { ...@@ -45,7 +45,7 @@ NET_CONFIG = {
class ConvBlock(TheseusLayer): class ConvBlock(TheseusLayer):
def __init__(self, input_channels, output_channels, groups, return_patterns=None): def __init__(self, input_channels, output_channels, groups):
super().__init__() super().__init__()
self.groups = groups self.groups = groups
...@@ -83,7 +83,6 @@ class ConvBlock(TheseusLayer): ...@@ -83,7 +83,6 @@ class ConvBlock(TheseusLayer):
self.max_pool = MaxPool2D(kernel_size=2, stride=2, padding=0) self.max_pool = MaxPool2D(kernel_size=2, stride=2, padding=0)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self._update_res(return_patterns)
def forward(self, inputs): def forward(self, inputs):
x = self.conv1(inputs) x = self.conv1(inputs)
...@@ -117,11 +116,11 @@ class VGGNet(TheseusLayer): ...@@ -117,11 +116,11 @@ class VGGNet(TheseusLayer):
self.stop_grad_layers = stop_grad_layers self.stop_grad_layers = stop_grad_layers
self.conv_block_1 = ConvBlock(3, 64, config[0], return_patterns) self.conv_block_1 = ConvBlock(3, 64, config[0])
self.conv_block_2 = ConvBlock(64, 128, config[1], return_patterns) self.conv_block_2 = ConvBlock(64, 128, config[1])
self.conv_block_3 = ConvBlock(128, 256, config[2], return_patterns) self.conv_block_3 = ConvBlock(128, 256, config[2])
self.conv_block_4 = ConvBlock(256, 512, config[3], return_patterns) self.conv_block_4 = ConvBlock(256, 512, config[3])
self.conv_block_5 = ConvBlock(512, 512, config[4], return_patterns) self.conv_block_5 = ConvBlock(512, 512, config[4])
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1) self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
......
...@@ -392,10 +392,14 @@ class Trainer(object): ...@@ -392,10 +392,14 @@ class Trainer(object):
return eval_result return eval_result
def forward(self, batch): def forward(self, batch):
if self.return_inter:
return_dict = {}
else:
return_dict = None
if not self.is_rec: if not self.is_rec:
out = self.model(batch[0]) out = self.model(batch[0], return_dict=return_dict)
else: else:
out = self.model(batch[0], batch[1]) out = self.model(batch[0], batch[1], return_dict=return_dict)
return out return out
@paddle.no_grad() @paddle.no_grad()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册