未验证 提交 f781473e 编写于 作者: J JZ-LIANG 提交者: GitHub

[AutoParallel] bugfixed for FP16 if cond (#47841)

* fixed cond state

* fixed cond state
上级 f50de679
...@@ -256,7 +256,10 @@ class FP16State: ...@@ -256,7 +256,10 @@ class FP16State:
for op in block.ops: for op in block.ops:
if is_forward_op(op): if is_forward_op(op):
# NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
if self._is_fp16_op(op.desc.original_id()) or op.type == "cast": if (
self._is_fp16_op(op.desc.original_id()) is True
or op.type == "cast"
):
for in_name in op.input_names: for in_name in op.input_names:
if _keep_fp32_input(op, in_name): if _keep_fp32_input(op, in_name):
continue continue
...@@ -273,7 +276,7 @@ class FP16State: ...@@ -273,7 +276,7 @@ class FP16State:
self.set_var_to_fp16(out_var_name, block) self.set_var_to_fp16(out_var_name, block)
set_op_dtype_to_fp16(op) set_op_dtype_to_fp16(op)
# NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
elif not self._is_fp16_op(op.desc.original_id()): elif self._is_fp16_op(op.desc.original_id()) is False:
for out_var_name in op.output_arg_names: for out_var_name in op.output_arg_names:
out_var = block.vars.get(out_var_name) out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types: if out_var is None or out_var.type not in _valid_types:
...@@ -281,7 +284,7 @@ class FP16State: ...@@ -281,7 +284,7 @@ class FP16State:
if out_var.dtype == core.VarDesc.VarType.FP16: if out_var.dtype == core.VarDesc.VarType.FP16:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32) out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
elif is_backward_op(op): elif is_backward_op(op):
if self._is_fp16_op(op.desc.original_id()): if self._is_fp16_op(op.desc.original_id()) is True:
for out_name in op.output_names: for out_name in op.output_names:
if _keep_fp32_output(op, out_name): if _keep_fp32_output(op, out_name):
continue continue
...@@ -289,7 +292,7 @@ class FP16State: ...@@ -289,7 +292,7 @@ class FP16State:
self.set_var_to_fp16(out_var_name, block) self.set_var_to_fp16(out_var_name, block)
set_op_dtype_to_fp16(op) set_op_dtype_to_fp16(op)
# NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
elif not self._is_fp16_op(op.desc.original_id()): elif self._is_fp16_op(op.desc.original_id()) is False:
for out_var_name in op.output_arg_names: for out_var_name in op.output_arg_names:
out_var = block.vars.get(out_var_name) out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types: if out_var is None or out_var.type not in _valid_types:
...@@ -308,7 +311,7 @@ class FP16State: ...@@ -308,7 +311,7 @@ class FP16State:
idx += 1 idx += 1
continue continue
elif is_forward_op(op): elif is_forward_op(op):
if not self._is_fp16_op(op.desc.original_id()): if self._is_fp16_op(op.desc.original_id()) is False:
num_cast_ops = self._insert_forward_cast_ops( num_cast_ops = self._insert_forward_cast_ops(
op, op,
idx, idx,
...@@ -317,7 +320,7 @@ class FP16State: ...@@ -317,7 +320,7 @@ class FP16State:
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
self.dist_context, self.dist_context,
) )
elif self._is_fp16_op(op.desc.original_id()): elif self._is_fp16_op(op.desc.original_id()) is True:
num_cast_ops = self._insert_forward_cast_ops( num_cast_ops = self._insert_forward_cast_ops(
op, op,
idx, idx,
...@@ -328,7 +331,7 @@ class FP16State: ...@@ -328,7 +331,7 @@ class FP16State:
) )
elif is_backward_op(op): elif is_backward_op(op):
if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id: if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
if not self._is_fp16_op(op.desc.original_id()): if self._is_fp16_op(op.desc.original_id()) is False:
num_cast_ops = self._insert_backward_cast_ops( num_cast_ops = self._insert_backward_cast_ops(
op, op,
idx, idx,
...@@ -337,7 +340,7 @@ class FP16State: ...@@ -337,7 +340,7 @@ class FP16State:
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
self.dist_context, self.dist_context,
) )
elif self._is_fp16_op(op.desc.original_id()): elif self._is_fp16_op(op.desc.original_id()) is True:
num_cast_ops = self._insert_backward_cast_ops( num_cast_ops = self._insert_backward_cast_ops(
op, op,
idx, idx,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册