From fdedf90921f4e4228915e9aa4507a8b690c6dfee Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Sat, 29 Jan 2022 11:58:59 +0800 Subject: [PATCH] Auto parallel/qkv fuse (#39080) * support qkv fuse * support qkv fuse * update completion * update completion * update dist_split * rerun ci * is_auto_compatible added * is_auto_compatible added --- .../distributed/auto_parallel/completion.py | 29 ++++- .../auto_parallel/operators/__init__.py | 1 + .../auto_parallel/operators/dist_split.py | 115 ++++++++++++++++++ .../paddle/distributed/auto_parallel/utils.py | 2 +- 4 files changed, 144 insertions(+), 3 deletions(-) create mode 100644 python/paddle/distributed/auto_parallel/operators/dist_split.py diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 54491f9e6c1..45ea9a3c9dd 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -325,10 +325,8 @@ class Completer: def complete_forward_annotation(self, serial_main_program): """ Complete annotation for the partial annotated serial_main_program. - Arguments: serial_main_program: partial annotated serial_main_program. - Returns: serial_main_program: completed annotated serial_main_program. """ @@ -443,6 +441,33 @@ class Completer: dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()]) assert forward_op is not None + if grad_op.type == "concat" and forward_op.type == "split": + forward_op_dist_attr = dist_context.get_op_dist_attr_for_program( + forward_op) + output_var = vars[grad_op.desc.output('Out')[0]] + split_input_var_name = forward_op.input("X")[0] + ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping( + split_input_var_name) + ref_mesh = forward_op_dist_attr.process_mesh + + grad_op_dist_attr = OperatorDistributedAttribute() + for input_name in grad_op.input_arg_names: + grad_op_dist_attr.set_input_dims_mapping( + input_name, ref_dims_mapping) + + output_var_dist_attr = TensorDistributedAttribute() + output_var_dist_attr.dims_mapping = ref_dims_mapping + output_var_dist_attr.process_mesh = ref_mesh + dist_context.set_tensor_dist_attr_for_program( + output_var, output_var_dist_attr) + + grad_op_dist_attr.set_output_dims_mapping(output_var.name, + ref_dims_mapping) + grad_op_dist_attr.process_mesh = ref_mesh + dist_context.set_op_dist_attr_for_program(grad_op, + grad_op_dist_attr) + continue + # op dist attr forward_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( forward_op) diff --git a/python/paddle/distributed/auto_parallel/operators/__init__.py b/python/paddle/distributed/auto_parallel/operators/__init__.py index ea743df8d64..9f84df2d896 100644 --- a/python/paddle/distributed/auto_parallel/operators/__init__.py +++ b/python/paddle/distributed/auto_parallel/operators/__init__.py @@ -26,3 +26,4 @@ from . import dist_default from . import dist_eltwise from . import dist_check_finite_and_unscale from . import dist_update_loss_scaling +from . import dist_split diff --git a/python/paddle/distributed/auto_parallel/operators/dist_split.py b/python/paddle/distributed/auto_parallel/operators/dist_split.py new file mode 100644 index 00000000000..289da80e1a7 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/operators/dist_split.py @@ -0,0 +1,115 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .common import DistributedOperatorImplContainer +from .common import DistributedOperatorImpl +from .common import register_distributed_operator_impl_container +from .common import register_distributed_operator_impl +from ..utils import is_dim_shard +from ..utils import is_valid_list_index +from ..utils import compute_compatible_dim_mapping +from ..utils import compute_compatible_dims_mapping +from ..utils import compute_compatible_and_update_dim_mapping +from .dist_default import DistributedDefaultImpl0 + + +class DistributedSplit(DistributedOperatorImplContainer): + def __init__(self, op_type): + super(DistributedSplit, self).__init__(op_type) + + +register_distributed_operator_impl_container(DistributedSplit("split")) + + +class DistributedSplitImpl(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedSplitImpl, self).__init__(name) + self._forward_implemented = True + self._backward_implemented = True + + def is_input_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + x_name = op_desc.input('X')[0] + axis = op_desc.attr('axis') + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + + if is_dim_shard(x_dims_mapping[axis]): + return False + + return True + + def is_output_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + out_names = op_desc.output('Out') + axis = op_desc.attr('axis') + for out_name in out_names: + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + if is_dim_shard(out_dims_mapping[axis]): + return False + + return True + + def is_compatible(self, dist_op): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): + return False + + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + x_name = op_desc.input('X')[0] + axis = op_desc.attr('axis') + out_names = op_desc.output('Out') + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + for out_name in out_names: + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + if x_dims_mapping != out_dims_mapping: + return False + + return True + + def update_dims_mapping(self, dist_op): + changed = False + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + x_name = op_desc.input('X')[0] + out_names = op_desc.output('Out') + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + + for out_name in out_names: + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + for i in range(len(x_dims_mapping)): + dim_changed = compute_compatible_and_update_dim_mapping( + [x_dims_mapping, out_dims_mapping], [i, i]) + if dim_changed: + changed = True + + return changed + + def is_auto_compatible(self, dist_op): + raise NotImplementedError( + "Auto Search is not supported by dist split yet.") + + @staticmethod + def forward(ctx, *args, **kwargs): + DistributedDefaultImpl0.forward(ctx, *args, **kwargs) + + @staticmethod + def backward(ctx, *args, **kwargs): + DistributedDefaultImpl0.backward(ctx, *args, **kwargs) + + +register_distributed_operator_impl("split", + DistributedSplitImpl("replicate_in_axis")) diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index f81291fa64f..75e0ae251ef 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1002,7 +1002,7 @@ def set_grad_var_shape(program, dist_context): if op.type in ["check_finite_and_unscale", "update_loss_scaling"]: break - if op.type in ["sum"]: + if op.type in ["sum", "concat"]: continue if int(op.attr('op_role')) == int(OpRole.Backward): op_dist_attr = dist_context.get_op_dist_attr_for_program(op) -- GitLab