提交 d0efcc74 编写于 作者: 文幕地方's avatar 文幕地方

convert fp16 params to fp32 when params is fp16 format

上级 12aa058c
......@@ -90,12 +90,16 @@ def load_model(config, model, optimizer=None, model_type='det'):
params = paddle.load(checkpoints + '.pdparams')
state_dict = model.state_dict()
new_state_dict = {}
is_float16 = False
for key, value in state_dict.items():
if key not in params:
logger.warning("{} not in loaded params {} !".format(
key, params.keys()))
continue
pre_value = params[key]
if pre_value.dtype == paddle.float16:
pre_value = pre_value.astype(paddle.float32)
is_float16 = True
if list(value.shape) == list(pre_value.shape):
new_state_dict[key] = pre_value
else:
......@@ -104,6 +108,10 @@ def load_model(config, model, optimizer=None, model_type='det'):
format(key, value.shape, pre_value.shape))
model.set_state_dict(new_state_dict)
if is_float16:
logger.info(
"The parameter type is float16, which is converted to float32 when loading"
)
if optimizer is not None:
if os.path.exists(checkpoints + '.pdopt'):
optim_dict = paddle.load(checkpoints + '.pdopt')
......@@ -138,17 +146,26 @@ def load_pretrained_params(model, path):
params = paddle.load(path + '.pdparams')
state_dict = model.state_dict()
new_state_dict = {}
is_float16 = False
for k1 in params.keys():
if k1 not in state_dict.keys():
logger.warning("The pretrained params {} not in model".format(k1))
else:
if params[k1].dtype == paddle.float16:
params[k1] = params[k1].astype(paddle.float32)
is_float16 = True
if list(state_dict[k1].shape) == list(params[k1].shape):
new_state_dict[k1] = params[k1]
else:
logger.warning(
"The shape of model params {} {} not matched with loaded params {} {} !".
format(k1, state_dict[k1].shape, k1, params[k1].shape))
model.set_state_dict(new_state_dict)
if is_float16:
logger.info(
"The parameter type is float16, which is converted to float32 when loading"
)
logger.info("load pretrain successful from {}".format(path))
return model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册