提交 9790cc51 编写于 作者: W weishengyu

add return_dict to trainer

上级 bba9e0df
......@@ -44,6 +44,8 @@ class TheseusLayer(nn.Layer):
if return_layers is not None and re.match(return_pattern, layer_name):
self._sub_layers[layer_i].register_forward_post_hook(
self._save_sub_res_hook)
if isinstance(self._sub_layers[layer_i], TheseusLayer):
self._sub_layers[layer_i]._update_res(return_layers)
def _save_sub_res_hook(self, layer, input, output):
if self.res_dict is not None:
......
......@@ -45,7 +45,7 @@ NET_CONFIG = {
class ConvBlock(TheseusLayer):
def __init__(self, input_channels, output_channels, groups, return_patterns=None):
def __init__(self, input_channels, output_channels, groups):
super().__init__()
self.groups = groups
......@@ -83,7 +83,6 @@ class ConvBlock(TheseusLayer):
self.max_pool = MaxPool2D(kernel_size=2, stride=2, padding=0)
self.relu = nn.ReLU()
self._update_res(return_patterns)
def forward(self, inputs):
x = self.conv1(inputs)
......@@ -117,11 +116,11 @@ class VGGNet(TheseusLayer):
self.stop_grad_layers = stop_grad_layers
self.conv_block_1 = ConvBlock(3, 64, config[0], return_patterns)
self.conv_block_2 = ConvBlock(64, 128, config[1], return_patterns)
self.conv_block_3 = ConvBlock(128, 256, config[2], return_patterns)
self.conv_block_4 = ConvBlock(256, 512, config[3], return_patterns)
self.conv_block_5 = ConvBlock(512, 512, config[4], return_patterns)
self.conv_block_1 = ConvBlock(3, 64, config[0])
self.conv_block_2 = ConvBlock(64, 128, config[1])
self.conv_block_3 = ConvBlock(128, 256, config[2])
self.conv_block_4 = ConvBlock(256, 512, config[3])
self.conv_block_5 = ConvBlock(512, 512, config[4])
self.relu = nn.ReLU()
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
......
......@@ -392,10 +392,14 @@ class Trainer(object):
return eval_result
def forward(self, batch):
if self.return_inter:
return_dict = {}
else:
return_dict = None
if not self.is_rec:
out = self.model(batch[0])
out = self.model(batch[0], return_dict=return_dict)
else:
out = self.model(batch[0], batch[1])
out = self.model(batch[0], batch[1], return_dict=return_dict)
return out
@paddle.no_grad()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册