未验证 提交 d98b8816 编写于 作者: W Wei Shengyu 提交者: GitHub

Merge pull request #1208 from weisy11/fix_bug_of_theseus

dbg theseus
......@@ -23,6 +23,7 @@ from . import backbone, gears
from .backbone import *
from .gears import build_gear
from .utils import *
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.utils import logger
from ppcls.utils.save_load import load_dygraph_pretrain
......
......@@ -38,17 +38,18 @@ class TheseusLayer(nn.Layer):
for layer_i in self._sub_layers:
layer_name = self._sub_layers[layer_i].full_name()
if isinstance(self._sub_layers[layer_i], (nn.Sequential, nn.LayerList)):
self._sub_layers[layer_i] = wrap_theseus(self._sub_layers[layer_i])
self._sub_layers[layer_i].res_dict = self.res_dict
self._sub_layers[layer_i] = wrap_theseus(self._sub_layers[layer_i], self.res_dict)
self._sub_layers[layer_i].update_res(return_patterns)
else:
for return_pattern in return_patterns:
if re.match(return_pattern, layer_name):
if not isinstance(self._sub_layers[layer_i], TheseusLayer):
self._sub_layers[layer_i] = wrap_theseus(self._sub_layers[layer_i])
self._sub_layers[layer_i] = wrap_theseus(self._sub_layers[layer_i], self.res_dict)
else:
self._sub_layers[layer_i].res_dict = self.res_dict
self._sub_layers[layer_i].register_forward_post_hook(
self._sub_layers[layer_i]._save_sub_res_hook)
self._sub_layers[layer_i].res_dict = self.res_dict
if isinstance(self._sub_layers[layer_i], TheseusLayer):
self._sub_layers[layer_i].res_dict = self.res_dict
self._sub_layers[layer_i].update_res(return_patterns)
......@@ -56,6 +57,12 @@ class TheseusLayer(nn.Layer):
def _save_sub_res_hook(self, layer, input, output):
self.res_dict[layer.full_name()] = output
def _return_dict_hook(self, layer, input, output):
res_dict = {"output": output}
for res_key in list(self.res_dict):
res_dict[res_key] = self.res_dict.pop(res_key)
return res_dict
def replace_sub(self, layer_name_pattern, replace_function, recursive=True):
for layer_i in self._sub_layers:
layer_name = self._sub_layers[layer_i].full_name()
......@@ -85,10 +92,12 @@ class TheseusLayer(nn.Layer):
class WrapLayer(TheseusLayer):
def __init__(self, sub_layer):
def __init__(self, sub_layer, res_dict=None):
super(WrapLayer, self).__init__()
self.sub_layer = sub_layer
self.name = sub_layer.full_name()
if res_dict is not None:
self.res_dict = res_dict
def full_name(self):
return self.name
......@@ -101,14 +110,14 @@ class WrapLayer(TheseusLayer):
return
for layer_i in self.sub_layer._sub_layers:
if isinstance(self.sub_layer._sub_layers[layer_i], (nn.Sequential, nn.LayerList)):
self.sub_layer._sub_layers[layer_i] = wrap_theseus(self.sub_layer._sub_layers[layer_i])
self.sub_layer._sub_layers[layer_i].res_dict = self.res_dict
self.sub_layer._sub_layers[layer_i] = wrap_theseus(self.sub_layer._sub_layers[layer_i], self.res_dict)
self.sub_layer._sub_layers[layer_i].update_res(return_patterns)
elif isinstance(self.sub_layer._sub_layers[layer_i], TheseusLayer):
self.sub_layer._sub_layers[layer_i].res_dict = self.res_dict
layer_name = self.sub_layer._sub_layers[layer_i].full_name()
for return_pattern in return_patterns:
if re.match(return_pattern, layer_name):
self.sub_layer._sub_layers[layer_i].res_dict = self.res_dict
self.sub_layer._sub_layers[layer_i].register_forward_post_hook(
self._sub_layers[layer_i]._save_sub_res_hook)
......@@ -116,6 +125,6 @@ class WrapLayer(TheseusLayer):
self.sub_layer._sub_layers[layer_i].update_res(return_patterns)
def wrap_theseus(sub_layer):
wrapped_layer = WrapLayer(sub_layer)
def wrap_theseus(sub_layer, res_dict=None):
wrapped_layer = WrapLayer(sub_layer, res_dict)
return wrapped_layer
......@@ -367,7 +367,7 @@ class HRNet(TheseusLayer):
model: nn.Layer. Specific HRNet model depends on args.
"""
def __init__(self, width=18, has_se=False, class_num=1000):
def __init__(self, width=18, has_se=False, class_num=1000, return_patterns=None):
super().__init__()
self.width = width
......@@ -456,8 +456,11 @@ class HRNet(TheseusLayer):
2048,
class_num,
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, x, res_dict=None):
def forward(self, x):
x = self.conv_layer1_1(x)
x = self.conv_layer1_2(x)
......
......@@ -454,7 +454,7 @@ class Inception_V3(TheseusLayer):
model: nn.Layer. Specific Inception_V3 model depends on args.
"""
def __init__(self, config, class_num=1000):
def __init__(self, config, class_num=1000, return_patterns=None):
super().__init__()
self.inception_a_list = config["inception_a"]
......@@ -496,6 +496,9 @@ class Inception_V3(TheseusLayer):
class_num,
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)),
bias_attr=ParamAttr())
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, x):
x = self.inception_stem(x)
......
......@@ -102,7 +102,7 @@ class MobileNet(TheseusLayer):
model: nn.Layer. Specific MobileNet model depends on args.
"""
def __init__(self, scale=1.0, class_num=1000):
def __init__(self, scale=1.0, class_num=1000, return_patterns=None):
super().__init__()
self.scale = scale
......@@ -145,6 +145,9 @@ class MobileNet(TheseusLayer):
int(1024 * scale),
class_num,
weight_attr=ParamAttr(initializer=KaimingNormal()))
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, x):
x = self.conv(x)
......
......@@ -142,7 +142,8 @@ class MobileNetV3(TheseusLayer):
inplanes=STEM_CONV_NUMBER,
class_squeeze=LAST_SECOND_CONV_LARGE,
class_expand=LAST_CONV,
dropout_prob=0.2):
dropout_prob=0.2,
return_patterns=None):
super().__init__()
self.cfg = config
......@@ -199,6 +200,9 @@ class MobileNetV3(TheseusLayer):
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
self.fc = Linear(self.class_expand, class_num)
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, x):
x = self.conv(x)
......
......@@ -269,7 +269,8 @@ class ResNet(TheseusLayer):
class_num=1000,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
data_format="NCHW",
input_image_channel=3):
input_image_channel=3,
return_patterns=None):
super().__init__()
self.cfg = config
......@@ -337,6 +338,9 @@ class ResNet(TheseusLayer):
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
self.data_format = data_format
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, x):
with paddle.static.amp.fp16_guard():
......
......@@ -137,8 +137,11 @@ class VGGNet(TheseusLayer):
self.fc1 = Linear(7 * 7 * 512, 4096)
self.fc2 = Linear(4096, 4096)
self.fc3 = Linear(4096, class_num)
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, inputs, res_dict=None):
def forward(self, inputs):
x = self.conv_block_1(inputs)
x = self.conv_block_2(x)
x = self.conv_block_3(x)
......@@ -152,9 +155,6 @@ class VGGNet(TheseusLayer):
x = self.relu(x)
x = self.drop(x)
x = self.fc3(x)
if self.res_dict and res_dict is not None:
for res_key in list(self.res_dict):
res_dict[res_key] = self.res_dict.pop(res_key)
return x
......
......@@ -346,6 +346,8 @@ class Engine(object):
out = self.model(batch_tensor)
if isinstance(out, list):
out = out[0]
if isinstance(out, dict):
out = out["output"]
result = self.postprocess_func(out, image_file_list)
print(result)
batch_data.clear()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册