未验证 提交 d98b8816 编写于 作者: W Wei Shengyu 提交者: GitHub

Merge pull request #1208 from weisy11/fix_bug_of_theseus

dbg theseus
...@@ -23,6 +23,7 @@ from . import backbone, gears ...@@ -23,6 +23,7 @@ from . import backbone, gears
from .backbone import * from .backbone import *
from .gears import build_gear from .gears import build_gear
from .utils import * from .utils import *
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.utils import logger from ppcls.utils import logger
from ppcls.utils.save_load import load_dygraph_pretrain from ppcls.utils.save_load import load_dygraph_pretrain
......
...@@ -38,24 +38,31 @@ class TheseusLayer(nn.Layer): ...@@ -38,24 +38,31 @@ class TheseusLayer(nn.Layer):
for layer_i in self._sub_layers: for layer_i in self._sub_layers:
layer_name = self._sub_layers[layer_i].full_name() layer_name = self._sub_layers[layer_i].full_name()
if isinstance(self._sub_layers[layer_i], (nn.Sequential, nn.LayerList)): if isinstance(self._sub_layers[layer_i], (nn.Sequential, nn.LayerList)):
self._sub_layers[layer_i] = wrap_theseus(self._sub_layers[layer_i]) self._sub_layers[layer_i] = wrap_theseus(self._sub_layers[layer_i], self.res_dict)
self._sub_layers[layer_i].res_dict = self.res_dict
self._sub_layers[layer_i].update_res(return_patterns) self._sub_layers[layer_i].update_res(return_patterns)
else: else:
for return_pattern in return_patterns: for return_pattern in return_patterns:
if re.match(return_pattern, layer_name): if re.match(return_pattern, layer_name):
if not isinstance(self._sub_layers[layer_i], TheseusLayer): if not isinstance(self._sub_layers[layer_i], TheseusLayer):
self._sub_layers[layer_i] = wrap_theseus(self._sub_layers[layer_i]) self._sub_layers[layer_i] = wrap_theseus(self._sub_layers[layer_i], self.res_dict)
else:
self._sub_layers[layer_i].res_dict = self.res_dict
self._sub_layers[layer_i].register_forward_post_hook( self._sub_layers[layer_i].register_forward_post_hook(
self._sub_layers[layer_i]._save_sub_res_hook) self._sub_layers[layer_i]._save_sub_res_hook)
self._sub_layers[layer_i].res_dict = self.res_dict if isinstance(self._sub_layers[layer_i], TheseusLayer):
if isinstance(self._sub_layers[layer_i], TheseusLayer): self._sub_layers[layer_i].res_dict = self.res_dict
self._sub_layers[layer_i].res_dict = self.res_dict self._sub_layers[layer_i].update_res(return_patterns)
self._sub_layers[layer_i].update_res(return_patterns)
def _save_sub_res_hook(self, layer, input, output): def _save_sub_res_hook(self, layer, input, output):
self.res_dict[layer.full_name()] = output self.res_dict[layer.full_name()] = output
def _return_dict_hook(self, layer, input, output):
res_dict = {"output": output}
for res_key in list(self.res_dict):
res_dict[res_key] = self.res_dict.pop(res_key)
return res_dict
def replace_sub(self, layer_name_pattern, replace_function, recursive=True): def replace_sub(self, layer_name_pattern, replace_function, recursive=True):
for layer_i in self._sub_layers: for layer_i in self._sub_layers:
layer_name = self._sub_layers[layer_i].full_name() layer_name = self._sub_layers[layer_i].full_name()
...@@ -85,10 +92,12 @@ class TheseusLayer(nn.Layer): ...@@ -85,10 +92,12 @@ class TheseusLayer(nn.Layer):
class WrapLayer(TheseusLayer): class WrapLayer(TheseusLayer):
def __init__(self, sub_layer): def __init__(self, sub_layer, res_dict=None):
super(WrapLayer, self).__init__() super(WrapLayer, self).__init__()
self.sub_layer = sub_layer self.sub_layer = sub_layer
self.name = sub_layer.full_name() self.name = sub_layer.full_name()
if res_dict is not None:
self.res_dict = res_dict
def full_name(self): def full_name(self):
return self.name return self.name
...@@ -101,14 +110,14 @@ class WrapLayer(TheseusLayer): ...@@ -101,14 +110,14 @@ class WrapLayer(TheseusLayer):
return return
for layer_i in self.sub_layer._sub_layers: for layer_i in self.sub_layer._sub_layers:
if isinstance(self.sub_layer._sub_layers[layer_i], (nn.Sequential, nn.LayerList)): if isinstance(self.sub_layer._sub_layers[layer_i], (nn.Sequential, nn.LayerList)):
self.sub_layer._sub_layers[layer_i] = wrap_theseus(self.sub_layer._sub_layers[layer_i]) self.sub_layer._sub_layers[layer_i] = wrap_theseus(self.sub_layer._sub_layers[layer_i], self.res_dict)
self.sub_layer._sub_layers[layer_i].res_dict = self.res_dict
self.sub_layer._sub_layers[layer_i].update_res(return_patterns) self.sub_layer._sub_layers[layer_i].update_res(return_patterns)
elif isinstance(self.sub_layer._sub_layers[layer_i], TheseusLayer):
self.sub_layer._sub_layers[layer_i].res_dict = self.res_dict
layer_name = self.sub_layer._sub_layers[layer_i].full_name() layer_name = self.sub_layer._sub_layers[layer_i].full_name()
for return_pattern in return_patterns: for return_pattern in return_patterns:
if re.match(return_pattern, layer_name): if re.match(return_pattern, layer_name):
self.sub_layer._sub_layers[layer_i].res_dict = self.res_dict
self.sub_layer._sub_layers[layer_i].register_forward_post_hook( self.sub_layer._sub_layers[layer_i].register_forward_post_hook(
self._sub_layers[layer_i]._save_sub_res_hook) self._sub_layers[layer_i]._save_sub_res_hook)
...@@ -116,6 +125,6 @@ class WrapLayer(TheseusLayer): ...@@ -116,6 +125,6 @@ class WrapLayer(TheseusLayer):
self.sub_layer._sub_layers[layer_i].update_res(return_patterns) self.sub_layer._sub_layers[layer_i].update_res(return_patterns)
def wrap_theseus(sub_layer): def wrap_theseus(sub_layer, res_dict=None):
wrapped_layer = WrapLayer(sub_layer) wrapped_layer = WrapLayer(sub_layer, res_dict)
return wrapped_layer return wrapped_layer
...@@ -367,7 +367,7 @@ class HRNet(TheseusLayer): ...@@ -367,7 +367,7 @@ class HRNet(TheseusLayer):
model: nn.Layer. Specific HRNet model depends on args. model: nn.Layer. Specific HRNet model depends on args.
""" """
def __init__(self, width=18, has_se=False, class_num=1000): def __init__(self, width=18, has_se=False, class_num=1000, return_patterns=None):
super().__init__() super().__init__()
self.width = width self.width = width
...@@ -456,8 +456,11 @@ class HRNet(TheseusLayer): ...@@ -456,8 +456,11 @@ class HRNet(TheseusLayer):
2048, 2048,
class_num, class_num,
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv))) weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, x, res_dict=None): def forward(self, x):
x = self.conv_layer1_1(x) x = self.conv_layer1_1(x)
x = self.conv_layer1_2(x) x = self.conv_layer1_2(x)
......
...@@ -454,7 +454,7 @@ class Inception_V3(TheseusLayer): ...@@ -454,7 +454,7 @@ class Inception_V3(TheseusLayer):
model: nn.Layer. Specific Inception_V3 model depends on args. model: nn.Layer. Specific Inception_V3 model depends on args.
""" """
def __init__(self, config, class_num=1000): def __init__(self, config, class_num=1000, return_patterns=None):
super().__init__() super().__init__()
self.inception_a_list = config["inception_a"] self.inception_a_list = config["inception_a"]
...@@ -496,6 +496,9 @@ class Inception_V3(TheseusLayer): ...@@ -496,6 +496,9 @@ class Inception_V3(TheseusLayer):
class_num, class_num,
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)), weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)),
bias_attr=ParamAttr()) bias_attr=ParamAttr())
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, x): def forward(self, x):
x = self.inception_stem(x) x = self.inception_stem(x)
......
...@@ -102,7 +102,7 @@ class MobileNet(TheseusLayer): ...@@ -102,7 +102,7 @@ class MobileNet(TheseusLayer):
model: nn.Layer. Specific MobileNet model depends on args. model: nn.Layer. Specific MobileNet model depends on args.
""" """
def __init__(self, scale=1.0, class_num=1000): def __init__(self, scale=1.0, class_num=1000, return_patterns=None):
super().__init__() super().__init__()
self.scale = scale self.scale = scale
...@@ -145,6 +145,9 @@ class MobileNet(TheseusLayer): ...@@ -145,6 +145,9 @@ class MobileNet(TheseusLayer):
int(1024 * scale), int(1024 * scale),
class_num, class_num,
weight_attr=ParamAttr(initializer=KaimingNormal())) weight_attr=ParamAttr(initializer=KaimingNormal()))
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
......
...@@ -142,7 +142,8 @@ class MobileNetV3(TheseusLayer): ...@@ -142,7 +142,8 @@ class MobileNetV3(TheseusLayer):
inplanes=STEM_CONV_NUMBER, inplanes=STEM_CONV_NUMBER,
class_squeeze=LAST_SECOND_CONV_LARGE, class_squeeze=LAST_SECOND_CONV_LARGE,
class_expand=LAST_CONV, class_expand=LAST_CONV,
dropout_prob=0.2): dropout_prob=0.2,
return_patterns=None):
super().__init__() super().__init__()
self.cfg = config self.cfg = config
...@@ -199,6 +200,9 @@ class MobileNetV3(TheseusLayer): ...@@ -199,6 +200,9 @@ class MobileNetV3(TheseusLayer):
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1) self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
self.fc = Linear(self.class_expand, class_num) self.fc = Linear(self.class_expand, class_num)
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
......
...@@ -269,7 +269,8 @@ class ResNet(TheseusLayer): ...@@ -269,7 +269,8 @@ class ResNet(TheseusLayer):
class_num=1000, class_num=1000,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0], lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
data_format="NCHW", data_format="NCHW",
input_image_channel=3): input_image_channel=3,
return_patterns=None):
super().__init__() super().__init__()
self.cfg = config self.cfg = config
...@@ -337,6 +338,9 @@ class ResNet(TheseusLayer): ...@@ -337,6 +338,9 @@ class ResNet(TheseusLayer):
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv))) weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
self.data_format = data_format self.data_format = data_format
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, x): def forward(self, x):
with paddle.static.amp.fp16_guard(): with paddle.static.amp.fp16_guard():
......
...@@ -137,8 +137,11 @@ class VGGNet(TheseusLayer): ...@@ -137,8 +137,11 @@ class VGGNet(TheseusLayer):
self.fc1 = Linear(7 * 7 * 512, 4096) self.fc1 = Linear(7 * 7 * 512, 4096)
self.fc2 = Linear(4096, 4096) self.fc2 = Linear(4096, 4096)
self.fc3 = Linear(4096, class_num) self.fc3 = Linear(4096, class_num)
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, inputs, res_dict=None): def forward(self, inputs):
x = self.conv_block_1(inputs) x = self.conv_block_1(inputs)
x = self.conv_block_2(x) x = self.conv_block_2(x)
x = self.conv_block_3(x) x = self.conv_block_3(x)
...@@ -152,9 +155,6 @@ class VGGNet(TheseusLayer): ...@@ -152,9 +155,6 @@ class VGGNet(TheseusLayer):
x = self.relu(x) x = self.relu(x)
x = self.drop(x) x = self.drop(x)
x = self.fc3(x) x = self.fc3(x)
if self.res_dict and res_dict is not None:
for res_key in list(self.res_dict):
res_dict[res_key] = self.res_dict.pop(res_key)
return x return x
......
...@@ -346,6 +346,8 @@ class Engine(object): ...@@ -346,6 +346,8 @@ class Engine(object):
out = self.model(batch_tensor) out = self.model(batch_tensor)
if isinstance(out, list): if isinstance(out, list):
out = out[0] out = out[0]
if isinstance(out, dict):
out = out["output"]
result = self.postprocess_func(out, image_file_list) result = self.postprocess_func(out, image_file_list)
print(result) print(result)
batch_data.clear() batch_data.clear()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册