未验证 提交 e9820280 编写于 作者: L Leo Chen 提交者: GitHub

add custom op gelu for bert amp training (#5008)

* add custom op gelu

* refine run_pretrain
上级 f07cdf53
...@@ -257,7 +257,7 @@ def do_train(args): ...@@ -257,7 +257,7 @@ def do_train(args):
]) ])
if args.use_amp: if args.use_amp:
amp_list = paddle.fluid.contrib.mixed_precision.AutoMixedPrecisionLists( amp_list = paddle.fluid.contrib.mixed_precision.AutoMixedPrecisionLists(
custom_white_list=['softmax']) custom_white_list=['softmax', 'layer_norm', 'gelu'])
optimizer = paddle.fluid.contrib.mixed_precision.decorate( optimizer = paddle.fluid.contrib.mixed_precision.decorate(
optimizer, optimizer,
amp_list, amp_list,
......
...@@ -228,7 +228,7 @@ def do_train(args): ...@@ -228,7 +228,7 @@ def do_train(args):
]) ])
if args.use_amp: if args.use_amp:
amp_list = paddle.fluid.contrib.mixed_precision.AutoMixedPrecisionLists( amp_list = paddle.fluid.contrib.mixed_precision.AutoMixedPrecisionLists(
custom_white_list=['layer_norm', 'softmax']) custom_white_list=['layer_norm', 'softmax', 'gelu'])
optimizer = paddle.fluid.contrib.mixed_precision.decorate( optimizer = paddle.fluid.contrib.mixed_precision.decorate(
optimizer, optimizer,
amp_list, amp_list,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册