未验证 提交 37bb3342 编写于 作者: Y Yuang Liu 提交者: GitHub

add get_loss_scaling to fleet (#32401)

上级 2b68d20b
......@@ -1041,6 +1041,26 @@ class Fleet(object):
# imitate target optimizer retrieval
return self.user_defined_optimizer.clear_grad()
def _get_amp_optimizer(self):
# imitate target optimizer retrieval
amp_optimizer = None
for optimizer in self.strategy_compiler._get_applied_meta_optimizer():
if hasattr(optimizer, 'amp_init'):
amp_optimizer = optimizer
break
if amp_optimizer is None:
if hasattr(self.user_defined_optimizer, 'amp_init'):
amp_optimizer = self.user_defined_optimizer
assert amp_optimizer is not None, \
"amp_init can only be used when the amp(auto mixed precision) strategy is turned on."
return amp_optimizer
def get_loss_scaling(self):
amp_optimizer = self._get_amp_optimizer()
return amp_optimizer.get_loss_scaling()
def amp_init(self,
place,
scope=None,
......@@ -1101,21 +1121,7 @@ class Fleet(object):
if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
run_example_code()
"""
# imitate target optimizer retrieval
amp_optimizer = None
for optimizer in self.strategy_compiler._get_applied_meta_optimizer():
if hasattr(optimizer, 'amp_init'):
amp_optimizer = optimizer
break
if amp_optimizer is None:
if hasattr(self.user_defined_optimizer, 'amp_init'):
amp_optimizer = self.user_defined_optimizer
assert amp_optimizer is not None, \
"amp_init can only be used when the amp(auto mixed precision) strategy is turned on."
amp_optimizer = self._get_amp_optimizer()
return amp_optimizer.amp_init(place, scope, test_program, use_fp16_test)
def _final_strategy(self):
......
......@@ -98,6 +98,7 @@ class OptimizerWithMixedPrecision(object):
def get_loss_scaling(self):
"""Return the real-time loss scaling factor.
"""
assert self._loss_scaling is not None, 'Call minimize() before calling get_loss_scaling()'
return self._loss_scaling
def get_scaled_loss(self):
......
......@@ -70,6 +70,8 @@ class TestFleetAMPInit(unittest.TestCase):
optimizer = fleet.distributed_optimizer(optimizer)
optimizer.minimize(cost)
loss_scale = optimizer.get_loss_scaling()
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册