diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 5d43d656f2f8aca91aa1adbeac95d26381264531..53bf5fa9d03b891be964417fda3ae8d873191f9b 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -53,6 +53,7 @@ def load_model(config, model, optimizer=None, model_type='det'): checkpoints = global_config.get('checkpoints') pretrained_model = global_config.get('pretrained_model') best_model_dict = {} + is_float16 = False if model_type == 'vqa': checkpoints = config['Architecture']['Backbone']['checkpoints'] @@ -90,7 +91,6 @@ def load_model(config, model, optimizer=None, model_type='det'): params = paddle.load(checkpoints + '.pdparams') state_dict = model.state_dict() new_state_dict = {} - is_float16 = False for key, value in state_dict.items(): if key not in params: logger.warning("{} not in loaded params {} !".format( @@ -107,7 +107,6 @@ def load_model(config, model, optimizer=None, model_type='det'): "The shape of model params {} {} not matched with loaded params shape {} !". format(key, value.shape, pre_value.shape)) model.set_state_dict(new_state_dict) - if is_float16: logger.info( "The parameter type is float16, which is converted to float32 when loading" @@ -130,9 +129,10 @@ def load_model(config, model, optimizer=None, model_type='det'): best_model_dict['start_epoch'] = states_dict['epoch'] + 1 logger.info("resume from {}".format(checkpoints)) elif pretrained_model: - load_pretrained_params(model, pretrained_model) + is_float16 = load_pretrained_params(model, pretrained_model) else: logger.info('train from scratch') + best_model_dict['is_float16'] = is_float16 return best_model_dict @@ -167,7 +167,7 @@ def load_pretrained_params(model, path): "The parameter type is float16, which is converted to float32 when loading" ) logger.info("load pretrain successful from {}".format(path)) - return model + return is_float16 def save_model(model,