Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleOCR
提交
f01dbb56
P
PaddleOCR
项目概览
s920243400
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f01dbb56
编写于
12月 20, 2021
作者:
文幕地方
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add LayoutLM ser
上级
a0a0a363
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
185 addition
and
68 deletion
+185
-68
ppstructure/vqa/README.md
ppstructure/vqa/README.md
+11
-5
ppstructure/vqa/eval_ser.py
ppstructure/vqa/eval_ser.py
+33
-14
ppstructure/vqa/infer_re.py
ppstructure/vqa/infer_re.py
+6
-3
ppstructure/vqa/infer_ser.py
ppstructure/vqa/infer_ser.py
+31
-16
ppstructure/vqa/infer_ser_e2e.py
ppstructure/vqa/infer_ser_e2e.py
+33
-16
ppstructure/vqa/losses.py
ppstructure/vqa/losses.py
+35
-0
ppstructure/vqa/train_ser.py
ppstructure/vqa/train_ser.py
+34
-14
ppstructure/vqa/utils.py
ppstructure/vqa/utils.py
+2
-0
未找到文件。
ppstructure/vqa/README.md
浏览文件 @
f01dbb56
...
...
@@ -18,12 +18,13 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
## 1 性能
我们在
[
XFUN
](
https://github.com/doc-analysis/XFUND
)
评估
数据集上对算法进行了评估,性能如下
我们在
[
XFUN
](
https://github.com/doc-analysis/XFUND
)
的中文
数据集上对算法进行了评估,性能如下
|任务| f1 | 模型下载地址|
|:---:|:---:| :---:|
|SER|0.9056|
[
链接
](
https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar
)
|
|RE|0.7113|
[
链接
](
https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar
)
|
| 模型 | 任务 | f1 | 模型下载地址 |
|:---:|:---:|:---:| :---:|
| LayoutXLM | RE | 0.7113 |
[
链接
](
https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar
)
|
| LayoutXLM | SER | 0.9056 |
[
链接
](
https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar
)
|
| LayoutLM | SER | 0.78 |
[
链接
](
https://paddleocr.bj.bcebos.com/pplayout/LayoutLM_ser_pretrained.tar
)
|
...
...
@@ -135,6 +136,7 @@ wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
```
shell
python3.7 train_ser.py
\
--model_name_or_path
"layoutxlm-base-uncased"
\
--ser_model_type
"LayoutLM"
\
--train_data_dir
"XFUND/zh_train/image"
\
--train_label_path
"XFUND/zh_train/xfun_normalize_train.json"
\
--eval_data_dir
"XFUND/zh_val/image"
\
...
...
@@ -155,6 +157,7 @@ python3.7 train_ser.py \
```
shell
python3.7 train_ser.py
\
--model_name_or_path
"model_path"
\
--ser_model_type
"LayoutXLM"
\
--train_data_dir
"XFUND/zh_train/image"
\
--train_label_path
"XFUND/zh_train/xfun_normalize_train.json"
\
--eval_data_dir
"XFUND/zh_val/image"
\
...
...
@@ -175,6 +178,7 @@ python3.7 train_ser.py \
export
CUDA_VISIBLE_DEVICES
=
0
python3 eval_ser.py
\
--model_name_or_path
"PP-Layout_v1.0_ser_pretrained/"
\
--ser_model_type
"LayoutXLM"
\
--eval_data_dir
"XFUND/zh_val/image"
\
--eval_label_path
"XFUND/zh_val/xfun_normalize_val.json"
\
--per_gpu_eval_batch_size
8
\
...
...
@@ -190,6 +194,7 @@ python3 eval_ser.py \
export
CUDA_VISIBLE_DEVICES
=
0
python3.7 infer_ser.py
\
--model_name_or_path
"./PP-Layout_v1.0_ser_pretrained/"
\
--ser_model_type
"LayoutXLM"
\
--output_dir
"output_res/"
\
--infer_imgs
"XFUND/zh_val/image/"
\
--ocr_json_path
"XFUND/zh_val/xfun_normalize_val.json"
...
...
@@ -203,6 +208,7 @@ python3.7 infer_ser.py \
export
CUDA_VISIBLE_DEVICES
=
0
python3.7 infer_ser_e2e.py
\
--model_name_or_path
"./output/PP-Layout_v1.0_ser_pretrained/"
\
--ser_model_type
"LayoutXLM"
\
--max_seq_length
512
\
--output_dir
"output_res_e2e/"
\
--infer_imgs
"images/input/zh_val_0.jpg"
...
...
ppstructure/vqa/eval_ser.py
浏览文件 @
f01dbb56
...
...
@@ -29,11 +29,21 @@ import paddle
import
numpy
as
np
from
seqeval.metrics
import
classification_report
,
f1_score
,
precision_score
,
recall_score
from
paddlenlp.transformers
import
LayoutXLMModel
,
LayoutXLMTokenizer
,
LayoutXLMForTokenClassification
from
paddlenlp.transformers
import
LayoutLMModel
,
LayoutLMTokenizer
,
LayoutLMForTokenClassification
from
xfun
import
XFUNDataset
from
losses
import
SERLoss
from
utils
import
parse_args
,
get_bio_label_maps
,
print_arguments
from
ppocr.utils.logging
import
get_logger
MODELS
=
{
'LayoutXLM'
:
(
LayoutXLMTokenizer
,
LayoutXLMModel
,
LayoutXLMForTokenClassification
),
'LayoutLM'
:
(
LayoutLMTokenizer
,
LayoutLMModel
,
LayoutLMForTokenClassification
)
}
def
eval
(
args
):
logger
=
get_logger
()
...
...
@@ -42,9 +52,9 @@ def eval(args):
label2id_map
,
id2label_map
=
get_bio_label_maps
(
args
.
label_map_path
)
pad_token_label_id
=
paddle
.
nn
.
CrossEntropyLoss
().
ignore_index
tokenizer
=
LayoutXLMTokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
LayoutXLMForTokenClassification
.
from_pretrained
(
args
.
model_name_or_path
)
tokenizer
_class
,
base_model_class
,
model_class
=
MODELS
[
args
.
ser_model_type
]
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
)
eval_dataset
=
XFUNDataset
(
tokenizer
,
...
...
@@ -65,8 +75,11 @@ def eval(args):
use_shared_memory
=
True
,
collate_fn
=
None
,
)
results
,
_
=
evaluate
(
args
,
model
,
tokenizer
,
eval_dataloader
,
label2id_map
,
id2label_map
,
pad_token_label_id
,
logger
)
loss_class
=
SERLoss
(
len
(
label2id_map
))
results
,
_
=
evaluate
(
args
,
model
,
tokenizer
,
loss_class
,
eval_dataloader
,
label2id_map
,
id2label_map
,
pad_token_label_id
,
logger
)
logger
.
info
(
results
)
...
...
@@ -74,6 +87,7 @@ def eval(args):
def
evaluate
(
args
,
model
,
tokenizer
,
loss_class
,
eval_dataloader
,
label2id_map
,
id2label_map
,
...
...
@@ -88,24 +102,29 @@ def evaluate(args,
model
.
eval
()
for
idx
,
batch
in
enumerate
(
eval_dataloader
):
with
paddle
.
no_grad
():
if
args
.
ser_model_type
==
'LayoutLM'
:
if
'image'
in
batch
:
batch
.
pop
(
'image'
)
labels
=
batch
.
pop
(
'labels'
)
outputs
=
model
(
**
batch
)
tmp_eval_loss
,
logits
=
outputs
[:
2
]
if
args
.
ser_model_type
==
'LayoutXLM'
:
outputs
=
outputs
[
0
]
loss
=
loss_class
(
labels
,
outputs
,
batch
[
'attention_mask'
])
tmp_eval_loss
=
tmp_eval_
loss
.
mean
()
loss
=
loss
.
mean
()
if
paddle
.
distributed
.
get_rank
()
==
0
:
logger
.
info
(
"[Eval]process: {}/{}, loss: {:.5f}"
.
format
(
idx
,
len
(
eval_dataloader
),
tmp_eval_
loss
.
numpy
()[
0
]))
idx
,
len
(
eval_dataloader
),
loss
.
numpy
()[
0
]))
eval_loss
+=
tmp_eval_
loss
.
item
()
eval_loss
+=
loss
.
item
()
nb_eval_steps
+=
1
if
preds
is
None
:
preds
=
logi
ts
.
numpy
()
out_label_ids
=
batch
[
"labels"
]
.
numpy
()
preds
=
outpu
ts
.
numpy
()
out_label_ids
=
labels
.
numpy
()
else
:
preds
=
np
.
append
(
preds
,
logits
.
numpy
(),
axis
=
0
)
out_label_ids
=
np
.
append
(
out_label_ids
,
batch
[
"labels"
].
numpy
(),
axis
=
0
)
preds
=
np
.
append
(
preds
,
outputs
.
numpy
(),
axis
=
0
)
out_label_ids
=
np
.
append
(
out_label_ids
,
labels
.
numpy
(),
axis
=
0
)
eval_loss
=
eval_loss
/
nb_eval_steps
preds
=
np
.
argmax
(
preds
,
axis
=
2
)
...
...
ppstructure/vqa/infer_re.py
浏览文件 @
f01dbb56
...
...
@@ -56,7 +56,11 @@ def infer(args):
ocr_info_list
=
load_ocr
(
args
.
eval_data_dir
,
args
.
eval_label_path
)
for
idx
,
batch
in
enumerate
(
eval_dataloader
):
logger
.
info
(
"[Infer] process: {}/{}"
.
format
(
idx
,
len
(
eval_dataloader
)))
save_img_path
=
os
.
path
.
join
(
args
.
output_dir
,
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))[
0
]
+
"_re.jpg"
)
logger
.
info
(
"[Infer] process: {}/{}, save_result to {}"
.
format
(
idx
,
len
(
eval_dataloader
),
save_img_path
))
with
paddle
.
no_grad
():
outputs
=
model
(
**
batch
)
pred_relations
=
outputs
[
'pred_relations'
]
...
...
@@ -85,8 +89,7 @@ def infer(args):
img
=
cv2
.
imread
(
image_path
)
img_show
=
draw_re_results
(
img
,
result
)
save_path
=
os
.
path
.
join
(
args
.
output_dir
,
os
.
path
.
basename
(
image_path
))
cv2
.
imwrite
(
save_path
,
img_show
)
cv2
.
imwrite
(
save_img_path
,
img_show
)
def
load_ocr
(
img_folder
,
json_path
):
...
...
ppstructure/vqa/infer_ser.py
浏览文件 @
f01dbb56
...
...
@@ -24,6 +24,14 @@ import paddle
# relative reference
from
utils
import
parse_args
,
get_image_file_list
,
draw_ser_results
,
get_bio_label_maps
from
paddlenlp.transformers
import
LayoutXLMModel
,
LayoutXLMTokenizer
,
LayoutXLMForTokenClassification
from
paddlenlp.transformers
import
LayoutLMModel
,
LayoutLMTokenizer
,
LayoutLMForTokenClassification
MODELS
=
{
'LayoutXLM'
:
(
LayoutXLMTokenizer
,
LayoutXLMModel
,
LayoutXLMForTokenClassification
),
'LayoutLM'
:
(
LayoutLMTokenizer
,
LayoutLMModel
,
LayoutLMForTokenClassification
)
}
def
pad_sentences
(
tokenizer
,
...
...
@@ -217,10 +225,10 @@ def infer(args):
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
# init token and model
tokenizer
=
LayoutXLMTokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
# model = LayoutXLMModel
.from_pretrained(args.model_name_or_path)
model
=
LayoutXLMForTokenClassification
.
from_pretrained
(
args
.
model_name_or_path
)
tokenizer
_class
,
base_model_class
,
model_class
=
MODELS
[
args
.
ser_model_type
]
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
)
model
.
eval
()
# load ocr results json
...
...
@@ -240,7 +248,10 @@ def infer(args):
"w"
,
encoding
=
'utf-8'
)
as
fout
:
for
idx
,
img_path
in
enumerate
(
infer_imgs
):
print
(
"process: [{}/{}]"
.
format
(
idx
,
len
(
infer_imgs
),
img_path
))
save_img_path
=
os
.
path
.
join
(
args
.
output_dir
,
os
.
path
.
basename
(
img_path
))
print
(
"process: [{}/{}], save_result to {}"
.
format
(
idx
,
len
(
infer_imgs
),
save_img_path
))
img
=
cv2
.
imread
(
img_path
)
...
...
@@ -250,15 +261,21 @@ def infer(args):
ori_img
=
img
,
ocr_info
=
ocr_info
,
max_seq_len
=
args
.
max_seq_length
)
if
args
.
ser_model_type
==
'LayoutLM'
:
preds
=
model
(
input_ids
=
inputs
[
"input_ids"
],
bbox
=
inputs
[
"bbox"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
attention_mask
=
inputs
[
"attention_mask"
])
elif
args
.
ser_model_type
==
'LayoutXLM'
:
preds
=
model
(
input_ids
=
inputs
[
"input_ids"
],
bbox
=
inputs
[
"bbox"
],
image
=
inputs
[
"image"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
attention_mask
=
inputs
[
"attention_mask"
])
preds
=
preds
[
0
]
outputs
=
model
(
input_ids
=
inputs
[
"input_ids"
],
bbox
=
inputs
[
"bbox"
],
image
=
inputs
[
"image"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
attention_mask
=
inputs
[
"attention_mask"
])
preds
=
outputs
[
0
]
preds
=
postprocess
(
inputs
[
"attention_mask"
],
preds
,
args
.
label_map_path
)
ocr_info
=
merge_preds_list_with_ocr_info
(
...
...
@@ -271,9 +288,7 @@ def infer(args):
},
ensure_ascii
=
False
)
+
"
\n
"
)
img_res
=
draw_ser_results
(
img
,
ocr_info
)
cv2
.
imwrite
(
os
.
path
.
join
(
args
.
output_dir
,
os
.
path
.
basename
(
img_path
)),
img_res
)
cv2
.
imwrite
(
save_img_path
,
img_res
)
return
...
...
ppstructure/vqa/infer_ser_e2e.py
浏览文件 @
f01dbb56
...
...
@@ -22,12 +22,20 @@ from PIL import Image
import
paddle
from
paddlenlp.transformers
import
LayoutXLMModel
,
LayoutXLMTokenizer
,
LayoutXLMForTokenClassification
from
paddlenlp.transformers
import
LayoutLMModel
,
LayoutLMTokenizer
,
LayoutLMForTokenClassification
# relative reference
from
utils
import
parse_args
,
get_image_file_list
,
draw_ser_results
,
get_bio_label_maps
from
utils
import
pad_sentences
,
split_page
,
preprocess
,
postprocess
,
merge_preds_list_with_ocr_info
MODELS
=
{
'LayoutXLM'
:
(
LayoutXLMTokenizer
,
LayoutXLMModel
,
LayoutXLMForTokenClassification
),
'LayoutLM'
:
(
LayoutLMTokenizer
,
LayoutLMModel
,
LayoutLMForTokenClassification
)
}
def
trans_poly_to_bbox
(
poly
):
x1
=
np
.
min
([
p
[
0
]
for
p
in
poly
])
...
...
@@ -50,14 +58,15 @@ def parse_ocr_info_for_ser(ocr_result):
class
SerPredictor
(
object
):
def
__init__
(
self
,
args
):
self
.
args
=
args
self
.
max_seq_length
=
args
.
max_seq_length
# init ser token and model
self
.
tokenizer
=
LayoutXLMTokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
self
.
model
=
LayoutXLMForTokenClassification
.
from_pretrained
(
tokenizer_class
,
base_model_class
,
model_class
=
MODELS
[
args
.
ser_model_type
]
self
.
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
self
.
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
)
self
.
model
.
eval
()
# init ocr_engine
...
...
@@ -89,14 +98,21 @@ class SerPredictor(object):
ocr_info
=
ocr_info
,
max_seq_len
=
self
.
max_seq_length
)
outputs
=
self
.
model
(
input_ids
=
inputs
[
"input_ids"
],
bbox
=
inputs
[
"bbox"
],
image
=
inputs
[
"image"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
attention_mask
=
inputs
[
"attention_mask"
])
if
args
.
ser_model_type
==
'LayoutLM'
:
preds
=
self
.
model
(
input_ids
=
inputs
[
"input_ids"
],
bbox
=
inputs
[
"bbox"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
attention_mask
=
inputs
[
"attention_mask"
])
elif
args
.
ser_model_type
==
'LayoutXLM'
:
preds
=
self
.
model
(
input_ids
=
inputs
[
"input_ids"
],
bbox
=
inputs
[
"bbox"
],
image
=
inputs
[
"image"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
attention_mask
=
inputs
[
"attention_mask"
])
preds
=
preds
[
0
]
preds
=
outputs
[
0
]
preds
=
postprocess
(
inputs
[
"attention_mask"
],
preds
,
self
.
id2label_map
)
ocr_info
=
merge_preds_list_with_ocr_info
(
ocr_info
,
inputs
[
"segment_offset_id"
],
preds
,
...
...
@@ -118,7 +134,11 @@ if __name__ == "__main__":
"w"
,
encoding
=
'utf-8'
)
as
fout
:
for
idx
,
img_path
in
enumerate
(
infer_imgs
):
print
(
"process: [{}/{}], {}"
.
format
(
idx
,
len
(
infer_imgs
),
img_path
))
save_img_path
=
os
.
path
.
join
(
args
.
output_dir
,
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))[
0
]
+
"_ser.jpg"
)
print
(
"process: [{}/{}], save_result to {}"
.
format
(
idx
,
len
(
infer_imgs
),
save_img_path
))
img
=
cv2
.
imread
(
img_path
)
...
...
@@ -129,7 +149,4 @@ if __name__ == "__main__":
},
ensure_ascii
=
False
)
+
"
\n
"
)
img_res
=
draw_ser_results
(
img
,
result
)
cv2
.
imwrite
(
os
.
path
.
join
(
args
.
output_dir
,
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))[
0
]
+
"_ser.jpg"
),
img_res
)
cv2
.
imwrite
(
save_img_path
,
img_res
)
ppstructure/vqa/losses.py
0 → 100644
浏览文件 @
f01dbb56
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle
import
nn
class
SERLoss
(
nn
.
Layer
):
def
__init__
(
self
,
num_classes
):
super
().
__init__
()
self
.
loss_class
=
nn
.
CrossEntropyLoss
()
self
.
num_classes
=
num_classes
self
.
ignore_index
=
self
.
loss_class
.
ignore_index
def
forward
(
self
,
labels
,
outputs
,
attention_mask
):
if
attention_mask
is
not
None
:
active_loss
=
attention_mask
.
reshape
([
-
1
,
])
==
1
active_outputs
=
outputs
.
reshape
(
[
-
1
,
self
.
num_classes
])[
active_loss
]
active_labels
=
labels
.
reshape
([
-
1
,
])[
active_loss
]
loss
=
self
.
loss_class
(
active_outputs
,
active_labels
)
else
:
loss
=
self
.
loss_class
(
outputs
.
reshape
([
-
1
,
self
.
num_classes
]),
labels
.
reshape
([
-
1
,
]))
return
loss
ppstructure/vqa/train_ser.py
浏览文件 @
f01dbb56
...
...
@@ -29,11 +29,21 @@ import paddle
import
numpy
as
np
from
seqeval.metrics
import
classification_report
,
f1_score
,
precision_score
,
recall_score
from
paddlenlp.transformers
import
LayoutXLMModel
,
LayoutXLMTokenizer
,
LayoutXLMForTokenClassification
from
paddlenlp.transformers
import
LayoutLMModel
,
LayoutLMTokenizer
,
LayoutLMForTokenClassification
from
xfun
import
XFUNDataset
from
utils
import
parse_args
,
get_bio_label_maps
,
print_arguments
,
set_seed
from
eval_ser
import
evaluate
from
losses
import
SERLoss
from
ppocr.utils.logging
import
get_logger
MODELS
=
{
'LayoutXLM'
:
(
LayoutXLMTokenizer
,
LayoutXLMModel
,
LayoutXLMForTokenClassification
),
'LayoutLM'
:
(
LayoutLMTokenizer
,
LayoutLMModel
,
LayoutLMForTokenClassification
)
}
def
train
(
args
):
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
...
...
@@ -44,22 +54,24 @@ def train(args):
print_arguments
(
args
,
logger
)
label2id_map
,
id2label_map
=
get_bio_label_maps
(
args
.
label_map_path
)
pad_token_label_id
=
paddle
.
nn
.
CrossEntropyLoss
().
ignore_index
loss_class
=
SERLoss
(
len
(
label2id_map
))
pad_token_label_id
=
loss_class
.
ignore_index
# dist mode
if
distributed
:
paddle
.
distributed
.
init_parallel_env
()
tokenizer
=
LayoutXLMTokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
tokenizer_class
,
base_model_class
,
model_class
=
MODELS
[
args
.
ser_model_type
]
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
if
not
args
.
resume
:
model
=
LayoutXLMModel
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
LayoutXLMForTokenClassification
(
model
,
num_classes
=
len
(
label2id_map
),
dropout
=
None
)
base_model
=
base_model_class
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
model_class
(
base_
model
,
num_classes
=
len
(
label2id_map
),
dropout
=
None
)
logger
.
info
(
'train from scratch'
)
else
:
logger
.
info
(
'resume from {}'
.
format
(
args
.
model_name_or_path
))
model
=
LayoutXLMForTokenClassification
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
)
# dist mode
if
distributed
:
...
...
@@ -153,12 +165,19 @@ def train(args):
for
step
,
batch
in
enumerate
(
train_dataloader
):
train_reader_cost
+=
time
.
time
()
-
reader_start
if
args
.
ser_model_type
==
'LayoutLM'
:
if
'image'
in
batch
:
batch
.
pop
(
'image'
)
labels
=
batch
.
pop
(
'labels'
)
train_start
=
time
.
time
()
outputs
=
model
(
**
batch
)
train_run_cost
+=
time
.
time
()
-
train_start
if
args
.
ser_model_type
==
'LayoutXLM'
:
outputs
=
outputs
[
0
]
loss
=
loss_class
(
labels
,
outputs
,
batch
[
'attention_mask'
])
# model outputs are always tuple in ppnlp (see doc)
loss
=
outputs
[
0
]
loss
=
loss
.
mean
()
loss
.
backward
()
tr_loss
+=
loss
.
item
()
...
...
@@ -166,7 +185,7 @@ def train(args):
lr_scheduler
.
step
()
# Update learning rate schedule
optimizer
.
clear_grad
()
global_step
+=
1
total_samples
+=
batch
[
'i
mage
'
].
shape
[
0
]
total_samples
+=
batch
[
'i
nput_ids
'
].
shape
[
0
]
if
rank
==
0
and
step
%
print_step
==
0
:
logger
.
info
(
...
...
@@ -186,9 +205,9 @@ def train(args):
if
rank
==
0
and
args
.
eval_steps
>
0
and
global_step
%
args
.
eval_steps
==
0
and
args
.
evaluate_during_training
:
# Log metrics
# Only evaluate when single GPU otherwise metrics may not average well
results
,
_
=
evaluate
(
args
,
model
,
tokenizer
,
eval_dataloader
,
label2id_map
,
id2label
_map
,
pad_token_label_id
,
logger
)
results
,
_
=
evaluate
(
args
,
model
,
tokenizer
,
loss_class
,
eval_dataloader
,
label2id
_map
,
id2label_map
,
pad_token_label_id
,
logger
)
if
best_metrics
is
None
or
results
[
"f1"
]
>=
best_metrics
[
"f1"
]:
best_metrics
=
copy
.
deepcopy
(
results
)
...
...
@@ -201,7 +220,8 @@ def train(args):
tokenizer
.
save_pretrained
(
output_dir
)
paddle
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to %s"
,
output_dir
)
logger
.
info
(
"Saving model checkpoint to {}"
.
format
(
output_dir
))
logger
.
info
(
"[epoch {}/{}][iter: {}/{}] results: {}"
.
format
(
epoch_id
,
args
.
num_train_epochs
,
step
,
...
...
@@ -219,7 +239,7 @@ def train(args):
model
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
paddle
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to
%s"
,
output_dir
)
logger
.
info
(
"Saving model checkpoint to
{}"
.
format
(
output_dir
)
)
return
global_step
,
tr_loss
/
global_step
...
...
ppstructure/vqa/utils.py
浏览文件 @
f01dbb56
...
...
@@ -350,6 +350,8 @@ def parse_args():
# yapf: disable
parser
.
add_argument
(
"--model_name_or_path"
,
default
=
None
,
type
=
str
,
required
=
True
,)
parser
.
add_argument
(
"--ser_model_type"
,
default
=
'LayoutXLM'
,
type
=
str
)
parser
.
add_argument
(
"--re_model_name_or_path"
,
default
=
None
,
type
=
str
,
required
=
False
,)
parser
.
add_argument
(
"--train_data_dir"
,
default
=
None
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录