Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
ERNIE
提交
df1cb679
E
ERNIE
项目概览
PaddlePaddle
/
ERNIE
大约 1 年 前同步成功
通知
109
Star
5997
Fork
1270
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
29
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
E
ERNIE
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
29
Issue
29
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
df1cb679
编写于
3月 06, 2019
作者:
Y
Yibing Liu
提交者:
GitHub
3月 06, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5 from PaddlePaddle/fix_train_args
Enable batching not in tokens in pretraining
上级
8a0753a5
e65ba415
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
33 addition
and
16 deletion
+33
-16
BERT/reader/pretraining.py
BERT/reader/pretraining.py
+13
-6
BERT/train.py
BERT/train.py
+20
-10
未找到文件。
BERT/reader/pretraining.py
浏览文件 @
df1cb679
...
@@ -36,6 +36,7 @@ class DataReader(object):
...
@@ -36,6 +36,7 @@ class DataReader(object):
data_dir
,
data_dir
,
vocab_path
,
vocab_path
,
batch_size
=
4096
,
batch_size
=
4096
,
in_tokens
=
True
,
max_seq_len
=
512
,
max_seq_len
=
512
,
shuffle_files
=
True
,
shuffle_files
=
True
,
epoch
=
100
,
epoch
=
100
,
...
@@ -46,6 +47,7 @@ class DataReader(object):
...
@@ -46,6 +47,7 @@ class DataReader(object):
self
.
vocab
=
self
.
load_vocab
(
vocab_path
)
self
.
vocab
=
self
.
load_vocab
(
vocab_path
)
self
.
data_dir
=
data_dir
self
.
data_dir
=
data_dir
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
in_tokens
=
in_tokens
self
.
shuffle_files
=
shuffle_files
self
.
shuffle_files
=
shuffle_files
self
.
epoch
=
epoch
self
.
epoch
=
epoch
self
.
current_epoch
=
0
self
.
current_epoch
=
0
...
@@ -60,8 +62,9 @@ class DataReader(object):
...
@@ -60,8 +62,9 @@ class DataReader(object):
self
.
mask_id
=
self
.
vocab
[
"[MASK]"
]
self
.
mask_id
=
self
.
vocab
[
"[MASK]"
]
self
.
is_test
=
is_test
self
.
is_test
=
is_test
self
.
generate_neg_sample
=
generate_neg_sample
self
.
generate_neg_sample
=
generate_neg_sample
assert
self
.
batch_size
>
100
,
"Current batch size means total token's number,
\
if
self
.
in_tokens
:
it should not be set to too small number."
assert
self
.
batch_size
>=
self
.
max_seq_len
,
"The number of "
\
"tokens in batch should not be smaller than max seq length."
if
self
.
is_test
:
if
self
.
is_test
:
self
.
epoch
=
1
self
.
epoch
=
1
...
@@ -245,12 +248,16 @@ class DataReader(object):
...
@@ -245,12 +248,16 @@ class DataReader(object):
continue
continue
yield
sample
yield
sample
def
batch_reader
(
reader
,
batch_size
):
def
batch_reader
(
reader
,
batch_size
,
in_tokens
):
batch
,
total_token_num
,
max_len
=
[],
0
,
0
batch
,
total_token_num
,
max_len
=
[],
0
,
0
for
parsed_line
in
reader
():
for
parsed_line
in
reader
():
token_ids
,
sent_ids
,
pos_ids
,
label
=
parsed_line
token_ids
,
sent_ids
,
pos_ids
,
label
=
parsed_line
max_len
=
max
(
max_len
,
len
(
token_ids
))
max_len
=
max
(
max_len
,
len
(
token_ids
))
if
(
len
(
batch
)
+
1
)
*
max_len
<=
batch_size
:
if
in_tokens
:
to_append
=
(
len
(
batch
)
+
1
)
*
max_len
<=
batch_size
else
:
to_append
=
len
(
batch
)
<
batch_size
if
to_append
:
batch
.
append
(
parsed_line
)
batch
.
append
(
parsed_line
)
total_token_num
+=
len
(
token_ids
)
total_token_num
+=
len
(
token_ids
)
else
:
else
:
...
@@ -261,8 +268,8 @@ class DataReader(object):
...
@@ -261,8 +268,8 @@ class DataReader(object):
if
len
(
batch
)
>
0
:
if
len
(
batch
)
>
0
:
yield
batch
,
total_token_num
yield
batch
,
total_token_num
for
batch_data
,
total_token_num
in
batch_reader
(
reader
,
for
batch_data
,
total_token_num
in
batch_reader
(
self
.
batch_size
):
reader
,
self
.
batch_size
,
self
.
in_tokens
):
yield
prepare_batch_data
(
yield
prepare_batch_data
(
batch_data
,
batch_data
,
total_token_num
,
total_token_num
,
...
...
BERT/train.py
浏览文件 @
df1cb679
...
@@ -61,14 +61,15 @@ log_g.add_arg("verbose", bool, False, "Whether to output verbose l
...
@@ -61,14 +61,15 @@ log_g.add_arg("verbose", bool, False, "Whether to output verbose l
data_g
=
ArgumentGroup
(
parser
,
"data"
,
"Data paths, vocab paths and data processing options"
)
data_g
=
ArgumentGroup
(
parser
,
"data"
,
"Data paths, vocab paths and data processing options"
)
data_g
.
add_arg
(
"data_dir"
,
str
,
"./data/train/"
,
"Path to training data."
)
data_g
.
add_arg
(
"data_dir"
,
str
,
"./data/train/"
,
"Path to training data."
)
data_g
.
add_arg
(
"validation_set_dir"
,
str
,
"./data/validation/"
,
"Path to
training
data."
)
data_g
.
add_arg
(
"validation_set_dir"
,
str
,
"./data/validation/"
,
"Path to
validation
data."
)
data_g
.
add_arg
(
"test_set_dir"
,
str
,
None
,
"Path to t
raining
data."
)
data_g
.
add_arg
(
"test_set_dir"
,
str
,
None
,
"Path to t
est
data."
)
data_g
.
add_arg
(
"vocab_path"
,
str
,
"./config/vocab.txt"
,
"Vocabulary path."
)
data_g
.
add_arg
(
"vocab_path"
,
str
,
"./config/vocab.txt"
,
"Vocabulary path."
)
data_g
.
add_arg
(
"max_seq_len"
,
int
,
512
,
"Number of words of the longest seqence."
)
data_g
.
add_arg
(
"max_seq_len"
,
int
,
512
,
"Tokens' number of the longest seqence allowed."
)
data_g
.
add_arg
(
"batch_size"
,
int
,
16
,
"Total examples' number in batch for training. see also --in_tokens."
)
data_g
.
add_arg
(
"batch_size"
,
int
,
8192
,
data_g
.
add_arg
(
"in_tokens"
,
bool
,
False
,
"The total number of examples in one batch for training, see also --in_tokens."
)
"If set, the batch size will be the maximum number of tokens in one batch. "
data_g
.
add_arg
(
"in_tokens"
,
bool
,
True
,
"Otherwise, it will be the maximum number of examples in one batch."
)
"If set, the batch size will be the maximum number of tokens in one batch. "
"Otherwise, it will be the maximum number of examples in one batch."
)
run_type_g
=
ArgumentGroup
(
parser
,
"run_type"
,
"running type options."
)
run_type_g
=
ArgumentGroup
(
parser
,
"run_type"
,
"running type options."
)
run_type_g
.
add_arg
(
"is_distributed"
,
bool
,
False
,
"If set, then start distributed training."
)
run_type_g
.
add_arg
(
"is_distributed"
,
bool
,
False
,
"If set, then start distributed training."
)
...
@@ -128,6 +129,7 @@ def predict_wrapper(args,
...
@@ -128,6 +129,7 @@ def predict_wrapper(args,
data_path
,
data_path
,
vocab_path
=
args
.
vocab_path
,
vocab_path
=
args
.
vocab_path
,
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
in_tokens
=
args
.
in_tokens
,
voc_size
=
bert_config
[
'vocab_size'
],
voc_size
=
bert_config
[
'vocab_size'
],
shuffle_files
=
False
,
shuffle_files
=
False
,
epoch
=
1
,
epoch
=
1
,
...
@@ -250,9 +252,16 @@ def train(args):
...
@@ -250,9 +252,16 @@ def train(args):
dev_count
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
dev_count
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
print
(
"Device count %d"
%
dev_count
)
print
(
"Device count %d"
%
dev_count
)
print
(
"theoretical memory usage: "
)
if
args
.
verbose
:
print
(
fluid
.
contrib
.
memory_usage
(
if
args
.
in_tokens
:
program
=
train_program
,
batch_size
=
args
.
batch_size
//
args
.
max_seq_len
))
lower_mem
,
upper_mem
,
unit
=
fluid
.
contrib
.
memory_usage
(
program
=
train_program
,
batch_size
=
args
.
batch_size
//
args
.
max_seq_len
)
else
:
lower_mem
,
upper_mem
,
unit
=
fluid
.
contrib
.
memory_usage
(
program
=
train_program
,
batch_size
=
args
.
batch_size
)
print
(
"Theoretical memory usage in training: %.3f - %.3f %s"
%
(
lower_mem
,
upper_mem
,
unit
))
nccl2_num_trainers
=
1
nccl2_num_trainers
=
1
nccl2_trainer_id
=
0
nccl2_trainer_id
=
0
...
@@ -293,6 +302,7 @@ def train(args):
...
@@ -293,6 +302,7 @@ def train(args):
data_reader
=
DataReader
(
data_reader
=
DataReader
(
data_dir
=
args
.
data_dir
,
data_dir
=
args
.
data_dir
,
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
in_tokens
=
args
.
in_tokens
,
vocab_path
=
args
.
vocab_path
,
vocab_path
=
args
.
vocab_path
,
voc_size
=
bert_config
[
'vocab_size'
],
voc_size
=
bert_config
[
'vocab_size'
],
epoch
=
args
.
epoch
,
epoch
=
args
.
epoch
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录