diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index 78f30422e742f1b22634271db86dbf53205d4385..e18cee6d42dca65120884826226301df07d11c24 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -363,7 +363,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): output_name) # replicate op in dist program - dist_op_desc = main_block.append_op(type='nop').desc + dist_op_desc = main_block.desc.append_op() dist_op_desc.copy_from(src_op.desc) set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx) for input_name in src_op.desc.input_names(): @@ -371,6 +371,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): for output_name in src_op.desc.output_names(): dist_op_desc.set_output(output_name, kwargs[output_name]) + main_block._sync_with_cpp() + # data parallel synchronization for primtive operators from paddle.incubate.autograd import prim_enabled if prim_enabled(): @@ -426,6 +428,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): op_attr.set_input_dims_mapping(param.name, dims_mapping) ctx.set_op_dist_attr_for_program(new_op, op_attr) + startup_block._sync_with_cpp() + @staticmethod def backward(ctx, *args, **kwargs): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_reduce_p.py b/python/paddle/distributed/auto_parallel/operators/dist_reduce_p.py index 755dcab4be34f654727d53ac09e25aa333f4c259..3275bddd9b4cc17525f910c801566d4b7f4acebe 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_reduce_p.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_reduce_p.py @@ -107,13 +107,14 @@ class DistributedReducePrimtiveImpl0(DistributedOperatorImpl): output_name) # replicate op in dist program - dist_op_desc = main_block.append_op(type='nop').desc + dist_op_desc = main_block.desc.append_op() dist_op_desc.copy_from(src_op.desc) set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx) for input_name in src_op.desc.input_names(): dist_op_desc.set_input(input_name, kwargs[input_name]) for output_name in src_op.desc.output_names(): dist_op_desc.set_output(output_name, kwargs[output_name]) + main_block._sync_with_cpp() # batch dimension synchronization var_name = src_op.output_arg_names[0] diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index ce686fd6a568334b302ca07ea6b991720d0c1f9e..6a767e5afcdf6aa0f9eeae487232771568c2aed4 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -25,7 +25,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext, Di from .dist_attribute import OperatorDistributedAttribute from .process_group import new_process_group from .utils import set_dist_op_desc_original_id -from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_loss_op +from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_loss_op, is_optimize_op from .operators.common import BACKWARD_ONLY_DIST_OPS __varname_not_in_block__ = ["lod_tensor_blocking_queue_0"] @@ -263,14 +263,14 @@ class Partitioner(object): dist_op_backward_impl.backward( self._dist_context, **kinputs, **koutputs, **{"grad_var_to_var": grad_var_to_var}) - elif int(op.attr('op_role')) == 2: + elif is_optimize_op(op): kinputs, koutputs = dist_op_context.prepare_context(op) dist_op_impl = get_distributed_operator_impl_container( "default").get_impl(0) dist_op_impl.backward(self._dist_context, **kinputs, **koutputs) else: raise NotImplementedError( - "partitioner only support forward op and backward op, but got {}". + "partitioner only support forward and backward, optimize ops, but got {}". format(str(op))) def _is_valid_annotated_program(self, program): diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 42d90b0d4d619bea6f51109cdcd0b4a7a8ac1945..7b198e288c636141f459189407c65de9a60a2019 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1099,6 +1099,11 @@ def is_backward_op(op): int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward) +def is_optimize_op(op): + return OP_ROLE_KEY in op.attr_names and \ + int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize) + + def is_loss_op(op): return OP_ROLE_KEY in op.attr_names and \ int(op.all_attrs()[OP_ROLE_KEY]) == (int(core.op_proto_and_checker_maker.OpRole.Forward) | int(core.op_proto_and_checker_maker.OpRole.Loss))