未验证 提交 62e732cf 编写于 作者: T taixiurong 提交者: GitHub

support xpu in bert (#5044)

上级 16693577
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
- 支持 BERT GPU 单机、分布式预训练 - 支持 BERT GPU 单机、分布式预训练
- 支持 BERT GPU 多卡 Fine-tuning - 支持 BERT GPU 多卡 Fine-tuning
- 支持 BERT XPU 单机 Fine-tuning
- 提供 BERT 预测接口 demo, 方便多硬件设备生产环境的部署 - 提供 BERT 预测接口 demo, 方便多硬件设备生产环境的部署
2)支持 FP16/FP32 混合精度训练和 Fine-tuning,节省显存开销、加速训练过程; 2)支持 FP16/FP32 混合精度训练和 Fine-tuning,节省显存开销、加速训练过程;
...@@ -105,6 +106,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ...@@ -105,6 +106,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
如果采用 CPU 多核的方式进行预训练,则需要通过环境设置所用 CPU 的核数,例如 `export CPU_NUM=5`,否则会占据所有的CPU。 如果采用 CPU 多核的方式进行预训练,则需要通过环境设置所用 CPU 的核数,例如 `export CPU_NUM=5`,否则会占据所有的CPU。
这里需要特别说明的是,参数 `generate_neg_sample``True` 表示在预训练过程中,`Next Sentence Prediction` 任务的负样本是根据训练数据中的正样本动态生成的,我们给出的样例训练数据 [`demo_wiki_train.gz`](data/train/demo_wiki_train.gz) 只包含 `Next Sentence Prediction` 任务的正样本;如果已事先构造了 `Next Sentence Prediction` 任务的正负样本,则需要将 `generate_neg_sample` 置为 `False` 这里需要特别说明的是,参数 `generate_neg_sample``True` 表示在预训练过程中,`Next Sentence Prediction` 任务的负样本是根据训练数据中的正样本动态生成的,我们给出的样例训练数据 [`demo_wiki_train.gz`](data/train/demo_wiki_train.gz) 只包含 `Next Sentence Prediction` 任务的正样本;如果已事先构造了 `Next Sentence Prediction` 任务的正负样本,则需要将 `generate_neg_sample` 置为 `False`
预训练任务进行的过程中会输出当前学习率、训练数据所经过的轮数、当前迭代的总步数、训练误差、训练速度等信息,根据 `--validation_steps ${N}` 的配置,每间隔 `N` 步输出模型在验证集的各种指标: 预训练任务进行的过程中会输出当前学习率、训练数据所经过的轮数、当前迭代的总步数、训练误差、训练速度等信息,根据 `--validation_steps ${N}` 的配置,每间隔 `N` 步输出模型在验证集的各种指标:
...@@ -183,6 +185,47 @@ python -u run_classifier.py --task_name ${TASK_NAME} \ ...@@ -183,6 +185,47 @@ python -u run_classifier.py --task_name ${TASK_NAME} \
--verbose true --verbose true
``` ```
以 XNLI 任务为例,启动 XPU Fine-tuning 的方式如下:
```shell
export FLAGS_sync_nccl_allreduce=0
export FLAGS_eager_delete_tensor_gb=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export FLAGS_selected_xpus=0
export XPUSIM_DEVICE_MODEL=KUNLUN1
export XPU_PADDLE_TRAIN_L3_SIZE=13631488
export XPU_PADDLE_MAIN_STREAM=0
BERT_BASE_PATH="chinese_L-12_H-768_A-12"
TASK_NAME='XNLI'
DATA_PATH=/path/to/xnli/data/
CKPT_PATH=/path/to/save/checkpoints/
python -u run_classifier.py --task_name ${TASK_NAME} \
--use_cuda false \
--use_xpu true \
--do_train true \
--do_val true \
--do_test true \
--batch_size 16 \
--in_tokens false \
--init_pretraining_params ${BERT_BASE_PATH}/params \
--data_dir ${DATA_PATH} \
--vocab_path ${BERT_BASE_PATH}/vocab.txt \
--checkpoints ${CKPT_PATH} \
--save_steps 1000 \
--weight_decay 0.01 \
--warmup_proportion 0.1 \
--validation_steps 100 \
--epoch 3 \
--max_seq_len 128 \
--bert_config_path ${BERT_BASE_PATH}/bert_config.json \
--learning_rate 5e-5 \
--skip_steps 10 \
--num_iteration_per_drop_scope 10 \
--verbose true
```
这里的 `chinese_L-12_H-768_A-12` 即是转换后的中文预训练模型。需要注意的是,BERT on PaddlePaddle 支持按两种方式构建一个 batch 的数据,`in_tokens` 参数影响 `batch_size` 参数的意义,如果 `in_tokens``true` 则按照 token 个数构建 batch, 如不设定则按照 example 个数来构建 batch. 训练过程中会输出训练误差、训练速度等信息,训练结束后会输出如下所示的在验证集上的测试结果: 这里的 `chinese_L-12_H-768_A-12` 即是转换后的中文预训练模型。需要注意的是,BERT on PaddlePaddle 支持按两种方式构建一个 batch 的数据,`in_tokens` 参数影响 `batch_size` 参数的意义,如果 `in_tokens``true` 则按照 token 个数构建 batch, 如不设定则按照 example 个数来构建 batch. 训练过程中会输出训练误差、训练速度等信息,训练结束后会输出如下所示的在验证集上的测试结果:
``` ```
......
...@@ -38,7 +38,7 @@ import reader.cls as reader ...@@ -38,7 +38,7 @@ import reader.cls as reader
from model.bert import BertConfig from model.bert import BertConfig
from model.classifier import create_model from model.classifier import create_model
from optimization import optimization from optimization import optimization
from utils.args import ArgumentGroup, print_arguments, check_cuda, check_version from utils.args import ArgumentGroup, print_arguments, check_cuda, check_xpu, check_version
from utils.init import init_pretraining_params, init_checkpoint from utils.init import init_pretraining_params, init_checkpoint
from utils.cards import get_cards from utils.cards import get_cards
import dist_utils import dist_utils
...@@ -101,6 +101,7 @@ run_type_g.add_arg("is_profiler", int, 0, "the profiler ...@@ -101,6 +101,7 @@ run_type_g.add_arg("is_profiler", int, 0, "the profiler
run_type_g.add_arg("max_iter", int, 0, "the max batch nums to train. (used for benchmark)") run_type_g.add_arg("max_iter", int, 0, "the max batch nums to train. (used for benchmark)")
run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.") run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.")
run_type_g.add_arg("use_xpu", bool, True, "If set, use XPU for training.")
run_type_g.add_arg("use_fast_executor", bool, False, "If set, use fast parallel executor (in experiment).") run_type_g.add_arg("use_fast_executor", bool, False, "If set, use fast parallel executor (in experiment).")
run_type_g.add_arg("shuffle", bool, True, "") run_type_g.add_arg("shuffle", bool, True, "")
run_type_g.add_arg("num_iteration_per_drop_scope", int, 1, "Ihe iteration intervals to clean up temporary variables.") run_type_g.add_arg("num_iteration_per_drop_scope", int, 1, "Ihe iteration intervals to clean up temporary variables.")
...@@ -148,10 +149,17 @@ def get_device_num(): ...@@ -148,10 +149,17 @@ def get_device_num():
def main(args): def main(args):
bert_config = BertConfig(args.bert_config_path) bert_config = BertConfig(args.bert_config_path)
bert_config.print_config() bert_config.print_config()
if args.use_xpu:
paddle.enable_static()
if args.use_cuda: if args.use_cuda:
place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0'))) place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0')))
dev_count = get_device_num() dev_count = get_device_num()
elif args.use_xpu:
xpu_id = int(os.getenv('FLAGS_selected_xpus', '0'))
place = fluid.XPUPlace(xpu_id)
dev_count = len([place])
else: else:
place = fluid.CPUPlace() place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
...@@ -311,8 +319,12 @@ def main(args): ...@@ -311,8 +319,12 @@ def main(args):
train_data_generator = fluid.contrib.reader.distributed_batch_reader( train_data_generator = fluid.contrib.reader.distributed_batch_reader(
train_data_generator) train_data_generator)
train_compiled_program = fluid.CompiledProgram(train_program).with_data_parallel( if args.use_xpu:
loss_name=loss.name, build_strategy=build_strategy) train_compiled_program = train_program
else:
train_compiled_program = fluid.CompiledProgram(train_program).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
train_data_loader.set_batch_generator(train_data_generator, place) train_data_loader.set_batch_generator(train_data_generator, place)
...@@ -449,5 +461,6 @@ if __name__ == '__main__': ...@@ -449,5 +461,6 @@ if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
print_arguments(args) print_arguments(args)
check_cuda(args.use_cuda) check_cuda(args.use_cuda)
check_xpu(args.use_xpu)
check_version() check_version()
main(args) main(args)
...@@ -61,6 +61,16 @@ def check_cuda(use_cuda, err = \ ...@@ -61,6 +61,16 @@ def check_cuda(use_cuda, err = \
except Exception as e: except Exception as e:
pass pass
def check_xpu(use_xpu, err = \
"\nYou can not set use_xpu = True in the model because you are using paddlepaddle-cpu or paddlepaddle-gpu.\n \
Please: 1. Install paddlepaddle-xpu to run your models on XPU or 2. Set use_xpu = False to run models on CPU.\n"
):
try:
if use_xpu == True and fluid.is_compiled_with_xpu() == False:
print(err)
sys.exit(1)
except Exception as e:
pass
def check_version(): def check_version():
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册