From c38b04883e8b3079d8321b5cce03f9ec07df1fd1 Mon Sep 17 00:00:00 2001 From: caozhou <48191911+Caozhou1995@users.noreply.github.com> Date: Mon, 11 Oct 2021 17:45:18 +0800 Subject: [PATCH] add reshard module (#35779) * add reshard module * fix conflict * update reshard module * update and add unitest * update reshard module and unitest * add more unitests --- .../distributed/auto_parallel/__init__.py | 2 + .../distributed/auto_parallel/completion.py | 170 +++ .../distributed/auto_parallel/context.py | 3 + .../auto_parallel/operators/dist_embedding.py | 14 +- .../distributed/auto_parallel/parallelizer.py | 9 +- .../distributed/auto_parallel/reshard.py | 1002 +++++++++++++++++ .../fluid/tests/unittests/CMakeLists.txt | 12 + .../unittests/test_auto_parallel_reshard.py | 287 +++++ .../test_auto_parallel_reshard_dpmppp.py | 173 +++ .../test_auto_parallel_reshard_mppp.py | 231 ++++ .../test_auto_parallel_reshard_serial.py | 184 +++ 11 files changed, 2083 insertions(+), 4 deletions(-) create mode 100644 python/paddle/distributed/auto_parallel/reshard.py create mode 100644 python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py create mode 100644 python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py create mode 100644 python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py create mode 100644 python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_serial.py diff --git a/python/paddle/distributed/auto_parallel/__init__.py b/python/paddle/distributed/auto_parallel/__init__.py index 5b0fdc1f1f1..31f92e2575a 100644 --- a/python/paddle/distributed/auto_parallel/__init__.py +++ b/python/paddle/distributed/auto_parallel/__init__.py @@ -19,5 +19,7 @@ from .interface import set_offload_device # noqa: F401 from .interface import set_pipeline_stage # noqa: F401 from .interface import ProcessMesh # noqa: F401 from .completion import complete_annotation # noqa: F401 +from .completion import complete_backward_annotation # noqa: F401 +from .reshard import reshard # noqa: F401 __all__ = [] diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 6e886d09d67..3fdbad6950d 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -23,6 +23,7 @@ from .utils import compute_compatible_dims_mapping from .utils import print_program_with_distributed_attr from .context import get_default_distributed_context from .operators import find_best_compatible_distributed_operator_impl +from .attribute import OperatorDistributedAttribute, TensorDistributedAttribute ELEMENTWISE_LIKE_OP_LIST = ["elementwise_add", "gelu", "dropout", "cast"] @@ -597,3 +598,172 @@ def complete_annotation(program, dist_context=None): dist_context.amend_distributed_attr_for_program() return program + + +def complete_backward_annotation(auto_parallel_main_prog, dist_context): + """Complete the annotation of vars and ops in the backward phase for parallel program.""" + + def _is_grad_var_name(name): + if "@GRAD" in name: + return True + return False + + grad_start_idx = None + for idx, op in enumerate(auto_parallel_main_prog.global_block().ops): + for var_name in op.output_arg_names: + # TODO: use _is_loss_op to judge + if "@GRAD" in var_name and op.type == "fill_constant": + grad_start_idx = idx + break + assert grad_start_idx is not None, "No backward procedure found in this program." + + ops = list(auto_parallel_main_prog.global_block().ops) + vars = auto_parallel_main_prog.global_block().vars + for idx in range(grad_start_idx, len(ops)): + # complete the loss op + if idx == grad_start_idx: + grad_var = vars[ops[idx].output_arg_names[0]] + grad_var_name = grad_var.name + forward_var_name = grad_var_name[:grad_var_name.find("@GRAD")] + forward_var = vars[forward_var_name] + tensor_attr = TensorDistributedAttribute(grad_var, dist_context) + process_mesh = dist_context.get_tensor_distributed_attr_for_program( + forward_var).get_process_mesh() + dims_mapping = dist_context.get_tensor_distributed_attr_for_program( + forward_var).get_dims_mapping() + tensor_attr.set_dims_mapping(dims_mapping) + tensor_attr.set_process_mesh(process_mesh) + dist_context.set_tensor_distributed_attr_for_program(grad_var, + tensor_attr) + op_attr = OperatorDistributedAttribute(ops[idx], dist_context) + op_attr.set_process_mesh(process_mesh) + dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr) + + # in the data parallel mode, the loss op followed by scale op. + if ops[idx + 1].type == "scale" and grad_var_name in ops[idx + 1].input_arg_names \ + and grad_var_name in ops[idx + 1].output_arg_names: + op_attr = OperatorDistributedAttribute(ops[idx + 1], + dist_context) + op_attr.set_process_mesh(process_mesh) + dist_context.set_op_distributed_attr_for_program(ops[idx + 1], + op_attr) + continue + + # complete the annotation of the optimizer op. + # TODO: use _is_optimizer_op to judge + if "Grad" in ops[idx].input_names and "Param" in ops[idx].input_names: + assert len(ops[idx].input( + "Param")) == 1, "Only support one-to-one now." + assert len(ops[idx].input( + "Grad")) == 1, "Only support one-to-one now." + var = vars[ops[idx].input("Param")[0]] + grad_var = vars[ops[idx].input("Grad")[0]] + process_mesh = dist_context.get_tensor_distributed_attr_for_program( + var).get_process_mesh() + dims_mapping = dist_context.get_tensor_distributed_attr_for_program( + var).get_dims_mapping() + op_attr = OperatorDistributedAttribute(ops[idx], dist_context) + op_attr.set_process_mesh(process_mesh) + op_attr.set_input_dims_mapping(grad_var.name, dims_mapping) + dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr) + continue + + # complete the c_allreduce_sum op for gradient in the data parallel mode. + if ops[idx].type == "c_allreduce_sum" and ops[ + idx].input_arg_names == ops[idx].output_arg_names: + grad_var = vars[ops[idx].output_arg_names[0]] + op_attr = OperatorDistributedAttribute(ops[idx], dist_context) + process_mesh = dist_context.get_tensor_distributed_attr_for_program( + grad_var).get_process_mesh() + op_attr.set_process_mesh(process_mesh) + dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr) + continue + + # complete the annotation of grad op + grad_op = ops[idx] + for i, op in enumerate(ops[:grad_start_idx]): + match_op = None + grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(op.desc, + set(), + []) + grad_op_input = [] + for input_arg_name in grad_op.desc.input_arg_names(): + if "@GRAD" in input_arg_name: + name = input_arg_name[:input_arg_name.find("@GRAD") + 5] + grad_op_input.append(name) + else: + grad_op_input.append(input_arg_name) + + # like sum op: the count of grad op will larger than 1 + if len(grad_op_desc_list) > 1: + for grad_op_desc in grad_op_desc_list: + if grad_op_input == grad_op_desc.input_arg_names() \ + and grad_op.desc.type() == grad_op_desc.type(): + match_op = op + break + elif len(grad_op_desc_list) == 1: + if grad_op_input == grad_op_desc_list[0].input_arg_names() \ + and grad_op.desc.type() == grad_op_desc_list[0].type(): + match_op = op + + if match_op is not None: + op_attr = dist_context.get_op_distributed_attr_for_program(op) + grad_op_attr = OperatorDistributedAttribute(grad_op, + dist_context) + grad_op_attr.set_process_mesh(op_attr.get_process_mesh()) + for var_name in grad_op.input_arg_names: + if "@GRAD" in var_name: + dims_mapping = dist_context.get_tensor_distributed_attr_for_program( + vars[var_name]).get_dims_mapping() + grad_op_attr.set_input_dims_mapping(var_name, + dims_mapping) + else: + dims_mapping = op_attr.get_input_dims_mapping(var_name) + grad_op_attr.set_input_dims_mapping(var_name, + dims_mapping) + dist_context.set_op_distributed_attr_for_program(grad_op, + grad_op_attr) + + for var_name in grad_op.output_arg_names: + if "@GRAD" in var_name: + forward_var = vars[var_name[:var_name.find("@GRAD")]] + tensor_attr = TensorDistributedAttribute(vars[var_name], + dist_context) + process_mesh = grad_op_attr.get_process_mesh() + dims_mapping = grad_op_attr.get_input_dims_mapping( + forward_var.name) + tensor_attr.set_process_mesh(process_mesh) + tensor_attr.set_dims_mapping(dims_mapping) + dist_context.set_tensor_distributed_attr_for_program( + vars[var_name], tensor_attr) + break + + # complete the annotation of sum op for multiple renamed grad var + if grad_op.type == "sum" and all( + map(_is_grad_var_name, grad_op.input_arg_names)): + assert len(grad_op.output_arg_names + ) == 1, "The output count of sum op should be one." + grad_op_attr = OperatorDistributedAttribute(grad_op, dist_context) + for var_name in grad_op.input_arg_names: + if "@GRAD" in var_name: + forward_var = vars[var_name[:var_name.find("@GRAD")]] + dims_mapping = dist_context.get_tensor_distributed_attr_for_program( + forward_var).get_dims_mapping() + grad_op_attr.set_input_dims_mapping(var_name, dims_mapping) + for var_name in grad_op.output_arg_names: + forward_var = vars[var_name[:var_name.find("@GRAD")]] + tensor_attr = TensorDistributedAttribute(vars[var_name], + dist_context) + process_mesh = dist_context.get_tensor_distributed_attr_for_program( + forward_var).get_process_mesh() + dims_mapping = dist_context.get_tensor_distributed_attr_for_program( + forward_var).get_dims_mapping() + tensor_attr.set_dims_mapping(dims_mapping) + tensor_attr.set_process_mesh(process_mesh) + dist_context.set_tensor_distributed_attr_for_program( + vars[var_name], tensor_attr) + grad_op_attr.set_process_mesh( + dist_context.get_tensor_distributed_attr_for_program( + forward_var).get_process_mesh()) + dist_context.set_op_distributed_attr_for_program(grad_op, + grad_op_attr) diff --git a/python/paddle/distributed/auto_parallel/context.py b/python/paddle/distributed/auto_parallel/context.py index 4958c5adfae..5e6565aa3d8 100644 --- a/python/paddle/distributed/auto_parallel/context.py +++ b/python/paddle/distributed/auto_parallel/context.py @@ -59,6 +59,9 @@ class DistributedContext: if self._process_mesh.ndim == 1: self._data_parallel_axis = 0 self._model_parallel_axis = 0 + elif self._process_mesh.ndim == 3: + self._data_parallel_axis = 1 + self._model_parallel_axis = 2 else: self._data_parallel_axis = 0 self._model_parallel_axis = 1 diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index 141c3d14a7f..3f8fbf9cc3a 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -146,8 +146,18 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): assert mesh_shape <= 2, "row_parallel_embedding only support 1 or 2 dimensional process mesh, but got {}".format( process_mesh_shape) num_partition = process_mesh_shape[embedding_row_dim_mapping] - # TODO generalize here, support any mesh group + # TODO generalize here, support any mesh group + model_parallel_axis, process_mesh = op_dist_attr.get_owner_context( + )._get_model_parallel_info() if mesh_shape == 1: + if rank_id not in process_mesh_group: + assert len( + process_mesh.topology + ) == 2, " row_parallel_embedding process mapping only support 2 dimensional process mesh, \ + but got {}".format(len(process_mesh.topology)) + rank_id = process_mesh_group[ + process_mesh.process_group.index(rank_id) % + process_mesh_shape[0]] relative_idx = process_mesh_group.index(rank_id) else: relative_idx = rank_id % num_partition @@ -156,8 +166,6 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): relative_idx = relative_idx * per_part_size # TODO caculate ring id - model_parallel_axis, process_mesh = op_dist_attr.get_owner_context( - )._get_model_parallel_info() group_ranks = _get_comm_group(process_mesh.process_group, process_mesh.topology, model_parallel_axis, rank_id) diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index a08da13a39c..2994d35ef92 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -17,9 +17,10 @@ from paddle.distributed.fleet import cloud_utils import paddle.fluid.core as core from .context import DistributedContext from .context import get_default_distributed_context -from .completion import complete_annotation +from .completion import complete_annotation, complete_backward_annotation from .partitioner import Partitioner from .process import get_all_process_groups +from .reshard import reshard class AutoParallelizer: @@ -85,10 +86,16 @@ class AutoParallelizer: # instantiate communication by process_mapping. all_process_groups = get_all_process_groups() for process_group in all_process_groups: + if rank not in process_group._ranks: + continue process_group.instantiate() # The last step: remove all distributed attributes to be compatiable # with inference. self._remove_distributed_attrs(partitioned_main_prog) + complete_backward_annotation(partitioned_main_prog, self._dist_context) + reshard(partitioned_main_prog, partitioned_startup_prog, rank, + self._dist_context) + return dist_optimize_ops, dist_params_grads, partitioned_startup_prog, partitioned_main_prog diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py new file mode 100644 index 00000000000..d66d799c6e0 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -0,0 +1,1002 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import copy +from functools import reduce + +import paddle +import paddle.fluid.core as core +from paddle.utils import unique_name +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.framework import Program, OpProtoHolder +import paddle.fluid.layers.utils as utils +from ..collective import _get_global_env +from .context import DistributedContext +from .attribute import OperatorDistributedAttribute, TensorDistributedAttribute +from .process import new_process_group, ProcessGroup, PROCESS_GROUP_MAP + + +class AllGatherOpDesc: + """ + Describe the allgather op in the reshard phase. + + Args: + group (list): Process group. + """ + + def __init__(self, group): + self._group = group + self._desc = "all_gather" + + @property + def group(self): + return self._group + + @property + def desc(self): + return self._desc + + def __repr__(self): + return f"op: {self._desc}, group: {self._group}." + + +class SendOpDesc: + """ + Describe the send op in the reshard phase. + + Args: + partition_index (list): The index of partition in complete tensor. + dst (int): The destination process to receive. + """ + + def __init__(self, partition_index, dst): + self._dst = dst + self._partition_index = partition_index + self._desc = "send" + + @property + def partition_index(self): + return self._partition_index + + @property + def dst(self): + return self._dst + + @property + def desc(self): + return self._desc + + def __repr__(self): + return f"op: {self._desc}, partition_index: {self._partition_index}, dst: {self._dst}." + + +class RecvOpDesc: + """ + Describe the recv op in the reshard op. + + Args: + partition_index (list): The index of partition in complete tensor. + src (int): The source process to send. + """ + + def __init__(self, partition_index, src): + self._src = src + self._partition_index = partition_index + self._desc = "recv" + + @property + def partition_index(self): + return self._partition_index + + @property + def src(self): + return self._src + + @property + def desc(self): + return self._desc + + def __repr__(self): + return f"op: {self._desc}, partition_index: {self._partition_index}, src: {self._src}." + + +class SliceOpDesc: + """ + Describe the slice op in the reshard phase. + + Args: + starts (list): It represents starting indices of corresponding axis in ``axes``. + ends (list): It represents ending indices of corresponding axis in ``axes``. + axes (list): Axes that `starts` and `ends` apply to . + """ + + def __init__(self, starts, ends, axes): + self._starts = starts + self._ends = ends + self._axes = axes + self._desc = "slice" + + @property + def starts(self): + return self._starts + + @property + def ends(self): + return self._ends + + @property + def axes(self): + return self._axes + + @property + def desc(self): + return self._desc + + def __repr__(self): + return f"op: {self._desc}, starts: {self._starts}, ends: {self._ends}, axes: {self._axes}." + + +class ConcatOpDesc: + """ + Describe the concat op in the reshard phase. + + Args: + partition_index_list (list): A list contains all partition index. + """ + + def __init__(self, partition_index_list): + self._partition_index_list = partition_index_list + self._desc = "concat" + + @property + def partition_index_list(self): + return self._partition_index_list + + @property + def desc(self): + return self._desc + + def __repr__(self): + return f"op: {self._desc}, partition_index_list: {self._partition_index_list}." + + +def _compute_partition_shape(complete_shape, dims_mapping, process_shape): + """Compute the shape of partition.""" + partition_shape = [] + for idx, item in enumerate(complete_shape): + if dims_mapping[idx] == -1: + partition_shape.append(item) + else: + partition_shape.append(item // process_shape[dims_mapping[idx]]) + + return partition_shape + + +def _compute_process_index(process, process_group, process_shape): + """Compute the index of process_shape corresponding to the process.""" + relative_process = process_group.index(process) + process_index = [] + product = reduce(lambda x, y: x * y, process_shape) + + for i in range(len(process_shape)): + idx = relative_process // (product // process_shape[i]) + product = product // process_shape[i] + relative_process = relative_process - relative_process // product * product + process_index.append(idx) + + return process_index + + +def _compute_partition_index(process, complete_shape, dims_mapping, + process_shape, process_group): + """Compute the partition index in complete tensor.""" + partition_shape = _compute_partition_shape(complete_shape, dims_mapping, + process_shape) + process_index = _compute_process_index(process, process_group, + process_shape) + partition_index = [] + + for i in range(len(complete_shape)): + if dims_mapping[i] == -1: + partition_index.append([0, partition_shape[i]]) + else: + partition_index.append([ + process_index[dims_mapping[i]] * partition_shape[i], + (process_index[dims_mapping[i]] + 1) * partition_shape[i] + ]) + + return partition_index + + +def _compute_concat_info(partition_index_x, partition_index_y): + """Judge whether two partition can be concatenated and compute concatenated partition index.""" + differ_count = 0 + concat_axis = -1 + first_order = 0 + new_partition = [] + + for idx, item in enumerate(partition_index_x): + if item != partition_index_y[idx]: + differ_count += 1 + if item[1] == partition_index_y[idx][0] and item[ + 0] < partition_index_y[idx][1]: + concat_axis = idx + new_partition.append([item[0], partition_index_y[idx][1]]) + elif item[0] == partition_index_y[idx][1] and item[ + 1] > partition_index_y[idx][0]: + first_order = 1 + concat_axis = idx + new_partition.append([partition_index_y[idx][0], item[1]]) + else: + new_partition.append(item) + + if differ_count == 1: + return concat_axis, first_order, new_partition + else: + return -1, first_order, new_partition + + +def _concat_partitions(partition_index_list, partition_index): + """Concat the given partitions without inserting concat op.""" + if not partition_index_list: + partition_index_list.append(partition_index) + else: + i = 0 + has_concat = False + while i < len(partition_index_list): + concat_axis, _, new_partition = _compute_concat_info( + partition_index_list[i], partition_index) + if concat_axis != -1: + has_concat = True + partition_index_list.pop(i) + _concat_partitions(partition_index_list, new_partition) + break + i += 1 + if not has_concat: + partition_index_list.append(partition_index) + + +def _is_overlapped(shape_x, shape_y): + """Judge whether two partitions intersect on the specified dimension.""" + overlapped = False + if (shape_y[0] <= shape_x[0] < shape_y[1]) or ( + shape_x[0] <= shape_y[0] < shape_x[1]): + overlapped = True + return overlapped + + +def _need_reshard(tensor_dist_attr, op_dist_attr): + """Judge the tensor whether needs to be resharded.""" + is_reshard = False + tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() + tensor_process_mesh = tensor_dist_attr.get_process_mesh() + op_input_dims_mapping = op_dist_attr.get_input_dims_mapping( + tensor_dist_attr.get_owner_tensor().name) + op_process_mesh = op_dist_attr.get_process_mesh() + if all( + map(lambda x: x is not None, [ + tensor_dims_mapping, tensor_process_mesh, op_input_dims_mapping, + op_process_mesh + ])): + if tensor_dims_mapping != op_input_dims_mapping or tensor_process_mesh._id != op_process_mesh._id: + is_reshard = True + return is_reshard + + +def _compute_complete_shape(slice_shape, process_shape, dims_mapping): + """compute the complete shape of the slice tensor with its process mesh and dims mapping""" + complete_shape = [] + for idx, item in enumerate(slice_shape): + if dims_mapping[idx] == -1: + complete_shape.append(item) + else: + complete_shape.append(item * process_shape[dims_mapping[idx]]) + return complete_shape + + +def find_op_desc_seq(source_tensor, tensor_dist_attr, op_dist_attr): + """ + Find the op description sequence to reshard the source tensor for matching the op requirement. + + Args: + source_tensor (Variable): A tensor with distributed attribute. + tensor_dist_attr (TensorDistributedAttribute): The distributed attribute of tensor. + op_dist_attr (OperatorDistributedAttribute): The distributed attribute of operator. + + Returns: + Dict, the dict represents the required op description sequence corresponding to process, The key of dict is + process and value is a list containing op description. + """ + source_dims_mapping = tensor_dist_attr.get_dims_mapping() + source_process_mesh = tensor_dist_attr.get_process_mesh() + source_process_group = source_process_mesh.process_group + source_process_shape = source_process_mesh.topology + + target_process_mesh = op_dist_attr.get_process_mesh() + target_dims_mapping = op_dist_attr.get_input_dims_mapping( + tensor_dist_attr.get_owner_tensor().name) + target_process_group = target_process_mesh.process_group + target_process_shape = target_process_mesh.topology + + complete_shape = _compute_complete_shape( + source_tensor.shape, source_process_shape, source_dims_mapping) + op_desc_seq = {} + + # TODO: if the target process group has the same process with source process group + if set(target_process_group).intersection(set( + source_process_group)) and set(target_process_group).difference( + set(source_process_group)): + pass + + # in the different process group, it will use send, recv, concat and slice op + elif target_process_group != source_process_group: + partition_process_mapping_list = [] + for source_process in source_process_group: + source_partition_index = _compute_partition_index(source_process, complete_shape, source_dims_mapping, \ + source_process_shape, source_process_group) + if not partition_process_mapping_list: + partition_process_mapping_list.append( + [source_partition_index, [source_process], [False]]) + else: + partition_list = list( + [item[0] for item in partition_process_mapping_list]) + process_list = list( + [item[1] for item in partition_process_mapping_list]) + has_used = list( + [item[2] for item in partition_process_mapping_list]) + if partition_list.count(source_partition_index) == 1: + index = partition_list.index(source_partition_index) + process_list[index].append(source_process) + has_used[index].append(False) + else: + partition_process_mapping_list.append( + [source_partition_index, [source_process], [False]]) + + for target_process in target_process_group: + has_sent = [] + target_partition_index = _compute_partition_index( + target_process, complete_shape, target_dims_mapping, + target_process_shape, target_process_group) + partition_index_list = [] + all_partition_index_list = [] + for source_process in source_process_group: + source_partition_index = _compute_partition_index( + source_process, complete_shape, source_dims_mapping, + source_process_shape, source_process_group) + to_send_process = None + if all(_ for _ in list(map(_is_overlapped, source_partition_index, target_partition_index))) \ + and source_partition_index not in has_sent: + idx = list([ + item[0] for item in partition_process_mapping_list + ]).index(source_partition_index) + has_used = list( + [item[2] + for item in partition_process_mapping_list])[idx] + process_list = list( + [item[1] + for item in partition_process_mapping_list])[idx] + i = 0 + while i < len(has_used): + if not has_used[i]: + to_send_process = process_list[i] + has_used[i] = True + break + i += 1 + if i == len(has_used): + has_used = list(map(lambda x: False, has_used)) + to_send_process = process_list[0] + has_used[0] = True + assert to_send_process is not None, "Failed to find the send process." + + if to_send_process not in op_desc_seq.keys(): + op_desc_seq[to_send_process] = [] + if target_process not in op_desc_seq.keys(): + op_desc_seq[target_process] = [] + all_partition_index_list.append(source_partition_index) + + # append send and recv op desc + send_op_desc = SendOpDesc(source_partition_index, + target_process) + recv_op_desc = RecvOpDesc(source_partition_index, + to_send_process) + op_desc_seq[to_send_process].append(send_op_desc) + op_desc_seq[target_process].append(recv_op_desc) + has_sent.append(source_partition_index) + _concat_partitions(partition_index_list, + source_partition_index) + + # append concat op desc + op_desc_seq[target_process].append( + ConcatOpDesc(all_partition_index_list)) + + # append slice op desc + slice_starts = [] + slice_ends = [] + slices_axes = [] + concatenated_partition_index = partition_index_list[0] + for idx, item in enumerate(concatenated_partition_index): + slice_starts.append(target_partition_index[idx][0] - item[0]) + slice_ends.append(target_partition_index[idx][1] - item[0]) + slices_axes.append(idx) + op_desc_seq[target_process].append( + SliceOpDesc(slice_starts, slice_ends, slices_axes)) + + # in the same process group, it will use allgahther and slice op + else: + partition_index_list = [] + all_partition_index_list = [] + process_index = [] + for source_process in source_process_group: + source_partition_index = _compute_partition_index( + source_process, complete_shape, source_dims_mapping, + source_process_shape, source_process_group) + if source_partition_index not in partition_index_list: + partition_index_list.append(source_partition_index) + process_index.append( + [[source_process, ], source_partition_index]) + else: + process_index[partition_index_list.index( + source_partition_index)][0].append(source_process) + + for i in range(len(process_index[0][0])): + group = [] + for j in range(len(process_index)): + group.append(process_index[j][0][i]) + if i == 0: + all_partition_index_list.append(process_index[j][1]) + for process in group: + # append slice op desc + slice_starts = [] + slice_ends = [] + slices_axes = [] + target_partition_index = _compute_partition_index( + process, complete_shape, target_dims_mapping, + target_process_shape, target_process_group) + for idx, item in enumerate(target_partition_index): + slice_starts.append(item[0]) + slice_ends.append(item[1]) + slices_axes.append(idx) + + slice_op_desc = SliceOpDesc( + starts=slice_starts, ends=slice_ends, axes=slices_axes) + op_desc_seq[process] = [AllGatherOpDesc(group=group), + ConcatOpDesc(partition_index_list=all_partition_index_list), slice_op_desc] \ + if len(group) > 1 else [slice_op_desc] + + return op_desc_seq + + +def _insert_send_op(block, idx, tensor, dst): + """Insert send op into block at the given index.""" + op_type = 'send_v2' + block._insert_op( + idx, + type=op_type, + inputs={'X': [tensor]}, + attrs={ + 'ring_id': 0, + 'peer': dst, + 'use_calc_stream': True, + }) + + +def _insert_recv_op(block, idx, tensor, src): + """Insert recv op into block at the given index.""" + op_type = 'recv_v2' + block._insert_op( + idx, + type=op_type, + inputs={'X': [tensor]}, + outputs={'Out': [tensor]}, + attrs={ + 'ring_id': 0, + 'peer': src, + 'out_shape': tensor.shape, + 'dtype': tensor.dtype, + 'use_calc_stream': True, + }) + + +def _insert_concat_op(block, idx, tensors, axis): + """Insert concat op into block at the given block.""" + inputs = {'X': tensors} + attrs = {} + attrs['axis'] = axis + helper = LayerHelper('concat', **locals()) + with paddle.static.program_guard(block.program): + out = helper.create_variable_for_type_inference( + dtype=helper.input_dtype()) + block._insert_op( + idx, type='concat', inputs=inputs, outputs={'Out': [out]}, attrs=attrs) + return out + + +def _insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name): + """Insert slice op into block at the given block.""" + inputs = {'Input': tensor} + infer_flags = list(1 for i in range(len(axes))) + attrs = { + "axes": axes, + "starts": starts, + "ends": ends, + "infer_flags": infer_flags + } + helper = LayerHelper('slice', **locals()) + out = block.create_var( + name=new_var_name, + dtype=tensor.dtype, + type=core.VarDesc.VarType.LOD_TENSOR) + block._insert_op( + idx, type="slice", inputs=inputs, outputs={'Out': [out]}, attrs=attrs) + return out + + +def _insert_split_op(block, idx, tensor, num_or_sections): + """Insert split op into block at the given index.""" + helper = LayerHelper('split', **locals()) + input_shape = tensor.shape + inputs = {'X': tensor} + attrs = {'num': num_or_sections, "axis": 0} + with paddle.static.program_guard(block.program): + outs = [ + helper.create_variable_for_type_inference( + dtype=helper.input_dtype()) for i in range(num_or_sections) + ] + block._insert_op( + idx, type="split", inputs=inputs, outputs={'Out': outs}, attrs=attrs) + return outs + + +def _insert_allgather_op(block, idx, tensor, ranks): + """Insert allgather op into block at the given index.""" + + def _insert_fill_constant_op(block, idx): + """Insert fill constant op into block at the given index.""" + helper = LayerHelper("fill_constant", **locals()) + with paddle.static.program_guard(block.program): + out = helper.create_variable_for_type_inference(dtype="int32") + inputs = {} + attrs = {'force_cpu': False} + attrs['str_value'] = str(int("1")) + attrs['value'] = int("1") + attrs['dtype'] = out.dtype + utils.get_shape_tensor_inputs( + inputs=inputs, attrs=attrs, shape=[0], op_type='fill_constant') + block._insert_op( + idx, + type='fill_constant', + inputs=inputs, + outputs={'Out': [out]}, + attrs=attrs) + out.stop_gradient = True + return out + + tensor_list = [] + group = new_process_group(ranks) + idx_offset = 0 + + # instant process group before insert allgather op. + if not group.is_instantiate(): + # insert fill_constant op + fill_constant_out = _insert_fill_constant_op(block, idx) + fill_constant_out.stop_gradient = True + + # insert c_allreduce_sum op + block._insert_op( + idx + 1, + type="c_allreduce_sum", + inputs={'X': [fill_constant_out]}, + outputs={'Out': [fill_constant_out]}, + attrs={'ring_id': 0, + 'use_calc_stream': True}) + + # insert c_sync_calc_stream op + block._insert_op( + idx + 2, + type="c_sync_calc_stream", + inputs={'X': [fill_constant_out]}, + outputs={'Out': [fill_constant_out]}) + idx_offset = 3 + + # insert c_allgather op + op_type = 'c_allgather' + helper = LayerHelper(op_type, **locals()) + with paddle.static.program_guard(block.program): + allgather_out = helper.create_variable_for_type_inference( + dtype=tensor.dtype) + block._insert_op( + idx + idx_offset, + type=op_type, + inputs={'X': [tensor]}, + outputs={'Out': [allgather_out]}, + attrs={ + 'ring_id': group.id, + 'use_calc_stream': True, + 'nranks': group._nranks + }) + idx_offset += 1 + + # insert split op + split_out = _insert_split_op(block, idx + idx_offset, allgather_out, + group._nranks) + idx_offset += 1 + tensor_list.extend(split_out) + return tensor_list, idx_offset + + +def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index, + block, idx): + """Concat the tensors and insert concat op.""" + if not partition_tensor_list: + partition_tensor_list.append((tensor, partition_index)) + else: + i = 0 + has_concat = False + while i < len(partition_tensor_list): + concat_axis, first_order, new_partition = _compute_concat_info( + partition_tensor_list[i][1], partition_index) + if concat_axis != -1: + has_concat = True + _ = _insert_concat_op(block, idx[0], [partition_tensor_list[i][0], tensor], concat_axis) \ + if first_order == 0 else \ + _insert_concat_op(block, idx[0], [tensor, partition_tensor_list[i][0]], concat_axis) + partition_tensor_list.pop(i) + idx[0] += 1 + _concat_partitions_with_op(partition_tensor_list, _, + new_partition, block, idx) + break + i += 1 + if not has_concat: + partition_tensor_list.append((tensor, partition_index)) + + +def _init_comm_for_send_recv(): + if not PROCESS_GROUP_MAP["global_group"].is_instantiate(): + PROCESS_GROUP_MAP["global_group"].instantiate() + + +HAS_SENT = {} +HAS_RECV = {} +HAS_ALLGATHER = {} + + +def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op, + dist_context): + """Parse op desc sequence and insert op in the block""" + global HAS_SENT + global HAS_RECV + global HAS_ALLGATHER + tensor_list = [] + partition_tensor_list = [] + if rank_id not in op_desc_seq.keys(): + return + op_desc_list = op_desc_seq[rank_id] + block = program.global_block() + assert var_name in block.vars.keys( + ), "The {} cannot be found in the {} program.".format(var_name, rank_id) + + idx = None + for index, op in list(enumerate(block.ops)): + if op.desc.id == reshard_op.desc.id: + idx = index + break + assert idx is not None, "The op for reshard cannot be found in the rank {} program.".format( + rank_id) + + matched_op = block.ops[idx] + source_tensor = block.vars[var_name] + for op_desc in op_desc_list: + if isinstance(op_desc, AllGatherOpDesc): # noqa: F401 + if var_name not in HAS_ALLGATHER.keys(): + HAS_ALLGATHER[var_name] = [] + if not HAS_ALLGATHER[var_name] or op_desc.group not in list( + map(lambda x: x[0], HAS_ALLGATHER[var_name])): + tensor_list, idx_offset = _insert_allgather_op( + block, idx, source_tensor, op_desc.group) + idx += idx_offset + tensor_name_list = [var.name for var in tensor_list] + HAS_ALLGATHER[var_name].append( + [op_desc.group, tensor_name_list]) + else: + for item in HAS_ALLGATHER[var_name]: + if op_desc.group == item[0]: + tensor_list = [ + program.global_block().vars[var_name] + for var_name in item[1] + ] + break + assert tensor_list, "The result of parsing allgather op should not be None." + + elif isinstance(op_desc, SendOpDesc): + _init_comm_for_send_recv() + if var_name not in HAS_SENT.keys(): + HAS_SENT[var_name] = [] + if op_desc.dst not in HAS_SENT[var_name]: + _insert_send_op(block, idx, source_tensor, op_desc.dst) + idx += 1 + HAS_SENT[var_name].append(op_desc.dst) + + elif isinstance(op_desc, RecvOpDesc): + _init_comm_for_send_recv() + if var_name not in HAS_RECV.keys(): + HAS_RECV[var_name] = {} + if op_desc.src not in HAS_RECV[var_name].keys(): + partition_index = op_desc.partition_index + shape = [] + for index in partition_index: + shape.append(index[1] - index[0]) + recv_tensor = block.create_var( + name=unique_name.generate(var_name + "@recv"), + shape=shape, + dtype=source_tensor.dtype) + _insert_recv_op(block, idx, recv_tensor, op_desc.src) + tensor_list.append(recv_tensor) + idx += 1 + HAS_RECV[var_name][op_desc.src] = recv_tensor + else: + tensor_list.append(HAS_RECV[var_name][op_desc.src]) + + elif isinstance(op_desc, ConcatOpDesc): + partition_index_list = op_desc.partition_index_list + idx_list = [idx] + for index, tensor in enumerate(tensor_list): + _concat_partitions_with_op(partition_tensor_list, tensor, + partition_index_list[index], block, + idx_list) + idx = idx_list[0] + + elif isinstance(op_desc, SliceOpDesc): + assert len(partition_tensor_list) == 1 or not partition_tensor_list + to_slice_tensor = partition_tensor_list[0][0] if len( + partition_tensor_list) == 1 else source_tensor + new_name = unique_name.generate(var_name + "@RESHARD") + target_tensor = _insert_slice_op( + block, + idx, + to_slice_tensor, + starts=op_desc.starts, + ends=op_desc.ends, + axes=op_desc.axes, + new_var_name=new_name) + + tensor_attr = TensorDistributedAttribute(target_tensor, + dist_context) + process_mesh = dist_context.get_op_distributed_attr_for_program( + matched_op).get_process_mesh() + dims_mapping = dist_context.get_op_distributed_attr_for_program( + matched_op).get_input_dims_mapping(var_name) + tensor_attr.set_dims_mapping(dims_mapping) + tensor_attr.set_process_mesh(process_mesh) + dist_context.set_tensor_distributed_attr_for_program(target_tensor, + tensor_attr) + + # rename op input name according to new name + for op in block.ops: + for name in op.input_arg_names: + op_dist_attr = dist_context.get_op_distributed_attr_for_program( + op) + if name == var_name and op_dist_attr is not None: + op_process_mesh = op_dist_attr.get_process_mesh() + op_input_dims_mapping = op_dist_attr.get_input_dims_mapping( + var_name) + if op_process_mesh._id == process_mesh._id and op_input_dims_mapping == dims_mapping: + op.desc._rename_input(name, target_tensor.name) + op_dist_attr.set_input_dims_mapping( + target_tensor.name, dims_mapping) + op_dist_attr._dims_mapping.pop(name, None) + + +def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id): + """Remove no need ops in the main program""" + not_remove_op_ref = [ + "create_py_reader", "create_double_buffer_reader", "read" + ] + remove_op_idx = [] + block = auto_parallel_main_prog.global_block() + ops = block.ops + vars = block.vars + for idx, op in enumerate(ops): + # handle read op in the pipeline scene specially, it will be removed in the future. + if op.type == "read": + dim_list = [] + for var_name in op.output_arg_names: + dim_list.extend(vars[var_name].shape) + for i in range(idx, -1, -1): + if ops[i].type == "create_py_reader": + ops[i]._set_attr("shape_concat", dim_list) + break + continue + + # replace the input and output of c_sync_comm_stream op when in pipeline scene. + if op.type == "c_sync_comm_stream": + need_save = [] + for var_name in op.input_arg_names: + process_mesh = dist_context.get_tensor_distributed_attr_for_program( + vars[var_name]).get_process_mesh() + if rank_id in process_mesh.process_group: + need_save.append(var_name) + if not need_save: + remove_op_idx.append(idx) + continue + + proto = OpProtoHolder.instance().get_op_proto(op.type) + op.desc.set_input(proto.inputs[0].name, need_save) + op.desc.set_output(proto.outputs[0].name, need_save) + continue + + # judge the other op whether should be removed. + op_dist_attr = dist_context.get_op_distributed_attr_for_program(op) + if op_dist_attr is not None: + op_process_mesh = op_dist_attr.get_process_mesh() + if rank_id not in op_process_mesh.process_group and op.type not in not_remove_op_ref: + remove_op_idx.append(idx) + + for idx in remove_op_idx[::-1]: + block._remove_op(idx) + + +def _remove_no_need_vars(auto_parallel_main_prog): + """Remove no need vars in the main program""" + remove_vars = set() + block = auto_parallel_main_prog.global_block() + ops = block.ops + vars = block.vars + need_vars = set() + for op in ops: + for var_name in op.input_arg_names: + if var_name in vars: + need_vars.add(var_name) + for var_name in op.output_arg_names: + if var_name in vars: + need_vars.add(var_name) + for var in vars: + if var not in need_vars: + remove_vars.add(var) + for var in remove_vars: + block._remove_var(var) + + +def remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id): + """Remove no need vars and ops in the main program.""" + _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id) + _remove_no_need_vars(auto_parallel_main_prog) + + +def remove_no_need_in_startup(auto_parallel_main_prog, + auto_parallel_startup_prog): + """Remove no need vars and ops in the startup program.""" + main_input_vars = set() + main_ops = auto_parallel_main_prog.global_block().ops + for op in main_ops: + for var_name in op.input_arg_names: + main_input_vars.add(var_name) + + startup_block = auto_parallel_startup_prog.global_block() + startup_output_vars = set() + startup_ops = startup_block.ops + for op in startup_ops: + # skip c_sync_comm_stream op + if op.type == "c_sync_comm_stream": + continue + for var_name in op.output_arg_names: + startup_output_vars.add(var_name) + + need_vars = set() + for var_name in startup_output_vars: + if var_name in main_input_vars: + need_vars.add(var_name) + + startup_ops = startup_block.ops + actual_need_vars = set() + for idx, op in enumerate(startup_ops): + is_need_op = False + if op.type == "c_sync_comm_stream": + continue + for var_name in op.output_arg_names: + if var_name in need_vars: + is_need_op = True + break + if is_need_op: + for var_name in op.output_arg_names: + actual_need_vars.add(var_name) + for var_name in op.input_arg_names: + actual_need_vars.add(var_name) + + remove_vars = set() + for var_name in startup_block.vars: + if var_name not in actual_need_vars: + remove_vars.add(var_name) + for var in remove_vars: + startup_block._remove_var(var) + + remove_op_idx = [] + vars = startup_block.vars + for idx, op in enumerate(startup_block.ops): + is_no_need_op = False + if op.type == "c_sync_comm_stream": + var_names = [] + for var_name in op.input_arg_names: + if var_name in vars: + var_names.append(var_name) + if not var_names: + remove_op_idx.append(idx) + else: + proto = OpProtoHolder.instance().get_op_proto(op.type) + op.desc.set_input(proto.inputs[0].name, var_names) + op.desc.set_output(proto.outputs[0].name, var_names) + continue + + for var_name in op.output_arg_names: + if var_name not in vars: + is_no_need_op = True + break + if is_no_need_op: + remove_op_idx.append(idx) + for idx in remove_op_idx[::-1]: + startup_block._remove_op(idx) + + +def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, + dist_context): + """ + Reshard tensor in the program according to its dist attr and corresponding op dist attr. + + Args: + auto_parallel_main_prog (Program): An auto parallel main program. + auto_parallel_startup_prog (Program): An auto parallel startup program. + rank_id (int): The process id. + """ + assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_main_prog should be Program, " \ + "but got {}.".format(type(auto_parallel_main_prog)) + assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_startup_prog should be Program, " \ + "but got {}.".format(type(auto_parallel_startup_prog)) + assert isinstance(rank_id, int), "The type of rank_id should be int, " \ + "but got {}.".format(type(rank_id)) + assert isinstance(dist_context, DistributedContext), "The type of dist_context should be DistributedContext, " \ + "but got {}.".format(type(dist_context)) + + block = auto_parallel_main_prog.global_block() + idx = 0 + while idx < len(block.ops): + pre_op_count = len(block.ops) + op = block.ops[idx] + op_dist_attr = dist_context.get_op_distributed_attr_for_program(op) + if op_dist_attr is not None: + idx_offset = 0 + for var_name in op.input_arg_names: + # skip lod_tensor_blocking_queue_0 + if var_name == "lod_tensor_blocking_queue_0": + continue + var = block.vars[var_name] + tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program( + var) + if tensor_dist_attr is not None and _need_reshard( + tensor_dist_attr, op_dist_attr): + reshard_op_desc = find_op_desc_seq(var, tensor_dist_attr, + op_dist_attr) + parse_op_desc(auto_parallel_main_prog, rank_id, + reshard_op_desc, var_name, op, dist_context) + cur_op_count = len(block.ops) + idx_offset = idx_offset + cur_op_count - pre_op_count + pre_op_count = cur_op_count + idx = idx + idx_offset + 1 + else: + idx += 1 + + # remove no need vars and ops in the main program + remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id) + + # remove no need vars and ops in the startip program + remove_no_need_in_startup(auto_parallel_main_prog, + auto_parallel_startup_prog) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 61a43aeb44e..0c2731bc452 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -86,6 +86,10 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_auto) list(APPEND MIXED_DIST_TEST_OPS test_fleet_static_mp_layers) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner_gpt) +list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard) +list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_serial) +list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_mppp) +list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_dpmppp) foreach(TEST_OP ${MIXED_DIST_TEST_OPS}) list(REMOVE_ITEM TEST_OPS ${TEST_OP}) endforeach() @@ -225,6 +229,10 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) LIST(REMOVE_ITEM TEST_OPS test_parallel_margin_cross_entropy) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner_gpt) + LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard) + LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_serial) + LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_mppp) + LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_dpmppp) elseif(WITH_GPU) if (${CUDNN_VERSION} VERSION_LESS 7100) LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op) @@ -589,6 +597,10 @@ if(WITH_DISTRIBUTE) py_test_modules(test_fleet_lamb_meta_optimizer MODULES test_fleet_lamb_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_partitioner MODULES test_auto_parallel_partitioner ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_partitioner_gpt MODULES test_auto_parallel_partitioner_gpt ENVS ${dist_ENVS}) + py_test_modules(test_auto_parallel_reshard MODULES test_auto_parallel_reshard ENVS ${dist_ENVS}) + py_test_modules(test_auto_parallel_reshard_serial MODULES test_auto_parallel_reshard_serial ENVS ${dist_ENVS}) + py_test_modules(test_auto_parallel_reshard_mppp MODULES test_auto_parallel_reshard_mppp ENVS ${dist_ENVS}) + py_test_modules(test_auto_parallel_reshard_dpmppp MODULES test_auto_parallel_reshard_dpmppp ENVS ${dist_ENVS}) endif(NOT WIN32) endif(NOT APPLE) if(WITH_DGC) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py new file mode 100644 index 00000000000..89e9b7e817f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py @@ -0,0 +1,287 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import paddle +import paddle.nn as nn +import paddle.static as static +import paddle.nn.functional as F +import paddle.utils as utils +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.context import DistributedContext +from paddle.distributed import fleet +from paddle.distributed.auto_parallel.partitioner import Partitioner +from paddle.distributed.auto_parallel.completion import complete_backward_annotation +from paddle.distributed.auto_parallel.reshard import reshard + +paddle.enable_static() +_global_parallel_strategy = None +_global_process_mesh = None +ROOT_MESH = auto.ProcessMesh([0, 1]) +PP_MESH_0 = None +PP_MESH_1 = None + + +class MLPLayer(nn.Layer): + def __init__(self, + hidden_size=1024, + intermediate_size=4 * 1024, + initializer_range=0.02): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range)) + bias_attr = None + + self.linear0 = nn.Linear( + d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) + self.linear1 = nn.Linear( + dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + + def forward(self, input): + if _global_parallel_strategy == "pp": + auto.shard_tensor( + self.linear0.weight, PP_MESH_0, dim_mapping=[-1, -1]) + auto.shard_tensor( + self.linear1.weight, PP_MESH_1, dim_mapping=[-1, -1]) + else: + auto.shard_tensor( + self.linear0.weight, _global_process_mesh, + dim_mapping=[-1, -1]) + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, + dim_mapping=[-1, -1]) + + out = self.norm(input) + out = self.linear0(out) + out = F.gelu(out, approximate=True) + out = self.linear1(out) + + return out + + +def mlp_forward(train_program, start_program): + with static.program_guard(train_program, + start_program), utils.unique_name.guard(): + batch_size = 4 + hidden_size = 1024 + sequence_len = 512 + input = static.data( + name="input", shape=[batch_size, hidden_size], dtype='float32') + label = static.data( + name="label", shape=[batch_size, 1], dtype='float32') + + if _global_parallel_strategy == "pp": + auto.shard_tensor(input, PP_MESH_0, dim_mapping=[-1, -1]) + auto.shard_tensor(label, PP_MESH_1, dim_mapping=[-1, -1]) + elif _global_parallel_strategy == "dp": + auto.shard_tensor(input, _global_process_mesh, dim_mapping=[0, -1]) + else: + auto.shard_tensor(input, _global_process_mesh, dim_mapping=[-1, -1]) + + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + initializer_range=0.02) + + predict = mlp(input) + error_cost = paddle.nn.functional.square_error_cost(predict, label) + loss = paddle.mean(error_cost) + + return loss, train_program, start_program + + +def get_dist_prog(train_program, startup_program, dist_context, rank_id): + global _global_process_mesh + dist_context.set_process_mesh(_global_process_mesh) + loss, train_program, startup_program = mlp_forward(train_program, + startup_program) + + # auto completion + complete_train_program = auto.complete_annotation(train_program, + dist_context) + + dist_strategy = fleet.DistributedStrategy() + partitioner = Partitioner(dist_strategy, dist_context, rank_id) + # logical partition + auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward( + complete_train_program, startup_program) + dist_params_grads = partitioner.apply_backward( + loss, complete_train_program, startup_program, auto_parallel_main_prog, + auto_parallel_startup_prog) + optimizer = paddle.fluid.optimizer.AdamOptimizer() + opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads, + auto_parallel_main_prog, + auto_parallel_startup_prog) + return auto_parallel_main_prog, auto_parallel_startup_prog + + +def check_backward_dist_attr(dist_context, dist_main_prog, op_need_check): + has_dist_attr = True + vars = dist_main_prog.global_block().vars + + op_dist_attr = dist_context.get_op_distributed_attr_for_program( + op_need_check) + if not op_dist_attr or not op_dist_attr.get_process_mesh(): + has_dist_attr = False + + for var_name in op_need_check.input_arg_names: + if not op_dist_attr.get_input_dims_mapping(var_name) or \ + not dist_context.get_tensor_distributed_attr_for_program(vars[var_name]).get_dims_mapping() or \ + not dist_context.get_tensor_distributed_attr_for_program(vars[var_name]).get_process_mesh(): + has_dist_attr = False + break + + if has_dist_attr: + for var_name in op_need_check.output_arg_names: + if not dist_context.get_tensor_distributed_attr_for_program(vars[var_name]).get_dims_mapping() or \ + not dist_context.get_tensor_distributed_attr_for_program(vars[var_name]).get_process_mesh(): + has_dist_attr = False + break + + return has_dist_attr + + +def check_send_recv_result(dist_main_prog, rank_id): + send_result = False + recv_result = False + ops = dist_main_prog.global_block().ops + if rank_id == 0: + for idx, op in enumerate(ops): + if op.type == "send_v2" and "gelu_0.tmp_0" in op.input_arg_names: + send_result = True + if op.type == "recv_v2" and "gelu_0.tmp_0@GRAD" in op.output_arg_names[ + 0]: + recv_result = True + else: + for idx, op in enumerate(ops): + if op.type == "send_v2" and "gelu_0.tmp_0@GRAD" in op.input_arg_names: + send_result = True + if op.type == "recv_v2" and "gelu_0.tmp_0" in op.output_arg_names[ + 0]: + recv_result = True + + return send_result and recv_result + + +def check_initialization(dist_startup_prog, rank_id): + if rank_id == 0: + need_check_params = [ + "layer_norm_0.b_0", "layer_norm_0.w_0", "linear_0.w_0", + "linear_0.b_0" + ] + else: + need_check_params = ['linear_1.w_0', 'linear_1.b_0'] + + params = [] + for var_name, var in dist_startup_prog.global_block().vars.items(): + if var.is_parameter: + params.append(var_name) + + return params == need_check_params + + +def check_initialization_for_dp(dist_startup_prog): + need_check_params = [ + "layer_norm_0.b_0", "layer_norm_0.w_0", "linear_0.w_0", "linear_0.b_0" + ] + ['linear_1.w_0', 'linear_1.b_0'] + params = [] + for var_name, var in dist_startup_prog.global_block().vars.items(): + if var.is_parameter: + params.append(var_name) + broadcast_varnames = [] + for op in dist_startup_prog.global_block().ops: + if op.type == "c_broadcast": + broadcast_varnames.append(op.output_arg_names[0]) + + return params == need_check_params == broadcast_varnames + + +class TestMLPReshard(unittest.TestCase): + def test_complete_backward_annotation(self): + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH) + + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + dist_context = DistributedContext() + rank_id = 0 + dist_main_prog, dist_startup_prog = get_dist_prog( + train_program, startup_program, dist_context, 0) + complete_backward_annotation(dist_main_prog, dist_context) + + op_need_check = None + for op in dist_main_prog.global_block().ops: + if op.type == "gelu_grad": + op_need_check = op + break + + # grad op should have dist attr + self.assertTrue( + check_backward_dist_attr(dist_context, dist_main_prog, + op_need_check)) + + def test_mlp_pp(self): + global _global_parallel_strategy + _global_parallel_strategy = "pp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH) + global PP_MESH_0 + PP_MESH_0 = auto.ProcessMesh(mesh=[0], parent=ROOT_MESH) + global PP_MESH_1 + PP_MESH_1 = auto.ProcessMesh(mesh=[1], parent=ROOT_MESH) + + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + dist_context = DistributedContext() + rank_id = 1 + dist_main_prog, dist_startup_prog = get_dist_prog( + train_program, startup_program, dist_context, rank_id) + complete_backward_annotation(dist_main_prog, dist_context) + reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) + + # check send and recv result + self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) + + # parameter initialization of every rank should be different in the pipeline scene + self.assertTrue(check_initialization(dist_startup_prog, rank_id)) + + def test_mlp_dp(self): + global _global_parallel_strategy + _global_parallel_strategy = "dp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH) + + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + dist_context = DistributedContext() + rank_id = 0 + dist_main_prog, dist_startup_prog = get_dist_prog( + train_program, startup_program, dist_context, rank_id) + complete_backward_annotation(dist_main_prog, dist_context) + reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) + # send and recv should not exist in dp scene. + self.assertFalse(check_send_recv_result(dist_main_prog, rank_id)) + + # all parameters should be initialized in dp scene + self.assertTrue(check_initialization_for_dp(dist_startup_prog)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py new file mode 100644 index 00000000000..1e134eebfd2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py @@ -0,0 +1,173 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import paddle +import paddle.nn as nn +import paddle.static as static +import paddle.nn.functional as F +import paddle.utils as utils +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.context import DistributedContext +from paddle.distributed import fleet +from paddle.distributed.auto_parallel.partitioner import Partitioner +from paddle.distributed.auto_parallel.completion import complete_backward_annotation +from paddle.distributed.auto_parallel.reshard import reshard + +paddle.enable_static() +_global_parallel_strategy = "dp_mp_pp" +ROOT_MESH = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]]) +_global_process_mesh = auto.ProcessMesh( + [[[0, 1], [4, 5]], [[2, 3], [6, 7]]], parent=ROOT_MESH) +PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], parent=ROOT_MESH) +PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], parent=ROOT_MESH) + + +class MLPLayer(nn.Layer): + def __init__(self, + hidden_size=1024, + intermediate_size=4 * 1024, + initializer_range=0.02): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range)) + bias_attr = None + + self.linear0 = nn.Linear( + d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) + self.linear1 = nn.Linear( + dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + + def forward(self, input): + auto.shard_tensor(self.linear0.weight, PP_MESH_0, dim_mapping=[-1, 1]) + auto.shard_tensor(self.linear1.weight, PP_MESH_1, dim_mapping=[1, -1]) + + out = self.norm(input) + out = self.linear0(out) + out = F.gelu(out, approximate=True) + out = self.linear1(out) + + return out + + +def mlp_forward(train_program, start_program): + with static.program_guard(train_program, + start_program), utils.unique_name.guard(): + batch_size = 4 + hidden_size = 1024 + sequence_len = 512 + input = static.data( + name="input", shape=[batch_size, hidden_size], dtype='float32') + label = static.data( + name="label", shape=[batch_size, 1], dtype='float32') + + auto.shard_tensor(input, PP_MESH_0, dim_mapping=[0, -1]) + auto.shard_tensor(label, PP_MESH_1, dim_mapping=[0, -1]) + + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + initializer_range=0.02) + + predict = mlp(input) + error_cost = paddle.nn.functional.square_error_cost(predict, label) + loss = paddle.mean(error_cost) + + return loss, train_program, start_program + + +def get_dist_prog(train_program, startup_program, dist_context, rank_id): + global _global_process_mesh + dist_context.set_process_mesh(_global_process_mesh) + loss, train_program, startup_program = mlp_forward(train_program, + startup_program) + + # auto completion + complete_train_program = auto.complete_annotation(train_program, + dist_context) + + dist_strategy = fleet.DistributedStrategy() + partitioner = Partitioner(dist_strategy, dist_context, rank_id) + # logical partition + auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward( + complete_train_program, startup_program) + dist_params_grads = partitioner.apply_backward( + loss, complete_train_program, startup_program, auto_parallel_main_prog, + auto_parallel_startup_prog) + optimizer = paddle.fluid.optimizer.AdamOptimizer() + opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads, + auto_parallel_main_prog, + auto_parallel_startup_prog) + return auto_parallel_main_prog, auto_parallel_startup_prog + + +def check_send_recv_result(dist_main_prog, rank_id): + send_result = False + recv_result = False + ops = dist_main_prog.global_block().ops + if rank_id in [0, 1, 4, 5]: + for idx, op in enumerate(ops): + if op.type == "send_v2" and "gelu_0.tmp_0" in op.input_arg_names: + send_result = True + if op.type == "recv_v2" and "gelu_0.tmp_0@GRAD" in op.output_arg_names[ + 0]: + recv_result = True + else: + for idx, op in enumerate(ops): + if op.type == "send_v2" and "gelu_0.tmp_0@GRAD" in op.input_arg_names: + send_result = True + if op.type == "recv_v2" and "gelu_0.tmp_0" in op.output_arg_names[ + 0]: + recv_result = True + + return send_result and recv_result + + +def check_initialization_for_dpmppp(dist_startup_prog): + broadcast_varnames = [] + for op in dist_startup_prog.global_block().ops: + if op.type == "c_broadcast": + broadcast_varnames.append(op.output_arg_names[0]) + result = len(broadcast_varnames) > 0 + return result + + +class TestMLPReshard(unittest.TestCase): + def test_mlp_dpmppp(self): + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + dist_context = DistributedContext() + rank_id = 2 + dist_main_prog, dist_startup_prog = get_dist_prog( + train_program, startup_program, dist_context, rank_id) + print(dist_main_prog) + complete_backward_annotation(dist_main_prog, dist_context) + reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) + print(dist_main_prog) + print(dist_startup_prog) + # check send and recv result + self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) + + # check parameter initialization + self.assertTrue(check_initialization_for_dpmppp(dist_startup_prog)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py new file mode 100644 index 00000000000..5a10a218345 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py @@ -0,0 +1,231 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import paddle +import paddle.nn as nn +import paddle.static as static +import paddle.nn.functional as F +import paddle.utils as utils +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.context import DistributedContext +from paddle.distributed import fleet +from paddle.distributed.auto_parallel.partitioner import Partitioner +from paddle.distributed.auto_parallel.completion import complete_backward_annotation +from paddle.distributed.auto_parallel.reshard import reshard + +paddle.enable_static() +_global_parallel_strategy = "mp_pp" +ROOT_MESH = auto.ProcessMesh([[0, 1], [2, 3]]) +_global_process_mesh = auto.ProcessMesh([[0, 1], [2, 3]], parent=ROOT_MESH) +PP_MESH_0 = auto.ProcessMesh([0, 1], parent=ROOT_MESH) +PP_MESH_1 = auto.ProcessMesh([2, 3], parent=ROOT_MESH) + + +class MLPLayer(nn.Layer): + def __init__(self, + hidden_size=1024, + intermediate_size=4 * 1024, + initializer_range=0.02): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range)) + bias_attr = None + + self.word_embeddings = nn.Embedding( + hidden_size, + hidden_size, + weight_attr=paddle.ParamAttr( + name="word_embeddings", + initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range))) + + self.linear0 = nn.Linear( + d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) + self.linear1 = nn.Linear( + dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) + self.linear2 = nn.Linear( + dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) + + def forward(self, input): + auto.shard_tensor( + self.word_embeddings.weight, PP_MESH_0, dim_mapping=[0, -1]) + auto.shard_tensor(self.linear0.weight, PP_MESH_0, dim_mapping=[-1, 0]) + auto.shard_tensor(self.linear1.weight, PP_MESH_1, dim_mapping=[0, -1]) + auto.shard_tensor(self.linear2.weight, PP_MESH_1, dim_mapping=[0, -1]) + w_out = self.word_embeddings(input) + out = self.linear0(w_out) + gelu_out = F.gelu(out, approximate=True) + out = self.linear1(gelu_out) + out1 = self.linear2(gelu_out) + out = out + out1 + + return out + + +def mlp_forward(train_program, start_program): + with static.program_guard(train_program, + start_program), utils.unique_name.guard(): + batch_size = 4 + hidden_size = 1024 + sequence_len = 512 + input = static.data(name="input", shape=[batch_size], dtype='int32') + label = static.data( + name="label", shape=[batch_size, 1], dtype='float32') + + auto.shard_tensor(input, PP_MESH_0, dim_mapping=[-1]) + auto.shard_tensor(label, PP_MESH_1, dim_mapping=[-1, -1]) + + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + initializer_range=0.02) + + predict = mlp(input) + error_cost = paddle.nn.functional.square_error_cost(predict, label) + loss = paddle.mean(error_cost) + + return loss, train_program, start_program + + +def get_dist_prog(train_program, startup_program, dist_context, rank_id): + global _global_process_mesh + dist_context.set_process_mesh(_global_process_mesh) + loss, train_program, startup_program = mlp_forward(train_program, + startup_program) + + # auto completion + complete_train_program = auto.complete_annotation(train_program, + dist_context) + + dist_strategy = fleet.DistributedStrategy() + partitioner = Partitioner(dist_strategy, dist_context, rank_id) + # logical partition + auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward( + complete_train_program, startup_program) + dist_params_grads = partitioner.apply_backward( + loss, complete_train_program, startup_program, auto_parallel_main_prog, + auto_parallel_startup_prog) + optimizer = paddle.fluid.optimizer.AdamOptimizer() + opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads, + auto_parallel_main_prog, + auto_parallel_startup_prog) + return auto_parallel_main_prog, auto_parallel_startup_prog + + +def check_send_recv_result(dist_main_prog, rank_id): + send_result = False + recv_result = False + ops = dist_main_prog.global_block().ops + if rank_id in [0, 1]: + for idx, op in enumerate(ops): + if op.type == "send_v2" and "gelu_0.tmp_0" in op.input_arg_names: + send_result = True + if op.type == "recv_v2" and "gelu_0.tmp_0@GRAD" in op.output_arg_names[ + 0]: + recv_result = True + else: + for idx, op in enumerate(ops): + if op.type == "send_v2" and "gelu_0.tmp_0@GRAD" in op.input_arg_names[ + 0]: + send_result = True + if op.type == "recv_v2" and "gelu_0.tmp_0" in op.output_arg_names[ + 0]: + recv_result = True + + return send_result and recv_result + + +def check_initialization_for_mppp(dist_startup_prog, rank_id): + if rank_id in [0, 1]: + need_check_params = [] + else: + need_check_params = ["linear_1.b_0", "linear_2.b_0"] + broadcast_varnames = [] + for op in dist_startup_prog.global_block().ops: + if op.type == "c_broadcast": + broadcast_varnames.append(op.output_arg_names[0]) + + return need_check_params == broadcast_varnames + + +def check_allgather(dist_main_program): + allgather_out = "x@RESHARD_0" + var_result = False + op_result = False + vars = dist_main_program.global_block().vars + if allgather_out in vars and vars[allgather_out].shape == (4, 4): + var_result = True + for op in dist_main_program.global_block().ops: + if op.type == "matmul_v2": + if allgather_out in op.input_arg_names: + op_result = True + return var_result and op_result + + +class TestMLPReshard(unittest.TestCase): + def test_mlp_mppp(self): + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + dist_context = DistributedContext() + rank_id = 2 + dist_main_prog, dist_startup_prog = get_dist_prog( + train_program, startup_program, dist_context, rank_id) + complete_backward_annotation(dist_main_prog, dist_context) + reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) + + # check send and recv result + self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) + + # parameter which not been sliced should be the same in the mp scene + self.assertTrue( + check_initialization_for_mppp(dist_startup_prog, rank_id)) + + def test_allgather(self): + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + process_mesh = auto.ProcessMesh(mesh=[0, 3], parent=ROOT_MESH) + with static.program_guard(train_program, startup_program): + x = paddle.static.data(name="x", shape=[4, 4], dtype='float32') + x = auto.shard_tensor(x, process_mesh, dim_mapping=[0, -1]) + + w = paddle.static.data(name="w", shape=[4, 4], dtype='float32') + w = auto.shard_tensor(w, process_mesh, dim_mapping=[-1, -1]) + + y = paddle.distributed.shard_op(paddle.matmul, process_mesh, { + x.name: [-1, -1], + w.name: [-1, -1] + }, **{"x": x, + "y": w})[0] + + rank_id = 0 + dist_context = DistributedContext() + dist_strategy = fleet.DistributedStrategy() + partitioner = Partitioner(dist_strategy, dist_context, rank_id) + complete_train_program = auto.complete_annotation(train_program, + dist_context) + auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward( + complete_train_program, startup_program) + reshard(auto_parallel_main_prog, startup_program, rank_id, dist_context) + # the x should not be slice + self.assertTrue(check_allgather(auto_parallel_main_prog)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_serial.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_serial.py new file mode 100644 index 00000000000..bf2ba9f061f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_serial.py @@ -0,0 +1,184 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import os +if os.getenv("CUDA_VISIBLE_DEVICES", None) is None: + os.environ["CUDA_VISIBLE_DEVICES"] = '0' + +import paddle +import paddle.nn as nn +import paddle.static as static +import paddle.nn.functional as F +import paddle.utils as utils +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.context import get_default_distributed_context +from paddle.distributed import fleet +from paddle.distributed.auto_parallel.partitioner import Partitioner +from paddle.distributed.auto_parallel.reshard import reshard +from paddle.distributed.auto_parallel.process import new_process_group + +paddle.enable_static() +_global_parallel_strategy = None +_global_process_mesh = None +ROOT_MESH = auto.ProcessMesh([0]) + + +class MLPLayer(nn.Layer): + def __init__(self, + hidden_size=1024, + intermediate_size=4 * 1024, + initializer_range=0.02): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range)) + bias_attr = None + + self.linear0 = nn.Linear( + d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) + self.linear1 = nn.Linear( + dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + + def forward(self, input): + if _global_parallel_strategy == "pp": + auto.shard_tensor( + self.linear0.weight, PP_MESH_0, dim_mapping=[-1, -1]) + auto.shard_tensor( + self.linear1.weight, PP_MESH_1, dim_mapping=[-1, -1]) + else: + auto.shard_tensor( + self.linear0.weight, _global_process_mesh, + dim_mapping=[-1, -1]) + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, + dim_mapping=[-1, -1]) + + out = self.norm(input) + out = self.linear0(out) + out = F.gelu(out, approximate=True) + out = self.linear1(out) + + return out + + +def mlp_forward(train_program, start_program): + with static.program_guard(train_program, + start_program), utils.unique_name.guard(): + batch_size = 4 + hidden_size = 1024 + sequence_len = 512 + input = static.data( + name="input", shape=[batch_size, hidden_size], dtype='float32') + label = static.data( + name="label", shape=[batch_size, 1], dtype='float32') + + if _global_parallel_strategy == "pp": + auto.shard_tensor(input, PP_MESH_0, dim_mapping=[-1, -1]) + auto.shard_tensor(label, PP_MESH_1, dim_mapping=[-1, -1]) + elif _global_parallel_strategy == "dp": + auto.shard_tensor(input, _global_process_mesh, dim_mapping=[0, -1]) + else: + auto.shard_tensor(input, _global_process_mesh, dim_mapping=[-1, -1]) + + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + initializer_range=0.02) + + predict = mlp(input) + error_cost = paddle.nn.functional.square_error_cost(predict, label) + loss = paddle.mean(error_cost) + + return loss, train_program, start_program + + +def get_dist_prog_with_parallelizer(train_program, startup_program, + dist_context): + global _global_process_mesh + + dist_strategy = fleet.DistributedStrategy() + dist_strategy.amp = False + dist_strategy.pipeline = False + dist_strategy.recompute = False + + # init parallel optimizer + dist_strategy.semi_auto = True + fleet.init(is_collective=True, strategy=dist_strategy) + + loss, train_program, startup_program = mlp_forward(train_program, + startup_program) + + optimizer = paddle.fluid.optimizer.AdamOptimizer( + learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) + optimizer = fleet.distributed_optimizer(optimizer) + + # fake a comm group + pg = new_process_group([3, 4]) + _, _, distributed_startup_program, distributed_main_program = optimizer.minimize( + loss, startup_program) + + return distributed_main_program, distributed_startup_program + + +def check_send_recv_result(dist_main_prog, rank_id): + send_result = False + recv_result = False + ops = dist_main_prog.global_block().ops + if rank_id == 0: + for idx, op in enumerate(ops): + if op.type == "send_v2" and "gelu_0.tmp_0" in op.input_arg_names: + send_result = True + if op.type == "recv_v2" and "gelu_0.tmp_0@GRAD" in op.output_arg_names[ + 0]: + recv_result = True + else: + for idx, op in enumerate(ops): + if op.type == "send_v2" and "gelu_0.tmp_0@GRAD" in op.input_arg_names: + send_result = True + if op.type == "recv_v2" and "gelu_0.tmp_0" in op.output_arg_names[ + 0]: + recv_result = True + + return send_result and recv_result + + +class TestMLPReshard(unittest.TestCase): + def test_mlp_serial(self): + global _global_parallel_strategy + _global_parallel_strategy = None + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh(mesh=[0], parent=ROOT_MESH) + + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + dist_context = get_default_distributed_context() + rank_id = 0 + dist_main_prog, dist_startup_prog = get_dist_prog_with_parallelizer( + train_program, startup_program, dist_context) + # send and recv should not exist in serial scene. + self.assertFalse(check_send_recv_result(dist_main_prog, rank_id)) + + +if __name__ == "__main__": + unittest.main() -- GitLab