提交 f01dbb56 编写于 作者: 文幕地方's avatar 文幕地方

add LayoutLM ser

上级 a0a0a363
...@@ -18,12 +18,13 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进 ...@@ -18,12 +18,13 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
## 1 性能 ## 1 性能
我们在 [XFUN](https://github.com/doc-analysis/XFUND) 评估数据集上对算法进行了评估,性能如下 我们在 [XFUN](https://github.com/doc-analysis/XFUND) 的中文数据集上对算法进行了评估,性能如下
|任务| f1 | 模型下载地址| | 模型 | 任务 | f1 | 模型下载地址 |
|:---:|:---:| :---:| |:---:|:---:|:---:| :---:|
|SER|0.9056| [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar)| | LayoutXLM | RE | 0.7113 | [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) |
|RE|0.7113| [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar)| | LayoutXLM | SER | 0.9056 | [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) |
| LayoutLM | SER | 0.78 | [链接](https://paddleocr.bj.bcebos.com/pplayout/LayoutLM_ser_pretrained.tar) |
...@@ -135,6 +136,7 @@ wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar ...@@ -135,6 +136,7 @@ wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
```shell ```shell
python3.7 train_ser.py \ python3.7 train_ser.py \
--model_name_or_path "layoutxlm-base-uncased" \ --model_name_or_path "layoutxlm-base-uncased" \
--ser_model_type "LayoutLM" \
--train_data_dir "XFUND/zh_train/image" \ --train_data_dir "XFUND/zh_train/image" \
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \ --train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
--eval_data_dir "XFUND/zh_val/image" \ --eval_data_dir "XFUND/zh_val/image" \
...@@ -155,6 +157,7 @@ python3.7 train_ser.py \ ...@@ -155,6 +157,7 @@ python3.7 train_ser.py \
```shell ```shell
python3.7 train_ser.py \ python3.7 train_ser.py \
--model_name_or_path "model_path" \ --model_name_or_path "model_path" \
--ser_model_type "LayoutXLM" \
--train_data_dir "XFUND/zh_train/image" \ --train_data_dir "XFUND/zh_train/image" \
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \ --train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
--eval_data_dir "XFUND/zh_val/image" \ --eval_data_dir "XFUND/zh_val/image" \
...@@ -175,6 +178,7 @@ python3.7 train_ser.py \ ...@@ -175,6 +178,7 @@ python3.7 train_ser.py \
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python3 eval_ser.py \ python3 eval_ser.py \
--model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \ --model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
--ser_model_type "LayoutXLM" \
--eval_data_dir "XFUND/zh_val/image" \ --eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \ --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--per_gpu_eval_batch_size 8 \ --per_gpu_eval_batch_size 8 \
...@@ -190,6 +194,7 @@ python3 eval_ser.py \ ...@@ -190,6 +194,7 @@ python3 eval_ser.py \
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python3.7 infer_ser.py \ python3.7 infer_ser.py \
--model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \ --model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \
--ser_model_type "LayoutXLM" \
--output_dir "output_res/" \ --output_dir "output_res/" \
--infer_imgs "XFUND/zh_val/image/" \ --infer_imgs "XFUND/zh_val/image/" \
--ocr_json_path "XFUND/zh_val/xfun_normalize_val.json" --ocr_json_path "XFUND/zh_val/xfun_normalize_val.json"
...@@ -203,6 +208,7 @@ python3.7 infer_ser.py \ ...@@ -203,6 +208,7 @@ python3.7 infer_ser.py \
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python3.7 infer_ser_e2e.py \ python3.7 infer_ser_e2e.py \
--model_name_or_path "./output/PP-Layout_v1.0_ser_pretrained/" \ --model_name_or_path "./output/PP-Layout_v1.0_ser_pretrained/" \
--ser_model_type "LayoutXLM" \
--max_seq_length 512 \ --max_seq_length 512 \
--output_dir "output_res_e2e/" \ --output_dir "output_res_e2e/" \
--infer_imgs "images/input/zh_val_0.jpg" --infer_imgs "images/input/zh_val_0.jpg"
......
...@@ -29,11 +29,21 @@ import paddle ...@@ -29,11 +29,21 @@ import paddle
import numpy as np import numpy as np
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
from xfun import XFUNDataset from xfun import XFUNDataset
from losses import SERLoss
from utils import parse_args, get_bio_label_maps, print_arguments from utils import parse_args, get_bio_label_maps, print_arguments
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
MODELS = {
'LayoutXLM':
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
'LayoutLM':
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
}
def eval(args): def eval(args):
logger = get_logger() logger = get_logger()
...@@ -42,9 +52,9 @@ def eval(args): ...@@ -42,9 +52,9 @@ def eval(args):
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path) label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path) tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
model = LayoutXLMForTokenClassification.from_pretrained( tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
args.model_name_or_path) model = model_class.from_pretrained(args.model_name_or_path)
eval_dataset = XFUNDataset( eval_dataset = XFUNDataset(
tokenizer, tokenizer,
...@@ -65,8 +75,11 @@ def eval(args): ...@@ -65,8 +75,11 @@ def eval(args):
use_shared_memory=True, use_shared_memory=True,
collate_fn=None, ) collate_fn=None, )
results, _ = evaluate(args, model, tokenizer, eval_dataloader, label2id_map, loss_class = SERLoss(len(label2id_map))
id2label_map, pad_token_label_id, logger)
results, _ = evaluate(args, model, tokenizer, loss_class, eval_dataloader,
label2id_map, id2label_map, pad_token_label_id,
logger)
logger.info(results) logger.info(results)
...@@ -74,6 +87,7 @@ def eval(args): ...@@ -74,6 +87,7 @@ def eval(args):
def evaluate(args, def evaluate(args,
model, model,
tokenizer, tokenizer,
loss_class,
eval_dataloader, eval_dataloader,
label2id_map, label2id_map,
id2label_map, id2label_map,
...@@ -88,24 +102,29 @@ def evaluate(args, ...@@ -88,24 +102,29 @@ def evaluate(args,
model.eval() model.eval()
for idx, batch in enumerate(eval_dataloader): for idx, batch in enumerate(eval_dataloader):
with paddle.no_grad(): with paddle.no_grad():
if args.ser_model_type == 'LayoutLM':
if 'image' in batch:
batch.pop('image')
labels = batch.pop('labels')
outputs = model(**batch) outputs = model(**batch)
tmp_eval_loss, logits = outputs[:2] if args.ser_model_type == 'LayoutXLM':
outputs = outputs[0]
loss = loss_class(labels, outputs, batch['attention_mask'])
tmp_eval_loss = tmp_eval_loss.mean() loss = loss.mean()
if paddle.distributed.get_rank() == 0: if paddle.distributed.get_rank() == 0:
logger.info("[Eval]process: {}/{}, loss: {:.5f}".format( logger.info("[Eval]process: {}/{}, loss: {:.5f}".format(
idx, len(eval_dataloader), tmp_eval_loss.numpy()[0])) idx, len(eval_dataloader), loss.numpy()[0]))
eval_loss += tmp_eval_loss.item() eval_loss += loss.item()
nb_eval_steps += 1 nb_eval_steps += 1
if preds is None: if preds is None:
preds = logits.numpy() preds = outputs.numpy()
out_label_ids = batch["labels"].numpy() out_label_ids = labels.numpy()
else: else:
preds = np.append(preds, logits.numpy(), axis=0) preds = np.append(preds, outputs.numpy(), axis=0)
out_label_ids = np.append( out_label_ids = np.append(out_label_ids, labels.numpy(), axis=0)
out_label_ids, batch["labels"].numpy(), axis=0)
eval_loss = eval_loss / nb_eval_steps eval_loss = eval_loss / nb_eval_steps
preds = np.argmax(preds, axis=2) preds = np.argmax(preds, axis=2)
......
...@@ -56,7 +56,11 @@ def infer(args): ...@@ -56,7 +56,11 @@ def infer(args):
ocr_info_list = load_ocr(args.eval_data_dir, args.eval_label_path) ocr_info_list = load_ocr(args.eval_data_dir, args.eval_label_path)
for idx, batch in enumerate(eval_dataloader): for idx, batch in enumerate(eval_dataloader):
logger.info("[Infer] process: {}/{}".format(idx, len(eval_dataloader))) save_img_path = os.path.join(
args.output_dir,
os.path.splitext(os.path.basename(img_path))[0] + "_re.jpg")
logger.info("[Infer] process: {}/{}, save_result to {}".format(
idx, len(eval_dataloader), save_img_path))
with paddle.no_grad(): with paddle.no_grad():
outputs = model(**batch) outputs = model(**batch)
pred_relations = outputs['pred_relations'] pred_relations = outputs['pred_relations']
...@@ -85,8 +89,7 @@ def infer(args): ...@@ -85,8 +89,7 @@ def infer(args):
img = cv2.imread(image_path) img = cv2.imread(image_path)
img_show = draw_re_results(img, result) img_show = draw_re_results(img, result)
save_path = os.path.join(args.output_dir, os.path.basename(image_path)) cv2.imwrite(save_img_path, img_show)
cv2.imwrite(save_path, img_show)
def load_ocr(img_folder, json_path): def load_ocr(img_folder, json_path):
......
...@@ -24,6 +24,14 @@ import paddle ...@@ -24,6 +24,14 @@ import paddle
# relative reference # relative reference
from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
MODELS = {
'LayoutXLM':
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
'LayoutLM':
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
}
def pad_sentences(tokenizer, def pad_sentences(tokenizer,
...@@ -217,10 +225,10 @@ def infer(args): ...@@ -217,10 +225,10 @@ def infer(args):
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
# init token and model # init token and model
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path) tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
# model = LayoutXLMModel.from_pretrained(args.model_name_or_path) tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = LayoutXLMForTokenClassification.from_pretrained( model = model_class.from_pretrained(args.model_name_or_path)
args.model_name_or_path)
model.eval() model.eval()
# load ocr results json # load ocr results json
...@@ -240,7 +248,10 @@ def infer(args): ...@@ -240,7 +248,10 @@ def infer(args):
"w", "w",
encoding='utf-8') as fout: encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs): for idx, img_path in enumerate(infer_imgs):
print("process: [{}/{}]".format(idx, len(infer_imgs), img_path)) save_img_path = os.path.join(args.output_dir,
os.path.basename(img_path))
print("process: [{}/{}], save_result to {}".format(
idx, len(infer_imgs), save_img_path))
img = cv2.imread(img_path) img = cv2.imread(img_path)
...@@ -250,15 +261,21 @@ def infer(args): ...@@ -250,15 +261,21 @@ def infer(args):
ori_img=img, ori_img=img,
ocr_info=ocr_info, ocr_info=ocr_info,
max_seq_len=args.max_seq_length) max_seq_len=args.max_seq_length)
if args.ser_model_type == 'LayoutLM':
preds = model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
elif args.ser_model_type == 'LayoutXLM':
preds = model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
image=inputs["image"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
preds = preds[0]
outputs = model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
image=inputs["image"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
preds = outputs[0]
preds = postprocess(inputs["attention_mask"], preds, preds = postprocess(inputs["attention_mask"], preds,
args.label_map_path) args.label_map_path)
ocr_info = merge_preds_list_with_ocr_info( ocr_info = merge_preds_list_with_ocr_info(
...@@ -271,9 +288,7 @@ def infer(args): ...@@ -271,9 +288,7 @@ def infer(args):
}, ensure_ascii=False) + "\n") }, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img, ocr_info) img_res = draw_ser_results(img, ocr_info)
cv2.imwrite( cv2.imwrite(save_img_path, img_res)
os.path.join(args.output_dir, os.path.basename(img_path)),
img_res)
return return
......
...@@ -22,12 +22,20 @@ from PIL import Image ...@@ -22,12 +22,20 @@ from PIL import Image
import paddle import paddle
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
# relative reference # relative reference
from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info from utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info
MODELS = {
'LayoutXLM':
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
'LayoutLM':
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
}
def trans_poly_to_bbox(poly): def trans_poly_to_bbox(poly):
x1 = np.min([p[0] for p in poly]) x1 = np.min([p[0] for p in poly])
...@@ -50,14 +58,15 @@ def parse_ocr_info_for_ser(ocr_result): ...@@ -50,14 +58,15 @@ def parse_ocr_info_for_ser(ocr_result):
class SerPredictor(object): class SerPredictor(object):
def __init__(self, args): def __init__(self, args):
self.args = args
self.max_seq_length = args.max_seq_length self.max_seq_length = args.max_seq_length
# init ser token and model # init ser token and model
self.tokenizer = LayoutXLMTokenizer.from_pretrained( tokenizer_class, base_model_class, model_class = MODELS[
args.model_name_or_path) args.ser_model_type]
self.model = LayoutXLMForTokenClassification.from_pretrained( self.tokenizer = tokenizer_class.from_pretrained(
args.model_name_or_path) args.model_name_or_path)
self.model = model_class.from_pretrained(args.model_name_or_path)
self.model.eval() self.model.eval()
# init ocr_engine # init ocr_engine
...@@ -89,14 +98,21 @@ class SerPredictor(object): ...@@ -89,14 +98,21 @@ class SerPredictor(object):
ocr_info=ocr_info, ocr_info=ocr_info,
max_seq_len=self.max_seq_length) max_seq_len=self.max_seq_length)
outputs = self.model( if args.ser_model_type == 'LayoutLM':
input_ids=inputs["input_ids"], preds = self.model(
bbox=inputs["bbox"], input_ids=inputs["input_ids"],
image=inputs["image"], bbox=inputs["bbox"],
token_type_ids=inputs["token_type_ids"], token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"]) attention_mask=inputs["attention_mask"])
elif args.ser_model_type == 'LayoutXLM':
preds = self.model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
image=inputs["image"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
preds = preds[0]
preds = outputs[0]
preds = postprocess(inputs["attention_mask"], preds, self.id2label_map) preds = postprocess(inputs["attention_mask"], preds, self.id2label_map)
ocr_info = merge_preds_list_with_ocr_info( ocr_info = merge_preds_list_with_ocr_info(
ocr_info, inputs["segment_offset_id"], preds, ocr_info, inputs["segment_offset_id"], preds,
...@@ -118,7 +134,11 @@ if __name__ == "__main__": ...@@ -118,7 +134,11 @@ if __name__ == "__main__":
"w", "w",
encoding='utf-8') as fout: encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs): for idx, img_path in enumerate(infer_imgs):
print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path)) save_img_path = os.path.join(
args.output_dir,
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
print("process: [{}/{}], save_result to {}".format(
idx, len(infer_imgs), save_img_path))
img = cv2.imread(img_path) img = cv2.imread(img_path)
...@@ -129,7 +149,4 @@ if __name__ == "__main__": ...@@ -129,7 +149,4 @@ if __name__ == "__main__":
}, ensure_ascii=False) + "\n") }, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img, result) img_res = draw_ser_results(img, result)
cv2.imwrite( cv2.imwrite(save_img_path, img_res)
os.path.join(args.output_dir,
os.path.splitext(os.path.basename(img_path))[0] +
"_ser.jpg"), img_res)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import nn
class SERLoss(nn.Layer):
def __init__(self, num_classes):
super().__init__()
self.loss_class = nn.CrossEntropyLoss()
self.num_classes = num_classes
self.ignore_index = self.loss_class.ignore_index
def forward(self, labels, outputs, attention_mask):
if attention_mask is not None:
active_loss = attention_mask.reshape([-1, ]) == 1
active_outputs = outputs.reshape(
[-1, self.num_classes])[active_loss]
active_labels = labels.reshape([-1, ])[active_loss]
loss = self.loss_class(active_outputs, active_labels)
else:
loss = self.loss_class(
outputs.reshape([-1, self.num_classes]), labels.reshape([-1, ]))
return loss
...@@ -29,11 +29,21 @@ import paddle ...@@ -29,11 +29,21 @@ import paddle
import numpy as np import numpy as np
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
from xfun import XFUNDataset from xfun import XFUNDataset
from utils import parse_args, get_bio_label_maps, print_arguments, set_seed from utils import parse_args, get_bio_label_maps, print_arguments, set_seed
from eval_ser import evaluate from eval_ser import evaluate
from losses import SERLoss
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
MODELS = {
'LayoutXLM':
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
'LayoutLM':
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
}
def train(args): def train(args):
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
...@@ -44,22 +54,24 @@ def train(args): ...@@ -44,22 +54,24 @@ def train(args):
print_arguments(args, logger) print_arguments(args, logger)
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path) label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index loss_class = SERLoss(len(label2id_map))
pad_token_label_id = loss_class.ignore_index
# dist mode # dist mode
if distributed: if distributed:
paddle.distributed.init_parallel_env() paddle.distributed.init_parallel_env()
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path) tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
if not args.resume: if not args.resume:
model = LayoutXLMModel.from_pretrained(args.model_name_or_path) base_model = base_model_class.from_pretrained(args.model_name_or_path)
model = LayoutXLMForTokenClassification( model = model_class(
model, num_classes=len(label2id_map), dropout=None) base_model, num_classes=len(label2id_map), dropout=None)
logger.info('train from scratch') logger.info('train from scratch')
else: else:
logger.info('resume from {}'.format(args.model_name_or_path)) logger.info('resume from {}'.format(args.model_name_or_path))
model = LayoutXLMForTokenClassification.from_pretrained( model = model_class.from_pretrained(args.model_name_or_path)
args.model_name_or_path)
# dist mode # dist mode
if distributed: if distributed:
...@@ -153,12 +165,19 @@ def train(args): ...@@ -153,12 +165,19 @@ def train(args):
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
train_reader_cost += time.time() - reader_start train_reader_cost += time.time() - reader_start
if args.ser_model_type == 'LayoutLM':
if 'image' in batch:
batch.pop('image')
labels = batch.pop('labels')
train_start = time.time() train_start = time.time()
outputs = model(**batch) outputs = model(**batch)
train_run_cost += time.time() - train_start train_run_cost += time.time() - train_start
if args.ser_model_type == 'LayoutXLM':
outputs = outputs[0]
loss = loss_class(labels, outputs, batch['attention_mask'])
# model outputs are always tuple in ppnlp (see doc) # model outputs are always tuple in ppnlp (see doc)
loss = outputs[0]
loss = loss.mean() loss = loss.mean()
loss.backward() loss.backward()
tr_loss += loss.item() tr_loss += loss.item()
...@@ -166,7 +185,7 @@ def train(args): ...@@ -166,7 +185,7 @@ def train(args):
lr_scheduler.step() # Update learning rate schedule lr_scheduler.step() # Update learning rate schedule
optimizer.clear_grad() optimizer.clear_grad()
global_step += 1 global_step += 1
total_samples += batch['image'].shape[0] total_samples += batch['input_ids'].shape[0]
if rank == 0 and step % print_step == 0: if rank == 0 and step % print_step == 0:
logger.info( logger.info(
...@@ -186,9 +205,9 @@ def train(args): ...@@ -186,9 +205,9 @@ def train(args):
if rank == 0 and args.eval_steps > 0 and global_step % args.eval_steps == 0 and args.evaluate_during_training: if rank == 0 and args.eval_steps > 0 and global_step % args.eval_steps == 0 and args.evaluate_during_training:
# Log metrics # Log metrics
# Only evaluate when single GPU otherwise metrics may not average well # Only evaluate when single GPU otherwise metrics may not average well
results, _ = evaluate(args, model, tokenizer, eval_dataloader, results, _ = evaluate(args, model, tokenizer, loss_class,
label2id_map, id2label_map, eval_dataloader, label2id_map,
pad_token_label_id, logger) id2label_map, pad_token_label_id, logger)
if best_metrics is None or results["f1"] >= best_metrics["f1"]: if best_metrics is None or results["f1"] >= best_metrics["f1"]:
best_metrics = copy.deepcopy(results) best_metrics = copy.deepcopy(results)
...@@ -201,7 +220,8 @@ def train(args): ...@@ -201,7 +220,8 @@ def train(args):
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
paddle.save(args, paddle.save(args,
os.path.join(output_dir, "training_args.bin")) os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to %s", output_dir) logger.info("Saving model checkpoint to {}".format(
output_dir))
logger.info("[epoch {}/{}][iter: {}/{}] results: {}".format( logger.info("[epoch {}/{}][iter: {}/{}] results: {}".format(
epoch_id, args.num_train_epochs, step, epoch_id, args.num_train_epochs, step,
...@@ -219,7 +239,7 @@ def train(args): ...@@ -219,7 +239,7 @@ def train(args):
model.save_pretrained(output_dir) model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
paddle.save(args, os.path.join(output_dir, "training_args.bin")) paddle.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to %s", output_dir) logger.info("Saving model checkpoint to {}".format(output_dir))
return global_step, tr_loss / global_step return global_step, tr_loss / global_step
......
...@@ -350,6 +350,8 @@ def parse_args(): ...@@ -350,6 +350,8 @@ def parse_args():
# yapf: disable # yapf: disable
parser.add_argument("--model_name_or_path", parser.add_argument("--model_name_or_path",
default=None, type=str, required=True,) default=None, type=str, required=True,)
parser.add_argument("--ser_model_type",
default='LayoutXLM', type=str)
parser.add_argument("--re_model_name_or_path", parser.add_argument("--re_model_name_or_path",
default=None, type=str, required=False,) default=None, type=str, required=False,)
parser.add_argument("--train_data_dir", default=None, parser.add_argument("--train_data_dir", default=None,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册