未验证 提交 5dab0b0d 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] fix amp o1 (#46391) (#46481)

上级 5711bbee
...@@ -38,14 +38,18 @@ class AMPState(object): ...@@ -38,14 +38,18 @@ class AMPState(object):
self._op_fp16_dict = { self._op_fp16_dict = {
} # op_id --> True/False. 'True' means that the current op is in fp16 mode. } # op_id --> True/False. 'True' means that the current op is in fp16 mode.
self._var_name_dict = {} # fwd_op_id --> {old_name: cast_name} self._var_name_dict = {} # fwd_op_id --> {old_name: cast_name}
self.is_train = False
def _is_fp16_op(self, op_id): def _is_fp16_op(self, op_id):
return self._op_fp16_dict.get(op_id, None) return self._op_fp16_dict.get(op_id, None)
def _build_stats(self, amp_lists, dist_context): def _build_state(self, amp_lists, dist_context):
ops = self._block.ops ops = self._block.ops
dist_op_context = dist_context.dist_op_context dist_op_context = dist_context.dist_op_context
for op in ops: for op in ops:
if int(op.attr('op_role')) == 257:
self.is_train = True
if int(op.attr('op_role')) == int(OpRole.Forward): if int(op.attr('op_role')) == int(OpRole.Forward):
self._mark_black_white_ops(amp_lists) self._mark_black_white_ops(amp_lists)
elif int(op.attr('op_role')) == int(OpRole.Backward): elif int(op.attr('op_role')) == int(OpRole.Backward):
...@@ -59,6 +63,8 @@ class AMPState(object): ...@@ -59,6 +63,8 @@ class AMPState(object):
elif int(op.attr('op_role')) == int(OpRole.Optimize): elif int(op.attr('op_role')) == int(OpRole.Optimize):
break break
return self.is_train
def _mark_black_white_ops(self, amp_lists): def _mark_black_white_ops(self, amp_lists):
""" """
this function is modified from paddle.fluid.contrib.mixed_precision this function is modified from paddle.fluid.contrib.mixed_precision
...@@ -546,13 +552,15 @@ class AMPPass(PassBase): ...@@ -546,13 +552,15 @@ class AMPPass(PassBase):
set(self.get_attr("custom_black_list")), set(self.get_attr("custom_black_list")),
set(self.get_attr("custom_black_varnames"))) set(self.get_attr("custom_black_varnames")))
with paddle.static.program_guard(main_program, startup_program):
amp_state = AMPState(main_program.global_block()) amp_state = AMPState(main_program.global_block())
amp_state._build_stats(amp_lists, self.dist_context) is_train = amp_state._build_state(amp_lists, self.dist_context)
with paddle.static.program_guard(main_program, startup_program):
amp_state.cast_forward_program(self.dist_context) amp_state.cast_forward_program(self.dist_context)
if is_train:
with paddle.static.program_guard(main_program, startup_program):
amp_state.cast_backward_program(params_grads, self.dist_context) amp_state.cast_backward_program(params_grads, self.dist_context)
# TODO (JZ-LIANG)support cast forward program only when inference
self._init_amp_var() self._init_amp_var()
self._scale_loss() self._scale_loss()
......
...@@ -97,6 +97,7 @@ class TestAMPPass(unittest.TestCase): ...@@ -97,6 +97,7 @@ class TestAMPPass(unittest.TestCase):
3, 3,
batch_size=self.batch_size) batch_size=self.batch_size)
amp_o1_losses = np.array(amp_o1_losses["loss"]) amp_o1_losses = np.array(amp_o1_losses["loss"])
amp_o1_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# self.check_results(mp_losses, amp_o1_losses) # self.check_results(mp_losses, amp_o1_losses)
# mp2 amp-o2 training # mp2 amp-o2 training
...@@ -105,6 +106,7 @@ class TestAMPPass(unittest.TestCase): ...@@ -105,6 +106,7 @@ class TestAMPPass(unittest.TestCase):
3, 3,
batch_size=self.batch_size) batch_size=self.batch_size)
amp_o2_losses = np.array(amp_o2_losses["loss"]) amp_o2_losses = np.array(amp_o2_losses["loss"])
amp_o2_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# self.check_results(mp_losses, amp_o2_losses) # self.check_results(mp_losses, amp_o2_losses)
# mp2 amp-o3 training # mp2 amp-o3 training
...@@ -113,6 +115,7 @@ class TestAMPPass(unittest.TestCase): ...@@ -113,6 +115,7 @@ class TestAMPPass(unittest.TestCase):
3, 3,
batch_size=self.batch_size) batch_size=self.batch_size)
amp_o3_losses = np.array(amp_o3_losses["loss"]) amp_o3_losses = np.array(amp_o3_losses["loss"])
amp_o3_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# self.check_results(mp_losses, amp_o3_losses) # self.check_results(mp_losses, amp_o3_losses)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册