未验证 提交 17a81df6 编写于 作者: W WangXi 提交者: GitHub

fix fleet amp get_loss_scaling (#33935)

上级 42431948
......@@ -124,3 +124,6 @@ class AMPOptimizer(MetaOptimizerBase):
use_fp16_test=False):
return self.wrapped_opt.amp_init(place, scope, test_program,
use_fp16_test)
def get_loss_scaling(self):
return self.wrapped_opt.get_loss_scaling()
......@@ -117,6 +117,7 @@ class TestFleetAMPInit(unittest.TestCase):
optimizer.minimize(cost)
print(fleet._get_applied_meta_list())
loss_scale = optimizer.get_loss_scaling()
place = paddle.CUDAPlace(0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册