From 797bd40d093189ce3c9f24fcd0f59bbe2878b2ca Mon Sep 17 00:00:00 2001 From: JZ-LIANG <38102074+JZ-LIANG@users.noreply.github.com> Date: Wed, 20 Oct 2021 10:23:35 +0800 Subject: [PATCH] [Auto Parallel] Generalization for Partition and Completion (#35735) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * default dist op * add dist_attr for dist op * add unitest * update inputname * update function name * add unitest * update CMakeLists.txt for CI * fix dis_matmul * fix compile error * update matmul to matmul_v2 * unify api * unify api * todo * update distop forward func * update distop forward func * auto parallel backward * update dist op * autoparallel backward * add backward for embedding * temp1 * temp2 * temp3 * temp4 * backward done1 * backward done2 * backward done3 * dist embedding remove mp mode * dist matmul remove mp mode * update dist embedding 『 * dist op init1 * dist op init 2 * update unitest * context remove parallel mode * partitioner remove parallel mode * update unitest * a more general method to support varying mesh in pipeline parallel * support varying mesh in pipeline parallel * embedding support varying mesh in pipeline parallel * matmul support varying mesh in pipeline parallel * default dist op support varying mesh in pipeline parallel * dist attribute for startup program * default dist op support varying mesh in pipeline parallel 2 * partitoner support varying mesh in pipeline parallel * revise logic for auto compeletion * revise framework.py * revise reshard unitest * revise unitest for parallelize * chmod * fixed bug for dist embedding name mapping Co-authored-by: zhaoyingli --- .../distributed/auto_parallel/completion.py | 269 +++--- .../distributed/auto_parallel/context.py | 125 ++- .../auto_parallel/operators/__init__.py | 1 + .../auto_parallel/operators/common.py | 6 +- .../auto_parallel/operators/dist_default.py | 247 +++++ .../auto_parallel/operators/dist_embedding.py | 331 ++++--- .../auto_parallel/operators/dist_matmul.py | 911 +++++++++++------- .../auto_parallel/operators/dist_reshape.py | 288 +++--- .../auto_parallel/operators/dist_softmax.py | 6 + .../auto_parallel/operators/dist_transpose.py | 6 + .../distributed/auto_parallel/parallelizer.py | 4 +- .../distributed/auto_parallel/partitioner.py | 414 ++++---- .../paddle/distributed/auto_parallel/utils.py | 45 +- python/paddle/fluid/backward.py | 13 +- .../fluid/tests/unittests/CMakeLists.txt | 3 + .../unittests/auto_parallel_parallelizer.py | 140 +++ .../test_auto_parallel_parallelizer.py | 126 +-- .../test_auto_parallel_partitioner.py | 100 +- .../test_auto_parallel_partitioner_gpt.py | 30 +- .../unittests/test_auto_parallel_reshard.py | 7 +- .../test_auto_parallel_reshard_dpmppp.py | 2 - .../test_auto_parallel_reshard_mppp.py | 2 - 22 files changed, 1896 insertions(+), 1180 deletions(-) create mode 100755 python/paddle/distributed/auto_parallel/operators/dist_default.py mode change 100644 => 100755 python/paddle/distributed/auto_parallel/operators/dist_embedding.py create mode 100755 python/paddle/fluid/tests/unittests/auto_parallel_parallelizer.py diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 3fdbad6950d..855eb656bd9 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -24,6 +24,7 @@ from .utils import print_program_with_distributed_attr from .context import get_default_distributed_context from .operators import find_best_compatible_distributed_operator_impl from .attribute import OperatorDistributedAttribute, TensorDistributedAttribute +from paddle.distributed.fleet.meta_optimizers.common import OpRole ELEMENTWISE_LIKE_OP_LIST = ["elementwise_add", "gelu", "dropout", "cast"] @@ -600,7 +601,7 @@ def complete_annotation(program, dist_context=None): return program -def complete_backward_annotation(auto_parallel_main_prog, dist_context): +def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): """Complete the annotation of vars and ops in the backward phase for parallel program.""" def _is_grad_var_name(name): @@ -608,24 +609,44 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context): return True return False - grad_start_idx = None + def _get_forward_varname_from_grad_varname(grad_var_name): + assert _is_grad_var_name( + grad_var_name), "[{}] is not a grad varnme.".format(grad_var_name) + return grad_var_name[:grad_var_name.find("@GRAD")] + + def _get_op_by_id(ops, id): + for op in ops: + if op.desc.id() == id: + return op + return None + + if dist_context is None: + dist_context = get_default_distributed_context() + + grad_start_idx = -1 for idx, op in enumerate(auto_parallel_main_prog.global_block().ops): - for var_name in op.output_arg_names: - # TODO: use _is_loss_op to judge - if "@GRAD" in var_name and op.type == "fill_constant": - grad_start_idx = idx - break - assert grad_start_idx is not None, "No backward procedure found in this program." + if int(op.attr('op_role')) == int( + int(core.op_proto_and_checker_maker.OpRole.Backward) | int( + core.op_proto_and_checker_maker.OpRole.Loss)): + assert op.type == "fill_constant" + grad_start_idx = idx + break + + assert grad_start_idx >= 0, "No backward procedure found in this program." ops = list(auto_parallel_main_prog.global_block().ops) vars = auto_parallel_main_prog.global_block().vars + for idx in range(grad_start_idx, len(ops)): - # complete the loss op + + # complete the initial grad loss op if idx == grad_start_idx: grad_var = vars[ops[idx].output_arg_names[0]] - grad_var_name = grad_var.name - forward_var_name = grad_var_name[:grad_var_name.find("@GRAD")] + forward_var_name = _get_forward_varname_from_grad_varname( + grad_var.name) forward_var = vars[forward_var_name] + + # TODO complete other attribte for grad var tensor_attr = TensorDistributedAttribute(grad_var, dist_context) process_mesh = dist_context.get_tensor_distributed_attr_for_program( forward_var).get_process_mesh() @@ -635,39 +656,31 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context): tensor_attr.set_process_mesh(process_mesh) dist_context.set_tensor_distributed_attr_for_program(grad_var, tensor_attr) + op_attr = OperatorDistributedAttribute(ops[idx], dist_context) op_attr.set_process_mesh(process_mesh) dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr) - - # in the data parallel mode, the loss op followed by scale op. - if ops[idx + 1].type == "scale" and grad_var_name in ops[idx + 1].input_arg_names \ - and grad_var_name in ops[idx + 1].output_arg_names: - op_attr = OperatorDistributedAttribute(ops[idx + 1], - dist_context) - op_attr.set_process_mesh(process_mesh) - dist_context.set_op_distributed_attr_for_program(ops[idx + 1], - op_attr) continue - # complete the annotation of the optimizer op. - # TODO: use _is_optimizer_op to judge - if "Grad" in ops[idx].input_names and "Param" in ops[idx].input_names: - assert len(ops[idx].input( - "Param")) == 1, "Only support one-to-one now." - assert len(ops[idx].input( - "Grad")) == 1, "Only support one-to-one now." - var = vars[ops[idx].input("Param")[0]] - grad_var = vars[ops[idx].input("Grad")[0]] + # TODO remove this when dist op handle its own grad scale + # in the data parallel mode, the loss op followed by scale op. + if ops[idx].type == "scale" and idx == grad_start_idx + 1: + assert grad_var.name in ops[ + idx].input_arg_names and grad_var.name in ops[ + idx].output_arg_names + grad_var = vars[ops[idx].output_arg_names[0]] + forward_var_name = _get_forward_varname_from_grad_varname( + grad_var.name) + forward_var = vars[forward_var_name] process_mesh = dist_context.get_tensor_distributed_attr_for_program( - var).get_process_mesh() - dims_mapping = dist_context.get_tensor_distributed_attr_for_program( - var).get_dims_mapping() + forward_var).get_process_mesh() op_attr = OperatorDistributedAttribute(ops[idx], dist_context) op_attr.set_process_mesh(process_mesh) - op_attr.set_input_dims_mapping(grad_var.name, dims_mapping) dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr) continue + # TODO remove this when dist op handle its own communication + # TODO should distinguish the dp allreduce and mp allreduce # complete the c_allreduce_sum op for gradient in the data parallel mode. if ops[idx].type == "c_allreduce_sum" and ops[ idx].input_arg_names == ops[idx].output_arg_names: @@ -679,91 +692,123 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context): dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr) continue - # complete the annotation of grad op + # complete the annotation of grad op (xxx_grad op or sum op) grad_op = ops[idx] - for i, op in enumerate(ops[:grad_start_idx]): - match_op = None - grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(op.desc, - set(), - []) - grad_op_input = [] - for input_arg_name in grad_op.desc.input_arg_names(): - if "@GRAD" in input_arg_name: - name = input_arg_name[:input_arg_name.find("@GRAD") + 5] - grad_op_input.append(name) - else: - grad_op_input.append(input_arg_name) - - # like sum op: the count of grad op will larger than 1 - if len(grad_op_desc_list) > 1: - for grad_op_desc in grad_op_desc_list: - if grad_op_input == grad_op_desc.input_arg_names() \ - and grad_op.desc.type() == grad_op_desc.type(): - match_op = op - break - elif len(grad_op_desc_list) == 1: - if grad_op_input == grad_op_desc_list[0].input_arg_names() \ - and grad_op.desc.type() == grad_op_desc_list[0].type(): - match_op = op - - if match_op is not None: - op_attr = dist_context.get_op_distributed_attr_for_program(op) - grad_op_attr = OperatorDistributedAttribute(grad_op, - dist_context) - grad_op_attr.set_process_mesh(op_attr.get_process_mesh()) - for var_name in grad_op.input_arg_names: - if "@GRAD" in var_name: - dims_mapping = dist_context.get_tensor_distributed_attr_for_program( - vars[var_name]).get_dims_mapping() - grad_op_attr.set_input_dims_mapping(var_name, - dims_mapping) - else: - dims_mapping = op_attr.get_input_dims_mapping(var_name) - grad_op_attr.set_input_dims_mapping(var_name, - dims_mapping) - dist_context.set_op_distributed_attr_for_program(grad_op, - grad_op_attr) - - for var_name in grad_op.output_arg_names: - if "@GRAD" in var_name: - forward_var = vars[var_name[:var_name.find("@GRAD")]] - tensor_attr = TensorDistributedAttribute(vars[var_name], - dist_context) - process_mesh = grad_op_attr.get_process_mesh() - dims_mapping = grad_op_attr.get_input_dims_mapping( - forward_var.name) - tensor_attr.set_process_mesh(process_mesh) - tensor_attr.set_dims_mapping(dims_mapping) - dist_context.set_tensor_distributed_attr_for_program( - vars[var_name], tensor_attr) - break - - # complete the annotation of sum op for multiple renamed grad var - if grad_op.type == "sum" and all( - map(_is_grad_var_name, grad_op.input_arg_names)): - assert len(grad_op.output_arg_names - ) == 1, "The output count of sum op should be one." + + # xxx_grad op will have a corresponding forward op in gradopidx2opidx + dist_op_helper = dist_context.get_dist_op_helper() + if grad_op.desc.id() in dist_op_helper.gradopidx2opidx: + # TODO support the case where one forward op corresponding to multiple xxx_grad op + forward_op = _get_op_by_id( + ops[:grad_start_idx], + dist_op_helper.gradopidx2opidx[grad_op.desc.id()]) + assert forward_op is not None + + # op dist attr + forward_op_attr = dist_context.get_op_distributed_attr_for_program( + forward_op) grad_op_attr = OperatorDistributedAttribute(grad_op, dist_context) + grad_op_attr.set_process_mesh(forward_op_attr.get_process_mesh()) + for var_name in grad_op.input_arg_names: if "@GRAD" in var_name: - forward_var = vars[var_name[:var_name.find("@GRAD")]] dims_mapping = dist_context.get_tensor_distributed_attr_for_program( - forward_var).get_dims_mapping() + vars[var_name]).get_dims_mapping() + grad_op_attr.set_input_dims_mapping(var_name, dims_mapping) + else: + dims_mapping = forward_op_attr.get_input_dims_mapping( + var_name) + # TODO fixed here + if dims_mapping == None: + dims_mapping = forward_op_attr.get_output_dims_mapping( + var_name) + assert dims_mapping is not None, "[{}]'s dims_mapping is None".format( + var_name) grad_op_attr.set_input_dims_mapping(var_name, dims_mapping) + dist_context.set_op_distributed_attr_for_program(grad_op, + grad_op_attr) + # var dist attr for var_name in grad_op.output_arg_names: - forward_var = vars[var_name[:var_name.find("@GRAD")]] - tensor_attr = TensorDistributedAttribute(vars[var_name], - dist_context) - process_mesh = dist_context.get_tensor_distributed_attr_for_program( - forward_var).get_process_mesh() - dims_mapping = dist_context.get_tensor_distributed_attr_for_program( - forward_var).get_dims_mapping() - tensor_attr.set_dims_mapping(dims_mapping) - tensor_attr.set_process_mesh(process_mesh) - dist_context.set_tensor_distributed_attr_for_program( - vars[var_name], tensor_attr) - grad_op_attr.set_process_mesh( - dist_context.get_tensor_distributed_attr_for_program( - forward_var).get_process_mesh()) + if _is_grad_var_name(var_name): + + forward_var_name = _get_forward_varname_from_grad_varname( + var_name) + forward_var = vars[forward_var_name] + tensor_attr = TensorDistributedAttribute(vars[var_name], + dist_context) + process_mesh = grad_op_attr.get_process_mesh() + dims_mapping = grad_op_attr.get_input_dims_mapping( + forward_var_name) + tensor_attr.set_process_mesh(process_mesh) + tensor_attr.set_dims_mapping(dims_mapping) + dist_context.set_tensor_distributed_attr_for_program( + vars[var_name], tensor_attr) + + # only sum op for merge mutiple version grad has no a corresponding mapping in gradopidx2opidx + else: + assert grad_op.type == "sum", "got unexpect op [{}]".format( + str(grad_op.type)) + assert all(map(_is_grad_var_name, grad_op.input_arg_names)) + assert len(grad_op.output_arg_names) == 1 + + ref_forward_var_name = _get_forward_varname_from_grad_varname( + grad_op.output_arg_names[0]) + forward_var = vars[ref_forward_var_name] + ref_forward_var_dims_mapping = dist_context.get_tensor_distributed_attr_for_program( + forward_var).get_dims_mapping() + ref_forward_var_process_mesh = dist_context.get_tensor_distributed_attr_for_program( + forward_var).get_process_mesh() + + # output + tensor_attr = TensorDistributedAttribute( + vars[grad_op.output_arg_names[0]], dist_context) + tensor_attr.set_dims_mapping(ref_forward_var_dims_mapping) + tensor_attr.set_process_mesh(ref_forward_var_process_mesh) + dist_context.set_tensor_distributed_attr_for_program( + vars[grad_op.output_arg_names[0]], tensor_attr) + + # op + grad_op_attr = OperatorDistributedAttribute(grad_op, dist_context) + grad_op_attr.set_process_mesh(ref_forward_var_process_mesh) + for var_name in grad_op.input_arg_names: + assert _get_forward_varname_from_grad_varname( + var_name) == ref_forward_var_name + grad_op_attr.set_input_dims_mapping( + var_name, ref_forward_var_dims_mapping) dist_context.set_op_distributed_attr_for_program(grad_op, grad_op_attr) + + +def complete_update_annotation(auto_parallel_main_prog, dist_context): + """Complete the annotation of vars and ops in the update phase for parallel program.""" + + if dist_context is None: + dist_context = get_default_distributed_context() + + ops = list(auto_parallel_main_prog.global_block().ops) + vars = auto_parallel_main_prog.global_block().vars + + for idx in range(len(ops)): + + # complete the annotation of the optimizer op. + # TODO to add attribute for moment var + if int(ops[idx].attr('op_role')) == int(OpRole.Optimize): + if "Grad" in ops[idx].input_names and "Param" in ops[ + idx].input_names: + assert len(ops[idx].input( + "Param")) == 1, "Only support one-to-one now." + assert len(ops[idx].input( + "Grad")) == 1, "Only support one-to-one now." + param = vars[ops[idx].input("Param")[0]] + grad_var = vars[ops[idx].input("Grad")[0]] + process_mesh = dist_context.get_tensor_distributed_attr_for_program( + param).get_process_mesh() + dims_mapping = dist_context.get_tensor_distributed_attr_for_program( + param).get_dims_mapping() + op_attr = OperatorDistributedAttribute(ops[idx], dist_context) + op_attr.set_process_mesh(process_mesh) + op_attr.set_input_dims_mapping(grad_var.name, dims_mapping) + op_attr.set_input_dims_mapping(param.name, dims_mapping) + dist_context.set_op_distributed_attr_for_program(ops[idx], + op_attr) + continue diff --git a/python/paddle/distributed/auto_parallel/context.py b/python/paddle/distributed/auto_parallel/context.py index 5e6565aa3d8..6785f21351a 100644 --- a/python/paddle/distributed/auto_parallel/context.py +++ b/python/paddle/distributed/auto_parallel/context.py @@ -51,23 +51,8 @@ class DistributedContext: self._op_distributed_attr_map_for_program = {} self._tensor_distributed_attr_map_for_graph = {} self._op_distributed_attr_map_for_graph = {} - # The following is a hard code and will be removed in the future - self._data_parallel_axis = None - self._model_parallel_axis = None + self._get_dist_op_helper = DistOpHelper() self._process_mesh = _g_process_mesh_map.get(0, None) - if self._process_mesh is not None: - if self._process_mesh.ndim == 1: - self._data_parallel_axis = 0 - self._model_parallel_axis = 0 - elif self._process_mesh.ndim == 3: - self._data_parallel_axis = 1 - self._model_parallel_axis = 2 - else: - self._data_parallel_axis = 0 - self._model_parallel_axis = 1 - else: - self._data_parallel_axis = -1 - self._model_parallel_axis = -1 def is_initialized_for_program(self): return self._is_initialized_for_program @@ -120,16 +105,9 @@ class DistributedContext: def set_process_mesh(self, process_mesh): self._process_mesh = process_mesh - if self._process_mesh is not None: - if self._process_mesh.ndim == 1: - self._data_parallel_axis = 0 - self._model_parallel_axis = 0 - else: - self._data_parallel_axis = 0 - self._model_parallel_axis = 1 - else: - self._data_parallel_axis = -1 - self._model_parallel_axis = -1 + + def get_dist_op_helper(self): + return self._get_dist_op_helper def initialize_distributed_attr_for_program(self, program): if self._is_initialized_for_program: @@ -425,10 +403,93 @@ class DistributedContext: and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: dims_mapping[i] = -1 - def _get_data_parallel_info(self): - # This function is a hard code, and will be obsoleted in the future - return self._data_parallel_axis, self._process_mesh - def _get_model_parallel_info(self): - # This function is a hard code, and will be obsoleted in the future - return self._model_parallel_axis, self._process_mesh +class DistOpHelper: + """ + DistOpHelper is used to create a dist op desc in Program. + Every time to create a new dist op, the context should be updated for it accordingly. + """ + + def __init__(self): + self._dst_main_program = None + self._dst_startup_program = None + self._varname_mapping = None + self._rank_id = None + self._cur_src_op = None + self._cur_dist_attr = None + self.gradopidx2opidx = {} + self.already_init_sync_vars = set() + + def set_dst_main_program(self, prog): + self._dst_main_program = prog + + def get_dst_main_program(self): + return self._dst_main_program + + def set_dst_startup_program(self, prog): + self._dst_startup_program = prog + + def get_dst_startup_program(self): + return self._dst_startup_program + + def set_varname_mapping(self, mapping): + self._varname_mapping = mapping + + def get_varname_mapping(self): + return self._varname_mapping + + def set_rank_id(self, rank_id): + self._rank_id = rank_id + + def get_rank_id(self): + return self._rank_id + + def set_cur_src_op(self, cur_src_op): + self._cur_src_op = cur_src_op + + def get_cur_src_op(self): + return self._cur_src_op + + def prepare_forward_context(self, src_op): + + self.set_cur_src_op(src_op) + + # build input varname mapping + kinputs = {} + for input_name in src_op.desc.input_names(): + varnames = [] + for varname in src_op.desc.input(input_name): + varnames.append(self._varname_mapping[varname]) + kinputs[input_name] = varnames + + # build output varname mapping + koutputs = {} + for output_name in src_op.desc.output_names(): + varnames = [] + for varname in src_op.desc.output(output_name): + varnames.append(self._varname_mapping[varname]) + koutputs[output_name] = varnames + + return kinputs, koutputs + + def prepare_backward_context(self, backward_op): + + self.set_cur_src_op(backward_op) + + # build input varname mapping + kinputs = {} + for input_name in backward_op.desc.input_names(): + varnames = [] + for varname in backward_op.desc.input(input_name): + varnames.append(varname) + kinputs[input_name] = varnames + + # build output varname mapping + koutputs = {} + for output_name in backward_op.desc.output_names(): + varnames = [] + for varname in backward_op.desc.output(output_name): + varnames.append(varname) + koutputs[output_name] = varnames + + return kinputs, koutputs diff --git a/python/paddle/distributed/auto_parallel/operators/__init__.py b/python/paddle/distributed/auto_parallel/operators/__init__.py index 14ded477cb7..3b3359b4ebf 100644 --- a/python/paddle/distributed/auto_parallel/operators/__init__.py +++ b/python/paddle/distributed/auto_parallel/operators/__init__.py @@ -22,3 +22,4 @@ from . import dist_matmul from . import dist_reshape from . import dist_softmax from . import dist_transpose +from . import dist_default diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index 1b0b05d3954..5685c40a322 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -36,10 +36,12 @@ class DistributedOperatorImpl: self._forward_implemented = False self._backward_implemented = False - def forward(self, dist_ctx, *args, **kwargs): + @staticmethod + def forward(dist_ctx, *args, **kwargs): raise NotImplementedError("Please Implement this method in Subclass.") - def backward(self, dist_ctx, *grad_outputs): + @staticmethod + def backward(dist_ctx, *grad_outputs, **kwargs): raise NotImplementedError("Please Implement this method in Subclass.") def get_name(self): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py new file mode 100755 index 00000000000..cf17b7afb0f --- /dev/null +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -0,0 +1,247 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from .common import DistributedOperator +from .common import DistributedOperatorImpl +from .common import register_distributed_operator +from .common import register_distributed_operator_impl +from ..utils import is_dim_shard +from ..utils import is_dim_replicate +from ..utils import is_valid_list_index +from ..utils import compute_compatible_dim_mapping +from ..utils import compute_compatible_dims_mapping +from ..utils import compute_compatible_and_update_dim_mapping +from ..attribute import OperatorDistributedAttribute +from paddle.fluid import core, unique_name +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.framework import Program, Parameter, Variable, program_guard +from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype +from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY +from ..process import new_process_group +from ..utils import _get_comm_group, _get_corresponding_rank + + +class DistributedDefault(DistributedOperator): + def __init__(self, name): + super(DistributedDefault, self).__init__() + self._name = name + + +register_distributed_operator("default", DistributedDefault("default")) + + +# Replicated Default +class DistributedDefaultImpl0(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedDefaultImpl0, self).__init__() + self._name = name + self._forward_implemented = True + self._backward_implemented = True + + def is_process_mesh_compatible(self, op_dist_attr): + raise NotImplementedError("Please Implement this method.") + + def is_input_compatible(self, op_dist_attr): + raise NotImplementedError("Please Implement this method.") + + def is_output_compatible(self, op_dist_attr): + raise NotImplementedError("Please Implement this method.") + + def update_dims_mapping(self, op_dist_attr): + raise NotImplementedError("Please Implement this method.") + + @staticmethod + def forward(ctx, *args, **kwargs): + + dist_op_helper = ctx.get_dist_op_helper() + main_block = dist_op_helper.get_dst_main_program().global_block() + startup_block = dist_op_helper.get_dst_startup_program().global_block() + src_op = dist_op_helper.get_cur_src_op() + varname_mapping = dist_op_helper.get_varname_mapping() + rank_id = dist_op_helper.get_rank_id() + + # check validation of inputs / outputs + for input_name in src_op.desc.input_names(): + assert input_name in kwargs, "input [{}] is not given".format( + input_name) + assert len(kwargs[input_name]) == len( + src_op.desc.input(input_name) + ), "number of tensor for input [{}] is not match".format(input_name) + for output_name in src_op.desc.output_names(): + assert output_name in kwargs, "input [{}] is not given".format( + output_name) + assert len(kwargs[output_name]) == len( + src_op.desc.output(output_name) + ), "number of tensor for input [{}] is not match".format( + output_name) + + # replicate op in dist program + dist_op_desc = main_block.desc.append_op() + dist_op_desc.copy_from(src_op.desc) + for input_name in src_op.desc.input_names(): + dist_op_desc.set_input(input_name, kwargs[input_name]) + for output_name in src_op.desc.output_names(): + dist_op_desc.set_output(output_name, kwargs[output_name]) + + main_block._sync_with_cpp() + + # param initialization sync + for varname in dist_op_desc.input_arg_names(): + if startup_block.has_var(varname) and startup_block.var( + varname + ).is_parameter and varname not in dist_op_helper.already_init_sync_vars: + dist_op_helper.already_init_sync_vars.add(varname) + param = startup_block.var(varname) + param_dist_attr = ctx.get_tensor_distributed_attr_for_program( + param) + process_mesh = param_dist_attr.get_process_mesh() + dims_mapping = param_dist_attr.get_dims_mapping() + + # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism + if rank_id not in process_mesh.process_group: + rank_id = _get_corresponding_rank(process_mesh, rank_id) + + # NOTE all not splited axis should be presented in mesh + for axis, size in enumerate(process_mesh.topology): + if size <= 1 or axis in dims_mapping: + pass + else: + group_ranks = _get_comm_group( + process_mesh.process_group, process_mesh.topology, + axis, rank_id) + sync_group = new_process_group(group_ranks) + + new_op = startup_block.append_op( + type='c_broadcast', + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': sync_group.id, + 'root': 0, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Forward + }) + + # set distributed attribute + op_attr = OperatorDistributedAttribute(new_op, ctx) + op_attr.set_process_mesh(process_mesh) + op_attr.set_output_dims_mapping(param.name, + dims_mapping) + op_attr.set_input_dims_mapping(param.name, dims_mapping) + ctx.set_op_distributed_attr_for_program(new_op, op_attr) + + startup_block._sync_with_cpp() + + @staticmethod + def backward(ctx, *args, **kwargs): + + # by now the backward function only insert the gradient allreduce for dist op itself + dist_op_helper = ctx.get_dist_op_helper() + main_block = dist_op_helper.get_dst_main_program().global_block() + backward_op = dist_op_helper.get_cur_src_op() + dist_attr = ctx.get_op_distributed_attr_for_program(backward_op) + assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( + str(backward_op)) + rank_id = dist_op_helper.get_rank_id() + + # check if need gradient allreduce + # if there is a non-gradient & non-parameter input and its batch dimension is splited, + # we need insert gradient allreduce for the gradient of parameter in its output + need_gradient_allreduce = False + for input_name in backward_op.desc.input_names(): + for varname in backward_op.desc.input(input_name): + if "@GRAD" not in varname and not main_block.var( + varname).is_parameter: + + # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op + process_mesh = dist_attr.get_process_mesh() + var_dim_mapping = dist_attr.get_input_dims_mapping(varname) + + # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism + if rank_id not in process_mesh.process_group: + rank_id = _get_corresponding_rank(process_mesh, rank_id) + + mesh_shape = process_mesh.topology + batch_size_axis = var_dim_mapping[0] + if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: + need_gradient_allreduce = True + group_ranks = _get_comm_group( + process_mesh.process_group, process_mesh.topology, + batch_size_axis, rank_id) + dp_degree = len(group_ranks) + dp_group = new_process_group(group_ranks) + break + + if need_gradient_allreduce: + allreduce_vars = [] + for input_name in backward_op.desc.input_names(): + for varname in backward_op.desc.input(input_name): + if "@GRAD" not in varname and main_block.var( + varname).is_parameter: + assert len( + backward_op.desc.input(input_name) + ) == 1, "parameter input to grad op should be length 1, but got [{}]".format( + backward_op.desc.input(input_name)) + + assert varname + "@GRAD" in backward_op.desc.output_arg_names( + ), "parameter's grad [{}] not found in the grad op's output".format( + varname + "@GRAD") + assert len( + backward_op.desc.output(input_name + "@GRAD") + ) == 1, "parameter grad of grad op should be length 1, but got [{}]".format( + backward_op.desc.output(input_name + "@GRAD")) + allreduce_vars.append( + backward_op.desc.output(input_name + "@GRAD")[0]) + + if len(allreduce_vars) > 0: + + for varname in allreduce_vars: + + grad_var = main_block.var(varname) + allreduce_op = main_block.append_op( + type='c_allreduce_sum', + inputs={'X': [grad_var]}, + outputs={'Out': [grad_var]}, + attrs={ + 'ring_id': dp_group.id, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Backward + }) + + scale_op = main_block.append_op( + type='scale', + inputs={'X': grad_var}, + outputs={'Out': grad_var}, + attrs={ + 'scale': 1.0 / dp_degree, + OP_ROLE_KEY: OpRole.Backward + }) + + dims_mapping = ctx.get_tensor_distributed_attr_for_program( + grad_var).get_dims_mapping() + process_mesh = dist_attr.get_process_mesh() + for op in [allreduce_op, scale_op]: + op_attr = OperatorDistributedAttribute(op, ctx) + op_attr.set_process_mesh(process_mesh) + op_attr.set_output_dims_mapping(grad_var.name, + dims_mapping) + op_attr.set_input_dims_mapping(grad_var.name, + dims_mapping) + ctx.set_op_distributed_attr_for_program(op, op_attr) + + main_block._sync_with_cpp() + + +register_distributed_operator_impl( + "default", DistributedDefaultImpl0("replicate_parallel")) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py old mode 100644 new mode 100755 index 3f8fbf9cc3a..cd6d2255c81 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -24,12 +24,14 @@ from ..utils import is_valid_list_index from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping +from ..attribute import OperatorDistributedAttribute from paddle.fluid import core, unique_name from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype +from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from ..process import new_process_group -from ..utils import _get_comm_group +from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank class DistributedEmbedding(DistributedOperator): @@ -40,6 +42,7 @@ class DistributedEmbedding(DistributedOperator): register_distributed_operator("lookup_table_v2", DistributedEmbedding("embedding")) +register_distributed_operator("c_embedding", DistributedEmbedding("embedding")) # RowParallel @@ -48,7 +51,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): super(DistributedEmbeddingImpl, self).__init__() self._name = name self._forward_implemented = True - self._backward_implemented = False + self._backward_implemented = True def is_process_mesh_compatible(self, op_dist_attr): """ No restriction for now. """ @@ -102,127 +105,231 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): return changed - def forward(self, serial_op): - def static_handle(dst_block, - src_op, - op_dist_attr, - input_name_mapping, - output_name_mapping, - rank_id=0): - assert len( - input_name_mapping - ) == 2, "row_parallel_embedding take 2 inputs variable but got {}".format( - input_name_mapping) - assert len( - output_name_mapping - ) == 1, "row_parallel_embedding take 2 inputs variable but got {}".format( - output_name_mapping) - assert len( - input_name_mapping['Ids'] - ) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( - input_name_mapping['Ids']) - assert len( - input_name_mapping['W'] - ) == 1, "row_parallel_embedding input W take 1 variable but got {}".format( - input_name_mapping['W']) - assert len( - output_name_mapping['Out'] - ) == 1, "row_parallel_embedding input Out take 1 variable but got {}".format( - input_name_mapping['Out']) - - Ids_var = dst_block.var(input_name_mapping['Ids'][0]) - Weight_var = dst_block.var(input_name_mapping['W'][0]) - Out_var = dst_block.var(output_name_mapping['Out'][0]) - - # got dist attribute info - embedding_row_dim_mapping = op_dist_attr.get_input_dims_mapping( - Weight_var.name)[0] - process_mesh_shape = op_dist_attr.get_process_mesh().topology - process_mesh_group = op_dist_attr.get_process_mesh().process_group - - # caculate embedding offset - # TODO generalize here, using cartisian product to allow any dimensional mesh shape - mesh_shape = len(process_mesh_shape) - assert mesh_shape <= 2, "row_parallel_embedding only support 1 or 2 dimensional process mesh, but got {}".format( - process_mesh_shape) - num_partition = process_mesh_shape[embedding_row_dim_mapping] - # TODO generalize here, support any mesh group - model_parallel_axis, process_mesh = op_dist_attr.get_owner_context( - )._get_model_parallel_info() - if mesh_shape == 1: - if rank_id not in process_mesh_group: - assert len( - process_mesh.topology - ) == 2, " row_parallel_embedding process mapping only support 2 dimensional process mesh, \ - but got {}".format(len(process_mesh.topology)) - rank_id = process_mesh_group[ - process_mesh.process_group.index(rank_id) % - process_mesh_shape[0]] - relative_idx = process_mesh_group.index(rank_id) + @staticmethod + def forward(ctx, *args, **kwargs): + """ + kwargs: inputname_mapping & outputname_mapping + """ + + dist_op_helper = ctx.get_dist_op_helper() + main_block = dist_op_helper.get_dst_main_program().global_block() + startup_block = dist_op_helper.get_dst_startup_program().global_block() + src_op = dist_op_helper.get_cur_src_op() + rank_id = dist_op_helper.get_rank_id() + op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) + assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( + str(src_op)) + + # check validation of inputs / outputs + assert 'Ids' in kwargs, "input [{}] is not given".format('Ids') + assert 'W' in kwargs, "input [{}] is not given".format('W') + assert 'Out' in kwargs, "output [{}] is not given".format('Out') + + assert len( + kwargs['Ids'] + ) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( + kwargs['Ids']) + assert len( + kwargs['W'] + ) == 1, "row_parallel_embedding input W take 1 variable but got {}".format( + kwargs['W']) + assert len( + kwargs['Out'] + ) == 1, "row_parallel_embedding output Out take 1 variable but got {}".format( + kwargs['Out']) + + Ids_var = main_block.var(kwargs['Ids'][0]) + Weight_var = main_block.var(kwargs['W'][0]) + Out_var = main_block.var(kwargs['Out'][0]) + + # got dist attribute info + embedding_row_dim_mapping = op_dist_attr.get_input_dims_mapping( + Weight_var.name)[0] + assert embedding_row_dim_mapping >= 0, "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format( + embedding_row_dim_mapping) + process_mesh_shape = op_dist_attr.get_process_mesh().topology + process_mesh_group = op_dist_attr.get_process_mesh().process_group + + # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism + if rank_id not in process_mesh_group: + rank_id = _get_corresponding_rank(op_dist_attr.get_process_mesh(), + rank_id) + + # A generalized method to caculate embedding offset using cartisian product + relative_idx = _get_idx_in_axis(process_mesh_group, process_mesh_shape, + embedding_row_dim_mapping, rank_id) + + per_part_size = Weight_var.shape[0] + relative_idx = relative_idx * per_part_size + + # TODO caculate ring id + parallel_axis = embedding_row_dim_mapping + group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, + parallel_axis, rank_id) + group = new_process_group(group_ranks) + + # append op + check_variable_and_dtype(Ids_var, 'input', ['int32', 'int64'], + 'c_embedding') + + intermediate_var_0 = main_block.create_var( + name=unique_name.generate_with_ignorable_key(".".join( + ["c_embedding", 'tmp'])), + dtype=Weight_var.dtype, + shape=Out_var.shape, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=Out_var.stop_gradient) + + # copy Out_var's dist_attr to intermediate_var_0's dist_attr + copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, Out_var) + + check_variable_and_dtype( + Out_var, 'tensor', + ['float16', 'float32', 'float64', 'int32', 'int64'], + 'c_allreduce_sum') + + c_embedding_op = main_block.append_op( + type='c_embedding', + inputs={'Ids': [Ids_var], + 'W': [Weight_var]}, + outputs={'Out': [intermediate_var_0]}, + attrs={"start_index": relative_idx}) + + # use_model_parallel + c_allreduce_sum_op = main_block.append_op( + type='c_allreduce_sum', + inputs={'X': [intermediate_var_0]}, + outputs={'Out': [Out_var]}, + attrs={ + 'ring_id': group.id, + 'use_calc_stream': True, + 'use_model_parallel': True, + }) + + # copy serial op's dist_attr to dist op's dist_attr + copy_distributed_attr_for_dist_op(c_embedding_op, main_block, + op_dist_attr) + copy_distributed_attr_for_dist_op(c_allreduce_sum_op, main_block, + op_dist_attr) + + # param initialization sync + assert Weight_var.name not in dist_op_helper.already_init_sync_vars + dist_op_helper.already_init_sync_vars.add(Weight_var.name) + param = startup_block.var(Weight_var.name) + param_dist_attr = ctx.get_tensor_distributed_attr_for_program(param) + process_mesh = param_dist_attr.get_process_mesh() + dim_mapping = param_dist_attr.get_dims_mapping() + + # NOTE all not splited axis should be presented in mesh + for axis, size in enumerate(process_mesh.topology): + if size <= 1 or axis in dim_mapping: + pass else: - relative_idx = rank_id % num_partition + group_ranks = _get_comm_group(process_mesh.process_group, + process_mesh.topology, axis, + rank_id) + sync_group = new_process_group(group_ranks) + + startup_block.append_op( + type='c_broadcast', + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': sync_group.id, + 'root': 0, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Forward + }) + startup_block._sync_with_cpp() + + @staticmethod + def backward(ctx, *args, **kwargs): + + # by now the backward function only insert the gradient allreduce for dist op itself + dist_op_helper = ctx.get_dist_op_helper() + main_block = dist_op_helper.get_dst_main_program().global_block() + backward_op = dist_op_helper.get_cur_src_op() + rank_id = dist_op_helper.get_rank_id() + dist_attr = ctx.get_op_distributed_attr_for_program(backward_op) + assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( + str(backward_op)) - per_part_size = Weight_var.shape[0] - relative_idx = relative_idx * per_part_size + # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism + if rank_id not in dist_attr.get_process_mesh().process_group: + rank_id = _get_corresponding_rank(dist_attr.get_process_mesh(), + rank_id) + + # check if need gradient allreduce + need_gradient_allreduce = False + + assert 'Ids' in kwargs, "input [{}] is not given".format('Ids') + assert 'W' in kwargs, "input [{}] is not given".format('W') + assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out') + assert 'W@GRAD' in kwargs, "output [{}] is not given".format('W@GRAD') + + assert len( + kwargs['Ids'] + ) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( + kwargs['Ids']) + assert len( + kwargs['W'] + ) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( + kwargs['W']) + assert len( + kwargs['Out@GRAD'] + ) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( + kwargs['Out']) + assert len( + kwargs['W@GRAD'] + ) == 1, "row_parallel_embedding output Ids take 1 variable but got {}".format( + kwargs['W@GRAD']) + + Ids_var = main_block.var(kwargs['Ids'][0]) + process_mesh = dist_attr.get_process_mesh() + var_dim_mapping = dist_attr.get_input_dims_mapping(Ids_var.name) + mesh_shape = process_mesh.topology + batch_size_axis = var_dim_mapping[0] + if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: + need_gradient_allreduce = True - # TODO caculate ring id group_ranks = _get_comm_group(process_mesh.process_group, process_mesh.topology, - model_parallel_axis, rank_id) - group = new_process_group(group_ranks) - - # append op - check_variable_and_dtype(Ids_var, 'input', ['int32', 'int64'], - 'c_embedding') - - intermediate_var_0 = dst_block.create_var( - name=unique_name.generate_with_ignorable_key(".".join( - ["c_embedding", 'tmp'])), - dtype=Weight_var.dtype, - shape=Out_var.shape, - type=core.VarDesc.VarType.LOD_TENSOR, - persistable=False, - stop_gradient=Out_var.stop_gradient) - # copy Out_var's dist_attr to intermediate_var_0's dist_attr - copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, - Out_var) - - check_variable_and_dtype( - Out_var, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], - 'c_allreduce_sum') - - c_embedding_op = dst_block.append_op( - type='c_embedding', - inputs={'Ids': [Ids_var], - 'W': [Weight_var]}, - outputs={'Out': [intermediate_var_0]}, - attrs={"start_index": relative_idx}) - - # use_model_parallel - c_allreduce_sum_op = dst_block.append_op( + batch_size_axis, rank_id) + dp_degree = len(group_ranks) + dp_group = new_process_group(group_ranks) + + if need_gradient_allreduce: + W_Grad_var = main_block.var(kwargs['W@GRAD'][0]) + allreduce_op = main_block.append_op( type='c_allreduce_sum', - inputs={'X': [intermediate_var_0]}, - outputs={'Out': [Out_var]}, + inputs={'X': [W_Grad_var]}, + outputs={'Out': [W_Grad_var]}, attrs={ - 'ring_id': group.id, + 'ring_id': dp_group.id, 'use_calc_stream': True, - 'use_model_parallel': True, + OP_ROLE_KEY: OpRole.Backward }) + scale_op = main_block.append_op( + type='scale', + inputs={'X': W_Grad_var}, + outputs={'Out': W_Grad_var}, + attrs={'scale': 1.0 / dp_degree, + OP_ROLE_KEY: OpRole.Backward}) + main_block._sync_with_cpp() - # copy serial op's dist_attr to dist op's dist_attr - copy_distributed_attr_for_dist_op(c_embedding_op, dst_block, - op_dist_attr) - copy_distributed_attr_for_dist_op(c_allreduce_sum_op, dst_block, - op_dist_attr) - - if in_dygraph_mode(): - raise NotImplementedError( - "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( - "matmul", 0)) - else: - return static_handle + dims_mapping = ctx.get_tensor_distributed_attr_for_program( + W_Grad_var).get_dims_mapping() + process_mesh = dist_attr.get_process_mesh() + for op in [allreduce_op, scale_op]: + op_attr = OperatorDistributedAttribute(op, ctx) + op_attr.set_process_mesh(process_mesh) + op_attr.set_output_dims_mapping(W_Grad_var.name, dims_mapping) + op_attr.set_input_dims_mapping(W_Grad_var.name, dims_mapping) + ctx.set_op_distributed_attr_for_program(op, op_attr) register_distributed_operator_impl("lookup_table_v2", DistributedEmbeddingImpl("row_parallel")) +register_distributed_operator_impl("c_embedding", + DistributedEmbeddingImpl("row_parallel")) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index 10a01dc57ed..2edbcd2318c 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -24,12 +24,14 @@ from ..utils import is_valid_list_index from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping +from ..attribute import OperatorDistributedAttribute from paddle.fluid import core, unique_name from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype +from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from ..process import new_process_group -from ..utils import _get_comm_group +from ..utils import _get_comm_group, _get_corresponding_rank def _update_dims_mapping_for_matmul(op_dist_attr): @@ -123,6 +125,130 @@ def _update_dims_mapping_for_matmul(op_dist_attr): return changed +def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): + + # by now the backward function only insert the gradient allreduce for dist op itself + + dist_op_helper = ctx.get_dist_op_helper() + main_block = dist_op_helper.get_dst_main_program().global_block() + backward_op = dist_op_helper.get_cur_src_op() + rank_id = dist_op_helper.get_rank_id() + dist_attr = ctx.get_op_distributed_attr_for_program(backward_op) + assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( + str(backward_op)) + + # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism + if rank_id not in dist_attr.get_process_mesh().process_group: + rank_id = _get_corresponding_rank(dist_attr.get_process_mesh(), rank_id) + + # check if need gradient allreduce + need_gradient_allreduce = False + + assert 'Y' in kwargs, "input [{}] is not given".format('Y') + assert 'X' in kwargs, "input [{}] is not given".format('X') + assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out@GRAD') + assert 'Y@GRAD' in kwargs, "output [{}] is not given".format('Y@GRAD') + assert 'X@GRAD' in kwargs, "output [{}] is not given".format('X@GRAD') + + assert len( + kwargs['Y'] + ) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( + kwargs['Y']) + assert len( + kwargs['X'] + ) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( + kwargs['X']) + assert len( + kwargs['Out@GRAD'] + ) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( + kwargs['Out']) + assert len( + kwargs['Y@GRAD'] + ) == 1, "row_parallel_embedding output Ids take 1 variable but got {}".format( + kwargs['Y@GRAD']) + assert len( + kwargs['X@GRAD'] + ) == 1, "row_parallel_embedding output Ids take 1 variable but got {}".format( + kwargs['X@GRAD']) + + X_var = main_block.var(kwargs['X'][0]) + assert not X_var.is_parameter, "left operand(X) [{}] of dist matmul should not be parameter".format( + X_var.name) + + process_mesh = dist_attr.get_process_mesh() + var_dim_mapping = dist_attr.get_input_dims_mapping(X_var.name) + mesh_shape = process_mesh.topology + batch_size_axis = var_dim_mapping[0] + if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: + need_gradient_allreduce = True + group_ranks = _get_comm_group(process_mesh.process_group, + process_mesh.topology, batch_size_axis, + rank_id) + dp_degree = len(group_ranks) + dp_group = new_process_group(group_ranks) + + Y_var = main_block.var(kwargs['Y'][0]) + if need_gradient_allreduce and Y_var.is_parameter: + Y_Grad_var = main_block.var(kwargs['Y@GRAD'][0]) + allreduce_op = main_block.append_op( + type='c_allreduce_sum', + inputs={'X': [Y_Grad_var]}, + outputs={'Out': [Y_Grad_var]}, + attrs={ + 'ring_id': dp_group.id, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Backward + }) + scale_op = main_block.append_op( + type='scale', + inputs={'X': Y_Grad_var}, + outputs={'Out': Y_Grad_var}, + attrs={'scale': 1.0 / dp_degree, + OP_ROLE_KEY: OpRole.Backward}) + main_block._sync_with_cpp() + + dims_mapping = ctx.get_tensor_distributed_attr_for_program( + Y_Grad_var).get_dims_mapping() + process_mesh = dist_attr.get_process_mesh() + for op in [allreduce_op, scale_op]: + op_attr = OperatorDistributedAttribute(op, ctx) + op_attr.set_process_mesh(process_mesh) + op_attr.set_output_dims_mapping(Y_Grad_var.name, dims_mapping) + op_attr.set_input_dims_mapping(Y_Grad_var.name, dims_mapping) + ctx.set_op_distributed_attr_for_program(op, op_attr) + + +def _init_param_sync(Weight_var, dist_op_helper, startup_block, ctx, rank_id): + + assert Weight_var.name not in dist_op_helper.already_init_sync_vars + assert startup_block.has_var(Weight_var.name) + dist_op_helper.already_init_sync_vars.add(Weight_var.name) + param = startup_block.var(Weight_var.name) + param_dist_attr = ctx.get_tensor_distributed_attr_for_program(param) + process_mesh = param_dist_attr.get_process_mesh() + dim_mapping = param_dist_attr.get_dims_mapping() + + for axis, size in enumerate(process_mesh.topology): + if size <= 1 or axis in dim_mapping: + pass + else: + group_ranks = _get_comm_group(process_mesh.process_group, + process_mesh.topology, axis, rank_id) + sync_group = new_process_group(group_ranks) + + startup_block.append_op( + type='c_broadcast', + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': sync_group.id, + 'root': 0, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Forward + }) + startup_block._sync_with_cpp() + + class DistributedMatmul(DistributedOperator): def __init__(self, name): super(DistributedMatmul, self).__init__() @@ -138,7 +264,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): super(DistributedMatmulImpl0, self).__init__() self._name = name self._forward_implemented = True - self._backward_implemented = False + self._backward_implemented = True def is_process_mesh_compatible(self, op_dist_attr): """ No restriction for now. """ @@ -178,101 +304,109 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): changed = True return changed - def forward(self, serial_op): - def static_handle(dst_block, - src_op, - op_dist_attr, - input_name_mapping, - output_name_mapping, - rank_id=0): - assert len( - input_name_mapping - ) == 2, "col_parallel_linear take 2 inputs variable but got {}".format( - input_name_mapping) - assert len( - output_name_mapping - ) == 1, "col_parallel_linear take 2 inputs variable but got {}".format( - output_name_mapping) - assert len( - input_name_mapping['X'] - ) == 1, "col_parallel_linear input X take 1 variable but got {}".format( - input_name_mapping['X']) - assert len( - input_name_mapping['Y'] - ) == 1, "col_parallel_linear input Y take 1 variable but got {}".format( - input_name_mapping['Y']) - assert len( - output_name_mapping['Out'] - ) == 1, "col_parallel_linear input Out take 1 variable but got {}".format( - input_name_mapping['Out']) - X_var = dst_block.var(input_name_mapping['X'][0]) - Weight_var = dst_block.var(input_name_mapping['Y'][0]) - Out_var = dst_block.var(output_name_mapping['Out'][0]) - - # TODO infer logic comm presentation - model_parallel_axis, process_mesh = op_dist_attr.get_owner_context( - )._get_model_parallel_info() - group_ranks = _get_comm_group(process_mesh.process_group, - process_mesh.topology, - model_parallel_axis, rank_id) - group = new_process_group(group_ranks) - - intermediate_var_0 = dst_block.create_var( - name=unique_name.generate_with_ignorable_key(".".join( - ["c_identity", 'tmp'])), - dtype=X_var.dtype, - shape=X_var.shape, - type=core.VarDesc.VarType.LOD_TENSOR, - persistable=False, - stop_gradient=X_var.stop_gradient) - # copy X_var's dist_attr to intermediate_var_0's dist_attr - copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, - X_var) - - check_variable_and_dtype( - X_var, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], - '_c_identity') - - c_identity_op = dst_block.append_op( - type='c_identity', - inputs={'X': [X_var]}, - outputs={'Out': intermediate_var_0}, - attrs={ - 'ring_id': group.id, - 'use_calc_stream': True, - 'use_model_parallel': True, - }) - - check_variable_and_dtype(intermediate_var_0, 'x', - ['float16', 'float32', 'float64'], - 'linear') - check_dtype(intermediate_var_0.dtype, 'dtype', - ['float16', 'float32', 'float64'], 'linear') - attrs = { - 'transpose_X': False, - 'transpose_Y': False, - 'alpha': 1, - } - inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} - matmul_op = dst_block.append_op( - type='matmul', - inputs=inputs, - outputs={'Out': Out_var}, - attrs=attrs) - - # copy serial op's dist_attr to dist op's dist_attr - copy_distributed_attr_for_dist_op(c_identity_op, dst_block, - op_dist_attr) - copy_distributed_attr_for_dist_op(matmul_op, dst_block, - op_dist_attr) - - if in_dygraph_mode(): - raise NotImplementedError( - "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( - "matmul", 0)) - else: - return static_handle + @staticmethod + def forward(ctx, *args, **kwargs): + """ + kwargs: inputname_mapping & outputname_mapping + """ + + dist_op_helper = ctx.get_dist_op_helper() + main_block = dist_op_helper.get_dst_main_program().global_block() + startup_block = dist_op_helper.get_dst_startup_program().global_block() + src_op = dist_op_helper.get_cur_src_op() + rank_id = dist_op_helper.get_rank_id() + op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) + assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( + str(src_op)) + + # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism + if rank_id not in op_dist_attr.get_process_mesh().process_group: + rank_id = _get_corresponding_rank(op_dist_attr.get_process_mesh(), + rank_id) + + # check validation of inputs / outputs + for input_name in src_op.desc.input_names(): + assert input_name in kwargs, "input [{}] is not given".format( + input_name) + assert len(kwargs[input_name]) == len( + src_op.desc.input(input_name) + ), "number of tensor for input [{}] is not match".format(input_name) + for output_name in src_op.desc.output_names(): + assert output_name in kwargs, "input [{}] is not given".format( + output_name) + assert len(kwargs[output_name]) == len( + src_op.desc.output(output_name) + ), "number of tensor for input [{}] is not match".format( + output_name) + + X_var = main_block.var(kwargs['X'][0]) + Weight_var = main_block.var(kwargs['Y'][0]) + Out_var = main_block.var(kwargs['Out'][0]) + + # TODO infer logic comm presentation + matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( + Weight_var.name)[1] + assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( + matmul_col_dim_mapping) + process_mesh_shape = op_dist_attr.get_process_mesh().topology + process_mesh_group = op_dist_attr.get_process_mesh().process_group + + parallel_axis = matmul_col_dim_mapping + group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, + parallel_axis, rank_id) + group = new_process_group(group_ranks) + + intermediate_var_0 = main_block.create_var( + name=unique_name.generate_with_ignorable_key(".".join( + ["c_identity", 'tmp'])), + dtype=X_var.dtype, + shape=X_var.shape, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=X_var.stop_gradient) + # copy X_var's dist_attr to intermediate_var_0's dist_attr + copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, X_var) + + check_variable_and_dtype( + X_var, 'tensor', + ['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity') + + c_identity_op = main_block.append_op( + type='c_identity', + inputs={'X': [X_var]}, + outputs={'Out': intermediate_var_0}, + attrs={ + 'ring_id': group.id, + 'use_calc_stream': True, + 'use_model_parallel': True, + }) + + check_variable_and_dtype(intermediate_var_0, 'x', + ['float16', 'float32', 'float64'], 'linear') + check_dtype(intermediate_var_0.dtype, 'dtype', + ['float16', 'float32', 'float64'], 'linear') + attrs = { + 'transpose_X': False, + 'transpose_Y': False, + 'alpha': 1, + } + inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} + matmul_op = main_block.append_op( + type='matmul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs) + + # copy serial op's dist_attr to dist op's dist_attr + copy_distributed_attr_for_dist_op(c_identity_op, main_block, + op_dist_attr) + copy_distributed_attr_for_dist_op(matmul_op, main_block, op_dist_attr) + + # init param sync + if Weight_var.is_parameter: + _init_param_sync(Weight_var, dist_op_helper, startup_block, ctx, + rank_id) + + @staticmethod + def backward(ctx, *args, **kwargs): + _right_operand_parameter_matmul_backward(ctx, *args, **kwargs) # RowParallel @@ -281,7 +415,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): super(DistributedMatmulImpl1, self).__init__() self._name = name self._forward_implemented = True - self._backward_implemented = False + self._backward_implemented = True def is_process_mesh_compatible(self, op_dist_attr): """ No restriction for now. """ @@ -323,95 +457,108 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): changed = True return changed - def forward(self, serial_op): - def static_handle(dst_block, - src_op, - op_dist_attr, - input_name_mapping, - output_name_mapping, - rank_id=0): - assert len( - input_name_mapping - ) == 2, "col_parallel_linear take 2 inputs variable but got {}".format( - input_name_mapping) - assert len( - output_name_mapping - ) == 1, "col_parallel_linear take 2 inputs variable but got {}".format( - output_name_mapping) - assert len( - input_name_mapping['X'] - ) == 1, "col_parallel_linear input X take 1 variable but got {}".format( - input_name_mapping['X']) - assert len( - input_name_mapping['Y'] - ) == 1, "col_parallel_linear input Y take 1 variable but got {}".format( - input_name_mapping['Y']) - assert len( - output_name_mapping['Out'] - ) == 1, "col_parallel_linear input Out take 1 variable but got {}".format( - input_name_mapping['Out']) - X_var = dst_block.var(input_name_mapping['X'][0]) - Weight_var = dst_block.var(input_name_mapping['Y'][0]) - Out_var = dst_block.var(output_name_mapping['Out'][0]) - - # TODO infer logic comm presentation - model_parallel_axis, process_mesh = op_dist_attr.get_owner_context( - )._get_model_parallel_info() - group_ranks = _get_comm_group(process_mesh.process_group, - process_mesh.topology, - model_parallel_axis, rank_id) - group = new_process_group(group_ranks) - - check_variable_and_dtype( - X_var, 'x', ['float16', 'float32', 'float64'], 'linear') - check_dtype(X_var.dtype, 'dtype', - ['float16', 'float32', 'float64'], 'linear') - attrs = { - 'transpose_X': False, - 'transpose_Y': False, - 'alpha': 1, - } - inputs = {'X': X_var, 'Y': Weight_var} - intermediate_var_0 = dst_block.create_var( - shape=Out_var.shape, - dtype=Out_var.dtype, - type=Out_var.type, - lod_level=Out_var.lod_level, - persistable=False, - is_data=False, - need_check_feed=Out_var.desc.need_check_feed()) - # copy Out_var's dist_attr to intermediate_var_0's dist_attr - copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, - Out_var) - - matmul_op = dst_block.append_op( - type='matmul', - inputs=inputs, - outputs={'Out': intermediate_var_0}, - attrs=attrs) - - c_allreduce_sum_op = dst_block.append_op( - type='c_allreduce_sum', - inputs={'X': intermediate_var_0}, - outputs={'Out': Out_var}, - attrs={ - 'ring_id': group.id, - 'use_calc_stream': True, - 'use_model_parallel': True - }) - - # copy serial op's dist_attr to dist op's dist_attr - copy_distributed_attr_for_dist_op(matmul_op, dst_block, - op_dist_attr) - copy_distributed_attr_for_dist_op(c_allreduce_sum_op, dst_block, - op_dist_attr) - - if in_dygraph_mode(): - raise NotImplementedError( - "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( - "matmul", 0)) - else: - return static_handle + @staticmethod + def forward(ctx, *args, **kwargs): + """ + kwargs: inputname_mapping & outputname_mapping + """ + + dist_op_helper = ctx.get_dist_op_helper() + main_block = dist_op_helper.get_dst_main_program().global_block() + startup_block = dist_op_helper.get_dst_startup_program().global_block() + src_op = dist_op_helper.get_cur_src_op() + rank_id = dist_op_helper.get_rank_id() + op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) + assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( + str(src_op)) + + # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism + if rank_id not in op_dist_attr.get_process_mesh().process_group: + rank_id = _get_corresponding_rank(op_dist_attr.get_process_mesh(), + rank_id) + + # check validation of inputs / outputs + for input_name in src_op.desc.input_names(): + assert input_name in kwargs, "input [{}] is not given".format( + input_name) + assert len(kwargs[input_name]) == len( + src_op.desc.input(input_name) + ), "number of tensor for input [{}] is not match".format(input_name) + for output_name in src_op.desc.output_names(): + assert output_name in kwargs, "input [{}] is not given".format( + output_name) + assert len(kwargs[output_name]) == len( + src_op.desc.output(output_name) + ), "number of tensor for input [{}] is not match".format( + output_name) + + X_var = main_block.var(kwargs['X'][0]) + Weight_var = main_block.var(kwargs['Y'][0]) + Out_var = main_block.var(kwargs['Out'][0]) + + # TODO infer logic comm presentation + matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( + Weight_var.name)[0] + assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( + matmul_row_dim_mapping) + process_mesh_shape = op_dist_attr.get_process_mesh().topology + process_mesh_group = op_dist_attr.get_process_mesh().process_group + + parallel_axis = matmul_row_dim_mapping + group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, + parallel_axis, rank_id) + group = new_process_group(group_ranks) + + check_variable_and_dtype(X_var, 'x', ['float16', 'float32', 'float64'], + 'linear') + check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], + 'linear') + attrs = { + 'transpose_X': False, + 'transpose_Y': False, + 'alpha': 1, + } + inputs = {'X': X_var, 'Y': Weight_var} + intermediate_var_0 = main_block.create_var( + shape=Out_var.shape, + dtype=Out_var.dtype, + type=Out_var.type, + lod_level=Out_var.lod_level, + persistable=False, + is_data=False, + need_check_feed=Out_var.desc.need_check_feed()) + # copy Out_var's dist_attr to intermediate_var_0's dist_attr + copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, Out_var) + + matmul_op = main_block.append_op( + type='matmul', + inputs=inputs, + outputs={'Out': intermediate_var_0}, + attrs=attrs) + + c_allreduce_sum_op = main_block.append_op( + type='c_allreduce_sum', + inputs={'X': intermediate_var_0}, + outputs={'Out': Out_var}, + attrs={ + 'ring_id': group.id, + 'use_calc_stream': True, + 'use_model_parallel': True + }) + + # copy serial op's dist_attr to dist op's dist_attr + copy_distributed_attr_for_dist_op(matmul_op, main_block, op_dist_attr) + copy_distributed_attr_for_dist_op(c_allreduce_sum_op, main_block, + op_dist_attr) + + # init param sync + if Weight_var.is_parameter: + _init_param_sync(Weight_var, dist_op_helper, startup_block, ctx, + rank_id) + + @staticmethod + def backward(ctx, *args, **kwargs): + _right_operand_parameter_matmul_backward(ctx, *args, **kwargs) # ReplicateParallel @@ -465,6 +612,10 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): changed = True return changed + @staticmethod + def backward(ctx, *args, **kwargs): + _right_operand_parameter_matmul_backward(ctx, *args, **kwargs) + register_distributed_operator_impl("matmul", DistributedMatmulImpl0("column_parallel")) @@ -489,7 +640,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): super(DistributedMatmulV2Impl0, self).__init__() self._name = name self._forward_implemented = True - self._backward_implemented = False + self._backward_implemented = True def is_process_mesh_compatible(self, op_dist_attr): """ No restriction for now. """ @@ -529,97 +680,109 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): changed = True return changed - def forward(self, serial_op): - def static_handle(dst_block, - src_op, - op_dist_attr, - input_name_mapping, - output_name_mapping, - rank_id=0): - assert len( - input_name_mapping - ) == 2, "col_parallel_linear take 2 inputs variable but got {}".format( - input_name_mapping) - assert len( - output_name_mapping - ) == 1, "col_parallel_linear take 2 inputs variable but got {}".format( - output_name_mapping) - assert len( - input_name_mapping['X'] - ) == 1, "col_parallel_linear input X take 1 variable but got {}".format( - input_name_mapping['X']) - assert len( - input_name_mapping['Y'] - ) == 1, "col_parallel_linear input Y take 1 variable but got {}".format( - input_name_mapping['Y']) - assert len( - output_name_mapping['Out'] - ) == 1, "col_parallel_linear input Out take 1 variable but got {}".format( - input_name_mapping['Out']) - X_var = dst_block.var(input_name_mapping['X'][0]) - Weight_var = dst_block.var(input_name_mapping['Y'][0]) - Out_var = dst_block.var(output_name_mapping['Out'][0]) - - # TODO infer logic comm presentation - model_parallel_axis, process_mesh = op_dist_attr.get_owner_context( - )._get_model_parallel_info() - group_ranks = _get_comm_group(process_mesh.process_group, - process_mesh.topology, - model_parallel_axis, rank_id) - group = new_process_group(group_ranks) - - intermediate_var_0 = dst_block.create_var( - name=unique_name.generate_with_ignorable_key(".".join( - ["c_identity", 'tmp'])), - dtype=X_var.dtype, - shape=X_var.shape, - type=core.VarDesc.VarType.LOD_TENSOR, - persistable=False, - stop_gradient=X_var.stop_gradient) - # copy X_var's dist_attr to intermediate_var_0's dist_attr - copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, - X_var) - - check_variable_and_dtype( - X_var, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], - '_c_identity') - - c_identity_op = dst_block.append_op( - type='c_identity', - inputs={'X': [X_var]}, - outputs={'Out': intermediate_var_0}, - attrs={ - 'ring_id': group.id, - 'use_calc_stream': True, - 'use_model_parallel': True, - }) - - check_variable_and_dtype(intermediate_var_0, 'x', - ['float16', 'float32', 'float64'], - 'linear') - check_dtype(intermediate_var_0.dtype, 'dtype', - ['float16', 'float32', 'float64'], 'linear') - attrs = {'trans_x': False, 'trans_y': False} - inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} - matmul_v2_op = dst_block.append_op( - type='matmul_v2', - inputs=inputs, - outputs={'Out': Out_var}, - attrs=attrs) - - # copy serial op's dist_attr to dist op's dist_attr - copy_distributed_attr_for_dist_op(c_identity_op, dst_block, - op_dist_attr) - copy_distributed_attr_for_dist_op(matmul_v2_op, dst_block, - op_dist_attr) - - if in_dygraph_mode(): - raise NotImplementedError( - "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( - "matmul", 0)) - else: - return static_handle + @staticmethod + def forward(ctx, *args, **kwargs): + """ + kwargs: inputname_mapping & outputname_mapping + """ + + dist_op_helper = ctx.get_dist_op_helper() + main_block = dist_op_helper.get_dst_main_program().global_block() + startup_block = dist_op_helper.get_dst_startup_program().global_block() + src_op = dist_op_helper.get_cur_src_op() + rank_id = dist_op_helper.get_rank_id() + op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) + assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( + str(src_op)) + + # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism + if rank_id not in op_dist_attr.get_process_mesh().process_group: + rank_id = _get_corresponding_rank(op_dist_attr.get_process_mesh(), + rank_id) + + # check validation of inputs / outputs + for input_name in src_op.desc.input_names(): + assert input_name in kwargs, "input [{}] is not given".format( + input_name) + assert len(kwargs[input_name]) == len( + src_op.desc.input(input_name) + ), "number of tensor for input [{}] is not match".format(input_name) + for output_name in src_op.desc.output_names(): + assert output_name in kwargs, "input [{}] is not given".format( + output_name) + assert len(kwargs[output_name]) == len( + src_op.desc.output(output_name) + ), "number of tensor for input [{}] is not match".format( + output_name) + + X_var = main_block.var(kwargs['X'][0]) + Weight_var = main_block.var(kwargs['Y'][0]) + Out_var = main_block.var(kwargs['Out'][0]) + + # TODO infer logic comm presentation + matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( + Weight_var.name)[1] + assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( + matmul_col_dim_mapping) + process_mesh_shape = op_dist_attr.get_process_mesh().topology + process_mesh_group = op_dist_attr.get_process_mesh().process_group + + parallel_axis = matmul_col_dim_mapping + group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, + parallel_axis, rank_id) + group = new_process_group(group_ranks) + + intermediate_var_0 = main_block.create_var( + name=unique_name.generate_with_ignorable_key(".".join( + ["c_identity", 'tmp'])), + dtype=X_var.dtype, + shape=X_var.shape, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=X_var.stop_gradient) + # copy X_var's dist_attr to intermediate_var_0's dist_attr + copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, X_var) + + check_variable_and_dtype( + X_var, 'tensor', + ['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity') + + c_identity_op = main_block.append_op( + type='c_identity', + inputs={'X': [X_var]}, + outputs={'Out': intermediate_var_0}, + attrs={ + 'ring_id': group.id, + 'use_calc_stream': True, + 'use_model_parallel': True, + }) + + check_variable_and_dtype(intermediate_var_0, 'x', + ['float16', 'float32', 'float64'], 'linear') + check_dtype(intermediate_var_0.dtype, 'dtype', + ['float16', 'float32', 'float64'], 'linear') + attrs = {'trans_x': False, 'trans_y': False} + inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} + matmul_v2_op = main_block.append_op( + type='matmul_v2', + inputs=inputs, + outputs={'Out': Out_var}, + attrs=attrs) + + # copy serial op's dist_attr to dist op's dist_attr + copy_distributed_attr_for_dist_op(c_identity_op, main_block, + op_dist_attr) + copy_distributed_attr_for_dist_op(matmul_v2_op, main_block, + op_dist_attr) + + # init param sync + if Weight_var.is_parameter: + _init_param_sync(Weight_var, dist_op_helper, startup_block, ctx, + rank_id) + + @staticmethod + def backward(ctx, *args, **kwargs): + _right_operand_parameter_matmul_backward(ctx, *args, **kwargs) # RowParallel @@ -628,7 +791,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): super(DistributedMatmulV2Impl1, self).__init__() self._name = name self._forward_implemented = True - self._backward_implemented = False + self._backward_implemented = True def is_process_mesh_compatible(self, op_dist_attr): """ No restriction for now. """ @@ -670,91 +833,105 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): changed = True return changed - def forward(self, serial_op): - def static_handle(dst_block, - src_op, - op_dist_attr, - input_name_mapping, - output_name_mapping, - rank_id=0): - assert len( - input_name_mapping - ) == 2, "col_parallel_linear take 2 inputs variable but got {}".format( - input_name_mapping) - assert len( - output_name_mapping - ) == 1, "col_parallel_linear take 2 inputs variable but got {}".format( - output_name_mapping) - assert len( - input_name_mapping['X'] - ) == 1, "col_parallel_linear input X take 1 variable but got {}".format( - input_name_mapping['X']) - assert len( - input_name_mapping['Y'] - ) == 1, "col_parallel_linear input Y take 1 variable but got {}".format( - input_name_mapping['Y']) - assert len( - output_name_mapping['Out'] - ) == 1, "col_parallel_linear input Out take 1 variable but got {}".format( - input_name_mapping['Out']) - X_var = dst_block.var(input_name_mapping['X'][0]) - Weight_var = dst_block.var(input_name_mapping['Y'][0]) - Out_var = dst_block.var(output_name_mapping['Out'][0]) - - # TODO infer logic comm presentation - model_parallel_axis, process_mesh = op_dist_attr.get_owner_context( - )._get_model_parallel_info() - group_ranks = _get_comm_group(process_mesh.process_group, - process_mesh.topology, - model_parallel_axis, rank_id) - group = new_process_group(group_ranks) - - check_variable_and_dtype( - X_var, 'x', ['float16', 'float32', 'float64'], 'linear') - check_dtype(X_var.dtype, 'dtype', - ['float16', 'float32', 'float64'], 'linear') - attrs = {'trans_x': False, 'trans_y': False} - inputs = {'X': X_var, 'Y': Weight_var} - intermediate_var_0 = dst_block.create_var( - shape=Out_var.shape, - dtype=Out_var.dtype, - type=Out_var.type, - lod_level=Out_var.lod_level, - persistable=False, - is_data=False, - need_check_feed=Out_var.desc.need_check_feed()) - # copy Out_var's dist_attr to intermediate_var_0's dist_attr - copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, - Out_var) - - matmul_v2_op = dst_block.append_op( - type='matmul_v2', - inputs=inputs, - outputs={'Out': intermediate_var_0}, - attrs=attrs) - - c_allreduce_sum_op = dst_block.append_op( - type='c_allreduce_sum', - inputs={'X': intermediate_var_0}, - outputs={'Out': Out_var}, - attrs={ - 'ring_id': group.id, - 'use_calc_stream': True, - 'use_model_parallel': True - }) - - # copy serial op's dist_attr to dist op's dist_attr - copy_distributed_attr_for_dist_op(matmul_v2_op, dst_block, - op_dist_attr) - copy_distributed_attr_for_dist_op(c_allreduce_sum_op, dst_block, - op_dist_attr) - - if in_dygraph_mode(): - raise NotImplementedError( - "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( - "matmul", 0)) - else: - return static_handle + @staticmethod + def forward(ctx, *args, **kwargs): + """ + kwargs: inputname_mapping & outputname_mapping + """ + + dist_op_helper = ctx.get_dist_op_helper() + main_block = dist_op_helper.get_dst_main_program().global_block() + startup_block = dist_op_helper.get_dst_startup_program().global_block() + src_op = dist_op_helper.get_cur_src_op() + rank_id = dist_op_helper.get_rank_id() + op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) + assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( + str(src_op)) + + # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism + if rank_id not in op_dist_attr.get_process_mesh().process_group: + rank_id = _get_corresponding_rank(op_dist_attr.get_process_mesh(), + rank_id) + + # check validation of inputs / outputs + for input_name in src_op.desc.input_names(): + assert input_name in kwargs, "input [{}] is not given".format( + input_name) + assert len(kwargs[input_name]) == len( + src_op.desc.input(input_name) + ), "number of tensor for input [{}] is not match".format(input_name) + for output_name in src_op.desc.output_names(): + assert output_name in kwargs, "input [{}] is not given".format( + output_name) + assert len(kwargs[output_name]) == len( + src_op.desc.output(output_name) + ), "number of tensor for input [{}] is not match".format( + output_name) + + X_var = main_block.var(kwargs['X'][0]) + Weight_var = main_block.var(kwargs['Y'][0]) + Out_var = main_block.var(kwargs['Out'][0]) + + # TODO infer logic comm presentation + matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( + Weight_var.name)[0] + assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( + matmul_row_dim_mapping) + process_mesh_shape = op_dist_attr.get_process_mesh().topology + process_mesh_group = op_dist_attr.get_process_mesh().process_group + + parallel_axis = matmul_row_dim_mapping + group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, + parallel_axis, rank_id) + group = new_process_group(group_ranks) + + check_variable_and_dtype(X_var, 'x', ['float16', 'float32', 'float64'], + 'linear') + check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], + 'linear') + attrs = {'trans_x': False, 'trans_y': False} + inputs = {'X': X_var, 'Y': Weight_var} + intermediate_var_0 = main_block.create_var( + shape=Out_var.shape, + dtype=Out_var.dtype, + type=Out_var.type, + lod_level=Out_var.lod_level, + persistable=False, + is_data=False, + need_check_feed=Out_var.desc.need_check_feed()) + # copy Out_var's dist_attr to intermediate_var_0's dist_attr + copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, Out_var) + + matmul_v2_op = main_block.append_op( + type='matmul_v2', + inputs=inputs, + outputs={'Out': intermediate_var_0}, + attrs=attrs) + + c_allreduce_sum_op = main_block.append_op( + type='c_allreduce_sum', + inputs={'X': intermediate_var_0}, + outputs={'Out': Out_var}, + attrs={ + 'ring_id': group.id, + 'use_calc_stream': True, + 'use_model_parallel': True + }) + + # copy serial op's dist_attr to dist op's dist_attr + copy_distributed_attr_for_dist_op(matmul_v2_op, main_block, + op_dist_attr) + copy_distributed_attr_for_dist_op(c_allreduce_sum_op, main_block, + op_dist_attr) + + # init param sync + if Weight_var.is_parameter: + _init_param_sync(Weight_var, dist_op_helper, startup_block, ctx, + rank_id) + + @staticmethod + def backward(ctx, *args, **kwargs): + _right_operand_parameter_matmul_backward(ctx, *args, **kwargs) # ReplicateParallel @@ -808,6 +985,10 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): changed = True return changed + @staticmethod + def backward(ctx, *args, **kwargs): + _right_operand_parameter_matmul_backward(ctx, *args, **kwargs) + register_distributed_operator_impl("matmul_v2", DistributedMatmulV2Impl0("column_parallel")) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py index e7fbe9cfeba..39e97850b86 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py @@ -42,7 +42,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): super(DistributedReshapeImpl0, self).__init__() self._name = name self._forward_implemented = True - self._backward_implemented = False + self._backward_implemented = True def is_process_mesh_compatible(self, op_dist_attr): """ No restriction for now. """ @@ -97,82 +97,72 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): return changed - def forward(self, serial_op): - def static_handle(dst_block, - src_op, - op_dist_attr, - input_name_mapping, - output_name_mapping, - rank_id=0): - assert len( - input_name_mapping - ) == 3, "Dist op of Reshape take 3 inputs variable but got {}".format( - input_name_mapping) - assert len( - output_name_mapping - ) == 2, "Dist op of Reshape take 2 inputs variable but got {}".format( - output_name_mapping) - assert len( - input_name_mapping['X'] - ) == 1, "Dist op of Reshape input X take 1 variable but got {}".format( - input_name_mapping['X']) - assert len( - input_name_mapping['ShapeTensor'] - ) <= 1, "Dist op of Reshape input ShapeTensor take 0 or 1 variable but got {}".format( - input_name_mapping['ShapeTensor']) - assert len( - input_name_mapping['Shape'] - ) <= 1, "Dist op of Reshape input Shape take 0 or 1 variable but got {}".format( - input_name_mapping['Shape']) - assert len( - output_name_mapping['Out'] - ) == 1, "Dist op of Reshape input Out take 1 variable but got {}".format( - input_name_mapping['Out']) - assert len( - output_name_mapping['XShape'] - ) == 1, "Dist op of Reshape input XShape take 1 variable but got {}".format( - input_name_mapping['XShape']) - - X_var = dst_block.var(input_name_mapping['X'][0]) - Out_var = dst_block.var(output_name_mapping['Out'][0]) - XShape_var = dst_block.var(output_name_mapping['XShape'][0]) - shape_list = src_op.desc.attr("shape") - ShapeTensor_var_list = [] - for name in input_name_mapping['ShapeTensor']: - ShapeTensor_var_list.append(name) - Shape_var_list = [] - for name in input_name_mapping['Shape']: - Shape_var_list.append(name) - - # got dist attribute info - dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) - process_mesh_shape = op_dist_attr.get_process_mesh().topology - - # modify target shape - for idx, axis in enumerate(dim_mapping): - if axis >= 0: - if len(shape_list) > idx: - shape_list[idx] = shape_list[idx] // process_mesh_shape[ - axis] - - # create op - new_op_desc = dst_block.desc.append_op() - new_op_desc.copy_from(src_op.desc) - new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list) - new_op_desc.set_input('Shape', Shape_var_list) - new_op_desc.set_input('X', [X_var.name]) - new_op_desc.set_output('XShape', [XShape_var.name]) - new_op_desc.set_output('Out', [Out_var.name]) - new_op_desc._set_attr('shape', shape_list) - - dst_block._sync_with_cpp() - - if in_dygraph_mode(): - raise NotImplementedError( - "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( - "matmul", 0)) - else: - return static_handle + @staticmethod + def forward(ctx, *args, **kwargs): + """ + kwargs: inputname_mapping & outputname_mapping + """ + + dist_op_helper = ctx.get_dist_op_helper() + main_block = dist_op_helper.get_dst_main_program().global_block() + src_op = dist_op_helper.get_cur_src_op() + rank_id = dist_op_helper.get_rank_id() + op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) + assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( + str(src_op)) + + # check validation of inputs / outputs + for input_name in src_op.desc.input_names(): + assert input_name in kwargs, "input [{}] is not given".format( + input_name) + assert len(kwargs[input_name]) == len( + src_op.desc.input(input_name) + ), "number of tensor for input [{}] is not match".format(input_name) + for output_name in src_op.desc.output_names(): + assert output_name in kwargs, "input [{}] is not given".format( + output_name) + assert len(kwargs[output_name]) == len( + src_op.desc.output(output_name) + ), "number of tensor for input [{}] is not match".format( + output_name) + + X_var = main_block.var(kwargs['X'][0]) + Out_var = main_block.var(kwargs['Out'][0]) + XShape_var = main_block.var(kwargs['XShape'][0]) + shape_list = src_op.desc.attr("shape") + ShapeTensor_var_list = [] + for name in kwargs['ShapeTensor']: + ShapeTensor_var_list.append(name) + Shape_var_list = [] + for name in kwargs['Shape']: + Shape_var_list.append(name) + + # got dist attribute info + dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) + process_mesh_shape = op_dist_attr.get_process_mesh().topology + + # modify target shape + for idx, axis in enumerate(dim_mapping): + if axis >= 0: + if len(shape_list) > idx: + shape_list[idx] = shape_list[idx] // process_mesh_shape[ + axis] + + # create op + new_op_desc = main_block.desc.append_op() + new_op_desc.copy_from(src_op.desc) + new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list) + new_op_desc.set_input('Shape', Shape_var_list) + new_op_desc.set_input('X', [X_var.name]) + new_op_desc.set_output('XShape', [XShape_var.name]) + new_op_desc.set_output('Out', [Out_var.name]) + new_op_desc._set_attr('shape', shape_list) + + main_block._sync_with_cpp() + + @staticmethod + def backward(ctx, *args, **kwargs): + pass class DistributedReshapeImpl1(DistributedOperatorImpl): @@ -180,7 +170,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): super(DistributedReshapeImpl1, self).__init__() self._name = name self._forward_implemented = True - self._backward_implemented = False + self._backward_implemented = True def is_process_mesh_compatible(self, op_dist_attr): """ No restriction for now. """ @@ -235,82 +225,72 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): return changed - def forward(self, serial_op): - def static_handle(dst_block, - src_op, - op_dist_attr, - input_name_mapping, - output_name_mapping, - rank_id=0): - assert len( - input_name_mapping - ) == 3, "Dist op of Reshape take 3 inputs variable but got {}".format( - input_name_mapping) - assert len( - output_name_mapping - ) == 2, "Dist op of Reshape take 2 inputs variable but got {}".format( - output_name_mapping) - assert len( - input_name_mapping['X'] - ) == 1, "Dist op of Reshape input X take 1 variable but got {}".format( - input_name_mapping['X']) - assert len( - input_name_mapping['ShapeTensor'] - ) <= 1, "Dist op of Reshape input ShapeTensor take 0 or 1 variable but got {}".format( - input_name_mapping['ShapeTensor']) - assert len( - input_name_mapping['Shape'] - ) <= 1, "Dist op of Reshape input Shape take 0 or 1 variable but got {}".format( - input_name_mapping['Shape']) - assert len( - output_name_mapping['Out'] - ) == 1, "Dist op of Reshape input Out take 1 variable but got {}".format( - input_name_mapping['Out']) - assert len( - output_name_mapping['XShape'] - ) == 1, "Dist op of Reshape input XShape take 1 variable but got {}".format( - input_name_mapping['XShape']) - - X_var = dst_block.var(input_name_mapping['X'][0]) - Out_var = dst_block.var(output_name_mapping['Out'][0]) - XShape_var = dst_block.var(output_name_mapping['XShape'][0]) - shape_list = src_op.desc.attr("shape") - ShapeTensor_var_list = [] - for name in input_name_mapping['ShapeTensor']: - ShapeTensor_var_list.append(name) - Shape_var_list = [] - for name in input_name_mapping['Shape']: - Shape_var_list.append(name) - - # got dist attribute info - dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) - process_mesh_shape = op_dist_attr.get_process_mesh().topology - - # modify target shape - for idx, axis in enumerate(dim_mapping): - if axis >= 0: - if len(shape_list) > idx: - shape_list[idx] = shape_list[idx] // process_mesh_shape[ - axis] - - # create op - new_op_desc = dst_block.desc.append_op() - new_op_desc.copy_from(src_op.desc) - new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list) - new_op_desc.set_input('Shape', Shape_var_list) - new_op_desc.set_input('X', [X_var.name]) - new_op_desc.set_output('XShape', [XShape_var.name]) - new_op_desc.set_output('Out', [Out_var.name]) - new_op_desc._set_attr('shape', shape_list) - - dst_block._sync_with_cpp() - - if in_dygraph_mode(): - raise NotImplementedError( - "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( - "matmul", 0)) - else: - return static_handle + @staticmethod + def forward(ctx, *args, **kwargs): + """ + kwargs: inputname_mapping & outputname_mapping + """ + + dist_op_helper = ctx.get_dist_op_helper() + main_block = dist_op_helper.get_dst_main_program().global_block() + src_op = dist_op_helper.get_cur_src_op() + rank_id = dist_op_helper.get_rank_id() + op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) + assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( + str(src_op)) + + # check validation of inputs / outputs + for input_name in src_op.desc.input_names(): + assert input_name in kwargs, "input [{}] is not given".format( + input_name) + assert len(kwargs[input_name]) == len( + src_op.desc.input(input_name) + ), "number of tensor for input [{}] is not match".format(input_name) + for output_name in src_op.desc.output_names(): + assert output_name in kwargs, "input [{}] is not given".format( + output_name) + assert len(kwargs[output_name]) == len( + src_op.desc.output(output_name) + ), "number of tensor for input [{}] is not match".format( + output_name) + + X_var = main_block.var(kwargs['X'][0]) + Out_var = main_block.var(kwargs['Out'][0]) + XShape_var = main_block.var(kwargs['XShape'][0]) + shape_list = src_op.desc.attr("shape") + ShapeTensor_var_list = [] + for name in kwargs['ShapeTensor']: + ShapeTensor_var_list.append(name) + Shape_var_list = [] + for name in kwargs['Shape']: + Shape_var_list.append(name) + + # got dist attribute info + dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) + process_mesh_shape = op_dist_attr.get_process_mesh().topology + + # modify target shape + for idx, axis in enumerate(dim_mapping): + if axis >= 0: + if len(shape_list) > idx: + shape_list[idx] = shape_list[idx] // process_mesh_shape[ + axis] + + # create op + new_op_desc = main_block.desc.append_op() + new_op_desc.copy_from(src_op.desc) + new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list) + new_op_desc.set_input('Shape', Shape_var_list) + new_op_desc.set_input('X', [X_var.name]) + new_op_desc.set_output('XShape', [XShape_var.name]) + new_op_desc.set_output('Out', [Out_var.name]) + new_op_desc._set_attr('shape', shape_list) + + main_block._sync_with_cpp() + + @staticmethod + def backward(ctx, *args, **kwargs): + pass register_distributed_operator_impl("reshape2", diff --git a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py index dc78bdee1fb..56be75b3bea 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py @@ -37,6 +37,8 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): def __init__(self, name): super(DistributedSoftmaxImpl, self).__init__() self._name = name + self._forward_implemented = False + self._backward_implemented = True def is_process_mesh_compatible(self, op_dist_attr): """ No restriction for now. """ @@ -86,6 +88,10 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): return changed + @staticmethod + def backward(ctx, *args, **kwargs): + pass + register_distributed_operator_impl( "softmax", DistributedSoftmaxImpl("replicate_last_axis")) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_transpose.py b/python/paddle/distributed/auto_parallel/operators/dist_transpose.py index c2ca4d85fdf..10b8bf2666f 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_transpose.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_transpose.py @@ -37,6 +37,8 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): def __init__(self, name): super(DistributedTranspose2Impl, self).__init__() self._name = name + self._forward_implemented = False + self._backward_implemented = True def is_process_mesh_compatible(self, op_dist_attr): """ No restriction for now. """ @@ -82,6 +84,10 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): return changed + @staticmethod + def backward(ctx, *args, **kwargs): + pass + register_distributed_operator_impl( "transpose2", DistributedTranspose2Impl("same_mapping_transpose")) diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index 1437dbb2f90..8f4a4866eb8 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -94,10 +94,8 @@ class AutoParallelizer: # The last step: remove all distributed attributes to be compatiable # with inference. self._remove_distributed_attrs(partitioned_main_prog) - - complete_backward_annotation(partitioned_main_prog, self._dist_context) - make_data_unshard(partitioned_main_prog, partitioned_startup_prog) + reshard(partitioned_main_prog, partitioned_startup_prog, rank, self._dist_context) diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index b67f1e1ab97..c0a91f4b53a 100755 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -23,15 +23,15 @@ from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid.backward import append_backward, _some_in_set_, _append_grad_suffix_ from paddle.distributed.auto_parallel.operators.common import get_distributed_operator -from paddle.distributed.auto_parallel.operators.common import find_best_compatible_distributed_operator_impl from paddle.fluid.clip import GradientClipBase, GradientClipByNorm, error_clip_callback, append_gradient_clip_ops, ClipGradByGlobalNorm from paddle.distributed.fleet.base.distributed_strategy import DistributedStrategy -from paddle.distributed.auto_parallel.context import DistributedContext +from paddle.distributed.auto_parallel.context import DistributedContext, DistOpHelper from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op, is_backward_op, is_optimizer_op from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from .process import new_process_group from .interface import _g_process_mesh_map -from .utils import _get_comm_group +from .attribute import OperatorDistributedAttribute +from paddle.distributed.auto_parallel.completion import complete_backward_annotation, complete_update_annotation __varname_not_in_block__ = ["lod_tensor_blocking_queue_0"] @@ -122,16 +122,6 @@ class Partitioner(object): # should be set to False self._compatible_with_auto_backward = True - # data parallelism - self._enable_data_parallel = False - self._dp_degree = 0 - self._dp_group = None - - # tensor parallelism - self._enable_tensor_parallel = False - self._tp_degree = 0 - self._tp_group = None - def transpile_forward(self, serial_main_program, serial_startup_program): """ take serial forward programs with shard annotation, create a new distributed forward programs based on the serial ones. @@ -236,9 +226,6 @@ class Partitioner(object): raise RuntimeError( "Not all vars or ops are annotated in main program !") - # determine parallelism mode - self._determine_parallel_mode(main_program) - # dist op & partition vars new_main_prog, new_startup_program = self._dist_var_op_forward_transpile( main_program, startup_program) @@ -270,11 +257,6 @@ class Partitioner(object): self._sharding_backward_transpile(new_main_prog, new_startup_program) - # Data Parallel pass - if self._enable_data_parallel: - self._gradient_sync_transpile(dist_main_program, - dist_startup_program) - return params_grads def apply_optimize_impl(self, user_define_optimizer, params_grads, @@ -311,9 +293,78 @@ class Partitioner(object): partitioned_main_prog = fluid.Program() partitioned_global_block = partitioned_main_prog.global_block() - serial_global_block = serial_main_program.global_block() + serial_main_block = serial_main_program.global_block() serial_ops = serial_main_program.global_block().ops + # transpile startup program + if serial_startup_program == None: + partitioned_startup_prog = None + else: + partitioned_startup_prog = fluid.Program() + # create parameter + partitioned_startup_global_block = partitioned_startup_prog.global_block( + ) + param2shape = {} + temp_varname_map = {} + for var in serial_startup_program.list_vars(): + if isinstance(var, Parameter): + # TODO if var not belong to this rank, should be filtered + serial_main_var = serial_main_block.var(var.name) + dist_attr = self._auto_parallel_context.get_tensor_distributed_attr_for_program( + serial_main_var) + target_shape = _get_dist_shape(serial_main_var, dist_attr) + new_name = var.name + self._dist_varname_suffix + temp_varname_map[var.name] = new_name + _partition_parameter(self._auto_parallel_context, + serial_main_var, + partitioned_startup_global_block, + new_name, target_shape) + param2shape[new_name] = target_shape + + # copy initializer + for op in serial_startup_program.global_block().ops: + # TODO if var not belong to this rank, should be filtered + output_vars = op.desc.output_arg_names() + assert len( + output_vars + ) == 1, "initializer should output only ONE variable, but got [{}]".format( + str(op.desc)) + assert temp_varname_map[output_vars[ + 0]] in param2shape, "try to initialize [{}] which is not a Parameter".format( + output_vars[0]) + new_op_desc = partitioned_startup_global_block.desc.append_op() + new_op_desc.copy_from(op.desc) + new_op_desc._rename_output(output_vars[0], + temp_varname_map[output_vars[0]]) + new_op_desc._set_attr( + "shape", param2shape[temp_varname_map[output_vars[0]]]) + partitioned_startup_global_block._sync_with_cpp() + + # set distribute atrribute + new_op = partitioned_startup_global_block.ops[-1] + assert new_op.type == new_op_desc.type() + assert new_op.desc == new_op_desc + output_var = partitioned_startup_global_block.var(output_vars[ + 0]) + output_var_attr = self._auto_parallel_context.get_tensor_distributed_attr_for_program( + output_var) + op_attr = OperatorDistributedAttribute( + new_op, self._auto_parallel_context) + op_attr.set_process_mesh(output_var_attr.get_process_mesh()) + op_attr.set_output_dims_mapping( + output_var.name, output_var_attr.get_dims_mapping()) + op_attr.set_input_dims_mapping( + output_var.name, output_var_attr.get_dims_mapping()) + self._auto_parallel_context.set_op_distributed_attr_for_program( + new_op, op_attr) + + # TODO move helper init to a comm place + dist_op_helper = self._auto_parallel_context.get_dist_op_helper() + dist_op_helper.set_dst_main_program(partitioned_main_prog) + dist_op_helper.set_dst_startup_program(partitioned_startup_prog) + dist_op_helper.set_varname_mapping(self._serial2dist_varname_mapping) + dist_op_helper.set_rank_id(self._rank_id) + # transpile main program for op in serial_ops: @@ -321,9 +372,9 @@ class Partitioner(object): for serial_input_varname in op.desc.input_arg_names(): if serial_input_varname not in self._serial2dist_varname_mapping: new_varname = serial_input_varname + self._dist_varname_suffix - if serial_global_block.has_var(serial_input_varname): + if serial_main_block.has_var(serial_input_varname): _partition_var(self._auto_parallel_context, - serial_global_block, + serial_main_block, partitioned_global_block, serial_input_varname, new_varname) else: @@ -337,118 +388,27 @@ class Partitioner(object): if serial_output_varname not in self._serial2dist_varname_mapping: new_varname = serial_output_varname + self._dist_varname_suffix _partition_var(self._auto_parallel_context, - serial_global_block, - partitioned_global_block, + serial_main_block, partitioned_global_block, serial_output_varname, new_varname) self._serial2dist_varname_mapping[ serial_output_varname] = new_varname # partition op - if _found_match_dist_op(self._auto_parallel_context, op): - # replace with corresponding dist op - _insert_dist_op(op, partitioned_global_block, - self._serial2dist_varname_mapping, - self._auto_parallel_context, self._rank_id) + kinputs, koutputs = dist_op_helper.prepare_forward_context(op) + dist_attr = self._auto_parallel_context.get_op_distributed_attr_for_program( + op) + if _is_dist_op_forward_implement(self._auto_parallel_context, op): + dist_ops = get_distributed_operator(op.type) + dist_op_impl = dist_ops.get_impl(dist_attr.get_impl_idx()) + dist_op_impl.forward(self._auto_parallel_context, **kinputs, + **koutputs) + else: # replicate op - _insert_src_op(op, partitioned_global_block, - self._serial2dist_varname_mapping) - - # transpile startup program - if serial_startup_program == None: - partitioned_startup_prog = None - else: - partitioned_startup_prog = fluid.Program() - # create parameter - partitioned_startup_global_block = partitioned_startup_prog.global_block( - ) - param2shape = {} - for var in partitioned_main_prog.list_vars(): - if isinstance(var, Parameter): - _partition_parameter(self._auto_parallel_context, var, - partitioned_startup_global_block, - var.name, var.shape) - param2shape[var.name] = var.shape - - # copy initializer - for op in serial_startup_program.global_block().ops: - output_vars = op.desc.output_arg_names() - assert len( - output_vars - ) == 1, "initializer should output only ONE variable, but got [{}]".format( - str(op.desc)) - assert self._serial2dist_varname_mapping[output_vars[ - 0]] in param2shape, "try to initialize [{}] which is not a Parameter".format( - output_vars[0]) - new_op_desc = partitioned_startup_global_block.desc.append_op() - new_op_desc.copy_from(op.desc) - new_op_desc._rename_output( - output_vars[0], - self._serial2dist_varname_mapping[output_vars[0]]) - new_op_desc._set_attr("shape", param2shape[ - self._serial2dist_varname_mapping[output_vars[0]]]) - partitioned_startup_global_block._sync_with_cpp() - - # MP broadcast not split parameter - # NOTE Theoretically, the MP param init broadcast should be handled by - # each dist op itself. but if we insert the broadcast op at that moment, the broadcast - # will before the initializer, which lead to a undertermined case. - if self._enable_tensor_parallel: - param_to_sync = [] - for param in partitioned_startup_prog.all_parameters(): - if not self._is_var_distributed(param): - param_to_sync.append(param) - # FIXME the ring id should be set by autoparallel.mapping module - # it should be determined by dp groups butfixed it here for hacking - partitioned_startup_global_block.append_op( - type='c_broadcast', - inputs={'X': param}, - outputs={'Out': param}, - attrs={ - 'ring_id': self._tp_group.id, - 'root': 0, - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Forward - }) - partitioned_startup_global_block.append_op( - type='c_sync_comm_stream', - inputs={'X': param_to_sync}, - outputs={'Out': param_to_sync}, - attrs={ - 'ring_id': self._tp_group.id, - OP_ROLE_KEY: OpRole.Forward - }) - partitioned_startup_global_block._sync_with_cpp() - - # DP init param broadcast - if self._enable_data_parallel: - # parameters initialization synchronization - param_to_sync = [] - - for param in partitioned_startup_global_block.all_parameters(): - param_to_sync.append(param) - - # FIXME the ring id should be set by autoparallel.mapping module - # it should be determined by dp groups butfixed it here for hacking - partitioned_startup_global_block.append_op( - type='c_broadcast', - inputs={'X': param}, - outputs={'Out': param}, - attrs={ - 'ring_id': self._dp_group.id, - 'root': 0, - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Forward - }) - partitioned_startup_global_block.append_op( - type='c_sync_comm_stream', - inputs={'X': param_to_sync}, - outputs={'Out': param_to_sync}, - attrs={ - 'ring_id': self._dp_group.id, - OP_ROLE_KEY: OpRole.Forward - }) - partitioned_startup_global_block._sync_with_cpp() + dist_ops = get_distributed_operator("default") + dist_op_impl = dist_ops.get_impl(0) + dist_op_impl.forward(self._auto_parallel_context, **kinputs, + **koutputs) return partitioned_main_prog, partitioned_startup_prog @@ -493,12 +453,65 @@ class Partitioner(object): for param in no_grad_set ] - return _auto_backward( + dist_op_helper = self._auto_parallel_context.get_dist_op_helper() + params_and_grads = _auto_backward( dist_loss, dist_startup_program, parameter_list=parameter_list, no_grad_set=no_grad_set, - callbacks=callbacks) + callbacks=callbacks, + distop_context=dist_op_helper) + + # backward completion + complete_backward_annotation( + dist_main_program, dist_context=self._auto_parallel_context) + + # transpiler backward for dist op + # get backward ops + ops = dist_main_program.global_block().ops + first_backward_op_idx = -1 + forward_op_id2forward_op = {} + for idx in range(len(ops)): + if is_forward_op(ops[idx]): + forward_op_id2forward_op[ops[idx].desc.id()] = ops[idx] + + if int(ops[idx].attr('op_role')) == int(OpRole.Backward): + first_backward_op_idx = idx + break + assert first_backward_op_idx >= 0, "not found backward ops in program" + assert len(forward_op_id2forward_op + ) > 0, "not found forward ops in program" + + backward_ops = ops[first_backward_op_idx:] + for backward_op in backward_ops: + # if the backward op has a corresponding forward op + if backward_op.desc.id() in dist_op_helper.gradopidx2opidx: + forward_op_id = dist_op_helper.gradopidx2opidx[ + backward_op.desc.id()] + forward_op = forward_op_id2forward_op[forward_op_id] + # TODO backward attr should has _impl_idx + forward_op_dist_attr = self._auto_parallel_context.get_op_distributed_attr_for_program( + forward_op) + # TODO use the backward op itself to find the dist op + dist_ops = get_distributed_operator(forward_op.type) + kinputs, koutputs = dist_op_helper.prepare_backward_context( + backward_op) + + # TODO use backward op itself to determine impl idx + if _is_dist_op_backward_implement( + self._auto_parallel_context, forward_op): + dist_op_impl = dist_ops.get_impl( + forward_op_dist_attr.get_impl_idx()) + dist_op_impl.backward(self._auto_parallel_context, + **kinputs, **koutputs) + else: + # replicate op + dist_ops = get_distributed_operator("default") + dist_op_impl = dist_ops.get_impl(0) + dist_op_impl.backward(self._auto_parallel_context, + **kinputs, **koutputs) + + return params_and_grads # replace dist grad ops else: raise RuntimeError("transpile NOT implemented !") @@ -509,6 +522,10 @@ class Partitioner(object): with program_guard(main_program, startup_program): optimize_ops = user_define_optimizer.apply_gradients(params_grads) + # update completion + complete_update_annotation( + main_program, dist_context=self._auto_parallel_context) + return optimize_ops def _is_valid_annotated_program(self, program): @@ -544,47 +561,6 @@ class Partitioner(object): return dist_var - def _determine_parallel_mode(self, program): - """ - determine the parallelism that is enabled - NOTE a hard rule and should be updated in future - """ - - for param in program.all_parameters(): - if self._is_var_distributed(param): - self._enable_tensor_parallel = True - break - - for var in program.list_vars(): - var_dist_attr = self._auto_parallel_context.get_tensor_distributed_attr_for_program( - var) - if not var_dist_attr.is_parameter(): - mapping = var_dist_attr.get_dims_mapping() - mesh = var_dist_attr.get_process_mesh().topology - if mapping and mapping[0] >= 0 and mesh[mapping[0]] > 1: - self._enable_data_parallel = True - break - - # tensor parallelism - if self._enable_tensor_parallel: - model_parallel_axis, process_mesh = self._auto_parallel_context._get_model_parallel_info( - ) - group_ranks = _get_comm_group(process_mesh.process_group, - process_mesh.topology, - model_parallel_axis, self._rank_id) - self._tp_degree = len(group_ranks) - self._tp_group = new_process_group(group_ranks) - - # data parallelism - data_parallel_axis, process_mesh = self._auto_parallel_context._get_data_parallel_info( - ) - if self._enable_data_parallel: - group_ranks = _get_comm_group(process_mesh.process_group, - process_mesh.topology, - data_parallel_axis, self._rank_id) - self._dp_degree = len(group_ranks) - self._dp_group = new_process_group(group_ranks) - def _is_var_distributed(self, var): dist_attr = self._auto_parallel_context.get_tensor_distributed_attr_for_program( @@ -629,68 +605,6 @@ class Partitioner(object): """ raise RuntimeError("sharding transpile is NOT implemented !") - def _gradient_sync_transpile(self, main_program, startup_program): - """ - append the gradient allreduce ops for all parameters' grad in case of Data Parallel - """ - - # scale loss by dp degree - main_global_block = main_program.global_block() - for idx, op in reversed(list(enumerate(main_global_block.ops))): - if is_loss_grad_op(op): - loss_grad_var = main_global_block.vars[op.output_arg_names[0]] - main_global_block._insert_op_without_sync( - idx + 1, - type='scale', - inputs={'X': loss_grad_var}, - outputs={'Out': loss_grad_var}, - attrs={ - 'scale': 1.0 / self._dp_degree, - OP_ROLE_KEY: OpRole.Backward - }) - break - main_global_block._sync_with_cpp() - - # gradient synchronization - # NOTE naive gradient sync without overlapping - # so there is not need to sync between calc and comm - # collecting grad var - grad_to_sync = [] - for idx, op in reversed(list(enumerate(main_global_block.ops))): - if is_backward_op(op) and \ - OP_ROLE_VAR_KEY in op.attr_names: - op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY] - if len(op_role_var) != 0: - assert len(op_role_var) % 2 == 0 - for i in range(0, len(op_role_var), 2): - param, reduced_grad = op_role_var[i], op_role_var[i + 1] - assert (reduced_grad not in grad_to_sync) - grad_to_sync.append(reduced_grad) - if is_optimizer_op(op): - first_optimize_op_idx = idx - - # insert allreduce - for grad in grad_to_sync: - # FIXME the ring id should be set by autoparallel.mapping module - # it should be determined by dp groups butfixed it here for hacking - main_global_block.append_op( - type='c_allreduce_sum', - inputs={'X': grad}, - outputs={'Out': grad}, - attrs={ - 'ring_id': self._dp_group.id, - 'root': 0, - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Backward - }) - main_global_block.append_op( - type='c_sync_comm_stream', - inputs={'X': grad_to_sync}, - outputs={'Out': grad_to_sync}, - attrs={'ring_id': self._dp_group.id, - OP_ROLE_KEY: OpRole.Backward}) - main_global_block._sync_with_cpp() - def _get_no_grad_set_name(no_grad_set): no_grad_set_name = set() @@ -723,7 +637,7 @@ def _get_no_grad_set(loss, no_grad_set=None): return no_grad_set -def _found_match_dist_op(auto_paralle_context, op): +def _is_dist_op_forward_implement(auto_paralle_context, op): dist_attr = auto_paralle_context.get_op_distributed_attr_for_program(op) dist_ops = get_distributed_operator(op.type) @@ -731,11 +645,20 @@ def _found_match_dist_op(auto_paralle_context, op): dist_attr.get_impl_idx())._forward_implemented +def _is_dist_op_backward_implement(auto_paralle_context, op): + dist_attr = auto_paralle_context.get_op_distributed_attr_for_program(op) + dist_ops = get_distributed_operator(op.type) + + return dist_ops and dist_attr.get_impl_idx() >= 0 and dist_ops.get_impl( \ + dist_attr.get_impl_idx())._backward_implemented + + def _auto_backward(loss, startup_program=None, parameter_list=None, no_grad_set=None, - callbacks=None): + callbacks=None, + distop_context=None): """ modification is inplaced """ @@ -753,9 +676,14 @@ def _auto_backward(loss, loss.shape) program = loss.block.program + with program_guard(program, startup_program): - params_grads = append_backward(loss, parameter_list, act_no_grad_set, - callbacks) + params_grads = append_backward( + loss, + parameter_list, + act_no_grad_set, + callbacks, + distop_context=distop_context) return params_grads @@ -822,6 +750,7 @@ def _partition_parameter(auto_paralle_context, src_var, dst_block, dst_varname, # param.desc.set_distributed_attr_uid(distributed_attr_uid) dist_attr = copy.deepcopy( auto_paralle_context.get_tensor_distributed_attr_for_program(src_var)) + assert dist_attr is not None dist_attr._owner_tensor = param dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program( src_var)._owner_context @@ -848,6 +777,7 @@ def _partition_intermediate_var(auto_paralle_context, src_var, dst_block, # var.desc.set_distributed_attr_uid(distributed_attr_uid) dist_attr = copy.deepcopy( auto_paralle_context.get_tensor_distributed_attr_for_program(src_var)) + assert dist_attr is not None dist_attr._owner_tensor = var dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program( src_var)._owner_context @@ -923,3 +853,11 @@ def _insert_dist_op(src_op, dst_block, varname_mapping, auto_paralle_context, input_mapping, output_mapping, rank_id=rank_id) + + +def is_forward_op(op): + role1 = int(core.op_proto_and_checker_maker.OpRole.Forward) | int( + core.op_proto_and_checker_maker.OpRole.Loss) + role2 = int(core.op_proto_and_checker_maker.OpRole.Forward) + op_role = int(op.attr('op_role')) + return op_role == role2 or op_role == role1 diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index a81ff699189..813bd481d92 100755 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -15,6 +15,7 @@ import threading import paddle.fluid.core as core import numpy as np +from .interface import _g_process_mesh_map def is_valid_list_index(list, index): @@ -171,7 +172,9 @@ def _get_comm_group(processes, shape, axis, rank): """ # NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous - # tricks to support processes mesh when it is not start with 0 or continuous + # tricks to support processes mesh when it is not start with 0 or continuous + assert rank in processes, "rank [{}] is NOT in processes group {}".format( + rank, processes) rank_relatvie = processes.index(rank) coordinate = _linear_idx2coordinate(shape, rank_relatvie) coordinates_in_group = [coordinate[:] for i in range(shape[axis])] @@ -189,6 +192,25 @@ def _get_comm_group(processes, shape, axis, rank): return sorted(ranks_in_group) +def _get_idx_in_axis(processes, shape, axis, rank): + """ + Given a rank and the processes mesh the rank belongs to, + compute the index of the rank in given axis. + + Example: 27 processes managed in a 3-Dimensinal mesh with shape of [3, 3, 3]. + the index of rank 22 are: + in axis 0: 1 + in axis 1: 1 + in axis 2: 2 + """ + + # NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous + # tricks to support processes mesh when it is not start with 0 or continuous + rank_relatvie = processes.index(rank) + coordinate = _linear_idx2coordinate(shape, rank_relatvie) + return coordinate[axis] + + def _coordinate2linear_idx(mesh_shape, coordinate): """ convert a coordinate in multidimensional mesh space into a scala idx in linear space. @@ -279,6 +301,27 @@ def _linear_idx2coordinate(mesh_shape, linear_idx): return coordinate +def _get_corresponding_rank(target_mesh, rank): + + # TODO(JZ-LIANG) a hack method to support varying mesh in Pipeline parallelism case. + # we assume that all mesh are evenly divide from a parent mesh and should have same size. + # to revise this in future. + + coordinate = None + for key, mesh in _g_process_mesh_map.items(): + if key == 0: + continue + if rank in mesh.process_group and mesh.topology == target_mesh.topology: + coordinate = _linear_idx2coordinate(mesh.topology, + mesh.process_group.index(rank)) + break + + assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format( + rank) + return target_mesh.process_group[_coordinate2linear_idx(mesh.topology, + coordinate)] + + def _get_unshard_dist_shape(var, dist_attr): var_shape = var.shape mapping = dist_attr.get_dims_mapping() diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index d62f7b59411..9ea407c760f 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -1051,7 +1051,8 @@ def _append_backward_ops_(block, grad_to_var, callbacks=None, input_grad_names_set=None, - op_path_dict=None): + op_path_dict=None, + distop_context=None): """ Create all grad ops, and insert them into given block @@ -1108,6 +1109,10 @@ def _append_backward_ops_(block, # Getting op's corresponding grad_op grad_op_desc, op_grad_to_var = core.get_grad_op_desc( op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list) + if distop_context is not None: + for op_desc in grad_op_desc: + assert op_desc.id() not in distop_context.gradopidx2opidx + distop_context.gradopidx2opidx[op_desc.id()] = op.desc.id() # Set device for grad_op according to forward Op device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() @@ -1402,7 +1407,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callbacks=None, - checkpoints=None): + checkpoints=None, + distop_context=None): """ :api_attr: Static Graph @@ -1617,7 +1623,8 @@ def append_backward(loss, grad_to_var, callbacks, input_grad_names_set=input_grad_names_set, - op_path_dict=op_path_dict) + op_path_dict=op_path_dict, + distop_context=distop_context, ) grad_info_map = dict() diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 90f59758a2f..745e7118522 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -32,6 +32,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_pipeline_parallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_tensor_parallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel) +list(APPEND DIST_TEST_OPS test_auto_parallel_parallelizer) list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers) list(APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper) list(APPEND DIST_TEST_OPS test_parallel_class_center_sample) @@ -221,6 +222,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_pipeline_parallel) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_tensor_parallel) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sharding_parallel) + list(REMOVE_ITEM TEST_OPS test_auto_parallel_parallelizer) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers) LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision) LIST(REMOVE_ITEM TEST_OPS test_mixed_precision) @@ -1002,6 +1004,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) set_tests_properties(test_parallel_dygraph_pipeline_parallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_tensor_parallel PROPERTIES TIMEOUT 200) set_tests_properties(test_parallel_dygraph_sharding_parallel PROPERTIES TIMEOUT 120) + set_tests_properties(test_auto_parallel_parallelizer PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120) set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_class_center_sample PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel_parallelizer.py b/python/paddle/fluid/tests/unittests/auto_parallel_parallelizer.py new file mode 100755 index 00000000000..89880f8c2f4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel_parallelizer.py @@ -0,0 +1,140 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import paddle +import paddle.nn as nn +import paddle.static as static +import paddle.nn.functional as F +import paddle.utils as utils +from paddle.fluid import layers +from paddle.distributed import fleet +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr +import paddle.fluid.core as core + +paddle.enable_static() +_global_parallel_strategy = None +_global_process_mesh = None +ROOT_MESH = auto.ProcessMesh([0, 1]) + + +class MLPLayer(nn.Layer): + def __init__(self, + hidden_size=1024, + intermediate_size=4 * 1024, + dropout_ratio=0.1, + initializer_range=0.02): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range)) + bias_attr = None + + self.linear0 = nn.Linear( + d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) + self.linear1 = nn.Linear( + dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) + self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr) + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") + + def forward(self, input): + out = self.norm(input) + out = self.linear0(out) + out = F.gelu(out, approximate=True) + out = self.linear1(out) + out = self.dropout(out) + out = self.linear2(out) + + return out + + +def mlp_pretrain_forward(train_program, start_program): + with static.program_guard(train_program, + start_program), utils.unique_name.guard(): + batch_size = 4 + hidden_size = 1024 + sequence_len = 512 + input = static.data( + name="input", + shape=[batch_size, sequence_len, hidden_size], + dtype='float32') + label = static.data( + name="label", shape=[batch_size, sequence_len, 1], dtype='float32') + + auto.shard_tensor(input, _global_process_mesh, dim_mapping=[-1, -1, -1]) + auto.set_pipeline_stage(1) + + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02) + + predict = mlp(input) + + cost = layers.cross_entropy(input=predict, label=label) + avg_cost = layers.mean(x=cost) + + return avg_cost, train_program, start_program + + +class TestMLPAutoParallelizer(unittest.TestCase): + def test_mlp_serial(self): + + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH) + + dist_strategy = fleet.DistributedStrategy() + dist_strategy.amp = False + dist_strategy.pipeline = False + dist_strategy.recompute = False + + # init parallel optimizer + dist_strategy.semi_auto = True + + fleet.init(is_collective=True, strategy=dist_strategy) + + train_program = static.Program() + start_program = static.Program() + loss, train_program, start_program = mlp_pretrain_forward(train_program, + start_program) + + optimizer = paddle.fluid.optimizer.AdamOptimizer( + learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) + + optimizer = fleet.distributed_optimizer(optimizer) + _, _, distributed_startup_program, distributed_main_program = optimizer.minimize( + loss, start_program) + suffix = core.kAutoParallelSuffix() + for block in distributed_main_program.blocks: + for op in block.ops: + for attr_name in op.attr_names: + self.assertTrue(suffix not in attr_name) + # print_program_with_distributed_attr(distributed_main_program) + self.assertIsNotNone(distributed_startup_program) + self.assertIsNotNone(distributed_main_program) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_parallelizer.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_parallelizer.py index a92e1e2f338..7147716c74c 100755 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_parallelizer.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_parallelizer.py @@ -15,130 +15,16 @@ from __future__ import print_function import unittest +import paddle.fluid as fluid -# The following statements are used to satisfy fleet initialization -import os -if os.getenv("CUDA_VISIBLE_DEVICES", None) is None: - os.environ["CUDA_VISIBLE_DEVICES"] = '0' +from test_parallel_dygraph_dataparallel import TestMultipleGpus -import paddle -import paddle.nn as nn -import paddle.static as static -import paddle.nn.functional as F -import paddle.utils as utils -from paddle.fluid import layers -from paddle.distributed import fleet -import paddle.distributed.auto_parallel as auto -from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr -import paddle.fluid.core as core -paddle.enable_static() -_global_parallel_strategy = None -_global_process_mesh = None -ROOT_MESH = auto.ProcessMesh([0, 1]) +class TestParallelizer(TestMultipleGpus): - -class MLPLayer(nn.Layer): - def __init__(self, - hidden_size=1024, - intermediate_size=4 * 1024, - dropout_ratio=0.1, - initializer_range=0.02): - super(MLPLayer, self).__init__() - d_model = hidden_size - dim_feedforward = intermediate_size - weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( - mean=0.0, std=initializer_range)) - bias_attr = None - - self.linear0 = nn.Linear( - d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) - self.linear1 = nn.Linear( - dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) - self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr) - self.norm = nn.LayerNorm(d_model, epsilon=1e-5) - self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") - - def forward(self, input): - out = self.norm(input) - out = self.linear0(out) - out = F.gelu(out, approximate=True) - out = self.linear1(out) - out = self.dropout(out) - out = self.linear2(out) - - return out - - -def mlp_pretrain_forward(train_program, start_program): - with static.program_guard(train_program, - start_program), utils.unique_name.guard(): - batch_size = 4 - hidden_size = 1024 - sequence_len = 512 - input = static.data( - name="input", - shape=[batch_size, sequence_len, hidden_size], - dtype='float32') - label = static.data( - name="label", shape=[batch_size, sequence_len, 1], dtype='float32') - - auto.shard_tensor(input, _global_process_mesh, dim_mapping=[-1, -1, -1]) - auto.set_pipeline_stage(1) - - mlp = MLPLayer( - hidden_size=hidden_size, - intermediate_size=4 * hidden_size, - dropout_ratio=0.1, - initializer_range=0.02) - - predict = mlp(input) - - cost = layers.cross_entropy(input=predict, label=label) - avg_cost = layers.mean(x=cost) - - return avg_cost, train_program, start_program - - -class TestMLPAutoParallelizer(unittest.TestCase): - def test_mlp_serial(self): - - global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH) - - dist_strategy = fleet.DistributedStrategy() - dist_strategy.amp = False - dist_strategy.pipeline = False - dist_strategy.recompute = False - - # init parallel optimizer - dist_strategy.semi_auto = True - - fleet.init(is_collective=True, strategy=dist_strategy) - - train_program = static.Program() - start_program = static.Program() - loss, train_program, start_program = mlp_pretrain_forward(train_program, - start_program) - - optimizer = paddle.fluid.optimizer.AdamOptimizer( - learning_rate=0.00001, - beta1=0.9, - beta2=0.999, - epsilon=1e-08, - grad_clip=None) - - optimizer = fleet.distributed_optimizer(optimizer) - _, _, distributed_startup_program, distributed_main_program = optimizer.minimize( - loss, start_program) - suffix = core.kAutoParallelSuffix() - for block in distributed_main_program.blocks: - for op in block.ops: - for attr_name in op.attr_names: - self.assertTrue(suffix not in attr_name) - # print_program_with_distributed_attr(distributed_main_program) - self.assertIsNotNone(distributed_startup_program) - self.assertIsNotNone(distributed_main_program) + # check sharding logic as well as the accuracy with single mode + def test_parallelizer_logic(self): + self.run_mnist_2gpu('auto_parallel_parallelizer.py') if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py index 29ba863c962..44a52524401 100755 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py @@ -92,9 +92,9 @@ def check_tensor_split(prog1, varnames1, prog2, varnames2, axis, nsplit): def initialization_check(mode, dist_context, dist_startup_prog, - serial_startup_prog, var_need_broadcast): + serial_startup_prog, var_need_broadcast, process_mesh, + mp_parallel_axis, dp_parallel_axis): if 'mp' in mode: - mp_parallel_axis, process_mesh = dist_context._get_model_parallel_info() group_ranks = _get_comm_group(process_mesh.process_group, process_mesh.topology, mp_parallel_axis, 3) @@ -110,7 +110,6 @@ def initialization_check(mode, dist_context, dist_startup_prog, return False if 'dp' in mode: - dp_parallel_axis, process_mesh = dist_context._get_data_parallel_info() group_ranks = _get_comm_group(process_mesh.process_group, process_mesh.topology, dp_parallel_axis, 3) @@ -359,9 +358,15 @@ class TestMLPAutoPartitioner(unittest.TestCase): # parameter initialization var_need_broadcast = [] self.assertTrue( - initialization_check(_global_parallel_strategy, dist_context, - dist_startup_prog, serial_startup_prog, - var_need_broadcast)) + initialization_check( + _global_parallel_strategy, + dist_context, + dist_startup_prog, + serial_startup_prog, + var_need_broadcast, + _global_process_mesh, + mp_parallel_axis=None, + dp_parallel_axis=0)) def test_mlp_mp(self): global _global_parallel_strategy @@ -406,9 +411,15 @@ class TestMLPAutoPartitioner(unittest.TestCase): var_need_broadcast = sorted( ['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0']) self.assertTrue( - initialization_check(_global_parallel_strategy, dist_context, - dist_startup_prog, serial_startup_prog, - var_need_broadcast)) + initialization_check( + _global_parallel_strategy, + dist_context, + dist_startup_prog, + serial_startup_prog, + var_need_broadcast, + _global_process_mesh, + mp_parallel_axis=0, + dp_parallel_axis=None)) # check var and op all have dist_attr in dist_main_program self.assertTrue( @@ -464,9 +475,15 @@ class TestMLPAutoPartitioner(unittest.TestCase): var_need_broadcast = sorted( ['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0']) self.assertTrue( - initialization_check(_global_parallel_strategy, dist_context, - dist_startup_prog, serial_startup_prog, - var_need_broadcast)) + initialization_check( + _global_parallel_strategy, + dist_context, + dist_startup_prog, + serial_startup_prog, + var_need_broadcast, + _global_process_mesh, + mp_parallel_axis=1, + dp_parallel_axis=0)) # check var and op all have dist_attr in dist_main_program self.assertTrue( @@ -635,9 +652,15 @@ class TestAttentionAutoPartitioner(unittest.TestCase): # parameter initialization var_need_broadcast = [] self.assertTrue( - initialization_check(_global_parallel_strategy, dist_context, - dist_startup_prog, serial_startup_prog, - var_need_broadcast)) + initialization_check( + _global_parallel_strategy, + dist_context, + dist_startup_prog, + serial_startup_prog, + var_need_broadcast, + _global_process_mesh, + mp_parallel_axis=None, + dp_parallel_axis=0)) def test_attn_mp(self): global _global_parallel_strategy @@ -686,9 +709,15 @@ class TestAttentionAutoPartitioner(unittest.TestCase): # parameter initialization var_need_broadcast = ['linear_3.b_0'] self.assertTrue( - initialization_check(_global_parallel_strategy, dist_context, - dist_startup_prog, serial_startup_prog, - var_need_broadcast)) + initialization_check( + _global_parallel_strategy, + dist_context, + dist_startup_prog, + serial_startup_prog, + var_need_broadcast, + _global_process_mesh, + mp_parallel_axis=0, + dp_parallel_axis=None)) # check var and op all have dist_attr in dist_main_program self.assertTrue( @@ -748,9 +777,15 @@ class TestAttentionAutoPartitioner(unittest.TestCase): # parameter initialization var_need_broadcast = ['linear_3.b_0'] self.assertTrue( - initialization_check(_global_parallel_strategy, dist_context, - dist_startup_prog, serial_startup_prog, - var_need_broadcast)) + initialization_check( + _global_parallel_strategy, + dist_context, + dist_startup_prog, + serial_startup_prog, + var_need_broadcast, + _global_process_mesh, + mp_parallel_axis=1, + dp_parallel_axis=0)) # check var and op all have dist_attr in dist_main_program self.assertTrue( @@ -1043,9 +1078,15 @@ class TestDecoderLayerPartitioner(unittest.TestCase): 'layer_norm_0.w_0', 'linear_5.b_0' ]) self.assertTrue( - initialization_check(_global_parallel_strategy, dist_context, - dist_startup_prog, serial_startup_prog, - var_need_broadcast)) + initialization_check( + _global_parallel_strategy, + dist_context, + dist_startup_prog, + serial_startup_prog, + var_need_broadcast, + _global_process_mesh, + mp_parallel_axis=1, + dp_parallel_axis=0)) # check var and op all have dist_attr in dist_main_program self.assertTrue( @@ -1117,7 +1158,16 @@ class TestDecoderLayerPartitioner(unittest.TestCase): 'fill_constant', 'gaussian_random', 'fill_constant', 'gaussian_random', 'fill_constant', 'gaussian_random', 'fill_constant', 'gaussian_random', 'fill_constant', - 'gaussian_random', 'fill_constant', 'fill_constant', 'fill_constant' + 'gaussian_random', 'fill_constant', 'fill_constant', + 'fill_constant', 'c_broadcast', 'c_broadcast', 'c_broadcast', + 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', + 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', + 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', + 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', + 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', + 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', + 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', + 'c_broadcast' ] self.assertTrue(dist_ops == ref_ops) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py index 16cbad3ef6f..11b3338bc67 100755 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py @@ -521,7 +521,7 @@ class GPTModel(nn.Layer): def __init__(self, vocab_size, hidden_size=768, - num_hidden_layers=12, + num_hidden_layers=4, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", @@ -787,6 +787,14 @@ class TestGPTPartitioner(unittest.TestCase): dist_params_grads = partitioner.apply_backward( loss, complete_train_program, start_program, auto_parallel_main_prog, auto_parallel_startup_prog) + + with open("./test_auto_parallel_partitioner_serial_main_new.txt", + "w") as fw: + fw.write(str(train_program)) + with open("./test_auto_parallel_partitioner_serial_startup_new.txt", + "w") as fw: + fw.write(str(start_program)) + optimizer = paddle.fluid.optimizer.AdamOptimizer( learning_rate=0.00001, beta1=0.9, @@ -796,7 +804,17 @@ class TestGPTPartitioner(unittest.TestCase): opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads, auto_parallel_main_prog, auto_parallel_startup_prog) - + from paddle.distributed.auto_parallel.context import set_default_distributed_context + set_default_distributed_context(dist_context) + with open("./test_auto_parallel_partitioner_main_new.txt1", "w") as fw: + fw.write(str(auto_parallel_main_prog)) + with open("./test_auto_parallel_partitioner_startup_new.txt1", + "w") as fw: + fw.write(str(auto_parallel_startup_prog)) + # with open("./test_auto_parallel_partitioner_main_completed.txt", "w") as fw: + # from paddle.distributed.auto_parallel.completion import complete_backward_annotation + # complete_backward_annotation(auto_parallel_main_prog) + # fw.write(str(auto_parallel_main_prog)) nrank = 4 # col parallel weights = [ @@ -826,16 +844,20 @@ class TestGPTPartitioner(unittest.TestCase): 'layer_norm_6.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_8.tmp_2' ] - mp_parallel_axis, process_mesh = dist_context._get_model_parallel_info() + process_mesh = _global_process_mesh + mp_parallel_axis = 1 + dp_parallel_axis = 0 + group_ranks = _get_comm_group(process_mesh.process_group, process_mesh.topology, mp_parallel_axis, 3) mp_ring_id = new_process_group(group_ranks).id - dp_parallel_axis, process_mesh = dist_context._get_data_parallel_info() + group_ranks = _get_comm_group(process_mesh.process_group, process_mesh.topology, dp_parallel_axis, 3) dp_ring_id = new_process_group(group_ranks).id + tensor_parallel_allreduce_vars = sorted([ op.desc.output_arg_names()[0].split("@")[0] for op in auto_parallel_main_prog.global_block().ops diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py index da82e56d4a1..fe9b965ed87 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py @@ -25,7 +25,6 @@ import paddle.distributed.auto_parallel as auto from paddle.distributed.auto_parallel.context import DistributedContext from paddle.distributed import fleet from paddle.distributed.auto_parallel.partitioner import Partitioner -from paddle.distributed.auto_parallel.completion import complete_backward_annotation from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.process import PROCESS_GROUP_MAP @@ -211,7 +210,8 @@ def check_initialization_for_dp(dist_startup_prog): if op.type == "c_broadcast": broadcast_varnames.append(op.output_arg_names[0]) - return params == need_check_params == broadcast_varnames + return sorted(params) == sorted(need_check_params) == sorted( + broadcast_varnames) class TestMLPReshard(unittest.TestCase): @@ -225,7 +225,6 @@ class TestMLPReshard(unittest.TestCase): rank_id = 0 dist_main_prog, dist_startup_prog = get_dist_prog( train_program, startup_program, dist_context, 0) - complete_backward_annotation(dist_main_prog, dist_context) op_need_check = None for op in dist_main_prog.global_block().ops: @@ -254,7 +253,6 @@ class TestMLPReshard(unittest.TestCase): rank_id = 1 dist_main_prog, dist_startup_prog = get_dist_prog( train_program, startup_program, dist_context, rank_id) - complete_backward_annotation(dist_main_prog, dist_context) for key in list(PROCESS_GROUP_MAP.keys()): del PROCESS_GROUP_MAP[key] reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) @@ -277,7 +275,6 @@ class TestMLPReshard(unittest.TestCase): rank_id = 0 dist_main_prog, dist_startup_prog = get_dist_prog( train_program, startup_program, dist_context, rank_id) - complete_backward_annotation(dist_main_prog, dist_context) reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) # send and recv should not exist in dp scene. self.assertFalse(check_send_recv_result(dist_main_prog, rank_id)) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py index 1e134eebfd2..babc622393c 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py @@ -25,7 +25,6 @@ import paddle.distributed.auto_parallel as auto from paddle.distributed.auto_parallel.context import DistributedContext from paddle.distributed import fleet from paddle.distributed.auto_parallel.partitioner import Partitioner -from paddle.distributed.auto_parallel.completion import complete_backward_annotation from paddle.distributed.auto_parallel.reshard import reshard paddle.enable_static() @@ -158,7 +157,6 @@ class TestMLPReshard(unittest.TestCase): dist_main_prog, dist_startup_prog = get_dist_prog( train_program, startup_program, dist_context, rank_id) print(dist_main_prog) - complete_backward_annotation(dist_main_prog, dist_context) reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) print(dist_main_prog) print(dist_startup_prog) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py index 5a10a218345..96a8b2a8d7c 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py @@ -25,7 +25,6 @@ import paddle.distributed.auto_parallel as auto from paddle.distributed.auto_parallel.context import DistributedContext from paddle.distributed import fleet from paddle.distributed.auto_parallel.partitioner import Partitioner -from paddle.distributed.auto_parallel.completion import complete_backward_annotation from paddle.distributed.auto_parallel.reshard import reshard paddle.enable_static() @@ -187,7 +186,6 @@ class TestMLPReshard(unittest.TestCase): rank_id = 2 dist_main_prog, dist_startup_prog = get_dist_prog( train_program, startup_program, dist_context, rank_id) - complete_backward_annotation(dist_main_prog, dist_context) reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) # check send and recv result -- GitLab