Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
ERNIE
提交
841e53a8
E
ERNIE
项目概览
PaddlePaddle
/
ERNIE
1 年多 前同步成功
通知
111
Star
5997
Fork
1271
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
29
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
E
ERNIE
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
29
Issue
29
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
841e53a8
编写于
7月 11, 2019
作者:
T
tianxin
提交者:
GitHub
7月 11, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #199 from tianxin1860/develop
add predict_classifier.py
上级
7babcff1
b27f8e0c
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
150 addition
and
0 deletion
+150
-0
ERNIE/README.md
ERNIE/README.md
+22
-0
ERNIE/predict_classifier.py
ERNIE/predict_classifier.py
+128
-0
未找到文件。
ERNIE/README.md
浏览文件 @
841e53a8
...
...
@@ -295,3 +295,25 @@ python -u ernir_encoder.py \
#### 如何获取输入句子中每个 token 经过 ERNIE 编码后的 Embedding 表示?
[
解决方案同上
](
#如何获取输入句子经过-ERNIE-编码后的-Embedding-表示?
)
#### 如何利用 finetune 得到的模型对新数据进行批量预测?
我们以分类任务为例,给出了分类任务进行批量预测的脚本, 使用示例如下:
```
python -u predict_classifier.py \
--use_cuda true \
--batch_size 32 \
--vocab_path config/vocab.txt \
--init_checkpoint "./checkpoints/step_100" \
--do_lower_case true \
--max_seq_len 128 \
--ernie_config_path config/ernie_config.json \
--do_predict true \
--predict_set ${TASK_DATA_PATH}/lcqmc/test.tsv \
--num_labels 2
```
实际使用时,需要通过
`init_checkpoint`
指定预测用的模型,通过
`predict_set`
指定待预测的数据文件,通过
`num_labels`
配置分类的类别数目;
**Note**
: predict_set 的数据格式与 dev_set 和 test_set 的数据格式完全一致,是由 text_a、text_b(可选) 、label 组成的2列/3列 tsv 文件,predict_set 中的 label 列起到占位符的作用,全部置 0 即可;
ERNIE/predict_classifier.py
0 → 100644
浏览文件 @
841e53a8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Load classifier's checkpoint to do prediction or save inference model."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
time
import
argparse
import
numpy
as
np
import
multiprocessing
import
paddle.fluid
as
fluid
from
reader.task_reader
import
ClassifyReader
from
model.ernie
import
ErnieConfig
from
finetune.classifier
import
create_model
from
utils.args
import
ArgumentGroup
,
print_arguments
from
utils.init
import
init_pretraining_params
from
finetune_args
import
parser
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
model_g
=
ArgumentGroup
(
parser
,
"model"
,
"options to init, resume and save model."
)
model_g
.
add_arg
(
"ernie_config_path"
,
str
,
None
,
"Path to the json file for bert model config."
)
model_g
.
add_arg
(
"init_checkpoint"
,
str
,
None
,
"Init checkpoint to resume training from."
)
model_g
.
add_arg
(
"use_fp16"
,
bool
,
False
,
"Whether to resume parameters from fp16 checkpoint."
)
model_g
.
add_arg
(
"num_labels"
,
int
,
2
,
"num labels for classify"
)
data_g
=
ArgumentGroup
(
parser
,
"data"
,
"Data paths, vocab paths and data processing options."
)
data_g
.
add_arg
(
"predict_set"
,
str
,
None
,
"Predict set file"
)
data_g
.
add_arg
(
"vocab_path"
,
str
,
None
,
"Vocabulary path."
)
data_g
.
add_arg
(
"label_map_config"
,
str
,
None
,
"Label_map_config json file."
)
data_g
.
add_arg
(
"max_seq_len"
,
int
,
128
,
"Number of words of the longest seqence."
)
data_g
.
add_arg
(
"batch_size"
,
int
,
32
,
"Total examples' number in batch for training. see also --in_tokens."
)
data_g
.
add_arg
(
"do_lower_case"
,
bool
,
True
,
"Whether to lower case the input text. Should be True for uncased models and False for cased models."
)
run_type_g
=
ArgumentGroup
(
parser
,
"run_type"
,
"running type options."
)
run_type_g
.
add_arg
(
"use_cuda"
,
bool
,
True
,
"If set, use GPU for training."
)
run_type_g
.
add_arg
(
"do_prediction"
,
bool
,
True
,
"Whether to do prediction on test set."
)
args
=
parser
.
parse_args
()
# yapf: enable.
def
main
(
args
):
ernie_config
=
ErnieConfig
(
args
.
ernie_config_path
)
ernie_config
.
print_config
()
reader
=
ClassifyReader
(
vocab_path
=
args
.
vocab_path
,
label_map_config
=
args
.
label_map_config
,
max_seq_len
=
args
.
max_seq_len
,
do_lower_case
=
args
.
do_lower_case
,
in_tokens
=
False
)
predict_prog
=
fluid
.
Program
()
predict_startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
predict_prog
,
predict_startup
):
with
fluid
.
unique_name
.
guard
():
predict_pyreader
,
probs
,
feed_target_names
=
create_model
(
args
,
pyreader_name
=
'predict_reader'
,
ernie_config
=
ernie_config
,
is_prediction
=
True
)
predict_prog
=
predict_prog
.
clone
(
for_test
=
True
)
if
args
.
use_cuda
:
place
=
fluid
.
CUDAPlace
(
0
)
dev_count
=
fluid
.
core
.
get_cuda_device_count
()
else
:
place
=
fluid
.
CPUPlace
()
dev_count
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
place
=
fluid
.
CUDAPlace
(
0
)
if
args
.
use_cuda
==
True
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
predict_startup
)
if
args
.
init_checkpoint
:
init_pretraining_params
(
exe
,
args
.
init_checkpoint
,
predict_prog
)
else
:
raise
ValueError
(
"args 'init_checkpoint' should be set for prediction!"
)
predict_exe
=
fluid
.
Executor
(
place
)
predict_data_generator
=
reader
.
data_generator
(
input_file
=
args
.
predict_set
,
batch_size
=
args
.
batch_size
,
epoch
=
1
,
shuffle
=
False
)
predict_pyreader
.
decorate_tensor_provider
(
predict_data_generator
)
predict_pyreader
.
start
()
all_results
=
[]
time_begin
=
time
.
time
()
while
True
:
try
:
results
=
predict_exe
.
run
(
program
=
predict_prog
,
fetch_list
=
[
probs
.
name
])
all_results
.
extend
(
results
[
0
])
except
fluid
.
core
.
EOFException
:
predict_pyreader
.
reset
()
break
time_end
=
time
.
time
()
np
.
set_printoptions
(
precision
=
4
,
suppress
=
True
)
print
(
"-------------- prediction results --------------"
)
for
index
,
result
in
enumerate
(
all_results
):
print
(
str
(
index
)
+
'
\t
{}'
.
format
(
result
))
if
__name__
==
'__main__'
:
print_arguments
(
args
)
main
(
args
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录