From 37bb3342dd66922c2919481e519866a6a901a597 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Wed, 21 Apr 2021 15:39:59 +0800 Subject: [PATCH] add get_loss_scaling to fleet (#32401) --- .../distributed/fleet/base/fleet_base.py | 36 +++++++++++-------- .../contrib/mixed_precision/decorator.py | 1 + .../tests/unittests/test_fleet_amp_init.py | 2 ++ 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index d71f7e77405..5e17794dfea 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -1041,6 +1041,26 @@ class Fleet(object): # imitate target optimizer retrieval return self.user_defined_optimizer.clear_grad() + def _get_amp_optimizer(self): + # imitate target optimizer retrieval + amp_optimizer = None + for optimizer in self.strategy_compiler._get_applied_meta_optimizer(): + if hasattr(optimizer, 'amp_init'): + amp_optimizer = optimizer + break + + if amp_optimizer is None: + if hasattr(self.user_defined_optimizer, 'amp_init'): + amp_optimizer = self.user_defined_optimizer + + assert amp_optimizer is not None, \ + "amp_init can only be used when the amp(auto mixed precision) strategy is turned on." + return amp_optimizer + + def get_loss_scaling(self): + amp_optimizer = self._get_amp_optimizer() + return amp_optimizer.get_loss_scaling() + def amp_init(self, place, scope=None, @@ -1101,21 +1121,7 @@ class Fleet(object): if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0: run_example_code() """ - - # imitate target optimizer retrieval - amp_optimizer = None - for optimizer in self.strategy_compiler._get_applied_meta_optimizer(): - if hasattr(optimizer, 'amp_init'): - amp_optimizer = optimizer - break - - if amp_optimizer is None: - if hasattr(self.user_defined_optimizer, 'amp_init'): - amp_optimizer = self.user_defined_optimizer - - assert amp_optimizer is not None, \ - "amp_init can only be used when the amp(auto mixed precision) strategy is turned on." - + amp_optimizer = self._get_amp_optimizer() return amp_optimizer.amp_init(place, scope, test_program, use_fp16_test) def _final_strategy(self): diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator.py b/python/paddle/fluid/contrib/mixed_precision/decorator.py index d37e90b4695..724f707c2e1 100644 --- a/python/paddle/fluid/contrib/mixed_precision/decorator.py +++ b/python/paddle/fluid/contrib/mixed_precision/decorator.py @@ -98,6 +98,7 @@ class OptimizerWithMixedPrecision(object): def get_loss_scaling(self): """Return the real-time loss scaling factor. """ + assert self._loss_scaling is not None, 'Call minimize() before calling get_loss_scaling()' return self._loss_scaling def get_scaled_loss(self): diff --git a/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py b/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py index 869ca41a192..6930a330a7c 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py @@ -70,6 +70,8 @@ class TestFleetAMPInit(unittest.TestCase): optimizer = fleet.distributed_optimizer(optimizer) optimizer.minimize(cost) + loss_scale = optimizer.get_loss_scaling() + place = paddle.CUDAPlace(0) exe = paddle.static.Executor(place) -- GitLab