From c4fd53070eddb6929d57c397e8e086341c28a5d9 Mon Sep 17 00:00:00 2001 From: WenmuZhou <572459439@qq.com> Date: Mon, 22 Aug 2022 11:33:00 +0000 Subject: [PATCH] update fp16 load --- ppocr/utils/save_load.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 7ccadb00..1a377f9e 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -102,8 +102,9 @@ def load_model(config, model, optimizer=None, model_type='det'): continue pre_value = params[key] if pre_value.dtype == paddle.float16: - pre_value = pre_value.astype(paddle.float32) is_float16 = True + if pre_value.dtype != value.dtype: + pre_value = pre_value.astype(value.dtype) if list(value.shape) == list(pre_value.shape): new_state_dict[key] = pre_value else: @@ -160,8 +161,9 @@ def load_pretrained_params(model, path): logger.warning("The pretrained params {} not in model".format(k1)) else: if params[k1].dtype == paddle.float16: - params[k1] = params[k1].astype(paddle.float32) is_float16 = True + if params[k1].dtype != state_dict[k1].dtype: + params[k1] = params[k1].astype(state_dict[k1].dtype) if list(state_dict[k1].shape) == list(params[k1].shape): new_state_dict[k1] = params[k1] else: -- GitLab