diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 3647111fddaa848a75873ab689559c63dd6d4814..5d43d656f2f8aca91aa1adbeac95d26381264531 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -90,12 +90,16 @@ def load_model(config, model, optimizer=None, model_type='det'): params = paddle.load(checkpoints + '.pdparams') state_dict = model.state_dict() new_state_dict = {} + is_float16 = False for key, value in state_dict.items(): if key not in params: logger.warning("{} not in loaded params {} !".format( key, params.keys())) continue pre_value = params[key] + if pre_value.dtype == paddle.float16: + pre_value = pre_value.astype(paddle.float32) + is_float16 = True if list(value.shape) == list(pre_value.shape): new_state_dict[key] = pre_value else: @@ -104,6 +108,10 @@ def load_model(config, model, optimizer=None, model_type='det'): format(key, value.shape, pre_value.shape)) model.set_state_dict(new_state_dict) + if is_float16: + logger.info( + "The parameter type is float16, which is converted to float32 when loading" + ) if optimizer is not None: if os.path.exists(checkpoints + '.pdopt'): optim_dict = paddle.load(checkpoints + '.pdopt') @@ -138,17 +146,26 @@ def load_pretrained_params(model, path): params = paddle.load(path + '.pdparams') state_dict = model.state_dict() new_state_dict = {} + is_float16 = False for k1 in params.keys(): if k1 not in state_dict.keys(): logger.warning("The pretrained params {} not in model".format(k1)) else: + if params[k1].dtype == paddle.float16: + params[k1] = params[k1].astype(paddle.float32) + is_float16 = True if list(state_dict[k1].shape) == list(params[k1].shape): new_state_dict[k1] = params[k1] else: logger.warning( "The shape of model params {} {} not matched with loaded params {} {} !". format(k1, state_dict[k1].shape, k1, params[k1].shape)) + model.set_state_dict(new_state_dict) + if is_float16: + logger.info( + "The parameter type is float16, which is converted to float32 when loading" + ) logger.info("load pretrain successful from {}".format(path)) return model