未验证 提交 413d3ee1 编写于 作者: Z Zhen Wang 提交者: GitHub

Add pure fp16 training. (#5052)

* add pure fp16 training.

* Add lookup_table in custom_black_list during the pure fp16 training.
上级 16cd0f6f
......@@ -56,6 +56,11 @@ def create_pretraining_dataset(input_file,
mask_token_num += 1
# mask_token_num
out.append(np.asarray([mask_token_num], dtype=np.float32))
if args.use_amp and args.use_pure_fp16:
# cast input_mask to fp16
out[2] = out[2].astype(np.float16)
# cast masked_lm_scale to fp16
out[-1] = out[-1].astype(np.float16)
return out
train_data_loader = DataLoader(
......
......@@ -141,6 +141,11 @@ def parse_args():
type=float,
default=1.0,
help="The value of scale_loss for fp16.")
parser.add_argument(
"--use_pure_fp16",
type=distutils.util.strtobool,
default=False,
help="Whether to use pure fp16 training.")
args = parser.parse_args()
return args
......@@ -236,15 +241,20 @@ def do_train(args):
apply_decay_param_fun=lambda x: x in [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
])
],
multi_precision=args.use_pure_fp16)
if args.use_amp:
amp_list = paddle.fluid.contrib.mixed_precision.AutoMixedPrecisionLists(
custom_white_list=['layer_norm', 'softmax', 'gelu'])
optimizer = paddle.fluid.contrib.mixed_precision.decorate(
custom_black_list=(['lookup_table', 'lookup_table_v2']
if args.use_pure_fp16 else None)
amp_list = paddle.static.amp.AutoMixedPrecisionLists(
custom_white_list=['layer_norm', 'softmax', 'gelu'],
custom_black_list=custom_black_list)
optimizer = paddle.static.amp.decorate(
optimizer,
amp_list,
init_loss_scaling=args.scale_loss,
use_dynamic_loss_scaling=True)
use_dynamic_loss_scaling=True,
use_pure_fp16=args.use_pure_fp16)
optimizer.minimize(loss)
# Define the Executor for running the static model
......@@ -255,6 +265,8 @@ def do_train(args):
# Use the state dict to update the parameter
reset_state_dict = reset_program_state_dict(model, state_dict)
paddle.static.set_program_state(main_program, reset_state_dict)
if args.use_amp:
optimizer.amp_init(place)
# Construct the compiled program
main_program = build_compiled_program(args, main_program, loss)
global_step = 0
......
......@@ -476,6 +476,7 @@ class BertForPretraining(BertPretrainedModel):
position_ids=None,
attention_mask=None,
masked_positions=None):
with paddle.static.amp.fp16_guard():
outputs = self.bert(
input_ids,
token_type_ids=token_type_ids,
......@@ -496,6 +497,7 @@ class BertPretrainingCriterion(paddle.nn.Layer):
def forward(self, prediction_scores, seq_relationship_score,
masked_lm_labels, next_sentence_labels, masked_lm_scale):
with paddle.static.amp.fp16_guard():
masked_lm_loss = paddle.nn.functional.softmax_with_cross_entropy(
prediction_scores, masked_lm_labels, ignore_index=-1)
masked_lm_loss = masked_lm_loss / masked_lm_scale
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册