未验证 提交 c26429de 编写于 作者: W Wei Shengyu 提交者: GitHub

Merge pull request #839 from littletomatodonkey/reg/add_distillation_infer

fix distillation model infer and export model
Global: Global:
infer_imgs: "../docs/images/whl/demo.jpg" infer_imgs: "../docs/images/whl/demo.jpg"
inference_model_dir: "./MobileNetV1_infer/" inference_model_dir: "../inference/"
batch_size: 1 batch_size: 1
use_gpu: True use_gpu: True
enable_mkldnn: True enable_mkldnn: True
...@@ -27,4 +27,4 @@ PreProcess: ...@@ -27,4 +27,4 @@ PreProcess:
PostProcess: PostProcess:
name: Topk name: Topk
topk: 5 topk: 5
class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt" class_id_map_file: "../ppcls/utils/imagenet1k_label_list.txt"
\ No newline at end of file \ No newline at end of file
...@@ -69,7 +69,8 @@ class DistillationModel(nn.Layer): ...@@ -69,7 +69,8 @@ class DistillationModel(nn.Layer):
def __init__(self, def __init__(self,
models=None, models=None,
pretrained_list=None, pretrained_list=None,
freeze_params_list=None): freeze_params_list=None,
**kargs):
super().__init__() super().__init__()
assert isinstance(models, list) assert isinstance(models, list)
self.model_list = [] self.model_list = []
...@@ -105,5 +106,5 @@ class DistillationModel(nn.Layer): ...@@ -105,5 +106,5 @@ class DistillationModel(nn.Layer):
if label is None: if label is None:
result_dict[model_name] = self.model_list[idx](x) result_dict[model_name] = self.model_list[idx](x)
else: else:
result_dict[model_name] = self.model_list[idx](x) result_dict[model_name] = self.model_list[idx](x, label)
return result_dict return result_dict
...@@ -33,6 +33,8 @@ Arch: ...@@ -33,6 +33,8 @@ Arch:
name: MobileNetV3_small_x1_0 name: MobileNetV3_small_x1_0
pretrained: False pretrained: False
infer_model_name: "Student"
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
...@@ -136,7 +138,8 @@ Infer: ...@@ -136,7 +138,8 @@ Infer:
order: '' order: ''
- ToCHWImage: - ToCHWImage:
PostProcess: PostProcess:
name: Topk name: DistillationPostProcess
func: Topk
topk: 5 topk: 5
class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt" class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt"
......
...@@ -25,3 +25,17 @@ def build_postprocess(config): ...@@ -25,3 +25,17 @@ def build_postprocess(config):
mod = importlib.import_module(__name__) mod = importlib.import_module(__name__)
postprocess_func = getattr(mod, model_name)(**config) postprocess_func = getattr(mod, model_name)(**config)
return postprocess_func return postprocess_func
class DistillationPostProcess(object):
def __init__(self, model_name="Student", key=None, func="Topk", **kargs):
super().__init__()
self.func = eval(func)(**kargs)
self.model_name = model_name
self.key = key
def __call__(self, x, file_names=None):
x = x[self.model_name]
if self.key is not None:
x = x[self.key]
return self.func(x, file_names=file_names)
...@@ -24,20 +24,27 @@ import paddle ...@@ -24,20 +24,27 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
from ppcls.utils import config from ppcls.utils import config
from ppcls.arch import build_model, RecModel from ppcls.arch import build_model, RecModel, DistillationModel
from ppcls.utils.save_load import load_dygraph_pretrain from ppcls.utils.save_load import load_dygraph_pretrain
from ppcls.arch.gears.identity_head import IdentityHead from ppcls.arch.gears.identity_head import IdentityHead
class ExportModel(nn.Layer): class ExportModel(nn.Layer):
""" """
ClasModel: add softmax onto the model ExportModel: add softmax onto the model
""" """
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.base_model = build_model(config) self.base_model = build_model(config)
self.infer_output_key = config.get("infer_output_key")
# we should choose a final model to export
if isinstance(self.base_model, DistillationModel):
self.infer_model_name = config["infer_model_name"]
else:
self.infer_model_name = None
self.infer_output_key = config.get("infer_output_key", None)
if self.infer_output_key == "features" and isinstance(self.base_model, if self.infer_output_key == "features" and isinstance(self.base_model,
RecModel): RecModel):
self.base_model.head = IdentityHead() self.base_model.head = IdentityHead()
...@@ -54,6 +61,8 @@ class ExportModel(nn.Layer): ...@@ -54,6 +61,8 @@ class ExportModel(nn.Layer):
def forward(self, x): def forward(self, x):
x = self.base_model(x) x = self.base_model(x)
if self.infer_model_name is not None:
x = x[self.infer_model_name]
if self.infer_output_key is not None: if self.infer_output_key is not None:
x = x[self.infer_output_key] x = x[self.infer_output_key]
if self.softmax is not None: if self.softmax is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册