diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 7ccadb005a8ad591d9927c0e028887caacb3e37b..1a377f9ed6ab49a15d8b29d886a6e4640926e991 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -102,8 +102,9 @@ def load_model(config, model, optimizer=None, model_type='det'): continue pre_value = params[key] if pre_value.dtype == paddle.float16: - pre_value = pre_value.astype(paddle.float32) is_float16 = True + if pre_value.dtype != value.dtype: + pre_value = pre_value.astype(value.dtype) if list(value.shape) == list(pre_value.shape): new_state_dict[key] = pre_value else: @@ -160,8 +161,9 @@ def load_pretrained_params(model, path): logger.warning("The pretrained params {} not in model".format(k1)) else: if params[k1].dtype == paddle.float16: - params[k1] = params[k1].astype(paddle.float32) is_float16 = True + if params[k1].dtype != state_dict[k1].dtype: + params[k1] = params[k1].astype(state_dict[k1].dtype) if list(state_dict[k1].shape) == list(params[k1].shape): new_state_dict[k1] = params[k1] else: