From cdb8e50ac32fe1e118a4678671496a51511d4cb4 Mon Sep 17 00:00:00 2001 From: Guo Sheng Date: Thu, 4 Feb 2021 12:37:19 +0800 Subject: [PATCH] Add amp support for BERT. (#5198) --- .../examples/language_model/bert/README.md | 12 ++++-- .../examples/language_model/bert/run_glue.py | 27 ++++++++++-- .../language_model/bert/run_pretrain.py | 42 +++++++++++++------ 3 files changed, 61 insertions(+), 20 deletions(-) diff --git a/PaddleNLP/examples/language_model/bert/README.md b/PaddleNLP/examples/language_model/bert/README.md index 895552a6..ed15dffe 100644 --- a/PaddleNLP/examples/language_model/bert/README.md +++ b/PaddleNLP/examples/language_model/bert/README.md @@ -74,7 +74,8 @@ python -u ./run_pretrain.py \ --logging_steps 1 \ --save_steps 20000 \ --max_steps 1000000 \ - --n_cards 1 + --n_cards 1 \ + --use_amp False ``` 其中参数释义如下: @@ -92,7 +93,8 @@ python -u ./run_pretrain.py \ - `logging_steps` 表示日志打印间隔。 - `save_steps` 表示模型保存及评估间隔。 - `max_steps` 表示最大训练步数。若训练`num_train_epochs`轮包含的训练步数大于该值,则达到`max_steps`后就提前结束。 -- `n_gpu` 表示使用的 GPU 卡数。若希望使用多卡训练,将其设置为指定数目即可;若为0,则使用CPU。 +- `n_cards` 表示使用的 GPU 卡数。若希望使用多卡训练,将其设置为指定数目即可。 +- `use_amp` 指示是否启用自动混合精度训练。 ### 执行Fine-tunning @@ -110,7 +112,8 @@ python -u ./run_glue.py \ --logging_steps 1 \ --save_steps 500 \ --output_dir ./tmp/ \ - --n_cards 1 + --n_cards 1 \ + --use_amp False ``` 其中参数释义如下: @@ -124,7 +127,8 @@ python -u ./run_glue.py \ - `logging_steps` 表示日志打印间隔。 - `save_steps` 表示模型保存及评估间隔。 - `output_dir` 表示模型保存路径。 -- `n_gpu` 表示使用的 GPU 卡数。若希望使用多卡训练,将其设置为指定数目即可;若为0,则使用CPU。 +- `n_cards` 表示使用的 GPU 卡数。若希望使用多卡训练,将其设置为指定数目即可。 +- `use_amp` 指示是否启用自动混合精度训练。 基于`bert-base-uncased`在GLUE各评测任务上Fine-tuning后,在验证集上有如下结果: diff --git a/PaddleNLP/examples/language_model/bert/run_glue.py b/PaddleNLP/examples/language_model/bert/run_glue.py index b0712c63..f0d36bc5 100644 --- a/PaddleNLP/examples/language_model/bert/run_glue.py +++ b/PaddleNLP/examples/language_model/bert/run_glue.py @@ -19,6 +19,7 @@ import sys import random import time import math +import distutils.util from functools import partial import numpy as np @@ -161,6 +162,14 @@ def parse_args(): type=str, default="gpu", help="Device for selecting for the training.") + parser.add_argument("--use_amp", + type=distutils.util.strtobool, + default=False, + help="Enable mixed precision training.") + parser.add_argument("--scale_loss", + type=float, + default=2**15, + help="The value of scale_loss for fp16.") args = parser.parse_args() return args @@ -380,16 +389,26 @@ def do_train(args): metric = metric_class() + if args.use_amp: + scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss) + global_step = 0 tic_train = time.time() for epoch in range(args.num_train_epochs): for step, batch in enumerate(train_data_loader): global_step += 1 input_ids, segment_ids, labels = batch - logits = model(input_ids, segment_ids) - loss = loss_fct(logits, labels) - loss.backward() - optimizer.step() + with paddle.amp.auto_cast( + args.use_amp, + custom_white_list=["layer_norm", "softmax", "gelu"]): + logits = model(input_ids, segment_ids) + loss = loss_fct(logits, labels) + if args.use_amp: + scaler.scale(loss).backward() + scaler.minimize(optimizer, loss) + else: + loss.backward() + optimizer.step() lr_scheduler.step() optimizer.clear_gradients() if global_step % args.logging_steps == 0: diff --git a/PaddleNLP/examples/language_model/bert/run_pretrain.py b/PaddleNLP/examples/language_model/bert/run_pretrain.py index 6327b2e5..57978bff 100644 --- a/PaddleNLP/examples/language_model/bert/run_pretrain.py +++ b/PaddleNLP/examples/language_model/bert/run_pretrain.py @@ -20,6 +20,7 @@ import os import random import time import h5py +import distutils.util from functools import partial from concurrent.futures import ThreadPoolExecutor @@ -146,6 +147,14 @@ def parse_args(): type=str, default="gpu", help="Device for selecting for the training.") + parser.add_argument("--use_amp", + type=distutils.util.strtobool, + default=False, + help="Enable mixed precision training.") + parser.add_argument("--scale_loss", + type=float, + default=2**15, + help="The value of scale_loss for fp16.") args = parser.parse_args() return args @@ -313,6 +322,8 @@ def do_train(args): p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ]) + if args.use_amp: + scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss) pool = ThreadPoolExecutor(1) global_step = 0 @@ -370,14 +381,25 @@ def do_train(args): (input_ids, segment_ids, input_mask, masked_lm_positions, masked_lm_labels, next_sentence_labels, masked_lm_scale) = batch - prediction_scores, seq_relationship_score = model( - input_ids=input_ids, - token_type_ids=segment_ids, - attention_mask=input_mask, - masked_positions=masked_lm_positions) - loss = criterion(prediction_scores, seq_relationship_score, - masked_lm_labels, next_sentence_labels, - masked_lm_scale) + with paddle.amp.auto_cast( + args.use_amp, + custom_white_list=["layer_norm", "softmax", "gelu"]): + prediction_scores, seq_relationship_score = model( + input_ids=input_ids, + token_type_ids=segment_ids, + attention_mask=input_mask, + masked_positions=masked_lm_positions) + loss = criterion(prediction_scores, seq_relationship_score, + masked_lm_labels, next_sentence_labels, + masked_lm_scale) + if args.use_amp: + scaler.scale(loss).backward() + scaler.minimize(optimizer, loss) + else: + loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.clear_gradients() if global_step % args.logging_steps == 0: if (not args.n_cards > 1 ) or paddle.distributed.get_rank() == 0: @@ -386,10 +408,6 @@ def do_train(args): % (global_step, epoch, step, loss, args.logging_steps / (time.time() - tic_train))) tic_train = time.time() - loss.backward() - optimizer.step() - lr_scheduler.step() - optimizer.clear_gradients() if global_step % args.save_steps == 0: if (not args.n_cards > 1 ) or paddle.distributed.get_rank() == 0: -- GitLab