Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
ERNIE
提交
2dc53566
E
ERNIE
项目概览
PaddlePaddle
/
ERNIE
大约 1 年 前同步成功
通知
109
Star
5997
Fork
1270
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
29
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
E
ERNIE
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
29
Issue
29
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2dc53566
编写于
4月 10, 2019
作者:
T
tianxin
提交者:
GitHub
4月 10, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #84 from PaddlePaddle/fix_rand_hang
Yield dev_count times batches for exiting fine-tuning normally
上级
12256da0
49ac5072
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
29 addition
and
4 deletion
+29
-4
BERT/reader/cls.py
BERT/reader/cls.py
+13
-2
BERT/reader/squad.py
BERT/reader/squad.py
+9
-1
BERT/run_classifier.py
BERT/run_classifier.py
+5
-1
BERT/run_squad.py
BERT/run_squad.py
+2
-0
未找到文件。
BERT/reader/cls.py
浏览文件 @
2dc53566
...
...
@@ -118,7 +118,12 @@ class DataProcessor(object):
"""Gets progress for training phase."""
return
self
.
current_train_example
,
self
.
current_train_epoch
def
data_generator
(
self
,
batch_size
,
phase
=
'train'
,
epoch
=
1
,
shuffle
=
True
):
def
data_generator
(
self
,
batch_size
,
phase
=
'train'
,
epoch
=
1
,
dev_count
=
1
,
shuffle
=
True
):
"""
Generate data for train, dev or test.
...
...
@@ -178,6 +183,7 @@ class DataProcessor(object):
yield
batch
,
total_token_num
def
wrapper
():
all_dev_batches
=
[]
for
batch_data
,
total_token_num
in
batch_reader
(
instance_reader
,
batch_size
,
self
.
in_tokens
):
batch_data
=
self
.
generate_batch_data
(
...
...
@@ -188,7 +194,12 @@ class DataProcessor(object):
return_input_mask
=
True
,
return_max_len
=
False
,
return_num_token
=
False
)
yield
batch_data
if
len
(
all_dev_batches
)
<
dev_count
:
all_dev_batches
.
append
(
batch_data
)
else
:
for
batch
in
all_dev_batches
:
yield
batch
all_dev_batches
=
[
batch_data
]
return
wrapper
...
...
BERT/reader/squad.py
浏览文件 @
2dc53566
...
...
@@ -488,6 +488,7 @@ class DataProcessor(object):
batch_size
,
phase
=
'train'
,
shuffle
=
False
,
dev_count
=
1
,
version_2_with_negative
=
False
,
epoch
=
1
):
if
phase
==
'train'
:
...
...
@@ -549,9 +550,10 @@ class DataProcessor(object):
else
:
features
=
self
.
get_features
(
examples
,
is_training
=
False
)
all_dev_batches
=
[]
for
batch_data
,
total_token_num
in
batch_reader
(
features
,
batch_size
,
self
.
_in_tokens
):
yield
prepare_batch_data
(
batch_data
=
prepare_batch_data
(
batch_data
,
total_token_num
,
voc_size
=-
1
,
...
...
@@ -562,6 +564,12 @@ class DataProcessor(object):
return_input_mask
=
True
,
return_max_len
=
False
,
return_num_token
=
False
)
if
len
(
all_dev_batches
)
<
dev_count
:
all_dev_batches
.
append
(
batch_data
)
else
:
for
batch
in
all_dev_batches
:
yield
batch
all_dev_batches
=
[
batch_data
]
return
wrapper
...
...
BERT/run_classifier.py
浏览文件 @
2dc53566
...
...
@@ -148,6 +148,7 @@ def main(args):
batch_size
=
args
.
batch_size
,
phase
=
'train'
,
epoch
=
args
.
epoch
,
dev_count
=
dev_count
,
shuffle
=
True
)
num_train_examples
=
processor
.
get_num_examples
(
phase
=
'train'
)
...
...
@@ -330,6 +331,7 @@ def main(args):
batch_size
=
args
.
batch_size
,
phase
=
'dev'
,
epoch
=
1
,
dev_count
=
1
,
shuffle
=
False
))
evaluate
(
exe
,
test_prog
,
test_pyreader
,
[
loss
.
name
,
accuracy
.
name
,
num_seqs
.
name
],
...
...
@@ -341,6 +343,7 @@ def main(args):
batch_size
=
args
.
batch_size
,
phase
=
'test'
,
epoch
=
1
,
dev_count
=
1
,
shuffle
=
False
))
evaluate
(
exe
,
test_prog
,
test_pyreader
,
[
loss
.
name
,
accuracy
.
name
,
num_seqs
.
name
],
...
...
@@ -355,7 +358,7 @@ def main(args):
if
args
.
do_val
:
test_pyreader
.
decorate_tensor_provider
(
processor
.
data_generator
(
batch_size
=
args
.
batch_size
,
phase
=
'dev'
,
epoch
=
1
,
batch_size
=
args
.
batch_size
,
phase
=
'dev'
,
epoch
=
1
,
dev_count
=
1
,
shuffle
=
False
))
print
(
"Final validation result:"
)
evaluate
(
exe
,
test_prog
,
test_pyreader
,
...
...
@@ -368,6 +371,7 @@ def main(args):
batch_size
=
args
.
batch_size
,
phase
=
'test'
,
epoch
=
1
,
dev_count
=
1
,
shuffle
=
False
))
print
(
"Final test result:"
)
evaluate
(
exe
,
test_prog
,
test_pyreader
,
...
...
BERT/run_squad.py
浏览文件 @
2dc53566
...
...
@@ -242,6 +242,7 @@ def train(args):
batch_size
=
args
.
batch_size
,
phase
=
'train'
,
shuffle
=
False
,
dev_count
=
dev_count
,
version_2_with_negative
=
args
.
version_2_with_negative
,
epoch
=
args
.
epoch
)
...
...
@@ -413,6 +414,7 @@ def train(args):
batch_size
=
args
.
batch_size
,
phase
=
'predict'
,
shuffle
=
False
,
dev_count
=
1
,
epoch
=
1
))
predict
(
exe
,
test_prog
,
test_pyreader
,
[
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录