diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index d71f7e774057b161691adbff13fbfc441355e49a..5e17794dfeac1255c05eefc280bdd07943690418 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -1041,6 +1041,26 @@ class Fleet(object): # imitate target optimizer retrieval return self.user_defined_optimizer.clear_grad() + def _get_amp_optimizer(self): + # imitate target optimizer retrieval + amp_optimizer = None + for optimizer in self.strategy_compiler._get_applied_meta_optimizer(): + if hasattr(optimizer, 'amp_init'): + amp_optimizer = optimizer + break + + if amp_optimizer is None: + if hasattr(self.user_defined_optimizer, 'amp_init'): + amp_optimizer = self.user_defined_optimizer + + assert amp_optimizer is not None, \ + "amp_init can only be used when the amp(auto mixed precision) strategy is turned on." + return amp_optimizer + + def get_loss_scaling(self): + amp_optimizer = self._get_amp_optimizer() + return amp_optimizer.get_loss_scaling() + def amp_init(self, place, scope=None, @@ -1101,21 +1121,7 @@ class Fleet(object): if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0: run_example_code() """ - - # imitate target optimizer retrieval - amp_optimizer = None - for optimizer in self.strategy_compiler._get_applied_meta_optimizer(): - if hasattr(optimizer, 'amp_init'): - amp_optimizer = optimizer - break - - if amp_optimizer is None: - if hasattr(self.user_defined_optimizer, 'amp_init'): - amp_optimizer = self.user_defined_optimizer - - assert amp_optimizer is not None, \ - "amp_init can only be used when the amp(auto mixed precision) strategy is turned on." - + amp_optimizer = self._get_amp_optimizer() return amp_optimizer.amp_init(place, scope, test_program, use_fp16_test) def _final_strategy(self): diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator.py b/python/paddle/fluid/contrib/mixed_precision/decorator.py index d37e90b4695d03b5c9caa71c65c8624e558d1065..724f707c2e1f065c63de01e861244f10cf4cf7da 100644 --- a/python/paddle/fluid/contrib/mixed_precision/decorator.py +++ b/python/paddle/fluid/contrib/mixed_precision/decorator.py @@ -98,6 +98,7 @@ class OptimizerWithMixedPrecision(object): def get_loss_scaling(self): """Return the real-time loss scaling factor. """ + assert self._loss_scaling is not None, 'Call minimize() before calling get_loss_scaling()' return self._loss_scaling def get_scaled_loss(self): diff --git a/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py b/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py index 869ca41a1923daa099112df95f9b8e3b520883d7..6930a330a7c315780c11fe40cdc0ae90803d4fe6 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py @@ -70,6 +70,8 @@ class TestFleetAMPInit(unittest.TestCase): optimizer = fleet.distributed_optimizer(optimizer) optimizer.minimize(cost) + loss_scale = optimizer.get_loss_scaling() + place = paddle.CUDAPlace(0) exe = paddle.static.Executor(place)