From d03b0b16e80d20655a19272484108df7112039a5 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Thu, 22 Apr 2021 15:46:14 +0800 Subject: [PATCH] Add fleet get_loss_scaling doc and update alert message (#32419) --- python/paddle/distributed/fleet/base/fleet_base.py | 2 ++ python/paddle/fluid/contrib/mixed_precision/decorator.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 2a4977847b..178edc0fe8 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -1059,6 +1059,8 @@ class Fleet(object): return amp_optimizer def get_loss_scaling(self): + """Return the real-time loss scaling factor. + """ amp_optimizer = self._get_amp_optimizer() return amp_optimizer.get_loss_scaling() diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator.py b/python/paddle/fluid/contrib/mixed_precision/decorator.py index 724f707c2e..16cba2ce36 100644 --- a/python/paddle/fluid/contrib/mixed_precision/decorator.py +++ b/python/paddle/fluid/contrib/mixed_precision/decorator.py @@ -98,7 +98,7 @@ class OptimizerWithMixedPrecision(object): def get_loss_scaling(self): """Return the real-time loss scaling factor. """ - assert self._loss_scaling is not None, 'Call minimize() before calling get_loss_scaling()' + assert self._loss_scaling is not None, 'Please call minimize() before calling get_loss_scaling().' return self._loss_scaling def get_scaled_loss(self): -- GitLab