提交 e9162595 编写于 作者: W weishengyu

update return_res method

上级 5131956d
......@@ -23,6 +23,7 @@ from . import backbone, gears
from .backbone import *
from .gears import build_gear
from .utils import *
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.utils import logger
from ppcls.utils.save_load import load_dygraph_pretrain
......@@ -32,8 +33,11 @@ __all__ = ["build_model", "RecModel", "DistillationModel"]
def build_model(config):
config = copy.deepcopy(config)
model_type = config.pop("name")
return_patterns = config.pop("return_patterns", None)
mod = importlib.import_module(__name__)
arch = getattr(mod, model_type)(**config)
if return_patterns is not None and isinstance(arch, TheseusLayer):
arch.update_res(return_patterns=return_patterns, return_dict=True)
return arch
......@@ -55,7 +59,10 @@ class RecModel(nn.Layer):
super().__init__()
backbone_config = config["Backbone"]
backbone_name = backbone_config.pop("name")
return_patterns = config.pop("return_patterns", None)
self.backbone = eval(backbone_name)(**backbone_config)
if return_patterns is not None and isinstance(self.backbone, TheseusLayer):
self.backbone.update_res(return_patterns=return_patterns, return_dict=True)
if "BackboneStopLayer" in config:
backbone_stop_layer = config["BackboneStopLayer"]["name"]
self.backbone.stop_after(backbone_stop_layer)
......
......@@ -57,6 +57,12 @@ class TheseusLayer(nn.Layer):
def _save_sub_res_hook(self, layer, input, output):
self.res_dict[layer.full_name()] = output
def _return_dict_hook(self, layer, input, output):
res_dict = {"output": output}
for res_key in list(self.res_dict):
res_dict[res_key] = self.res_dict.pop(res_key)
return res_dict
def replace_sub(self, layer_name_pattern, replace_function, recursive=True):
for layer_i in self._sub_layers:
layer_name = self._sub_layers[layer_i].full_name()
......
......@@ -367,7 +367,7 @@ class HRNet(TheseusLayer):
model: nn.Layer. Specific HRNet model depends on args.
"""
def __init__(self, width=18, has_se=False, class_num=1000):
def __init__(self, width=18, has_se=False, class_num=1000, return_patterns=None):
super().__init__()
self.width = width
......@@ -456,8 +456,11 @@ class HRNet(TheseusLayer):
2048,
class_num,
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, x, res_dict=None):
def forward(self, x):
x = self.conv_layer1_1(x)
x = self.conv_layer1_2(x)
......
......@@ -454,7 +454,7 @@ class Inception_V3(TheseusLayer):
model: nn.Layer. Specific Inception_V3 model depends on args.
"""
def __init__(self, config, class_num=1000):
def __init__(self, config, class_num=1000, return_patterns=None):
super().__init__()
self.inception_a_list = config["inception_a"]
......@@ -496,6 +496,9 @@ class Inception_V3(TheseusLayer):
class_num,
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)),
bias_attr=ParamAttr())
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, x):
x = self.inception_stem(x)
......
......@@ -102,7 +102,7 @@ class MobileNet(TheseusLayer):
model: nn.Layer. Specific MobileNet model depends on args.
"""
def __init__(self, scale=1.0, class_num=1000):
def __init__(self, scale=1.0, class_num=1000, return_patterns=None):
super().__init__()
self.scale = scale
......@@ -145,16 +145,16 @@ class MobileNet(TheseusLayer):
int(1024 * scale),
class_num,
weight_attr=ParamAttr(initializer=KaimingNormal()))
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, x, res_dict=None):
def forward(self, x):
x = self.conv(x)
x = self.blocks(x)
x = self.avg_pool(x)
x = self.flatten(x)
x = self.fc(x)
if self.res_dict and res_dict is not None:
for res_key in list(self.res_dict):
res_dict[res_key] = self.res_dict.pop(res_key)
return x
......
......@@ -142,7 +142,8 @@ class MobileNetV3(TheseusLayer):
inplanes=STEM_CONV_NUMBER,
class_squeeze=LAST_SECOND_CONV_LARGE,
class_expand=LAST_CONV,
dropout_prob=0.2):
dropout_prob=0.2,
return_patterns=None):
super().__init__()
self.cfg = config
......@@ -199,6 +200,9 @@ class MobileNetV3(TheseusLayer):
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
self.fc = Linear(self.class_expand, class_num)
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, x):
x = self.conv(x)
......
......@@ -269,7 +269,8 @@ class ResNet(TheseusLayer):
class_num=1000,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
data_format="NCHW",
input_image_channel=3):
input_image_channel=3,
return_patterns=None):
super().__init__()
self.cfg = config
......@@ -337,6 +338,9 @@ class ResNet(TheseusLayer):
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
self.data_format = data_format
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, x):
with paddle.static.amp.fp16_guard():
......
......@@ -111,7 +111,7 @@ class VGGNet(TheseusLayer):
model: nn.Layer. Specific VGG model depends on args.
"""
def __init__(self, config, stop_grad_layers=0, class_num=1000):
def __init__(self, config, stop_grad_layers=0, class_num=1000, return_patterns=None):
super().__init__()
self.stop_grad_layers = stop_grad_layers
......@@ -137,8 +137,11 @@ class VGGNet(TheseusLayer):
self.fc1 = Linear(7 * 7 * 512, 4096)
self.fc2 = Linear(4096, 4096)
self.fc3 = Linear(4096, class_num)
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, inputs, res_dict=None):
def forward(self, inputs):
x = self.conv_block_1(inputs)
x = self.conv_block_2(x)
x = self.conv_block_3(x)
......@@ -152,9 +155,6 @@ class VGGNet(TheseusLayer):
x = self.relu(x)
x = self.drop(x)
x = self.fc3(x)
if self.res_dict and res_dict is not None:
for res_key in list(self.res_dict):
res_dict[res_key] = self.res_dict.pop(res_key)
return x
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册