From 5c91010dd20390543e3c0658908b84bd7ba3c367 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Tue, 19 Apr 2022 09:25:01 +0800 Subject: [PATCH] [AutoParallel] dist p-norm op (#41805) * add dist_pnorm op * update cmakelist * fix cmakelist * fix cmakelist --- .../auto_parallel/operators/__init__.py | 1 + .../auto_parallel/operators/dist_pnorm.py | 363 ++++++++++++++++++ .../unittests/auto_parallel/CMakeLists.txt | 1 + .../auto_parallel/test_dist_pnorm.py | 127 ++++++ 4 files changed, 492 insertions(+) create mode 100644 python/paddle/distributed/auto_parallel/operators/dist_pnorm.py create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_dist_pnorm.py diff --git a/python/paddle/distributed/auto_parallel/operators/__init__.py b/python/paddle/distributed/auto_parallel/operators/__init__.py index c87316ed8f..3f06b34b53 100644 --- a/python/paddle/distributed/auto_parallel/operators/__init__.py +++ b/python/paddle/distributed/auto_parallel/operators/__init__.py @@ -28,6 +28,7 @@ from . import dist_check_finite_and_unscale from . import dist_update_loss_scaling from . import dist_split from . import dist_fill_constant_batch_size_like +from . import dist_pnorm from . import dist_slice from . import dist_fused_feedforward from . import dist_fused_attention diff --git a/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py b/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py new file mode 100644 index 0000000000..ce68e20602 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py @@ -0,0 +1,363 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import paddle +import paddle.fluid.layers.utils as utils + +from .common import DistributedOperatorImplContainer +from .common import DistributedOperatorImpl +from .common import register_distributed_operator_impl_container +from .common import register_distributed_operator_impl +from .common import set_comm_op_dist_attr_for_program +from .dist_default import DistributedDefaultImpl0 +from ..reshard import Resharder +from ..process_group import new_process_group +from ..utils import is_dim_shard, is_dim_replicate, _get_corresponding_rank +from ..utils import compute_compatible_dim_mapping, set_dist_op_desc_original_id, _get_comm_group +from ..dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute + +from paddle.fluid import core, unique_name +from paddle.fluid.framework import Operator +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype + + +class DistributedPNorm(DistributedOperatorImplContainer): + def __init__(self, op_type): + super(DistributedPNorm, self).__init__(op_type) + + +register_distributed_operator_impl_container(DistributedPNorm("p_norm")) + + +def _insert_fill_constant_op(block, op_role): + """Insert fill constant op into block at the given index.""" + helper = LayerHelper("fill_constant", **locals()) + with paddle.static.program_guard(block.program): + out = helper.create_variable_for_type_inference(dtype="int32") + inputs = {} + attrs = {'force_cpu': False} + attrs['str_value'] = str(int("1")) + attrs['value'] = int("1") + attrs['dtype'] = out.dtype + attrs['op_role'] = op_role + utils.get_shape_tensor_inputs( + inputs=inputs, attrs=attrs, shape=[0], op_type='fill_constant') + fill_constant_op = block.append_op( + type='fill_constant', + inputs=inputs, + outputs={'Out': [out]}, + attrs=attrs) + out.stop_gradient = True + return out, fill_constant_op + + +# Row Parallel +class DistributedPNormImpl(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedPNormImpl, self).__init__(name) + self._forward_implemented = True + self._backward_implemented = True + + def is_input_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + x_name = op_desc.input('X')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + if is_dim_replicate(x_dims_mapping[0]): + return False + # Other dimensions must be replicate except the batch dimension + for mapping in x_dims_mapping[1:]: + if is_dim_shard(mapping): + return False + return True + + def is_output_compatible(self, dist_op): + return True + + def is_compatible(self, dist_op): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): + return False + return True + + def is_auto_compatible(self, dist_op): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)) or \ + (not self.is_compatible(dist_op)): + return False + return True + + def update_dims_mapping(self, dist_op): + changed = False + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + + batch_dim_mappings = [] + for arg_name in op_desc.input_arg_names(): + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + if len(dims_mapping) >= 1: + batch_dim_mappings.append(dims_mapping[0]) + for arg_name in op_desc.output_arg_names(): + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if len(dims_mapping) >= 1: + batch_dim_mappings.append(dims_mapping[0]) + + compatible_dim_mapping = compute_compatible_dim_mapping( + batch_dim_mappings) + assert compatible_dim_mapping is not None, "There is no compatible dim mapping." + + for arg_name in op_desc.input_arg_names(): + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + if len(dims_mapping + ) >= 1 and compatible_dim_mapping != dims_mapping[0]: + dims_mapping[0] = compatible_dim_mapping + changed = True + for arg_name in op_desc.output_arg_names(): + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if len(dims_mapping + ) >= 1 and compatible_dim_mapping != dims_mapping[0]: + dims_mapping[0] = compatible_dim_mapping + changed = True + + return changed + + @staticmethod + def forward(ctx, *args, **kwargs): + + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.work_block + src_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id + op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) + assert op_dist_attr is not None + + # check validation of inputs / outputs + for input_name in src_op.desc.input_names(): + assert input_name in kwargs, "input [{}] is not given".format( + input_name) + assert len(kwargs[input_name]) == len( + src_op.desc.input(input_name) + ), "number of tensor for input [{}] is not match".format(input_name) + for output_name in src_op.desc.output_names(): + assert output_name in kwargs, "input [{}] is not given".format( + output_name) + assert len(kwargs[output_name]) == len( + src_op.desc.output(output_name) + ), "number of tensor for input [{}] is not match".format( + output_name) + + if rank_id not in op_dist_attr.process_mesh.processes: + rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, + rank_id) + + X_var = main_block.var(kwargs['X'][0]) + in_dims_mapping = op_dist_attr.get_input_dims_mapping(X_var.name) + for axis in range(len(in_dims_mapping)): + if in_dims_mapping[axis] != -1: + break + process_mesh_shape = op_dist_attr.process_mesh.topology + process_mesh_group = op_dist_attr.process_mesh.processes + group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, + axis, rank_id) + group = new_process_group(group_ranks) + + check_variable_and_dtype(X_var, 'x', ['float16', 'float32', 'float64'], + 'norm') + check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], + 'norm') + + # 1. insert barrier op + ref_process_mesh = op_dist_attr.process_mesh + constant_out_dims_mapping = [-1] + fill_constant_out, fill_constant_op = _insert_fill_constant_op( + main_block, src_op.attr('op_role')) + # set fill_constant_out tensor dist_attr + constant_out_dist_attr = TensorDistributedAttribute() + constant_out_dist_attr.process_mesh = ref_process_mesh + constant_out_dist_attr.dims_mapping = constant_out_dims_mapping + ctx.set_tensor_dist_attr_for_program(fill_constant_out, + constant_out_dist_attr) + # set fill_constant op dist_attr + constant_op_dist_attr = OperatorDistributedAttribute() + constant_op_dist_attr.process_mesh = ref_process_mesh + constant_op_dist_attr.set_output_dims_mapping(fill_constant_out.name, + constant_out_dims_mapping) + ctx.set_op_dist_attr_for_program(fill_constant_op, + constant_op_dist_attr) + barrier_op = main_block.append_op( + type='barrier', + inputs={'X': [fill_constant_out]}, + outputs={'Out': [fill_constant_out]}, + attrs={'ring_id': group.id}) + # set barrier op dist attr + set_comm_op_dist_attr_for_program(barrier_op, ref_process_mesh, + constant_out_dist_attr, ctx) + + # 2. insert c_allgather op + # create c_allgather output var + allgather_out = main_block.create_var( + name=".".join(["c_allgather", X_var.name]), + dtype=X_var.dtype, + shape=X_var.shape, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=X_var.stop_gradient) + # set allgather_out tensor dist_attr + allgather_out_dist_attr = TensorDistributedAttribute() + allgather_out_dist_attr.process_mesh = op_dist_attr.process_mesh + allgather_out_dist_attr.dims_mapping = [ + -1 for i in range(len(allgather_out.shape)) + ] + ctx.set_tensor_dist_attr_for_program(allgather_out, + allgather_out_dist_attr) + c_allgather_op = main_block.append_op( + type='c_allgather', + inputs={'X': [X_var]}, + outputs={'Out': [allgather_out]}, + attrs={ + 'ring_id': group.id, + 'use_calc_stream': True, + 'nranks': group.nranks, + 'op_role': src_op.attr('op_role') + }) + # set c_allgather op dist_attr + allgather_op_dist_attr = OperatorDistributedAttribute() + allgather_op_dist_attr.process_mesh = op_dist_attr.process_mesh + allgather_op_dist_attr.set_input_dims_mapping(X_var.name, + in_dims_mapping) + allgather_op_dist_attr.set_output_dims_mapping( + allgather_out.name, allgather_out_dist_attr.dims_mapping) + ctx.set_op_dist_attr_for_program(c_allgather_op, allgather_op_dist_attr) + + # 3. copy p_norm op desc and reset input name + # rename input + kwargs['X'] = [allgather_out.name] + # replicate op in dist program + dist_op_desc = main_block.desc.append_op() + dist_op_desc.copy_from(src_op.desc) + set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx) + for input_name in src_op.desc.input_names(): + dist_op_desc.set_input(input_name, kwargs[input_name]) + for output_name in src_op.desc.output_names(): + dist_op_desc.set_output(output_name, kwargs[output_name]) + pnorm_op = Operator(main_block, dist_op_desc) + op_dist_attr.set_input_dims_mapping( + allgather_out.name, allgather_out_dist_attr.dims_mapping) + ctx.set_op_dist_attr_for_program(pnorm_op, op_dist_attr) + + main_block._sync_with_cpp() + + @staticmethod + def backward(ctx, *args, **kwargs): + + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.work_block + backward_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id + op_dist_attr = ctx.get_op_dist_attr_for_program(backward_op) + assert op_dist_attr is not None + + # check validation of inputs / outputs + for input_name in backward_op.desc.input_names(): + assert input_name in kwargs, "input [{}] is not given".format( + input_name) + assert len(kwargs[input_name]) == len( + backward_op.desc.input(input_name) + ), "number of tensor for input [{}] is not match".format(input_name) + for output_name in backward_op.desc.output_names(): + assert output_name in kwargs, "input [{}] is not given".format( + output_name) + assert len(kwargs[output_name]) == len( + backward_op.desc.output(output_name) + ), "number of tensor for input [{}] is not match".format( + output_name) + + X_var = main_block.var(kwargs['X'][0]) + X_grad_var = main_block.var(kwargs['X@GRAD'][0]) + + # 1. copy p_norm_grad op and reset input name and output name + new_kwargs = copy.deepcopy(kwargs) + new_kwargs['X'] = [".".join(["c_allgather", X_var.name])] + new_X_var = main_block.var(new_kwargs['X'][0]) + new_X_grad = main_block.create_var( + name=".".join(["c_allgather", X_grad_var.name]), + dtype=X_grad_var.dtype, + shape=new_X_var.shape, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=X_grad_var.stop_gradient) + new_kwargs['X@GRAD'] = [new_X_grad.name] + new_X_var_dist_attr = ctx.get_tensor_dist_attr_for_program(new_X_var) + ctx.set_tensor_dist_attr_for_program(new_X_grad, new_X_var_dist_attr) + # replicate op in dist program with new kwargs + dist_op_desc = main_block.desc.append_op() + dist_op_desc.copy_from(backward_op.desc) + # Refer to the related dist op + set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx) + for input_name in backward_op.desc.input_names(): + dist_op_desc.set_input(input_name, new_kwargs[input_name]) + for output_name in backward_op.desc.output_names(): + dist_op_desc.set_output(output_name, new_kwargs[output_name]) + p_norm_grad_op = Operator(main_block, dist_op_desc) + op_dist_attr.set_input_dims_mapping(new_X_var.name, + new_X_var_dist_attr.dims_mapping) + op_dist_attr.set_output_dims_mapping(new_X_grad.name, + new_X_var_dist_attr.dims_mapping) + ctx.set_op_dist_attr_for_program(p_norm_grad_op, op_dist_attr) + main_block._sync_with_cpp() + + # 2. insert slice op + process_mesh_shape = op_dist_attr.process_mesh.topology + process_mesh_group = op_dist_attr.process_mesh.processes + dims_mapping = [0] + [-1 for _ in range(len(new_X_grad.shape) - 1)] + partition_idx = Resharder.compute_partition_index( + rank_id, new_X_grad.shape, dims_mapping, process_mesh_shape, + process_mesh_group) + slice_starts = [] + slice_ends = [] + slices_axes = [] + for idx, item in enumerate(partition_idx): + slice_starts.append(item[0]) + slice_ends.append(item[1]) + slices_axes.append(idx) + + infer_flags = list(1 for i in range(len(slices_axes))) + attrs = { + "axes": slices_axes, + "starts": slice_starts, + "ends": slice_ends, + "infer_flags": infer_flags, + "op_role": backward_op.attr('op_role') + } + slice_op = main_block.append_op( + type='slice', + inputs={'Input': [new_X_grad]}, + outputs={'Out': [X_grad_var]}, + attrs=attrs) + X_grad_var_dims_mapping = op_dist_attr.get_output_dims_mapping( + X_grad_var.name) + slice_op_dist_attr = OperatorDistributedAttribute() + slice_op_dist_attr.process_mesh = op_dist_attr.process_mesh + slice_op_dist_attr.set_input_dims_mapping( + new_X_grad.name, new_X_var_dist_attr.dims_mapping) + slice_op_dist_attr.set_output_dims_mapping(X_grad_var.name, + X_grad_var_dims_mapping) + ctx.set_op_dist_attr_for_program(slice_op, slice_op_dist_attr) + main_block._sync_with_cpp() + + +register_distributed_operator_impl("p_norm", + DistributedPNormImpl("row_parallel")) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 46b1cffa54..3bbc15ca0c 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -18,6 +18,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_recorder MODULES test_recorder ENVS ${dist_ENVS}) py_test_modules(test_trial MODULES test_trial ENVS ${dist_ENVS}) py_test_modules(test_new_cost_model MODULES test_new_cost_model ENVS ${dist_ENVS}) + py_test_modules(test_dist_pnorm MODULES test_dist_pnorm ENVS ${dist_ENVS}) py_test_modules(test_dist_slice MODULES test_dist_slice ENVS ${dist_ENVS}) py_test_modules(test_cluster MODULES test_cluster ENVS ${dist_ENVS}) endif() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_pnorm.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_pnorm.py new file mode 100644 index 0000000000..946f33b7e4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_pnorm.py @@ -0,0 +1,127 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import paddle.distributed.auto_parallel as auto + +from paddle.fluid import program_guard +from paddle.fluid.backward import append_backward +from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr + +paddle.enable_static() + + +def make_program_dp2(): + main_program = paddle.fluid.Program() + start_program = paddle.fluid.Program() + with paddle.static.program_guard(main_program, start_program): + x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32') + x.stop_gradient = False + auto.shard_tensor( + x, + dist_attr={ + "process_mesh": auto.ProcessMesh([0, 1]), + "dims_mapping": [0, -1, -1] + }) + tmp_0 = paddle.norm(x, p=2) + return main_program, start_program, tmp_0 + + +def make_program_serial(): + main_program = paddle.fluid.Program() + start_program = paddle.fluid.Program() + with paddle.static.program_guard(main_program, start_program): + x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32') + x.stop_gradient = False + auto.shard_tensor( + x, + dist_attr={ + "process_mesh": auto.ProcessMesh([0]), + "dims_mapping": [-1, -1, -1] + }) + tmp_0 = paddle.norm(x, p=2) + return main_program, start_program, tmp_0 + + +def parallelizer(program_func, rank): + from paddle.distributed.auto_parallel.completion import Completer + from paddle.distributed.auto_parallel.partitioner import Partitioner + from paddle.distributed.auto_parallel.dist_context import DistributedContext + + main_program, start_program, loss = program_func() + + dist_context = DistributedContext() + completer = Completer(dist_context) + completer.complete_forward_annotation(main_program) + dist_context.block_state.parse_forward_blocks(main_program) + + with program_guard(main_program, start_program): + params_grads = append_backward( + loss, distop_context=dist_context.dist_op_context) + completer.complete_backward_annotation(main_program) + + dist_context.block_state.parse_backward_blocks(main_program) + partitioner = Partitioner(dist_context, rank) + dist_main_prog, _, _ = partitioner.partition(main_program, start_program, + []) + + return dist_main_prog, dist_context + + +class TestDistPNorm(unittest.TestCase): + def test_dist_pnorm_dp2(self): + + for rank in range(2): + dist_main_prog, dist_context = parallelizer(make_program_dp2, rank) + ops = dist_main_prog.global_block().ops + op_types = [] + for op in ops: + op_types.append(op.type) + op_dist_attr = dist_context.get_op_dist_attr_for_program(op) + if op.type == "p_norm": + assert op_dist_attr.impl_type == "p_norm" + if op.type in ["p_norm", "p_norm_grad"]: + for input_attr in op_dist_attr.inputs_dist_attrs.values(): + assert set(input_attr.dims_mapping) == set([-1]) + for output_attr in op_dist_attr.outputs_dist_attrs.values(): + assert set(output_attr.dims_mapping) == set([-1]) + if op.type == 'c_allgather': + for input_attr in op_dist_attr.inputs_dist_attrs.values(): + assert input_attr.dims_mapping[0] == 0 + assert set(input_attr.dims_mapping[1:]) == set([-1]) + for output_attr in op_dist_attr.outputs_dist_attrs.values(): + assert set(output_attr.dims_mapping) == set([-1]) + if op.type == 'slice': + for input_attr in op_dist_attr.inputs_dist_attrs.values(): + assert set(input_attr.dims_mapping) == set([-1]) + for output_attr in op_dist_attr.outputs_dist_attrs.values(): + assert output_attr.dims_mapping[0] == 0 + assert set(output_attr.dims_mapping[1:]) == set([-1]) + + assert op_types == [ + "fill_constant", "barrier", "c_allgather", "p_norm", + "fill_constant", "p_norm_grad", "slice" + ] + + def test_dist_pnorm_serial(self): + dist_main_prog, dist_context = parallelizer(make_program_serial, 0) + ops = dist_main_prog.global_block().ops + for op in ops: + op_dist_attr = dist_context.get_op_dist_attr_for_program(op) + assert op_dist_attr.impl_type == "default" + + +if __name__ == "__main__": + unittest.main() -- GitLab