Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
ERNIE
提交
2d278cdc
E
ERNIE
项目概览
PaddlePaddle
/
ERNIE
大约 1 年 前同步成功
通知
109
Star
5997
Fork
1270
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
29
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
E
ERNIE
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
29
Issue
29
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2d278cdc
编写于
9月 29, 2019
作者:
C
chenxuyi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
bugfix: predict_hangs
+ sanity check in `in tokens` mode
上级
abd48478
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
16 addition
and
5 deletion
+16
-5
finetune_args.py
finetune_args.py
+1
-1
run_classifier.py
run_classifier.py
+6
-2
run_mrc.py
run_mrc.py
+2
-0
run_sequence_labeling.py
run_sequence_labeling.py
+7
-2
未找到文件。
finetune_args.py
浏览文件 @
2d278cdc
...
...
@@ -78,7 +78,7 @@ data_g.add_arg("dev_set", str, None, "Path to validation data.")
data_g
.
add_arg
(
"vocab_path"
,
str
,
None
,
"Vocabulary path."
)
data_g
.
add_arg
(
"max_seq_len"
,
int
,
512
,
"Number of words of the longest seqence."
)
data_g
.
add_arg
(
"batch_size"
,
int
,
32
,
"Total examples' number in batch for training. see also --in_tokens."
)
data_g
.
add_arg
(
"predict_batch_size"
,
int
,
8
,
"Total examples' number in batch for predict. see also --in_tokens."
)
data_g
.
add_arg
(
"predict_batch_size"
,
int
,
None
,
"Total examples' number in batch for predict. see also --in_tokens."
)
data_g
.
add_arg
(
"in_tokens"
,
bool
,
False
,
"If set, the batch size will be the maximum number of tokens in one batch. "
"Otherwise, it will be the maximum number of examples in one batch."
)
...
...
run_classifier.py
浏览文件 @
2d278cdc
...
...
@@ -92,6 +92,8 @@ def main(args):
num_train_examples
=
reader
.
get_num_examples
(
args
.
train_set
)
if
args
.
in_tokens
:
if
args
.
batch_size
<
args
.
max_seq_len
:
raise
ValueError
(
'if in_tokens=True, batch_size should greater than max_sqelen, got batch_size:%d seqlen:%d'
%
(
args
.
batch_size
,
args
.
max_seq_len
))
max_train_steps
=
args
.
epoch
*
num_train_examples
//
(
args
.
batch_size
//
args
.
max_seq_len
)
//
dev_count
else
:
...
...
@@ -376,11 +378,12 @@ def main(args):
def
evaluate_wrapper
(
args
,
reader
,
exe
,
test_prog
,
test_pyreader
,
graph_vars
,
epoch
,
steps
):
# evaluate dev set
batch_size
=
args
.
batch_size
if
args
.
predict_batch_size
is
None
else
args
.
predict_batch_size
for
ds
in
args
.
dev_set
.
split
(
','
):
test_pyreader
.
decorate_tensor_provider
(
reader
.
data_generator
(
ds
,
batch_size
=
args
.
predict_
batch_size
,
batch_size
=
batch_size
,
epoch
=
1
,
dev_count
=
1
,
shuffle
=
False
))
...
...
@@ -403,12 +406,13 @@ def predict_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars,
test_sets
=
args
.
test_set
.
split
(
','
)
save_dirs
=
args
.
test_save
.
split
(
','
)
assert
len
(
test_sets
)
==
len
(
save_dirs
)
batch_size
=
args
.
batch_size
if
args
.
predict_batch_size
is
None
else
args
.
predict_batch_size
for
test_f
,
save_f
in
zip
(
test_sets
,
save_dirs
):
test_pyreader
.
decorate_tensor_provider
(
reader
.
data_generator
(
test_f
,
batch_size
=
args
.
predict_
batch_size
,
batch_size
=
batch_size
,
epoch
=
1
,
dev_count
=
1
,
shuffle
=
False
))
...
...
run_mrc.py
浏览文件 @
2d278cdc
...
...
@@ -95,6 +95,8 @@ def main(args):
num_train_examples
=
reader
.
get_num_examples
(
"train"
)
if
args
.
in_tokens
:
if
args
.
batch_size
<
args
.
max_seq_len
:
raise
ValueError
(
'if in_tokens=True, batch_size should greater than max_sqelen, got batch_size:%d seqlen:%d'
%
(
args
.
batch_size
,
args
.
max_seq_len
))
max_train_steps
=
args
.
epoch
*
num_train_examples
//
(
args
.
batch_size
//
args
.
max_seq_len
)
//
dev_count
else
:
...
...
run_sequence_labeling.py
浏览文件 @
2d278cdc
...
...
@@ -85,6 +85,9 @@ def main(args):
num_train_examples
=
reader
.
get_num_examples
(
args
.
train_set
)
if
args
.
in_tokens
:
if
args
.
batch_size
<
args
.
max_seq_len
:
raise
ValueError
(
'if in_tokens=True, batch_size should greater than max_sqelen, got batch_size:%d seqlen:%d'
%
(
args
.
batch_size
,
args
.
max_seq_len
))
max_train_steps
=
args
.
epoch
*
num_train_examples
//
(
args
.
batch_size
//
args
.
max_seq_len
)
//
dev_count
else
:
...
...
@@ -297,11 +300,12 @@ def main(args):
def
evaluate_wrapper
(
reader
,
exe
,
test_prog
,
test_pyreader
,
graph_vars
,
epoch
,
steps
):
# evaluate dev set
batch_size
=
args
.
batch_size
if
args
.
predict_batch_size
is
None
else
args
.
predict_batch_size
for
ds
in
args
.
dev_set
.
split
(
','
):
#single card eval
test_pyreader
.
decorate_tensor_provider
(
reader
.
data_generator
(
ds
,
batch_size
=
args
.
predict_
batch_size
,
batch_size
=
batch_size
,
epoch
=
1
,
dev_count
=
1
,
shuffle
=
False
))
...
...
@@ -318,10 +322,11 @@ def predict_wrapper(reader, exe, test_prog, test_pyreader, graph_vars,
save_dirs
=
args
.
test_save
.
split
(
','
)
assert
len
(
test_sets
)
==
len
(
save_dirs
),
'number of test_sets & test_save not match, got %d vs %d'
%
(
len
(
test_sets
),
len
(
save_dirs
))
batch_size
=
args
.
batch_size
if
args
.
predict_batch_size
is
None
else
args
.
predict_batch_size
for
test_f
,
save_f
in
zip
(
test_sets
,
save_dirs
):
test_pyreader
.
decorate_tensor_provider
(
reader
.
data_generator
(
test_f
,
batch_size
=
args
.
predict_
batch_size
,
batch_size
=
batch_size
,
epoch
=
1
,
dev_count
=
1
,
shuffle
=
False
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录