提交 6ebe7f09 编写于 作者: W weishengyu

add return_inter flag

上级 9790cc51
...@@ -37,15 +37,15 @@ class TheseusLayer(nn.Layer): ...@@ -37,15 +37,15 @@ class TheseusLayer(nn.Layer):
stop_layer_name) stop_layer_name)
return after_stop return after_stop
def _update_res(self, return_layers): def _update_res(self, return_patterns):
for layer_i in self._sub_layers: for layer_i in self._sub_layers:
layer_name = self._sub_layers[layer_i].full_name() layer_name = self._sub_layers[layer_i].full_name()
for return_pattern in return_layers: for return_pattern in return_patterns:
if return_layers is not None and re.match(return_pattern, layer_name): if return_patterns is not None and re.match(return_pattern, layer_name):
self._sub_layers[layer_i].register_forward_post_hook( self._sub_layers[layer_i].register_forward_post_hook(
self._save_sub_res_hook) self._save_sub_res_hook)
if isinstance(self._sub_layers[layer_i], TheseusLayer): if isinstance(self._sub_layers[layer_i], TheseusLayer):
self._sub_layers[layer_i]._update_res(return_layers) self._sub_layers[layer_i]._update_res(return_patterns)
def _save_sub_res_hook(self, layer, input, output): def _save_sub_res_hook(self, layer, input, output):
if self.res_dict is not None: if self.res_dict is not None:
......
...@@ -75,6 +75,8 @@ class Trainer(object): ...@@ -75,6 +75,8 @@ class Trainer(object):
self.is_rec = False self.is_rec = False
self.model = build_model(self.config["Arch"]) self.model = build_model(self.config["Arch"])
if "return_pattern" in self.config["Arch"]:
self.return_inter = True
# set @to_static for benchmark, skip this by default. # set @to_static for benchmark, skip this by default.
apply_to_static(self.config, self.model) apply_to_static(self.config, self.model)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册