提交 c4fd5307 编写于 作者: 文幕地方's avatar 文幕地方

update fp16 load

上级 c3924a95
...@@ -102,8 +102,9 @@ def load_model(config, model, optimizer=None, model_type='det'): ...@@ -102,8 +102,9 @@ def load_model(config, model, optimizer=None, model_type='det'):
continue continue
pre_value = params[key] pre_value = params[key]
if pre_value.dtype == paddle.float16: if pre_value.dtype == paddle.float16:
pre_value = pre_value.astype(paddle.float32)
is_float16 = True is_float16 = True
if pre_value.dtype != value.dtype:
pre_value = pre_value.astype(value.dtype)
if list(value.shape) == list(pre_value.shape): if list(value.shape) == list(pre_value.shape):
new_state_dict[key] = pre_value new_state_dict[key] = pre_value
else: else:
...@@ -160,8 +161,9 @@ def load_pretrained_params(model, path): ...@@ -160,8 +161,9 @@ def load_pretrained_params(model, path):
logger.warning("The pretrained params {} not in model".format(k1)) logger.warning("The pretrained params {} not in model".format(k1))
else: else:
if params[k1].dtype == paddle.float16: if params[k1].dtype == paddle.float16:
params[k1] = params[k1].astype(paddle.float32)
is_float16 = True is_float16 = True
if params[k1].dtype != state_dict[k1].dtype:
params[k1] = params[k1].astype(state_dict[k1].dtype)
if list(state_dict[k1].shape) == list(params[k1].shape): if list(state_dict[k1].shape) == list(params[k1].shape):
new_state_dict[k1] = params[k1] new_state_dict[k1] = params[k1]
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册