Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
4e553e2b
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4e553e2b
编写于
8月 06, 2019
作者:
Y
Yibing Liu
提交者:
GitHub
8月 06, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use gc, PyReader & compiledprogram for bert (#3035)
上级
80283e6d
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
128 addition
and
121 deletion
+128
-121
PaddleNLP/README.md
PaddleNLP/README.md
+1
-1
PaddleNLP/language_representations_kit/BERT/README.md
PaddleNLP/language_representations_kit/BERT/README.md
+12
-8
PaddleNLP/language_representations_kit/BERT/model/classifier.py
...NLP/language_representations_kit/BERT/model/classifier.py
+20
-15
PaddleNLP/language_representations_kit/BERT/predict_classifier.py
...P/language_representations_kit/BERT/predict_classifier.py
+2
-3
PaddleNLP/language_representations_kit/BERT/run_classifier.py
...leNLP/language_representations_kit/BERT/run_classifier.py
+29
-31
PaddleNLP/language_representations_kit/BERT/run_squad.py
PaddleNLP/language_representations_kit/BERT/run_squad.py
+29
-23
PaddleNLP/language_representations_kit/BERT/train.py
PaddleNLP/language_representations_kit/BERT/train.py
+35
-40
未找到文件。
PaddleNLP/README.md
浏览文件 @
4e553e2b
...
@@ -84,7 +84,7 @@ cd models/PaddleNLP/sentiment_classification
...
@@ -84,7 +84,7 @@ cd models/PaddleNLP/sentiment_classification
-
[
机器翻译
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/neural_machine_translation/transformer
)
-
[
机器翻译
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/neural_machine_translation/transformer
)
### 语义表示与语言模型
### 语义表示与语言模型
-
[
语言表示工具箱
](
https://github.com/PaddlePaddle/
LARK/tree/develop
)
-
[
语言表示工具箱
](
https://github.com/PaddlePaddle/
models/tree/develop/PaddleNLP/language_representations_kit
)
-
[
语言模型
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/language_model
)
-
[
语言模型
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/language_model
)
### 复杂任务
### 复杂任务
...
...
PaddleNLP/language_representations_kit/BERT/README.md
浏览文件 @
4e553e2b
...
@@ -18,6 +18,8 @@
...
@@ -18,6 +18,8 @@
| Model | Layers | Hidden size | Heads |Parameters |
| Model | Layers | Hidden size | Heads |Parameters |
| :------| :------: | :------: |:------: |:------: |
| :------| :------: | :------: |:------: |:------: |
|
[
BERT-Large, Uncased (Whole Word Masking)
](
https://bert-models.bj.bcebos.com/wwm_uncased_L-24_H-1024_A-16.tar.gz
)
| 24 | 1024 | 16 | 340M |
|
[
BERT-Large, Cased (Whole Word Masking)
](
https://bert-models.bj.bcebos.com/wwm_cased_L-24_H-1024_A-16.tar.gz
)
| 24 | 1024 | 16 | 340M |
|
[
BERT-Base, Uncased
](
https://bert-models.bj.bcebos.com/uncased_L-12_H-768_A-12.tar.gz
)
| 12 | 768 |12 |110M |
|
[
BERT-Base, Uncased
](
https://bert-models.bj.bcebos.com/uncased_L-12_H-768_A-12.tar.gz
)
| 12 | 768 |12 |110M |
|
[
BERT-Large, Uncased
](
https://bert-models.bj.bcebos.com/uncased_L-24_H-1024_A-16.tar.gz
)
| 24 | 1024 |16 |340M |
|
[
BERT-Large, Uncased
](
https://bert-models.bj.bcebos.com/uncased_L-24_H-1024_A-16.tar.gz
)
| 24 | 1024 |16 |340M |
|
[
BERT-Base, Cased
](
https://bert-models.bj.bcebos.com/cased_L-12_H-768_A-12.tar.gz
)
|12|768|12|110M|
|
[
BERT-Base, Cased
](
https://bert-models.bj.bcebos.com/cased_L-12_H-768_A-12.tar.gz
)
|12|768|12|110M|
...
@@ -46,7 +48,7 @@
...
@@ -46,7 +48,7 @@
-
[
inference 接口调用示例
](
#inference-接口调用示例
)
-
[
inference 接口调用示例
](
#inference-接口调用示例
)
## 安装
## 安装
本项目依赖于 Paddle Fluid
**1.
3
.1**
,请参考
[
安装指南
](
http://www.paddlepaddle.org/#quick-start
)
进行安装。如果需要进行 TensorFlow 模型到 Paddle Fluid 参数的转换,则需要同时安装 TensorFlow 1.12。
本项目依赖于 Paddle Fluid
**1.
5
.1**
,请参考
[
安装指南
](
http://www.paddlepaddle.org/#quick-start
)
进行安装。如果需要进行 TensorFlow 模型到 Paddle Fluid 参数的转换,则需要同时安装 TensorFlow 1.12。
## 预训练
## 预训练
...
@@ -138,22 +140,24 @@ python -u run_classifier.py --task_name ${TASK_NAME} \
...
@@ -138,22 +140,24 @@ python -u run_classifier.py --task_name ${TASK_NAME} \
--do_train
true
\
--do_train
true
\
--do_val
true
\
--do_val
true
\
--do_test
true
\
--do_test
true
\
--batch_size
819
2
\
--batch_size
3
2
\
--in_tokens
tru
e
\
--in_tokens
fals
e
\
--init_pretraining_params
${
BERT_BASE_PATH
}
/params
\
--init_pretraining_params
${
BERT_BASE_PATH
}
/params
\
--data_dir
${
DATA_PATH
}
\
--data_dir
${
DATA_PATH
}
\
--vocab_path
${
BERT_BASE_PATH
}
/vocab.txt
\
--vocab_path
${
BERT_BASE_PATH
}
/vocab.txt
\
--checkpoints
${
CKPT_PATH
}
\
--checkpoints
${
CKPT_PATH
}
\
--save_steps
1000
\
--save_steps
1000
\
--weight_decay
0.01
\
--weight_decay
0.01
\
--warmup_proportion
0.
0
\
--warmup_proportion
0.
1
\
--validation_steps
25
\
--validation_steps
100
\
--epoch
3
\
--epoch
3
\
--max_seq_len
512
\
--max_seq_len
128
\
--bert_config_path
${
BERT_BASE_PATH
}
/bert_config.json
\
--bert_config_path
${
BERT_BASE_PATH
}
/bert_config.json
\
--learning_rate
1e-4
\
--learning_rate
5e-5
\
--skip_steps
10
\
--skip_steps
10
\
--random_seed
1
--num_iteration_per_drop_scope
10
\
--use_fp16
true
\
--verbose
true
```
```
这里的
`chinese_L-12_H-768_A-12`
即是转换后的中文预训练模型。需要注意的是,BERT on PaddlePaddle 支持按两种方式构建一个 batch 的数据,
`in_tokens`
参数影响
`batch_size`
参数的意义,如果
`in_tokens`
为
`true`
则按照 token 个数构建 batch, 如不设定则按照 example 个数来构建 batch. 训练过程中会输出训练误差、训练速度等信息,训练结束后会输出如下所示的在验证集上的测试结果:
这里的
`chinese_L-12_H-768_A-12`
即是转换后的中文预训练模型。需要注意的是,BERT on PaddlePaddle 支持按两种方式构建一个 batch 的数据,
`in_tokens`
参数影响
`batch_size`
参数的意义,如果
`in_tokens`
为
`true`
则按照 token 个数构建 batch, 如不设定则按照 example 个数来构建 batch. 训练过程中会输出训练误差、训练速度等信息,训练结束后会输出如下所示的在验证集上的测试结果:
...
...
PaddleNLP/language_representations_kit/BERT/model/classifier.py
浏览文件 @
4e553e2b
...
@@ -22,22 +22,27 @@ import paddle.fluid as fluid
...
@@ -22,22 +22,27 @@ import paddle.fluid as fluid
from
model.bert
import
BertModel
from
model.bert
import
BertModel
def
create_model
(
args
,
def
create_model
(
args
,
bert_config
,
num_labels
,
is_prediction
=
False
):
pyreader_name
,
input_fields
=
{
bert_config
,
'names'
:
[
'src_ids'
,
'pos_ids'
,
'sent_ids'
,
'input_mask'
,
'labels'
],
num_labels
,
'shapes'
:
is_prediction
=
False
):
[[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
args
.
max_seq_len
,
1
],
pyreader
=
fluid
.
layers
.
py_reader
(
[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
1
]],
capacity
=
50
,
'dtypes'
:
[
'int64'
,
'int64'
,
'int64'
,
'float32'
,
'int64'
],
shapes
=
[[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
args
.
max_seq_len
,
1
],
'lod_levels'
:
[
0
,
0
,
0
,
0
,
0
],
[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
1
]],
}
dtypes
=
[
'int64'
,
'int64'
,
'int64'
,
'float32'
,
'int64'
],
lod_levels
=
[
0
,
0
,
0
,
0
,
0
],
name
=
pyreader_name
,
use_double_buffer
=
True
)
(
src_ids
,
pos_ids
,
sent_ids
,
input_mask
,
inputs
=
[
labels
)
=
fluid
.
layers
.
read_file
(
pyreader
)
fluid
.
layers
.
data
(
name
=
input_fields
[
'names'
][
i
],
shape
=
input_fields
[
'shapes'
][
i
],
dtype
=
input_fields
[
'dtypes'
][
i
],
lod_level
=
input_fields
[
'lod_levels'
][
i
])
for
i
in
range
(
len
(
input_fields
[
'names'
]))
]
(
src_ids
,
pos_ids
,
sent_ids
,
input_mask
,
labels
)
=
inputs
pyreader
=
fluid
.
io
.
PyReader
(
feed_list
=
inputs
,
capacity
=
50
,
iterable
=
False
)
bert
=
BertModel
(
bert
=
BertModel
(
src_ids
=
src_ids
,
src_ids
=
src_ids
,
...
...
PaddleNLP/language_representations_kit/BERT/predict_classifier.py
浏览文件 @
4e553e2b
...
@@ -84,7 +84,6 @@ def main(args):
...
@@ -84,7 +84,6 @@ def main(args):
with
fluid
.
unique_name
.
guard
():
with
fluid
.
unique_name
.
guard
():
predict_pyreader
,
probs
,
feed_target_names
=
create_model
(
predict_pyreader
,
probs
,
feed_target_names
=
create_model
(
args
,
args
,
pyreader_name
=
'predict_reader'
,
bert_config
=
bert_config
,
bert_config
=
bert_config
,
num_labels
=
num_labels
,
num_labels
=
num_labels
,
is_prediction
=
True
)
is_prediction
=
True
)
...
@@ -103,7 +102,7 @@ def main(args):
...
@@ -103,7 +102,7 @@ def main(args):
exe
.
run
(
predict_startup
)
exe
.
run
(
predict_startup
)
if
args
.
init_checkpoint
:
if
args
.
init_checkpoint
:
init_pretraining_params
(
exe
,
args
.
init_checkpoint
,
predict_prog
)
init_pretraining_params
(
exe
,
args
.
init_checkpoint
,
predict_prog
,
args
.
use_fp16
)
else
:
else
:
raise
ValueError
(
"args 'init_checkpoint' should be set for prediction!"
)
raise
ValueError
(
"args 'init_checkpoint' should be set for prediction!"
)
...
@@ -113,7 +112,7 @@ def main(args):
...
@@ -113,7 +112,7 @@ def main(args):
predict_exe
=
fluid
.
ParallelExecutor
(
predict_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
args
.
use_cuda
,
main_program
=
predict_prog
)
use_cuda
=
args
.
use_cuda
,
main_program
=
predict_prog
)
predict_pyreader
.
decorate_
tensor_provide
r
(
predict_pyreader
.
decorate_
batch_generato
r
(
processor
.
data_generator
(
processor
.
data_generator
(
batch_size
=
args
.
batch_size
,
phase
=
'test'
,
epoch
=
1
,
shuffle
=
False
))
batch_size
=
args
.
batch_size
,
phase
=
'test'
,
epoch
=
1
,
shuffle
=
False
))
...
...
PaddleNLP/language_representations_kit/BERT/run_classifier.py
浏览文件 @
4e553e2b
...
@@ -193,7 +193,6 @@ def main(args):
...
@@ -193,7 +193,6 @@ def main(args):
with
fluid
.
unique_name
.
guard
():
with
fluid
.
unique_name
.
guard
():
train_pyreader
,
loss
,
probs
,
accuracy
,
num_seqs
=
create_model
(
train_pyreader
,
loss
,
probs
,
accuracy
,
num_seqs
=
create_model
(
args
,
args
,
pyreader_name
=
'train_reader'
,
bert_config
=
bert_config
,
bert_config
=
bert_config
,
num_labels
=
num_labels
)
num_labels
=
num_labels
)
scheduled_lr
=
optimization
(
scheduled_lr
=
optimization
(
...
@@ -219,17 +218,41 @@ def main(args):
...
@@ -219,17 +218,41 @@ def main(args):
print
(
"Theoretical memory usage in training: %.3f - %.3f %s"
%
print
(
"Theoretical memory usage in training: %.3f - %.3f %s"
%
(
lower_mem
,
upper_mem
,
unit
))
(
lower_mem
,
upper_mem
,
unit
))
if
args
.
do_val
or
args
.
do_test
:
if
args
.
do_val
:
dev_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
dev_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
dev_pyreader
,
loss
,
probs
,
accuracy
,
num_seqs
=
create_model
(
args
,
bert_config
=
bert_config
,
num_labels
=
num_labels
)
dev_prog
=
dev_prog
.
clone
(
for_test
=
True
)
dev_pyreader
.
decorate_batch_generator
(
processor
.
data_generator
(
batch_size
=
args
.
batch_size
,
phase
=
'dev'
,
epoch
=
1
,
dev_count
=
1
,
shuffle
=
False
),
place
)
if
args
.
do_test
:
test_prog
=
fluid
.
Program
()
test_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
test_prog
,
startup_prog
):
with
fluid
.
program_guard
(
test_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
with
fluid
.
unique_name
.
guard
():
test_pyreader
,
loss
,
probs
,
accuracy
,
num_seqs
=
create_model
(
test_pyreader
,
loss
,
probs
,
accuracy
,
num_seqs
=
create_model
(
args
,
args
,
pyreader_name
=
'test_reader'
,
bert_config
=
bert_config
,
bert_config
=
bert_config
,
num_labels
=
num_labels
)
num_labels
=
num_labels
)
test_prog
=
test_prog
.
clone
(
for_test
=
True
)
test_prog
=
test_prog
.
clone
(
for_test
=
True
)
test_pyreader
.
decorate_batch_generator
(
processor
.
data_generator
(
batch_size
=
args
.
batch_size
,
phase
=
'test'
,
epoch
=
1
,
dev_count
=
1
,
shuffle
=
False
),
place
)
exe
.
run
(
startup_prog
)
exe
.
run
(
startup_prog
)
...
@@ -276,7 +299,7 @@ def main(args):
...
@@ -276,7 +299,7 @@ def main(args):
train_compiled_program
=
fluid
.
CompiledProgram
(
train_program
).
with_data_parallel
(
train_compiled_program
=
fluid
.
CompiledProgram
(
train_program
).
with_data_parallel
(
loss_name
=
loss
.
name
,
build_strategy
=
build_strategy
)
loss_name
=
loss
.
name
,
build_strategy
=
build_strategy
)
train_pyreader
.
decorate_
tensor_provider
(
train_data_generator
)
train_pyreader
.
decorate_
batch_generator
(
train_data_generator
,
place
)
if
args
.
do_train
:
if
args
.
do_train
:
...
@@ -350,25 +373,11 @@ def main(args):
...
@@ -350,25 +373,11 @@ def main(args):
throughput
=
[]
throughput
=
[]
# evaluate dev set
# evaluate dev set
if
args
.
do_val
:
if
args
.
do_val
:
test_pyreader
.
decorate_tensor_provider
(
evaluate
(
exe
,
dev_prog
,
dev_pyreader
,
processor
.
data_generator
(
batch_size
=
args
.
batch_size
,
phase
=
'dev'
,
epoch
=
1
,
dev_count
=
1
,
shuffle
=
False
))
evaluate
(
exe
,
test_prog
,
test_pyreader
,
[
loss
.
name
,
accuracy
.
name
,
num_seqs
.
name
],
[
loss
.
name
,
accuracy
.
name
,
num_seqs
.
name
],
"dev"
)
"dev"
)
# evaluate test set
# evaluate test set
if
args
.
do_test
:
if
args
.
do_test
:
test_pyreader
.
decorate_tensor_provider
(
processor
.
data_generator
(
batch_size
=
args
.
batch_size
,
phase
=
'test'
,
epoch
=
1
,
dev_count
=
1
,
shuffle
=
False
))
evaluate
(
exe
,
test_prog
,
test_pyreader
,
evaluate
(
exe
,
test_prog
,
test_pyreader
,
[
loss
.
name
,
accuracy
.
name
,
num_seqs
.
name
],
[
loss
.
name
,
accuracy
.
name
,
num_seqs
.
name
],
"test"
)
"test"
)
...
@@ -398,23 +407,12 @@ def main(args):
...
@@ -398,23 +407,12 @@ def main(args):
# final eval on dev set
# final eval on dev set
if
args
.
do_val
:
if
args
.
do_val
:
test_pyreader
.
decorate_tensor_provider
(
processor
.
data_generator
(
batch_size
=
args
.
batch_size
,
phase
=
'dev'
,
epoch
=
1
,
dev_count
=
1
,
shuffle
=
False
))
print
(
"Final validation result:"
)
print
(
"Final validation result:"
)
evaluate
(
exe
,
test_prog
,
test
_pyreader
,
evaluate
(
exe
,
dev_prog
,
dev
_pyreader
,
[
loss
.
name
,
accuracy
.
name
,
num_seqs
.
name
],
"dev"
)
[
loss
.
name
,
accuracy
.
name
,
num_seqs
.
name
],
"dev"
)
# final eval on test set
# final eval on test set
if
args
.
do_test
:
if
args
.
do_test
:
test_pyreader
.
decorate_tensor_provider
(
processor
.
data_generator
(
batch_size
=
args
.
batch_size
,
phase
=
'test'
,
epoch
=
1
,
dev_count
=
1
,
shuffle
=
False
))
print
(
"Final test result:"
)
print
(
"Final test result:"
)
evaluate
(
exe
,
test_prog
,
test_pyreader
,
evaluate
(
exe
,
test_prog
,
test_pyreader
,
[
loss
.
name
,
accuracy
.
name
,
num_seqs
.
name
],
"test"
)
[
loss
.
name
,
accuracy
.
name
,
num_seqs
.
name
],
"test"
)
...
...
PaddleNLP/language_representations_kit/BERT/run_squad.py
浏览文件 @
4e553e2b
...
@@ -92,31 +92,39 @@ run_type_g.add_arg("do_predict", bool, True, "Whether to pe
...
@@ -92,31 +92,39 @@ run_type_g.add_arg("do_predict", bool, True, "Whether to pe
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
# yapf: enable.
# yapf: enable.
def
create_model
(
pyreader_name
,
bert_config
,
is_training
=
False
):
def
create_model
(
bert_config
,
is_training
=
False
):
if
is_training
:
if
is_training
:
pyreader
=
fluid
.
layers
.
py_reader
(
input_fields
=
{
capacity
=
50
,
'names'
:
[
'src_ids'
,
'pos_ids'
,
'sent_ids'
,
'input_mask'
,
'start_positions'
,
'end_positions'
]
,
shapes
=
[[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
args
.
max_seq_len
,
1
],
'shapes'
:
[[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
1
],
[
-
1
,
1
]],
[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
1
],
[
-
1
,
1
]],
dtypes
=
[
'dtypes'
:
[
'int64'
,
'int64'
,
'int64'
,
'float32'
,
'int64'
,
'int64'
],
'int64'
,
'int64'
,
'int64'
,
'float32'
,
'int64'
,
'int64'
],
lod_levels
=
[
0
,
0
,
0
,
0
,
0
,
0
],
'lod_levels'
:
[
0
,
0
,
0
,
0
,
0
,
0
],
name
=
pyreader_name
,
}
use_double_buffer
=
True
)
(
src_ids
,
pos_ids
,
sent_ids
,
input_mask
,
start_positions
,
end_positions
)
=
fluid
.
layers
.
read_file
(
pyreader
)
else
:
else
:
pyreader
=
fluid
.
layers
.
py_reader
(
input_fields
=
{
capacity
=
50
,
'names'
:
[
'src_ids'
,
'pos_ids'
,
'sent_ids'
,
'input_mask'
,
'unique_id'
]
,
shapes
=
[[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
args
.
max_seq_len
,
1
],
'shapes'
:
[[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
1
]],
[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
1
]],
dtypes
=
[
'int64'
,
'int64'
,
'int64'
,
'float32'
,
'int64'
],
'dtypes'
:
[
lod_levels
=
[
0
,
0
,
0
,
0
,
0
],
'int64'
,
'int64'
,
'int64'
,
'float32'
,
'int64'
],
name
=
pyreader_name
,
'lod_levels'
:
[
0
,
0
,
0
,
0
,
0
],
use_double_buffer
=
True
)
}
(
src_ids
,
pos_ids
,
sent_ids
,
input_mask
,
unique_id
)
=
fluid
.
layers
.
read_file
(
pyreader
)
inputs
=
[
fluid
.
layers
.
data
(
name
=
input_fields
[
'names'
][
i
],
shape
=
input_fields
[
'shapes'
][
i
],
dtype
=
input_fields
[
'dtypes'
][
i
],
lod_level
=
input_fields
[
'lod_levels'
][
i
])
for
i
in
range
(
len
(
input_fields
[
'names'
]))]
pyreader
=
fluid
.
io
.
PyReader
(
feed_list
=
inputs
,
capacity
=
50
,
iterable
=
False
)
if
is_training
:
(
src_ids
,
pos_ids
,
sent_ids
,
input_mask
,
start_positions
,
end_positions
)
=
inputs
else
:
(
src_ids
,
pos_ids
,
sent_ids
,
input_mask
,
unique_id
)
=
inputs
bert
=
BertModel
(
bert
=
BertModel
(
src_ids
=
src_ids
,
src_ids
=
src_ids
,
...
@@ -263,7 +271,6 @@ def train(args):
...
@@ -263,7 +271,6 @@ def train(args):
with
fluid
.
program_guard
(
train_program
,
startup_prog
):
with
fluid
.
program_guard
(
train_program
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
with
fluid
.
unique_name
.
guard
():
train_pyreader
,
loss
,
num_seqs
=
create_model
(
train_pyreader
,
loss
,
num_seqs
=
create_model
(
pyreader_name
=
'train_reader'
,
bert_config
=
bert_config
,
bert_config
=
bert_config
,
is_training
=
True
)
is_training
=
True
)
...
@@ -296,7 +303,6 @@ def train(args):
...
@@ -296,7 +303,6 @@ def train(args):
with
fluid
.
program_guard
(
test_prog
,
startup_prog
):
with
fluid
.
program_guard
(
test_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
with
fluid
.
unique_name
.
guard
():
test_pyreader
,
unique_ids
,
start_logits
,
end_logits
,
num_seqs
=
create_model
(
test_pyreader
,
unique_ids
,
start_logits
,
end_logits
,
num_seqs
=
create_model
(
pyreader_name
=
'test_reader'
,
bert_config
=
bert_config
,
bert_config
=
bert_config
,
is_training
=
False
)
is_training
=
False
)
...
@@ -341,7 +347,7 @@ def train(args):
...
@@ -341,7 +347,7 @@ def train(args):
train_compiled_program
=
fluid
.
CompiledProgram
(
train_program
).
with_data_parallel
(
train_compiled_program
=
fluid
.
CompiledProgram
(
train_program
).
with_data_parallel
(
loss_name
=
loss
.
name
,
exec_strategy
=
exec_strategy
)
loss_name
=
loss
.
name
,
exec_strategy
=
exec_strategy
)
train_pyreader
.
decorate_
tensor_provider
(
train_data_generator
)
train_pyreader
.
decorate_
batch_generator
(
train_data_generator
,
place
)
train_pyreader
.
start
()
train_pyreader
.
start
()
steps
=
0
steps
=
0
...
@@ -402,14 +408,14 @@ def train(args):
...
@@ -402,14 +408,14 @@ def train(args):
break
break
if
args
.
do_predict
:
if
args
.
do_predict
:
test_pyreader
.
decorate_
tensor_provide
r
(
test_pyreader
.
decorate_
batch_generato
r
(
processor
.
data_generator
(
processor
.
data_generator
(
data_path
=
args
.
predict_file
,
data_path
=
args
.
predict_file
,
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
phase
=
'predict'
,
phase
=
'predict'
,
shuffle
=
False
,
shuffle
=
False
,
dev_count
=
1
,
dev_count
=
1
,
epoch
=
1
))
epoch
=
1
)
,
place
)
predict
(
exe
,
test_prog
,
test_pyreader
,
[
predict
(
exe
,
test_prog
,
test_pyreader
,
[
unique_ids
.
name
,
start_logits
.
name
,
end_logits
.
name
,
num_seqs
.
name
unique_ids
.
name
,
start_logits
.
name
,
end_logits
.
name
,
num_seqs
.
name
...
...
PaddleNLP/language_representations_kit/BERT/train.py
浏览文件 @
4e553e2b
...
@@ -82,21 +82,24 @@ args = parser.parse_args()
...
@@ -82,21 +82,24 @@ args = parser.parse_args()
# yapf: enable.
# yapf: enable.
def
create_model
(
pyreader_name
,
bert_config
):
def
create_model
(
bert_config
):
pyreader
=
fluid
.
layers
.
py_reader
(
input_fields
=
{
capacity
=
70
,
'names'
:
[
'src_ids'
,
'pos_ids'
,
'sent_ids'
,
'input_mask'
,
'mask_label'
,
'mask_pos'
,
'labels'
]
,
shapes
=
[[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
args
.
max_seq_len
,
1
],
'shapes'
:
[[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
1
],
[
-
1
,
1
],
[
-
1
,
args
.
max_seq_len
,
1
],
[
-
1
,
1
],
[
-
1
,
1
],
[
-
1
,
1
]],
[
-
1
,
1
]],
'dtypes'
:
[
'int64'
,
'int64'
,
'int64'
,
'float32'
,
'int64'
,
'int64'
,
'int64'
],
dtypes
=
[
'lod_levels'
:
[
0
,
0
,
0
,
0
,
0
,
0
,
0
],
'int64'
,
'int64'
,
'int64'
,
'float32'
,
'int64'
,
'int64'
,
'int64'
}
],
lod_levels
=
[
0
,
0
,
0
,
0
,
0
,
0
,
0
],
name
=
pyreader_name
,
use_double_buffer
=
True
)
(
src_ids
,
pos_ids
,
sent_ids
,
input_mask
,
mask_label
,
mask_pos
,
labels
)
=
fluid
.
layers
.
read_file
(
pyreader
)
inputs
=
[
fluid
.
layers
.
data
(
name
=
input_fields
[
'names'
][
i
],
shape
=
input_fields
[
'shapes'
][
i
],
dtype
=
input_fields
[
'dtypes'
][
i
],
lod_level
=
input_fields
[
'lod_levels'
][
i
])
for
i
in
range
(
len
(
input_fields
[
'names'
]))]
(
src_ids
,
pos_ids
,
sent_ids
,
input_mask
,
mask_label
,
mask_pos
,
labels
)
=
inputs
pyreader
=
fluid
.
io
.
PyReader
(
feed_list
=
inputs
,
capacity
=
50
,
iterable
=
False
)
bert
=
BertModel
(
bert
=
BertModel
(
src_ids
=
src_ids
,
src_ids
=
src_ids
,
...
@@ -143,7 +146,7 @@ def predict_wrapper(args,
...
@@ -143,7 +146,7 @@ def predict_wrapper(args,
def
predict
(
exe
=
exe
,
pyreader
=
pyreader
):
def
predict
(
exe
=
exe
,
pyreader
=
pyreader
):
pyreader
.
decorate_
tensor_provide
r
(
data_reader
.
data_generator
())
pyreader
.
decorate_
batch_generato
r
(
data_reader
.
data_generator
())
pyreader
.
start
()
pyreader
.
start
()
cost
=
0
cost
=
0
...
@@ -181,7 +184,7 @@ def test(args):
...
@@ -181,7 +184,7 @@ def test(args):
with
fluid
.
program_guard
(
test_prog
,
test_startup
):
with
fluid
.
program_guard
(
test_prog
,
test_startup
):
with
fluid
.
unique_name
.
guard
():
with
fluid
.
unique_name
.
guard
():
test_pyreader
,
next_sent_acc
,
mask_lm_loss
,
total_loss
=
create_model
(
test_pyreader
,
next_sent_acc
,
mask_lm_loss
,
total_loss
=
create_model
(
pyreader_name
=
'test_reader'
,
bert_config
=
bert_config
)
bert_config
=
bert_config
)
test_prog
=
test_prog
.
clone
(
for_test
=
True
)
test_prog
=
test_prog
.
clone
(
for_test
=
True
)
...
@@ -216,7 +219,7 @@ def train(args):
...
@@ -216,7 +219,7 @@ def train(args):
with
fluid
.
program_guard
(
train_program
,
startup_prog
):
with
fluid
.
program_guard
(
train_program
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
with
fluid
.
unique_name
.
guard
():
train_pyreader
,
next_sent_acc
,
mask_lm_loss
,
total_loss
=
create_model
(
train_pyreader
,
next_sent_acc
,
mask_lm_loss
,
total_loss
=
create_model
(
pyreader_name
=
'train_reader'
,
bert_config
=
bert_config
)
bert_config
=
bert_config
)
scheduled_lr
=
optimization
(
scheduled_lr
=
optimization
(
loss
=
total_loss
,
loss
=
total_loss
,
warmup_steps
=
args
.
warmup_steps
,
warmup_steps
=
args
.
warmup_steps
,
...
@@ -229,17 +232,11 @@ def train(args):
...
@@ -229,17 +232,11 @@ def train(args):
use_fp16
=
args
.
use_fp16
,
use_fp16
=
args
.
use_fp16
,
loss_scaling
=
args
.
loss_scaling
)
loss_scaling
=
args
.
loss_scaling
)
fluid
.
memory_optimize
(
input_program
=
train_program
,
skip_opt_set
=
[
next_sent_acc
.
name
,
mask_lm_loss
.
name
,
total_loss
.
name
])
test_prog
=
fluid
.
Program
()
test_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
test_prog
,
startup_prog
):
with
fluid
.
program_guard
(
test_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
with
fluid
.
unique_name
.
guard
():
test_pyreader
,
next_sent_acc
,
mask_lm_loss
,
total_loss
=
create_model
(
test_pyreader
,
next_sent_acc
,
mask_lm_loss
,
total_loss
=
create_model
(
pyreader_name
=
'test_reader'
,
bert_config
=
bert_config
)
bert_config
=
bert_config
)
test_prog
=
test_prog
.
clone
(
for_test
=
True
)
test_prog
=
test_prog
.
clone
(
for_test
=
True
)
...
@@ -313,18 +310,16 @@ def train(args):
...
@@ -313,18 +310,16 @@ def train(args):
exec_strategy
.
num_threads
=
dev_count
exec_strategy
.
num_threads
=
dev_count
exec_strategy
.
num_iteration_per_drop_scope
=
args
.
num_iteration_per_drop_scope
exec_strategy
.
num_iteration_per_drop_scope
=
args
.
num_iteration_per_drop_scope
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
num_trainers
=
nccl2_num_trainers
build_strategy
.
trainer_id
=
nccl2_trainer_id
# use_ngraph is for CPU only, please refer to README_ngraph.md for details
# use_ngraph is for CPU only, please refer to README_ngraph.md for details
use_ngraph
=
os
.
getenv
(
'FLAGS_use_ngraph'
)
use_ngraph
=
os
.
getenv
(
'FLAGS_use_ngraph'
)
if
not
use_ngraph
:
if
not
use_ngraph
:
train_exe
=
fluid
.
ParallelExecutor
(
train_compiled_program
=
fluid
.
CompiledProgram
(
train_program
).
with_data_parallel
(
use_cuda
=
args
.
use_cuda
,
loss_name
=
total_loss
.
name
,
loss_name
=
total_loss
.
name
,
exec_strategy
=
exec_strategy
,
exec_strategy
=
exec_strategy
,
build_strategy
=
build_strategy
)
main_program
=
train_program
,
num_trainers
=
nccl2_num_trainers
,
trainer_id
=
nccl2_trainer_id
)
else
:
train_exe
=
exe
if
args
.
validation_set_dir
and
args
.
validation_set_dir
!=
""
:
if
args
.
validation_set_dir
and
args
.
validation_set_dir
!=
""
:
predict
=
predict_wrapper
(
predict
=
predict_wrapper
(
...
@@ -337,7 +332,7 @@ def train(args):
...
@@ -337,7 +332,7 @@ def train(args):
next_sent_acc
.
name
,
mask_lm_loss
.
name
,
total_loss
.
name
next_sent_acc
.
name
,
mask_lm_loss
.
name
,
total_loss
.
name
])
])
train_pyreader
.
decorate_
tensor_provide
r
(
data_reader
.
data_generator
())
train_pyreader
.
decorate_
batch_generato
r
(
data_reader
.
data_generator
())
train_pyreader
.
start
()
train_pyreader
.
start
()
steps
=
0
steps
=
0
cost
=
[]
cost
=
[]
...
@@ -351,28 +346,28 @@ def train(args):
...
@@ -351,28 +346,28 @@ def train(args):
if
nccl2_trainer_id
!=
0
:
if
nccl2_trainer_id
!=
0
:
if
use_ngraph
:
if
use_ngraph
:
train_
exe
.
run
(
fetch_list
=
[],
program
=
train_program
)
exe
.
run
(
fetch_list
=
[],
program
=
train_program
)
else
:
else
:
train_exe
.
run
(
fetch_list
=
[]
)
exe
.
run
(
fetch_list
=
[],
program
=
train_compiled_program
)
continue
continue
if
steps
%
skip_steps
!=
0
:
if
steps
%
skip_steps
!=
0
:
if
use_ngraph
:
if
use_ngraph
:
train_
exe
.
run
(
fetch_list
=
[],
program
=
train_program
)
exe
.
run
(
fetch_list
=
[],
program
=
train_program
)
else
:
else
:
train_exe
.
run
(
fetch_list
=
[]
)
exe
.
run
(
fetch_list
=
[],
program
=
train_compiled_program
)
else
:
else
:
if
use_ngraph
:
if
use_ngraph
:
each_next_acc
,
each_mask_lm_cost
,
each_total_cost
,
np_lr
=
train_
exe
.
run
(
each_next_acc
,
each_mask_lm_cost
,
each_total_cost
,
np_lr
=
exe
.
run
(
fetch_list
=
[
fetch_list
=
[
next_sent_acc
.
name
,
mask_lm_loss
.
name
,
total_loss
.
name
,
next_sent_acc
.
name
,
mask_lm_loss
.
name
,
total_loss
.
name
,
scheduled_lr
.
name
],
program
=
train_program
)
scheduled_lr
.
name
],
program
=
train_program
)
else
:
else
:
each_next_acc
,
each_mask_lm_cost
,
each_total_cost
,
np_lr
=
train_
exe
.
run
(
each_next_acc
,
each_mask_lm_cost
,
each_total_cost
,
np_lr
=
exe
.
run
(
fetch_list
=
[
fetch_list
=
[
next_sent_acc
.
name
,
mask_lm_loss
.
name
,
total_loss
.
name
,
next_sent_acc
.
name
,
mask_lm_loss
.
name
,
total_loss
.
name
,
scheduled_lr
.
name
])
scheduled_lr
.
name
]
,
program
=
train_compiled_program
)
acc
.
extend
(
each_next_acc
)
acc
.
extend
(
each_next_acc
)
lm_cost
.
extend
(
each_mask_lm_cost
)
lm_cost
.
extend
(
each_mask_lm_cost
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录