From 6b8efc458deacb6c65a1bdf728be5251b975c210 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Thu, 19 May 2022 11:12:35 +0800 Subject: [PATCH] [Auto Parallel] Support Primitive operators with Data Parallel (#42709) * auto parallel support primitive op with data parallel * add primitive change * 5 loss 3D cylinder acc aligned * add unitest --- .../distributed/auto_parallel/completion.py | 67 ++++++++ .../auto_parallel/operators/__init__.py | 1 + .../auto_parallel/operators/common.py | 7 +- .../auto_parallel/operators/dist_default.py | 63 +++++++- .../auto_parallel/operators/dist_reduce_p.py | 151 ++++++++++++++++++ .../distributed/auto_parallel/partitioner.py | 5 + .../paddle/distributed/auto_parallel/utils.py | 7 + .../auto_parallel/test_prim_dist_op.py | 106 ++++++++++++ 8 files changed, 398 insertions(+), 9 deletions(-) create mode 100644 python/paddle/distributed/auto_parallel/operators/dist_reduce_p.py create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_prim_dist_op.py diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 8c286c0201..31bdc4cc65 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -1250,3 +1250,70 @@ class Completer: self._dist_context.set_op_dist_attr_for_program( op, op_dist_attr) continue + + def complete_prim_annotation(self, serial_main_program=None): + """ + fill default data parallel annotation for program with primitive operators. + + Arguments: + serial_main_program: partial annotated serial_main_program. + Returns: + serial_main_program: completed annotated serial_main_program. + """ + if serial_main_program is None: + serial_main_program = self._dist_context.serial_main_program + else: + self._dist_context.serial_main_program = serial_main_program + + import time + + start_time = time.time() + self._dist_context._is_initialized = True + + start_time = time.time() + self._dist_context._init_dist_attr_for_program() + + start_time = time.time() + self._init_global_mesh_for_program() + + # Do the validation check and amend some completion + start_time = time.time() + self._dist_context.amend_dist_attr_for_program() + self._dist_context.validate_dist_attr_for_program() + + def _init_global_mesh_for_program(self): + # Copy the dist tensors and dist ops annotated by users from the default context + # global mesh + from paddle.distributed.auto_parallel.process_group import get_world_process_group + world_ranks = get_world_process_group().ranks + + for block in self._dist_context._serial_main_program.blocks: + for tensor in block.vars.values(): + # Copy the distributed tensors in the default context + dist_tensor = self._dist_context.get_dist_tensor_for_program( + tensor) + assert dist_tensor is not None + dist_tensor.dist_attr.process_mesh = world_ranks + for op in block.ops: + # Copy the distributed operators in the default context + dist_op = self._dist_context.get_dist_op_for_program(op) + assert dist_op is not None + dist_op.dist_attr.process_mesh = world_ranks + + # Find the most compatible implemenetations from the distributed operator + op_dist_impls = find_best_compatible_distributed_operator_impl( + dist_op, fwd=True) + if op_dist_impls is not None: + backup_op_dist_attr = copy.deepcopy(dist_op.dist_attr) + for op_dist_impl in op_dist_impls: + dim_changed = op_dist_impl.update_dims_mapping(dist_op) + if op_dist_impl.is_auto_compatible(dist_op): + if op_dist_impl.type == "elementwise": + dist_op.dist_attr.impl_type = "default" + else: + dist_op.dist_attr.impl_type = op_dist_impl.type + # op_dist_attr.impl_type = op_dist_impl.type + dist_op.dist_attr.impl_idx = op_dist_impl.idx + break + else: + dist_op.dist_attr = backup_op_dist_attr diff --git a/python/paddle/distributed/auto_parallel/operators/__init__.py b/python/paddle/distributed/auto_parallel/operators/__init__.py index 3f06b34b53..3ff4746972 100644 --- a/python/paddle/distributed/auto_parallel/operators/__init__.py +++ b/python/paddle/distributed/auto_parallel/operators/__init__.py @@ -32,3 +32,4 @@ from . import dist_pnorm from . import dist_slice from . import dist_fused_feedforward from . import dist_fused_attention +from . import dist_reduce_p diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index 5d43c56827..441eb88a9f 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -24,9 +24,10 @@ BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'} def is_elementwise_op(op_type): - for eltwise_op in _g_elementwise_ops: - if eltwise_op in op_type: - return True + if op_type in _g_elementwise_ops: + return True + if "elementwise" in op_type: + return True return False diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index 2e47bcd816..6d9b48ea1e 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -18,7 +18,7 @@ from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl, is_parameter_related from ..utils import is_dim_shard from ..utils import is_dim_replicate -from ..utils import is_valid_list_index +from ..utils import is_valid_list_index, is_prim_op from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping @@ -35,6 +35,55 @@ from ..utils import _get_comm_group, _get_corresponding_rank __op_not_need_param_init__ = ["while", "cond"] +def prim_operator_data_parallel_functor(ctx, src_op): + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.work_block + startup_block = dist_op_context.startup_block + + var_name = src_op.output_arg_names[0] + if var_name in ctx.grads_params: + assert var_name not in ctx.synced_gradient, "in primtive mode, grad is already {} synced".format( + var_name) + ctx.synced_gradient.add(var_name) + sync_group = new_process_group(ctx.data_parallel_group) + + allreduce_op = main_block.append_op( + type='c_allreduce_sum', + inputs={'X': [var_name]}, + outputs={'Out': [var_name]}, + attrs={ + 'ring_id': sync_group.id, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Backward + }) + + param = ctx.grads_params[var_name] + startup_block = dist_op_context.startup_block + new_op = startup_block.append_op( + type='c_broadcast', + inputs={'X': [param]}, + outputs={'Out': [param]}, + attrs={ + 'ring_id': sync_group.id, + 'root': 0, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Forward + }) + + grad_var = main_block.var(var_name) + dims_mapping = ctx.get_tensor_dist_attr_for_program( + grad_var).dims_mapping + dist_attr = ctx.get_op_dist_attr_for_program(src_op) + process_mesh = dist_attr.process_mesh + op_attr = OperatorDistributedAttribute() + op_attr.process_mesh = process_mesh + op_attr.set_output_dims_mapping(grad_var.name, dims_mapping) + op_attr.set_input_dims_mapping(grad_var.name, dims_mapping) + ctx.set_op_dist_attr_for_program(allreduce_op, op_attr) + + return + + class DistributedDefault(DistributedOperatorImplContainer): def __init__(self, op_type): super(DistributedDefault, self).__init__(op_type) @@ -292,7 +341,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): @staticmethod def forward(ctx, *args, **kwargs): - dist_op_context = ctx.dist_op_context main_block = dist_op_context.work_block startup_block = dist_op_context.startup_block @@ -315,7 +363,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): output_name) # replicate op in dist program - dist_op_desc = main_block.desc.append_op() + dist_op_desc = main_block.append_op(type='nop').desc dist_op_desc.copy_from(src_op.desc) set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx) for input_name in src_op.desc.input_names(): @@ -323,7 +371,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): for output_name in src_op.desc.output_names(): dist_op_desc.set_output(output_name, kwargs[output_name]) - main_block._sync_with_cpp() + # data parallel synchronization for primtive operators + from paddle.incubate.autograd import prim_enabled + if prim_enabled(): + assert is_prim_op(src_op) + prim_operator_data_parallel_functor(ctx, src_op) + return # param initialization sync if src_op.type in __op_not_need_param_init__: @@ -373,8 +426,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): op_attr.set_input_dims_mapping(param.name, dims_mapping) ctx.set_op_dist_attr_for_program(new_op, op_attr) - startup_block._sync_with_cpp() - @staticmethod def backward(ctx, *args, **kwargs): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_reduce_p.py b/python/paddle/distributed/auto_parallel/operators/dist_reduce_p.py new file mode 100644 index 0000000000..755dcab4be --- /dev/null +++ b/python/paddle/distributed/auto_parallel/operators/dist_reduce_p.py @@ -0,0 +1,151 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from .common import DistributedOperatorImplContainer +from .common import DistributedOperatorImpl +from .common import register_distributed_operator_impl_container +from .common import register_distributed_operator_impl, is_parameter_related +from ..utils import is_dim_shard +from ..utils import is_dim_replicate +from ..utils import is_valid_list_index +from ..utils import compute_compatible_dim_mapping +from ..utils import compute_compatible_dims_mapping +from ..utils import compute_compatible_and_update_dim_mapping +from ..utils import set_dist_op_desc_original_id +from ..dist_attribute import OperatorDistributedAttribute +from paddle.fluid import core, unique_name +from paddle.fluid.framework import _non_static_mode +from paddle.fluid.framework import Program, Parameter, Variable, program_guard +from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype +from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY +from ..process_group import new_process_group +from ..utils import _get_comm_group, _get_corresponding_rank + + +class DistributedReducePrimtive(DistributedOperatorImplContainer): + def __init__(self, op_type): + super(DistributedReducePrimtive, self).__init__(op_type) + + +register_distributed_operator_impl_container( + DistributedReducePrimtive("reduce_p")) + + +# Batch Dimension Reduce Primitive +class DistributedReducePrimtiveImpl0(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedReducePrimtiveImpl0, self).__init__(name) + self._forward_implemented = True + self._backward_implemented = True + + def is_input_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + + return len(op_desc.input_arg_names()) == 1 + + def is_output_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + outputs = op_desc.output_arg_names() + + if len(outputs) != 1: + return False + + output_name = outputs[0] + output_var = dist_op.serial_op.block.var(output_name) + if output_var.shape != (1, ): + return False + + return True + + def is_auto_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + + return self.is_input_compatible(dist_op) and self.is_output_compatible( + dist_op) + + def update_dims_mapping(self, dist_op): + changed = False + + return changed + + @staticmethod + def forward(ctx, *args, **kwargs): + + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.work_block + startup_block = dist_op_context.startup_block + src_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id + + # check validation of inputs / outputs + for input_name in src_op.desc.input_names(): + assert input_name in kwargs, "input [{}] is not given".format( + input_name) + assert len(kwargs[input_name]) == len( + src_op.desc.input(input_name) + ), "number of tensor for input [{}] is not match".format(input_name) + for output_name in src_op.desc.output_names(): + assert output_name in kwargs, "input [{}] is not given".format( + output_name) + assert len(kwargs[output_name]) == len( + src_op.desc.output(output_name) + ), "number of tensor for input [{}] is not match".format( + output_name) + + # replicate op in dist program + dist_op_desc = main_block.append_op(type='nop').desc + dist_op_desc.copy_from(src_op.desc) + set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx) + for input_name in src_op.desc.input_names(): + dist_op_desc.set_input(input_name, kwargs[input_name]) + for output_name in src_op.desc.output_names(): + dist_op_desc.set_output(output_name, kwargs[output_name]) + + # batch dimension synchronization + var_name = src_op.output_arg_names[0] + sync_group = new_process_group(ctx.data_parallel_group) + allreduce_op = main_block.append_op( + type='c_allreduce_sum', + inputs={'X': [var_name]}, + outputs={'Out': [var_name]}, + attrs={ + 'ring_id': sync_group.id, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Forward + }) + + # dist attr + var = main_block.var(var_name) + tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var) + op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) + new_op_attr = OperatorDistributedAttribute() + new_op_attr.process_mesh = op_dist_attr.process_mesh + new_op_attr.set_output_dims_mapping(var.name, + tensor_dist_attr.dims_mapping) + new_op_attr.set_input_dims_mapping(var.name, + tensor_dist_attr.dims_mapping) + ctx.set_op_dist_attr_for_program(allreduce_op, new_op_attr) + + @staticmethod + def backward(ctx, *args, **kwargs): + raise RuntimeError( + "primitive operator does NOT have backward function, op type: {}". + format(str(op.type))) + + +register_distributed_operator_impl( + "reduce_p", DistributedReducePrimtiveImpl0("batch_dimension_reduce_p")) diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index fe091cd08b..91a31dd1b9 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -263,6 +263,11 @@ class Partitioner(object): dist_op_backward_impl.backward( self._dist_context, **kinputs, **koutputs, **{"grad_var_to_var": grad_var_to_var}) + elif int(op.attr('op_role')) == 2: + kinputs, koutputs = dist_op_context.prepare_context(op) + dist_op_impl = get_distributed_operator_impl_container( + "default").get_impl(0) + dist_op_impl.backward(self._dist_context, **kinputs, **koutputs) else: raise NotImplementedError( "partitioner only support forward op and backward op, but got {}". diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index ac07b49f45..fbe3a43a79 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1101,6 +1101,10 @@ def is_loss_op(op): int(op.all_attrs()[OP_ROLE_KEY]) == (int(core.op_proto_and_checker_maker.OpRole.Forward) | int(core.op_proto_and_checker_maker.OpRole.Loss)) +def is_prim_op(op): + return op.type.endswith("_p") + + def get_loss_op(block): loss_ops = [] for op in block.ops: @@ -1118,6 +1122,9 @@ def set_var_dist_attr(dist_context, var, dims_mapping, process_mesh, **kwargs): tensor_dist_attr.dims_mapping = dims_mapping # TODO get global mesh group tensor_dist_attr.process_mesh = process_mesh + if "mark_annotated" in kwargs and kwargs["mark_annotated"]: + tensor_dist_attr.mark_annotated("dims_mapping") + tensor_dist_attr.mark_annotated("process_mesh") dist_context.set_tensor_dist_attr_for_program(var, tensor_dist_attr) return tensor_dist_attr diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_prim_dist_op.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_prim_dist_op.py new file mode 100644 index 0000000000..f9ab6f37f3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_prim_dist_op.py @@ -0,0 +1,106 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import paddle.distributed.auto_parallel as auto + +from paddle.fluid import program_guard +from paddle.incubate.autograd import prim2orig, enable_prim, prim_enabled +from paddle.fluid.layer_helper import LayerHelper +from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.completion import Completer +from paddle.distributed.auto_parallel.partitioner import Partitioner +from paddle.distributed.auto_parallel.utils import set_var_dist_attr +from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context, set_default_distributed_context + +paddle.enable_static() +enable_prim() +nranks = 2 +rank = 0 + + +class TestPrimDistOp(unittest.TestCase): + def setUp(self): + self.main_program = paddle.static.Program() + self.startup_program = paddle.static.Program() + self.layer_help = LayerHelper('TestPrimDistOp') + + with paddle.static.program_guard(self.main_program, + self.startup_program): + self.init_prog() + + def init_prog(self): + # block = self.main_program.global_block() + # block = self.main_program.global_block() + self.w = self.layer_help.create_parameter( + dtype="float", shape=[20], attr=None) + self.w_grad = paddle.static.data( + name='w_grad', shape=[20], dtype='float') + self.tmp1 = paddle.static.data(name='tmp1', shape=[20], dtype='float') + self.tmp2 = paddle.static.data(name='tmp2', shape=[20], dtype='float') + self.batch_reduced = paddle.static.data( + name='batch_reduced', shape=[1], dtype='float') + self.attrs = {} + + default_dist_context = get_default_distributed_context() + _global_process_mesh = auto.ProcessMesh(list(range(nranks))) + tensor_dist_attr = set_var_dist_attr( + default_dist_context, + self.tmp1, [-1], + _global_process_mesh, + mark_annotated=True) + tensor_dist_attr = set_var_dist_attr( + default_dist_context, + self.tmp1, [-1], + _global_process_mesh, + mark_annotated=True) + + op = self.layer_help.append_op( + type="add_p", + inputs={'X': self.tmp1, + 'Y': self.w}, + outputs={'Z': self.w_grad}, + attrs=self.attrs) + + op = self.layer_help.append_op( + type="reduce_p", + inputs={'X': self.tmp2}, + outputs={'Y': self.batch_reduced}, + attrs={"axis": [0]}) + + def test_loss_and_grad_allreduce(self): + + dist_context = DistributedContext(self.main_program, + self.startup_program) + completer = Completer(dist_context) + completer.complete_prim_annotation(self.main_program) + dist_context.block_state.parse_forward_blocks(self.main_program) + dist_context.block_state.parse_backward_blocks(self.main_program) + dist_context.grads_params = dict() + dist_context.grads_params[self.w_grad.name] = self.w.name + dist_context.synced_gradient = set() + dist_context.data_parallel_group = list(range(nranks)) + partitioner = Partitioner(dist_context, rank) + dist_main_prog, dist_startup_prog, _ = partitioner.partition( + self.main_program, self.startup_program, [(self.w, self.w_grad)]) + ops = dist_main_prog.global_block().ops + + self.assertTrue(ops[1].type == "c_allreduce_sum") + self.assertTrue(ops[3].type == "c_allreduce_sum") + + +if __name__ == "__main__": + unittest.main() -- GitLab