未验证 提交 413d3ee1 编写于 作者: Z Zhen Wang 提交者: GitHub

Add pure fp16 training. (#5052)

* add pure fp16 training.

* Add lookup_table in custom_black_list during the pure fp16 training.
上级 16cd0f6f
...@@ -56,6 +56,11 @@ def create_pretraining_dataset(input_file, ...@@ -56,6 +56,11 @@ def create_pretraining_dataset(input_file,
mask_token_num += 1 mask_token_num += 1
# mask_token_num # mask_token_num
out.append(np.asarray([mask_token_num], dtype=np.float32)) out.append(np.asarray([mask_token_num], dtype=np.float32))
if args.use_amp and args.use_pure_fp16:
# cast input_mask to fp16
out[2] = out[2].astype(np.float16)
# cast masked_lm_scale to fp16
out[-1] = out[-1].astype(np.float16)
return out return out
train_data_loader = DataLoader( train_data_loader = DataLoader(
......
...@@ -141,6 +141,11 @@ def parse_args(): ...@@ -141,6 +141,11 @@ def parse_args():
type=float, type=float,
default=1.0, default=1.0,
help="The value of scale_loss for fp16.") help="The value of scale_loss for fp16.")
parser.add_argument(
"--use_pure_fp16",
type=distutils.util.strtobool,
default=False,
help="Whether to use pure fp16 training.")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -236,15 +241,20 @@ def do_train(args): ...@@ -236,15 +241,20 @@ def do_train(args):
apply_decay_param_fun=lambda x: x in [ apply_decay_param_fun=lambda x: x in [
p.name for n, p in model.named_parameters() p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"]) if not any(nd in n for nd in ["bias", "norm"])
]) ],
multi_precision=args.use_pure_fp16)
if args.use_amp: if args.use_amp:
amp_list = paddle.fluid.contrib.mixed_precision.AutoMixedPrecisionLists( custom_black_list=(['lookup_table', 'lookup_table_v2']
custom_white_list=['layer_norm', 'softmax', 'gelu']) if args.use_pure_fp16 else None)
optimizer = paddle.fluid.contrib.mixed_precision.decorate( amp_list = paddle.static.amp.AutoMixedPrecisionLists(
custom_white_list=['layer_norm', 'softmax', 'gelu'],
custom_black_list=custom_black_list)
optimizer = paddle.static.amp.decorate(
optimizer, optimizer,
amp_list, amp_list,
init_loss_scaling=args.scale_loss, init_loss_scaling=args.scale_loss,
use_dynamic_loss_scaling=True) use_dynamic_loss_scaling=True,
use_pure_fp16=args.use_pure_fp16)
optimizer.minimize(loss) optimizer.minimize(loss)
# Define the Executor for running the static model # Define the Executor for running the static model
...@@ -255,6 +265,8 @@ def do_train(args): ...@@ -255,6 +265,8 @@ def do_train(args):
# Use the state dict to update the parameter # Use the state dict to update the parameter
reset_state_dict = reset_program_state_dict(model, state_dict) reset_state_dict = reset_program_state_dict(model, state_dict)
paddle.static.set_program_state(main_program, reset_state_dict) paddle.static.set_program_state(main_program, reset_state_dict)
if args.use_amp:
optimizer.amp_init(place)
# Construct the compiled program # Construct the compiled program
main_program = build_compiled_program(args, main_program, loss) main_program = build_compiled_program(args, main_program, loss)
global_step = 0 global_step = 0
......
...@@ -476,15 +476,16 @@ class BertForPretraining(BertPretrainedModel): ...@@ -476,15 +476,16 @@ class BertForPretraining(BertPretrainedModel):
position_ids=None, position_ids=None,
attention_mask=None, attention_mask=None,
masked_positions=None): masked_positions=None):
outputs = self.bert( with paddle.static.amp.fp16_guard():
input_ids, outputs = self.bert(
token_type_ids=token_type_ids, input_ids,
position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask) position_ids=position_ids,
sequence_output, pooled_output = outputs[:2] attention_mask=attention_mask)
prediction_scores, seq_relationship_score = self.cls( sequence_output, pooled_output = outputs[:2]
sequence_output, pooled_output, masked_positions) prediction_scores, seq_relationship_score = self.cls(
return prediction_scores, seq_relationship_score sequence_output, pooled_output, masked_positions)
return prediction_scores, seq_relationship_score
class BertPretrainingCriterion(paddle.nn.Layer): class BertPretrainingCriterion(paddle.nn.Layer):
...@@ -496,9 +497,10 @@ class BertPretrainingCriterion(paddle.nn.Layer): ...@@ -496,9 +497,10 @@ class BertPretrainingCriterion(paddle.nn.Layer):
def forward(self, prediction_scores, seq_relationship_score, def forward(self, prediction_scores, seq_relationship_score,
masked_lm_labels, next_sentence_labels, masked_lm_scale): masked_lm_labels, next_sentence_labels, masked_lm_scale):
masked_lm_loss = paddle.nn.functional.softmax_with_cross_entropy( with paddle.static.amp.fp16_guard():
prediction_scores, masked_lm_labels, ignore_index=-1) masked_lm_loss = paddle.nn.functional.softmax_with_cross_entropy(
masked_lm_loss = masked_lm_loss / masked_lm_scale prediction_scores, masked_lm_labels, ignore_index=-1)
next_sentence_loss = paddle.nn.functional.softmax_with_cross_entropy( masked_lm_loss = masked_lm_loss / masked_lm_scale
seq_relationship_score, next_sentence_labels) next_sentence_loss = paddle.nn.functional.softmax_with_cross_entropy(
seq_relationship_score, next_sentence_labels)
return paddle.sum(masked_lm_loss) + paddle.mean(next_sentence_loss) return paddle.sum(masked_lm_loss) + paddle.mean(next_sentence_loss)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册