Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
ERNIE
提交
85cf2ee1
E
ERNIE
项目概览
PaddlePaddle
/
ERNIE
大约 1 年 前同步成功
通知
109
Star
5997
Fork
1270
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
29
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
E
ERNIE
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
29
Issue
29
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
85cf2ee1
编写于
7月 17, 2019
作者:
Y
Yibing Liu
提交者:
GitHub
7月 17, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #208 from tianxin1860/develop
add classify inference using infer_program
上级
39d4571d
2f2fe7af
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
83 addition
and
48 deletion
+83
-48
ERNIE/README.md
ERNIE/README.md
+1
-1
ERNIE/finetune/classifier.py
ERNIE/finetune/classifier.py
+1
-1
ERNIE/predict_classifier.py
ERNIE/predict_classifier.py
+39
-20
ERNIE/reader/task_reader.py
ERNIE/reader/task_reader.py
+42
-26
未找到文件。
ERNIE/README.md
浏览文件 @
85cf2ee1
...
...
@@ -316,4 +316,4 @@ python -u predict_classifier.py \
实际使用时,需要通过
`init_checkpoint`
指定预测用的模型,通过
`predict_set`
指定待预测的数据文件,通过
`num_labels`
配置分类的类别数目;
**Note**
: predict_set 的数据格式
与 dev_set 和 test_set 的数据格式完全一致,是由 text_a、text_b(可选) 、label 组成的2列/3列 tsv 文件,predict_set 中的 label 列起到占位符的作用,全部置 0 即可
;
**Note**
: predict_set 的数据格式
是由 text_a、text_b(可选) 组成的1列/2列 tsv 文件
;
ERNIE/finetune/classifier.py
浏览文件 @
85cf2ee1
...
...
@@ -65,7 +65,7 @@ def create_model(args, pyreader_name, ernie_config, is_prediction=False):
if
is_prediction
:
probs
=
fluid
.
layers
.
softmax
(
logits
)
feed_targets_name
=
[
src_ids
.
name
,
pos_ids
.
name
,
sent
_ids
.
name
,
input_mask
.
name
src_ids
.
name
,
sent_ids
.
name
,
pos
_ids
.
name
,
input_mask
.
name
]
return
pyreader
,
probs
,
feed_targets_name
...
...
ERNIE/predict_classifier.py
浏览文件 @
85cf2ee1
...
...
@@ -37,6 +37,7 @@ parser = argparse.ArgumentParser(__doc__)
model_g
=
ArgumentGroup
(
parser
,
"model"
,
"options to init, resume and save model."
)
model_g
.
add_arg
(
"ernie_config_path"
,
str
,
None
,
"Path to the json file for bert model config."
)
model_g
.
add_arg
(
"init_checkpoint"
,
str
,
None
,
"Init checkpoint to resume training from."
)
model_g
.
add_arg
(
"save_inference_model_path"
,
str
,
"inference_model"
,
"If set, save the inference model to this path."
)
model_g
.
add_arg
(
"use_fp16"
,
bool
,
False
,
"Whether to resume parameters from fp16 checkpoint."
)
model_g
.
add_arg
(
"num_labels"
,
int
,
2
,
"num labels for classify"
)
...
...
@@ -65,7 +66,8 @@ def main(args):
label_map_config
=
args
.
label_map_config
,
max_seq_len
=
args
.
max_seq_len
,
do_lower_case
=
args
.
do_lower_case
,
in_tokens
=
False
)
in_tokens
=
False
,
is_inference
=
True
)
predict_prog
=
fluid
.
Program
()
predict_startup
=
fluid
.
Program
()
...
...
@@ -95,7 +97,25 @@ def main(args):
else
:
raise
ValueError
(
"args 'init_checkpoint' should be set for prediction!"
)
predict_exe
=
fluid
.
Executor
(
place
)
assert
args
.
save_inference_model_path
,
"args save_inference_model_path should be set for prediction"
_
,
ckpt_dir
=
os
.
path
.
split
(
args
.
init_checkpoint
.
rstrip
(
'/'
))
dir_name
=
ckpt_dir
+
'_inference_model'
model_path
=
os
.
path
.
join
(
args
.
save_inference_model_path
,
dir_name
)
print
(
"save inference model to %s"
%
model_path
)
fluid
.
io
.
save_inference_model
(
model_path
,
feed_target_names
,
[
probs
],
exe
,
main_program
=
predict_prog
)
print
(
"load inference model from %s"
%
model_path
)
infer_program
,
feed_target_names
,
probs
=
fluid
.
io
.
load_inference_model
(
model_path
,
exe
)
src_ids
=
feed_target_names
[
0
]
sent_ids
=
feed_target_names
[
1
]
pos_ids
=
feed_target_names
[
2
]
input_mask
=
feed_target_names
[
3
]
predict_data_generator
=
reader
.
data_generator
(
input_file
=
args
.
predict_set
,
...
...
@@ -103,25 +123,24 @@ def main(args):
epoch
=
1
,
shuffle
=
False
)
predict_pyreader
.
decorate_tensor_provider
(
predict_data_generator
)
predict_pyreader
.
start
()
all_results
=
[]
time_begin
=
time
.
time
()
while
True
:
try
:
results
=
predict_exe
.
run
(
program
=
predict_prog
,
fetch_list
=
[
probs
.
name
])
all_results
.
extend
(
results
[
0
])
except
fluid
.
core
.
EOFException
:
predict_pyreader
.
reset
()
break
time_end
=
time
.
time
()
np
.
set_printoptions
(
precision
=
4
,
suppress
=
True
)
print
(
"-------------- prediction results --------------"
)
for
index
,
result
in
enumerate
(
all_results
):
print
(
str
(
index
)
+
'
\t
{}'
.
format
(
result
))
np
.
set_printoptions
(
precision
=
4
,
suppress
=
True
)
index
=
0
for
sample
in
predict_data_generator
():
src_ids_data
=
sample
[
0
]
sent_ids_data
=
sample
[
1
]
pos_ids_data
=
sample
[
2
]
input_mask_data
=
sample
[
3
]
output
=
exe
.
run
(
infer_program
,
feed
=
{
src_ids
:
src_ids_data
,
sent_ids
:
sent_ids_data
,
pos_ids
:
pos_ids_data
,
input_mask
:
input_mask_data
},
fetch_list
=
probs
)
for
single_result
in
output
[
0
]:
print
(
"example_index:{}
\t
{}"
.
format
(
index
,
single_result
))
index
+=
1
if
__name__
==
'__main__'
:
print_arguments
(
args
)
...
...
ERNIE/reader/task_reader.py
浏览文件 @
85cf2ee1
...
...
@@ -28,6 +28,7 @@ class BaseReader(object):
max_seq_len
=
512
,
do_lower_case
=
True
,
in_tokens
=
False
,
is_inference
=
False
,
random_seed
=
None
):
self
.
max_seq_len
=
max_seq_len
self
.
tokenizer
=
tokenization
.
FullTokenizer
(
...
...
@@ -37,6 +38,7 @@ class BaseReader(object):
self
.
cls_id
=
self
.
vocab
[
"[CLS]"
]
self
.
sep_id
=
self
.
vocab
[
"[SEP]"
]
self
.
in_tokens
=
in_tokens
self
.
is_inference
=
is_inference
np
.
random
.
seed
(
random_seed
)
...
...
@@ -141,25 +143,33 @@ class BaseReader(object):
token_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
position_ids
=
list
(
range
(
len
(
token_ids
)))
if
self
.
label_map
:
label_id
=
self
.
label_map
[
example
.
label
]
if
self
.
is_inference
:
Record
=
namedtuple
(
'Record'
,
[
'token_ids'
,
'text_type_ids'
,
'position_ids'
])
record
=
Record
(
token_ids
=
token_ids
,
text_type_ids
=
text_type_ids
,
position_ids
=
position_ids
)
else
:
label_id
=
example
.
label
Record
=
namedtuple
(
'Record'
,
[
'token_ids'
,
'text_type_ids'
,
'position_ids'
,
'label_id'
,
'qid'
])
qid
=
None
if
"qid"
in
example
.
_fields
:
qid
=
example
.
qid
record
=
Record
(
token_ids
=
token_ids
,
text_type_ids
=
text_type_ids
,
position_ids
=
position_ids
,
label_id
=
label_id
,
qid
=
qid
)
if
self
.
label_map
:
label_id
=
self
.
label_map
[
example
.
label
]
else
:
label_id
=
example
.
label
Record
=
namedtuple
(
'Record'
,
[
'token_ids'
,
'text_type_ids'
,
'position_ids'
,
'label_id'
,
'qid'
])
qid
=
None
if
"qid"
in
example
.
_fields
:
qid
=
example
.
qid
record
=
Record
(
token_ids
=
token_ids
,
text_type_ids
=
text_type_ids
,
position_ids
=
position_ids
,
label_id
=
label_id
,
qid
=
qid
)
return
record
def
_prepare_batch_data
(
self
,
examples
,
batch_size
,
phase
=
None
):
...
...
@@ -235,14 +245,18 @@ class ClassifyReader(BaseReader):
batch_token_ids
=
[
record
.
token_ids
for
record
in
batch_records
]
batch_text_type_ids
=
[
record
.
text_type_ids
for
record
in
batch_records
]
batch_position_ids
=
[
record
.
position_ids
for
record
in
batch_records
]
batch_labels
=
[
record
.
label_id
for
record
in
batch_records
]
batch_labels
=
np
.
array
(
batch_labels
).
astype
(
"int64"
).
reshape
([
-
1
,
1
])
if
batch_records
[
0
].
qid
is
not
None
:
batch_qids
=
[
record
.
qid
for
record
in
batch_records
]
batch_qids
=
np
.
array
(
batch_qids
).
astype
(
"int64"
).
reshape
([
-
1
,
1
])
else
:
batch_qids
=
np
.
array
([]).
astype
(
"int64"
).
reshape
([
-
1
,
1
])
if
not
self
.
is_inference
:
batch_labels
=
[
record
.
label_id
for
record
in
batch_records
]
batch_labels
=
np
.
array
(
batch_labels
).
astype
(
"int64"
).
reshape
(
[
-
1
,
1
])
if
batch_records
[
0
].
qid
is
not
None
:
batch_qids
=
[
record
.
qid
for
record
in
batch_records
]
batch_qids
=
np
.
array
(
batch_qids
).
astype
(
"int64"
).
reshape
(
[
-
1
,
1
])
else
:
batch_qids
=
np
.
array
([]).
astype
(
"int64"
).
reshape
([
-
1
,
1
])
# padding
padded_token_ids
,
input_mask
=
pad_batch_data
(
...
...
@@ -254,8 +268,10 @@ class ClassifyReader(BaseReader):
return_list
=
[
padded_token_ids
,
padded_text_type_ids
,
padded_position_ids
,
input_mask
,
batch_labels
,
batch_qids
input_mask
]
if
not
self
.
is_inference
:
return_list
+=
[
batch_labels
,
batch_qids
]
return
return_list
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录