From f781473e35aab0660d040861ea4e96ac255cc3de Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Mon, 14 Nov 2022 11:34:48 +0800 Subject: [PATCH] [AutoParallel] bugfixed for FP16 if cond (#47841) * fixed cond state * fixed cond state --- .../distributed/passes/auto_parallel_fp16.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 4cb0c361fe2..eb45b41c551 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -256,7 +256,10 @@ class FP16State: for op in block.ops: if is_forward_op(op): # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python - if self._is_fp16_op(op.desc.original_id()) or op.type == "cast": + if ( + self._is_fp16_op(op.desc.original_id()) is True + or op.type == "cast" + ): for in_name in op.input_names: if _keep_fp32_input(op, in_name): continue @@ -273,7 +276,7 @@ class FP16State: self.set_var_to_fp16(out_var_name, block) set_op_dtype_to_fp16(op) # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python - elif not self._is_fp16_op(op.desc.original_id()): + elif self._is_fp16_op(op.desc.original_id()) is False: for out_var_name in op.output_arg_names: out_var = block.vars.get(out_var_name) if out_var is None or out_var.type not in _valid_types: @@ -281,7 +284,7 @@ class FP16State: if out_var.dtype == core.VarDesc.VarType.FP16: out_var.desc.set_dtype(core.VarDesc.VarType.FP32) elif is_backward_op(op): - if self._is_fp16_op(op.desc.original_id()): + if self._is_fp16_op(op.desc.original_id()) is True: for out_name in op.output_names: if _keep_fp32_output(op, out_name): continue @@ -289,7 +292,7 @@ class FP16State: self.set_var_to_fp16(out_var_name, block) set_op_dtype_to_fp16(op) # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python - elif not self._is_fp16_op(op.desc.original_id()): + elif self._is_fp16_op(op.desc.original_id()) is False: for out_var_name in op.output_arg_names: out_var = block.vars.get(out_var_name) if out_var is None or out_var.type not in _valid_types: @@ -308,7 +311,7 @@ class FP16State: idx += 1 continue elif is_forward_op(op): - if not self._is_fp16_op(op.desc.original_id()): + if self._is_fp16_op(op.desc.original_id()) is False: num_cast_ops = self._insert_forward_cast_ops( op, idx, @@ -317,7 +320,7 @@ class FP16State: core.VarDesc.VarType.FP32, self.dist_context, ) - elif self._is_fp16_op(op.desc.original_id()): + elif self._is_fp16_op(op.desc.original_id()) is True: num_cast_ops = self._insert_forward_cast_ops( op, idx, @@ -328,7 +331,7 @@ class FP16State: ) elif is_backward_op(op): if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id: - if not self._is_fp16_op(op.desc.original_id()): + if self._is_fp16_op(op.desc.original_id()) is False: num_cast_ops = self._insert_backward_cast_ops( op, idx, @@ -337,7 +340,7 @@ class FP16State: core.VarDesc.VarType.FP32, self.dist_context, ) - elif self._is_fp16_op(op.desc.original_id()): + elif self._is_fp16_op(op.desc.original_id()) is True: num_cast_ops = self._insert_backward_cast_ops( op, idx, -- GitLab