diff --git a/BERT/README.md b/BERT/README.md index a6ef69f3a1f00232c5a6926ae06fb10ea8fb93d5..8bdd4fb9a1230c82e75614f9cbe7fc7b0b5642c2 100644 --- a/BERT/README.md +++ b/BERT/README.md @@ -77,6 +77,8 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./train.sh -local y ``` +如果采用 CPU 多核的方式进行预训练,则需要通过环境设置所用 CPU 的核数,例如 `export CPU_NUM=5`,否则会占据所有的CPU。 + 这里需要特别说明的是,参数 `generate_neg_sample` 为 `True` 表示在预训练过程中,`Next Sentence Prediction` 任务的负样本是根据训练数据中的正样本动态生成的,我们给出的样例训练数据 [`demo_wiki_train.gz`](data/train/demo_wiki_train.gz) 只包含 `Next Sentence Prediction` 任务的正样本;如果已事先构造了 `Next Sentence Prediction` 任务的正负样本,则需要将 `generate_neg_sample` 置为 `False`。 预训练任务进行的过程中会输出当前学习率、训练数据所经过的轮数、当前迭代的总步数、训练误差、训练速度等信息,根据 `--validation_steps ${N}` 的配置,每间隔 `N` 步输出模型在验证集的各种指标: @@ -122,8 +124,8 @@ export current_endpoint=192.168.0.17:9185 对于 [GLUE 数据](https://gluebenchmark.com/tasks),请运行这个[脚本](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)予以下载; 对于 XNLI 任务,则需分别下载 [XNLI dev/test set](https://bert-data.bj.bcebos.com/XNLI-1.0.zip) 和 [XNLI machine-translated training set](https://bert-data.bj.bcebos.com/XNLI-MT-1.0.zip),然后解压到同一个目录。以 XNLI 任务为例,启动 Fine-tuning 的方式如下: ```shell -export FLAGS_enable_parallel_graph=1 -export FLAGS_sync_nccl_allreduce=1 +export FLAGS_sync_nccl_allreduce=0 +export FLAGS_eager_delete_tensor_gb=1 export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 BERT_BASE_PATH="chinese_L-12_H-768_A-12" @@ -183,8 +185,8 @@ SQuAD v1.1 对于 SQuAD v1.1, 按如下方式启动 Fine-tuning: ```shell -export FLAGS_enable_parallel_graph=1 -export FLAGS_sync_nccl_allreduce=1 +export FLAGS_sync_nccl_allreduce=0 +export FLAGS_eager_delete_tensor_gb=1 export CUDA_VISIBLE_DEVICES=0,1,2,3 BERT_BASE_PATH="uncased_L-12_H-768_A-12" @@ -229,6 +231,8 @@ python ${SQUAD_PATH}/evaluate-v1.1.py ${SQUAD_PATH}/dev-v1.1.json ${CHECKPOINT_P 对于 SQuAD v2.0, 按如下方式启动 Fine-tuning: ```shell +export FLAGS_sync_nccl_allreduce=0 +export FLAGS_eager_delete_tensor_gb=1 export CUDA_VISIBLE_DEVICES=0,1,2,3 BERT_BASE_PATH="uncased_L-12_H-768_A-12" CHECKPOINT_PATH=/path/to/save/checkpoints/ diff --git a/BERT/run_classifier.py b/BERT/run_classifier.py index fd0c592789baaa9c99cd8964ba6eff4b3076dadc..e8583587e64b6d7bf67bdbfaf6150ee1be33502e 100644 --- a/BERT/run_classifier.py +++ b/BERT/run_classifier.py @@ -208,12 +208,6 @@ def main(args): use_fp16=args.use_fp16, loss_scaling=args.loss_scaling) - fluid.memory_optimize( - input_program=train_program, - skip_opt_set=[ - loss.name, probs.name, accuracy.name, num_seqs.name - ]) - if args.verbose: if args.in_tokens: lower_mem, upper_mem, unit = fluid.contrib.memory_usage( @@ -279,22 +273,11 @@ def main(args): train_data_generator = fluid.contrib.reader.distributed_batch_reader( train_data_generator) - train_exe = fluid.ParallelExecutor( - use_cuda=args.use_cuda, - loss_name=loss.name, - exec_strategy=exec_strategy, - build_strategy = build_strategy, - main_program=train_program) + train_compiled_program = fluid.CompiledProgram(train_program).with_data_parallel( + loss_name=loss.name, build_strategy=build_strategy) train_pyreader.decorate_tensor_provider(train_data_generator) - else: - train_exe = None - if args.do_val or args.do_test: - test_exe = fluid.ParallelExecutor( - use_cuda=args.use_cuda, - main_program=test_prog, - share_vars_from=train_exe) if args.do_train: train_pyreader.start() @@ -317,7 +300,7 @@ def main(args): else: fetch_list = [] - outputs = train_exe.run(fetch_list=fetch_list) + outputs = exe.run(train_compiled_program, fetch_list=fetch_list) if steps % args.skip_steps == 0: if warmup_steps <= 0: diff --git a/BERT/run_squad.py b/BERT/run_squad.py index 3d4a23f913a07a46c5b332cf23bd3cd3e97f18df..514b815878cc9862c2b1fa98f69bacebfc176fc0 100644 --- a/BERT/run_squad.py +++ b/BERT/run_squad.py @@ -279,7 +279,6 @@ def train(args): use_fp16=args.use_fp16, loss_scaling=args.loss_scaling) - fluid.memory_optimize(train_program, skip_opt_set=[loss.name, num_seqs.name]) if args.verbose: if args.in_tokens: @@ -301,8 +300,6 @@ def train(args): bert_config=bert_config, is_training=False) - fluid.memory_optimize(test_prog, skip_opt_set=[unique_ids.name, - start_logits.name, end_logits.name, num_seqs.name]) test_prog = test_prog.clone(for_test=True) @@ -341,11 +338,8 @@ def train(args): exec_strategy.num_threads = dev_count exec_strategy.num_iteration_per_drop_scope = args.num_iteration_per_drop_scope - train_exe = fluid.ParallelExecutor( - use_cuda=args.use_cuda, - loss_name=loss.name, - exec_strategy=exec_strategy, - main_program=train_program) + train_compiled_program = fluid.CompiledProgram(train_program).with_data_parallel( + loss_name=loss.name, exec_strategy=exec_strategy) train_pyreader.decorate_tensor_provider(train_data_generator) @@ -366,7 +360,7 @@ def train(args): else: fetch_list = [] - outputs = train_exe.run(fetch_list=fetch_list) + outputs = exe.run(train_compiled_program, fetch_list=fetch_list) if steps % args.skip_steps == 0: if warmup_steps <= 0: