提交 35437e22 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix distillation model infer and export model

上级 f3f0b399
......@@ -69,7 +69,8 @@ class DistillationModel(nn.Layer):
def __init__(self,
models=None,
pretrained_list=None,
freeze_params_list=None):
freeze_params_list=None,
**kargs):
super().__init__()
assert isinstance(models, list)
self.model_list = []
......@@ -105,5 +106,5 @@ class DistillationModel(nn.Layer):
if label is None:
result_dict[model_name] = self.model_list[idx](x)
else:
result_dict[model_name] = self.model_list[idx](x)
result_dict[model_name] = self.model_list[idx](x, label)
return result_dict
......@@ -33,6 +33,8 @@ Arch:
name: MobileNetV3_small_x1_0
pretrained: False
infer_model_name: "Student"
# loss function config for traing/eval process
Loss:
......@@ -136,7 +138,8 @@ Infer:
order: ''
- ToCHWImage:
PostProcess:
name: Topk
name: DistillationPostProcess
func: Topk
topk: 5
class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt"
......
......@@ -25,3 +25,17 @@ def build_postprocess(config):
mod = importlib.import_module(__name__)
postprocess_func = getattr(mod, model_name)(**config)
return postprocess_func
class DistillationPostProcess(object):
def __init__(self, model_name="Student", key=None, func="Topk", **kargs):
super().__init__()
self.func = eval(func)(**kargs)
self.model_name = model_name
self.key = key
def __call__(self, x, file_names=None):
x = x[self.model_name]
if self.key is not None:
x = x[self.key]
return self.func(x, file_names=file_names)
......@@ -24,20 +24,27 @@ import paddle
import paddle.nn as nn
from ppcls.utils import config
from ppcls.arch import build_model, RecModel
from ppcls.arch import build_model, RecModel, DistillationModel
from ppcls.utils.save_load import load_dygraph_pretrain
from ppcls.arch.gears.identity_head import IdentityHead
class ExportModel(nn.Layer):
"""
ClasModel: add softmax onto the model
ExportModel: add softmax onto the model
"""
def __init__(self, config):
super().__init__()
self.base_model = build_model(config)
self.infer_output_key = config.get("infer_output_key")
# we should choose a final model to export
if isinstance(self.base_model, DistillationModel):
self.infer_model_name = config["infer_model_name"]
else:
self.infer_model_name = None
self.infer_output_key = config.get("infer_output_key", None)
if self.infer_output_key == "features" and isinstance(self.base_model,
RecModel):
self.base_model.head = IdentityHead()
......@@ -54,6 +61,8 @@ class ExportModel(nn.Layer):
def forward(self, x):
x = self.base_model(x)
if self.infer_model_name is not None:
x = x[self.infer_model_name]
if self.infer_output_key is not None:
x = x[self.infer_output_key]
if self.softmax is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册