未验证 提交 a6080a83 编写于 作者: Z zhoujun 提交者: GitHub

add LayoutLM ser (#4984)

* add LayoutLM ser

* add LayoutLM ser

* rm _

* Update README.md

* Update README.md
上级 88dee023
......@@ -18,12 +18,13 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
## 1 性能
我们在 [XFUN](https://github.com/doc-analysis/XFUND) 评估数据集上对算法进行了评估,性能如下
我们在 [XFUN](https://github.com/doc-analysis/XFUND) 的中文数据集上对算法进行了评估,性能如下
|任务| f1 | 模型下载地址|
|:---:|:---:| :---:|
|SER|0.9056| [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar)|
|RE|0.7113| [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar)|
| 模型 | 任务 | f1 | 模型下载地址 |
|:---:|:---:|:---:| :---:|
| LayoutXLM | RE | 0.7113 | [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) |
| LayoutXLM | SER | 0.9056 | [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) |
| LayoutLM | SER | 0.78 | [链接](https://paddleocr.bj.bcebos.com/pplayout/LayoutLM_ser_pretrained.tar) |
......@@ -135,6 +136,7 @@ wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
```shell
python3.7 train_ser.py \
--model_name_or_path "layoutxlm-base-uncased" \
--ser_model_type "LayoutXLM" \
--train_data_dir "XFUND/zh_train/image" \
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
--eval_data_dir "XFUND/zh_val/image" \
......@@ -155,6 +157,7 @@ python3.7 train_ser.py \
```shell
python3.7 train_ser.py \
--model_name_or_path "model_path" \
--ser_model_type "LayoutXLM" \
--train_data_dir "XFUND/zh_train/image" \
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
--eval_data_dir "XFUND/zh_val/image" \
......@@ -175,6 +178,7 @@ python3.7 train_ser.py \
export CUDA_VISIBLE_DEVICES=0
python3 eval_ser.py \
--model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
--ser_model_type "LayoutXLM" \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--per_gpu_eval_batch_size 8 \
......@@ -189,8 +193,9 @@ python3 eval_ser.py \
```shell
export CUDA_VISIBLE_DEVICES=0
python3.7 infer_ser.py \
--model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \
--output_dir "output_res/" \
--model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
--ser_model_type "LayoutXLM" \
--output_dir "output/ser/" \
--infer_imgs "XFUND/zh_val/image/" \
--ocr_json_path "XFUND/zh_val/xfun_normalize_val.json"
```
......@@ -202,9 +207,10 @@ python3.7 infer_ser.py \
```shell
export CUDA_VISIBLE_DEVICES=0
python3.7 infer_ser_e2e.py \
--model_name_or_path "./output/PP-Layout_v1.0_ser_pretrained/" \
--model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
--ser_model_type "LayoutXLM" \
--max_seq_length 512 \
--output_dir "output_res_e2e/" \
--output_dir "output/ser_e2e/" \
--infer_imgs "images/input/zh_val_0.jpg"
```
......@@ -273,12 +279,12 @@ python3 train_re.py \
```shell
export CUDA_VISIBLE_DEVICES=0
python3 eval_re.py \
--model_name_or_path "output/check/checkpoint-best" \
--model_name_or_path "PP-Layout_v1.0_re_pretrained/" \
--max_seq_length 512 \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--label_map_path 'labels/labels_ser.txt' \
--output_dir "output/re_test/" \
--output_dir "output/re/" \
--per_gpu_eval_batch_size 8 \
--num_workers 8 \
--seed 2048
......@@ -291,12 +297,12 @@ python3 eval_re.py \
```shell
export CUDA_VISIBLE_DEVICES=0
python3 infer_re.py \
--model_name_or_path "./PP-Layout_v1.0_re_pretrained/" \
--model_name_or_path "PP-Layout_v1.0_re_pretrained/" \
--max_seq_length 512 \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--label_map_path 'labels/labels_ser.txt' \
--output_dir "output_res" \
--output_dir "output/re/" \
--per_gpu_eval_batch_size 1 \
--seed 2048
```
......@@ -308,10 +314,11 @@ python3 infer_re.py \
```shell
export CUDA_VISIBLE_DEVICES=0
python3.7 infer_ser_re_e2e.py \
--model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \
--re_model_name_or_path "./PP-Layout_v1.0_re_pretrained/" \
--model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
--re_model_name_or_path "PP-Layout_v1.0_re_pretrained/" \
--ser_model_type "LayoutXLM" \
--max_seq_length 512 \
--output_dir "output_ser_re_e2e_train/" \
--output_dir "output/ser_re_e2e/" \
--infer_imgs "images/input/zh_val_21.jpg"
```
......
......@@ -29,11 +29,21 @@ import paddle
import numpy as np
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
from xfun import XFUNDataset
from losses import SERLoss
from utils import parse_args, get_bio_label_maps, print_arguments
from ppocr.utils.logging import get_logger
MODELS = {
'LayoutXLM':
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
'LayoutLM':
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
}
def eval(args):
logger = get_logger()
......@@ -42,9 +52,9 @@ def eval(args):
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
model = LayoutXLMForTokenClassification.from_pretrained(
args.model_name_or_path)
tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = model_class.from_pretrained(args.model_name_or_path)
eval_dataset = XFUNDataset(
tokenizer,
......@@ -65,8 +75,11 @@ def eval(args):
use_shared_memory=True,
collate_fn=None, )
results, _ = evaluate(args, model, tokenizer, eval_dataloader, label2id_map,
id2label_map, pad_token_label_id, logger)
loss_class = SERLoss(len(label2id_map))
results, _ = evaluate(args, model, tokenizer, loss_class, eval_dataloader,
label2id_map, id2label_map, pad_token_label_id,
logger)
logger.info(results)
......@@ -74,6 +87,7 @@ def eval(args):
def evaluate(args,
model,
tokenizer,
loss_class,
eval_dataloader,
label2id_map,
id2label_map,
......@@ -88,24 +102,29 @@ def evaluate(args,
model.eval()
for idx, batch in enumerate(eval_dataloader):
with paddle.no_grad():
if args.ser_model_type == 'LayoutLM':
if 'image' in batch:
batch.pop('image')
labels = batch.pop('labels')
outputs = model(**batch)
tmp_eval_loss, logits = outputs[:2]
if args.ser_model_type == 'LayoutXLM':
outputs = outputs[0]
loss = loss_class(labels, outputs, batch['attention_mask'])
tmp_eval_loss = tmp_eval_loss.mean()
loss = loss.mean()
if paddle.distributed.get_rank() == 0:
logger.info("[Eval]process: {}/{}, loss: {:.5f}".format(
idx, len(eval_dataloader), tmp_eval_loss.numpy()[0]))
idx, len(eval_dataloader), loss.numpy()[0]))
eval_loss += tmp_eval_loss.item()
eval_loss += loss.item()
nb_eval_steps += 1
if preds is None:
preds = logits.numpy()
out_label_ids = batch["labels"].numpy()
preds = outputs.numpy()
out_label_ids = labels.numpy()
else:
preds = np.append(preds, logits.numpy(), axis=0)
out_label_ids = np.append(
out_label_ids, batch["labels"].numpy(), axis=0)
preds = np.append(preds, outputs.numpy(), axis=0)
out_label_ids = np.append(out_label_ids, labels.numpy(), axis=0)
eval_loss = eval_loss / nb_eval_steps
preds = np.argmax(preds, axis=2)
......
export CUDA_VISIBLE_DEVICES=6
# python3.7 infer_ser_e2e.py \
# --model_name_or_path "output/ser_distributed/best_model" \
# --max_seq_length 512 \
# --output_dir "output_res_e2e/" \
# --infer_imgs "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val/zh_val_0.jpg"
# python3.7 infer_ser_re_e2e.py \
# --model_name_or_path "output/ser_distributed/best_model" \
# --re_model_name_or_path "output/re_test/best_model" \
# --max_seq_length 512 \
# --output_dir "output_ser_re_e2e_train/" \
# --infer_imgs "images/input/zh_val_21.jpg"
# python3.7 infer_ser.py \
# --model_name_or_path "output/ser_LayoutLM/best_model" \
# --ser_model_type "LayoutLM" \
# --output_dir "ser_LayoutLM/" \
# --infer_imgs "images/input/zh_val_21.jpg" \
# --ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json"
python3.7 infer_ser.py \
--model_name_or_path "output/ser_new/best_model" \
--ser_model_type "LayoutXLM" \
--output_dir "ser_new/" \
--infer_imgs "images/input/zh_val_21.jpg" \
--ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json"
# python3.7 infer_ser_e2e.py \
# --model_name_or_path "output/ser_new/best_model" \
# --ser_model_type "LayoutXLM" \
# --max_seq_length 512 \
# --output_dir "output/ser_new/" \
# --infer_imgs "images/input/zh_val_0.jpg"
# python3.7 infer_ser_e2e.py \
# --model_name_or_path "output/ser_LayoutLM/best_model" \
# --ser_model_type "LayoutLM" \
# --max_seq_length 512 \
# --output_dir "output/ser_LayoutLM/" \
# --infer_imgs "images/input/zh_val_0.jpg"
# python3 infer_re.py \
# --model_name_or_path "/ssd1/zhoujun20/VQA/PaddleOCR/ppstructure/vqa/output/re_test/best_model/" \
# --max_seq_length 512 \
# --eval_data_dir "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val" \
# --eval_label_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json" \
# --label_map_path 'labels/labels_ser.txt' \
# --output_dir "output_res" \
# --per_gpu_eval_batch_size 1 \
# --seed 2048
# python3.7 infer_ser_re_e2e.py \
# --model_name_or_path "output/ser_LayoutLM/best_model" \
# --ser_model_type "LayoutLM" \
# --re_model_name_or_path "output/re_new/best_model" \
# --max_seq_length 512 \
# --output_dir "output_ser_re_e2e/" \
# --infer_imgs "images/input/zh_val_21.jpg"
\ No newline at end of file
......@@ -56,15 +56,19 @@ def infer(args):
ocr_info_list = load_ocr(args.eval_data_dir, args.eval_label_path)
for idx, batch in enumerate(eval_dataloader):
logger.info("[Infer] process: {}/{}".format(idx, len(eval_dataloader)))
with paddle.no_grad():
outputs = model(**batch)
pred_relations = outputs['pred_relations']
ocr_info = ocr_info_list[idx]
image_path = ocr_info['image_path']
ocr_info = ocr_info['ocr_info']
save_img_path = os.path.join(
args.output_dir,
os.path.splitext(os.path.basename(image_path))[0] + "_re.jpg")
logger.info("[Infer] process: {}/{}, save result to {}".format(
idx, len(eval_dataloader), save_img_path))
with paddle.no_grad():
outputs = model(**batch)
pred_relations = outputs['pred_relations']
# 根据entity里的信息,做token解码后去过滤不要的ocr_info
ocr_info = filter_bg_by_txt(ocr_info, batch, tokenizer)
......@@ -85,8 +89,7 @@ def infer(args):
img = cv2.imread(image_path)
img_show = draw_re_results(img, result)
save_path = os.path.join(args.output_dir, os.path.basename(image_path))
cv2.imwrite(save_path, img_show)
cv2.imwrite(save_img_path, img_show)
def load_ocr(img_folder, json_path):
......
......@@ -24,6 +24,14 @@ import paddle
# relative reference
from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
MODELS = {
'LayoutXLM':
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
'LayoutLM':
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
}
def pad_sentences(tokenizer,
......@@ -217,10 +225,10 @@ def infer(args):
os.makedirs(args.output_dir, exist_ok=True)
# init token and model
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
# model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
model = LayoutXLMForTokenClassification.from_pretrained(
args.model_name_or_path)
tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = model_class.from_pretrained(args.model_name_or_path)
model.eval()
# load ocr results json
......@@ -240,7 +248,10 @@ def infer(args):
"w",
encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs):
print("process: [{}/{}]".format(idx, len(infer_imgs), img_path))
save_img_path = os.path.join(args.output_dir,
os.path.basename(img_path))
print("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
img = cv2.imread(img_path)
......@@ -250,15 +261,21 @@ def infer(args):
ori_img=img,
ocr_info=ocr_info,
max_seq_len=args.max_seq_length)
if args.ser_model_type == 'LayoutLM':
preds = model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
elif args.ser_model_type == 'LayoutXLM':
preds = model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
image=inputs["image"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
preds = preds[0]
outputs = model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
image=inputs["image"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
preds = outputs[0]
preds = postprocess(inputs["attention_mask"], preds,
args.label_map_path)
ocr_info = merge_preds_list_with_ocr_info(
......@@ -271,9 +288,7 @@ def infer(args):
}, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img, ocr_info)
cv2.imwrite(
os.path.join(args.output_dir, os.path.basename(img_path)),
img_res)
cv2.imwrite(save_img_path, img_res)
return
......
......@@ -22,12 +22,20 @@ from PIL import Image
import paddle
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
# relative reference
from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info
MODELS = {
'LayoutXLM':
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
'LayoutLM':
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
}
def trans_poly_to_bbox(poly):
x1 = np.min([p[0] for p in poly])
......@@ -50,14 +58,15 @@ def parse_ocr_info_for_ser(ocr_result):
class SerPredictor(object):
def __init__(self, args):
self.args = args
self.max_seq_length = args.max_seq_length
# init ser token and model
self.tokenizer = LayoutXLMTokenizer.from_pretrained(
args.model_name_or_path)
self.model = LayoutXLMForTokenClassification.from_pretrained(
tokenizer_class, base_model_class, model_class = MODELS[
args.ser_model_type]
self.tokenizer = tokenizer_class.from_pretrained(
args.model_name_or_path)
self.model = model_class.from_pretrained(args.model_name_or_path)
self.model.eval()
# init ocr_engine
......@@ -89,14 +98,21 @@ class SerPredictor(object):
ocr_info=ocr_info,
max_seq_len=self.max_seq_length)
outputs = self.model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
image=inputs["image"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
if self.args.ser_model_type == 'LayoutLM':
preds = self.model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
elif self.args.ser_model_type == 'LayoutXLM':
preds = self.model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
image=inputs["image"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
preds = preds[0]
preds = outputs[0]
preds = postprocess(inputs["attention_mask"], preds, self.id2label_map)
ocr_info = merge_preds_list_with_ocr_info(
ocr_info, inputs["segment_offset_id"], preds,
......@@ -118,7 +134,11 @@ if __name__ == "__main__":
"w",
encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs):
print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path))
save_img_path = os.path.join(
args.output_dir,
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
print("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
img = cv2.imread(img_path)
......@@ -129,7 +149,4 @@ if __name__ == "__main__":
}, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img, result)
cv2.imwrite(
os.path.join(args.output_dir,
os.path.splitext(os.path.basename(img_path))[0] +
"_ser.jpg"), img_res)
cv2.imwrite(save_img_path, img_res)
......@@ -117,7 +117,11 @@ if __name__ == "__main__":
"w",
encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs):
print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path))
save_img_path = os.path.join(
args.output_dir,
os.path.splitext(os.path.basename(img_path))[0] + "_re.jpg")
print("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
img = cv2.imread(img_path)
......@@ -128,7 +132,4 @@ if __name__ == "__main__":
}, ensure_ascii=False) + "\n")
img_res = draw_re_results(img, result)
cv2.imwrite(
os.path.join(args.output_dir,
os.path.splitext(os.path.basename(img_path))[0] +
"_re.jpg"), img_res)
cv2.imwrite(save_img_path, img_res)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import nn
class SERLoss(nn.Layer):
def __init__(self, num_classes):
super().__init__()
self.loss_class = nn.CrossEntropyLoss()
self.num_classes = num_classes
self.ignore_index = self.loss_class.ignore_index
def forward(self, labels, outputs, attention_mask):
if attention_mask is not None:
active_loss = attention_mask.reshape([-1, ]) == 1
active_outputs = outputs.reshape(
[-1, self.num_classes])[active_loss]
active_labels = labels.reshape([-1, ])[active_loss]
loss = self.loss_class(active_outputs, active_labels)
else:
loss = self.loss_class(
outputs.reshape([-1, self.num_classes]), labels.reshape([-1, ]))
return loss
......@@ -29,11 +29,21 @@ import paddle
import numpy as np
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
from xfun import XFUNDataset
from utils import parse_args, get_bio_label_maps, print_arguments, set_seed
from eval_ser import evaluate
from losses import SERLoss
from ppocr.utils.logging import get_logger
MODELS = {
'LayoutXLM':
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
'LayoutLM':
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
}
def train(args):
os.makedirs(args.output_dir, exist_ok=True)
......@@ -44,22 +54,24 @@ def train(args):
print_arguments(args, logger)
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
loss_class = SERLoss(len(label2id_map))
pad_token_label_id = loss_class.ignore_index
# dist mode
if distributed:
paddle.distributed.init_parallel_env()
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
if not args.resume:
model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
model = LayoutXLMForTokenClassification(
model, num_classes=len(label2id_map), dropout=None)
base_model = base_model_class.from_pretrained(args.model_name_or_path)
model = model_class(
base_model, num_classes=len(label2id_map), dropout=None)
logger.info('train from scratch')
else:
logger.info('resume from {}'.format(args.model_name_or_path))
model = LayoutXLMForTokenClassification.from_pretrained(
args.model_name_or_path)
model = model_class.from_pretrained(args.model_name_or_path)
# dist mode
if distributed:
......@@ -153,12 +165,19 @@ def train(args):
for step, batch in enumerate(train_dataloader):
train_reader_cost += time.time() - reader_start
if args.ser_model_type == 'LayoutLM':
if 'image' in batch:
batch.pop('image')
labels = batch.pop('labels')
train_start = time.time()
outputs = model(**batch)
train_run_cost += time.time() - train_start
if args.ser_model_type == 'LayoutXLM':
outputs = outputs[0]
loss = loss_class(labels, outputs, batch['attention_mask'])
# model outputs are always tuple in ppnlp (see doc)
loss = outputs[0]
loss = loss.mean()
loss.backward()
tr_loss += loss.item()
......@@ -166,7 +185,7 @@ def train(args):
lr_scheduler.step() # Update learning rate schedule
optimizer.clear_grad()
global_step += 1
total_samples += batch['image'].shape[0]
total_samples += batch['input_ids'].shape[0]
if rank == 0 and step % print_step == 0:
logger.info(
......@@ -186,9 +205,9 @@ def train(args):
if rank == 0 and args.eval_steps > 0 and global_step % args.eval_steps == 0 and args.evaluate_during_training:
# Log metrics
# Only evaluate when single GPU otherwise metrics may not average well
results, _ = evaluate(args, model, tokenizer, eval_dataloader,
label2id_map, id2label_map,
pad_token_label_id, logger)
results, _ = evaluate(args, model, tokenizer, loss_class,
eval_dataloader, label2id_map,
id2label_map, pad_token_label_id, logger)
if best_metrics is None or results["f1"] >= best_metrics["f1"]:
best_metrics = copy.deepcopy(results)
......@@ -201,7 +220,8 @@ def train(args):
tokenizer.save_pretrained(output_dir)
paddle.save(args,
os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to %s", output_dir)
logger.info("Saving model checkpoint to {}".format(
output_dir))
logger.info("[epoch {}/{}][iter: {}/{}] results: {}".format(
epoch_id, args.num_train_epochs, step,
......@@ -219,7 +239,7 @@ def train(args):
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
paddle.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to %s", output_dir)
logger.info("Saving model checkpoint to {}".format(output_dir))
return global_step, tr_loss / global_step
......
......@@ -350,6 +350,8 @@ def parse_args():
# yapf: disable
parser.add_argument("--model_name_or_path",
default=None, type=str, required=True,)
parser.add_argument("--ser_model_type",
default='LayoutXLM', type=str)
parser.add_argument("--re_model_name_or_path",
default=None, type=str, required=False,)
parser.add_argument("--train_data_dir", default=None,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册