diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py index a62922f06e8b08e3f536349404de26d2f5af6c83..539680d974100be2a301b27411c947ae0b608743 100644 --- a/ppocr/losses/distillation_loss.py +++ b/ppocr/losses/distillation_loss.py @@ -82,7 +82,7 @@ class DistillationDistanceLoss(DistanceLoss): key=None, name="loss_distance", **kargs): - super().__init__(mode=mode, name=name) + super().__init__(mode=mode, name=name, **kargs) assert isinstance(model_name_pairs, list) self.key = key self.model_name_pairs = model_name_pairs diff --git a/ppocr/modeling/architectures/distillation_model.py b/ppocr/modeling/architectures/distillation_model.py index bbb9dceb82fb29f21555a249f7437495bedc6d2b..255ff32b23743f34bc6bd7cabde9d607d8416928 100644 --- a/ppocr/modeling/architectures/distillation_model.py +++ b/ppocr/modeling/architectures/distillation_model.py @@ -34,8 +34,8 @@ class DistillationModel(nn.Layer): config (dict): the super parameters for module. """ super().__init__() - self.model_dict = dict() - index = 0 + self.model_list = [] + self.model_name_list = [] for key in config["Models"]: model_config = config["Models"][key] freeze_params = False @@ -46,15 +46,15 @@ class DistillationModel(nn.Layer): pretrained = model_config.pop("pretrained") model = BaseModel(model_config) if pretrained is not None: - load_dygraph_pretrain(model, path=pretrained[index]) + load_dygraph_pretrain(model, path=pretrained) if freeze_params: for param in model.parameters(): param.trainable = False - self.model_dict[key] = self.add_sublayer(key, model) - index += 1 + self.model_list.append(self.add_sublayer(key, model)) + self.model_name_list.append(key) def forward(self, x): result_dict = dict() - for key in self.model_dict: - result_dict[key] = self.model_dict[key](x) + for idx, model_name in enumerate(self.model_name_list): + result_dict[model_name] = self.model_list[idx](x) return result_dict diff --git a/tools/export_model.py b/tools/export_model.py index bdff89f755d465742f1c2a810f8ae76153a558c6..1d4538c829672d7780fdf01868e544311f6cd312 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -17,7 +17,7 @@ import sys __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) +sys.path.append(os.path.abspath(os.path.join(__dir__, ".."))) import argparse @@ -31,32 +31,12 @@ from ppocr.utils.logging import get_logger from tools.program import load_config, merge_config, ArgsParser -def main(): - FLAGS = ArgsParser().parse_args() - config = load_config(FLAGS.config) - merge_config(FLAGS.opt) - logger = get_logger() - # build post process - - post_process_class = build_post_process(config['PostProcess'], - config['Global']) - - # build model - # for rec algorithm - if hasattr(post_process_class, 'character'): - char_num = len(getattr(post_process_class, 'character')) - config['Architecture']["Head"]['out_channels'] = char_num - model = build_model(config['Architecture']) - init_model(config, model, logger) - model.eval() - - save_path = '{}/inference'.format(config['Global']['save_inference_dir']) - - if config['Architecture']['algorithm'] == "SRN": - max_text_length = config['Architecture']['Head']['max_text_length'] +def export_single_model(model, arch_config, save_path, logger): + if arch_config["algorithm"] == "SRN": + max_text_length = arch_config["Head"]["max_text_length"] other_shape = [ paddle.static.InputSpec( - shape=[None, 1, 64, 256], dtype='float32'), [ + shape=[None, 1, 64, 256], dtype="float32"), [ paddle.static.InputSpec( shape=[None, 256, 1], dtype="int64"), paddle.static.InputSpec( @@ -71,24 +51,66 @@ def main(): model = to_static(model, input_spec=other_shape) else: infer_shape = [3, -1, -1] - if config['Architecture']['model_type'] == "rec": + if arch_config["model_type"] == "rec": infer_shape = [3, 32, -1] # for rec model, H must be 32 - if 'Transform' in config['Architecture'] and config['Architecture'][ - 'Transform'] is not None and config['Architecture'][ - 'Transform']['name'] == 'TPS': + if "Transform" in arch_config and arch_config[ + "Transform"] is not None and arch_config["Transform"][ + "name"] == "TPS": logger.info( - 'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training' + "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training" ) infer_shape[-1] = 100 + model = to_static( model, input_spec=[ paddle.static.InputSpec( - shape=[None] + infer_shape, dtype='float32') + shape=[None] + infer_shape, dtype="float32") ]) paddle.jit.save(model, save_path) - logger.info('inference model is saved to {}'.format(save_path)) + logger.info("inference model is saved to {}".format(save_path)) + return + + +def main(): + FLAGS = ArgsParser().parse_args() + config = load_config(FLAGS.config) + merge_config(FLAGS.opt) + logger = get_logger() + # build post process + + post_process_class = build_post_process(config["PostProcess"], + config["Global"]) + + # build model + # for rec algorithm + if hasattr(post_process_class, "character"): + char_num = len(getattr(post_process_class, "character")) + if config["Architecture"]["algorithm"] in ["Distillation", + ]: # distillation model + for key in config["Architecture"]["Models"]: + config["Architecture"]["Models"][key]["Head"][ + "out_channels"] = char_num + else: # base rec model + config["Architecture"]["Head"]["out_channels"] = char_num + model = build_model(config["Architecture"]) + init_model(config, model, logger) + model.eval() + + save_path = config["Global"]["save_inference_dir"] + + arch_config = config["Architecture"] + + if arch_config["algorithm"] in ["Distillation", ]: # distillation model + archs = list(arch_config["Models"].values()) + for idx, name in enumerate(model.model_name_list): + sub_model_save_path = os.path.join(save_path, name, "inference") + export_single_model(model.model_list[idx], archs[idx], + sub_model_save_path, logger) + else: + save_path = os.path.join(save_path, "inference") + export_single_model(model, arch_config, save_path, logger) if __name__ == "__main__":