diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 8fded687c62e8de9ff126037ec2a9fd88db9590d..e77a6ce0183611569193e1996e935f4bd30400a0 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -53,6 +53,7 @@ def load_model(config, model, optimizer=None, model_type='det'): checkpoints = global_config.get('checkpoints') pretrained_model = global_config.get('pretrained_model') best_model_dict = {} + is_float16 = False if model_type == 'vqa': # NOTE: for vqa model, resume training is not supported now @@ -100,6 +101,9 @@ def load_model(config, model, optimizer=None, model_type='det'): key, params.keys())) continue pre_value = params[key] + if pre_value.dtype == paddle.float16: + pre_value = pre_value.astype(paddle.float32) + is_float16 = True if list(value.shape) == list(pre_value.shape): new_state_dict[key] = pre_value else: @@ -107,7 +111,10 @@ def load_model(config, model, optimizer=None, model_type='det'): "The shape of model params {} {} not matched with loaded params shape {} !". format(key, value.shape, pre_value.shape)) model.set_state_dict(new_state_dict) - + if is_float16: + logger.info( + "The parameter type is float16, which is converted to float32 when loading" + ) if optimizer is not None: if os.path.exists(checkpoints + '.pdopt'): optim_dict = paddle.load(checkpoints + '.pdopt') @@ -126,9 +133,10 @@ def load_model(config, model, optimizer=None, model_type='det'): best_model_dict['start_epoch'] = states_dict['epoch'] + 1 logger.info("resume from {}".format(checkpoints)) elif pretrained_model: - load_pretrained_params(model, pretrained_model) + is_float16 = load_pretrained_params(model, pretrained_model) else: logger.info('train from scratch') + best_model_dict['is_float16'] = is_float16 return best_model_dict @@ -142,19 +150,28 @@ def load_pretrained_params(model, path): params = paddle.load(path + '.pdparams') state_dict = model.state_dict() new_state_dict = {} + is_float16 = False for k1 in params.keys(): if k1 not in state_dict.keys(): logger.warning("The pretrained params {} not in model".format(k1)) else: + if params[k1].dtype == paddle.float16: + params[k1] = params[k1].astype(paddle.float32) + is_float16 = True if list(state_dict[k1].shape) == list(params[k1].shape): new_state_dict[k1] = params[k1] else: logger.warning( "The shape of model params {} {} not matched with loaded params {} {} !". format(k1, state_dict[k1].shape, k1, params[k1].shape)) + model.set_state_dict(new_state_dict) + if is_float16: + logger.info( + "The parameter type is float16, which is converted to float32 when loading" + ) logger.info("load pretrain successful from {}".format(path)) - return model + return is_float16 def save_model(model, diff --git a/test_tipc/configs/layoutxlm_ser/train_infer_python.txt b/test_tipc/configs/layoutxlm_ser/train_infer_python.txt index 53415b3e8df56dff804c39f5dc1e50a774c05d76..34082bc193a2ebd8f4c7a9e7c9ce55dc8dbf8e40 100644 --- a/test_tipc/configs/layoutxlm_ser/train_infer_python.txt +++ b/test_tipc/configs/layoutxlm_ser/train_infer_python.txt @@ -6,7 +6,7 @@ Global.use_gpu:True|True Global.auto_cast:fp32 Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17 Global.save_model_dir:./output/ -Train.loader.batch_size_per_card:lite_train_lite_infer=8|whole_train_whole_infer=8 +Train.loader.batch_size_per_card:lite_train_lite_infer=4|whole_train_whole_infer=8 Architecture.Backbone.checkpoints:null train_model_name:latest train_infer_img_dir:ppstructure/docs/vqa/input/zh_val_42.jpg diff --git a/tools/program.py b/tools/program.py index c4a9f916e692b0015855848db393ab8a083b9882..051fdf581b0d5cfcc2678c1cbc46bc7e7246805f 100755 --- a/tools/program.py +++ b/tools/program.py @@ -160,18 +160,18 @@ def to_float32(preds): for k in preds: if isinstance(preds[k], dict) or isinstance(preds[k], list): preds[k] = to_float32(preds[k]) - elif isinstance(preds[k], paddle.Tensor): - preds[k] = preds[k].astype(paddle.float32) + else: + preds[k] = paddle.to_tensor(preds[k], dtype='float32') elif isinstance(preds, list): for k in range(len(preds)): if isinstance(preds[k], dict): preds[k] = to_float32(preds[k]) elif isinstance(preds[k], list): preds[k] = to_float32(preds[k]) - elif isinstance(preds[k], paddle.Tensor): - preds[k] = preds[k].astype(paddle.float32) - elif isinstance(preds[k], paddle.Tensor): - preds = preds.astype(paddle.float32) + else: + preds[k] = paddle.to_tensor(preds[k], dtype='float32') + else: + preds = paddle.to_tensor(preds, dtype='float32') return preds