Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
53a34ed2
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
53a34ed2
编写于
12月 05, 2022
作者:
C
ceci3
提交者:
GitHub
12月 05, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add uie (#1520)
* add uie * update
上级
94c0e670
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
816 addition
and
0 deletion
+816
-0
example/auto_compression/nlp/README.md
example/auto_compression/nlp/README.md
+15
-0
example/auto_compression/nlp/configs/uie/uie_base.yaml
example/auto_compression/nlp/configs/uie/uie_base.yaml
+22
-0
example/auto_compression/nlp/paddle_inference_eval_uie.py
example/auto_compression/nlp/paddle_inference_eval_uie.py
+401
-0
example/auto_compression/nlp/run_uie.py
example/auto_compression/nlp/run_uie.py
+378
-0
未找到文件。
example/auto_compression/nlp/README.md
浏览文件 @
53a34ed2
...
@@ -30,6 +30,12 @@
...
@@ -30,6 +30,12 @@
| ERNIE 3.0-Medium | Base模型| 75.35 | 57.45 | 60.17 | 81.16 | 77.19 | 80.59 | 79.70 | 73.09 |
| ERNIE 3.0-Medium | Base模型| 75.35 | 57.45 | 60.17 | 81.16 | 77.19 | 80.59 | 79.70 | 73.09 |
| ERNIE 3.0-Medium | 剪枝+量化训练| 74.17 | 56.84 | 59.75 | 80.54 | 76.03 | 76.97 | 80.80 | 72.16 |
| ERNIE 3.0-Medium | 剪枝+量化训练| 74.17 | 56.84 | 59.75 | 80.54 | 76.03 | 76.97 | 80.80 | 72.16 |
| 模型 | 策略 | 报销工单数据 |
| UIE-base | Base模型 |
[
91.83
](
https://bj.bcebos.com/v1/paddle-slim-models/act/uie_base.tar
)
|
| UIE-base | 量化训练 |
[
95.80
](
https://bj.bcebos.com/v1/paddle-slim-models/act/uie_base_qat_model.tar
)
|
注:UIE模型精度为在5-shot(每个类别包含5条标注数据)数据集上进行模型微调的结果,压缩后精度更高可能原因是过拟合在当前数据集。
模型在不同任务上平均精度以及加速对比如下:
模型在不同任务上平均精度以及加速对比如下:
| 模型 |策略| Accuracy(avg) | 预测时延
<sup><small>
FP32
</small><sup><br><sup>
| 预测时延
<sup><small>
FP16
</small><sup><br><sup>
| 预测时延
<sup><small>
INT8
</small><sup><br><sup>
| 加速比 |
| 模型 |策略| Accuracy(avg) | 预测时延
<sup><small>
FP32
</small><sup><br><sup>
| 预测时延
<sup><small>
FP16
</small><sup><br><sup>
| 预测时延
<sup><small>
INT8
</small><sup><br><sup>
| 加速比 |
|:-------:|:--------:|:----------:|:------------:|:------:|:------:|:------:|
|:-------:|:--------:|:----------:|:------------:|:------:|:------:|:------:|
...
@@ -37,6 +43,8 @@
...
@@ -37,6 +43,8 @@
|PP-MiniLM| 剪枝+离线量化 | 71.85 | - | - | 15.76ms | 5.99x |
|PP-MiniLM| 剪枝+离线量化 | 71.85 | - | - | 15.76ms | 5.99x |
|ERNIE 3.0-Medium| Base模型| 73.09 | 89.71ms | 20.76ms | - | - |
|ERNIE 3.0-Medium| Base模型| 73.09 | 89.71ms | 20.76ms | - | - |
|ERNIE 3.0-Medium| 剪枝+量化训练 | 72.16 | - | - | 14.08ms | 6.37x |
|ERNIE 3.0-Medium| 剪枝+量化训练 | 72.16 | - | - | 14.08ms | 6.37x |
|UIE-base| Base模型| 91.83 | 42.66ms | 14.23ms | - | - |
|UIE-base| 量化训练 | 95.80 | - | - | 10.94ms | 3.90x |
性能测试的环境为
性能测试的环境为
-
硬件:NVIDIA Tesla T4 单卡
-
硬件:NVIDIA Tesla T4 单卡
...
@@ -86,6 +94,7 @@ pip install paddlenlp
...
@@ -86,6 +94,7 @@ pip install paddlenlp
|:------:|:------:|:------:|:------:|:------:|:-----------:|:------:|:------:|
|:------:|:------:|:------:|:------:|:------:|:-----------:|:------:|:------:|
| PP-MiniLM |
[
afqmc
](
https://bj.bcebos.com/v1/paddle-slim-models/act/afqmc.tar
)
|
[
tnews
](
https://bj.bcebos.com/v1/paddle-slim-models/act/tnews.tar
)
|
[
iflytek
](
https://bj.bcebos.com/v1/paddle-slim-models/act/iflytek.tar
)
|
[
cmnli
](
https://bj.bcebos.com/v1/paddle-slim-models/act/cmnli.tar
)
|
[
ocnli
](
https://bj.bcebos.com/v1/paddle-slim-models/act/ocnli.tar
)
|
[
cluewsc2020
](
https://bj.bcebos.com/v1/paddle-slim-models/act/cluewsc.tar
)
|
[
csl
](
https://bj.bcebos.com/v1/paddle-slim-models/act/csl.tar
)
|
| PP-MiniLM |
[
afqmc
](
https://bj.bcebos.com/v1/paddle-slim-models/act/afqmc.tar
)
|
[
tnews
](
https://bj.bcebos.com/v1/paddle-slim-models/act/tnews.tar
)
|
[
iflytek
](
https://bj.bcebos.com/v1/paddle-slim-models/act/iflytek.tar
)
|
[
cmnli
](
https://bj.bcebos.com/v1/paddle-slim-models/act/cmnli.tar
)
|
[
ocnli
](
https://bj.bcebos.com/v1/paddle-slim-models/act/ocnli.tar
)
|
[
cluewsc2020
](
https://bj.bcebos.com/v1/paddle-slim-models/act/cluewsc.tar
)
|
[
csl
](
https://bj.bcebos.com/v1/paddle-slim-models/act/csl.tar
)
|
| ERNIE 3.0-Medium |
[
afqmc
](
https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/AFQMC.tar
)
|
[
tnews
](
https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/TNEWS.tar
)
|
[
iflytek
](
https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/IFLYTEK.tar
)
|
[
cmnli
](
https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/CMNLI.tar
)
|
[
ocnli
](
https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/OCNLI.tar
)
|
[
cluewsc2020
](
https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/CLUEWSC2020.tar
)
|
[
csl
](
https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/CSL.tar
)
|
| ERNIE 3.0-Medium |
[
afqmc
](
https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/AFQMC.tar
)
|
[
tnews
](
https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/TNEWS.tar
)
|
[
iflytek
](
https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/IFLYTEK.tar
)
|
[
cmnli
](
https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/CMNLI.tar
)
|
[
ocnli
](
https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/OCNLI.tar
)
|
[
cluewsc2020
](
https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/CLUEWSC2020.tar
)
|
[
csl
](
https://bj.bcebos.com/v1/paddle-slim-models/act/NLP/ernie3.0-medium/fp32_models/CSL.tar
)
|
| UIE-base |
[
报销工单
](
https://bj.bcebos.com/v1/paddle-slim-models/act/uie_base.tar
)
|
从上表获得模型超链接, 并用以下命令下载推理模型文件:
从上表获得模型超链接, 并用以下命令下载推理模型文件:
...
@@ -109,6 +118,12 @@ export CUDA_VISIBLE_DEVICES=0
...
@@ -109,6 +118,12 @@ export CUDA_VISIBLE_DEVICES=0
python run.py
--config_path
=
'./configs/pp-minilm/auto/afqmc.yaml'
--save_dir
=
'./save_afqmc_pruned/'
python run.py
--config_path
=
'./configs/pp-minilm/auto/afqmc.yaml'
--save_dir
=
'./save_afqmc_pruned/'
```
```
自动压缩UIE系列模型需要使用 run_uie.py 脚本启动,会使用接口
```paddleslim.auto_compression.AutoCompression```
对模型进行自动压缩。配置config文件中训练部分的参数,将任务名称、模型类型、数据集名称、压缩参数传入,配置完成后便可对模型进行蒸馏量化训练。
```
shell
export
CUDA_VISIBLE_DEVICES
=
0
python run_uie.py
--config_path
=
'./configs/uie/uie_base.yaml'
--save_dir
=
'./save_uie_qat/'
```
如仅需验证模型精度,或验证压缩之后模型精度,在启动
```run.py```
脚本时,将配置文件中模型文件夹
```model_dir```
改为压缩之后保存的文件夹路径
```./save_afqmc_pruned```
,命令加上
```--eval True```
即可:
如仅需验证模型精度,或验证压缩之后模型精度,在启动
```run.py```
脚本时,将配置文件中模型文件夹
```model_dir```
改为压缩之后保存的文件夹路径
```./save_afqmc_pruned```
,命令加上
```--eval True```
即可:
```
shell
```
shell
export
CUDA_VISIBLE_DEVICES
=
0
export
CUDA_VISIBLE_DEVICES
=
0
...
...
example/auto_compression/nlp/configs/uie/uie_base.yaml
0 → 100644
浏览文件 @
53a34ed2
Global
:
model_dir
:
./UIE
model_filename
:
inference.pdmodel
params_filename
:
inference.pdiparams
batch_size
:
1
max_seq_length
:
512
train_data
:
./data/train.txt
dev_data
:
./data/dev.txt
TrainConfig
:
epochs
:
200
eval_iter
:
100
learning_rate
:
1.0e-5
optimizer_builder
:
optimizer
:
type
:
AdamW
weight_decay
:
0.01
QuantAware
:
onnx_format
:
True
Distillation
:
alpha
:
1.0
loss
:
l2
example/auto_compression/nlp/paddle_inference_eval_uie.py
0 → 100644
浏览文件 @
53a34ed2
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
time
import
json
import
sys
from
functools
import
partial
import
distutils.util
import
numpy
as
np
import
paddle
from
paddle
import
inference
from
paddle.metric
import
Metric
,
Accuracy
,
Precision
,
Recall
from
paddlenlp.transformers
import
AutoModelForTokenClassification
,
AutoTokenizer
from
paddlenlp.datasets
import
load_dataset
from
paddlenlp.data
import
Stack
,
Tuple
,
Pad
,
Dict
from
paddlenlp.metrics
import
SpanEvaluator
def
parse_args
():
"""
parse_args func
"""
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_path"
,
default
=
"./afqmc"
,
type
=
str
,
required
=
True
,
help
=
"The path prefix of inference model to be used."
,
)
parser
.
add_argument
(
"--model_filename"
,
type
=
str
,
default
=
"inference.pdmodel"
,
help
=
"model file name"
)
parser
.
add_argument
(
"--params_filename"
,
type
=
str
,
default
=
"inference.pdiparams"
,
help
=
"params file name"
)
parser
.
add_argument
(
"--dev_data"
,
default
=
"./data/dev.txt"
,
type
=
str
,
help
=
"The data file of validation."
,
)
parser
.
add_argument
(
"--device"
,
default
=
"gpu"
,
choices
=
[
"gpu"
,
"cpu"
],
help
=
"Device selected for inference."
,
)
parser
.
add_argument
(
"--batch_size"
,
default
=
32
,
type
=
int
,
help
=
"Batch size for predict."
,
)
parser
.
add_argument
(
"--max_seq_length"
,
default
=
128
,
type
=
int
,
help
=
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
,
)
parser
.
add_argument
(
"--perf_warmup_steps"
,
default
=
20
,
type
=
int
,
help
=
"Warmup steps for performance test."
,
)
parser
.
add_argument
(
"--use_trt"
,
action
=
"store_true"
,
help
=
"Whether to use inference engin TensorRT."
,
)
parser
.
add_argument
(
"--precision"
,
type
=
str
,
default
=
"fp32"
,
choices
=
[
"fp32"
,
"fp16"
,
"int8"
],
help
=
"The precision of inference. It can be 'fp32', 'fp16' or 'int8'. Default is 'fp16'."
,
)
parser
.
add_argument
(
"--use_mkldnn"
,
type
=
bool
,
default
=
False
,
help
=
"Whether use mkldnn or not."
)
parser
.
add_argument
(
"--cpu_threads"
,
type
=
int
,
default
=
1
,
help
=
"Num of cpu threads."
)
args
=
parser
.
parse_args
()
return
args
def
map_offset
(
ori_offset
,
offset_mapping
):
"""
map ori offset to token offset
"""
for
index
,
span
in
enumerate
(
offset_mapping
):
if
span
[
0
]
<=
ori_offset
<
span
[
1
]:
return
index
return
-
1
def
_convert_example
(
example
,
tokenizer
,
max_seq_length
=
128
):
encoded_inputs
=
tokenizer
(
text
=
[
example
[
"prompt"
]],
text_pair
=
[
example
[
"content"
]],
truncation
=
True
,
max_seq_len
=
max_seq_length
,
pad_to_max_seq_len
=
True
,
return_attention_mask
=
True
,
return_position_ids
=
True
,
return_dict
=
False
,
return_offsets_mapping
=
True
)
encoded_inputs
=
encoded_inputs
[
0
]
offset_mapping
=
[
list
(
x
)
for
x
in
encoded_inputs
[
"offset_mapping"
]]
bias
=
0
for
index
in
range
(
1
,
len
(
offset_mapping
)):
mapping
=
offset_mapping
[
index
]
if
mapping
[
0
]
==
0
and
mapping
[
1
]
==
0
and
bias
==
0
:
bias
=
offset_mapping
[
index
-
1
][
1
]
+
1
# Includes [SEP] token
if
mapping
[
0
]
==
0
and
mapping
[
1
]
==
0
:
continue
offset_mapping
[
index
][
0
]
+=
bias
offset_mapping
[
index
][
1
]
+=
bias
start_ids
=
[
0.0
for
x
in
range
(
max_seq_length
)]
end_ids
=
[
0.0
for
x
in
range
(
max_seq_length
)]
for
item
in
example
[
"result_list"
]:
start
=
map_offset
(
item
[
"start"
]
+
bias
,
offset_mapping
)
end
=
map_offset
(
item
[
"end"
]
-
1
+
bias
,
offset_mapping
)
start_ids
[
start
]
=
1.0
end_ids
[
end
]
=
1.0
tokenized_output
=
{
"input_ids"
:
encoded_inputs
[
"input_ids"
],
"token_type_ids"
:
encoded_inputs
[
"token_type_ids"
],
"start_ids"
:
start_ids
,
"end_ids"
:
end_ids
}
return
tokenized_output
class
Predictor
(
object
):
"""
Inference Predictor class
"""
def
__init__
(
self
,
predictor
,
input_handles
,
output_handles
):
self
.
predictor
=
predictor
self
.
input_handles
=
input_handles
self
.
output_handles
=
output_handles
@
classmethod
def
create_predictor
(
cls
,
args
):
"""
create_predictor func
"""
cls
.
rerun_flag
=
False
config
=
paddle
.
inference
.
Config
(
os
.
path
.
join
(
args
.
model_path
,
args
.
model_filename
),
os
.
path
.
join
(
args
.
model_path
,
args
.
params_filename
))
if
args
.
device
==
"gpu"
:
# set GPU configs accordingly
config
.
enable_use_gpu
(
100
,
0
)
cls
.
device
=
paddle
.
set_device
(
"gpu"
)
else
:
config
.
disable_gpu
()
config
.
set_cpu_math_library_num_threads
(
args
.
cpu_threads
)
config
.
switch_ir_optim
()
if
args
.
use_mkldnn
:
config
.
enable_mkldnn
()
if
args
.
precision
==
"int8"
:
config
.
enable_mkldnn_int8
()
precision_map
=
{
"int8"
:
inference
.
PrecisionType
.
Int8
,
"fp32"
:
inference
.
PrecisionType
.
Float32
,
"fp16"
:
inference
.
PrecisionType
.
Half
,
}
if
args
.
precision
in
precision_map
.
keys
()
and
args
.
use_trt
:
config
.
enable_tensorrt_engine
(
workspace_size
=
1
<<
30
,
max_batch_size
=
args
.
batch_size
,
min_subgraph_size
=
5
,
precision_mode
=
precision_map
[
args
.
precision
],
use_static
=
True
,
use_calib_mode
=
False
,
)
dynamic_shape_file
=
os
.
path
.
join
(
args
.
model_path
,
"dynamic_shape.txt"
)
if
os
.
path
.
exists
(
dynamic_shape_file
):
config
.
enable_tuned_tensorrt_dynamic_shape
(
dynamic_shape_file
,
True
)
print
(
"trt set dynamic shape done!"
)
else
:
config
.
collect_shape_range_info
(
dynamic_shape_file
)
print
(
"Start collect dynamic shape..."
)
cls
.
rerun_flag
=
True
predictor
=
paddle
.
inference
.
create_predictor
(
config
)
input_handles
=
[
predictor
.
get_input_handle
(
name
)
for
name
in
predictor
.
get_input_names
()
]
output_handles
=
[
predictor
.
get_output_handle
(
name
)
for
name
in
predictor
.
get_output_names
()
]
return
cls
(
predictor
,
input_handles
,
output_handles
)
def
predict_batch
(
self
,
data
):
"""
predict from batch func
"""
for
input_field
,
input_handle
in
zip
(
data
,
self
.
input_handles
):
input_handle
.
copy_from_cpu
(
input_field
)
self
.
predictor
.
run
()
output
=
[
output_handle
.
copy_to_cpu
()
for
output_handle
in
self
.
output_handles
]
return
output
def
_convert_predict_batch
(
self
,
args
,
data
,
tokenizer
,
batchify_fn
):
examples
=
[]
for
example
in
data
:
example
=
_convert_example
(
example
,
tokenizer
,
max_seq_length
=
args
.
max_seq_length
)
examples
.
append
(
example
)
return
examples
def
predict
(
self
,
dataset
,
tokenizer
,
batchify_fn
,
args
):
"""
predict func
"""
batches
=
[
dataset
[
idx
:
idx
+
args
.
batch_size
]
for
idx
in
range
(
0
,
len
(
dataset
),
args
.
batch_size
)
]
for
i
,
batch
in
enumerate
(
batches
):
examples
=
self
.
_convert_predict_batch
(
args
,
batch
,
tokenizer
,
batchify_fn
)
input_ids
,
segment_ids
,
start_ids
,
end_ids
=
batchify_fn
(
examples
)
output
=
self
.
predict_batch
([
input_ids
,
segment_ids
])
if
i
>
args
.
perf_warmup_steps
:
break
if
self
.
rerun_flag
:
return
metric
=
SpanEvaluator
()
metric
.
reset
()
predict_time
=
0.0
for
i
,
batch
in
enumerate
(
batches
):
examples
=
self
.
_convert_predict_batch
(
args
,
batch
,
tokenizer
,
batchify_fn
)
input_ids
,
segment_ids
,
start_ids
,
end_ids
=
batchify_fn
(
examples
)
start_time
=
time
.
time
()
output
=
self
.
predict_batch
([
input_ids
,
segment_ids
])
end_time
=
time
.
time
()
predict_time
+=
end_time
-
start_time
start_ids
=
paddle
.
to_tensor
(
np
.
array
(
start_ids
))
end_ids
=
paddle
.
to_tensor
(
np
.
array
(
end_ids
))
start_prob
=
paddle
.
to_tensor
(
output
[
0
])
end_prob
=
paddle
.
to_tensor
(
output
[
1
])
num_correct
,
num_infer
,
num_label
=
metric
.
compute
(
start_prob
,
end_prob
,
start_ids
,
end_ids
)
metric
.
update
(
num_correct
,
num_infer
,
num_label
)
sequences_num
=
i
*
args
.
batch_size
print
(
"[benchmark]batch size: {} Inference time per batch: {}ms, qps: {}."
.
format
(
args
.
batch_size
,
round
(
predict_time
*
1000
/
i
,
2
),
round
(
sequences_num
/
predict_time
,
2
),
))
precision
,
recall
,
f1
=
metric
.
accumulate
()
print
(
"[benchmark]f1: %s.
\n
"
%
(
f1
),
end
=
""
)
sys
.
stdout
.
flush
()
def
reader_proprecess
(
data_path
,
max_seq_len
=
128
):
"""
read json
"""
with
open
(
data_path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
for
line
in
f
:
json_line
=
json
.
loads
(
line
)
content
=
json_line
[
'content'
].
strip
()
prompt
=
json_line
[
'prompt'
]
# Model Input is aslike: [CLS] Prompt [SEP] Content [SEP]
# It include three summary tokens.
if
max_seq_len
<=
len
(
prompt
)
+
3
:
raise
ValueError
(
"The value of max_seq_len is too small, please set a larger value"
)
max_content_len
=
max_seq_len
-
len
(
prompt
)
-
3
if
len
(
content
)
<=
max_content_len
:
yield
json_line
else
:
result_list
=
json_line
[
'result_list'
]
json_lines
=
[]
accumulate
=
0
while
True
:
cur_result_list
=
[]
for
result
in
result_list
:
if
result
[
'start'
]
+
1
<=
max_content_len
<
result
[
'end'
]:
max_content_len
=
result
[
'start'
]
break
cur_content
=
content
[:
max_content_len
]
res_content
=
content
[
max_content_len
:]
while
True
:
if
len
(
result_list
)
==
0
:
break
elif
result_list
[
0
][
'end'
]
<=
max_content_len
:
if
result_list
[
0
][
'end'
]
>
0
:
cur_result
=
result_list
.
pop
(
0
)
cur_result_list
.
append
(
cur_result
)
else
:
cur_result_list
=
[
result
for
result
in
result_list
]
break
else
:
break
json_line
=
{
'content'
:
cur_content
,
'result_list'
:
cur_result_list
,
'prompt'
:
prompt
}
json_lines
.
append
(
json_line
)
for
result
in
result_list
:
if
result
[
'end'
]
<=
0
:
break
result
[
'start'
]
-=
max_content_len
result
[
'end'
]
-=
max_content_len
accumulate
+=
max_content_len
max_content_len
=
max_seq_len
-
len
(
prompt
)
-
3
if
len
(
res_content
)
==
0
:
break
elif
len
(
res_content
)
<
max_content_len
:
json_line
=
{
'content'
:
res_content
,
'result_list'
:
result_list
,
'prompt'
:
prompt
}
json_lines
.
append
(
json_line
)
break
else
:
content
=
res_content
for
json_line
in
json_lines
:
yield
json_line
def
main
():
"""
main func
"""
paddle
.
seed
(
42
)
args
=
parse_args
()
if
args
.
use_mkldnn
:
paddle
.
set_device
(
"cpu"
)
predictor
=
Predictor
.
create_predictor
(
args
)
dev_ds
=
load_dataset
(
reader_proprecess
,
data_path
=
args
.
dev_data
,
lazy
=
False
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_path
)
batchify_fn
=
lambda
samples
,
fn
=
Dict
({
'input_ids'
:
Pad
(
axis
=
0
,
pad_val
=
tokenizer
.
pad_token_id
),
# input
'token_type_ids'
:
Pad
(
axis
=
0
,
pad_val
=
tokenizer
.
pad_token_type_id
),
# segment
'start_ids'
:
Stack
(
dtype
=
"int64"
),
'end_ids'
:
Stack
(
dtype
=
"int64"
)}):
fn
(
samples
)
predictor
.
predict
(
dev_ds
,
tokenizer
,
batchify_fn
,
args
)
if
predictor
.
rerun_flag
:
print
(
"***** Collect dynamic shape done, Please rerun the program to get correct results. *****"
)
if
__name__
==
"__main__"
:
paddle
.
set_device
(
"cpu"
)
main
()
example/auto_compression/nlp/run_uie.py
0 → 100644
浏览文件 @
53a34ed2
import
os
import
sys
import
argparse
import
json
import
functools
from
functools
import
partial
import
numpy
as
np
import
shutil
import
paddle
import
paddle.nn
as
nn
from
paddle.io
import
Dataset
,
BatchSampler
,
DataLoader
from
paddlenlp.transformers
import
AutoModelForTokenClassification
,
AutoTokenizer
from
paddlenlp.datasets
import
load_dataset
from
paddlenlp.data
import
Stack
,
Tuple
,
Pad
,
Dict
from
paddlenlp.data.sampler
import
SamplerHelper
from
paddlenlp.metrics
import
SpanEvaluator
from
paddleslim.common
import
load_config
from
paddleslim.auto_compression.compressor
import
AutoCompression
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
'--config_path'
,
type
=
str
,
default
=
None
,
help
=
"path of compression strategy config."
,
required
=
True
)
parser
.
add_argument
(
'--save_dir'
,
type
=
str
,
default
=
'output'
,
help
=
"directory to save compressed model."
)
parser
.
add_argument
(
'--eval'
,
type
=
bool
,
default
=
False
,
help
=
"whether validate the model only."
)
return
parser
def
map_offset
(
ori_offset
,
offset_mapping
):
"""
map ori offset to token offset
"""
for
index
,
span
in
enumerate
(
offset_mapping
):
if
span
[
0
]
<=
ori_offset
<
span
[
1
]:
return
index
return
-
1
def
convert_example
(
example
,
tokenizer
,
max_seq_len
,
multilingual
=
True
,
is_test
=
False
):
"""
example: {
title
prompt
content
result_list
}
"""
encoded_inputs
=
tokenizer
(
text
=
[
example
[
"prompt"
]],
text_pair
=
[
example
[
"content"
]],
truncation
=
True
,
max_seq_len
=
max_seq_len
,
pad_to_max_seq_len
=
True
,
return_attention_mask
=
True
,
return_position_ids
=
True
,
return_dict
=
False
,
return_offsets_mapping
=
True
)
encoded_inputs
=
encoded_inputs
[
0
]
offset_mapping
=
[
list
(
x
)
for
x
in
encoded_inputs
[
"offset_mapping"
]]
bias
=
0
for
index
in
range
(
1
,
len
(
offset_mapping
)):
mapping
=
offset_mapping
[
index
]
if
mapping
[
0
]
==
0
and
mapping
[
1
]
==
0
and
bias
==
0
:
bias
=
offset_mapping
[
index
-
1
][
1
]
+
1
# Includes [SEP] token
if
mapping
[
0
]
==
0
and
mapping
[
1
]
==
0
:
continue
offset_mapping
[
index
][
0
]
+=
bias
offset_mapping
[
index
][
1
]
+=
bias
start_ids
=
[
0.0
for
x
in
range
(
max_seq_len
)]
end_ids
=
[
0.0
for
x
in
range
(
max_seq_len
)]
for
item
in
example
[
"result_list"
]:
start
=
map_offset
(
item
[
"start"
]
+
bias
,
offset_mapping
)
end
=
map_offset
(
item
[
"end"
]
-
1
+
bias
,
offset_mapping
)
start_ids
[
start
]
=
1.0
end_ids
[
end
]
=
1.0
if
multilingual
:
if
not
is_test
:
tokenized_output
=
{
"input_ids"
:
encoded_inputs
[
"input_ids"
],
"token_type_ids"
:
encoded_inputs
[
"token_type_ids"
],
"start_ids"
:
start_ids
,
"end_ids"
:
end_ids
}
else
:
tokenized_output
=
{
"input_ids"
:
encoded_inputs
[
"input_ids"
],
"token_type_ids"
:
encoded_inputs
[
"token_type_ids"
],
}
else
:
if
not
is_test
:
tokenized_output
=
{
"input_ids"
:
encoded_inputs
[
"input_ids"
],
"token_type_ids"
:
encoded_inputs
[
"token_type_ids"
],
"pos_ids"
:
encoded_inputs
[
"position_ids"
],
"att_mask"
:
encoded_inputs
[
"attention_mask"
],
"start_ids"
:
start_ids
,
"end_ids"
:
end_ids
}
else
:
tokenized_output
=
{
"input_ids"
:
encoded_inputs
[
"input_ids"
],
"token_type_ids"
:
encoded_inputs
[
"token_type_ids"
],
"pos_ids"
:
encoded_inputs
[
"position_ids"
],
"att_mask"
:
encoded_inputs
[
"attention_mask"
],
}
return
tokenized_output
def
create_data_holder
(
multilingual
=
True
):
"""
Define the input data holder for the glue task.
"""
return_list
=
[]
input_ids
=
paddle
.
static
.
data
(
name
=
"input_ids"
,
shape
=
[
-
1
,
-
1
],
dtype
=
"int64"
)
return_list
=
[
input_ids
]
token_type_ids
=
paddle
.
static
.
data
(
name
=
"token_type_ids"
,
shape
=
[
-
1
,
-
1
],
dtype
=
"int64"
)
return_list
.
append
(
token_type_ids
)
if
not
multilingual
:
position_ids
=
paddle
.
static
.
data
(
name
=
"pos_ids"
,
shape
=
[
-
1
,
-
1
],
dtype
=
"int64"
)
attention_mask
=
paddle
.
static
.
data
(
name
=
"att_mask"
,
shape
=
[
-
1
,
-
1
],
dtype
=
"int64"
)
return_list
.
append
(
position_ids
)
return_list
.
append
(
attention_mask
)
start_ids
=
paddle
.
static
.
data
(
name
=
"start_ids"
,
shape
=
[
-
1
,
1
],
dtype
=
"float32"
)
end_ids
=
paddle
.
static
.
data
(
name
=
"end_ids"
,
shape
=
[
-
1
,
1
],
dtype
=
"float32"
)
return_list
.
append
(
start_ids
)
return_list
.
append
(
end_ids
)
return
return_list
def
reader_proprecess
(
data_path
,
max_seq_len
=
512
):
"""
read json
"""
with
open
(
data_path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
for
line
in
f
:
json_line
=
json
.
loads
(
line
)
content
=
json_line
[
'content'
].
strip
()
prompt
=
json_line
[
'prompt'
]
# Model Input is aslike: [CLS] Prompt [SEP] Content [SEP]
# It include three summary tokens.
if
max_seq_len
<=
len
(
prompt
)
+
3
:
raise
ValueError
(
"The value of max_seq_len is too small, please set a larger value"
)
max_content_len
=
max_seq_len
-
len
(
prompt
)
-
3
if
len
(
content
)
<=
max_content_len
:
yield
json_line
else
:
result_list
=
json_line
[
'result_list'
]
json_lines
=
[]
accumulate
=
0
while
True
:
cur_result_list
=
[]
for
result
in
result_list
:
if
result
[
'start'
]
+
1
<=
max_content_len
<
result
[
'end'
]:
max_content_len
=
result
[
'start'
]
break
cur_content
=
content
[:
max_content_len
]
res_content
=
content
[
max_content_len
:]
while
True
:
if
len
(
result_list
)
==
0
:
break
elif
result_list
[
0
][
'end'
]
<=
max_content_len
:
if
result_list
[
0
][
'end'
]
>
0
:
cur_result
=
result_list
.
pop
(
0
)
cur_result_list
.
append
(
cur_result
)
else
:
cur_result_list
=
[
result
for
result
in
result_list
]
break
else
:
break
json_line
=
{
'content'
:
cur_content
,
'result_list'
:
cur_result_list
,
'prompt'
:
prompt
}
json_lines
.
append
(
json_line
)
for
result
in
result_list
:
if
result
[
'end'
]
<=
0
:
break
result
[
'start'
]
-=
max_content_len
result
[
'end'
]
-=
max_content_len
accumulate
+=
max_content_len
max_content_len
=
max_seq_len
-
len
(
prompt
)
-
3
if
len
(
res_content
)
==
0
:
break
elif
len
(
res_content
)
<
max_content_len
:
json_line
=
{
'content'
:
res_content
,
'result_list'
:
result_list
,
'prompt'
:
prompt
}
json_lines
.
append
(
json_line
)
break
else
:
content
=
res_content
for
json_line
in
json_lines
:
yield
json_line
def
reader
():
train_ds
=
load_dataset
(
reader_proprecess
,
data_path
=
global_config
[
'train_data'
],
max_seq_len
=
global_config
[
'max_seq_length'
],
lazy
=
False
)
dev_ds
=
load_dataset
(
reader_proprecess
,
data_path
=
global_config
[
'dev_data'
],
max_seq_len
=
global_config
[
'max_seq_length'
],
lazy
=
False
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
global_config
[
'model_dir'
])
trans_fn
=
partial
(
convert_example
,
tokenizer
=
tokenizer
,
max_seq_len
=
global_config
[
'max_seq_length'
],
is_test
=
True
)
train_ds
=
train_ds
.
map
(
trans_fn
)
dev_trans_fn
=
partial
(
convert_example
,
tokenizer
=
tokenizer
,
max_seq_len
=
global_config
[
'max_seq_length'
],
is_test
=
False
)
dev_ds
=
dev_ds
.
map
(
dev_trans_fn
)
batchify_fn
=
lambda
samples
,
fn
=
Dict
({
'input_ids'
:
Pad
(
axis
=
0
,
pad_val
=
tokenizer
.
pad_token_id
),
# input
'token_type_ids'
:
Pad
(
axis
=
0
,
pad_val
=
tokenizer
.
pad_token_type_id
),
# segment
}):
fn
(
samples
)
dev_batchify_fn
=
lambda
samples
,
fn
=
Dict
({
'input_ids'
:
Pad
(
axis
=
0
,
pad_val
=
tokenizer
.
pad_token_id
),
# input
'token_type_ids'
:
Pad
(
axis
=
0
,
pad_val
=
tokenizer
.
pad_token_type_id
),
# segment
'start_ids'
:
Stack
(
dtype
=
"int64"
),
'end_ids'
:
Stack
(
dtype
=
"int64"
)}):
fn
(
samples
)
[
input_ids
,
token_type_ids
,
start_ids
,
end_ids
]
=
create_data_holder
()
train_batch_sampler
=
paddle
.
io
.
BatchSampler
(
dataset
=
train_ds
,
batch_size
=
global_config
[
'batch_size'
],
shuffle
=
True
)
train_dataloader
=
paddle
.
io
.
DataLoader
(
train_ds
,
batch_sampler
=
train_batch_sampler
,
return_list
=
False
,
feed_list
=
[
input_ids
,
token_type_ids
],
collate_fn
=
batchify_fn
)
dev_batch_sampler
=
paddle
.
io
.
BatchSampler
(
dataset
=
dev_ds
,
batch_size
=
global_config
[
'batch_size'
],
shuffle
=
False
)
eval_dataloader
=
paddle
.
io
.
DataLoader
(
dev_ds
,
batch_sampler
=
dev_batch_sampler
,
return_list
=
False
,
feed_list
=
[
input_ids
,
token_type_ids
,
start_ids
,
end_ids
],
collate_fn
=
dev_batchify_fn
)
return
train_dataloader
,
eval_dataloader
def
eval_function
(
exe
,
compiled_test_program
,
test_feed_names
,
test_fetch_list
):
metric
.
reset
()
for
data
in
eval_dataloader
():
logits
=
exe
.
run
(
compiled_test_program
,
feed
=
{
'input_ids'
:
data
[
0
][
'input_ids'
],
'token_type_ids'
:
data
[
0
][
'token_type_ids'
],
},
fetch_list
=
test_fetch_list
)
paddle
.
disable_static
()
start_ids
=
paddle
.
to_tensor
(
np
.
array
(
data
[
0
][
'start_ids'
]))
end_ids
=
paddle
.
to_tensor
(
np
.
array
(
data
[
0
][
'end_ids'
]))
start_prob
=
paddle
.
to_tensor
(
logits
[
0
])
end_prob
=
paddle
.
to_tensor
(
logits
[
1
])
num_correct
,
num_infer
,
num_label
=
metric
.
compute
(
start_prob
,
end_prob
,
start_ids
,
end_ids
)
metric
.
update
(
num_correct
,
num_infer
,
num_label
)
paddle
.
enable_static
()
precision
,
recall
,
f1
=
metric
.
accumulate
()
return
f1
def
apply_decay_param_fun
(
name
):
if
name
.
find
(
"bias"
)
>
-
1
:
return
True
elif
name
.
find
(
"b_0"
)
>
-
1
:
return
True
elif
name
.
find
(
"norm"
)
>
-
1
:
return
True
else
:
return
False
def
main
():
all_config
=
load_config
(
args
.
config_path
)
global
global_config
assert
"Global"
in
all_config
,
"Key Global not found in config file."
global_config
=
all_config
[
"Global"
]
if
'TrainConfig'
in
all_config
:
all_config
[
'TrainConfig'
][
'optimizer_builder'
][
'apply_decay_param_fun'
]
=
apply_decay_param_fun
global
train_dataloader
,
eval_dataloader
train_dataloader
,
eval_dataloader
=
reader
()
global
metric
metric
=
SpanEvaluator
()
ac
=
AutoCompression
(
model_dir
=
global_config
[
'model_dir'
],
model_filename
=
global_config
[
'model_filename'
],
params_filename
=
global_config
[
'params_filename'
],
save_dir
=
args
.
save_dir
,
config
=
all_config
,
train_dataloader
=
train_dataloader
,
eval_callback
=
eval_function
,
eval_dataloader
=
eval_dataloader
)
if
not
os
.
path
.
exists
(
args
.
save_dir
):
os
.
makedirs
(
args
.
save_dir
)
for
file_name
in
os
.
listdir
(
global_config
[
'model_dir'
]):
if
'json'
in
file_name
or
'txt'
in
file_name
:
shutil
.
copy
(
os
.
path
.
join
(
global_config
[
'model_dir'
],
file_name
),
args
.
save_dir
)
ac
.
compress
()
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
parser
=
argsparser
()
args
=
parser
.
parse_args
()
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录