未验证 提交 cdb8e50a 编写于 作者: G Guo Sheng 提交者: GitHub

Add amp support for BERT. (#5198)

上级 98683b59
...@@ -74,7 +74,8 @@ python -u ./run_pretrain.py \ ...@@ -74,7 +74,8 @@ python -u ./run_pretrain.py \
--logging_steps 1 \ --logging_steps 1 \
--save_steps 20000 \ --save_steps 20000 \
--max_steps 1000000 \ --max_steps 1000000 \
--n_cards 1 --n_cards 1 \
--use_amp False
``` ```
其中参数释义如下: 其中参数释义如下:
...@@ -92,7 +93,8 @@ python -u ./run_pretrain.py \ ...@@ -92,7 +93,8 @@ python -u ./run_pretrain.py \
- `logging_steps` 表示日志打印间隔。 - `logging_steps` 表示日志打印间隔。
- `save_steps` 表示模型保存及评估间隔。 - `save_steps` 表示模型保存及评估间隔。
- `max_steps` 表示最大训练步数。若训练`num_train_epochs`轮包含的训练步数大于该值,则达到`max_steps`后就提前结束。 - `max_steps` 表示最大训练步数。若训练`num_train_epochs`轮包含的训练步数大于该值,则达到`max_steps`后就提前结束。
- `n_gpu` 表示使用的 GPU 卡数。若希望使用多卡训练,将其设置为指定数目即可;若为0,则使用CPU。 - `n_cards` 表示使用的 GPU 卡数。若希望使用多卡训练,将其设置为指定数目即可。
- `use_amp` 指示是否启用自动混合精度训练。
### 执行Fine-tunning ### 执行Fine-tunning
...@@ -110,7 +112,8 @@ python -u ./run_glue.py \ ...@@ -110,7 +112,8 @@ python -u ./run_glue.py \
--logging_steps 1 \ --logging_steps 1 \
--save_steps 500 \ --save_steps 500 \
--output_dir ./tmp/ \ --output_dir ./tmp/ \
--n_cards 1 --n_cards 1 \
--use_amp False
``` ```
其中参数释义如下: 其中参数释义如下:
...@@ -124,7 +127,8 @@ python -u ./run_glue.py \ ...@@ -124,7 +127,8 @@ python -u ./run_glue.py \
- `logging_steps` 表示日志打印间隔。 - `logging_steps` 表示日志打印间隔。
- `save_steps` 表示模型保存及评估间隔。 - `save_steps` 表示模型保存及评估间隔。
- `output_dir` 表示模型保存路径。 - `output_dir` 表示模型保存路径。
- `n_gpu` 表示使用的 GPU 卡数。若希望使用多卡训练,将其设置为指定数目即可;若为0,则使用CPU。 - `n_cards` 表示使用的 GPU 卡数。若希望使用多卡训练,将其设置为指定数目即可。
- `use_amp` 指示是否启用自动混合精度训练。
基于`bert-base-uncased`在GLUE各评测任务上Fine-tuning后,在验证集上有如下结果: 基于`bert-base-uncased`在GLUE各评测任务上Fine-tuning后,在验证集上有如下结果:
......
...@@ -19,6 +19,7 @@ import sys ...@@ -19,6 +19,7 @@ import sys
import random import random
import time import time
import math import math
import distutils.util
from functools import partial from functools import partial
import numpy as np import numpy as np
...@@ -161,6 +162,14 @@ def parse_args(): ...@@ -161,6 +162,14 @@ def parse_args():
type=str, type=str,
default="gpu", default="gpu",
help="Device for selecting for the training.") help="Device for selecting for the training.")
parser.add_argument("--use_amp",
type=distutils.util.strtobool,
default=False,
help="Enable mixed precision training.")
parser.add_argument("--scale_loss",
type=float,
default=2**15,
help="The value of scale_loss for fp16.")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -380,14 +389,24 @@ def do_train(args): ...@@ -380,14 +389,24 @@ def do_train(args):
metric = metric_class() metric = metric_class()
if args.use_amp:
scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)
global_step = 0 global_step = 0
tic_train = time.time() tic_train = time.time()
for epoch in range(args.num_train_epochs): for epoch in range(args.num_train_epochs):
for step, batch in enumerate(train_data_loader): for step, batch in enumerate(train_data_loader):
global_step += 1 global_step += 1
input_ids, segment_ids, labels = batch input_ids, segment_ids, labels = batch
with paddle.amp.auto_cast(
args.use_amp,
custom_white_list=["layer_norm", "softmax", "gelu"]):
logits = model(input_ids, segment_ids) logits = model(input_ids, segment_ids)
loss = loss_fct(logits, labels) loss = loss_fct(logits, labels)
if args.use_amp:
scaler.scale(loss).backward()
scaler.minimize(optimizer, loss)
else:
loss.backward() loss.backward()
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
......
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
import random import random
import time import time
import h5py import h5py
import distutils.util
from functools import partial from functools import partial
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
...@@ -146,6 +147,14 @@ def parse_args(): ...@@ -146,6 +147,14 @@ def parse_args():
type=str, type=str,
default="gpu", default="gpu",
help="Device for selecting for the training.") help="Device for selecting for the training.")
parser.add_argument("--use_amp",
type=distutils.util.strtobool,
default=False,
help="Enable mixed precision training.")
parser.add_argument("--scale_loss",
type=float,
default=2**15,
help="The value of scale_loss for fp16.")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -313,6 +322,8 @@ def do_train(args): ...@@ -313,6 +322,8 @@ def do_train(args):
p.name for n, p in model.named_parameters() p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"]) if not any(nd in n for nd in ["bias", "norm"])
]) ])
if args.use_amp:
scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)
pool = ThreadPoolExecutor(1) pool = ThreadPoolExecutor(1)
global_step = 0 global_step = 0
...@@ -370,6 +381,9 @@ def do_train(args): ...@@ -370,6 +381,9 @@ def do_train(args):
(input_ids, segment_ids, input_mask, masked_lm_positions, (input_ids, segment_ids, input_mask, masked_lm_positions,
masked_lm_labels, next_sentence_labels, masked_lm_labels, next_sentence_labels,
masked_lm_scale) = batch masked_lm_scale) = batch
with paddle.amp.auto_cast(
args.use_amp,
custom_white_list=["layer_norm", "softmax", "gelu"]):
prediction_scores, seq_relationship_score = model( prediction_scores, seq_relationship_score = model(
input_ids=input_ids, input_ids=input_ids,
token_type_ids=segment_ids, token_type_ids=segment_ids,
...@@ -378,6 +392,14 @@ def do_train(args): ...@@ -378,6 +392,14 @@ def do_train(args):
loss = criterion(prediction_scores, seq_relationship_score, loss = criterion(prediction_scores, seq_relationship_score,
masked_lm_labels, next_sentence_labels, masked_lm_labels, next_sentence_labels,
masked_lm_scale) masked_lm_scale)
if args.use_amp:
scaler.scale(loss).backward()
scaler.minimize(optimizer, loss)
else:
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.clear_gradients()
if global_step % args.logging_steps == 0: if global_step % args.logging_steps == 0:
if (not args.n_cards > 1 if (not args.n_cards > 1
) or paddle.distributed.get_rank() == 0: ) or paddle.distributed.get_rank() == 0:
...@@ -386,10 +408,6 @@ def do_train(args): ...@@ -386,10 +408,6 @@ def do_train(args):
% (global_step, epoch, step, loss, % (global_step, epoch, step, loss,
args.logging_steps / (time.time() - tic_train))) args.logging_steps / (time.time() - tic_train)))
tic_train = time.time() tic_train = time.time()
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.clear_gradients()
if global_step % args.save_steps == 0: if global_step % args.save_steps == 0:
if (not args.n_cards > 1 if (not args.n_cards > 1
) or paddle.distributed.get_rank() == 0: ) or paddle.distributed.get_rank() == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册