未验证 提交 37bb3342 编写于 作者: Y Yuang Liu 提交者: GitHub

add get_loss_scaling to fleet (#32401)

上级 2b68d20b
...@@ -1041,6 +1041,26 @@ class Fleet(object): ...@@ -1041,6 +1041,26 @@ class Fleet(object):
# imitate target optimizer retrieval # imitate target optimizer retrieval
return self.user_defined_optimizer.clear_grad() return self.user_defined_optimizer.clear_grad()
def _get_amp_optimizer(self):
# imitate target optimizer retrieval
amp_optimizer = None
for optimizer in self.strategy_compiler._get_applied_meta_optimizer():
if hasattr(optimizer, 'amp_init'):
amp_optimizer = optimizer
break
if amp_optimizer is None:
if hasattr(self.user_defined_optimizer, 'amp_init'):
amp_optimizer = self.user_defined_optimizer
assert amp_optimizer is not None, \
"amp_init can only be used when the amp(auto mixed precision) strategy is turned on."
return amp_optimizer
def get_loss_scaling(self):
amp_optimizer = self._get_amp_optimizer()
return amp_optimizer.get_loss_scaling()
def amp_init(self, def amp_init(self,
place, place,
scope=None, scope=None,
...@@ -1101,21 +1121,7 @@ class Fleet(object): ...@@ -1101,21 +1121,7 @@ class Fleet(object):
if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0: if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
run_example_code() run_example_code()
""" """
amp_optimizer = self._get_amp_optimizer()
# imitate target optimizer retrieval
amp_optimizer = None
for optimizer in self.strategy_compiler._get_applied_meta_optimizer():
if hasattr(optimizer, 'amp_init'):
amp_optimizer = optimizer
break
if amp_optimizer is None:
if hasattr(self.user_defined_optimizer, 'amp_init'):
amp_optimizer = self.user_defined_optimizer
assert amp_optimizer is not None, \
"amp_init can only be used when the amp(auto mixed precision) strategy is turned on."
return amp_optimizer.amp_init(place, scope, test_program, use_fp16_test) return amp_optimizer.amp_init(place, scope, test_program, use_fp16_test)
def _final_strategy(self): def _final_strategy(self):
......
...@@ -98,6 +98,7 @@ class OptimizerWithMixedPrecision(object): ...@@ -98,6 +98,7 @@ class OptimizerWithMixedPrecision(object):
def get_loss_scaling(self): def get_loss_scaling(self):
"""Return the real-time loss scaling factor. """Return the real-time loss scaling factor.
""" """
assert self._loss_scaling is not None, 'Call minimize() before calling get_loss_scaling()'
return self._loss_scaling return self._loss_scaling
def get_scaled_loss(self): def get_scaled_loss(self):
......
...@@ -70,6 +70,8 @@ class TestFleetAMPInit(unittest.TestCase): ...@@ -70,6 +70,8 @@ class TestFleetAMPInit(unittest.TestCase):
optimizer = fleet.distributed_optimizer(optimizer) optimizer = fleet.distributed_optimizer(optimizer)
optimizer.minimize(cost) optimizer.minimize(cost)
loss_scale = optimizer.get_loss_scaling()
place = paddle.CUDAPlace(0) place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册