From 17a81df6621f80796e2dda14faa16f66bffd2234 Mon Sep 17 00:00:00 2001 From: WangXi Date: Fri, 2 Jul 2021 20:32:04 +0800 Subject: [PATCH] fix fleet amp get_loss_scaling (#33935) --- .../paddle/distributed/fleet/meta_optimizers/amp_optimizer.py | 3 +++ python/paddle/fluid/tests/unittests/test_fleet_amp_init.py | 1 + 2 files changed, 4 insertions(+) diff --git a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py index 9ffb47789ee..e3a781424e6 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py @@ -124,3 +124,6 @@ class AMPOptimizer(MetaOptimizerBase): use_fp16_test=False): return self.wrapped_opt.amp_init(place, scope, test_program, use_fp16_test) + + def get_loss_scaling(self): + return self.wrapped_opt.get_loss_scaling() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py b/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py index 6930a330a7c..a9a6b9c0660 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py @@ -117,6 +117,7 @@ class TestFleetAMPInit(unittest.TestCase): optimizer.minimize(cost) print(fleet._get_applied_meta_list()) + loss_scale = optimizer.get_loss_scaling() place = paddle.CUDAPlace(0) -- GitLab