提交 e9162595 编写于 作者: W weishengyu

update return_res method

上级 5131956d
...@@ -23,6 +23,7 @@ from . import backbone, gears ...@@ -23,6 +23,7 @@ from . import backbone, gears
from .backbone import * from .backbone import *
from .gears import build_gear from .gears import build_gear
from .utils import * from .utils import *
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.utils import logger from ppcls.utils import logger
from ppcls.utils.save_load import load_dygraph_pretrain from ppcls.utils.save_load import load_dygraph_pretrain
...@@ -32,8 +33,11 @@ __all__ = ["build_model", "RecModel", "DistillationModel"] ...@@ -32,8 +33,11 @@ __all__ = ["build_model", "RecModel", "DistillationModel"]
def build_model(config): def build_model(config):
config = copy.deepcopy(config) config = copy.deepcopy(config)
model_type = config.pop("name") model_type = config.pop("name")
return_patterns = config.pop("return_patterns", None)
mod = importlib.import_module(__name__) mod = importlib.import_module(__name__)
arch = getattr(mod, model_type)(**config) arch = getattr(mod, model_type)(**config)
if return_patterns is not None and isinstance(arch, TheseusLayer):
arch.update_res(return_patterns=return_patterns, return_dict=True)
return arch return arch
...@@ -55,7 +59,10 @@ class RecModel(nn.Layer): ...@@ -55,7 +59,10 @@ class RecModel(nn.Layer):
super().__init__() super().__init__()
backbone_config = config["Backbone"] backbone_config = config["Backbone"]
backbone_name = backbone_config.pop("name") backbone_name = backbone_config.pop("name")
return_patterns = config.pop("return_patterns", None)
self.backbone = eval(backbone_name)(**backbone_config) self.backbone = eval(backbone_name)(**backbone_config)
if return_patterns is not None and isinstance(self.backbone, TheseusLayer):
self.backbone.update_res(return_patterns=return_patterns, return_dict=True)
if "BackboneStopLayer" in config: if "BackboneStopLayer" in config:
backbone_stop_layer = config["BackboneStopLayer"]["name"] backbone_stop_layer = config["BackboneStopLayer"]["name"]
self.backbone.stop_after(backbone_stop_layer) self.backbone.stop_after(backbone_stop_layer)
......
...@@ -57,6 +57,12 @@ class TheseusLayer(nn.Layer): ...@@ -57,6 +57,12 @@ class TheseusLayer(nn.Layer):
def _save_sub_res_hook(self, layer, input, output): def _save_sub_res_hook(self, layer, input, output):
self.res_dict[layer.full_name()] = output self.res_dict[layer.full_name()] = output
def _return_dict_hook(self, layer, input, output):
res_dict = {"output": output}
for res_key in list(self.res_dict):
res_dict[res_key] = self.res_dict.pop(res_key)
return res_dict
def replace_sub(self, layer_name_pattern, replace_function, recursive=True): def replace_sub(self, layer_name_pattern, replace_function, recursive=True):
for layer_i in self._sub_layers: for layer_i in self._sub_layers:
layer_name = self._sub_layers[layer_i].full_name() layer_name = self._sub_layers[layer_i].full_name()
......
...@@ -367,7 +367,7 @@ class HRNet(TheseusLayer): ...@@ -367,7 +367,7 @@ class HRNet(TheseusLayer):
model: nn.Layer. Specific HRNet model depends on args. model: nn.Layer. Specific HRNet model depends on args.
""" """
def __init__(self, width=18, has_se=False, class_num=1000): def __init__(self, width=18, has_se=False, class_num=1000, return_patterns=None):
super().__init__() super().__init__()
self.width = width self.width = width
...@@ -456,8 +456,11 @@ class HRNet(TheseusLayer): ...@@ -456,8 +456,11 @@ class HRNet(TheseusLayer):
2048, 2048,
class_num, class_num,
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv))) weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, x, res_dict=None): def forward(self, x):
x = self.conv_layer1_1(x) x = self.conv_layer1_1(x)
x = self.conv_layer1_2(x) x = self.conv_layer1_2(x)
......
...@@ -454,7 +454,7 @@ class Inception_V3(TheseusLayer): ...@@ -454,7 +454,7 @@ class Inception_V3(TheseusLayer):
model: nn.Layer. Specific Inception_V3 model depends on args. model: nn.Layer. Specific Inception_V3 model depends on args.
""" """
def __init__(self, config, class_num=1000): def __init__(self, config, class_num=1000, return_patterns=None):
super().__init__() super().__init__()
self.inception_a_list = config["inception_a"] self.inception_a_list = config["inception_a"]
...@@ -496,6 +496,9 @@ class Inception_V3(TheseusLayer): ...@@ -496,6 +496,9 @@ class Inception_V3(TheseusLayer):
class_num, class_num,
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)), weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)),
bias_attr=ParamAttr()) bias_attr=ParamAttr())
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, x): def forward(self, x):
x = self.inception_stem(x) x = self.inception_stem(x)
......
...@@ -102,7 +102,7 @@ class MobileNet(TheseusLayer): ...@@ -102,7 +102,7 @@ class MobileNet(TheseusLayer):
model: nn.Layer. Specific MobileNet model depends on args. model: nn.Layer. Specific MobileNet model depends on args.
""" """
def __init__(self, scale=1.0, class_num=1000): def __init__(self, scale=1.0, class_num=1000, return_patterns=None):
super().__init__() super().__init__()
self.scale = scale self.scale = scale
...@@ -145,16 +145,16 @@ class MobileNet(TheseusLayer): ...@@ -145,16 +145,16 @@ class MobileNet(TheseusLayer):
int(1024 * scale), int(1024 * scale),
class_num, class_num,
weight_attr=ParamAttr(initializer=KaimingNormal())) weight_attr=ParamAttr(initializer=KaimingNormal()))
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, x, res_dict=None): def forward(self, x):
x = self.conv(x) x = self.conv(x)
x = self.blocks(x) x = self.blocks(x)
x = self.avg_pool(x) x = self.avg_pool(x)
x = self.flatten(x) x = self.flatten(x)
x = self.fc(x) x = self.fc(x)
if self.res_dict and res_dict is not None:
for res_key in list(self.res_dict):
res_dict[res_key] = self.res_dict.pop(res_key)
return x return x
......
...@@ -142,7 +142,8 @@ class MobileNetV3(TheseusLayer): ...@@ -142,7 +142,8 @@ class MobileNetV3(TheseusLayer):
inplanes=STEM_CONV_NUMBER, inplanes=STEM_CONV_NUMBER,
class_squeeze=LAST_SECOND_CONV_LARGE, class_squeeze=LAST_SECOND_CONV_LARGE,
class_expand=LAST_CONV, class_expand=LAST_CONV,
dropout_prob=0.2): dropout_prob=0.2,
return_patterns=None):
super().__init__() super().__init__()
self.cfg = config self.cfg = config
...@@ -199,6 +200,9 @@ class MobileNetV3(TheseusLayer): ...@@ -199,6 +200,9 @@ class MobileNetV3(TheseusLayer):
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1) self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
self.fc = Linear(self.class_expand, class_num) self.fc = Linear(self.class_expand, class_num)
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
......
...@@ -269,7 +269,8 @@ class ResNet(TheseusLayer): ...@@ -269,7 +269,8 @@ class ResNet(TheseusLayer):
class_num=1000, class_num=1000,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0], lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
data_format="NCHW", data_format="NCHW",
input_image_channel=3): input_image_channel=3,
return_patterns=None):
super().__init__() super().__init__()
self.cfg = config self.cfg = config
...@@ -337,6 +338,9 @@ class ResNet(TheseusLayer): ...@@ -337,6 +338,9 @@ class ResNet(TheseusLayer):
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv))) weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
self.data_format = data_format self.data_format = data_format
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, x): def forward(self, x):
with paddle.static.amp.fp16_guard(): with paddle.static.amp.fp16_guard():
......
...@@ -111,7 +111,7 @@ class VGGNet(TheseusLayer): ...@@ -111,7 +111,7 @@ class VGGNet(TheseusLayer):
model: nn.Layer. Specific VGG model depends on args. model: nn.Layer. Specific VGG model depends on args.
""" """
def __init__(self, config, stop_grad_layers=0, class_num=1000): def __init__(self, config, stop_grad_layers=0, class_num=1000, return_patterns=None):
super().__init__() super().__init__()
self.stop_grad_layers = stop_grad_layers self.stop_grad_layers = stop_grad_layers
...@@ -137,8 +137,11 @@ class VGGNet(TheseusLayer): ...@@ -137,8 +137,11 @@ class VGGNet(TheseusLayer):
self.fc1 = Linear(7 * 7 * 512, 4096) self.fc1 = Linear(7 * 7 * 512, 4096)
self.fc2 = Linear(4096, 4096) self.fc2 = Linear(4096, 4096)
self.fc3 = Linear(4096, class_num) self.fc3 = Linear(4096, class_num)
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, inputs, res_dict=None): def forward(self, inputs):
x = self.conv_block_1(inputs) x = self.conv_block_1(inputs)
x = self.conv_block_2(x) x = self.conv_block_2(x)
x = self.conv_block_3(x) x = self.conv_block_3(x)
...@@ -152,9 +155,6 @@ class VGGNet(TheseusLayer): ...@@ -152,9 +155,6 @@ class VGGNet(TheseusLayer):
x = self.relu(x) x = self.relu(x)
x = self.drop(x) x = self.drop(x)
x = self.fc3(x) x = self.fc3(x)
if self.res_dict and res_dict is not None:
for res_key in list(self.res_dict):
res_dict[res_key] = self.res_dict.pop(res_key)
return x return x
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册