提交 6361a38f 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix export model for distillation model

上级 ab4db2ac
......@@ -82,7 +82,7 @@ class DistillationDistanceLoss(DistanceLoss):
key=None,
name="loss_distance",
**kargs):
super().__init__(mode=mode, name=name)
super().__init__(mode=mode, name=name, **kargs)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = model_name_pairs
......
......@@ -34,8 +34,8 @@ class DistillationModel(nn.Layer):
config (dict): the super parameters for module.
"""
super().__init__()
self.model_dict = dict()
index = 0
self.model_list = []
self.model_name_list = []
for key in config["Models"]:
model_config = config["Models"][key]
freeze_params = False
......@@ -46,15 +46,15 @@ class DistillationModel(nn.Layer):
pretrained = model_config.pop("pretrained")
model = BaseModel(model_config)
if pretrained is not None:
load_dygraph_pretrain(model, path=pretrained[index])
load_dygraph_pretrain(model, path=pretrained)
if freeze_params:
for param in model.parameters():
param.trainable = False
self.model_dict[key] = self.add_sublayer(key, model)
index += 1
self.model_list.append(self.add_sublayer(key, model))
self.model_name_list.append(key)
def forward(self, x):
result_dict = dict()
for key in self.model_dict:
result_dict[key] = self.model_dict[key](x)
for idx, model_name in enumerate(self.model_name_list):
result_dict[model_name] = self.model_list[idx](x)
return result_dict
......@@ -17,7 +17,7 @@ import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
sys.path.append(os.path.abspath(os.path.join(__dir__, "..")))
import argparse
......@@ -31,32 +31,12 @@ from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser
def main():
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
logger = get_logger()
# build post process
post_process_class = build_post_process(config['PostProcess'],
config['Global'])
# build model
# for rec algorithm
if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
init_model(config, model, logger)
model.eval()
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
if config['Architecture']['algorithm'] == "SRN":
max_text_length = config['Architecture']['Head']['max_text_length']
def export_single_model(model, arch_config, save_path, logger):
if arch_config["algorithm"] == "SRN":
max_text_length = arch_config["Head"]["max_text_length"]
other_shape = [
paddle.static.InputSpec(
shape=[None, 1, 64, 256], dtype='float32'), [
shape=[None, 1, 64, 256], dtype="float32"), [
paddle.static.InputSpec(
shape=[None, 256, 1],
dtype="int64"), paddle.static.InputSpec(
......@@ -71,24 +51,66 @@ def main():
model = to_static(model, input_spec=other_shape)
else:
infer_shape = [3, -1, -1]
if config['Architecture']['model_type'] == "rec":
if arch_config["model_type"] == "rec":
infer_shape = [3, 32, -1] # for rec model, H must be 32
if 'Transform' in config['Architecture'] and config['Architecture'][
'Transform'] is not None and config['Architecture'][
'Transform']['name'] == 'TPS':
if "Transform" in arch_config and arch_config[
"Transform"] is not None and arch_config["Transform"][
"name"] == "TPS":
logger.info(
'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
)
infer_shape[-1] = 100
model = to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None] + infer_shape, dtype='float32')
shape=[None] + infer_shape, dtype="float32")
])
paddle.jit.save(model, save_path)
logger.info('inference model is saved to {}'.format(save_path))
logger.info("inference model is saved to {}".format(save_path))
return
def main():
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
logger = get_logger()
# build post process
post_process_class = build_post_process(config["PostProcess"],
config["Global"])
# build model
# for rec algorithm
if hasattr(post_process_class, "character"):
char_num = len(getattr(post_process_class, "character"))
if config["Architecture"]["algorithm"] in ["Distillation",
]: # distillation model
for key in config["Architecture"]["Models"]:
config["Architecture"]["Models"][key]["Head"][
"out_channels"] = char_num
else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config["Architecture"])
init_model(config, model, logger)
model.eval()
save_path = config["Global"]["save_inference_dir"]
arch_config = config["Architecture"]
if arch_config["algorithm"] in ["Distillation", ]: # distillation model
archs = list(arch_config["Models"].values())
for idx, name in enumerate(model.model_name_list):
sub_model_save_path = os.path.join(save_path, name, "inference")
export_single_model(model.model_list[idx], archs[idx],
sub_model_save_path, logger)
else:
save_path = os.path.join(save_path, "inference")
export_single_model(model, arch_config, save_path, logger)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册