未验证 提交 f781473e 编写于 作者: J JZ-LIANG 提交者: GitHub

[AutoParallel] bugfixed for FP16 if cond (#47841)

* fixed cond state

* fixed cond state
上级 f50de679
......@@ -256,7 +256,10 @@ class FP16State:
for op in block.ops:
if is_forward_op(op):
# NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
if self._is_fp16_op(op.desc.original_id()) or op.type == "cast":
if (
self._is_fp16_op(op.desc.original_id()) is True
or op.type == "cast"
):
for in_name in op.input_names:
if _keep_fp32_input(op, in_name):
continue
......@@ -273,7 +276,7 @@ class FP16State:
self.set_var_to_fp16(out_var_name, block)
set_op_dtype_to_fp16(op)
# NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
elif not self._is_fp16_op(op.desc.original_id()):
elif self._is_fp16_op(op.desc.original_id()) is False:
for out_var_name in op.output_arg_names:
out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types:
......@@ -281,7 +284,7 @@ class FP16State:
if out_var.dtype == core.VarDesc.VarType.FP16:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
elif is_backward_op(op):
if self._is_fp16_op(op.desc.original_id()):
if self._is_fp16_op(op.desc.original_id()) is True:
for out_name in op.output_names:
if _keep_fp32_output(op, out_name):
continue
......@@ -289,7 +292,7 @@ class FP16State:
self.set_var_to_fp16(out_var_name, block)
set_op_dtype_to_fp16(op)
# NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
elif not self._is_fp16_op(op.desc.original_id()):
elif self._is_fp16_op(op.desc.original_id()) is False:
for out_var_name in op.output_arg_names:
out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types:
......@@ -308,7 +311,7 @@ class FP16State:
idx += 1
continue
elif is_forward_op(op):
if not self._is_fp16_op(op.desc.original_id()):
if self._is_fp16_op(op.desc.original_id()) is False:
num_cast_ops = self._insert_forward_cast_ops(
op,
idx,
......@@ -317,7 +320,7 @@ class FP16State:
core.VarDesc.VarType.FP32,
self.dist_context,
)
elif self._is_fp16_op(op.desc.original_id()):
elif self._is_fp16_op(op.desc.original_id()) is True:
num_cast_ops = self._insert_forward_cast_ops(
op,
idx,
......@@ -328,7 +331,7 @@ class FP16State:
)
elif is_backward_op(op):
if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
if not self._is_fp16_op(op.desc.original_id()):
if self._is_fp16_op(op.desc.original_id()) is False:
num_cast_ops = self._insert_backward_cast_ops(
op,
idx,
......@@ -337,7 +340,7 @@ class FP16State:
core.VarDesc.VarType.FP32,
self.dist_context,
)
elif self._is_fp16_op(op.desc.original_id()):
elif self._is_fp16_op(op.desc.original_id()) is True:
num_cast_ops = self._insert_backward_cast_ops(
op,
idx,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册