diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 2a4977847b1821417ab6ed8cecfa5de5611b7470..178edc0fe88c569a242c630b5b211568109fa49d 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -1059,6 +1059,8 @@ class Fleet(object): return amp_optimizer def get_loss_scaling(self): + """Return the real-time loss scaling factor. + """ amp_optimizer = self._get_amp_optimizer() return amp_optimizer.get_loss_scaling() diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator.py b/python/paddle/fluid/contrib/mixed_precision/decorator.py index 724f707c2e1f065c63de01e861244f10cf4cf7da..16cba2ce36b20ef2d8b97c046b52a8df64fe0d49 100644 --- a/python/paddle/fluid/contrib/mixed_precision/decorator.py +++ b/python/paddle/fluid/contrib/mixed_precision/decorator.py @@ -98,7 +98,7 @@ class OptimizerWithMixedPrecision(object): def get_loss_scaling(self): """Return the real-time loss scaling factor. """ - assert self._loss_scaling is not None, 'Call minimize() before calling get_loss_scaling()' + assert self._loss_scaling is not None, 'Please call minimize() before calling get_loss_scaling().' return self._loss_scaling def get_scaled_loss(self):