未验证 提交 44852aae 编写于 作者: A andyjpaddle 提交者: GitHub

Merge pull request #7109 from WenmuZhou/tipc1

convert fp16 params to fp32 when params is fp16 format
......@@ -53,6 +53,7 @@ def load_model(config, model, optimizer=None, model_type='det'):
checkpoints = global_config.get('checkpoints')
pretrained_model = global_config.get('pretrained_model')
best_model_dict = {}
is_float16 = False
if model_type == 'vqa':
# NOTE: for vqa model, resume training is not supported now
......@@ -100,6 +101,9 @@ def load_model(config, model, optimizer=None, model_type='det'):
key, params.keys()))
continue
pre_value = params[key]
if pre_value.dtype == paddle.float16:
pre_value = pre_value.astype(paddle.float32)
is_float16 = True
if list(value.shape) == list(pre_value.shape):
new_state_dict[key] = pre_value
else:
......@@ -107,7 +111,10 @@ def load_model(config, model, optimizer=None, model_type='det'):
"The shape of model params {} {} not matched with loaded params shape {} !".
format(key, value.shape, pre_value.shape))
model.set_state_dict(new_state_dict)
if is_float16:
logger.info(
"The parameter type is float16, which is converted to float32 when loading"
)
if optimizer is not None:
if os.path.exists(checkpoints + '.pdopt'):
optim_dict = paddle.load(checkpoints + '.pdopt')
......@@ -126,9 +133,10 @@ def load_model(config, model, optimizer=None, model_type='det'):
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
logger.info("resume from {}".format(checkpoints))
elif pretrained_model:
load_pretrained_params(model, pretrained_model)
is_float16 = load_pretrained_params(model, pretrained_model)
else:
logger.info('train from scratch')
best_model_dict['is_float16'] = is_float16
return best_model_dict
......@@ -142,19 +150,28 @@ def load_pretrained_params(model, path):
params = paddle.load(path + '.pdparams')
state_dict = model.state_dict()
new_state_dict = {}
is_float16 = False
for k1 in params.keys():
if k1 not in state_dict.keys():
logger.warning("The pretrained params {} not in model".format(k1))
else:
if params[k1].dtype == paddle.float16:
params[k1] = params[k1].astype(paddle.float32)
is_float16 = True
if list(state_dict[k1].shape) == list(params[k1].shape):
new_state_dict[k1] = params[k1]
else:
logger.warning(
"The shape of model params {} {} not matched with loaded params {} {} !".
format(k1, state_dict[k1].shape, k1, params[k1].shape))
model.set_state_dict(new_state_dict)
if is_float16:
logger.info(
"The parameter type is float16, which is converted to float32 when loading"
)
logger.info("load pretrain successful from {}".format(path))
return model
return is_float16
def save_model(model,
......
......@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=8|whole_train_whole_infer=8
Train.loader.batch_size_per_card:lite_train_lite_infer=4|whole_train_whole_infer=8
Architecture.Backbone.checkpoints:null
train_model_name:latest
train_infer_img_dir:ppstructure/docs/vqa/input/zh_val_42.jpg
......
......@@ -108,7 +108,7 @@ if [ ${MODE} = "benchmark_train" ];then
fi
if [ ${model_name} == "layoutxlm_ser" ]; then
pip install -r ppstructure/vqa/requirements.txt
pip install paddlenlp\>=2.3.5 --force-reinstall
pip install paddlenlp\>=2.3.5 --force-reinstall -i https://mirrors.aliyun.com/pypi/simple/
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate
cd ./train_data/ && tar xf XFUND.tar
# expand gt.txt 10 times
......@@ -222,7 +222,7 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
fi
if [ ${model_name} == "layoutxlm_ser" ]; then
pip install -r ppstructure/vqa/requirements.txt
pip install paddlenlp\>=2.3.5 --force-reinstall
pip install paddlenlp\>=2.3.5 --force-reinstall -i https://mirrors.aliyun.com/pypi/simple/
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate
cd ./train_data/ && tar xf XFUND.tar
cd ../
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册