未验证 提交 4e553e2b 编写于 作者: Y Yibing Liu 提交者: GitHub

Use gc, PyReader & compiledprogram for bert (#3035)

上级 80283e6d
...@@ -84,7 +84,7 @@ cd models/PaddleNLP/sentiment_classification ...@@ -84,7 +84,7 @@ cd models/PaddleNLP/sentiment_classification
- [机器翻译](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/neural_machine_translation/transformer) - [机器翻译](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/neural_machine_translation/transformer)
### 语义表示与语言模型 ### 语义表示与语言模型
- [语言表示工具箱](https://github.com/PaddlePaddle/LARK/tree/develop) - [语言表示工具箱](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/language_representations_kit)
- [语言模型](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/language_model) - [语言模型](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/language_model)
### 复杂任务 ### 复杂任务
......
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
| Model | Layers | Hidden size | Heads |Parameters | | Model | Layers | Hidden size | Heads |Parameters |
| :------| :------: | :------: |:------: |:------: | | :------| :------: | :------: |:------: |:------: |
| [BERT-Large, Uncased (Whole Word Masking)](https://bert-models.bj.bcebos.com/wwm_uncased_L-24_H-1024_A-16.tar.gz)| 24 | 1024 | 16 | 340M |
| [BERT-Large, Cased (Whole Word Masking)](https://bert-models.bj.bcebos.com/wwm_cased_L-24_H-1024_A-16.tar.gz)| 24 | 1024 | 16 | 340M |
| [BERT-Base, Uncased](https://bert-models.bj.bcebos.com/uncased_L-12_H-768_A-12.tar.gz) | 12 | 768 |12 |110M | | [BERT-Base, Uncased](https://bert-models.bj.bcebos.com/uncased_L-12_H-768_A-12.tar.gz) | 12 | 768 |12 |110M |
| [BERT-Large, Uncased](https://bert-models.bj.bcebos.com/uncased_L-24_H-1024_A-16.tar.gz) | 24 | 1024 |16 |340M | | [BERT-Large, Uncased](https://bert-models.bj.bcebos.com/uncased_L-24_H-1024_A-16.tar.gz) | 24 | 1024 |16 |340M |
|[BERT-Base, Cased](https://bert-models.bj.bcebos.com/cased_L-12_H-768_A-12.tar.gz)|12|768|12|110M| |[BERT-Base, Cased](https://bert-models.bj.bcebos.com/cased_L-12_H-768_A-12.tar.gz)|12|768|12|110M|
...@@ -46,7 +48,7 @@ ...@@ -46,7 +48,7 @@
- [inference 接口调用示例](#inference-接口调用示例) - [inference 接口调用示例](#inference-接口调用示例)
## 安装 ## 安装
本项目依赖于 Paddle Fluid **1.3.1**,请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装。如果需要进行 TensorFlow 模型到 Paddle Fluid 参数的转换,则需要同时安装 TensorFlow 1.12。 本项目依赖于 Paddle Fluid **1.5.1**,请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装。如果需要进行 TensorFlow 模型到 Paddle Fluid 参数的转换,则需要同时安装 TensorFlow 1.12。
## 预训练 ## 预训练
...@@ -138,22 +140,24 @@ python -u run_classifier.py --task_name ${TASK_NAME} \ ...@@ -138,22 +140,24 @@ python -u run_classifier.py --task_name ${TASK_NAME} \
--do_train true \ --do_train true \
--do_val true \ --do_val true \
--do_test true \ --do_test true \
--batch_size 8192 \ --batch_size 32 \
--in_tokens true \ --in_tokens false \
--init_pretraining_params ${BERT_BASE_PATH}/params \ --init_pretraining_params ${BERT_BASE_PATH}/params \
--data_dir ${DATA_PATH} \ --data_dir ${DATA_PATH} \
--vocab_path ${BERT_BASE_PATH}/vocab.txt \ --vocab_path ${BERT_BASE_PATH}/vocab.txt \
--checkpoints ${CKPT_PATH} \ --checkpoints ${CKPT_PATH} \
--save_steps 1000 \ --save_steps 1000 \
--weight_decay 0.01 \ --weight_decay 0.01 \
--warmup_proportion 0.0 \ --warmup_proportion 0.1 \
--validation_steps 25 \ --validation_steps 100 \
--epoch 3 \ --epoch 3 \
--max_seq_len 512 \ --max_seq_len 128 \
--bert_config_path ${BERT_BASE_PATH}/bert_config.json \ --bert_config_path ${BERT_BASE_PATH}/bert_config.json \
--learning_rate 1e-4 \ --learning_rate 5e-5 \
--skip_steps 10 \ --skip_steps 10 \
--random_seed 1 --num_iteration_per_drop_scope 10 \
--use_fp16 true \
--verbose true
``` ```
这里的 `chinese_L-12_H-768_A-12` 即是转换后的中文预训练模型。需要注意的是,BERT on PaddlePaddle 支持按两种方式构建一个 batch 的数据,`in_tokens` 参数影响 `batch_size` 参数的意义,如果 `in_tokens``true` 则按照 token 个数构建 batch, 如不设定则按照 example 个数来构建 batch. 训练过程中会输出训练误差、训练速度等信息,训练结束后会输出如下所示的在验证集上的测试结果: 这里的 `chinese_L-12_H-768_A-12` 即是转换后的中文预训练模型。需要注意的是,BERT on PaddlePaddle 支持按两种方式构建一个 batch 的数据,`in_tokens` 参数影响 `batch_size` 参数的意义,如果 `in_tokens``true` 则按照 token 个数构建 batch, 如不设定则按照 example 个数来构建 batch. 训练过程中会输出训练误差、训练速度等信息,训练结束后会输出如下所示的在验证集上的测试结果:
......
...@@ -22,22 +22,27 @@ import paddle.fluid as fluid ...@@ -22,22 +22,27 @@ import paddle.fluid as fluid
from model.bert import BertModel from model.bert import BertModel
def create_model(args, def create_model(args, bert_config, num_labels, is_prediction=False):
pyreader_name, input_fields = {
bert_config, 'names': ['src_ids', 'pos_ids', 'sent_ids', 'input_mask', 'labels'],
num_labels, 'shapes':
is_prediction=False): [[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1],
pyreader = fluid.layers.py_reader( [-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], [-1, 1]],
capacity=50, 'dtypes': ['int64', 'int64', 'int64', 'float32', 'int64'],
shapes=[[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], 'lod_levels': [0, 0, 0, 0, 0],
[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], [-1, 1]], }
dtypes=['int64', 'int64', 'int64', 'float32', 'int64'],
lod_levels=[0, 0, 0, 0, 0],
name=pyreader_name,
use_double_buffer=True)
(src_ids, pos_ids, sent_ids, input_mask, inputs = [
labels) = fluid.layers.read_file(pyreader) fluid.layers.data(
name=input_fields['names'][i],
shape=input_fields['shapes'][i],
dtype=input_fields['dtypes'][i],
lod_level=input_fields['lod_levels'][i])
for i in range(len(input_fields['names']))
]
(src_ids, pos_ids, sent_ids, input_mask, labels) = inputs
pyreader = fluid.io.PyReader(feed_list=inputs, capacity=50, iterable=False)
bert = BertModel( bert = BertModel(
src_ids=src_ids, src_ids=src_ids,
......
...@@ -84,7 +84,6 @@ def main(args): ...@@ -84,7 +84,6 @@ def main(args):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
predict_pyreader, probs, feed_target_names = create_model( predict_pyreader, probs, feed_target_names = create_model(
args, args,
pyreader_name='predict_reader',
bert_config=bert_config, bert_config=bert_config,
num_labels=num_labels, num_labels=num_labels,
is_prediction=True) is_prediction=True)
...@@ -103,7 +102,7 @@ def main(args): ...@@ -103,7 +102,7 @@ def main(args):
exe.run(predict_startup) exe.run(predict_startup)
if args.init_checkpoint: if args.init_checkpoint:
init_pretraining_params(exe, args.init_checkpoint, predict_prog) init_pretraining_params(exe, args.init_checkpoint, predict_prog, args.use_fp16)
else: else:
raise ValueError("args 'init_checkpoint' should be set for prediction!") raise ValueError("args 'init_checkpoint' should be set for prediction!")
...@@ -113,7 +112,7 @@ def main(args): ...@@ -113,7 +112,7 @@ def main(args):
predict_exe = fluid.ParallelExecutor( predict_exe = fluid.ParallelExecutor(
use_cuda=args.use_cuda, main_program=predict_prog) use_cuda=args.use_cuda, main_program=predict_prog)
predict_pyreader.decorate_tensor_provider( predict_pyreader.decorate_batch_generator(
processor.data_generator( processor.data_generator(
batch_size=args.batch_size, phase='test', epoch=1, shuffle=False)) batch_size=args.batch_size, phase='test', epoch=1, shuffle=False))
......
...@@ -193,7 +193,6 @@ def main(args): ...@@ -193,7 +193,6 @@ def main(args):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
train_pyreader, loss, probs, accuracy, num_seqs = create_model( train_pyreader, loss, probs, accuracy, num_seqs = create_model(
args, args,
pyreader_name='train_reader',
bert_config=bert_config, bert_config=bert_config,
num_labels=num_labels) num_labels=num_labels)
scheduled_lr = optimization( scheduled_lr = optimization(
...@@ -219,17 +218,41 @@ def main(args): ...@@ -219,17 +218,41 @@ def main(args):
print("Theoretical memory usage in training: %.3f - %.3f %s" % print("Theoretical memory usage in training: %.3f - %.3f %s" %
(lower_mem, upper_mem, unit)) (lower_mem, upper_mem, unit))
if args.do_val or args.do_test: if args.do_val:
dev_prog = fluid.Program()
with fluid.program_guard(dev_prog, startup_prog):
with fluid.unique_name.guard():
dev_pyreader, loss, probs, accuracy, num_seqs = create_model(
args,
bert_config=bert_config,
num_labels=num_labels)
dev_prog = dev_prog.clone(for_test=True)
dev_pyreader.decorate_batch_generator(
processor.data_generator(
batch_size=args.batch_size,
phase='dev',
epoch=1,
dev_count=1,
shuffle=False), place)
if args.do_test:
test_prog = fluid.Program() test_prog = fluid.Program()
with fluid.program_guard(test_prog, startup_prog): with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
test_pyreader, loss, probs, accuracy, num_seqs = create_model( test_pyreader, loss, probs, accuracy, num_seqs = create_model(
args, args,
pyreader_name='test_reader',
bert_config=bert_config, bert_config=bert_config,
num_labels=num_labels) num_labels=num_labels)
test_prog = test_prog.clone(for_test=True) test_prog = test_prog.clone(for_test=True)
test_pyreader.decorate_batch_generator(
processor.data_generator(
batch_size=args.batch_size,
phase='test',
epoch=1,
dev_count=1,
shuffle=False), place)
exe.run(startup_prog) exe.run(startup_prog)
...@@ -276,7 +299,7 @@ def main(args): ...@@ -276,7 +299,7 @@ def main(args):
train_compiled_program = fluid.CompiledProgram(train_program).with_data_parallel( train_compiled_program = fluid.CompiledProgram(train_program).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy) loss_name=loss.name, build_strategy=build_strategy)
train_pyreader.decorate_tensor_provider(train_data_generator) train_pyreader.decorate_batch_generator(train_data_generator, place)
if args.do_train: if args.do_train:
...@@ -350,25 +373,11 @@ def main(args): ...@@ -350,25 +373,11 @@ def main(args):
throughput = [] throughput = []
# evaluate dev set # evaluate dev set
if args.do_val: if args.do_val:
test_pyreader.decorate_tensor_provider( evaluate(exe, dev_prog, dev_pyreader,
processor.data_generator(
batch_size=args.batch_size,
phase='dev',
epoch=1,
dev_count=1,
shuffle=False))
evaluate(exe, test_prog, test_pyreader,
[loss.name, accuracy.name, num_seqs.name], [loss.name, accuracy.name, num_seqs.name],
"dev") "dev")
# evaluate test set # evaluate test set
if args.do_test: if args.do_test:
test_pyreader.decorate_tensor_provider(
processor.data_generator(
batch_size=args.batch_size,
phase='test',
epoch=1,
dev_count=1,
shuffle=False))
evaluate(exe, test_prog, test_pyreader, evaluate(exe, test_prog, test_pyreader,
[loss.name, accuracy.name, num_seqs.name], [loss.name, accuracy.name, num_seqs.name],
"test") "test")
...@@ -398,23 +407,12 @@ def main(args): ...@@ -398,23 +407,12 @@ def main(args):
# final eval on dev set # final eval on dev set
if args.do_val: if args.do_val:
test_pyreader.decorate_tensor_provider(
processor.data_generator(
batch_size=args.batch_size, phase='dev', epoch=1, dev_count=1,
shuffle=False))
print("Final validation result:") print("Final validation result:")
evaluate(exe, test_prog, test_pyreader, evaluate(exe, dev_prog, dev_pyreader,
[loss.name, accuracy.name, num_seqs.name], "dev") [loss.name, accuracy.name, num_seqs.name], "dev")
# final eval on test set # final eval on test set
if args.do_test: if args.do_test:
test_pyreader.decorate_tensor_provider(
processor.data_generator(
batch_size=args.batch_size,
phase='test',
epoch=1,
dev_count=1,
shuffle=False))
print("Final test result:") print("Final test result:")
evaluate(exe, test_prog, test_pyreader, evaluate(exe, test_prog, test_pyreader,
[loss.name, accuracy.name, num_seqs.name], "test") [loss.name, accuracy.name, num_seqs.name], "test")
......
...@@ -92,31 +92,39 @@ run_type_g.add_arg("do_predict", bool, True, "Whether to pe ...@@ -92,31 +92,39 @@ run_type_g.add_arg("do_predict", bool, True, "Whether to pe
args = parser.parse_args() args = parser.parse_args()
# yapf: enable. # yapf: enable.
def create_model(pyreader_name, bert_config, is_training=False): def create_model(bert_config, is_training=False):
if is_training: if is_training:
pyreader = fluid.layers.py_reader( input_fields = {
capacity=50, 'names': ['src_ids', 'pos_ids', 'sent_ids', 'input_mask', 'start_positions', 'end_positions'],
shapes=[[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], 'shapes': [[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1],
[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1],
[-1, args.max_seq_len, 1], [-1, 1], [-1, 1]], [-1, args.max_seq_len, 1], [-1, 1], [-1, 1]],
dtypes=[ 'dtypes': [
'int64', 'int64', 'int64', 'float32', 'int64', 'int64'], 'int64', 'int64', 'int64', 'float32', 'int64', 'int64'],
lod_levels=[0, 0, 0, 0, 0, 0], 'lod_levels': [0, 0, 0, 0, 0, 0],
name=pyreader_name, }
use_double_buffer=True)
(src_ids, pos_ids, sent_ids, input_mask, start_positions,
end_positions) = fluid.layers.read_file(pyreader)
else: else:
pyreader = fluid.layers.py_reader( input_fields = {
capacity=50, 'names': ['src_ids', 'pos_ids', 'sent_ids', 'input_mask', 'unique_id'],
shapes=[[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], 'shapes': [[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1],
[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1],
[-1, args.max_seq_len, 1], [-1, 1]], [-1, args.max_seq_len, 1], [-1, 1]],
dtypes=['int64', 'int64', 'int64', 'float32', 'int64'], 'dtypes': [
lod_levels=[0, 0, 0, 0, 0], 'int64', 'int64', 'int64', 'float32', 'int64'],
name=pyreader_name, 'lod_levels': [0, 0, 0, 0, 0],
use_double_buffer=True) }
(src_ids, pos_ids, sent_ids, input_mask, unique_id) = fluid.layers.read_file(pyreader)
inputs = [fluid.layers.data(name=input_fields['names'][i],
shape=input_fields['shapes'][i],
dtype=input_fields['dtypes'][i],
lod_level=input_fields['lod_levels'][i]) for i in range(len(input_fields['names']))]
pyreader = fluid.io.PyReader(feed_list=inputs, capacity=50, iterable=False)
if is_training:
(src_ids, pos_ids, sent_ids, input_mask, start_positions, end_positions) = inputs
else:
(src_ids, pos_ids, sent_ids, input_mask, unique_id) = inputs
bert = BertModel( bert = BertModel(
src_ids=src_ids, src_ids=src_ids,
...@@ -263,7 +271,6 @@ def train(args): ...@@ -263,7 +271,6 @@ def train(args):
with fluid.program_guard(train_program, startup_prog): with fluid.program_guard(train_program, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
train_pyreader, loss, num_seqs = create_model( train_pyreader, loss, num_seqs = create_model(
pyreader_name='train_reader',
bert_config=bert_config, bert_config=bert_config,
is_training=True) is_training=True)
...@@ -296,7 +303,6 @@ def train(args): ...@@ -296,7 +303,6 @@ def train(args):
with fluid.program_guard(test_prog, startup_prog): with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
test_pyreader, unique_ids, start_logits, end_logits, num_seqs = create_model( test_pyreader, unique_ids, start_logits, end_logits, num_seqs = create_model(
pyreader_name='test_reader',
bert_config=bert_config, bert_config=bert_config,
is_training=False) is_training=False)
...@@ -341,7 +347,7 @@ def train(args): ...@@ -341,7 +347,7 @@ def train(args):
train_compiled_program = fluid.CompiledProgram(train_program).with_data_parallel( train_compiled_program = fluid.CompiledProgram(train_program).with_data_parallel(
loss_name=loss.name, exec_strategy=exec_strategy) loss_name=loss.name, exec_strategy=exec_strategy)
train_pyreader.decorate_tensor_provider(train_data_generator) train_pyreader.decorate_batch_generator(train_data_generator, place)
train_pyreader.start() train_pyreader.start()
steps = 0 steps = 0
...@@ -402,14 +408,14 @@ def train(args): ...@@ -402,14 +408,14 @@ def train(args):
break break
if args.do_predict: if args.do_predict:
test_pyreader.decorate_tensor_provider( test_pyreader.decorate_batch_generator(
processor.data_generator( processor.data_generator(
data_path=args.predict_file, data_path=args.predict_file,
batch_size=args.batch_size, batch_size=args.batch_size,
phase='predict', phase='predict',
shuffle=False, shuffle=False,
dev_count=1, dev_count=1,
epoch=1)) epoch=1), place)
predict(exe, test_prog, test_pyreader, [ predict(exe, test_prog, test_pyreader, [
unique_ids.name, start_logits.name, end_logits.name, num_seqs.name unique_ids.name, start_logits.name, end_logits.name, num_seqs.name
......
...@@ -82,21 +82,24 @@ args = parser.parse_args() ...@@ -82,21 +82,24 @@ args = parser.parse_args()
# yapf: enable. # yapf: enable.
def create_model(pyreader_name, bert_config): def create_model(bert_config):
pyreader = fluid.layers.py_reader( input_fields = {
capacity=70, 'names': ['src_ids', 'pos_ids', 'sent_ids', 'input_mask', 'mask_label', 'mask_pos', 'labels'],
shapes=[[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], 'shapes': [[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1],
[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1],
[-1, args.max_seq_len, 1], [-1, 1], [-1, 1], [-1, args.max_seq_len, 1], [-1, 1], [-1, 1], [-1, 1]],
[-1, 1]], 'dtypes': ['int64', 'int64', 'int64', 'float32', 'int64', 'int64', 'int64'],
dtypes=[ 'lod_levels': [0, 0, 0, 0, 0, 0, 0],
'int64', 'int64', 'int64', 'float32', 'int64', 'int64', 'int64' }
],
lod_levels=[0, 0, 0, 0, 0, 0, 0],
name=pyreader_name,
use_double_buffer=True)
(src_ids, pos_ids, sent_ids, input_mask, mask_label, mask_pos, labels) = fluid.layers.read_file(pyreader) inputs = [fluid.layers.data(name=input_fields['names'][i],
shape=input_fields['shapes'][i],
dtype=input_fields['dtypes'][i],
lod_level=input_fields['lod_levels'][i]) for i in range(len(input_fields['names']))]
(src_ids, pos_ids, sent_ids, input_mask, mask_label, mask_pos, labels) = inputs
pyreader = fluid.io.PyReader(feed_list=inputs, capacity=50, iterable=False)
bert = BertModel( bert = BertModel(
src_ids=src_ids, src_ids=src_ids,
...@@ -143,7 +146,7 @@ def predict_wrapper(args, ...@@ -143,7 +146,7 @@ def predict_wrapper(args,
def predict(exe=exe, pyreader=pyreader): def predict(exe=exe, pyreader=pyreader):
pyreader.decorate_tensor_provider(data_reader.data_generator()) pyreader.decorate_batch_generator(data_reader.data_generator())
pyreader.start() pyreader.start()
cost = 0 cost = 0
...@@ -181,7 +184,7 @@ def test(args): ...@@ -181,7 +184,7 @@ def test(args):
with fluid.program_guard(test_prog, test_startup): with fluid.program_guard(test_prog, test_startup):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
test_pyreader, next_sent_acc, mask_lm_loss, total_loss = create_model( test_pyreader, next_sent_acc, mask_lm_loss, total_loss = create_model(
pyreader_name='test_reader', bert_config=bert_config) bert_config=bert_config)
test_prog = test_prog.clone(for_test=True) test_prog = test_prog.clone(for_test=True)
...@@ -216,7 +219,7 @@ def train(args): ...@@ -216,7 +219,7 @@ def train(args):
with fluid.program_guard(train_program, startup_prog): with fluid.program_guard(train_program, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
train_pyreader, next_sent_acc, mask_lm_loss, total_loss = create_model( train_pyreader, next_sent_acc, mask_lm_loss, total_loss = create_model(
pyreader_name='train_reader', bert_config=bert_config) bert_config=bert_config)
scheduled_lr = optimization( scheduled_lr = optimization(
loss=total_loss, loss=total_loss,
warmup_steps=args.warmup_steps, warmup_steps=args.warmup_steps,
...@@ -229,17 +232,11 @@ def train(args): ...@@ -229,17 +232,11 @@ def train(args):
use_fp16=args.use_fp16, use_fp16=args.use_fp16,
loss_scaling=args.loss_scaling) loss_scaling=args.loss_scaling)
fluid.memory_optimize(
input_program=train_program,
skip_opt_set=[
next_sent_acc.name, mask_lm_loss.name, total_loss.name
])
test_prog = fluid.Program() test_prog = fluid.Program()
with fluid.program_guard(test_prog, startup_prog): with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
test_pyreader, next_sent_acc, mask_lm_loss, total_loss = create_model( test_pyreader, next_sent_acc, mask_lm_loss, total_loss = create_model(
pyreader_name='test_reader', bert_config=bert_config) bert_config=bert_config)
test_prog = test_prog.clone(for_test=True) test_prog = test_prog.clone(for_test=True)
...@@ -313,18 +310,16 @@ def train(args): ...@@ -313,18 +310,16 @@ def train(args):
exec_strategy.num_threads = dev_count exec_strategy.num_threads = dev_count
exec_strategy.num_iteration_per_drop_scope = args.num_iteration_per_drop_scope exec_strategy.num_iteration_per_drop_scope = args.num_iteration_per_drop_scope
build_strategy = fluid.BuildStrategy()
build_strategy.num_trainers = nccl2_num_trainers
build_strategy.trainer_id = nccl2_trainer_id
# use_ngraph is for CPU only, please refer to README_ngraph.md for details # use_ngraph is for CPU only, please refer to README_ngraph.md for details
use_ngraph = os.getenv('FLAGS_use_ngraph') use_ngraph = os.getenv('FLAGS_use_ngraph')
if not use_ngraph: if not use_ngraph:
train_exe = fluid.ParallelExecutor( train_compiled_program = fluid.CompiledProgram(train_program).with_data_parallel(
use_cuda=args.use_cuda, loss_name=total_loss.name,
loss_name=total_loss.name, exec_strategy=exec_strategy,
exec_strategy=exec_strategy, build_strategy=build_strategy)
main_program=train_program,
num_trainers=nccl2_num_trainers,
trainer_id=nccl2_trainer_id)
else:
train_exe = exe
if args.validation_set_dir and args.validation_set_dir != "": if args.validation_set_dir and args.validation_set_dir != "":
predict = predict_wrapper( predict = predict_wrapper(
...@@ -337,7 +332,7 @@ def train(args): ...@@ -337,7 +332,7 @@ def train(args):
next_sent_acc.name, mask_lm_loss.name, total_loss.name next_sent_acc.name, mask_lm_loss.name, total_loss.name
]) ])
train_pyreader.decorate_tensor_provider(data_reader.data_generator()) train_pyreader.decorate_batch_generator(data_reader.data_generator())
train_pyreader.start() train_pyreader.start()
steps = 0 steps = 0
cost = [] cost = []
...@@ -351,28 +346,28 @@ def train(args): ...@@ -351,28 +346,28 @@ def train(args):
if nccl2_trainer_id != 0: if nccl2_trainer_id != 0:
if use_ngraph: if use_ngraph:
train_exe.run(fetch_list=[], program=train_program) exe.run(fetch_list=[], program=train_program)
else: else:
train_exe.run(fetch_list=[]) exe.run(fetch_list=[], program=train_compiled_program)
continue continue
if steps % skip_steps != 0: if steps % skip_steps != 0:
if use_ngraph: if use_ngraph:
train_exe.run(fetch_list=[], program=train_program) exe.run(fetch_list=[], program=train_program)
else: else:
train_exe.run(fetch_list=[]) exe.run(fetch_list=[], program=train_compiled_program)
else: else:
if use_ngraph: if use_ngraph:
each_next_acc, each_mask_lm_cost, each_total_cost, np_lr = train_exe.run( each_next_acc, each_mask_lm_cost, each_total_cost, np_lr = exe.run(
fetch_list=[ fetch_list=[
next_sent_acc.name, mask_lm_loss.name, total_loss.name, next_sent_acc.name, mask_lm_loss.name, total_loss.name,
scheduled_lr.name], program=train_program) scheduled_lr.name], program=train_program)
else: else:
each_next_acc, each_mask_lm_cost, each_total_cost, np_lr = train_exe.run( each_next_acc, each_mask_lm_cost, each_total_cost, np_lr = exe.run(
fetch_list=[ fetch_list=[
next_sent_acc.name, mask_lm_loss.name, total_loss.name, next_sent_acc.name, mask_lm_loss.name, total_loss.name,
scheduled_lr.name]) scheduled_lr.name], program=train_compiled_program)
acc.extend(each_next_acc) acc.extend(each_next_acc)
lm_cost.extend(each_mask_lm_cost) lm_cost.extend(each_mask_lm_cost)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册