提交 6361a38f 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix export model for distillation model

上级 ab4db2ac
...@@ -82,7 +82,7 @@ class DistillationDistanceLoss(DistanceLoss): ...@@ -82,7 +82,7 @@ class DistillationDistanceLoss(DistanceLoss):
key=None, key=None,
name="loss_distance", name="loss_distance",
**kargs): **kargs):
super().__init__(mode=mode, name=name) super().__init__(mode=mode, name=name, **kargs)
assert isinstance(model_name_pairs, list) assert isinstance(model_name_pairs, list)
self.key = key self.key = key
self.model_name_pairs = model_name_pairs self.model_name_pairs = model_name_pairs
......
...@@ -34,8 +34,8 @@ class DistillationModel(nn.Layer): ...@@ -34,8 +34,8 @@ class DistillationModel(nn.Layer):
config (dict): the super parameters for module. config (dict): the super parameters for module.
""" """
super().__init__() super().__init__()
self.model_dict = dict() self.model_list = []
index = 0 self.model_name_list = []
for key in config["Models"]: for key in config["Models"]:
model_config = config["Models"][key] model_config = config["Models"][key]
freeze_params = False freeze_params = False
...@@ -46,15 +46,15 @@ class DistillationModel(nn.Layer): ...@@ -46,15 +46,15 @@ class DistillationModel(nn.Layer):
pretrained = model_config.pop("pretrained") pretrained = model_config.pop("pretrained")
model = BaseModel(model_config) model = BaseModel(model_config)
if pretrained is not None: if pretrained is not None:
load_dygraph_pretrain(model, path=pretrained[index]) load_dygraph_pretrain(model, path=pretrained)
if freeze_params: if freeze_params:
for param in model.parameters(): for param in model.parameters():
param.trainable = False param.trainable = False
self.model_dict[key] = self.add_sublayer(key, model) self.model_list.append(self.add_sublayer(key, model))
index += 1 self.model_name_list.append(key)
def forward(self, x): def forward(self, x):
result_dict = dict() result_dict = dict()
for key in self.model_dict: for idx, model_name in enumerate(self.model_name_list):
result_dict[key] = self.model_dict[key](x) result_dict[model_name] = self.model_list[idx](x)
return result_dict return result_dict
...@@ -17,7 +17,7 @@ import sys ...@@ -17,7 +17,7 @@ import sys
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, "..")))
import argparse import argparse
...@@ -31,32 +31,12 @@ from ppocr.utils.logging import get_logger ...@@ -31,32 +31,12 @@ from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser from tools.program import load_config, merge_config, ArgsParser
def main(): def export_single_model(model, arch_config, save_path, logger):
FLAGS = ArgsParser().parse_args() if arch_config["algorithm"] == "SRN":
config = load_config(FLAGS.config) max_text_length = arch_config["Head"]["max_text_length"]
merge_config(FLAGS.opt)
logger = get_logger()
# build post process
post_process_class = build_post_process(config['PostProcess'],
config['Global'])
# build model
# for rec algorithm
if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
init_model(config, model, logger)
model.eval()
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
if config['Architecture']['algorithm'] == "SRN":
max_text_length = config['Architecture']['Head']['max_text_length']
other_shape = [ other_shape = [
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 1, 64, 256], dtype='float32'), [ shape=[None, 1, 64, 256], dtype="float32"), [
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 256, 1], shape=[None, 256, 1],
dtype="int64"), paddle.static.InputSpec( dtype="int64"), paddle.static.InputSpec(
...@@ -71,24 +51,66 @@ def main(): ...@@ -71,24 +51,66 @@ def main():
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
else: else:
infer_shape = [3, -1, -1] infer_shape = [3, -1, -1]
if config['Architecture']['model_type'] == "rec": if arch_config["model_type"] == "rec":
infer_shape = [3, 32, -1] # for rec model, H must be 32 infer_shape = [3, 32, -1] # for rec model, H must be 32
if 'Transform' in config['Architecture'] and config['Architecture'][ if "Transform" in arch_config and arch_config[
'Transform'] is not None and config['Architecture'][ "Transform"] is not None and arch_config["Transform"][
'Transform']['name'] == 'TPS': "name"] == "TPS":
logger.info( logger.info(
'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training' "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
) )
infer_shape[-1] = 100 infer_shape[-1] = 100
model = to_static( model = to_static(
model, model,
input_spec=[ input_spec=[
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None] + infer_shape, dtype='float32') shape=[None] + infer_shape, dtype="float32")
]) ])
paddle.jit.save(model, save_path) paddle.jit.save(model, save_path)
logger.info('inference model is saved to {}'.format(save_path)) logger.info("inference model is saved to {}".format(save_path))
return
def main():
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
logger = get_logger()
# build post process
post_process_class = build_post_process(config["PostProcess"],
config["Global"])
# build model
# for rec algorithm
if hasattr(post_process_class, "character"):
char_num = len(getattr(post_process_class, "character"))
if config["Architecture"]["algorithm"] in ["Distillation",
]: # distillation model
for key in config["Architecture"]["Models"]:
config["Architecture"]["Models"][key]["Head"][
"out_channels"] = char_num
else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config["Architecture"])
init_model(config, model, logger)
model.eval()
save_path = config["Global"]["save_inference_dir"]
arch_config = config["Architecture"]
if arch_config["algorithm"] in ["Distillation", ]: # distillation model
archs = list(arch_config["Models"].values())
for idx, name in enumerate(model.model_name_list):
sub_model_save_path = os.path.join(save_path, name, "inference")
export_single_model(model.model_list[idx], archs[idx],
sub_model_save_path, logger)
else:
save_path = os.path.join(save_path, "inference")
export_single_model(model, arch_config, save_path, logger)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册