未验证 提交 fbfd1402 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] fix amp o1 (#46391)

上级 5437bd97
......@@ -38,14 +38,18 @@ class AMPState(object):
self._op_fp16_dict = {
} # op_id --> True/False. 'True' means that the current op is in fp16 mode.
self._var_name_dict = {} # fwd_op_id --> {old_name: cast_name}
self.is_train = False
def _is_fp16_op(self, op_id):
return self._op_fp16_dict.get(op_id, None)
def _build_stats(self, amp_lists, dist_context):
def _build_state(self, amp_lists, dist_context):
ops = self._block.ops
dist_op_context = dist_context.dist_op_context
for op in ops:
if int(op.attr('op_role')) == 257:
self.is_train = True
if int(op.attr('op_role')) == int(OpRole.Forward):
self._mark_black_white_ops(amp_lists)
elif int(op.attr('op_role')) == int(OpRole.Backward):
......@@ -59,6 +63,8 @@ class AMPState(object):
elif int(op.attr('op_role')) == int(OpRole.Optimize):
break
return self.is_train
def _mark_black_white_ops(self, amp_lists):
"""
this function is modified from paddle.fluid.contrib.mixed_precision
......@@ -546,13 +552,15 @@ class AMPPass(PassBase):
set(self.get_attr("custom_black_list")),
set(self.get_attr("custom_black_varnames")))
with paddle.static.program_guard(main_program, startup_program):
amp_state = AMPState(main_program.global_block())
amp_state._build_stats(amp_lists, self.dist_context)
is_train = amp_state._build_state(amp_lists, self.dist_context)
with paddle.static.program_guard(main_program, startup_program):
amp_state.cast_forward_program(self.dist_context)
if is_train:
with paddle.static.program_guard(main_program, startup_program):
amp_state.cast_backward_program(params_grads, self.dist_context)
# TODO (JZ-LIANG)support cast forward program only when inference
self._init_amp_var()
self._scale_loss()
......
......@@ -97,6 +97,7 @@ class TestAMPPass(unittest.TestCase):
3,
batch_size=self.batch_size)
amp_o1_losses = np.array(amp_o1_losses["loss"])
amp_o1_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# self.check_results(mp_losses, amp_o1_losses)
# mp2 amp-o2 training
......@@ -105,6 +106,7 @@ class TestAMPPass(unittest.TestCase):
3,
batch_size=self.batch_size)
amp_o2_losses = np.array(amp_o2_losses["loss"])
amp_o2_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# self.check_results(mp_losses, amp_o2_losses)
# mp2 amp-o3 training
......@@ -113,6 +115,7 @@ class TestAMPPass(unittest.TestCase):
3,
batch_size=self.batch_size)
amp_o3_losses = np.array(amp_o3_losses["loss"])
amp_o3_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# self.check_results(mp_losses, amp_o3_losses)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册