diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index 853df2eea7f5d0a7e96f84c5e0b83a685b3606d1..3262505416b1d031ff9e3702fcc3a0af1b443190 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -224,13 +224,15 @@ class Partitioner(object): forward_op_id2forward_op[ serial_ops[idx].desc.original_id()] = serial_ops[idx] - appended_grad_times = 0 # partiiton + appended_grad_times = 0 for idx, op in enumerate(serial_ops): + op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op) if is_backward_op(op) and (is_forward_op(serial_ops[idx - 1]) or is_loss_op(serial_ops[idx - 1])): - appended_grad_times += 1 + if not op_dist_attr.is_recompute: + appended_grad_times += 1 # partititon input variables for serial_input_varname in op.desc.input_arg_names(): @@ -256,7 +258,6 @@ class Partitioner(object): serial_output_varname] = new_varname # partition op - op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op) if is_forward_op(op) or op_dist_attr.is_recompute: kinputs, koutputs = dist_op_context.prepare_context(op) dist_op_forward_impl = _get_dist_op_forward_implement( diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 128cffcc6e9270d165cc6f4e8c4d1417e027dca9..f65b7591e59727dfa85acdacda744d59869ec975 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -18,14 +18,13 @@ import paddle from paddle.framework import core from paddle.fluid import unique_name from .pass_base import register_pass -from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.fluid.data_feeder import check_variable_and_dtype, check_type from paddle.distributed.auto_parallel.utils import set_var_dist_attr, naive_set_dist_op_attr_for_program_by_mesh_and_mapping from paddle.distributed.auto_parallel.process_group import get_world_process_group from paddle.fluid.contrib.mixed_precision.fp16_utils import AutoMixedPrecisionLists from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_layer_norm_scale_bias_to_fp32, _need_keep_fp32, _valid_types, _dtype_to_str from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute -from paddle.distributed.auto_parallel.utils import is_forward_op, is_backward_op +from paddle.distributed.auto_parallel.utils import is_forward_op, is_backward_op, OP_ROLE_KEY, OpRole from .auto_parallel_amp import AMPPass world_process_group = get_world_process_group() @@ -331,6 +330,7 @@ class FP16State(object): attrs={ "in_dtype": in_var.dtype, "out_dtype": cast_var.dtype, + OP_ROLE_KEY: OpRole.Forward }) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( cast_op, ref_mesh, ref_mapping, dist_context) @@ -409,6 +409,7 @@ class FP16State(object): attrs={ "in_dtype": dst_dtype, "out_dtype": src_dtype, + OP_ROLE_KEY: OpRole.Backward }) grad.desc.set_dtype(src_dtype)