提交 539750e3 编写于 作者: S ShawnXuan

bert reduce_mean on fp32

上级 0e816bb5
......@@ -39,6 +39,7 @@ def PreTrain(
type_vocab_size=16,
max_predictions_per_seq=20,
initializer_range=0.02,
use_fp16=False,
):
backbone = bert_util.BertBackbone(
input_ids_blob=input_ids_blob,
......@@ -81,6 +82,9 @@ def PreTrain(
initializer_range=initializer_range,
)
with flow.scope.namespace("cls-loss"):
if use_fp16:
lm_loss = flow.reduce_mean(lm_loss)
ns_loss = flow.reduce_mean(ns_loss)
total_loss = lm_loss + ns_loss
return total_loss, lm_loss, ns_loss
......
......@@ -89,6 +89,7 @@ def PretrainJob():
type_vocab_size=args.type_vocab_size,
max_predictions_per_seq=args.max_predictions_per_seq,
initializer_range=0.02,
use_fp16=args.use_fp16,
)
opt = CreateOptimizer(args)
opt.minimize(total_loss)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册