From a02532b576d41307f1e85b0c029b71b909bd456f Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Fri, 29 Oct 2021 11:20:04 +0800 Subject: [PATCH] [Auto Parallel] Improve the interface and the underlying mechanisms (#36617) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * default dist op * add dist_attr for dist op * add unitest * update inputname * update function name * add unitest * update CMakeLists.txt for CI * fix dis_matmul * fix compile error * update matmul to matmul_v2 * unify api * unify api * todo * update distop forward func * update distop forward func * auto parallel backward * update dist op * autoparallel backward * add backward for embedding * temp1 * temp2 * temp3 * temp4 * backward done1 * backward done2 * backward done3 * dist embedding remove mp mode * dist matmul remove mp mode * update dist embedding 『 * dist op init1 * dist op init 2 * update unitest * context remove parallel mode * partitioner remove parallel mode * update unitest * a more general method to support varying mesh in pipeline parallel * support varying mesh in pipeline parallel * embedding support varying mesh in pipeline parallel * matmul support varying mesh in pipeline parallel * default dist op support varying mesh in pipeline parallel * dist attribute for startup program * default dist op support varying mesh in pipeline parallel 2 * partitoner support varying mesh in pipeline parallel * revise logic for auto compeletion * revise framework.py * revise reshard unitest * revise unitest for parallelize * chmod * fixed bug for dist embedding name mapping * Improve the interface and the underlying mechanisms of auto parallel * revise completion for backward * revise completion for update * revise completion for update * update unitest * chmod * bugfix for grad_op output var's mesh * Modify codes for pr 36744 * Remove unnecessary comments in framework.py * Remove unnecessary comments in completion.py Co-authored-by: JZ-LIANG Co-authored-by: zhaoyingli Co-authored-by: JZ-LIANG <38102074+JZ-LIANG@users.noreply.github.com> --- python/paddle/distributed/__init__.py | 4 - .../distributed/auto_parallel/__init__.py | 9 +- .../distributed/auto_parallel/attribute.py | 309 ----------- .../distributed/auto_parallel/completion.py | 360 +++++++------ .../distributed/auto_parallel/context.py | 495 ------------------ .../distributed/auto_parallel/cost_model.py | 2 +- .../auto_parallel/dist_attribute.py | 436 +++++++++++++++ .../distributed/auto_parallel/dist_context.py | 427 +++++++++++++++ .../distributed/auto_parallel/dist_op.py | 243 +++++++++ .../distributed/auto_parallel/dist_tensor.py | 103 ++++ .../distributed/auto_parallel/interface.py | 479 ++--------------- .../auto_parallel/operators/__init__.py | 4 +- .../auto_parallel/operators/common.py | 155 +++--- .../auto_parallel/operators/dist_default.py | 105 ++-- .../auto_parallel/operators/dist_embedding.py | 108 ++-- .../auto_parallel/operators/dist_matmul.py | 309 ++++++----- .../auto_parallel/operators/dist_reshape.py | 75 ++- .../auto_parallel/operators/dist_softmax.py | 28 +- .../auto_parallel/operators/dist_transpose.py | 22 +- .../distributed/auto_parallel/parallelizer.py | 7 +- .../distributed/auto_parallel/partitioner.py | 207 ++++---- .../{process.py => process_group.py} | 50 +- .../distributed/auto_parallel/process_mesh.py | 135 +++++ .../distributed/auto_parallel/reshard.py | 104 ++-- .../paddle/distributed/auto_parallel/utils.py | 59 +-- python/paddle/fluid/framework.py | 17 +- .../unittests/auto_parallel_data_unshard.py | 64 ++- .../unittests/auto_parallel_parallelizer.py | 15 +- .../tests/unittests/test_auto_parallel_api.py | 197 ++++--- .../test_auto_parallel_completion.py | 408 +++++++++------ .../test_auto_parallel_completion_gpt.py | 129 +++-- .../test_auto_parallel_cost_model.py | 30 +- .../test_auto_parallel_partitioner.py | 292 +++++++---- .../test_auto_parallel_partitioner_gpt.py | 140 +++-- .../unittests/test_auto_parallel_reshard.py | 92 +++- .../test_auto_parallel_reshard_dpmppp.py | 37 +- .../test_auto_parallel_reshard_mppp.py | 80 ++- .../test_auto_parallel_reshard_serial.py | 61 ++- 38 files changed, 3220 insertions(+), 2577 deletions(-) delete mode 100644 python/paddle/distributed/auto_parallel/attribute.py mode change 100755 => 100644 python/paddle/distributed/auto_parallel/completion.py delete mode 100644 python/paddle/distributed/auto_parallel/context.py create mode 100644 python/paddle/distributed/auto_parallel/dist_attribute.py create mode 100755 python/paddle/distributed/auto_parallel/dist_context.py create mode 100644 python/paddle/distributed/auto_parallel/dist_op.py create mode 100644 python/paddle/distributed/auto_parallel/dist_tensor.py rename python/paddle/distributed/auto_parallel/{process.py => process_group.py} (76%) create mode 100644 python/paddle/distributed/auto_parallel/process_mesh.py diff --git a/python/paddle/distributed/__init__.py b/python/paddle/distributed/__init__.py index 20007f76ed5..600327e4a50 100644 --- a/python/paddle/distributed/__init__.py +++ b/python/paddle/distributed/__init__.py @@ -43,10 +43,6 @@ from .collective import wait # noqa: F401 from .auto_parallel import shard_op # noqa: F401 from .auto_parallel import shard_tensor # noqa: F401 -from .auto_parallel import set_shard_mask # noqa: F401 -from .auto_parallel import set_offload_device # noqa: F401 -from .auto_parallel import set_pipeline_stage # noqa: F401 -from .auto_parallel import ProcessMesh # noqa: F401 from .fleet import BoxPSDataset # noqa: F401 diff --git a/python/paddle/distributed/auto_parallel/__init__.py b/python/paddle/distributed/auto_parallel/__init__.py index 2779a9feb0b..3b5ccaa062f 100644 --- a/python/paddle/distributed/auto_parallel/__init__.py +++ b/python/paddle/distributed/auto_parallel/__init__.py @@ -14,10 +14,11 @@ from .interface import shard_tensor # noqa: F401 from .interface import shard_op # noqa: F401 -from .interface import set_shard_mask # noqa: F401 -from .interface import set_offload_device # noqa: F401 -from .interface import set_pipeline_stage # noqa: F401 -from .interface import ProcessMesh # noqa: F401 +from .process_mesh import ProcessMesh +# from .interface import set_shard_mask # noqa: F401 +# from .interface import set_offload_device # noqa: F401 +# from .interface import set_pipeline_stage # noqa: F401 +# from .interface import ProcessMesh # noqa: F401 from .completion import complete_annotation # noqa: F401 from .completion import complete_backward_annotation # noqa: F401 from .reshard import reshard # noqa: F401 diff --git a/python/paddle/distributed/auto_parallel/attribute.py b/python/paddle/distributed/auto_parallel/attribute.py deleted file mode 100644 index 879e94b8373..00000000000 --- a/python/paddle/distributed/auto_parallel/attribute.py +++ /dev/null @@ -1,309 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License - -import copy -from collections import defaultdict -from paddle.fluid import core - - -class TensorDistributedAttribute: - def __init__(self, owner_tensor, owner_context): - self._owner_tensor = owner_tensor - self._owner_context = owner_context - self._process_mesh = None - self._dims_mapping = None - self._shard_mask = None - self._offload_device = None - self._shape = None - self._is_annotated = {} - self._is_parameter = False - - def get_owner_tensor(self): - return self._owner_tensor - - def get_owner_context(self): - return self._owner_context - - def get_process_mesh(self): - return self._process_mesh - - def set_process_mesh(self, process_mesh): - self._process_mesh = copy.deepcopy(process_mesh) - - def get_dims_mapping(self): - return self._dims_mapping - - def set_dims_mapping(self, dims_mapping): - self._dims_mapping = copy.deepcopy(dims_mapping) - - def get_shard_mask(self): - return self._shard_mask - - def set_shard_mask(self, shard_mask): - self._shard_mask = copy.deepcopy(shard_mask) - - def get_offload_device(self): - return self._offload_device - - def set_offload_device(self, offload_device): - self._offload_device = copy.deepcopy(offload_device) - - def get_shape(self): - return self._shape - - def set_shape(self, shape): - self._shape = copy.deepcopy(shape) - - def is_annotated(self, dist_attr_name): - return self._is_annotated.get(dist_attr_name, False) - - def mark_as_annotated(self, dist_attr_name): - self._is_annotated[dist_attr_name] = True - - def is_parameter(self): - return self._is_parameter - - def mark_as_parameter(self): - self._is_parameter = True - - def is_valid(self): - if self.get_owner_tensor().type == core.VarDesc.VarType.READER: - return True - tensor_shape = self.get_owner_tensor().desc.shape() - if len(tensor_shape) != len(self.get_dims_mapping()): - return False - for i in range(len(self.get_dims_mapping())): - if self.get_dims_mapping()[i] < -1 or self.get_dims_mapping()[ - i] >= len(self.get_process_mesh().topology): - return False - for i in range(len(self.get_process_mesh().topology)): - if self.get_dims_mapping().count(i) > 1: - return False - return True - - def __str__(self): - str = "{{tensor name: {}, tensor id: {}".format( - self.get_owner_tensor().desc.name(), - self.get_owner_tensor().desc.id()) - if self.is_annotated("process_mesh"): - annotated_str = "annotated" - else: - annotated_str = "non-annotated" - str += ", process_mesh ({}): {}".format(annotated_str, - self.get_process_mesh()) - - str += ", is_parameter: {}".format(self._is_parameter) - - if self.is_annotated("dims_mapping"): - annotated_str = "annotated" - else: - annotated_str = "non-annotated" - str += ", dims_mapping ({}): {}".format(annotated_str, - self.get_dims_mapping()) - - if self.is_annotated("shard_mask"): - annotated_str = "annotated" - else: - annotated_str = "non-annotated" - str += ", shard_mask ({}): {}".format(annotated_str, - self.get_shard_mask()) - - if self.is_annotated("offload_device"): - annotated_str = "annotated" - else: - annotated_str = "non-annotated" - str += ", offload_device ({}): {} }}".format(annotated_str, - self.get_offload_device()) - return str - - def __deepcopy__(self, memo): - cls = self.__class__ - result = cls.__new__(cls) - memo[id(self)] = result - for k, v in self.__dict__.items(): - # No need to copy the owner tensor and context - if k == "_owner_tensor" or k == "_owner_context": - setattr(result, k, v) - else: - setattr(result, k, copy.deepcopy(v, memo)) - return result - - -class OperatorDistributedAttribute: - def __init__(self, owner_op, owner_context): - self._owner_op = owner_op - self._owner_context = owner_context - self._process_mesh = None - self._dims_mapping = {} - self._shapes = {} - self._is_annotated = {} - self._is_parameters = {} - self._pipeline_stage = None - self._impl_idx = None - - def get_owner_op(self): - return self._owner_op - - def get_owner_context(self): - return self._owner_context - - def get_process_mesh(self): - return self._process_mesh - - def set_process_mesh(self, process_mesh): - self._process_mesh = copy.deepcopy(process_mesh) - - def get_input_dims_mapping(self, name): - return self._dims_mapping.get("IN_" + name, None) - - def set_input_dims_mapping(self, name, dims_mapping): - self._dims_mapping["IN_" + name] = copy.deepcopy(dims_mapping) - - def get_output_dims_mapping(self, name): - return self._dims_mapping.get("OUT_" + name, None) - - def set_output_dims_mapping(self, name, dims_mapping): - self._dims_mapping["OUT_" + name] = copy.deepcopy(dims_mapping) - - def get_impl_idx(self): - return self._impl_idx - - def set_impl_idx(self, impl_idx): - self._impl_idx = impl_idx - - def get_pipeline_stage(self): - return self._pipeline_stage - - def set_pipeline_stage(self, pipeline_stage): - self._pipeline_stage = copy.deepcopy(pipeline_stage) - - def get_input_shape(self, name): - return self._shapes.get("IN_" + name, None) - - def set_input_shape(self, name, shape): - self._shapes["IN_" + name] = copy.deepcopy(shape) - - def get_output_shape(self, name): - return self._shapes.get("OUT_" + name, None) - - def set_output_shape(self, name, shape): - self._shapes["OUT_" + name] = copy.deepcopy(shape) - - def is_annotated(self, attr_name): - return self._is_annotated.get(attr_name, False) - - def mark_as_annotated(self, attr_name): - self._is_annotated[attr_name] = True - - def is_annotated_input_dims_mapping(self, name): - return self._is_annotated.get("IN_" + name, False) - - def mark_as_annotated_input_dims_mapping(self, name): - self._is_annotated["IN_" + name] = True - - def is_annotated_output_dims_mapping(self, name): - return self._is_annotated.get("OUT_" + name, False) - - def mark_as_annotated_output_dims_mapping(self, name): - self._is_annotated["OUT_" + name] = True - - def is_parameter(self, name): - return self._is_parameters.get(name, False) - - def mark_as_parameter(self, name): - self._is_parameters[name] = True - - def is_valid(self): - if "read" in self.get_owner_op().type: - return True - for name in self.get_owner_op().desc.input_arg_names(): - dims_mapping = self.get_input_dims_mapping(name) - shape = self.get_input_shape(name) - if len(shape) != len(dims_mapping): - return False - for i in range(len(dims_mapping)): - if dims_mapping[i] < -1 or dims_mapping[i] >= len( - self.get_process_mesh().topology): - return False - for i in range(len(self.get_process_mesh().topology)): - if dims_mapping.count(i) > 1: - return False - for name in self.get_owner_op().desc.output_arg_names(): - dims_mapping = self.get_output_dims_mapping(name) - shape = self.get_output_shape(name) - if len(shape) != len(dims_mapping): - return False - for i in range(len(dims_mapping)): - if dims_mapping[i] < -1 or dims_mapping[i] >= len( - self.get_process_mesh().topology): - return False - for i in range(len(self.get_process_mesh().topology)): - if dims_mapping.count(i) > 1: - return False - return True - - def __str__(self): - str = "{{op type: {}, op id: {}".format(self.get_owner_op().desc.type(), - self.get_owner_op().desc.id()) - - if self.is_annotated("process_mesh"): - annotated_str = "annotated" - else: - annotated_str = "non-annotated" - str += ", process_mesh ({}): {}".format(annotated_str, - self.get_process_mesh()) - - for arg_name in self.get_owner_op().desc.input_arg_names(): - dims_mapping = self.get_input_dims_mapping(arg_name) - if self.is_annotated_input_dims_mapping(arg_name): - annotated_str = "annotated" - else: - annotated_str = "non-annotated" - if self.is_parameter(arg_name): - is_parameter_str = "parameter" - else: - is_parameter_str = "non-parameter" - str += ", {}'s dims_mapping (input, {}, {}): {}".format( - arg_name, annotated_str, is_parameter_str, dims_mapping) - - for arg_name in self.get_owner_op().desc.output_arg_names(): - dims_mapping = self.get_output_dims_mapping(arg_name) - if self.is_annotated_output_dims_mapping(arg_name): - annotated_str = "annotated" - else: - annotated_str = "non-annotated" - if self.is_parameter(arg_name): - is_parameter_str = "parameter" - else: - is_parameter_str = "non-parameter" - str += ", {}'s dims_mapping (output, {}, {}): {}".format( - arg_name, annotated_str, is_parameter_str, dims_mapping) - - str += ", pipeline stage: {}".format(self._pipeline_stage) - - str += ", dist_impl idx: {} }}".format(self._impl_idx) - - return str - - def __deepcopy__(self, memo): - cls = self.__class__ - result = cls.__new__(cls) - memo[id(self)] = result - for k, v in self.__dict__.items(): - # No need to copy the owner op and context - if k == "_owner_op" or k == "_owner_context": - setattr(result, k, v) - else: - setattr(result, k, copy.deepcopy(v, memo)) - return result diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py old mode 100755 new mode 100644 index 0097a38e235..934239c0cd6 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -20,10 +20,13 @@ from paddle.fluid import framework from .utils import compute_compatible_process_mesh from .utils import compute_compatible_dim_mapping from .utils import compute_compatible_dims_mapping -from .utils import print_program_with_distributed_attr -from .context import get_default_distributed_context +from .utils import print_program_with_dist_attr from .operators import find_best_compatible_distributed_operator_impl -from .attribute import OperatorDistributedAttribute, TensorDistributedAttribute +from .dist_context import get_default_distributed_context +from .dist_tensor import DistributedTensor +from .dist_op import DistributedOperator +from .dist_attribute import TensorDistributedAttribute +from .dist_attribute import OperatorDistributedAttribute from paddle.distributed.fleet.meta_optimizers.common import OpRole ELEMENTWISE_LIKE_OP_LIST = ["elementwise_add", "gelu", "dropout", "cast"] @@ -43,36 +46,35 @@ def update_tensor_node_process_mesh(dist_context, tensor_node, fwd=True): process meshes are compatible for now. """ changed = False - tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( - tensor_node) + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(tensor_node) if tensor_dist_attr.is_annotated("process_mesh"): return changed - tensor_process_mesh = tensor_dist_attr.get_process_mesh() + tensor_process_mesh = tensor_dist_attr.process_mesh if fwd: inputs_process_meshes = [] for pred_op_node in tensor_node.inputs: if pred_op_node.op() is not None: - op_dist_attr = dist_context.get_op_distributed_attr_for_graph( + op_dist_attr = dist_context.get_op_dist_attr_for_graph( pred_op_node) - op_process_mesh = op_dist_attr.get_process_mesh() + op_process_mesh = op_dist_attr.process_mesh inputs_process_meshes.append(op_process_mesh) compatible_process_mesh = compute_compatible_process_mesh( inputs_process_meshes) if compatible_process_mesh is not None and tensor_process_mesh is None: - tensor_dist_attr.set_process_mesh(compatible_process_mesh) + tensor_dist_attr.process_mesh = compatible_process_mesh changed = True else: outputs_process_meshes = [] for succ_op_node in tensor_node.outputs: if succ_op_node.op() is not None: - op_dist_attr = dist_context.get_op_distributed_attr_for_graph( + op_dist_attr = dist_context.get_op_dist_attr_for_graph( succ_op_node) - op_process_mesh = op_dist_attr.get_process_mesh() + op_process_mesh = op_dist_attr.process_mesh outputs_process_meshes.append(op_process_mesh) compatible_process_mesh = compute_compatible_process_mesh( outputs_process_meshes) if compatible_process_mesh is not None and tensor_process_mesh is None: - tensor_dist_attr.set_process_mesh(compatible_process_mesh) + tensor_dist_attr.process_mesh = compatible_process_mesh changed = True return changed @@ -84,43 +86,47 @@ def update_op_node_process_mesh(dist_context, op_node, fwd=True): process meshes are compatible for now. """ changed = False - op_dist_attr = dist_context.get_op_distributed_attr_for_graph(op_node) + op_dist_attr = dist_context.get_op_dist_attr_for_graph(op_node) if op_dist_attr.is_annotated("process_mesh"): return changed - op_process_mesh = op_dist_attr.get_process_mesh() + op_process_mesh = op_dist_attr.process_mesh if fwd: inputs_process_meshes = [] for tensor_node in op_node.inputs: if tensor_node.var() is not None: - tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( tensor_node) - tensor_process_mesh = tensor_dist_attr.get_process_mesh() + tensor_process_mesh = tensor_dist_attr.process_mesh inputs_process_meshes.append(tensor_process_mesh) compatible_process_mesh = compute_compatible_process_mesh( inputs_process_meshes) if compatible_process_mesh is not None and op_process_mesh is None: - op_dist_attr.set_process_mesh(compatible_process_mesh) + op_dist_attr.process_mesh = compatible_process_mesh changed = True else: outputs_process_meshes = [] for tensor_node in op_node.outputs: if tensor_node.var() is not None: - tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( tensor_node) - tensor_process_mesh = tensor_dist_attr.get_process_mesh() + tensor_process_mesh = tensor_dist_attr.process_mesh outputs_process_meshes.append(tensor_process_mesh) compatible_process_mesh = compute_compatible_process_mesh( outputs_process_meshes) if compatible_process_mesh is not None and op_process_mesh is None: - op_dist_attr.set_process_mesh(compatible_process_mesh) + op_dist_attr.process_mesh = compatible_process_mesh changed = True return changed -def update_op_dims_mapping_by_default_dist_impl(op_dist_attr): +def update_op_dims_mapping_by_default_dist_impl(dist_context, op_node): """Each operator has a default distributed operator, only allowed to be sharded in batch dimension.""" changed = False - op_desc = op_dist_attr.get_owner_op().desc + if (not op_node.is_op()) or (op_node.op() is None): + return False + op_desc = op_node.op() + dist_op = dist_context.get_dist_op_for_graph(op_node) + op_dist_attr = dist_op.dist_attr # The following statement will be replaced by a more elegent way if op_desc.type() == "shape" or op_desc.type() == "slice": return False @@ -130,7 +136,8 @@ def update_op_dims_mapping_by_default_dist_impl(op_dist_attr): xshape_arg_names = op_desc.output("XShape") batch_dim_mappings = [] for arg_name in op_desc.input_arg_names(): - if op_dist_attr.is_parameter(arg_name): + serial_tensor = dist_op.get_serial_input(arg_name) + if serial_tensor.is_parameter: continue dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) if len(dims_mapping) > 1: @@ -140,7 +147,8 @@ def update_op_dims_mapping_by_default_dist_impl(op_dist_attr): .format(op_desc.type(), idx, mapping) batch_dim_mappings.append(dims_mapping[0]) for arg_name in op_desc.output_arg_names(): - if op_dist_attr.is_parameter(arg_name): + serial_tensor = dist_op.get_serial_output(arg_name) + if serial_tensor.is_parameter: continue dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) if arg_name not in xshape_arg_names: @@ -164,14 +172,16 @@ def update_op_dims_mapping_by_default_dist_impl(op_dist_attr): compatible_dim_mapping = compute_compatible_dim_mapping(batch_dim_mappings) assert compatible_dim_mapping is not None, "There is no compatible dim mapping." for arg_name in op_desc.input_arg_names(): - if op_dist_attr.is_parameter(arg_name): + serial_tensor = dist_op.get_serial_input(arg_name) + if serial_tensor.is_parameter: continue dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) if compatible_dim_mapping != dims_mapping[0]: dims_mapping[0] = compatible_dim_mapping changed = True for arg_name in op_desc.output_arg_names(): - if op_dist_attr.is_parameter(arg_name): + serial_tensor = dist_op.get_serial_output(arg_name) + if serial_tensor.is_parameter: continue dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) if arg_name not in xshape_arg_names: @@ -186,10 +196,13 @@ def update_op_dims_mapping_by_default_dist_impl(op_dist_attr): return changed -def update_op_dims_mapping_by_elementwise_like_dist_impl(op_dist_attr): +def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_context, op_node): """Element-wise operator can be sharded in any way (but should take care of broadcasting).""" changed = False - op_desc = op_dist_attr.get_owner_op().desc + if (not op_node.is_op()) or (op_node.op() is None): + return False + op_desc = op_node.op() + op_dist_attr = dist_context.get_op_dist_attr_for_graph(op_node) input_arg_names = op_desc.input_arg_names() input_dims_mapping_dict = {} @@ -258,12 +271,11 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): # Skip reader tensor if tensor_desc.type() == core.VarDesc.VarType.READER: return False - tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( - tensor_node) + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(tensor_node) assert tensor_dist_attr is not None if tensor_dist_attr.is_annotated("dims_mapping"): return False - tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() + tensor_dims_mapping = tensor_dist_attr.dims_mapping if fwd: dims_mapping_list = [] for pred_op_node in tensor_node.inputs: @@ -272,7 +284,7 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): or pred_op_node.op().type() == "create_double_buffer_reader" \ or pred_op_node.op().type() == "read": continue - op_dist_attr = dist_context.get_op_distributed_attr_for_graph( + op_dist_attr = dist_context.get_op_dist_attr_for_graph( pred_op_node) op_dims_mapping = op_dist_attr.get_output_dims_mapping( tensor_desc.name()) @@ -282,7 +294,7 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): dims_mapping_list) if (compatible_dims_mapping is not None) and \ (compatible_dims_mapping != tensor_dims_mapping): - tensor_dist_attr.set_dims_mapping(compatible_dims_mapping) + tensor_dist_attr.dims_mapping = compatible_dims_mapping changed = True else: dims_mapping_list = [] @@ -292,7 +304,7 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): or succ_op_node.op().type() == "create_double_buffer_reader" \ or succ_op_node.op().type() == "read": continue - op_dist_attr = dist_context.get_op_distributed_attr_for_graph( + op_dist_attr = dist_context.get_op_dist_attr_for_graph( succ_op_node) op_dims_mapping = op_dist_attr.get_input_dims_mapping( tensor_desc.name()) @@ -302,7 +314,7 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): dims_mapping_list) if (compatible_dims_mapping is not None) and \ (compatible_dims_mapping != tensor_dims_mapping): - tensor_dist_attr.set_dims_mapping(compatible_dims_mapping) + tensor_dist_attr.dims_mapping = compatible_dims_mapping changed = True return changed @@ -317,7 +329,8 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): or op_desc.type() == "create_double_buffer_reader" \ or op_desc.type() == "read": return False - op_dist_attr = dist_context.get_op_distributed_attr_for_graph(op_node) + dist_op = dist_context.get_dist_op_for_graph(op_node) + op_dist_attr = dist_op.dist_attr if fwd: for tensor_node in op_node.inputs: if tensor_node.var() is not None: @@ -327,9 +340,9 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): if op_dist_attr.is_annotated_input_dims_mapping( tensor_desc.name()): continue - tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( tensor_node) - tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() + tensor_dims_mapping = tensor_dist_attr.dims_mapping op_dims_mapping = op_dist_attr.get_input_dims_mapping( tensor_desc.name()) compatible_dims_mapping = compute_compatible_dims_mapping( @@ -341,26 +354,29 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): changed = True # Find the most compatible implemenetations from the distributed operator op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl( - op_desc.type(), op_dist_attr, fwd=True) + op_desc.type(), dist_op, fwd=True) if op_dist_impl is not None: - dim_changed = op_dist_impl.update_dims_mapping(op_dist_attr) + dim_changed = op_dist_impl.update_dims_mapping(dist_op) if dim_changed: changed = True # This statement will be replaced by a good way - if op_dist_impl.is_compatible(op_dist_attr): - op_dist_attr.set_impl_idx(op_dist_impl_idx) + if op_dist_impl.is_compatible(dist_op): + op_dist_attr.impl_type = op_desc.type() + op_dist_attr.impl_idx = op_dist_impl_idx elif is_elementwise_like_op(op_desc.type()): dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl( - op_dist_attr) + dist_context, op_node) if dim_changed: changed = True - op_dist_attr.set_impl_idx(-1) + op_dist_attr.impl_type = "element-wise" + op_dist_attr.impl_idx = -1 else: dim_changed = update_op_dims_mapping_by_default_dist_impl( - op_dist_attr) + dist_context, op_node) if dim_changed: changed = True - op_dist_attr.set_impl_idx(-2) + op_dist_attr.impl_type = "default" + op_dist_attr.impl_idx = -2 else: for tensor_node in op_node.outputs: if tensor_node.var() is not None: @@ -370,9 +386,9 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): if op_dist_attr.is_annotated_output_dims_mapping( tensor_desc.name()): continue - tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( tensor_node) - tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() + tensor_dims_mapping = tensor_dist_attr.dims_mapping op_dims_mapping = op_dist_attr.get_output_dims_mapping( tensor_desc.name()) compatible_dims_mapping = compute_compatible_dims_mapping( @@ -384,26 +400,29 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): changed = True # Find the most compatible implemenetations from the distributed operator op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl( - op_desc.type(), op_dist_attr, fwd=False) + op_desc.type(), dist_op, fwd=False) if op_dist_impl is not None: - dim_changed = op_dist_impl.update_dims_mapping(op_dist_attr) + dim_changed = op_dist_impl.update_dims_mapping(dist_op) if dim_changed: changed = True # This statement will be replaced by a good way - if op_dist_impl.is_compatible(op_dist_attr): - op_dist_attr.set_impl_idx(op_dist_impl_idx) + if op_dist_impl.is_compatible(dist_op): + op_dist_attr.impl_type = op_desc.type() + op_dist_attr.impl_idx = op_dist_impl_idx elif is_elementwise_like_op(op_desc.type()): dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl( - op_dist_attr) + dist_context, op_node) if dim_changed: changed = True - op_dist_attr.set_impl_idx(-1) + op_dist_attr.impl_type = "element-wise" + op_dist_attr.impl_idx = -1 else: dim_changed = update_op_dims_mapping_by_default_dist_impl( - op_dist_attr) + dist_context, op_node) if dim_changed: changed = True - op_dist_attr.set_impl_idx(-2) + op_dist_attr.impl_type = "default" + op_dist_attr.impl_idx = -2 return changed @@ -421,18 +440,20 @@ def complete_annotation(program, dist_context=None): # Use the default distribted context for completeion if there is no one if dist_context is None: dist_context = get_default_distributed_context() + dist_context.serial_program = program + else: + dist_context.serial_program = program - # Initialize distributed attributes for all var and op node in program - dist_context.initialize_distributed_attr_for_program(program) + # print_program_with_dist_attr(program, dist_context) - # Convert program to graph - graph = framework.IrGraph(core.Graph(program.desc)) + # Initialize distributed attributes for all var and op node in program + dist_context.init_dist_attr_for_program() # Initialize distributed attributes for all var and op node in graph - dist_context.initialize_distributed_attr_for_graph(graph) + dist_context.init_dist_attr_for_graph() # Complete process mesh for each node - all_nodes = list(graph.all_nodes()) + all_nodes = list(dist_context.serial_graph.all_nodes()) def sort_key_fun(node): first = -1 @@ -498,27 +519,27 @@ def complete_annotation(program, dist_context=None): is_wrong = False for node in all_nodes: if node.is_var() and node.var() is not None: - tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( node) - if tensor_dist_attr.get_process_mesh() is None: + if tensor_dist_attr.process_mesh is None: msg_str = "" for op_node in node.inputs: if op_node.op() is not None: - op_dist_attr = dist_context.get_op_distributed_attr_for_graph( + op_dist_attr = dist_context.get_op_dist_attr_for_graph( op_node) msg_str += "{} [{}], ".format( op_node.op().type(), - op_dist_attr.get_process_mesh()) + op_dist_attr.process_mesh) else: msg_str += "{} [{}], ".format(op_node.name(), None) for op_node in node.outputs: if op_node.op() is not None: - op_dist_attr = dist_context.get_op_distributed_attr_for_graph( + op_dist_attr = dist_context.get_op_dist_attr_for_graph( op_node) msg_str += "{} [{}], ".format( op_node.op().type(), - op_dist_attr.get_process_mesh()) + op_dist_attr.process_mesh) else: msg_str += "{} [{}], ".format(op_node.name(), None) @@ -527,27 +548,26 @@ def complete_annotation(program, dist_context=None): is_wrong = True print(msg_str) if node.is_op() and node.op() is not None: - op_dist_attr = dist_context.get_op_distributed_attr_for_graph( - node) - if op_dist_attr.get_process_mesh() is None: + op_dist_attr = dist_context.get_op_dist_attr_for_graph(node) + if op_dist_attr.process_mesh is None: msg_str = "" for tensor_node in node.inputs: if tensor_node.var() is not None: - tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( tensor_node) msg_str += "{} [{}], ".format( tensor_node.var().name(), - tensor_dist_attr.get_process_mesh()) + tensor_dist_attr.process_mesh) else: msg_str += "{} [{}], ".format( tensor_node.name(), None) for tensor_node in node.outputs: if tensor_node.var() is not None: - tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( tensor_node) msg_str += "{} [{}], ".format( tensor_node.var().name(), - tensor_dist_attr.get_process_mesh()) + tensor_dist_attr.process_mesh) else: msg_str += "{} [{}], ".format( tensor_node.name(), None) @@ -592,11 +612,14 @@ def complete_annotation(program, dist_context=None): reach_fix_point = True # Copy the corresponding distributed attribute from graph to program - dist_context.copy_distribute_attr_from_graph_to_program(graph, program) - dist_context.clear_distributed_attr_for_graph() + dist_context.copy_dist_attr_from_graph_to_program() + dist_context.clear_dist_info_for_graph() # Do the validation check and amend some completion - dist_context.amend_distributed_attr_for_program() + dist_context.amend_dist_attr_for_program() + + # print_program_with_dist_attr(program, dist_context) + dist_context.validate_dist_attr_for_program() return program @@ -636,7 +659,7 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): ops = list(auto_parallel_main_prog.global_block().ops) vars = auto_parallel_main_prog.global_block().vars - dist_op_helper = dist_context.get_dist_op_helper() + dist_op_context = dist_context.dist_op_context for idx in range(first_backward_op_idx, len(ops)): @@ -658,45 +681,42 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): forward_var = vars[forward_var_name] # TODO complete other attribte for grad var - tensor_attr = TensorDistributedAttribute(grad_var, dist_context) - process_mesh = dist_context.get_tensor_distributed_attr_for_program( - forward_var).get_process_mesh() - dims_mapping = dist_context.get_tensor_distributed_attr_for_program( - forward_var).get_dims_mapping() - tensor_attr.set_dims_mapping(dims_mapping) - tensor_attr.set_process_mesh(process_mesh) - dist_context.set_tensor_distributed_attr_for_program(grad_var, - tensor_attr) - - op_attr = OperatorDistributedAttribute(ops[idx], dist_context) - op_attr.set_process_mesh(process_mesh) - op_attr.set_output_dims_mapping(grad_var.name, dims_mapping) - dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr) + tensor_dist_attr = TensorDistributedAttribute() + process_mesh = dist_context.get_tensor_dist_attr_for_program( + forward_var).process_mesh + dims_mapping = dist_context.get_tensor_dist_attr_for_program( + forward_var).dims_mapping + tensor_dist_attr.dims_mapping = dims_mapping + tensor_dist_attr.process_mesh = process_mesh + dist_context.set_tensor_dist_attr_for_program(grad_var, + tensor_dist_attr) + + op_dist_attr = OperatorDistributedAttribute() + op_dist_attr.process_mesh = process_mesh + op_dist_attr.set_output_dims_mapping(grad_var.name, dims_mapping) + dist_context.set_op_dist_attr_for_program(ops[idx], op_dist_attr) continue # complete the annotation of grad op (xxx_grad op or sum op) # xxx_grad op will have a corresponding forward op in gradopidx2opidx grad_op = ops[idx] - if grad_op.desc.id() in dist_op_helper.gradopidx2opidx: + if grad_op.desc.id() in dist_op_context.gradopidx2opidx: # TODO support the case where one forward op corresponding to multiple xxx_grad op forward_op = _get_op_by_id( ops[:first_backward_op_idx], - dist_op_helper.gradopidx2opidx[grad_op.desc.id()]) + dist_op_context.gradopidx2opidx[grad_op.desc.id()]) assert forward_op is not None # op dist attr - forward_op_attr = dist_context.get_op_distributed_attr_for_program( + forward_op_dist_attr = dist_context.get_op_dist_attr_for_program( forward_op) - forward_op_process_mesh = forward_op_attr.get_process_mesh() - grad_op_attr = OperatorDistributedAttribute(grad_op, dist_context) - grad_op_attr.set_process_mesh(forward_op_process_mesh) + forward_op_process_mesh = forward_op_dist_attr.process_mesh + grad_op_dist_attr = OperatorDistributedAttribute() + grad_op_dist_attr.process_mesh = forward_op_process_mesh # var for output_name in grad_op.desc.output_names(): assert len(grad_op.desc.output(output_name)) in [0, 1] - # if grad_op.type == "cast": - # input_name = "X" - # else: if _is_grad_var_name(output_name): input_name = _get_forward_varname_from_grad_varname( output_name) @@ -711,39 +731,38 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): if len(grad_op.desc.output(output_name)) == 1: assert len(forward_op.desc.input(input_name)) == 1 input_var = vars[forward_op.desc.input(input_name)[0]] - input_var_dist_attr = dist_context.get_tensor_distributed_attr_for_program( + input_var_dist_attr = dist_context.get_tensor_dist_attr_for_program( input_var) assert input_var_dist_attr is not None, "[{}] has not dist attribute".format( input_var.name) - ref_dims_mapping = input_var_dist_attr.get_dims_mapping() + ref_dims_mapping = input_var_dist_attr.dims_mapping # tensor dist attr output_var = vars[grad_op.desc.output(output_name)[0]] - output_var_attr = TensorDistributedAttribute(output_var, - dist_context) - output_var_attr.set_dims_mapping(ref_dims_mapping) - output_var_attr.set_process_mesh(forward_op_process_mesh) - dist_context.set_tensor_distributed_attr_for_program( - output_var, output_var_attr) + output_var_dist_attr = TensorDistributedAttribute() + output_var_dist_attr.dims_mapping = ref_dims_mapping + output_var_dist_attr.process_mesh = forward_op_process_mesh + dist_context.set_tensor_dist_attr_for_program( + output_var, output_var_dist_attr) # op dist attr - grad_op_attr.set_output_dims_mapping(output_var.name, - ref_dims_mapping) + grad_op_dist_attr.set_output_dims_mapping(output_var.name, + ref_dims_mapping) for input_name in grad_op.input_arg_names: input_var = vars[input_name] - input_var_dist_attr = dist_context.get_tensor_distributed_attr_for_program( + input_var_dist_attr = dist_context.get_tensor_dist_attr_for_program( input_var) assert input_var_dist_attr is not None, "[{}] has not dist attribute".format( input_var.name) - ref_dims_mapping = input_var_dist_attr.get_dims_mapping() + ref_dims_mapping = input_var_dist_attr.dims_mapping assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format( input_var.name) - grad_op_attr.set_input_dims_mapping(input_name, - ref_dims_mapping) + grad_op_dist_attr.set_input_dims_mapping(input_name, + ref_dims_mapping) - dist_context.set_op_distributed_attr_for_program(grad_op, - grad_op_attr) + dist_context.set_op_dist_attr_for_program(grad_op, + grad_op_dist_attr) # only sum op for merge mutiple version grad has no a corresponding mapping in gradopidx2opidx else: @@ -755,32 +774,31 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): ref_forward_var_name = _get_forward_varname_from_grad_varname( grad_op.output_arg_names[0]) forward_var = vars[ref_forward_var_name] - ref_forward_var_dims_mapping = dist_context.get_tensor_distributed_attr_for_program( - forward_var).get_dims_mapping() - ref_forward_var_process_mesh = dist_context.get_tensor_distributed_attr_for_program( - forward_var).get_process_mesh() + ref_forward_var_dims_mapping = dist_context.get_tensor_dist_attr_for_program( + forward_var).dims_mapping + ref_forward_var_process_mesh = dist_context.get_tensor_dist_attr_for_program( + forward_var).process_mesh # output - tensor_attr = TensorDistributedAttribute( - vars[grad_op.output_arg_names[0]], dist_context) - tensor_attr.set_dims_mapping(ref_forward_var_dims_mapping) - tensor_attr.set_process_mesh(ref_forward_var_process_mesh) - dist_context.set_tensor_distributed_attr_for_program( - vars[grad_op.output_arg_names[0]], tensor_attr) + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.dims_mapping = ref_forward_var_dims_mapping + tensor_dist_attr.process_mesh = ref_forward_var_process_mesh + dist_context.set_tensor_dist_attr_for_program( + vars[grad_op.output_arg_names[0]], tensor_dist_attr) # op - grad_op_attr = OperatorDistributedAttribute(grad_op, dist_context) - grad_op_attr.set_process_mesh(ref_forward_var_process_mesh) + grad_op_dist_attr = OperatorDistributedAttribute() + grad_op_dist_attr.process_mesh = ref_forward_var_process_mesh for var_name in grad_op.input_arg_names: assert _get_forward_varname_from_grad_varname( var_name) == ref_forward_var_name - grad_op_attr.set_input_dims_mapping( + grad_op_dist_attr.set_input_dims_mapping( var_name, ref_forward_var_dims_mapping) - grad_op_attr.set_output_dims_mapping(grad_op.output_arg_names[0], - ref_forward_var_dims_mapping) - dist_context.set_op_distributed_attr_for_program(grad_op, - grad_op_attr) + grad_op_dist_attr.set_output_dims_mapping( + grad_op.output_arg_names[0], ref_forward_var_dims_mapping) + dist_context.set_op_dist_attr_for_program(grad_op, + grad_op_dist_attr) def complete_update_annotation(auto_parallel_main_prog, dist_context): @@ -808,39 +826,40 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context): param = vars[op.input("Param")[0]] grad_var = vars[op.input("Grad")[0]] - param_dist_attr = dist_context.get_tensor_distributed_attr_for_program( + param_dist_attr = dist_context.get_tensor_dist_attr_for_program( param) - grad_dist_attr = dist_context.get_tensor_distributed_attr_for_program( + grad_dist_attr = dist_context.get_tensor_dist_attr_for_program( grad_var) assert param_dist_attr is not None assert grad_dist_attr is not None - assert param_dist_attr.get_dims_mapping( - ) == grad_dist_attr.get_dims_mapping() + assert param_dist_attr.dims_mapping == grad_dist_attr.dims_mapping - ref_process_mesh = dist_context.get_tensor_distributed_attr_for_program( - param).get_process_mesh() + ref_process_mesh = dist_context.get_tensor_dist_attr_for_program( + param).process_mesh assert ref_process_mesh is not None - ref_dims_mapping = dist_context.get_tensor_distributed_attr_for_program( - param).get_dims_mapping() + ref_dims_mapping = dist_context.get_tensor_dist_attr_for_program( + param).dims_mapping assert ref_dims_mapping is not None - op_attr = OperatorDistributedAttribute(op, dist_context) - op_attr.set_process_mesh(ref_process_mesh) - op_attr.set_input_dims_mapping(grad_var.name, ref_dims_mapping) - op_attr.set_input_dims_mapping(param.name, ref_dims_mapping) - op_attr.set_output_dims_mapping(param.name, ref_dims_mapping) + op_dist_attr = OperatorDistributedAttribute() + op_dist_attr.process_mesh = ref_process_mesh + op_dist_attr.set_input_dims_mapping(grad_var.name, + ref_dims_mapping) + op_dist_attr.set_input_dims_mapping(param.name, + ref_dims_mapping) + op_dist_attr.set_output_dims_mapping(param.name, + ref_dims_mapping) learning_var = vars[op.input("LearningRate")[0]] - op_attr.set_input_dims_mapping(learning_var.name, [-1]) - op_attr.set_output_dims_mapping(learning_var.name, [-1]) + op_dist_attr.set_input_dims_mapping(learning_var.name, [-1]) + op_dist_attr.set_output_dims_mapping(learning_var.name, [-1]) if not learning_rate_completed: learning_rate_completed = True - var_dist_attr = TensorDistributedAttribute(learning_var, - dist_context) - var_dist_attr.set_process_mesh(ref_process_mesh) - var_dist_attr.set_dims_mapping([-1]) - dist_context.set_tensor_distributed_attr_for_program( - learning_var, var_dist_attr) + var_dist_attr = TensorDistributedAttribute() + var_dist_attr.process_mesh = ref_process_mesh + var_dist_attr.dims_mapping = [-1] + dist_context.set_tensor_dist_attr_for_program(learning_var, + var_dist_attr) for input_name in op.desc.input_names(): @@ -853,24 +872,25 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context): assert len(op.desc.input(input_name)) == 1 input_var = vars[op.desc.input(input_name)[0]] - input_var_attr = TensorDistributedAttribute(input_var, - dist_context) + input_var_attr = TensorDistributedAttribute() if "Beta1Pow" in input_name or "Beta2Pow" in input_name: - input_var_attr.set_dims_mapping([-1]) - op_attr.set_input_dims_mapping(input_var.name, [-1]) - op_attr.set_output_dims_mapping(input_var.name, [-1]) + input_var_attr.dims_mapping = [-1] + op_dist_attr.set_input_dims_mapping(input_var.name, + [-1]) + op_dist_attr.set_output_dims_mapping(input_var.name, + [-1]) else: assert "Moment" in input_name - input_var_attr.set_dims_mapping(ref_dims_mapping) - op_attr.set_input_dims_mapping(input_var.name, - ref_dims_mapping) - op_attr.set_output_dims_mapping(input_var.name, - ref_dims_mapping) - - input_var_attr.set_process_mesh(ref_process_mesh) - dist_context.set_tensor_distributed_attr_for_program( + input_var_attr.dims_mapping = ref_dims_mapping + op_dist_attr.set_input_dims_mapping(input_var.name, + ref_dims_mapping) + op_dist_attr.set_output_dims_mapping(input_var.name, + ref_dims_mapping) + + input_var_attr.process_mesh = ref_process_mesh + dist_context.set_tensor_dist_attr_for_program( input_var, input_var_attr) - dist_context.set_op_distributed_attr_for_program(op, op_attr) + dist_context.set_op_dist_attr_for_program(op, op_dist_attr) continue diff --git a/python/paddle/distributed/auto_parallel/context.py b/python/paddle/distributed/auto_parallel/context.py deleted file mode 100644 index 6785f21351a..00000000000 --- a/python/paddle/distributed/auto_parallel/context.py +++ /dev/null @@ -1,495 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License - -import copy -from collections import defaultdict -from paddle.fluid import framework -from paddle.fluid import core -from .attribute import TensorDistributedAttribute -from .attribute import OperatorDistributedAttribute -from .utils import append_distributed_attr_suffix -from .interface import _g_process_mesh_map - -# There always exists a default context for user. And user can set it to another one. -DEFAULT_DISTRIBUTED_CONTEXT = None - - -def get_default_distributed_context(): - global DEFAULT_DISTRIBUTED_CONTEXT - if DEFAULT_DISTRIBUTED_CONTEXT is None: - dist_context = DistributedContext() - set_default_distributed_context(dist_context) - return DEFAULT_DISTRIBUTED_CONTEXT - - -def set_default_distributed_context(dist_context): - global DEFAULT_DISTRIBUTED_CONTEXT - DEFAULT_DISTRIBUTED_CONTEXT = dist_context - - -class DistributedContext: - """ - DistributedContext is used to collect related distributed information for program and graph. - One auto-parallel run should use its own DistributedContext to avoid interfering other run. - """ - - def __init__(self): - self._is_initialized_for_program = False - self._is_initialized_for_graph = False - self._tensor_distributed_attr_map_for_program = {} - self._op_distributed_attr_map_for_program = {} - self._tensor_distributed_attr_map_for_graph = {} - self._op_distributed_attr_map_for_graph = {} - self._get_dist_op_helper = DistOpHelper() - self._process_mesh = _g_process_mesh_map.get(0, None) - - def is_initialized_for_program(self): - return self._is_initialized_for_program - - def is_initialized_for_graph(self): - return self._is_initialized_for_graph - - def get_tensor_distributed_attr_for_program(self, tensor): - tensor_id = tensor.desc.id() - tensor_dist_attr = self._tensor_distributed_attr_map_for_program.get( - tensor_id, None) - return tensor_dist_attr - - def set_tensor_distributed_attr_for_program(self, tensor, tensor_dist_attr): - tensor_id = tensor.desc.id() - self._tensor_distributed_attr_map_for_program[ - tensor_id] = tensor_dist_attr - - def get_op_distributed_attr_for_program(self, op): - op_id = op.desc.id() - op_dist_attr = self._op_distributed_attr_map_for_program.get(op_id, - None) - return op_dist_attr - - def set_op_distributed_attr_for_program(self, op, op_dist_attr): - op_id = op.desc.id() - self._op_distributed_attr_map_for_program[op_id] = op_dist_attr - - def get_tensor_distributed_attr_for_graph(self, tensor_node): - tensor_node_id = tensor_node.id() - tensor_dist_attr = self._tensor_distributed_attr_map_for_graph.get( - tensor_node_id, None) - return tensor_dist_attr - - def set_tensor_distributed_attr_for_graph(self, tensor_node, - tensor_dist_attr): - tensor_node_id = tensor_node.id() - self._tensor_distributed_attr_map_for_graph[ - tensor_node_id] = tensor_dist_attr - - def get_op_distributed_attr_for_graph(self, op_node): - op_node_id = op_node.id() - op_dist_attr = self._op_distributed_attr_map_for_graph.get(op_node_id, - None) - return op_dist_attr - - def set_op_distributed_attr_for_graph(self, op_node, op_dist_attr): - op_node_id = op_node.id() - self._op_distributed_attr_map_for_graph[op_node_id] = op_dist_attr - - def set_process_mesh(self, process_mesh): - self._process_mesh = process_mesh - - def get_dist_op_helper(self): - return self._get_dist_op_helper - - def initialize_distributed_attr_for_program(self, program): - if self._is_initialized_for_program: - return - for block in program.blocks: - for tensor in block.vars.values(): - # Since only tensors have distributed attributes, it's better to make sure var is a tensor - tensor_dist_attr = self.get_tensor_distributed_attr_for_program( - tensor) - if tensor_dist_attr is None: - tensor_dist_attr = TensorDistributedAttribute(tensor, self) - self._copy_distributed_attr_from_tensor_desc( - tensor.desc, tensor_dist_attr) - self.set_tensor_distributed_attr_for_program( - tensor, tensor_dist_attr) - if tensor.type == core.VarDesc.VarType.READER: - tensor_dist_attr.set_shape([]) - else: - tensor_dist_attr.set_shape(tensor.desc.shape()) - if tensor_dist_attr.get_process_mesh() is not None: - tensor_dist_attr.mark_as_annotated("process_mesh") - if tensor_dist_attr.get_dims_mapping() is None: - tensor_dims_mapping = [ - -1 for _ in range(len(tensor_dist_attr.get_shape())) - ] - tensor_dist_attr.set_dims_mapping(tensor_dims_mapping) - else: - tensor_dist_attr.mark_as_annotated("dims_mapping") - if isinstance(tensor, framework.Parameter): - tensor_dist_attr.mark_as_parameter() - for op in block.ops: - op_dist_attr = self.get_op_distributed_attr_for_program(op) - if op_dist_attr is None: - op_dist_attr = OperatorDistributedAttribute(op, self) - self._copy_distributed_attr_from_op_desc(op.desc, - op_dist_attr) - self.set_op_distributed_attr_for_program(op, op_dist_attr) - # Default distributed implementation for all operators - # This will be updated during the completion prcess - op_dist_attr.set_impl_idx(-2) - if op_dist_attr.get_process_mesh() is not None: - op_dist_attr.mark_as_annotated("process_mesh") - for tensor_name in op.input_arg_names: - # There may be a better way to find the tensor by name - if op.type == "create_py_reader" \ - or tensor.type == core.VarDesc.VarType.READER: - op_dist_attr.set_input_shape(tensor_name, []) - else: - tensor = op.block._var_recursive(tensor_name) - op_dist_attr.set_input_shape(tensor_name, - tensor.desc.shape()) - if op_dist_attr.get_input_dims_mapping(tensor_name) is None: - tensor_dims_mapping = [ - -1 - for _ in range( - len(op_dist_attr.get_input_shape(tensor_name))) - ] - op_dist_attr.set_input_dims_mapping(tensor_name, - tensor_dims_mapping) - else: - op_dist_attr.mark_as_annotated_input_dims_mapping( - tensor_name) - if isinstance(tensor, framework.Parameter): - op_dist_attr.mark_as_parameter(tensor_name) - for tensor_name in op.output_arg_names: - tensor = op.block._var_recursive(tensor_name) - if tensor.type == core.VarDesc.VarType.READER: - op_dist_attr.set_output_shape(tensor_name, []) - else: - op_dist_attr.set_output_shape(tensor_name, - tensor.desc.shape()) - if op_dist_attr.get_output_dims_mapping( - tensor_name) is None: - tensor_dims_mapping = [ - -1 - for _ in range( - len( - op_dist_attr.get_output_shape(tensor_name))) - ] - op_dist_attr.set_output_dims_mapping( - tensor_name, tensor_dims_mapping) - else: - op_dist_attr.mark_as_annotated_output_dims_mapping( - tensor_name) - if isinstance(tensor, framework.Parameter): - op_dist_attr.mark_as_parameter(tensor_name) - self._is_initialized_for_program = True - - def finalize_distributed_attr_for_program(self, program): - assert self._is_initialized_for_program, \ - "The program must initialize its distributed attribute before finalization." - for block in program.blocks: - for tensor in block.vars.values(): - tensor_dist_attr = self.get_tensor_distributed_attr_for_program( - tensor) - if tensor_dist_attr is not None: - self._store_distributed_attr_to_tensor_desc( - tensor.desc, tensor_dist_attr) - for op in block.ops: - op_dist_attr = self.get_op_distributed_attr_for_program(op) - if op_dist_attr is not None: - self._store_distributed_attr_to_op_desc(op.desc, - op_dist_attr) - - def _copy_distributed_attr_from_tensor_desc(self, desc, dist_attr): - from paddle.distributed.auto_parallel.interface import _g_process_mesh_map - attr_name = append_distributed_attr_suffix("mesh_id") - if desc.has_attr(attr_name): - mesh_id = desc.attr(attr_name) - process_mesh = _g_process_mesh_map[mesh_id] - copied_process_mesh = copy.deepcopy(process_mesh) - dist_attr.set_process_mesh(copied_process_mesh) - attr_name = append_distributed_attr_suffix("dim_mapping") - if desc.has_attr(attr_name): - dims_mapping = desc.attr(attr_name) - copied_dims_mapping = copy.deepcopy(dims_mapping) - dist_attr.set_dims_mapping(copied_dims_mapping) - attr_name = append_distributed_attr_suffix("mask") - if desc.has_attr(attr_name): - shard_mask = desc.attr(attr_name) - copied_shard_mask = copy.deepcopy(shard_mask) - dist_attr.set_shard_mask(copied_shard_mask) - attr_name = append_distributed_attr_suffix("offload_device") - if desc.has_attr(attr_name): - offload_device = desc.attr(attr_name) - copied_offload_device = copy.deepcopy(offload_device) - dist_attr.set_offload_device(copied_offload_device) - - def _copy_distributed_attr_from_op_desc(self, desc, dist_attr): - from paddle.distributed.auto_parallel.interface import _g_process_mesh_map - attr_name = append_distributed_attr_suffix("mesh_id") - if desc.has_attr(attr_name): - mesh_id = desc.attr(attr_name) - process_mesh = _g_process_mesh_map[mesh_id] - copied_process_mesh = copy.deepcopy(process_mesh) - dist_attr.set_process_mesh(copied_process_mesh) - for tensor_name in desc.input_arg_names(): - attr_name = append_distributed_attr_suffix("IN_" + tensor_name) - if desc.has_attr(attr_name): - dims_mapping = desc.attr(attr_name) - copied_dims_mapping = copy.deepcopy(dims_mapping) - dist_attr.set_input_dims_mapping(tensor_name, - copied_dims_mapping) - for tensor_name in desc.output_arg_names(): - attr_name = append_distributed_attr_suffix("OUT_" + tensor_name) - if desc.has_attr(attr_name): - dims_mapping = desc.attr(attr_name) - copied_dims_mapping = copy.deepcopy(dims_mapping) - dist_attr.set_input_dims_mapping(tensor_name, - copied_dims_mapping) - attr_name = append_distributed_attr_suffix("pipeline_stage") - if desc.has_attr(attr_name): - pipeline_stage = desc.attr(attr_name) - copied_pipeline_stage = copy.deepcopy(pipeline_stage) - dist_attr.set_pipeline_stage(copied_pipeline_stage) - - def _store_distributed_attr_to_tensor_desc(self, desc, dist_attr): - process_mesh = dist_attr.get_process_mesh() - if process_mesh is not None: - attr_name = append_distributed_attr_suffix("mesh_id") - desc._set_attr(attr_name, process_mesh._id) - dims_mapping = dist_attr.get_dims_mapping() - if dims_mapping is not None: - attr_name = append_distributed_attr_suffix("dim_mapping") - desc._set_attr(attr_name, dims_mapping) - shard_mask = dist_attr.get_shard_mask() - if shard_mask is not None: - attr_name = append_distributed_attr_suffix("mask") - desc._set_attr(attr_name, shard_mask) - offload_device = dist_attr.get_offload_device() - if offload_device is not None: - attr_name = append_distributed_attr_suffix("offload_device") - desc._set_attr(attr_name, offload_device) - - def _store_distributed_attr_to_op_desc(self, desc, dist_attr): - process_mesh = dist_attr.get_process_mesh() - if process_mesh is not None: - attr_name = append_distributed_attr_suffix("mesh_id") - desc._set_attr(attr_name, process_mesh._id) - for tensor_name in desc.input_arg_names(): - dims_mapping = dist_attr.get_input_dims_mapping(tensor_name) - if dims_mapping is not None: - attr_name = append_distributed_attr_suffix("IN_" + tensor_name) - desc._set_attr(attr_name, dims_mapping) - for tensor_name in desc.output_arg_names(): - dims_mapping = dist_attr.get_output_dims_mapping(tensor_name) - if dims_mapping is not None: - attr_name = append_distributed_attr_suffix("OUT_" + tensor_name) - desc._set_attr(attr_name, dims_mapping) - pipeline_stage = dist_attr.get_pipeline_stage() - if pipeline_stage is not None: - attr_name = append_distributed_attr_suffix("pipeline_stage") - desc._set_attr(attr_name, pipeline_stage) - - def initialize_distributed_attr_for_graph(self, graph): - assert self._is_initialized_for_program, \ - "The program must initialize its distributed attribute before its graph." - if self._is_initialized_for_graph: - return - all_nodes = graph.all_nodes() - for node in all_nodes: - if node.is_var() and node.var() is not None: - tensor_desc = node.var() - tensor_id = tensor_desc.id() - tensor_dist_attr = self._tensor_distributed_attr_map_for_program[ - tensor_id] - assert tensor_dist_attr is not None, \ - "Tensor must have a distributed attribute after the initialization for program." - new_tensor_dist_attr = copy.deepcopy(tensor_dist_attr) - self.set_tensor_distributed_attr_for_graph(node, - new_tensor_dist_attr) - - if node.is_op() and node.op() is not None: - op_desc = node.op() - op_id = op_desc.id() - op_dist_attr = self._op_distributed_attr_map_for_program[op_id] - assert op_dist_attr is not None, \ - "Operator must have a distributed attribute after the initialization for program." - new_op_dist_attr = copy.deepcopy(op_dist_attr) - self.set_op_distributed_attr_for_graph(node, new_op_dist_attr) - self._is_initialized_for_graph = True - - def clear_distributed_attr_for_program(self): - self._tensor_distributed_attr_map_for_program.clear() - self._op_distributed_attr_map_for_program.clear() - - def clear_distributed_attr_for_graph(self): - self._tensor_distributed_attr_map_for_graph.clear() - self._op_distributed_attr_map_for_graph.clear() - - def copy_distribute_attr_from_graph_to_program(self, graph, program): - assert self._is_initialized_for_program and self._is_initialized_for_graph, \ - "The distribute attributes must be initialized both in its program and graph" - updated_tensors = {} - all_nodes = graph.all_nodes() - for node in all_nodes: - if node.is_var() and node.var() is not None: - tensor_desc = node.var() - tensor_id = tensor_desc.id() - updated = updated_tensors.get(tensor_desc.name(), False) - # If a var has multiples var nodes in graph, only use the first one for now - if not updated: - tensor_dist_attr = self.get_tensor_distributed_attr_for_graph( - node) - new_tensor_dist_attr = copy.deepcopy(tensor_dist_attr) - self._tensor_distributed_attr_map_for_program[ - tensor_id] = new_tensor_dist_attr - updated_tensors[tensor_desc.name()] = True - if node.is_op() and node.op() is not None: - op_desc = node.op() - op_id = op_desc.id() - op_dist_attr = self.get_op_distributed_attr_for_graph(node) - new_op_dist_attr = copy.deepcopy(op_dist_attr) - self._op_distributed_attr_map_for_program[ - op_id] = new_op_dist_attr - - def amend_distributed_attr_for_program(self): - for attr in self._tensor_distributed_attr_map_for_program.values(): - assert attr.is_valid(), \ - "Tensor's distributed attribute {} is not valid".format(attr) - tensor_shape = attr.get_shape() - dims_mapping = attr.get_dims_mapping() - process_mesh_shape = attr.get_process_mesh().topology - # If the dimension of tensor is less than the sharding dimension of process mesh, - # we just amend the dimension mapping to -1. (Is this really OK?) - for i in range(len(tensor_shape)): - if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ - and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: - dims_mapping[i] = -1 - - for attr in self._op_distributed_attr_map_for_program.values(): - assert attr.is_valid(), \ - "Operator's distributed attribute {} is not valid".format(attr) - for arg_name in attr.get_owner_op().desc.input_arg_names(): - tensor_shape = attr.get_input_shape(arg_name) - dims_mapping = attr.get_input_dims_mapping(arg_name) - process_mesh_shape = attr.get_process_mesh().topology - # If the dimension of tensor is less than the sharding dimension of process mesh, - # we just amend the dimension mapping to -1. (Is this really OK?) - for i in range(len(tensor_shape)): - if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ - and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: - dims_mapping[i] = -1 - - for arg_name in attr.get_owner_op().desc.output_arg_names(): - tensor_shape = attr.get_output_shape(arg_name) - dims_mapping = attr.get_output_dims_mapping(arg_name) - process_mesh_shape = attr.get_process_mesh().topology - # If the dimension of tensor is less than the sharding dimension of process mesh, - # we just amend the dimension mapping to -1. (Is this really OK?) - for i in range(len(tensor_shape)): - if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ - and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: - dims_mapping[i] = -1 - - -class DistOpHelper: - """ - DistOpHelper is used to create a dist op desc in Program. - Every time to create a new dist op, the context should be updated for it accordingly. - """ - - def __init__(self): - self._dst_main_program = None - self._dst_startup_program = None - self._varname_mapping = None - self._rank_id = None - self._cur_src_op = None - self._cur_dist_attr = None - self.gradopidx2opidx = {} - self.already_init_sync_vars = set() - - def set_dst_main_program(self, prog): - self._dst_main_program = prog - - def get_dst_main_program(self): - return self._dst_main_program - - def set_dst_startup_program(self, prog): - self._dst_startup_program = prog - - def get_dst_startup_program(self): - return self._dst_startup_program - - def set_varname_mapping(self, mapping): - self._varname_mapping = mapping - - def get_varname_mapping(self): - return self._varname_mapping - - def set_rank_id(self, rank_id): - self._rank_id = rank_id - - def get_rank_id(self): - return self._rank_id - - def set_cur_src_op(self, cur_src_op): - self._cur_src_op = cur_src_op - - def get_cur_src_op(self): - return self._cur_src_op - - def prepare_forward_context(self, src_op): - - self.set_cur_src_op(src_op) - - # build input varname mapping - kinputs = {} - for input_name in src_op.desc.input_names(): - varnames = [] - for varname in src_op.desc.input(input_name): - varnames.append(self._varname_mapping[varname]) - kinputs[input_name] = varnames - - # build output varname mapping - koutputs = {} - for output_name in src_op.desc.output_names(): - varnames = [] - for varname in src_op.desc.output(output_name): - varnames.append(self._varname_mapping[varname]) - koutputs[output_name] = varnames - - return kinputs, koutputs - - def prepare_backward_context(self, backward_op): - - self.set_cur_src_op(backward_op) - - # build input varname mapping - kinputs = {} - for input_name in backward_op.desc.input_names(): - varnames = [] - for varname in backward_op.desc.input(input_name): - varnames.append(varname) - kinputs[input_name] = varnames - - # build output varname mapping - koutputs = {} - for output_name in backward_op.desc.output_names(): - varnames = [] - for varname in backward_op.desc.output(output_name): - varnames.append(varname) - koutputs[output_name] = varnames - - return kinputs, koutputs diff --git a/python/paddle/distributed/auto_parallel/cost_model.py b/python/paddle/distributed/auto_parallel/cost_model.py index 3fd438e2a62..b1ff4fb0ba7 100644 --- a/python/paddle/distributed/auto_parallel/cost_model.py +++ b/python/paddle/distributed/auto_parallel/cost_model.py @@ -131,7 +131,7 @@ class TensorCostNode(CostNode): elif node.dtype == paddle.int64: self.dtype_factor *= 8 else: - raise NotImplementedError("{} not counted".format(v.node.dtype)) + raise NotImplementedError("{} not counted".format(node.dtype)) self.batch_size = None if batch_size is not None: diff --git a/python/paddle/distributed/auto_parallel/dist_attribute.py b/python/paddle/distributed/auto_parallel/dist_attribute.py new file mode 100644 index 00000000000..4415448769d --- /dev/null +++ b/python/paddle/distributed/auto_parallel/dist_attribute.py @@ -0,0 +1,436 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import copy +from collections import defaultdict +from paddle.fluid.framework import Variable +from .process_mesh import ProcessMesh + +_g_tensor_dist_attr_field_keys = [ + "process_mesh", "dims_mapping", "shard_sizes", "device_placement" +] + +_g_op_dist_attr_field_keys = ["process_mesh", "impl_type", "impl_idx"] + +_g_op_input_suffix = "@input" + +_g_op_output_suffix = "@output" + + +def get_tensor_dist_attr_field_keys(): + global _g_tensor_dist_attr_field_keys + return _g_tensor_dist_attr_field_keys + + +def get_op_dist_attr_field_keys(): + global _g_op_dist_attr_field_keys + return _g_op_dist_attr_field_keys + + +def append_op_input_suffix(name): + global _g_op_input_suffix + return name + _g_op_input_suffix + + +def append_op_output_suffix(name): + global _g_op_output_suffix + return name + _g_op_output_suffix + + +class TensorDistributedAttribute: + def __init__(self): + # The process mesh of distributed operator attribute must is the same as + # the process meshes of all input and output distributed attributed + self._process_mesh = None + self._dims_mapping = None + self._shard_sizes = None + self._device_placement = None + self._is_annotated = {} + + @property + def process_mesh(self): + return self._process_mesh + + @process_mesh.setter + def process_mesh(self, process_mesh): + if process_mesh is not None: + assert isinstance(process_mesh, (list, ProcessMesh)), \ + "The type of process_mesh must be list or ProcessMesh." + if isinstance(process_mesh, list): + process_mesh = ProcessMesh(process_mesh) + self._process_mesh = copy.deepcopy(process_mesh) + + @property + def dims_mapping(self): + return self._dims_mapping + + @dims_mapping.setter + def dims_mapping(self, dims_mapping): + if dims_mapping is not None: + assert isinstance(dims_mapping, list), \ + "The type of dims_mapping must be list." + assert all(isinstance(x, int) for x in dims_mapping), \ + ("All elements of dims_mapping must be integer") + assert all(x >= -1 for x in dims_mapping), \ + ("All elements of dims_mapping must be greater than or equal to -1.") + self._dims_mapping = copy.deepcopy(dims_mapping) + + @property + def shard_sizes(self): + return self._shard_sizes + + @shard_sizes.setter + def shard_sizes(self, shard_sizes): + if shard_sizes is not None: + self._shard_sizes = copy.deepcopy(shard_sizes) + + @property + def device_placement(self): + return self._device_placement + + @device_placement.setter + def device_placement(self, device_placement): + if device_placement is not None: + self._device_placement = copy.deepcopy(device_placement) + + def init(self, dist_attr): + if dist_attr is None: + return + assert isinstance(dist_attr, (dict, TensorDistributedAttribute)), \ + "The type of dist_attr must be dict or TensorDistributedAttribute." + if isinstance(dist_attr, dict): + for key, value in dist_attr.items(): + if key in get_tensor_dist_attr_field_keys(): + field_property = TensorDistributedAttribute.__dict__.get( + key, None) + if field_property: + field_property.fset(self, value) + else: + assert False, "No setter for {} in args {}.".format( + key, dist_attr) + elif isinstance(dist_attr, TensorDistributedAttribute): + for key in get_tensor_dist_attr_field_keys(): + field_property = TensorDistributedAttribute.__dict__.get(key, + None) + if field_property: + field_property.fset(self, field_property.fget(dist_attr)) + else: + assert False, "No setter for {} in args {}.".format( + key, dist_attr) + self._is_annotated = copy.deepcopy(dist_attr._is_annotated) + + def is_annotated(self, dist_attr_field_name): + return self._is_annotated.get(dist_attr_field_name, False) + + def mark_annotated(self, dist_attr_field_name): + self._is_annotated[dist_attr_field_name] = True + + def mark_annotated_as(self, dist_attr): + if dist_attr is None: + return + assert isinstance(dist_attr, (dict, TensorDistributedAttribute)), \ + "The type of dist_attr must be dict or TensorDistributedAttribute." + if isinstance(dist_attr, dict): + for key in dist_attr.keys(): + if key in get_tensor_dist_attr_field_keys(): + self.mark_annotated(key) + elif isinstance(dist_attr, TensorDistributedAttribute): + self._is_annotated = copy.deepcopy(dist_attr._is_annotated) + + def clear_annotated(self): + self._is_annotated.clear() + + def __str__(self): + str = "\n\ttensor_dist_attr = {" + if self.is_annotated("process_mesh"): + annotated_str = "annotated" + else: + annotated_str = "non-annotated" + str += "\n\t\tprocess_mesh ({}): {},".format(annotated_str, + self.process_mesh) + + if self.is_annotated("dims_mapping"): + annotated_str = "annotated" + else: + annotated_str = "non-annotated" + str += "\n\t\tdims_mapping ({}): {}".format(annotated_str, + self.dims_mapping) + str += "\n\t}" + return str + + +class OperatorDistributedAttribute: + def __init__(self): + self._process_mesh = None + self._impl_type = None + self._impl_idx = None + self._inputs_dist_attrs = {} + self._outputs_dist_attrs = {} + self._is_annotated = {} + + @property + def process_mesh(self): + return self._process_mesh + + @process_mesh.setter + def process_mesh(self, process_mesh): + if process_mesh is not None: + assert isinstance(process_mesh, (list, ProcessMesh)), \ + "The type of process_mesh must be list or ProcessMesh." + if isinstance(process_mesh, list): + process_mesh = ProcessMesh(process_mesh) + self._process_mesh = copy.deepcopy(process_mesh) + for dist_attr in self._inputs_dist_attrs.values(): + dist_attr.process_mesh = process_mesh + for dist_attr in self._outputs_dist_attrs.values(): + dist_attr.process_mesh = process_mesh + + @property + def impl_type(self): + return self._impl_type + + @impl_type.setter + def impl_type(self, impl_type): + if impl_type is not None: + self._impl_type = impl_type + + @property + def impl_idx(self): + return self._impl_idx + + @impl_idx.setter + def impl_idx(self, impl_idx): + if impl_idx is not None: + self._impl_idx = impl_idx + + @property + def inputs_dist_attrs(self): + return self._inputs_dist_attrs + + @property + def outputs_dist_attrs(self): + return self._outputs_dist_attrs + + def get_input_dist_attr(self, name): + return self._inputs_dist_attrs.get(name, None) + + def set_input_dist_attr(self, name, dist_attr): + dist_attr_object = TensorDistributedAttribute() + dist_attr_object.init(dist_attr) + self._inputs_dist_attrs[name] = dist_attr_object + + def get_output_dist_attr(self, name): + return self._outputs_dist_attrs.get(name, None) + + def set_output_dist_attr(self, name, dist_attr): + dist_attr_object = TensorDistributedAttribute() + dist_attr_object.init(dist_attr) + self._outputs_dist_attrs[name] = dist_attr_object + + def get_input_dims_mapping(self, name): + input_dist_attr = self.get_input_dist_attr(name) + if input_dist_attr: + dims_mapping = input_dist_attr.dims_mapping + else: + dims_mapping = None + return dims_mapping + + def set_input_dims_mapping(self, name, dims_mapping): + input_dist_attr = self.get_input_dist_attr(name) + if input_dist_attr: + input_dist_attr.dims_mapping = dims_mapping + else: + dist_attr = TensorDistributedAttribute() + dist_attr.dims_mapping = dims_mapping + self._inputs_dist_attrs[name] = dist_attr + + def get_output_dims_mapping(self, name): + output_dist_attr = self.get_output_dist_attr(name) + if output_dist_attr: + dims_mapping = output_dist_attr.dims_mapping + else: + dims_mapping = None + return dims_mapping + + def set_output_dims_mapping(self, name, dims_mapping): + output_dist_attr = self.get_output_dist_attr(name) + if output_dist_attr: + output_dist_attr.dims_mapping = dims_mapping + else: + dist_attr = TensorDistributedAttribute() + dist_attr.dims_mapping = dims_mapping + self._outputs_dist_attrs[name] = dist_attr + + def init(self, dist_attr): + if dist_attr is None: + return + assert isinstance(dist_attr, (dict, OperatorDistributedAttribute)), \ + "The type of dist_attr must be dict or OperatorDistributedAttribute." + if isinstance(dist_attr, dict): + for key, value in dist_attr.items(): + if isinstance(key, Variable): + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.init(value) + if dist_attr.get(append_op_input_suffix(key.name), False): + self.set_input_dist_attr(key.name, tensor_dist_attr) + if dist_attr.get(append_op_output_suffix(key.name), False): + self.set_output_dist_attr(key.name, tensor_dist_attr) + else: + if key in get_op_dist_attr_field_keys(): + field_property = OperatorDistributedAttribute.__dict__.get( + key, None) + if field_property: + field_property.fset(self, value) + else: + assert False, "No setter for {} in args {}.".format( + key, dist_attr) + elif isinstance(dist_attr, OperatorDistributedAttribute): + for tensor_name, tensor_dist_attr in dist_attr.inputs_dist_attrs.items( + ): + self.set_input_dist_attr( + tensor_name, dist_attr.get_input_dist_attr(tensor_name)) + for tensor_name, tensor_dist_attr in dist_attr.outputs_dist_attrs.items( + ): + self.set_output_dist_attr( + tensor_name, dist_attr.get_output_dist_attr(tensor_name)) + self._is_annotated = copy.deepcopy(dist_attr._is_annotated) + for key in get_op_dist_attr_field_keys(): + field_property = OperatorDistributedAttribute.__dict__.get(key, + None) + if field_property: + field_property.fset(self, field_property.fget(dist_attr)) + else: + assert False, "No setter for {} in args {}.".format( + key, dist_attr) + # Make sure proscess_meshes in dist op be same + process_meshes = [] + process_meshes.append(self.process_mesh) + for tensor_dist_attr in self.inputs_dist_attrs.values(): + process_meshes.append(tensor_dist_attr.process_mesh) + for tensor_dist_attr in self.outputs_dist_attrs.values(): + process_meshes.append(tensor_dist_attr.process_mesh) + shared_process_mesh = None + for process_mesh in process_meshes: + if process_mesh is not None: + if shared_process_mesh is None: + shared_process_mesh = process_mesh + else: + assert process_mesh == shared_process_mesh, \ + "ProcessMeshes in DistributedOperator must be the same." + self.process_mesh = shared_process_mesh + + def is_annotated(self, attr_name): + return self._is_annotated.get(attr_name, False) + + def mark_annotated(self, attr_name): + if attr_name == "process_mesh": + # Make sure proscess_mesh be annotated consistently + self._is_annotated[attr_name] = True + for tensor_dist_attr in self.inputs_dist_attrs.values(): + tensor_dist_attr.mark_annotated(attr_name) + for tensor_dist_attr in self.outputs_dist_attrs.values(): + tensor_dist_attr.mark_annotated(attr_name) + else: + self._is_annotated[attr_name] = True + + def mark_annotated_as(self, dist_attr): + if dist_attr is None: + return + assert isinstance(dist_attr, (dict, OperatorDistributedAttribute)), \ + "The type of dist_attr must be dict or OperatorDistributedAttribute." + if isinstance(dist_attr, dict): + for key, value in dist_attr.items(): + if isinstance(key, Variable): + input_dist_attr = self.get_input_dist_attr(key.name) + if input_dist_attr is not None: + input_dist_attr.mark_annotated_as(value) + output_dist_attr = self.get_output_dist_attr(key.name) + if output_dist_attr is not None: + output_dist_attr.mark_annotated_as(value) + else: + if key in get_op_dist_attr_field_keys(): + self.mark_annotated(key) + process_mesh_annotated = False + if self.is_annotated("process_mesh"): + process_mesh_annotated = True + for tensor_dist_attr in self.inputs_dist_attrs.values(): + if tensor_dist_attr.is_annotated("process_mesh"): + process_mesh_annotated = True + for tensor_dist_attr in self.outputs_dist_attrs.values(): + if tensor_dist_attr.is_annotated("process_mesh"): + process_mesh_annotated = True + if process_mesh_annotated: + self.mark_annotated("process_mesh") + elif isinstance(dist_attr, OperatorDistributedAttribute): + process_mesh_annotated = False + self._is_annotated = copy.deepcopy(dist_attr._is_annotated) + if self.is_annotated("process_mesh"): + process_mesh_annotated = True + for tensor_name, tensor_dist_attr in dist_attr.inputs_dist_attrs.items( + ): + input_dist_attr = self.get_input_dist_attr(tensor_name) + if input_dist_attr is not None: + input_dist_attr.mark_annotated_as(tensor_dist_attr) + if input_dist_attr.is_annotated("process_mesh"): + process_mesh_annotated = True + for tensor_name, tensor_dist_attr in dist_attr.outputs_dist_attrs.items( + ): + output_dist_attr = self.get_output_dist_attr(tensor_name) + if output_dist_attr is not None: + output_dist_attr.mark_annotated_as(tensor_dist_attr) + if output_dist_attr.is_annotated("process_mesh"): + process_mesh_annotated = True + if process_mesh_annotated: + self.mark_annotated("process_mesh") + + def clear_annotated(self): + self._is_annotated.clear() + for tensor_dist_attr in self.inputs_dist_attrs.values(): + tensor_dist_attr.clear_annotated() + for tensor_dist_attr in self.outputs_dist_attrs.values(): + tensor_dist_attr.clear_annotated() + + def is_annotated_input_dims_mapping(self, name): + input_dist_attr = self.get_input_dist_attr(name) + if input_dist_attr: + return input_dist_attr.is_annotated("dims_mapping") + else: + return False + + def is_annotated_output_dims_mapping(self, name): + output_dist_attr = self.get_output_dist_attr(name) + if output_dist_attr: + return output_dist_attr.is_annotated("dims_mapping") + else: + return False + + def __str__(self): + str = "\n\top_dist_attr = {" + if self.is_annotated("process_mesh"): + annotated_str = "annotated" + else: + annotated_str = "non-annotated" + str += "\n\t\tprocess_mesh ({}): {},".format(annotated_str, + self.process_mesh) + + for arg_name, tensor_dist_attr in self.inputs_dist_attrs.items(): + str += "\n\t\t{}'s: {},".format(arg_name, tensor_dist_attr) + + for arg_name, tensor_dist_attr in self.outputs_dist_attrs.items(): + str += "\n\t\t{}'s: {},".format(arg_name, tensor_dist_attr) + + str += "\n\t\timpl type: {}, ".format(self._impl_type) + str += "impl idx: {}".format(self._impl_idx) + str += "\n\t}" + return str diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py new file mode 100755 index 00000000000..e3b3ee6a376 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -0,0 +1,427 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import copy +from collections import defaultdict +from paddle.fluid import framework +from paddle.fluid import core +from .dist_attribute import TensorDistributedAttribute +from .dist_attribute import OperatorDistributedAttribute +from .dist_tensor import DistributedTensor +from .dist_op import DistributedOperator +from .process_mesh import ProcessMesh + +# There always exists a default context for user. And user can set it to another one. +_g_default_distributed_context = None + + +def get_default_distributed_context(): + global _g_default_distributed_context + if _g_default_distributed_context is None: + dist_context = DistributedContext() + set_default_distributed_context(dist_context) + return _g_default_distributed_context + + +def set_default_distributed_context(dist_context): + global _g_default_distributed_context + _g_default_distributed_context = dist_context + + +class DistributedContext: + """ + DistributedContext is used to collect related distributed information for program and graph. + One auto-parallel run should use its own DistributedContext to avoid interfering other run. + """ + + def __init__(self, program=None): + self._serial_program = program + self._serial_graph = None + self._is_initialized_for_program = False + self._is_initialized_for_graph = False + self._dist_tensors_for_program = {} + self._dist_ops_for_program = {} + self._dist_tensors_for_graph = {} + self._dist_ops_for_graph = {} + self._dist_op_context = DistributedOperatorContext() + self._process_meshes = [] + + @property + def serial_program(self): + return self._serial_program + + @property + def serial_graph(self): + return self._serial_graph + + @serial_program.setter + def serial_program(self, program): + assert self._serial_program is None, \ + "This distributed context has already been realted to a serial program" + self._serial_program = program + + @property + def process_meshes(self): + return self._process_meshes + + @property + def dist_op_context(self): + return self._dist_op_context + + def add_process_mesh(self, process_mesh): + assert isinstance(process_mesh, ProcessMesh), \ + 'The type of dim_mapping must be ProcessMesh.' + if process_mesh not in self.process_meshes: + self._process_meshes.append(process_mesh) + + def add_dist_tensor_for_program(self, dist_tensor): + inner_serial_tensor = dist_tensor.serial_tensor + inner_serial_tensor_id = inner_serial_tensor.desc.id() + self._dist_tensors_for_program[inner_serial_tensor_id] = dist_tensor + + def add_dist_op_for_program(self, dist_op): + inner_serial_op = dist_op.serial_op + inner_serial_op_id = inner_serial_op.desc.id() + self._dist_ops_for_program[inner_serial_op_id] = dist_op + + def get_dist_tensor_for_program(self, serial_tensor): + serial_tensor_id = serial_tensor.desc.id() + return self._dist_tensors_for_program.get(serial_tensor_id, None) + + def get_dist_tensor_for_graph(self, serial_tensor_node): + serial_tensor_node_id = serial_tensor_node.id() + return self._dist_tensors_for_graph.get(serial_tensor_node_id, None) + + def get_dist_op_for_program(self, serial_tensor): + serial_tensor_id = serial_tensor.desc.id() + return self._dist_ops_for_program.get(serial_tensor_id, None) + + def get_dist_op_for_graph(self, serial_tensor_node): + serial_tensor_node_id = serial_tensor_node.id() + return self._dist_ops_for_graph.get(serial_tensor_node_id, None) + + def get_tensor_dist_attr_for_program(self, serial_tensor): + serial_tensor_id = serial_tensor.desc.id() + dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None) + if dist_tensor: + return dist_tensor.dist_attr + else: + return None + + def set_tensor_dist_attr_for_program(self, serial_tensor, dist_attr): + dist_tensor = DistributedTensor(serial_tensor, dist_attr) + self.add_dist_tensor_for_program(dist_tensor) + + def get_tensor_dist_attr_for_graph(self, serial_tensor_node): + serial_tensor_node_id = serial_tensor_node.id() + dist_tensor = self._dist_tensors_for_graph.get(serial_tensor_node_id, + None) + if dist_tensor: + return dist_tensor.dist_attr + else: + return None + + def set_tensor_dist_attr_for_graph(self, serial_tensor_node, dist_attr): + assert serial_tensor_node.is_var() and \ + serial_tensor_node.var() is not None + serial_tensor_id = serial_tensor_node.var().id() + dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None) + assert dist_tensor is not None, \ + "The distributed tensor of the program has not been added to this context." + serial_tensor_node_id = serial_tensor_node.id() + new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor, + dist_attr) + self._dist_tensors_for_graph[serial_tensor_node_id] = new_dist_tensor + + def get_op_dist_attr_for_program(self, serial_op): + serial_op_id = serial_op.desc.id() + dist_op = self._dist_ops_for_program.get(serial_op_id, None) + if dist_op: + return dist_op.dist_attr + else: + return None + + def set_op_dist_attr_for_program(self, serial_op, dist_attr): + dist_op = DistributedOperator(serial_op, dist_attr) + self.add_dist_op_for_program(dist_op) + + def get_op_dist_attr_for_graph(self, serial_op_node): + serial_op_node_id = serial_op_node.id() + dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None) + if dist_op: + return dist_op.dist_attr + else: + return None + + def set_op_dist_attr_for_graph(self, serial_op_node, dist_attr): + assert serial_op_node.is_op() and \ + serial_op_node.op() is not None + serial_op_id = serial_op_node.op().id() + dist_op = self._dist_ops_for_program.get(serial_op_id, None) + assert dist_op is not None, \ + "The distributed operator of the program has not been added to this context." + serial_op_node_id = serial_op_node.id() + new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr) + self._dist_ops_for_graph[serial_op_node_id] = new_dist_op + + def init_dist_attr_for_program(self): + assert self._serial_program, \ + "Please set the program of this context before initializing its distribute attributes." + if self._is_initialized_for_program: + return + # Copy the dist tensors and dist ops annotated by users from the default context + default_ctx = get_default_distributed_context() + self._process_meshes = copy.deepcopy(default_ctx.process_meshes) + for block in self._serial_program.blocks: + for tensor in block.vars.values(): + # Copy the distributed tensors in the default context + default_dist_tensor = default_ctx.get_dist_tensor_for_program( + tensor) + if default_dist_tensor and default_ctx is not self: + self.add_dist_tensor_for_program(default_dist_tensor) + current_dist_tensor = self.get_dist_tensor_for_program(tensor) + if current_dist_tensor is None: + dist_tensor = DistributedTensor(tensor) + self.add_dist_tensor_for_program(dist_tensor) + for op in block.ops: + # Copy the distributed operators in the default context + default_dist_op = default_ctx.get_dist_op_for_program(op) + if default_dist_op and default_ctx is not self: + self.add_dist_op_for_program(default_dist_op) + current_dist_op = self.get_dist_op_for_program(op) + if current_dist_op is None: + dist_op = DistributedOperator(op) + self.add_dist_op_for_program(dist_op) + self._is_initialized_for_program = True + + def init_dist_attr_for_graph(self): + assert self._is_initialized_for_program, \ + "The program must be initialized before initializing the distributed attributes for its graph." + if self._is_initialized_for_graph: + return + # Convert program to graph + self._serial_graph = framework.IrGraph( + core.Graph(self._serial_program.desc)) + all_nodes = self._serial_graph.all_nodes() + for node in all_nodes: + if node.is_var() and node.var() is not None: + tensor_desc = node.var() + tensor_id = tensor_desc.id() + dist_tensor = self._dist_tensors_for_program.get(tensor_id, + None) + assert dist_tensor is not None, \ + "Tensor must have a distributed tensor after the initialization for program." + self.set_tensor_dist_attr_for_graph(node, dist_tensor.dist_attr) + if node.is_op() and node.op() is not None: + op_desc = node.op() + op_id = op_desc.id() + dist_op = self._dist_ops_for_program.get(op_id, None) + assert dist_op is not None, \ + "Operator must have a distributed operator after the initialization for program." + self.set_op_dist_attr_for_graph(node, dist_op.dist_attr) + self._is_initialized_for_graph = True + + def clear_dist_info_for_program(self): + self._dist_tensors_for_program.clear() + self._dist_ops_for_program.clear() + + def clear_dist_info_for_graph(self): + self._dist_tensors_for_graph.clear() + self._dist_ops_for_graph.clear() + + def copy_dist_attr_from_graph_to_program(self): + assert self._is_initialized_for_program and self._is_initialized_for_graph, \ + "Both program and graph must be initialized." + updated_tensors = {} + all_nodes = self._serial_graph.all_nodes() + for node in all_nodes: + if node.is_var() and node.var() is not None: + tensor_desc = node.var() + tensor_id = tensor_desc.id() + updated = updated_tensors.get(tensor_desc.name(), False) + # If a var has multiples var nodes in graph, only use the first one for now + if not updated: + tensor_dist_attr_for_graph = self.get_tensor_dist_attr_for_graph( + node) + dist_tensor_for_program = self._dist_tensors_for_program[ + tensor_id] + dist_tensor_for_program.dist_attr = tensor_dist_attr_for_graph + updated_tensors[tensor_desc.name()] = True + if node.is_op() and node.op() is not None: + op_desc = node.op() + op_id = op_desc.id() + op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node) + dist_op_for_program = self._dist_ops_for_program[op_id] + dist_op_for_program.dist_attr = op_dist_attr_for_graph + + def amend_dist_attr_for_program(self): + for dist_tensor in self._dist_tensors_for_program.values(): + serial_tensor = dist_tensor.serial_tensor + dist_attr = dist_tensor.dist_attr + if serial_tensor.type == core.VarDesc.VarType.READER: + tensor_shape = [] + else: + tensor_shape = serial_tensor.shape + dims_mapping = dist_attr.dims_mapping + process_mesh_shape = dist_attr.process_mesh.topology + # If the dimension of tensor is less than the sharding dimension of process mesh, + # we just amend the dimension mapping to -1. (Is this really OK?) + for i in range(len(tensor_shape)): + if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ + and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: + dims_mapping[i] = -1 + + for dist_op in self._dist_ops_for_program.values(): + serial_op = dist_op.serial_op + dist_attr = dist_op.dist_attr + for arg_name in serial_op.input_arg_names: + if dist_op.get_serial_input(arg_name) is None: + tensor_shape = [] + else: + if dist_op.get_serial_input(arg_name).type == core.VarDesc.VarType.READER \ + or dist_op.serial_op.type == "create_py_reader": + tensor_shape = [] + else: + tensor_shape = dist_op.get_serial_input(arg_name).shape + dims_mapping = dist_attr.get_input_dims_mapping(arg_name) + process_mesh_shape = dist_attr.process_mesh.topology + # If the dimension of tensor is less than the sharding dimension of process mesh, + # we just amend the dimension mapping to -1. (Is this really OK?) + for i in range(len(tensor_shape)): + if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ + and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: + dims_mapping[i] = -1 + for arg_name in serial_op.output_arg_names: + if dist_op.get_serial_output( + arg_name).type == core.VarDesc.VarType.READER: + tensor_shape = [] + else: + tensor_shape = dist_op.get_serial_output(arg_name).shape + dims_mapping = dist_attr.get_output_dims_mapping(arg_name) + process_mesh_shape = dist_attr.process_mesh.topology + # If the dimension of tensor is less than the sharding dimension of process mesh, + # we just amend the dimension mapping to -1. (Is this really OK?) + for i in range(len(tensor_shape)): + if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ + and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: + dims_mapping[i] = -1 + + def validate_dist_attr_for_program(self): + if not self._is_initialized_for_program: + assert False, \ + "Program must be initialized before validating its distributed attributes" + for block in self.serial_program.blocks: + for tensor in block.vars.values(): + dist_tensor = self.get_dist_tensor_for_program(tensor) + if (dist_tensor is not None) and ( + not dist_tensor.validate_dist_attr()): + assert False, "Tensor {} has a wrong distributed attributes {}.".format( + dist_tensor.serial_tensor.name, dist_tensor.dist_attr) + for op in block.ops: + dist_op = self.get_dist_op_for_program(op) + if (dist_op is not None) and (not dist_op.validate_dist_attr()): + assert False, "Operator {} has a wrong distributed attributes {}.".format( + dist_op.serial_op.type, dist_tensor.dist_attr) + return True + + +class DistributedOperatorContext: + """ + DistributedOperatorContext is used to create a dist op desc in Program. + Every time to create a new dist op, the context should be updated for it accordingly. + """ + + def __init__(self): + self._dst_main_program = None + self._dst_startup_program = None + self._varname_mapping = None + self._rank_id = None + self._cur_src_op = None + self._cur_dist_attr = None + self.gradopidx2opidx = {} + self.already_init_sync_vars = set() + + def set_dst_main_program(self, prog): + self._dst_main_program = prog + + def get_dst_main_program(self): + return self._dst_main_program + + def set_dst_startup_program(self, prog): + self._dst_startup_program = prog + + def get_dst_startup_program(self): + return self._dst_startup_program + + def set_varname_mapping(self, mapping): + self._varname_mapping = mapping + + def get_varname_mapping(self): + return self._varname_mapping + + def set_rank_id(self, rank_id): + self._rank_id = rank_id + + def get_rank_id(self): + return self._rank_id + + def set_cur_src_op(self, cur_src_op): + self._cur_src_op = cur_src_op + + def get_cur_src_op(self): + return self._cur_src_op + + def prepare_forward_context(self, src_op): + + self.set_cur_src_op(src_op) + + # build input varname mapping + kinputs = {} + for input_name in src_op.desc.input_names(): + varnames = [] + for varname in src_op.desc.input(input_name): + varnames.append(self._varname_mapping[varname]) + kinputs[input_name] = varnames + + # build output varname mapping + koutputs = {} + for output_name in src_op.desc.output_names(): + varnames = [] + for varname in src_op.desc.output(output_name): + varnames.append(self._varname_mapping[varname]) + koutputs[output_name] = varnames + + return kinputs, koutputs + + def prepare_backward_context(self, backward_op): + + self.set_cur_src_op(backward_op) + + # build input varname mapping + kinputs = {} + for input_name in backward_op.desc.input_names(): + varnames = [] + for varname in backward_op.desc.input(input_name): + varnames.append(varname) + kinputs[input_name] = varnames + + # build output varname mapping + koutputs = {} + for output_name in backward_op.desc.output_names(): + varnames = [] + for varname in backward_op.desc.output(output_name): + varnames.append(varname) + koutputs[output_name] = varnames + + return kinputs, koutputs diff --git a/python/paddle/distributed/auto_parallel/dist_op.py b/python/paddle/distributed/auto_parallel/dist_op.py new file mode 100644 index 00000000000..aa447d7a423 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/dist_op.py @@ -0,0 +1,243 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import copy +from collections import defaultdict +import paddle +from paddle.fluid import core +from paddle.fluid.framework import Variable +from .dist_attribute import TensorDistributedAttribute +from .dist_attribute import OperatorDistributedAttribute +from .dist_attribute import append_op_input_suffix +from .dist_attribute import append_op_output_suffix +from .dist_attribute import get_tensor_dist_attr_field_keys +from .dist_attribute import get_op_dist_attr_field_keys + + +class DistributedOperator: + def __init__(self, serial_op, dist_attr=None): + self._serial_op = serial_op + self._serial_inputs = {} + self._serial_outputs = {} + self._dist_attr = None + # Reuse the dist_attr setter to initialize _dist_attr + self.dist_attr = dist_attr + + @property + def serial_op(self): + return self._serial_op + + @property + def dist_attr(self): + return self._dist_attr + + @dist_attr.setter + def dist_attr(self, dist_attr): + if self._dist_attr is None: + self._dist_attr = OperatorDistributedAttribute() + # Create new dist_attr related to current serial_op + dist_attr = self._filter_dist_attr(dist_attr) + # Append suffix to mark the inputs or outputs + if isinstance(dist_attr, dict): + # Copy the keys since we may add new ones + for key in list(dist_attr.keys()): + if isinstance(key, Variable): + if key.name in self._serial_op.input_arg_names: + dist_attr[append_op_input_suffix(key.name)] = True + if key.name in self._serial_op.output_arg_names: + dist_attr[append_op_output_suffix(key.name)] = True + self._dist_attr.init(dist_attr) + self._init_default_dist_attr() + + def get_serial_input(self, name): + return self._serial_inputs.get(name, None) + + def get_serial_output(self, name): + return self._serial_outputs.get(name, None) + + def _init_default_dist_attr(self): + for tensor_name in self._serial_op.input_arg_names: + if self._serial_op.type == "create_py_reader": + tensor = None + else: + tensor = self._serial_op.block._var_recursive(tensor_name) + self._serial_inputs[tensor_name] = tensor + if tensor is None: + tensor_shape = [] + else: + if tensor.type == core.VarDesc.VarType.READER: + tensor_shape = [] + else: + tensor_shape = tensor.shape + if self._dist_attr.get_input_dims_mapping(tensor_name) is None: + tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))] + self._dist_attr.set_input_dims_mapping(tensor_name, + tensor_dims_mapping) + for tensor_name in self._serial_op.output_arg_names: + tensor = self._serial_op.block._var_recursive(tensor_name) + if tensor.type == core.VarDesc.VarType.READER: + tensor_shape = [] + else: + tensor_shape = tensor.shape + self._serial_outputs[tensor_name] = tensor + if self._dist_attr.get_output_dims_mapping(tensor_name) is None: + tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))] + self._dist_attr.set_output_dims_mapping(tensor_name, + tensor_dims_mapping) + if self._dist_attr.impl_type is None: + self._dist_attr.impl_type = "default" + if self._dist_attr.impl_idx is None: + self._dist_attr.impl_idx = -2 + + def _filter_dist_attr(self, dist_attr): + if dist_attr is None: + return None + new_dist_attr = None + if isinstance(dist_attr, dict): + new_dist_attr = {} + for key, value in dist_attr.items(): + if isinstance(key, Variable): + if key.name in self._serial_op.input_arg_names \ + or key.name in self._serial_op.output_arg_names: + new_dist_attr[key] = value + else: + new_dist_attr[key] = value + elif isinstance(dist_attr, OperatorDistributedAttribute): + new_dist_attr = copy.deepcopy(dist_attr) + new_dist_attr._inputs_dist_attrs.clear() + new_dist_attr._outputs_dist_attrs.clear() + for tensor_name in self._serial_op.input_arg_names: + tensor_dist_attr = dist_attr.get_input_dist_attr(tensor_name) + if tensor_dist_attr: + new_dist_attr.set_input_dist_attr(tensor_name, + tensor_dist_attr) + for tensor_name in self._serial_op.output_arg_names: + tensor_dist_attr = dist_attr.get_output_dist_attr(tensor_name) + if tensor_dist_attr: + new_dist_attr.set_output_dist_attr(tensor_name, + tensor_dist_attr) + else: + assert False, "Cannot recognize the {} parameter.".format(dist_attr) + return new_dist_attr + + def validate_dist_attr(self): + if "read" in self.serial_op.type: + return True + for name in self.serial_op.input_arg_names: + input_dist_attr = self.dist_attr.get_input_dist_attr(name) + dims_mapping = input_dist_attr.dims_mapping + shape = self.get_serial_input(name).shape + if len(shape) != len(dims_mapping): + return False + for i in range(len(dims_mapping)): + if dims_mapping[i] < -1 or dims_mapping[i] >= len( + self.dist_attr.process_mesh.topology): + return False + for i in range(len(self.dist_attr.process_mesh.topology)): + if dims_mapping.count(i) > 1: + return False + if self.dist_attr.process_mesh != input_dist_attr.process_mesh: + return False + + for name in self.serial_op.output_arg_names: + output_dist_attr = self.dist_attr.get_output_dist_attr(name) + dims_mapping = output_dist_attr.dims_mapping + shape = self.get_serial_output(name).shape + if len(shape) != len(dims_mapping): + return False + for i in range(len(dims_mapping)): + if dims_mapping[i] < -1 or dims_mapping[i] >= len( + self.dist_attr.process_mesh.topology): + return False + for i in range(len(self.dist_attr.process_mesh.topology)): + if dims_mapping.count(i) > 1: + return False + if self.dist_attr.process_mesh != output_dist_attr.process_mesh: + return False + return True + + def __str__(self): + str = "{{op type: {}, op id: {}".format(self.serial_op.desc.type(), + self.serial_op.desc.id()) + + # str += ", {}".format(self.dist_attr) + # return str + + if self.dist_attr.is_annotated("process_mesh"): + annotated_str = "annotated" + else: + annotated_str = "non-annotated" + str += ", process_mesh ({}): {}".format(annotated_str, + self.dist_attr.process_mesh) + + for arg_name in self.serial_op.desc.input_arg_names(): + dims_mapping = self.dist_attr.get_input_dims_mapping(arg_name) + if self.dist_attr.is_annotated_input_dims_mapping(arg_name): + annotated_str = "annotated" + else: + annotated_str = "non-annotated" + if self.get_serial_input(arg_name) is not None: + if self.get_serial_input(arg_name).is_parameter: + is_parameter_str = "parameter" + else: + is_parameter_str = "non-parameter" + else: + is_parameter_str = "non-parameter" + str += ", {}'s dims_mapping (input, {}, {}): {}".format( + arg_name, annotated_str, is_parameter_str, dims_mapping) + + for arg_name in self.serial_op.desc.output_arg_names(): + dims_mapping = self.dist_attr.get_output_dims_mapping(arg_name) + if self.dist_attr.is_annotated_output_dims_mapping(arg_name): + annotated_str = "annotated" + else: + annotated_str = "non-annotated" + if self.get_serial_output(arg_name) is not None: + if self.get_serial_output(arg_name).is_parameter: + is_parameter_str = "parameter" + else: + is_parameter_str = "non-parameter" + else: + is_parameter_str = "non-parameter" + str += ", {}'s dims_mapping (output, {}, {}): {}".format( + arg_name, annotated_str, is_parameter_str, dims_mapping) + + str += ", pipeline stage: {}".format(None) + + str += ", dist_impl idx: {} }}".format(self.dist_attr._impl_idx) + + return str + + +class DistributedModule: + def __init__(self, serial_module, dist_attr=None): + self._serial_module = serial_module + self._dist_attr = dist_attr + + def __call__(self, *args, **kwargs): + from .dist_context import get_default_distributed_context + main_prog = paddle.fluid.default_main_program() + main_block = main_prog.global_block() + op_size = len(main_block.ops) + output = self._serial_module(*args, **kwargs) + new_op_size = len(main_block.ops) + default_dist_ctx = get_default_distributed_context() + for idx in range(op_size, new_op_size): + op = main_block.ops[idx] + dist_op = DistributedOperator(op, self._dist_attr) + dist_op.dist_attr.mark_annotated_as(self._dist_attr) + default_dist_ctx.add_dist_op_for_program(dist_op) + if isinstance(output, Variable): + output = [output] + return list(output) diff --git a/python/paddle/distributed/auto_parallel/dist_tensor.py b/python/paddle/distributed/auto_parallel/dist_tensor.py new file mode 100644 index 00000000000..3b292d7f435 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/dist_tensor.py @@ -0,0 +1,103 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import copy +from paddle.fluid import core +from .dist_attribute import TensorDistributedAttribute +from .dist_attribute import get_tensor_dist_attr_field_keys + + +class DistributedTensor: + def __init__(self, serial_tensor, dist_attr=None): + self._serial_tensor = serial_tensor + self._dist_attr = None + self._batch_dim = 0 + # Reuse the dist_attr setter to initialize _dist_attr + self.dist_attr = dist_attr + + @property + def serial_tensor(self): + return self._serial_tensor + + @property + def dist_attr(self): + return self._dist_attr + + @dist_attr.setter + def dist_attr(self, dist_attr): + if self._dist_attr is None: + self._dist_attr = TensorDistributedAttribute() + self._dist_attr.init(dist_attr) + self._init_default_dist_attr() + + def _init_default_dist_attr(self): + if self._dist_attr.dims_mapping is None: + if self.serial_tensor.type == core.VarDesc.VarType.READER: + tensor_shape = [] + else: + tensor_shape = self._serial_tensor.shape + tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))] + self._dist_attr.dims_mapping = tensor_dims_mapping + + def validate_dist_attr(self): + if self.serial_tensor.type == core.VarDesc.VarType.READER: + return True + tensor_shape = self.serial_tensor.shape + if len(tensor_shape) != len(self.dist_attr.dims_mapping): + return False + for i in range(len(self.dist_attr.dims_mapping)): + if self.dist_attr.dims_mapping[ + i] < -1 or self.dist_attr.dims_mapping[i] >= len( + self.dist_attr.process_mesh.topology): + return False + for i in range(len(self.dist_attr.process_mesh.topology)): + if self.dist_attr.dims_mapping.count(i) > 1: + return False + return True + + def __str__(self): + str = "{{tensor name: {}, tensor id: {}".format( + self.serial_tensor.desc.name(), self.serial_tensor.desc.id()) + + # str += ", {}".format(self.dist_attr) + # return str + + if self.dist_attr.is_annotated("process_mesh"): + annotated_str = "annotated" + else: + annotated_str = "non-annotated" + str += ", process_mesh ({}): {}".format(annotated_str, + self.dist_attr.process_mesh) + + str += ", is_parameter: {}".format(self.serial_tensor.is_parameter) + + if self.dist_attr.is_annotated("dims_mapping"): + annotated_str = "annotated" + else: + annotated_str = "non-annotated" + str += ", dims_mapping ({}): {}".format(annotated_str, + self.dist_attr.dims_mapping) + + if self.dist_attr.is_annotated("shard_mask"): + annotated_str = "annotated" + else: + annotated_str = "non-annotated" + str += ", shard_mask ({}): {}".format(annotated_str, None) + + if self.dist_attr.is_annotated("offload_device"): + annotated_str = "annotated" + else: + annotated_str = "non-annotated" + str += ", offload_device ({}): {} }}".format(annotated_str, None) + return str diff --git a/python/paddle/distributed/auto_parallel/interface.py b/python/paddle/distributed/auto_parallel/interface.py index 30055c5b763..f12b85c6f2b 100644 --- a/python/paddle/distributed/auto_parallel/interface.py +++ b/python/paddle/distributed/auto_parallel/interface.py @@ -18,293 +18,34 @@ import paddle import paddle.fluid.core as core from paddle.fluid.framework import Variable from paddle.fluid.framework import in_dygraph_mode - -__all__ = [] - -# a map from ProcessMesh ids to the ProcessMesh instances -_g_process_mesh_map = dict() - -# user defined map from logical process ids to physical ones -_user_defined_physical_map = None - - -def _append_attr_suffix(name): - """ - Append auto parallel suffix for distributed attribute name. - """ - return name + core.kAutoParallelSuffix() - - -def _remove_attr_suffix(name): - """ - Remove auto parallel suffix from distributed attribute name. - """ - return name.strip(core.kAutoParallelSuffix()) +from .dist_context import get_default_distributed_context +from .dist_tensor import DistributedTensor +from .dist_op import DistributedModule +from .dist_attribute import TensorDistributedAttribute +from .dist_attribute import OperatorDistributedAttribute def _static_mode_check(): if in_dygraph_mode(): - raise RuntimeError("Auto-parallel only supports static mode, " - "please use paddle.enable_static().") - - -def _get_nested_list_shape(nested_list): - """ - Get the shape of a nested_list. - """ - result = [] - while isinstance(nested_list, list): - result.append(len(nested_list)) - nested_list = nested_list[0] - return result - - -def _flatten_nested_list(nested_list): - """ - Get a list of all items in a nested_list. - Ref: https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-a-list-of-lists - """ - result = numpy.array(nested_list).flatten().tolist() - return result - - -class ProcessMesh(object): - r""" - The class `Processmesh` describes the topology of logical processes. - A mesh is an N-dimensional array. The shape of the N-dimensional - array represents the topology of logical processes and every - element of the N-dimensional array represent a logical process. For - example, the 2-dimensional array [[2, 4, 5], [0, 1, 3]] - illustrates six logical processes organized as the topology [2, 3], - i.e., the shape of the 2-dimensional array. With the above topology, - there are two parallel groups, where the first parallel group has a - parallel degree of 2 and the second one has a parallel degree of 3. - And the first logical process is the one with id=2. - - Args: - mesh (list): an N-dimensional array (nested list) describes the toplogy - of logical processes. The shape of the N-dimensional array - represents the topology of logical processes and every - element of the N-dimensional array represents a logical process. - parent (ProcessMesh, optional): the parent ProcessMesh. None means - the ProcessMesh is the root one without parent ProcessMesh. - Default: None. - - Returns: - None - - Raises: - ValueError: If `mesh` is not an instance of list. - - Examples: - .. code-block:: python - - import paddle - import paddle.distributed as dist - - paddle.enable_static() - - mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]]) - assert mesh.parent is None - assert mesh.topology == [2, 3] - assert mesh.process_group == [2, 4, 5, 0, 1, 3] - mesh.set_placement([0, 1, 2, 3, 4, 5]) - - """ - - def __init__(self, mesh, parent=None): - _static_mode_check() - if mesh is None or not isinstance(mesh, list): - raise ValueError('mesh must be an instance of list.') - - self._topology = _get_nested_list_shape(mesh) - self._processes = _flatten_nested_list(mesh) - - # Every element of mesh must be >= 0. - assert min(self._processes) >= 0, ('All elements of mesh must be >= 0.') - - unique_ids = set(self._processes) - assert len(unique_ids) == len(self._processes), ( - 'All elements of mesh must be unique.') - - if parent is None: - # For root ProcessMesh, the ids of logical processes must be range - # from 0 to N-1, where N is the number of logical processes. - assert max(self._processes) == len(self._processes) - 1, ( - 'For root ProcessMesh, ids of logical processes must be range ' - 'from 0 to N-1, where N is the number of logical processes.') - - parent_id = core.kNoneProcessMeshIndex() - assert len(_g_process_mesh_map.keys()) == 0, ( - 'The first ProcessMesh must be the root, which has no parent.') - else: - assert len(_g_process_mesh_map.keys()) > 0, ( - 'All ProcessMesh must have a parent except the root one.') - - assert isinstance(parent, ProcessMesh), ( - 'parent must be an instance of ProcessMesh.') - parent_id = parent._desc.id - - # All elements in mesh must belong to its parent - parent_ids = set(parent.process_group) - assert unique_ids <= parent_ids, ( - 'All elements in mesh must belong to its parent.') - - self._desc = core.ProcessMeshDesc(self._topology, self._processes, - parent_id) - - self._id = self._desc.id - self._parent_id = parent_id - assert self._id not in _g_process_mesh_map, ( - "The ProcessMesh with id %d already exists." % self._id) - _g_process_mesh_map[self._id] = self - - @property - def topology(self): - r""" - Get the topology of logical processes belonging to this ProcessMesh. - This is the shape of `mesh` used to initialized this ProcessMesh. - """ - return self._topology - - @property - def process_group(self): - r""" - Get a list of all processes belonging to this ProcessMesh. - """ - return self._processes - - @property - def parent(self): - r""" - Get the parent ProcessMesh. - """ - if self._parent_id == core.kNoneProcessMeshIndex(): return None - assert self._parent_id in _g_process_mesh_map, ( - "parent with id %d does not exist." % self._parent_id) - return _g_process_mesh_map[self._parent_id] - - @property - def ndim(self): - r""" - Get the number of dimension of ProcessMesh. - """ - return len(self._topology) - - def set_placement(self, order): - """ - Set the map from logical processes to physical ones using the - user defined order. - - Args: - order (list): order of the physical process ids. - - Returns: - None - - Examples: - .. code-block:: python - - import paddle - import paddle.distributed as dist - - paddle.enable_static() - - mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]]) - mesh.set_placement([0, 1, 2, 3, 4, 5]) - - """ - assert self.parent is None, ( - "This function can only be called by the root ProcessMesh.") - unique_ids = set(order) - assert isinstance(order, list) - - assert len(unique_ids) == len(order), ( - "All elements in order must be unique.") - assert min(order) == 0 - assert max(order) == len(order) - 1, ( - "All elements in order must be from 0 to N - 1, where N " - "is the number of physical processes.") - - logical_order = self.process_group - global _user_defined_physical_map - assert _user_defined_physical_map is None, ( - "This function can only be called once.") - _user_defined_physical_map = dict() - - assert len(logical_order) == len(order) - for idx, l_id in enumerate(logical_order): - _user_defined_physical_map[l_id] = order[idx] - - def _reset_global_process_mesh_map(self): - """ - Remove all process mesh in _g_process_mesh_map, make it empty. - """ - - _g_process_mesh_map = dict() - - def __eq__(self, other): - assert other and isinstance(other, ProcessMesh) - if self.topology != other.topology or self.process_group != other.process_group: - return False - return True - - def __ne__(self, other): - return not self.__eq__(other) - - def __str__(self): - str = "shape {} and process group {}".format(self.topology, - self.process_group) - return str - - def __deepcopy__(self, memo): - cls = self.__class__ - result = cls.__new__(cls) - memo[id(self)] = result - for k, v in self.__dict__.items(): - # No need to copy the owner tensor and context - if k == "_desc": - setattr(result, k, v) - else: - setattr(result, k, copy.deepcopy(v, memo)) - return result + raise RuntimeError("Auto-parallel only supports static mode for now, " + "please use paddle.enable_static() first.") -def _dim_mapping_checker(tensor, mesh, dim_mapping): - assert isinstance(mesh, - ProcessMesh), 'The type of mesh must be ProcessMesh.' - assert isinstance(dim_mapping, - list), 'The type of dim_mapping must be list.' - assert len(tensor.shape) == len(dim_mapping), ( - 'The number of dimensions ' - 'of tensor must be the same as the length of its corresponding ' - 'dim_mapping.') - mesh_dim = len(mesh.topology) - dim_set = set() - for i in range(len(dim_mapping)): - assert dim_mapping[i] == -1 or ( - dim_mapping[i] < mesh_dim and dim_mapping[i] >= 0), ( - 'Each element ' - 'in dim_mapping must be greater than zero and less than the ' - 'length of its corresponding topology, or it must be -1.') - if dim_mapping[i] >= 0: - assert dim_mapping[i] not in dim_set - dim_set.add(dim_mapping[i]) - - -def shard_tensor(x, mesh, dim_mapping): +def shard_tensor(x, dist_attr=None): """ Add distributed attributes for a tensors. Args: - x (Tensor): the tensor to process. - mesh (ProcessMesh): an instance of ProcessMesh to describe the topology of logical processes. - dim_mapping (list): a list to describe the mapping between `x` and `mesh`, - the dimension `i` of `x` is split across the dimension `dims_mapping[i]`, where -1 means - without parition along the corresponding dimension. + x (Tensor): the tensor to be sharded. + dist_attr (dict): the tensor distributed attributes. The accepted attributes are as follow: + "process_mesh": a nested list an to describe the mesh topology of logical processes. + "dims_mapping": a list to describe the mapping between `x` and `process_mesh`, the dimension + `i` of `x` is split across the dimension `dims_mapping[i]` of `process_mesh`, + where -1 means that tensor dimension is not split. + Both process_mesh and dims_mapping are optional and users can specify as need. Returns: - Tensor: the tensor `x` itself. + Tensor: the tensor `x` annotated with distributed attributes. Examples: .. code-block:: python @@ -314,87 +55,36 @@ def shard_tensor(x, mesh, dim_mapping): paddle.enable_static() - mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]]) - x = paddle.ones([4, 6]) - dist.shard_tensor(x, mesh, [0, -1]) - - """ - _static_mode_check() - _dim_mapping_checker(x, mesh, dim_mapping) - attr_name = _append_attr_suffix('mesh_id') - x._set_attr(attr_name, mesh._id) - attr_name = _append_attr_suffix('dim_mapping') - x._set_attr(attr_name, dim_mapping) - return x - - -def set_shard_mask(x, mask): - """ - Set the mask for a tensor which mask out the tensor from some processes in its mesh. - - Args: - x (Tensor): the tensor to process. - mask (list): a nested list. The shape of `mask` must be the same as the ProcessMesh belonging to - the tensor `x`. Every value of `mask` must be one or zero, where one means - the tenor `x` will be put on the corresponding logical process and zero means the tensor `x` - will not be put on the corresponding logical process. - For example, for a ProcessMesh represented by the 2-dimensional - array [[2, 4, 5], [0, 1, 3]], and a `mask` given by the - 2-dimensional [[1, 0, 1], [0, 1, 0]], - then the tensor `x` will only be put on logical processes 2, 5 and 1. - - Returns: - Tensor: the tensor `x` itself. - - Examples: - .. code-block:: python - - import paddle - import paddle.distributed as dist - - paddle.enable_static() - - mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]]) - mask = [[1, 0, 1], [0, 1, 0]] x = paddle.ones([4, 6]) - dist.shard_tensor(x, mesh, [-1, 1]) - dist.set_shard_mask(x, mask) + dist.shard_tensor(x, dist_attr={"process_mesh": [[0, 1], [2, 3]], + "dims_mapping": [0, -1]}) """ _static_mode_check() - assert isinstance(mask, list) - np_mask = numpy.array(mask) - min_ele = numpy.min(np_mask) - max_ele = numpy.max(np_mask) - mesh_attr_name = _append_attr_suffix('mesh_id') - assert x._has_attr(mesh_attr_name), \ - "Please set process mesh for the variable firstly." - assert min_ele >= 0 and max_ele <= 1, "Elements in mask must be 0 or 1." - x_mesh = x.process_mesh - assert x_mesh, "Please set process mesh for the variable firstly." - assert x_mesh.topology == list(np_mask.shape), ( - "The shape of mask " - "must be the same as the shape of its Process Mesh.") - attr_name = _append_attr_suffix('mask') - x._set_attr(attr_name, _flatten_nested_list(mask)) + assert dist_attr is None or isinstance(dist_attr, (dict, TensorDistributedAttribute)), \ + "The type of dist_attr must be None, dict or TensorDistributedAttribute." + dist_tensor = DistributedTensor(x, dist_attr) + dist_tensor.dist_attr.mark_annotated_as(dist_attr) + default_dist_ctx = get_default_distributed_context() + default_dist_ctx.add_dist_tensor_for_program(dist_tensor) return x -def shard_op(op_fn, mesh, dim_mapping_dict, **kwargs): +def shard_op(op_fn, dist_attr=None): """ Call a functioin and add distributed attributes for ops added by the function. Args: - op_fn (callable): a callable object of an API. - mesh (ProcessMesh): an instance of ProcessMesh specifies the topology of logical processes. - dim_mapping_dict (dict): a mapping from tensor's name to its dims_mapping. - The dim_mapping is a list to describe the mapping between a tensor and `mesh`, - the dimension `i` of the tensor is split across the dimension `dim_mapping[i]`, - where -1 means without parition along the corresponding dimension. - kwargs (dict): a dict of parameter passed to the function `op_fn`. + op_fn (callable): a callable operator or module to be sharded. + dist_attr (dict): the operator distributed attributes. The accepted attributes are classified into + two categories. The first category decsribes the distributed attributes shared by all inputs and + outputs, and only `process_mesh` can be specified now. The second category describes distributed + attributes for inputs or outputs same as the `dist_attr` of `shard_tensor`. All of them are + optional and users can specify them as need. Note that `process_mesh` for operators must be the + same as these process_meshes for inputs and outputs. Returns: - list: the outputs of the function `op_fn`. + list: the outputs of the function `op_fn`, which are annotated with distributed attributes. Examples: .. code-block:: python @@ -404,100 +94,19 @@ def shard_op(op_fn, mesh, dim_mapping_dict, **kwargs): paddle.enable_static() - mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]]) x = paddle.ones([4, 6]) y = paddle.zeros([4, 6]) - kwargs = {'x': x, 'y': y} - dist.shard_op(paddle.add, mesh, None, **kwargs) - - """ - _static_mode_check() - main_prog = paddle.fluid.default_main_program() - main_block = main_prog.global_block() - op_size = len(main_block.ops) - output = op_fn(**kwargs) - new_op_size = len(main_block.ops) - if dim_mapping_dict is None: - dim_mapping_dict = dict() - else: - assert isinstance(dim_mapping_dict, - dict), 'The type of dim_mapping_dict must be dict.' - for var_name in dim_mapping_dict.keys(): - dim_mapping = dim_mapping_dict[var_name] - tensor = main_block.var(var_name) - _dim_mapping_checker(tensor, mesh, dim_mapping) - for idx in range(op_size, new_op_size): - op = main_block.ops[idx] - attr_name = _append_attr_suffix('mesh_id') - op._set_attr(attr_name, mesh._id) - for var_name in dim_mapping_dict.keys(): - assert var_name in op.output_arg_names + op.input_arg_names - attr_name = _append_attr_suffix(var_name) - if var_name in op.input_arg_names: - # we use the prefix "IN_" to indicates an input argument name - attr_name = "IN_" + attr_name - else: - # we use the prefix "OUT_" to indicates an input argument name - attr_name = "OUT_" + attr_name - op._set_attr(attr_name, dim_mapping_dict[var_name]) - - if isinstance(output, Variable): - output = [output] - return list(output) - - -def set_offload_device(x, device): - """ - Set the device that the tensor `x` will be put on. - - Args: - x (tensor): the tensor to process. - device (str): the device that the tensor `x` will be put on, e.g., 'cpu'. - - Returns: - Tensor: the tensor `x` itself. - - Examples: - .. code-block:: python - - import paddle - import paddle.distributed as dist - - paddle.enable_static() - - x = paddle.ones([4, 6]) - dist.set_offload_device(x, 'cpu') - - """ - _static_mode_check() - assert device == "cpu", "Only 'cpu' is supported for destination device." - attr_name = _append_attr_suffix("offload_device") - x._set_attr(attr_name, device) - return x - - -def set_pipeline_stage(stage): - """ - Set the pipeline stage of the following ops. - - Args: - stage (int): the pipeline stage the following ops belonging to. - - Returns: - None. - - Examples: - .. code-block:: python - - import paddle - import paddle.distributed as dist - - paddle.enable_static() - - dist.set_pipeline_stage(0) + dist_add = dist.shard_op(paddle.add, + dist_attr={ + "process_mesh": [[2, 3, 1], [0, 4, 5]], + x: {"dims_mapping": [-1, 0]}, + y: {"dims_mapping": [0, -1]} + }) + dist_add(x, y) """ - from paddle.fluid.framework import _set_pipeline_stage _static_mode_check() - assert isinstance(stage, int), 'The type of stage must be int.' - _set_pipeline_stage(stage) + assert dist_attr is None or isinstance(dist_attr, (dict, OperatorDistributedAttribute)), \ + "The type of dist_attr must be dict or OperatorDistributedAttribute." + dist_module = DistributedModule(op_fn, dist_attr) + return dist_module diff --git a/python/paddle/distributed/auto_parallel/operators/__init__.py b/python/paddle/distributed/auto_parallel/operators/__init__.py index 3b3359b4ebf..d0ddeb1dcc7 100644 --- a/python/paddle/distributed/auto_parallel/operators/__init__.py +++ b/python/paddle/distributed/auto_parallel/operators/__init__.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License -from .common import DistributedOperator +from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl -from .common import register_distributed_operator +from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl from .common import find_best_compatible_distributed_operator_impl from . import dist_embedding diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index 5685c40a322..c23de81b591 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License -DISTRIBUTED_OPERATORS = {} +_g_distributed_operator_impl_registries = {} -class DistributedOperator: +class DistributedOperatorImplContainer: def __init__(self): self._impls = [] self._name = None @@ -47,67 +47,60 @@ class DistributedOperatorImpl: def get_name(self): return self._name - def is_process_mesh_compatible(self, op_dist_attr): + def is_input_compatible(self, dist_op): raise NotImplementedError("Please Implement this method in Subclass.") - def is_input_compatible(self, op_dist_attr): + def is_output_compatible(self, dist_op): raise NotImplementedError("Please Implement this method in Subclass.") - def is_output_compatible(self, op_dist_attr): - raise NotImplementedError("Please Implement this method in Subclass.") - - def is_compatible(self, op_dist_attr): - return self.is_process_mesh_compatible(op_dist_attr) \ - and self.is_input_compatible(op_dist_attr) \ - and self.is_output_compatible(op_dist_attr) + def is_compatible(self, dist_op): + return self.is_input_compatible(dist_op) and \ + self.is_output_compatible(dist_op) - def update_dims_mapping(self, op_dist_attr): + def update_dims_mapping(self, dist_op): raise NotImplementedError("Please Implement this method in Subclass.") -def register_distributed_operator(name, dist_op): - global DISTRIBUTED_OPERATORS - DISTRIBUTED_OPERATORS[name] = dist_op +def register_distributed_operator_impl_container(name, dist_op_impl_container): + global _g_distributed_operator_impl_registries + _g_distributed_operator_impl_registries[name] = dist_op_impl_container -def get_distributed_operator(name): - global DISTRIBUTED_OPERATORS - return DISTRIBUTED_OPERATORS.get(name, None) +def get_distributed_operator_impl_container(name): + global _g_distributed_operator_impl_registries + return _g_distributed_operator_impl_registries.get(name, None) def register_distributed_operator_impl(name, dist_impl): - dist_op = get_distributed_operator(name) - if dist_op is not None: - dist_op.register_impl(dist_impl) + dist_op_impl_container = get_distributed_operator_impl_container(name) + if dist_op_impl_container is not None: + dist_op_impl_container.register_impl(dist_impl) else: - assert False, "Must register distributed operator first." + assert False, "Must register distributed operator registry first." def get_distributed_operator_impl(name, impl_idx): - global DISTRIBUTED_OPERATORS - return DISTRIBUTED_OPERATORS[name].get_impl(impl_idx) + global _g_distributed_operator_impl_registries + return _g_distributed_operator_impl_registries[name].get_impl(impl_idx) -def find_best_compatible_distributed_operator_impl(name, op_dist_attr, - fwd=True): +def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True): """ Here just return the first compatible implemention. This will be improved by cost model in the future. """ - dist_op = get_distributed_operator(name) - if dist_op is None: + dist_op_impl_container = get_distributed_operator_impl_container(name) + if dist_op_impl_container is None: return None, -1 compatible_impls = [] - impls = dist_op.get_impls() + impls = dist_op_impl_container.get_impls() if fwd: for idx, impl in enumerate(impls): - if impl.is_process_mesh_compatible(op_dist_attr) \ - and impl.is_input_compatible(op_dist_attr): + if impl.is_input_compatible(dist_op): compatible_impls.append((impl, idx)) else: for idx, impl in enumerate(impls): - if impl.is_process_mesh_compatible(op_dist_attr) \ - and impl.is_output_compatible(op_dist_attr): + if impl.is_output_compatible(dist_op): compatible_impls.append((impl, idx)) if compatible_impls: @@ -118,48 +111,84 @@ def find_best_compatible_distributed_operator_impl(name, op_dist_attr, return best_compatible_impl, idx -def copy_distributed_attr_for_var(src_op_dist_attr, var, src_var): - """ - copy src var's dist_attr to dst var - """ - import copy +# def copy_distributed_attr_for_var(src_op_dist_attr, dst_var, src_var): +# """ +# copy src var's dist_attr to dst var +# """ +# import copy - auto_paralle_context = src_op_dist_attr.get_owner_context() - dist_attr = copy.deepcopy( - auto_paralle_context.get_tensor_distributed_attr_for_program(src_var)) - dist_attr._owner_tensor = var - dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program( - src_var)._owner_context - auto_paralle_context.set_tensor_distributed_attr_for_program(var, dist_attr) +# auto_paralle_context = src_op_dist_attr.get_owner_context() +# dist_attr = copy.deepcopy( +# auto_paralle_context.get_tensor_distributed_attr_for_program(src_var)) +# dist_attr._owner_tensor = var +# dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program( +# src_var)._owner_context +# auto_paralle_context.set_tensor_distributed_attr_for_program(var, dist_attr) -def copy_distributed_attr_for_dist_op(dist_op, dst_block, src_op_dist_attr): +def copy_distributed_attr_for_var(dist_context, dst_var, src_var): + """ + copy src var's dist_attr to dst var + """ + dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var) + dist_context.set_tensor_dist_attr_for_program(dst_var, dist_attr) + + +# def copy_distributed_attr_for_dist_op(dist_op, dst_block, src_op_dist_attr): +# """ +# copy src op's dist_attr to dst dist op +# """ +# from ..attribute import OperatorDistributedAttribute + +# auto_paralle_context = src_op_dist_attr.get_owner_context() +# op_dist_attr = OperatorDistributedAttribute(dist_op, auto_paralle_context) +# auto_paralle_context._copy_distributed_attr_from_op_desc(dist_op.desc, +# op_dist_attr) +# auto_paralle_context.set_op_distributed_attr_for_program(dist_op, +# op_dist_attr) + +# op_dist_attr.set_process_mesh(src_op_dist_attr.get_process_mesh()) +# op_dist_attr.set_impl_idx(src_op_dist_attr.get_impl_idx()) + +# for input_varname in dist_op.desc.input_arg_names(): +# input_var = dst_block.var(input_varname) +# tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program( +# input_var) +# tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() +# op_dist_attr.set_input_dims_mapping(input_varname, tensor_dims_mapping) + +# for output_varname in dist_op.desc.output_arg_names(): +# output_var = dst_block.var(output_varname) +# tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program( +# output_var) +# tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() +# op_dist_attr.set_output_dims_mapping(output_varname, +# tensor_dims_mapping) + + +def copy_distributed_attr_for_dist_op(dist_context, dist_op, dst_block, + src_op_dist_attr): """ copy src op's dist_attr to dst dist op """ - from ..attribute import OperatorDistributedAttribute + from ..dist_attribute import OperatorDistributedAttribute + # need check dist op attr and its inputs and outputs - auto_paralle_context = src_op_dist_attr.get_owner_context() - op_dist_attr = OperatorDistributedAttribute(dist_op, auto_paralle_context) - auto_paralle_context._copy_distributed_attr_from_op_desc(dist_op.desc, - op_dist_attr) - auto_paralle_context.set_op_distributed_attr_for_program(dist_op, - op_dist_attr) - - op_dist_attr.set_process_mesh(src_op_dist_attr.get_process_mesh()) - op_dist_attr.set_impl_idx(src_op_dist_attr.get_impl_idx()) + op_dist_attr = OperatorDistributedAttribute() + op_dist_attr.process_mesh = src_op_dist_attr.process_mesh + op_dist_attr.impl_idx = src_op_dist_attr.impl_idx for input_varname in dist_op.desc.input_arg_names(): input_var = dst_block.var(input_varname) - tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program( + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( input_var) - tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() - op_dist_attr.set_input_dims_mapping(input_varname, tensor_dims_mapping) + op_dist_attr.set_input_dist_attr(input_varname, tensor_dist_attr) for output_varname in dist_op.desc.output_arg_names(): output_var = dst_block.var(output_varname) - tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program( + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( output_var) - tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() - op_dist_attr.set_output_dims_mapping(output_varname, - tensor_dims_mapping) + op_dist_attr.set_output_dist_attr(output_varname, tensor_dist_attr) + + dist_context.set_op_dist_attr_for_program(dist_op, op_dist_attr) + op_dist_attr = dist_context.get_op_dist_attr_for_program(dist_op) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index cf17b7afb0f..05af1b402b4 100755 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License -from .common import DistributedOperator +from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl -from .common import register_distributed_operator +from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl from ..utils import is_dim_shard from ..utils import is_dim_replicate @@ -22,26 +22,27 @@ from ..utils import is_valid_list_index from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping -from ..attribute import OperatorDistributedAttribute +from ..dist_attribute import OperatorDistributedAttribute from paddle.fluid import core, unique_name from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY -from ..process import new_process_group +from ..process_group import new_process_group from ..utils import _get_comm_group, _get_corresponding_rank -class DistributedDefault(DistributedOperator): +class DistributedDefault(DistributedOperatorImplContainer): def __init__(self, name): super(DistributedDefault, self).__init__() self._name = name -register_distributed_operator("default", DistributedDefault("default")) +register_distributed_operator_impl_container("default", + DistributedDefault("default")) -# Replicated Default +# Replicated Default class DistributedDefaultImpl0(DistributedOperatorImpl): def __init__(self, name): super(DistributedDefaultImpl0, self).__init__() @@ -49,29 +50,26 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): self._forward_implemented = True self._backward_implemented = True - def is_process_mesh_compatible(self, op_dist_attr): + def is_input_compatible(self, dist_op): raise NotImplementedError("Please Implement this method.") - def is_input_compatible(self, op_dist_attr): + def is_output_compatible(self, dist_op): raise NotImplementedError("Please Implement this method.") - def is_output_compatible(self, op_dist_attr): - raise NotImplementedError("Please Implement this method.") - - def update_dims_mapping(self, op_dist_attr): + def update_dims_mapping(self, dist_op): raise NotImplementedError("Please Implement this method.") @staticmethod def forward(ctx, *args, **kwargs): - dist_op_helper = ctx.get_dist_op_helper() - main_block = dist_op_helper.get_dst_main_program().global_block() - startup_block = dist_op_helper.get_dst_startup_program().global_block() - src_op = dist_op_helper.get_cur_src_op() - varname_mapping = dist_op_helper.get_varname_mapping() - rank_id = dist_op_helper.get_rank_id() + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.get_dst_main_program().global_block() + startup_block = dist_op_context.get_dst_startup_program().global_block() + src_op = dist_op_context.get_cur_src_op() + varname_mapping = dist_op_context.get_varname_mapping() + rank_id = dist_op_context.get_rank_id() - # check validation of inputs / outputs + # check validation of inputs / outputs for input_name in src_op.desc.input_names(): assert input_name in kwargs, "input [{}] is not given".format( input_name) @@ -100,26 +98,26 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): for varname in dist_op_desc.input_arg_names(): if startup_block.has_var(varname) and startup_block.var( varname - ).is_parameter and varname not in dist_op_helper.already_init_sync_vars: - dist_op_helper.already_init_sync_vars.add(varname) + ).is_parameter and varname not in dist_op_context.already_init_sync_vars: + dist_op_context.already_init_sync_vars.add(varname) param = startup_block.var(varname) - param_dist_attr = ctx.get_tensor_distributed_attr_for_program( - param) - process_mesh = param_dist_attr.get_process_mesh() - dims_mapping = param_dist_attr.get_dims_mapping() + param_dist_attr = ctx.get_tensor_dist_attr_for_program(param) + process_mesh = param_dist_attr.process_mesh + dims_mapping = param_dist_attr.dims_mapping # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism - if rank_id not in process_mesh.process_group: - rank_id = _get_corresponding_rank(process_mesh, rank_id) + if rank_id not in process_mesh.processes: + rank_id = _get_corresponding_rank(ctx, process_mesh, + rank_id) - # NOTE all not splited axis should be presented in mesh + # NOTE all not splited axis should be presented in mesh for axis, size in enumerate(process_mesh.topology): if size <= 1 or axis in dims_mapping: pass else: - group_ranks = _get_comm_group( - process_mesh.process_group, process_mesh.topology, - axis, rank_id) + group_ranks = _get_comm_group(process_mesh.processes, + process_mesh.topology, + axis, rank_id) sync_group = new_process_group(group_ranks) new_op = startup_block.append_op( @@ -134,12 +132,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): }) # set distributed attribute - op_attr = OperatorDistributedAttribute(new_op, ctx) - op_attr.set_process_mesh(process_mesh) + op_attr = OperatorDistributedAttribute() + op_attr.process_mesh = process_mesh op_attr.set_output_dims_mapping(param.name, dims_mapping) op_attr.set_input_dims_mapping(param.name, dims_mapping) - ctx.set_op_distributed_attr_for_program(new_op, op_attr) + ctx.set_op_dist_attr_for_program(new_op, op_attr) startup_block._sync_with_cpp() @@ -147,16 +145,16 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): def backward(ctx, *args, **kwargs): # by now the backward function only insert the gradient allreduce for dist op itself - dist_op_helper = ctx.get_dist_op_helper() - main_block = dist_op_helper.get_dst_main_program().global_block() - backward_op = dist_op_helper.get_cur_src_op() - dist_attr = ctx.get_op_distributed_attr_for_program(backward_op) + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.get_dst_main_program().global_block() + backward_op = dist_op_context.get_cur_src_op() + dist_attr = ctx.get_op_dist_attr_for_program(backward_op) assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( str(backward_op)) - rank_id = dist_op_helper.get_rank_id() + rank_id = dist_op_context.get_rank_id() # check if need gradient allreduce - # if there is a non-gradient & non-parameter input and its batch dimension is splited, + # if there is a non-gradient & non-parameter input and its batch dimension is splited, # we need insert gradient allreduce for the gradient of parameter in its output need_gradient_allreduce = False for input_name in backward_op.desc.input_names(): @@ -165,20 +163,21 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): varname).is_parameter: # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op - process_mesh = dist_attr.get_process_mesh() + process_mesh = dist_attr.process_mesh var_dim_mapping = dist_attr.get_input_dims_mapping(varname) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism - if rank_id not in process_mesh.process_group: - rank_id = _get_corresponding_rank(process_mesh, rank_id) + if rank_id not in process_mesh.processes: + rank_id = _get_corresponding_rank(ctx, process_mesh, + rank_id) mesh_shape = process_mesh.topology batch_size_axis = var_dim_mapping[0] if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: need_gradient_allreduce = True - group_ranks = _get_comm_group( - process_mesh.process_group, process_mesh.topology, - batch_size_axis, rank_id) + group_ranks = _get_comm_group(process_mesh.processes, + process_mesh.topology, + batch_size_axis, rank_id) dp_degree = len(group_ranks) dp_group = new_process_group(group_ranks) break @@ -228,17 +227,17 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): OP_ROLE_KEY: OpRole.Backward }) - dims_mapping = ctx.get_tensor_distributed_attr_for_program( - grad_var).get_dims_mapping() - process_mesh = dist_attr.get_process_mesh() + dims_mapping = ctx.get_tensor_dist_attr_for_program( + grad_var).dims_mapping + process_mesh = dist_attr.process_mesh for op in [allreduce_op, scale_op]: - op_attr = OperatorDistributedAttribute(op, ctx) - op_attr.set_process_mesh(process_mesh) + op_attr = OperatorDistributedAttribute() + op_attr.process_mesh = process_mesh op_attr.set_output_dims_mapping(grad_var.name, dims_mapping) op_attr.set_input_dims_mapping(grad_var.name, dims_mapping) - ctx.set_op_distributed_attr_for_program(op, op_attr) + ctx.set_op_dist_attr_for_program(op, op_attr) main_block._sync_with_cpp() diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index cd6d2255c81..0099d6a09c4 100755 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License -from .common import DistributedOperator +from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl -from .common import register_distributed_operator +from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl from .common import copy_distributed_attr_for_var from .common import copy_distributed_attr_for_dist_op @@ -24,25 +24,26 @@ from ..utils import is_valid_list_index from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping -from ..attribute import OperatorDistributedAttribute +from ..dist_attribute import OperatorDistributedAttribute from paddle.fluid import core, unique_name from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY -from ..process import new_process_group +from ..process_group import new_process_group from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank -class DistributedEmbedding(DistributedOperator): +class DistributedEmbedding(DistributedOperatorImplContainer): def __init__(self, name): super(DistributedEmbedding, self).__init__() self._name = name -register_distributed_operator("lookup_table_v2", - DistributedEmbedding("embedding")) -register_distributed_operator("c_embedding", DistributedEmbedding("embedding")) +register_distributed_operator_impl_container("lookup_table_v2", + DistributedEmbedding("embedding")) +register_distributed_operator_impl_container("c_embedding", + DistributedEmbedding("embedding")) # RowParallel @@ -53,12 +54,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): self._forward_implemented = True self._backward_implemented = True - def is_process_mesh_compatible(self, op_dist_attr): - """ No restriction for now. """ - return True - - def is_input_compatible(self, op_dist_attr): - op_desc = op_dist_attr.get_owner_op().desc + def is_input_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr ids_name = op_desc.input('Ids')[0] w_name = op_desc.input('W')[0] ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name) @@ -72,8 +70,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): return False return True - def is_output_compatible(self, op_dist_attr): - op_desc = op_dist_attr.get_owner_op().desc + def is_output_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr out_name = op_desc.output('Out')[0] out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) # Other dimensions must be replicate except the batch dimension @@ -82,9 +81,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): return False return True - def update_dims_mapping(self, op_dist_attr): + def update_dims_mapping(self, dist_op): changed = False - op_desc = op_dist_attr.get_owner_op().desc + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr ids_name = op_desc.input('Ids')[0] w_name = op_desc.input('W')[0] out_name = op_desc.output('Out')[0] @@ -111,16 +111,16 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): kwargs: inputname_mapping & outputname_mapping """ - dist_op_helper = ctx.get_dist_op_helper() - main_block = dist_op_helper.get_dst_main_program().global_block() - startup_block = dist_op_helper.get_dst_startup_program().global_block() - src_op = dist_op_helper.get_cur_src_op() - rank_id = dist_op_helper.get_rank_id() - op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.get_dst_main_program().global_block() + startup_block = dist_op_context.get_dst_startup_program().global_block() + src_op = dist_op_context.get_cur_src_op() + rank_id = dist_op_context.get_rank_id() + op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( str(src_op)) - # check validation of inputs / outputs + # check validation of inputs / outputs assert 'Ids' in kwargs, "input [{}] is not given".format('Ids') assert 'W' in kwargs, "input [{}] is not given".format('W') assert 'Out' in kwargs, "output [{}] is not given".format('Out') @@ -147,12 +147,12 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): Weight_var.name)[0] assert embedding_row_dim_mapping >= 0, "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format( embedding_row_dim_mapping) - process_mesh_shape = op_dist_attr.get_process_mesh().topology - process_mesh_group = op_dist_attr.get_process_mesh().process_group + process_mesh_shape = op_dist_attr.process_mesh.topology + process_mesh_group = op_dist_attr.process_mesh.processes # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism if rank_id not in process_mesh_group: - rank_id = _get_corresponding_rank(op_dist_attr.get_process_mesh(), + rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, rank_id) # A generalized method to caculate embedding offset using cartisian product @@ -162,7 +162,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): per_part_size = Weight_var.shape[0] relative_idx = relative_idx * per_part_size - # TODO caculate ring id + # TODO caculate ring id parallel_axis = embedding_row_dim_mapping group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, parallel_axis, rank_id) @@ -182,7 +182,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): stop_gradient=Out_var.stop_gradient) # copy Out_var's dist_attr to intermediate_var_0's dist_attr - copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, Out_var) + copy_distributed_attr_for_var(ctx, intermediate_var_0, Out_var) check_variable_and_dtype( Out_var, 'tensor', @@ -208,25 +208,25 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): }) # copy serial op's dist_attr to dist op's dist_attr - copy_distributed_attr_for_dist_op(c_embedding_op, main_block, + copy_distributed_attr_for_dist_op(ctx, c_embedding_op, main_block, op_dist_attr) - copy_distributed_attr_for_dist_op(c_allreduce_sum_op, main_block, + copy_distributed_attr_for_dist_op(ctx, c_allreduce_sum_op, main_block, op_dist_attr) # param initialization sync - assert Weight_var.name not in dist_op_helper.already_init_sync_vars - dist_op_helper.already_init_sync_vars.add(Weight_var.name) + assert Weight_var.name not in dist_op_context.already_init_sync_vars + dist_op_context.already_init_sync_vars.add(Weight_var.name) param = startup_block.var(Weight_var.name) - param_dist_attr = ctx.get_tensor_distributed_attr_for_program(param) - process_mesh = param_dist_attr.get_process_mesh() - dim_mapping = param_dist_attr.get_dims_mapping() + param_dist_attr = ctx.get_tensor_dist_attr_for_program(param) + process_mesh = param_dist_attr.process_mesh + dim_mapping = param_dist_attr.dims_mapping - # NOTE all not splited axis should be presented in mesh + # NOTE all not splited axis should be presented in mesh for axis, size in enumerate(process_mesh.topology): if size <= 1 or axis in dim_mapping: pass else: - group_ranks = _get_comm_group(process_mesh.process_group, + group_ranks = _get_comm_group(process_mesh.processes, process_mesh.topology, axis, rank_id) sync_group = new_process_group(group_ranks) @@ -247,17 +247,17 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): def backward(ctx, *args, **kwargs): # by now the backward function only insert the gradient allreduce for dist op itself - dist_op_helper = ctx.get_dist_op_helper() - main_block = dist_op_helper.get_dst_main_program().global_block() - backward_op = dist_op_helper.get_cur_src_op() - rank_id = dist_op_helper.get_rank_id() - dist_attr = ctx.get_op_distributed_attr_for_program(backward_op) + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.get_dst_main_program().global_block() + backward_op = dist_op_context.get_cur_src_op() + rank_id = dist_op_context.get_rank_id() + dist_attr = ctx.get_op_dist_attr_for_program(backward_op) assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( str(backward_op)) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism - if rank_id not in dist_attr.get_process_mesh().process_group: - rank_id = _get_corresponding_rank(dist_attr.get_process_mesh(), + if rank_id not in dist_attr.process_mesh.processes: + rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh, rank_id) # check if need gradient allreduce @@ -286,14 +286,14 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): kwargs['W@GRAD']) Ids_var = main_block.var(kwargs['Ids'][0]) - process_mesh = dist_attr.get_process_mesh() + process_mesh = dist_attr.process_mesh var_dim_mapping = dist_attr.get_input_dims_mapping(Ids_var.name) mesh_shape = process_mesh.topology batch_size_axis = var_dim_mapping[0] if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: need_gradient_allreduce = True - group_ranks = _get_comm_group(process_mesh.process_group, + group_ranks = _get_comm_group(process_mesh.processes, process_mesh.topology, batch_size_axis, rank_id) dp_degree = len(group_ranks) @@ -318,15 +318,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): OP_ROLE_KEY: OpRole.Backward}) main_block._sync_with_cpp() - dims_mapping = ctx.get_tensor_distributed_attr_for_program( - W_Grad_var).get_dims_mapping() - process_mesh = dist_attr.get_process_mesh() + dims_mapping = ctx.get_tensor_dist_attr_for_program( + W_Grad_var).dims_mapping + process_mesh = dist_attr.process_mesh for op in [allreduce_op, scale_op]: - op_attr = OperatorDistributedAttribute(op, ctx) - op_attr.set_process_mesh(process_mesh) + op_attr = OperatorDistributedAttribute() + op_attr.process_mesh = process_mesh op_attr.set_output_dims_mapping(W_Grad_var.name, dims_mapping) op_attr.set_input_dims_mapping(W_Grad_var.name, dims_mapping) - ctx.set_op_distributed_attr_for_program(op, op_attr) + ctx.set_op_dist_attr_for_program(op, op_attr) register_distributed_operator_impl("lookup_table_v2", diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index 2edbcd2318c..43816ba88af 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License -from .common import DistributedOperator +from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl -from .common import register_distributed_operator +from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl from .common import copy_distributed_attr_for_var from .common import copy_distributed_attr_for_dist_op @@ -24,19 +24,20 @@ from ..utils import is_valid_list_index from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping -from ..attribute import OperatorDistributedAttribute +from ..dist_attribute import OperatorDistributedAttribute from paddle.fluid import core, unique_name from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY -from ..process import new_process_group +from ..process_group import new_process_group from ..utils import _get_comm_group, _get_corresponding_rank -def _update_dims_mapping_for_matmul(op_dist_attr): +def _update_dims_mapping_for_matmul(dist_op): changed = False - op_desc = op_dist_attr.get_owner_op().desc + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] y_name = op_desc.input('Y')[0] out_name = op_desc.output('Out')[0] @@ -53,7 +54,7 @@ def _update_dims_mapping_for_matmul(op_dist_attr): if y_dims_mapping_len == 1: y_dims_mapping.insert(1, -1) - # Deal with dim > 2 and take care of broadcasting + # Deal with dim > 2 and take care of broadcasting if out_dims_mapping_len > 2: broadcast_x_dims_mapping = [] broadcast_y_dims_mapping = [] @@ -95,7 +96,7 @@ def _update_dims_mapping_for_matmul(op_dist_attr): out_dims_mapping[i] = compatible_dims_mapping[i] changed = True - # The following which uses negative index can be work + # The following which uses negative index can be work # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2 dim_changed = compute_compatible_and_update_dim_mapping( [x_dims_mapping, y_dims_mapping], [-1, -2]) @@ -112,7 +113,7 @@ def _update_dims_mapping_for_matmul(op_dist_attr): if dim_changed: changed = True - # Remove unnecessary dim mapping to make sure the lenght of dims_mapping is same as its tensor + # Remove unnecessary dim mapping to make sure the length of dims_mapping is same as its tensor if x_dims_mapping_len == 1: x_dims_mapping.pop(0) if y_dims_mapping_len == 1: @@ -129,17 +130,17 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): # by now the backward function only insert the gradient allreduce for dist op itself - dist_op_helper = ctx.get_dist_op_helper() - main_block = dist_op_helper.get_dst_main_program().global_block() - backward_op = dist_op_helper.get_cur_src_op() - rank_id = dist_op_helper.get_rank_id() - dist_attr = ctx.get_op_distributed_attr_for_program(backward_op) + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.get_dst_main_program().global_block() + backward_op = dist_op_context.get_cur_src_op() + rank_id = dist_op_context.get_rank_id() + dist_attr = ctx.get_op_dist_attr_for_program(backward_op) assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( str(backward_op)) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism - if rank_id not in dist_attr.get_process_mesh().process_group: - rank_id = _get_corresponding_rank(dist_attr.get_process_mesh(), rank_id) + if rank_id not in dist_attr.process_mesh.processes: + rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh, rank_id) # check if need gradient allreduce need_gradient_allreduce = False @@ -175,13 +176,13 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): assert not X_var.is_parameter, "left operand(X) [{}] of dist matmul should not be parameter".format( X_var.name) - process_mesh = dist_attr.get_process_mesh() + process_mesh = dist_attr.process_mesh var_dim_mapping = dist_attr.get_input_dims_mapping(X_var.name) mesh_shape = process_mesh.topology batch_size_axis = var_dim_mapping[0] if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: need_gradient_allreduce = True - group_ranks = _get_comm_group(process_mesh.process_group, + group_ranks = _get_comm_group(process_mesh.processes, process_mesh.topology, batch_size_axis, rank_id) dp_degree = len(group_ranks) @@ -207,32 +208,32 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): OP_ROLE_KEY: OpRole.Backward}) main_block._sync_with_cpp() - dims_mapping = ctx.get_tensor_distributed_attr_for_program( - Y_Grad_var).get_dims_mapping() - process_mesh = dist_attr.get_process_mesh() + dims_mapping = ctx.get_tensor_dist_attr_for_program( + Y_Grad_var).dims_mapping + process_mesh = dist_attr.process_mesh for op in [allreduce_op, scale_op]: - op_attr = OperatorDistributedAttribute(op, ctx) - op_attr.set_process_mesh(process_mesh) + op_attr = OperatorDistributedAttribute() + op_attr.process_mesh = process_mesh op_attr.set_output_dims_mapping(Y_Grad_var.name, dims_mapping) op_attr.set_input_dims_mapping(Y_Grad_var.name, dims_mapping) - ctx.set_op_distributed_attr_for_program(op, op_attr) + ctx.set_op_dist_attr_for_program(op, op_attr) -def _init_param_sync(Weight_var, dist_op_helper, startup_block, ctx, rank_id): +def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id): - assert Weight_var.name not in dist_op_helper.already_init_sync_vars + assert Weight_var.name not in dist_op_context.already_init_sync_vars assert startup_block.has_var(Weight_var.name) - dist_op_helper.already_init_sync_vars.add(Weight_var.name) + dist_op_context.already_init_sync_vars.add(Weight_var.name) param = startup_block.var(Weight_var.name) - param_dist_attr = ctx.get_tensor_distributed_attr_for_program(param) - process_mesh = param_dist_attr.get_process_mesh() - dim_mapping = param_dist_attr.get_dims_mapping() + param_dist_attr = ctx.get_tensor_dist_attr_for_program(param) + process_mesh = param_dist_attr.process_mesh + dim_mapping = param_dist_attr.dims_mapping for axis, size in enumerate(process_mesh.topology): if size <= 1 or axis in dim_mapping: pass else: - group_ranks = _get_comm_group(process_mesh.process_group, + group_ranks = _get_comm_group(process_mesh.processes, process_mesh.topology, axis, rank_id) sync_group = new_process_group(group_ranks) @@ -249,13 +250,14 @@ def _init_param_sync(Weight_var, dist_op_helper, startup_block, ctx, rank_id): startup_block._sync_with_cpp() -class DistributedMatmul(DistributedOperator): +class DistributedMatmul(DistributedOperatorImplContainer): def __init__(self, name): super(DistributedMatmul, self).__init__() self._name = name -register_distributed_operator("matmul", DistributedMatmul("matmul")) +register_distributed_operator_impl_container("matmul", + DistributedMatmul("matmul")) # ColumnParallel @@ -266,12 +268,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): self._forward_implemented = True self._backward_implemented = True - def is_process_mesh_compatible(self, op_dist_attr): - """ No restriction for now. """ - return True - - def is_input_compatible(self, op_dist_attr): - op_desc = op_dist_attr.get_owner_op().desc + def is_input_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] y_name = op_desc.input('Y')[0] x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) @@ -286,8 +285,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): return False return True - def is_output_compatible(self, op_dist_attr): - op_desc = op_dist_attr.get_owner_op().desc + def is_output_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr out_name = op_desc.output('Out')[0] out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) if is_dim_replicate(out_dims_mapping[-1]): @@ -297,9 +297,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): return False return True - def update_dims_mapping(self, op_dist_attr): + def update_dims_mapping(self, dist_op): changed = False - dim_changed = _update_dims_mapping_for_matmul(op_dist_attr) + dim_changed = _update_dims_mapping_for_matmul(dist_op) if dim_changed: changed = True return changed @@ -310,21 +310,21 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): kwargs: inputname_mapping & outputname_mapping """ - dist_op_helper = ctx.get_dist_op_helper() - main_block = dist_op_helper.get_dst_main_program().global_block() - startup_block = dist_op_helper.get_dst_startup_program().global_block() - src_op = dist_op_helper.get_cur_src_op() - rank_id = dist_op_helper.get_rank_id() - op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.get_dst_main_program().global_block() + startup_block = dist_op_context.get_dst_startup_program().global_block() + src_op = dist_op_context.get_cur_src_op() + rank_id = dist_op_context.get_rank_id() + op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( str(src_op)) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism - if rank_id not in op_dist_attr.get_process_mesh().process_group: - rank_id = _get_corresponding_rank(op_dist_attr.get_process_mesh(), + if rank_id not in op_dist_attr.process_mesh.processes: + rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, rank_id) - # check validation of inputs / outputs + # check validation of inputs / outputs for input_name in src_op.desc.input_names(): assert input_name in kwargs, "input [{}] is not given".format( input_name) @@ -348,8 +348,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): Weight_var.name)[1] assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( matmul_col_dim_mapping) - process_mesh_shape = op_dist_attr.get_process_mesh().topology - process_mesh_group = op_dist_attr.get_process_mesh().process_group + process_mesh_shape = op_dist_attr.process_mesh.topology + process_mesh_group = op_dist_attr.process_mesh.processes parallel_axis = matmul_col_dim_mapping group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, @@ -365,7 +365,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): persistable=False, stop_gradient=X_var.stop_gradient) # copy X_var's dist_attr to intermediate_var_0's dist_attr - copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, X_var) + copy_distributed_attr_for_var(ctx, intermediate_var_0, X_var) check_variable_and_dtype( X_var, 'tensor', @@ -395,13 +395,14 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): type='matmul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs) # copy serial op's dist_attr to dist op's dist_attr - copy_distributed_attr_for_dist_op(c_identity_op, main_block, + copy_distributed_attr_for_dist_op(ctx, c_identity_op, main_block, + op_dist_attr) + copy_distributed_attr_for_dist_op(ctx, matmul_op, main_block, op_dist_attr) - copy_distributed_attr_for_dist_op(matmul_op, main_block, op_dist_attr) # init param sync if Weight_var.is_parameter: - _init_param_sync(Weight_var, dist_op_helper, startup_block, ctx, + _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id) @staticmethod @@ -417,12 +418,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): self._forward_implemented = True self._backward_implemented = True - def is_process_mesh_compatible(self, op_dist_attr): - """ No restriction for now. """ - return True - - def is_input_compatible(self, op_dist_attr): - op_desc = op_dist_attr.get_owner_op().desc + def is_input_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] y_name = op_desc.input('Y')[0] x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) @@ -438,8 +436,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): return False return True - def is_output_compatible(self, op_dist_attr): - op_desc = op_dist_attr.get_owner_op().desc + def is_output_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr out_name = op_desc.output('Out')[0] out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) if is_dim_shard(out_dims_mapping[-1]): @@ -450,9 +449,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): return False return True - def update_dims_mapping(self, op_dist_attr): + def update_dims_mapping(self, dist_op): changed = False - dim_changed = _update_dims_mapping_for_matmul(op_dist_attr) + dim_changed = _update_dims_mapping_for_matmul(dist_op) if dim_changed: changed = True return changed @@ -463,21 +462,21 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): kwargs: inputname_mapping & outputname_mapping """ - dist_op_helper = ctx.get_dist_op_helper() - main_block = dist_op_helper.get_dst_main_program().global_block() - startup_block = dist_op_helper.get_dst_startup_program().global_block() - src_op = dist_op_helper.get_cur_src_op() - rank_id = dist_op_helper.get_rank_id() - op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.get_dst_main_program().global_block() + startup_block = dist_op_context.get_dst_startup_program().global_block() + src_op = dist_op_context.get_cur_src_op() + rank_id = dist_op_context.get_rank_id() + op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( str(src_op)) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism - if rank_id not in op_dist_attr.get_process_mesh().process_group: - rank_id = _get_corresponding_rank(op_dist_attr.get_process_mesh(), + if rank_id not in op_dist_attr.process_mesh.processes: + rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, rank_id) - # check validation of inputs / outputs + # check validation of inputs / outputs for input_name in src_op.desc.input_names(): assert input_name in kwargs, "input [{}] is not given".format( input_name) @@ -501,8 +500,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): Weight_var.name)[0] assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( matmul_row_dim_mapping) - process_mesh_shape = op_dist_attr.get_process_mesh().topology - process_mesh_group = op_dist_attr.get_process_mesh().process_group + process_mesh_shape = op_dist_attr.process_mesh.topology + process_mesh_group = op_dist_attr.process_mesh.processes parallel_axis = matmul_row_dim_mapping group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, @@ -528,7 +527,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): is_data=False, need_check_feed=Out_var.desc.need_check_feed()) # copy Out_var's dist_attr to intermediate_var_0's dist_attr - copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, Out_var) + copy_distributed_attr_for_var(ctx, intermediate_var_0, Out_var) matmul_op = main_block.append_op( type='matmul', @@ -547,13 +546,14 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): }) # copy serial op's dist_attr to dist op's dist_attr - copy_distributed_attr_for_dist_op(matmul_op, main_block, op_dist_attr) - copy_distributed_attr_for_dist_op(c_allreduce_sum_op, main_block, + copy_distributed_attr_for_dist_op(ctx, matmul_op, main_block, + op_dist_attr) + copy_distributed_attr_for_dist_op(ctx, c_allreduce_sum_op, main_block, op_dist_attr) # init param sync if Weight_var.is_parameter: - _init_param_sync(Weight_var, dist_op_helper, startup_block, ctx, + _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id) @staticmethod @@ -561,18 +561,15 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): _right_operand_parameter_matmul_backward(ctx, *args, **kwargs) -# ReplicateParallel +# ReplicateParallel class DistributedMatmulImpl2(DistributedOperatorImpl): def __init__(self, name): super(DistributedMatmulImpl2, self).__init__() self._name = name - def is_process_mesh_compatible(self, op_dist_attr): - """ No restriction for now. """ - return True - - def is_input_compatible(self, op_dist_attr): - op_desc = op_dist_attr.get_owner_op().desc + def is_input_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] y_name = op_desc.input('Y')[0] x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) @@ -592,8 +589,9 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): return True - def is_output_compatible(self, op_dist_attr): - op_desc = op_dist_attr.get_owner_op().desc + def is_output_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr out_name = op_desc.output('Out')[0] out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) @@ -605,9 +603,9 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): return True - def update_dims_mapping(self, op_dist_attr): + def update_dims_mapping(self, dist_op): changed = False - dim_changed = _update_dims_mapping_for_matmul(op_dist_attr) + dim_changed = _update_dims_mapping_for_matmul(dist_op) if dim_changed: changed = True return changed @@ -625,13 +623,14 @@ register_distributed_operator_impl("matmul", DistributedMatmulImpl2("replicate_parallel")) -class DistributedMatmulV2(DistributedOperator): +class DistributedMatmulV2(DistributedOperatorImplContainer): def __init__(self, name): super(DistributedMatmulV2, self).__init__() self._name = name -register_distributed_operator("matmul_v2", DistributedMatmulV2("matmul_v2")) +register_distributed_operator_impl_container("matmul_v2", + DistributedMatmulV2("matmul_v2")) # ColumnParallel @@ -642,12 +641,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): self._forward_implemented = True self._backward_implemented = True - def is_process_mesh_compatible(self, op_dist_attr): - """ No restriction for now. """ - return True - - def is_input_compatible(self, op_dist_attr): - op_desc = op_dist_attr.get_owner_op().desc + def is_input_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] y_name = op_desc.input('Y')[0] x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) @@ -662,8 +658,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): return False return True - def is_output_compatible(self, op_dist_attr): - op_desc = op_dist_attr.get_owner_op().desc + def is_output_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr out_name = op_desc.output('Out')[0] out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) if is_dim_replicate(out_dims_mapping[-1]): @@ -673,9 +670,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): return False return True - def update_dims_mapping(self, op_dist_attr): + def update_dims_mapping(self, dist_op): changed = False - dim_changed = _update_dims_mapping_for_matmul(op_dist_attr) + dim_changed = _update_dims_mapping_for_matmul(dist_op) if dim_changed: changed = True return changed @@ -686,21 +683,21 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): kwargs: inputname_mapping & outputname_mapping """ - dist_op_helper = ctx.get_dist_op_helper() - main_block = dist_op_helper.get_dst_main_program().global_block() - startup_block = dist_op_helper.get_dst_startup_program().global_block() - src_op = dist_op_helper.get_cur_src_op() - rank_id = dist_op_helper.get_rank_id() - op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.get_dst_main_program().global_block() + startup_block = dist_op_context.get_dst_startup_program().global_block() + src_op = dist_op_context.get_cur_src_op() + rank_id = dist_op_context.get_rank_id() + op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( str(src_op)) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism - if rank_id not in op_dist_attr.get_process_mesh().process_group: - rank_id = _get_corresponding_rank(op_dist_attr.get_process_mesh(), + if rank_id not in op_dist_attr.process_mesh.processes: + rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, rank_id) - # check validation of inputs / outputs + # check validation of inputs / outputs for input_name in src_op.desc.input_names(): assert input_name in kwargs, "input [{}] is not given".format( input_name) @@ -724,8 +721,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): Weight_var.name)[1] assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( matmul_col_dim_mapping) - process_mesh_shape = op_dist_attr.get_process_mesh().topology - process_mesh_group = op_dist_attr.get_process_mesh().process_group + process_mesh_shape = op_dist_attr.process_mesh.topology + process_mesh_group = op_dist_attr.process_mesh.processes parallel_axis = matmul_col_dim_mapping group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, @@ -741,7 +738,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): persistable=False, stop_gradient=X_var.stop_gradient) # copy X_var's dist_attr to intermediate_var_0's dist_attr - copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, X_var) + copy_distributed_attr_for_var(ctx, intermediate_var_0, X_var) check_variable_and_dtype( X_var, 'tensor', @@ -770,14 +767,14 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): attrs=attrs) # copy serial op's dist_attr to dist op's dist_attr - copy_distributed_attr_for_dist_op(c_identity_op, main_block, + copy_distributed_attr_for_dist_op(ctx, c_identity_op, main_block, op_dist_attr) - copy_distributed_attr_for_dist_op(matmul_v2_op, main_block, + copy_distributed_attr_for_dist_op(ctx, matmul_v2_op, main_block, op_dist_attr) # init param sync if Weight_var.is_parameter: - _init_param_sync(Weight_var, dist_op_helper, startup_block, ctx, + _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id) @staticmethod @@ -793,12 +790,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): self._forward_implemented = True self._backward_implemented = True - def is_process_mesh_compatible(self, op_dist_attr): - """ No restriction for now. """ - return True - - def is_input_compatible(self, op_dist_attr): - op_desc = op_dist_attr.get_owner_op().desc + def is_input_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] y_name = op_desc.input('Y')[0] x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) @@ -814,8 +808,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): return False return True - def is_output_compatible(self, op_dist_attr): - op_desc = op_dist_attr.get_owner_op().desc + def is_output_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr out_name = op_desc.output('Out')[0] out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) if is_dim_shard(out_dims_mapping[-1]): @@ -826,9 +821,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): return False return True - def update_dims_mapping(self, op_dist_attr): + def update_dims_mapping(self, dist_op): changed = False - dim_changed = _update_dims_mapping_for_matmul(op_dist_attr) + dim_changed = _update_dims_mapping_for_matmul(dist_op) if dim_changed: changed = True return changed @@ -839,21 +834,21 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): kwargs: inputname_mapping & outputname_mapping """ - dist_op_helper = ctx.get_dist_op_helper() - main_block = dist_op_helper.get_dst_main_program().global_block() - startup_block = dist_op_helper.get_dst_startup_program().global_block() - src_op = dist_op_helper.get_cur_src_op() - rank_id = dist_op_helper.get_rank_id() - op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.get_dst_main_program().global_block() + startup_block = dist_op_context.get_dst_startup_program().global_block() + src_op = dist_op_context.get_cur_src_op() + rank_id = dist_op_context.get_rank_id() + op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( str(src_op)) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism - if rank_id not in op_dist_attr.get_process_mesh().process_group: - rank_id = _get_corresponding_rank(op_dist_attr.get_process_mesh(), + if rank_id not in op_dist_attr.process_mesh.processes: + rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, rank_id) - # check validation of inputs / outputs + # check validation of inputs / outputs for input_name in src_op.desc.input_names(): assert input_name in kwargs, "input [{}] is not given".format( input_name) @@ -877,8 +872,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): Weight_var.name)[0] assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( matmul_row_dim_mapping) - process_mesh_shape = op_dist_attr.get_process_mesh().topology - process_mesh_group = op_dist_attr.get_process_mesh().process_group + process_mesh_shape = op_dist_attr.process_mesh.topology + process_mesh_group = op_dist_attr.process_mesh.processes parallel_axis = matmul_row_dim_mapping group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, @@ -900,7 +895,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): is_data=False, need_check_feed=Out_var.desc.need_check_feed()) # copy Out_var's dist_attr to intermediate_var_0's dist_attr - copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, Out_var) + copy_distributed_attr_for_var(ctx, intermediate_var_0, Out_var) matmul_v2_op = main_block.append_op( type='matmul_v2', @@ -919,14 +914,14 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): }) # copy serial op's dist_attr to dist op's dist_attr - copy_distributed_attr_for_dist_op(matmul_v2_op, main_block, + copy_distributed_attr_for_dist_op(ctx, matmul_v2_op, main_block, op_dist_attr) - copy_distributed_attr_for_dist_op(c_allreduce_sum_op, main_block, + copy_distributed_attr_for_dist_op(ctx, c_allreduce_sum_op, main_block, op_dist_attr) # init param sync if Weight_var.is_parameter: - _init_param_sync(Weight_var, dist_op_helper, startup_block, ctx, + _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id) @staticmethod @@ -934,18 +929,15 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): _right_operand_parameter_matmul_backward(ctx, *args, **kwargs) -# ReplicateParallel +# ReplicateParallel class DistributedMatmulV2Impl2(DistributedOperatorImpl): def __init__(self, name): super(DistributedMatmulV2Impl2, self).__init__() self._name = name - def is_process_mesh_compatible(self, op_dist_attr): - """ No restriction for now. """ - return True - - def is_input_compatible(self, op_dist_attr): - op_desc = op_dist_attr.get_owner_op().desc + def is_input_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] y_name = op_desc.input('Y')[0] x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) @@ -965,8 +957,11 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): return True - def is_output_compatible(self, op_dist_attr): - op_desc = op_dist_attr.get_owner_op().desc + def is_output_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr out_name = op_desc.output('Out')[0] out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) @@ -978,9 +973,9 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): return True - def update_dims_mapping(self, op_dist_attr): + def update_dims_mapping(self, dist_op): changed = False - dim_changed = _update_dims_mapping_for_matmul(op_dist_attr) + dim_changed = _update_dims_mapping_for_matmul(dist_op) if dim_changed: changed = True return changed diff --git a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py index 39e97850b86..8821f3bc657 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License -from .common import DistributedOperator +from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl -from .common import register_distributed_operator +from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl from ..utils import is_dim_shard from ..utils import is_dim_replicate @@ -28,13 +28,14 @@ from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype -class DistributedReshape2(DistributedOperator): +class DistributedReshape2(DistributedOperatorImplContainer): def __init__(self, name): super(DistributedReshape2, self).__init__() self._name = name -register_distributed_operator("reshape2", DistributedReshape2("reshape2")) +register_distributed_operator_impl_container("reshape2", + DistributedReshape2("reshape2")) class DistributedReshapeImpl0(DistributedOperatorImpl): @@ -44,12 +45,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): self._forward_implemented = True self._backward_implemented = True - def is_process_mesh_compatible(self, op_dist_attr): - """ No restriction for now. """ - return True - - def is_input_compatible(self, op_dist_attr): - op_desc = op_dist_attr.get_owner_op().desc + def is_input_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] out_name = op_desc.output('Out')[0] x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) @@ -60,8 +58,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): return True - def is_output_compatible(self, op_dist_attr): - op_desc = op_dist_attr.get_owner_op().desc + def is_output_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] out_name = op_desc.output('Out')[0] x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) @@ -75,9 +74,10 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): return True - def update_dims_mapping(self, op_dist_attr): + def update_dims_mapping(self, dist_op): changed = False - op_desc = op_dist_attr.get_owner_op().desc + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] out_name = op_desc.output('Out')[0] x_shape_name = op_desc.output('XShape')[0] @@ -103,15 +103,15 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): kwargs: inputname_mapping & outputname_mapping """ - dist_op_helper = ctx.get_dist_op_helper() - main_block = dist_op_helper.get_dst_main_program().global_block() - src_op = dist_op_helper.get_cur_src_op() - rank_id = dist_op_helper.get_rank_id() - op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.get_dst_main_program().global_block() + src_op = dist_op_context.get_cur_src_op() + rank_id = dist_op_context.get_rank_id() + op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( str(src_op)) - # check validation of inputs / outputs + # check validation of inputs / outputs for input_name in src_op.desc.input_names(): assert input_name in kwargs, "input [{}] is not given".format( input_name) @@ -139,7 +139,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): # got dist attribute info dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) - process_mesh_shape = op_dist_attr.get_process_mesh().topology + process_mesh_shape = op_dist_attr.process_mesh.topology # modify target shape for idx, axis in enumerate(dim_mapping): @@ -172,12 +172,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): self._forward_implemented = True self._backward_implemented = True - def is_process_mesh_compatible(self, op_dist_attr): - """ No restriction for now. """ - return True - - def is_input_compatible(self, op_dist_attr): - op_desc = op_dist_attr.get_owner_op().desc + def is_input_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] out_name = op_desc.output('Out')[0] x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) @@ -191,8 +188,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): return True - def is_output_compatible(self, op_dist_attr): - op_desc = op_dist_attr.get_owner_op().desc + def is_output_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] out_name = op_desc.output('Out')[0] x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) @@ -203,9 +201,10 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): return True - def update_dims_mapping(self, op_dist_attr): + def update_dims_mapping(self, dist_op): changed = False - op_desc = op_dist_attr.get_owner_op().desc + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] out_name = op_desc.output('Out')[0] x_shape_name = op_desc.output('XShape')[0] @@ -231,15 +230,15 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): kwargs: inputname_mapping & outputname_mapping """ - dist_op_helper = ctx.get_dist_op_helper() - main_block = dist_op_helper.get_dst_main_program().global_block() - src_op = dist_op_helper.get_cur_src_op() - rank_id = dist_op_helper.get_rank_id() - op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.get_dst_main_program().global_block() + src_op = dist_op_context.get_cur_src_op() + rank_id = dist_op_context.get_rank_id() + op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( str(src_op)) - # check validation of inputs / outputs + # check validation of inputs / outputs for input_name in src_op.desc.input_names(): assert input_name in kwargs, "input [{}] is not given".format( input_name) @@ -267,7 +266,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): # got dist attribute info dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) - process_mesh_shape = op_dist_attr.get_process_mesh().topology + process_mesh_shape = op_dist_attr.process_mesh.topology # modify target shape for idx, axis in enumerate(dim_mapping): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py index 56be75b3bea..c90fc7da89d 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License -from .common import DistributedOperator +from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl -from .common import register_distributed_operator +from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl from ..utils import is_dim_shard from ..utils import is_dim_replicate @@ -24,13 +24,14 @@ from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping -class DistributedSoftmax(DistributedOperator): +class DistributedSoftmax(DistributedOperatorImplContainer): def __init__(self, name): super(DistributedSoftmax, self).__init__() self._name = name -register_distributed_operator("softmax", DistributedSoftmax("softmax")) +register_distributed_operator_impl_container("softmax", + DistributedSoftmax("softmax")) class DistributedSoftmaxImpl(DistributedOperatorImpl): @@ -40,12 +41,9 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): self._forward_implemented = False self._backward_implemented = True - def is_process_mesh_compatible(self, op_dist_attr): - """ No restriction for now. """ - return True - - def is_input_compatible(self, op_dist_attr): - op_desc = op_dist_attr.get_owner_op().desc + def is_input_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] axis = op_desc.attr('axis') x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) @@ -58,8 +56,9 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): return True - def is_output_compatible(self, op_dist_attr): - op_desc = op_dist_attr.get_owner_op().desc + def is_output_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr out_name = op_desc.output('Out')[0] axis = op_desc.attr('axis') out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) @@ -72,9 +71,10 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): return True - def update_dims_mapping(self, op_dist_attr): + def update_dims_mapping(self, dist_op): changed = False - op_desc = op_dist_attr.get_owner_op().desc + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] out_name = op_desc.output('Out')[0] x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_transpose.py b/python/paddle/distributed/auto_parallel/operators/dist_transpose.py index 10b8bf2666f..0bfc7d9f4ca 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_transpose.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_transpose.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License -from .common import DistributedOperator +from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl -from .common import register_distributed_operator +from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl from ..utils import is_dim_shard from ..utils import is_dim_replicate @@ -24,13 +24,14 @@ from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping -class DistributedTranspose2(DistributedOperator): +class DistributedTranspose2(DistributedOperatorImplContainer): def __init__(self, name): super(DistributedTranspose2, self).__init__() self._name = name -register_distributed_operator("transpose2", DistributedTranspose2("transpose2")) +register_distributed_operator_impl_container( + "transpose2", DistributedTranspose2("transpose2")) class DistributedTranspose2Impl(DistributedOperatorImpl): @@ -40,19 +41,16 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): self._forward_implemented = False self._backward_implemented = True - def is_process_mesh_compatible(self, op_dist_attr): - """ No restriction for now. """ + def is_input_compatible(self, dist_op): return True - def is_input_compatible(self, op_dist_attr): + def is_output_compatible(self, dist_op): return True - def is_output_compatible(self, op_dist_attr): - return True - - def update_dims_mapping(self, op_dist_attr): + def update_dims_mapping(self, dist_op): changed = False - op_desc = op_dist_attr.get_owner_op().desc + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] out_name = op_desc.output('Out')[0] x_shape_name = op_desc.output('XShape')[0] diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index 8f4a4866eb8..3f26f4f5b87 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -15,11 +15,11 @@ import paddle from paddle.distributed.fleet import cloud_utils import paddle.fluid.core as core -from .context import DistributedContext -from .context import get_default_distributed_context +from .dist_context import DistributedContext +from .dist_context import get_default_distributed_context from .completion import complete_annotation, complete_backward_annotation from .partitioner import Partitioner -from .process import get_all_process_groups +from .process_group import get_all_process_groups from .utils import make_data_unshard from .reshard import reshard @@ -70,7 +70,6 @@ class AutoParallelizer: # Annotation completion completed_main_program = complete_annotation( self._original_main_program, self._dist_context) - # Logical partition rank = paddle.distributed.get_rank() partitioner = Partitioner(self._dist_strategy, self._dist_context, rank) diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index c0a91f4b53a..9af194e810f 100755 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -22,15 +22,15 @@ from paddle.fluid import core, unique_name from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid.backward import append_backward, _some_in_set_, _append_grad_suffix_ -from paddle.distributed.auto_parallel.operators.common import get_distributed_operator +from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container from paddle.fluid.clip import GradientClipBase, GradientClipByNorm, error_clip_callback, append_gradient_clip_ops, ClipGradByGlobalNorm from paddle.distributed.fleet.base.distributed_strategy import DistributedStrategy -from paddle.distributed.auto_parallel.context import DistributedContext, DistOpHelper +from paddle.distributed.auto_parallel.dist_context import DistributedContext, DistributedOperatorContext from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op, is_backward_op, is_optimizer_op from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY -from .process import new_process_group -from .interface import _g_process_mesh_map -from .attribute import OperatorDistributedAttribute +from .dist_attribute import OperatorDistributedAttribute +from .process_group import new_process_group +from .utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.completion import complete_backward_annotation, complete_update_annotation __varname_not_in_block__ = ["lod_tensor_blocking_queue_0"] @@ -68,14 +68,14 @@ class Partitioner(object): # auto completion auto.ProcessMesh(shape=[2, 4], process_group=[0, 1, 2, 3, 4, 5, 6, 7]) annotated_main_program = auto.complete_annotation(serial_main_program) - auto_paralle_context = get_default_distributed_context() + dist_context = get_default_distributed_context() # distributed strategy & rank info rank_id = paddle.distributed.get_rank() dist_strategy = fleet.DistributedStrategy() # create partitioner - Partitioner = Partitioner(dist_strategy, auto_paralle_context, rank_id) + Partitioner = Partitioner(dist_strategy, dist_context, rank_id) # create dist program with forward only # for distributed inference, using partitioned_main_prog from here @@ -93,11 +93,11 @@ class Partitioner(object): opt_ops = Partitioner.apply_optimize(optimizer, dist_params_grads, partitioned_main_prog, partitioned_startup_prog) """ - def __init__(self, dist_strategy, auto_parallel_context, rank_id=0): + def __init__(self, dist_strategy, dist_context, rank_id=0): """ Args: dist_strategy (paddle.fleet.distributed_strategy): used to determine the user defined distributed strategy. - auto_parallel_context (paddle.fluid.DistributedContext): used to access the distributed_attr of var & op, every Partitioner object could maintain its own DistributedContext member, and partition program base on that shard scenario. + dist_context (paddle.fluid.DistributedContext): used to access the distributed_attr of var & op, every Partitioner object could maintain its own DistributedContext member, and partition program base on that shard scenario. rank_id (int): global rank id to which the partitioned distributed program belong. """ @@ -106,13 +106,13 @@ class Partitioner(object): "dist_strategy be paddle.fleet.base.DistributedStrategy, got %s here" % type(dist_strategy)) - if not isinstance(auto_parallel_context, DistributedContext): + if not isinstance(dist_context, DistributedContext): raise TypeError( - "auto_parallel_context be paddle.fluid.DistributedContext, got %s here" - % type(auto_parallel_context)) + "dist_context be paddle.fluid.DistributedContext, got %s here" % + type(dist_context)) self._dist_strategy = dist_strategy - self._auto_parallel_context = auto_parallel_context + self._dist_context = dist_context self._rank_id = rank_id self._serial2dist_varname_mapping = {} self._dist_varname_suffix = "" @@ -218,8 +218,8 @@ class Partitioner(object): if not isinstance(startup_program, (Program)): raise TypeError( - "auto_parallel_context be paddle.fluid.framework.program, got %s here" - % type(startup_program)) + "dist_context be paddle.fluid.framework.program, got %s here" % + type(startup_program)) # check if shard annotated serial program valid if not self._is_valid_annotated_program(main_program): @@ -310,13 +310,12 @@ class Partitioner(object): if isinstance(var, Parameter): # TODO if var not belong to this rank, should be filtered serial_main_var = serial_main_block.var(var.name) - dist_attr = self._auto_parallel_context.get_tensor_distributed_attr_for_program( + dist_attr = self._dist_context.get_tensor_dist_attr_for_program( serial_main_var) target_shape = _get_dist_shape(serial_main_var, dist_attr) new_name = var.name + self._dist_varname_suffix temp_varname_map[var.name] = new_name - _partition_parameter(self._auto_parallel_context, - serial_main_var, + _partition_parameter(self._dist_context, serial_main_var, partitioned_startup_global_block, new_name, target_shape) param2shape[new_name] = target_shape @@ -346,24 +345,22 @@ class Partitioner(object): assert new_op.desc == new_op_desc output_var = partitioned_startup_global_block.var(output_vars[ 0]) - output_var_attr = self._auto_parallel_context.get_tensor_distributed_attr_for_program( + output_var_attr = self._dist_context.get_tensor_dist_attr_for_program( output_var) - op_attr = OperatorDistributedAttribute( - new_op, self._auto_parallel_context) - op_attr.set_process_mesh(output_var_attr.get_process_mesh()) - op_attr.set_output_dims_mapping( - output_var.name, output_var_attr.get_dims_mapping()) - op_attr.set_input_dims_mapping( - output_var.name, output_var_attr.get_dims_mapping()) - self._auto_parallel_context.set_op_distributed_attr_for_program( - new_op, op_attr) + op_attr = OperatorDistributedAttribute() + op_attr.process_mesh = output_var_attr.process_mesh + op_attr.set_output_dims_mapping(output_var.name, + output_var_attr.dims_mapping) + op_attr.set_input_dims_mapping(output_var.name, + output_var_attr.dims_mapping) + self._dist_context.set_op_dist_attr_for_program(new_op, op_attr) # TODO move helper init to a comm place - dist_op_helper = self._auto_parallel_context.get_dist_op_helper() - dist_op_helper.set_dst_main_program(partitioned_main_prog) - dist_op_helper.set_dst_startup_program(partitioned_startup_prog) - dist_op_helper.set_varname_mapping(self._serial2dist_varname_mapping) - dist_op_helper.set_rank_id(self._rank_id) + dist_op_context = self._dist_context.dist_op_context + dist_op_context.set_dst_main_program(partitioned_main_prog) + dist_op_context.set_dst_startup_program(partitioned_startup_prog) + dist_op_context.set_varname_mapping(self._serial2dist_varname_mapping) + dist_op_context.set_rank_id(self._rank_id) # transpile main program for op in serial_ops: @@ -373,8 +370,7 @@ class Partitioner(object): if serial_input_varname not in self._serial2dist_varname_mapping: new_varname = serial_input_varname + self._dist_varname_suffix if serial_main_block.has_var(serial_input_varname): - _partition_var(self._auto_parallel_context, - serial_main_block, + _partition_var(self._dist_context, serial_main_block, partitioned_global_block, serial_input_varname, new_varname) else: @@ -387,28 +383,25 @@ class Partitioner(object): for serial_output_varname in op.desc.output_arg_names(): if serial_output_varname not in self._serial2dist_varname_mapping: new_varname = serial_output_varname + self._dist_varname_suffix - _partition_var(self._auto_parallel_context, - serial_main_block, partitioned_global_block, + _partition_var(self._dist_context, serial_main_block, + partitioned_global_block, serial_output_varname, new_varname) self._serial2dist_varname_mapping[ serial_output_varname] = new_varname # partition op - kinputs, koutputs = dist_op_helper.prepare_forward_context(op) - dist_attr = self._auto_parallel_context.get_op_distributed_attr_for_program( - op) - if _is_dist_op_forward_implement(self._auto_parallel_context, op): - dist_ops = get_distributed_operator(op.type) - dist_op_impl = dist_ops.get_impl(dist_attr.get_impl_idx()) - dist_op_impl.forward(self._auto_parallel_context, **kinputs, - **koutputs) + kinputs, koutputs = dist_op_context.prepare_forward_context(op) + dist_attr = self._dist_context.get_op_dist_attr_for_program(op) + if _is_dist_op_forward_implement(self._dist_context, op): + dist_ops = get_distributed_operator_impl_container(op.type) + dist_op_impl = dist_ops.get_impl(dist_attr.impl_idx) + dist_op_impl.forward(self._dist_context, **kinputs, **koutputs) else: # replicate op - dist_ops = get_distributed_operator("default") + dist_ops = get_distributed_operator_impl_container("default") dist_op_impl = dist_ops.get_impl(0) - dist_op_impl.forward(self._auto_parallel_context, **kinputs, - **koutputs) + dist_op_impl.forward(self._dist_context, **kinputs, **koutputs) return partitioned_main_prog, partitioned_startup_prog @@ -453,18 +446,18 @@ class Partitioner(object): for param in no_grad_set ] - dist_op_helper = self._auto_parallel_context.get_dist_op_helper() + dist_op_context = self._dist_context.dist_op_context params_and_grads = _auto_backward( dist_loss, dist_startup_program, parameter_list=parameter_list, no_grad_set=no_grad_set, callbacks=callbacks, - distop_context=dist_op_helper) + distop_context=dist_op_context) # backward completion complete_backward_annotation( - dist_main_program, dist_context=self._auto_parallel_context) + dist_main_program, dist_context=self._dist_context) # transpiler backward for dist op # get backward ops @@ -485,31 +478,33 @@ class Partitioner(object): backward_ops = ops[first_backward_op_idx:] for backward_op in backward_ops: # if the backward op has a corresponding forward op - if backward_op.desc.id() in dist_op_helper.gradopidx2opidx: - forward_op_id = dist_op_helper.gradopidx2opidx[ + if backward_op.desc.id() in dist_op_context.gradopidx2opidx: + forward_op_id = dist_op_context.gradopidx2opidx[ backward_op.desc.id()] forward_op = forward_op_id2forward_op[forward_op_id] # TODO backward attr should has _impl_idx - forward_op_dist_attr = self._auto_parallel_context.get_op_distributed_attr_for_program( + forward_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( forward_op) # TODO use the backward op itself to find the dist op - dist_ops = get_distributed_operator(forward_op.type) - kinputs, koutputs = dist_op_helper.prepare_backward_context( + dist_ops = get_distributed_operator_impl_container( + forward_op.type) + kinputs, koutputs = dist_op_context.prepare_backward_context( backward_op) # TODO use backward op itself to determine impl idx - if _is_dist_op_backward_implement( - self._auto_parallel_context, forward_op): + if _is_dist_op_backward_implement(self._dist_context, + forward_op): dist_op_impl = dist_ops.get_impl( - forward_op_dist_attr.get_impl_idx()) - dist_op_impl.backward(self._auto_parallel_context, - **kinputs, **koutputs) + forward_op_dist_attr.impl_idx) + dist_op_impl.backward(self._dist_context, **kinputs, + **koutputs) else: # replicate op - dist_ops = get_distributed_operator("default") + dist_ops = get_distributed_operator_impl_container( + "default") dist_op_impl = dist_ops.get_impl(0) - dist_op_impl.backward(self._auto_parallel_context, - **kinputs, **koutputs) + dist_op_impl.backward(self._dist_context, **kinputs, + **koutputs) return params_and_grads # replace dist grad ops @@ -524,7 +519,7 @@ class Partitioner(object): # update completion complete_update_annotation( - main_program, dist_context=self._auto_parallel_context) + main_program, dist_context=self._dist_context) return optimize_ops @@ -534,12 +529,11 @@ class Partitioner(object): ops = program.global_block().ops vars_ = program.list_vars() op_dist_attrs = [ - self._auto_parallel_context.get_op_distributed_attr_for_program(op) - for op in ops + self._dist_context.get_op_dist_attr_for_program(op) for op in ops ] var_dist_attrs = [ - self._auto_parallel_context.get_tensor_distributed_attr_for_program( - var) for var in vars_ + self._dist_context.get_tensor_dist_attr_for_program(var) + for var in vars_ ] all_ops_annotated = all(dist_attr is not None @@ -563,8 +557,7 @@ class Partitioner(object): def _is_var_distributed(self, var): - dist_attr = self._auto_parallel_context.get_tensor_distributed_attr_for_program( - var) + dist_attr = self._dist_context.get_tensor_dist_attr_for_program(var) assert dist_attr is not None, "dist_attr of var [{}] is None".format( var.name) return _is_distributed(dist_attr) @@ -637,20 +630,20 @@ def _get_no_grad_set(loss, no_grad_set=None): return no_grad_set -def _is_dist_op_forward_implement(auto_paralle_context, op): - dist_attr = auto_paralle_context.get_op_distributed_attr_for_program(op) - dist_ops = get_distributed_operator(op.type) +def _is_dist_op_forward_implement(dist_context, op): + dist_attr = dist_context.get_op_dist_attr_for_program(op) + dist_ops = get_distributed_operator_impl_container(op.type) - return dist_ops and dist_attr.get_impl_idx() >= 0 and dist_ops.get_impl( \ - dist_attr.get_impl_idx())._forward_implemented + return dist_ops and dist_attr.impl_idx >= 0 and dist_ops.get_impl( \ + dist_attr.impl_idx)._forward_implemented -def _is_dist_op_backward_implement(auto_paralle_context, op): - dist_attr = auto_paralle_context.get_op_distributed_attr_for_program(op) - dist_ops = get_distributed_operator(op.type) +def _is_dist_op_backward_implement(dist_context, op): + dist_attr = dist_context.get_op_dist_attr_for_program(op) + dist_ops = get_distributed_operator_impl_container(op.type) - return dist_ops and dist_attr.get_impl_idx() >= 0 and dist_ops.get_impl( \ - dist_attr.get_impl_idx())._backward_implemented + return dist_ops and dist_attr.impl_idx >= 0 and dist_ops.get_impl( \ + dist_attr.impl_idx)._backward_implemented def _auto_backward(loss, @@ -690,8 +683,8 @@ def _auto_backward(loss, def _is_distributed(dist_attr): - mapping = dist_attr.get_dims_mapping() - mesh = dist_attr.get_process_mesh().topology + mapping = dist_attr.dims_mapping + mesh = dist_attr.process_mesh.topology for idx in range(len(mapping)): if mapping[idx] >= 0 and mesh[mapping[idx]] > 1: return True @@ -702,8 +695,8 @@ def _is_distributed(dist_attr): def _get_dist_shape(var, dist_attr): var_shape = var.shape - mapping = dist_attr.get_dims_mapping() - mesh = dist_attr.get_process_mesh().topology + mapping = dist_attr.dims_mapping + mesh = dist_attr.process_mesh.topology assert len(var_shape) == len( mapping ), "variable shape [{}] and dim_mapping [{}] is NOT match !".format( @@ -721,7 +714,7 @@ def _get_dist_shape(var, dist_attr): return new_shape -def _partition_parameter(auto_paralle_context, src_var, dst_block, dst_varname, +def _partition_parameter(dist_context, src_var, dst_block, dst_varname, dst_shape): # NOTE hack to copied Parameter # not initialized parameter, need to initialize it @@ -749,17 +742,13 @@ def _partition_parameter(auto_paralle_context, src_var, dst_block, dst_varname, # distributed_attr_uid = src_var.desc.get_distributed_attr_uid() # param.desc.set_distributed_attr_uid(distributed_attr_uid) dist_attr = copy.deepcopy( - auto_paralle_context.get_tensor_distributed_attr_for_program(src_var)) + dist_context.get_tensor_dist_attr_for_program(src_var)) assert dist_attr is not None - dist_attr._owner_tensor = param - dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program( - src_var)._owner_context - auto_paralle_context.set_tensor_distributed_attr_for_program(param, - dist_attr) + dist_context.set_tensor_dist_attr_for_program(param, dist_attr) -def _partition_intermediate_var(auto_paralle_context, src_var, dst_block, - dst_varname, dst_shape): +def _partition_intermediate_var(dist_context, src_var, dst_block, dst_varname, + dst_shape): var = dst_block.create_var( type=src_var.type, name=dst_varname, @@ -776,15 +765,12 @@ def _partition_intermediate_var(auto_paralle_context, src_var, dst_block, # distributed_attr_uid = src_var.desc.get_distributed_attr_uid() # var.desc.set_distributed_attr_uid(distributed_attr_uid) dist_attr = copy.deepcopy( - auto_paralle_context.get_tensor_distributed_attr_for_program(src_var)) + dist_context.get_tensor_dist_attr_for_program(src_var)) assert dist_attr is not None - dist_attr._owner_tensor = var - dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program( - src_var)._owner_context - auto_paralle_context.set_tensor_distributed_attr_for_program(var, dist_attr) + dist_context.set_tensor_dist_attr_for_program(var, dist_attr) -def _partition_var(auto_paralle_context, src_block, dst_block, src_varname, +def _partition_var(dist_context, src_block, dst_block, src_varname, dst_varname): """ partition include: split + replicate @@ -798,16 +784,15 @@ def _partition_var(auto_paralle_context, src_block, dst_block, src_varname, persistable=True, stop_gradient=True) else: - dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program( - src_var) + dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var) target_shape = _get_dist_shape(src_var, dist_attr) if isinstance(src_var, Parameter): - _partition_parameter(auto_paralle_context, src_var, dst_block, - dst_varname, target_shape) + _partition_parameter(dist_context, src_var, dst_block, dst_varname, + target_shape) else: - _partition_intermediate_var(auto_paralle_context, src_var, - dst_block, dst_varname, target_shape) + _partition_intermediate_var(dist_context, src_var, dst_block, + dst_varname, target_shape) def _insert_src_op(src_op, dst_block, varname_mapping): @@ -822,8 +807,7 @@ def _insert_src_op(src_op, dst_block, varname_mapping): dst_block._sync_with_cpp() -def _insert_dist_op(src_op, dst_block, varname_mapping, auto_paralle_context, - rank_id): +def _insert_dist_op(src_op, dst_block, varname_mapping, dist_context, rank_id): # build input varname mapping input_mapping = {} @@ -842,10 +826,9 @@ def _insert_dist_op(src_op, dst_block, varname_mapping, auto_paralle_context, output_mapping[output_name] = varnames # append dist op - dist_attr = auto_paralle_context.get_op_distributed_attr_for_program(src_op) - dist_ops = get_distributed_operator(src_op.type) - append_op_handle = dist_ops.get_impl(dist_attr.get_impl_idx()).forward( - src_op) + dist_attr = dist_context.get_op_dist_attr_for_program(src_op) + dist_ops = get_distributed_operator_impl_container(src_op.type) + append_op_handle = dist_ops.get_impl(dist_attr.impl_idx).forward(src_op) append_op_handle( dst_block, src_op, diff --git a/python/paddle/distributed/auto_parallel/process.py b/python/paddle/distributed/auto_parallel/process_group.py similarity index 76% rename from python/paddle/distributed/auto_parallel/process.py rename to python/paddle/distributed/auto_parallel/process_group.py index b919645b96c..8bbe6f69155 100644 --- a/python/paddle/distributed/auto_parallel/process.py +++ b/python/paddle/distributed/auto_parallel/process_group.py @@ -19,62 +19,32 @@ from ..collective import _new_ring_id from ...fluid.framework import in_dygraph_mode from ...fluid.layers.tensor import fill_constant -LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP = None -PROCESSOR_TO_PHYSICAL_PROCESS_MAP = None - - -def get_all_logical_process_set(): - from .interface import _g_process_mesh_map - all_logical_process_set = set(_g_process_mesh_map[0].process_group) - return all_logical_process_set - - -def get_logical_process_to_physical_process_map(): - global LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP - return LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP - - -def set_logical_process_to_physical_process_map(mapping): - global LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP - LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP = mapping - - -def get_processor_to_physical_process_map(): - global PROCESSOR_TO_PHYSICAL_PROCESS_MAP - return PROCESSOR_TO_PHYSICAL_PROCESS_MAP - - -def set_processor_to_physical_process_map(mapping): - global PROCESSOR_TO_PHYSICAL_PROCESS_MAP - PROCESSOR_TO_PHYSICAL_PROCESS_MAP = mapping - - -PROCESS_GROUP_MAP = {} +_g_process_group_map = {} def get_all_process_groups(): - global PROCESS_GROUP_MAP - return PROCESS_GROUP_MAP.values() + global _g_process_group_map + return _g_process_group_map.values() def new_process_group(ranks): - global PROCESS_GROUP_MAP - if not PROCESS_GROUP_MAP: + global _g_process_group_map + if not _g_process_group_map: genv = _get_global_env() - PROCESS_GROUP_MAP["global_group"] = ProcessGroup( + _g_process_group_map["global_group"] = ProcessGroup( 0, list(range(genv.world_size))) # A key constructed from ranks is used in the global process group map key = ''.join(map(str, sorted(ranks))) - if key not in PROCESS_GROUP_MAP: - num_groups = len(PROCESS_GROUP_MAP) + if key not in _g_process_group_map: + num_groups = len(_g_process_group_map) # Note: our process group may interfere with the original implementation # so the created group id should start from the original _new_ring_id() group_id = _new_ring_id() + num_groups + 1 pg = ProcessGroup(group_id, ranks) - PROCESS_GROUP_MAP[key] = pg + _g_process_group_map[key] = pg return pg else: - pg = PROCESS_GROUP_MAP[key] + pg = _g_process_group_map[key] return pg diff --git a/python/paddle/distributed/auto_parallel/process_mesh.py b/python/paddle/distributed/auto_parallel/process_mesh.py new file mode 100644 index 00000000000..ecdd77f7ea7 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/process_mesh.py @@ -0,0 +1,135 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy +import copy + + +def _get_nested_list_shape(nested_list): + """ + Get the shape of a nested_list. + """ + result = [] + while isinstance(nested_list, list): + result.append(len(nested_list)) + nested_list = nested_list[0] + return result + + +def _flatten_nested_list(nested_list): + """ + Get a list of all items in a nested_list. + Ref: https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-a-list-of-lists + """ + result = numpy.array(nested_list).flatten().tolist() + return result + + +class ProcessMesh(object): + r""" + The class `Processmesh` describes the topology of logical processes. + A mesh is an N-dimensional array. The shape of the N-dimensional + array represents the topology of logical processes and every + element of the N-dimensional array represent a logical process. For + example, the 2-dimensional array [[2, 4, 5], [0, 1, 3]] + illustrates six logical processes organized as the topology [2, 3], + i.e., the shape of the 2-dimensional array. With the above topology, + there are two parallel groups, where the first parallel group has a + parallel degree of 2 and the second one has a parallel degree of 3. + And the first logical process is the one with id=2. + + Args: + mesh (list): an N-dimensional array (nested list) describes the toplogy + of logical processes. The shape of the N-dimensional array + represents the topology of logical processes and every + element of the N-dimensional array represents a logical process. + + Returns: + None + + Raises: + ValueError: If `mesh` is not an instance of list. + + Examples: + .. code-block:: python + + import paddle + import paddle.distributed as dist + + paddle.enable_static() + + mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]]) + assert mesh.topology == [2, 3] + assert mesh.processes == [2, 4, 5, 0, 1, 3] + + """ + + def __init__(self, mesh): + if mesh is None or not isinstance(mesh, list): + raise ValueError('mesh must be an instance of list.') + + processes = _flatten_nested_list(mesh) + + assert all(isinstance(p, int) for p in processes), \ + ("All elements of mesh must be integer") + + assert min(processes) >= 0, ('All elements of mesh must be >= 0.') + + unique_processes = set(processes) + assert len(unique_processes) == len(processes), ( + 'All elements of mesh must be unique.') + + self._topology = _get_nested_list_shape(mesh) + self._processes = processes + + from .dist_context import get_default_distributed_context + default_dist_cxt = get_default_distributed_context() + default_dist_cxt.add_process_mesh(self) + + @property + def topology(self): + r""" + Get the topology of logical processes belonging to this ProcessMesh. + This is the shape of `mesh` used to initialized this ProcessMesh. + """ + return self._topology + + @property + def processes(self): + r""" + Get a list of all processes belonging to this ProcessMesh. + """ + return self._processes + + @property + def ndim(self): + r""" + Get the number of dimension of ProcessMesh. + """ + return len(self._topology) + + def __eq__(self, other): + if not isinstance(other, ProcessMesh): + return False + if self.topology != other.topology or self.processes != other.processes: + return False + return True + + def __ne__(self, other): + return not self.__eq__(other) + + def __str__(self): + str = "shape {} and process group {}".format(self.topology, + self.processes) + return str diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index 2d54bf8a788..fb130e9deef 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -22,9 +22,9 @@ from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.framework import Program, OpProtoHolder import paddle.fluid.layers.utils as utils from ..collective import _get_global_env -from .context import DistributedContext -from .attribute import OperatorDistributedAttribute, TensorDistributedAttribute -from .process import new_process_group, ProcessGroup, PROCESS_GROUP_MAP +from .dist_context import DistributedContext +from .dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute +from .process_group import new_process_group, ProcessGroup, _g_process_group_map class AllGatherOpDesc: @@ -276,20 +276,22 @@ def _is_overlapped(shape_x, shape_y): return overlapped -def _need_reshard(tensor_dist_attr, op_dist_attr): +def _need_reshard(dist_tensor, dist_op): """Judge the tensor whether needs to be resharded.""" is_reshard = False - tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() - tensor_process_mesh = tensor_dist_attr.get_process_mesh() - op_input_dims_mapping = op_dist_attr.get_input_dims_mapping( - tensor_dist_attr.get_owner_tensor().name) - op_process_mesh = op_dist_attr.get_process_mesh() + tensor_dist_attr = dist_tensor.dist_attr + tensor_name = dist_tensor.serial_tensor.name + tensor_dims_mapping = tensor_dist_attr.dims_mapping + tensor_process_mesh = tensor_dist_attr.process_mesh + op_dist_attr = dist_op.dist_attr + op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name) + op_process_mesh = op_dist_attr.process_mesh if all( map(lambda x: x is not None, [ tensor_dims_mapping, tensor_process_mesh, op_input_dims_mapping, op_process_mesh ])): - if tensor_dims_mapping != op_input_dims_mapping or tensor_process_mesh._id != op_process_mesh._id: + if tensor_dims_mapping != op_input_dims_mapping or tensor_process_mesh != op_process_mesh: is_reshard = True return is_reshard @@ -305,28 +307,30 @@ def _compute_complete_shape(slice_shape, process_shape, dims_mapping): return complete_shape -def find_op_desc_seq(source_tensor, tensor_dist_attr, op_dist_attr): +def find_op_desc_seq(dist_tensor, dist_op): """ Find the op description sequence to reshard the source tensor for matching the op requirement. Args: - source_tensor (Variable): A tensor with distributed attribute. - tensor_dist_attr (TensorDistributedAttribute): The distributed attribute of tensor. - op_dist_attr (OperatorDistributedAttribute): The distributed attribute of operator. + dist_tensor (DistributedTensor): A distributed tensor. + dist_op (DistributedOperator): A distributed operator. Returns: Dict, the dict represents the required op description sequence corresponding to process, The key of dict is process and value is a list containing op description. """ - source_dims_mapping = tensor_dist_attr.get_dims_mapping() - source_process_mesh = tensor_dist_attr.get_process_mesh() - source_process_group = source_process_mesh.process_group + tensor_dist_attr = dist_tensor.dist_attr + source_tensor = dist_tensor.serial_tensor + tensor_name = source_tensor.name + source_dims_mapping = tensor_dist_attr.dims_mapping + source_process_mesh = tensor_dist_attr.process_mesh + source_process_group = source_process_mesh.processes source_process_shape = source_process_mesh.topology - target_process_mesh = op_dist_attr.get_process_mesh() - target_dims_mapping = op_dist_attr.get_input_dims_mapping( - tensor_dist_attr.get_owner_tensor().name) - target_process_group = target_process_mesh.process_group + op_dist_attr = dist_op.dist_attr + target_process_mesh = op_dist_attr.process_mesh + target_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name) + target_process_group = target_process_mesh.processes target_process_shape = target_process_mesh.topology complete_shape = _compute_complete_shape( @@ -662,11 +666,11 @@ def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index, def _init_comm_for_send_recv(): - if not PROCESS_GROUP_MAP: + if not _g_process_group_map: genv = _get_global_env() - PROCESS_GROUP_MAP["global_group"] = ProcessGroup( + _g_process_group_map["global_group"] = ProcessGroup( 0, list(range(genv.world_size))) - PROCESS_GROUP_MAP["global_group"].instantiate() + _g_process_group_map["global_group"].instantiate() HAS_SENT = {} @@ -773,31 +777,29 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op, axes=op_desc.axes, new_var_name=new_name) - tensor_attr = TensorDistributedAttribute(target_tensor, - dist_context) - process_mesh = dist_context.get_op_distributed_attr_for_program( - matched_op).get_process_mesh() - dims_mapping = dist_context.get_op_distributed_attr_for_program( + tensor_attr = TensorDistributedAttribute() + process_mesh = dist_context.get_op_dist_attr_for_program( + matched_op).process_mesh + dims_mapping = dist_context.get_op_dist_attr_for_program( matched_op).get_input_dims_mapping(var_name) - tensor_attr.set_dims_mapping(dims_mapping) - tensor_attr.set_process_mesh(process_mesh) - dist_context.set_tensor_distributed_attr_for_program(target_tensor, - tensor_attr) + tensor_attr.dims_mapping = dims_mapping + tensor_attr.process_mesh = process_mesh + dist_context.set_tensor_dist_attr_for_program(target_tensor, + tensor_attr) # rename op input name according to new name for op in block.ops: for name in op.input_arg_names: - op_dist_attr = dist_context.get_op_distributed_attr_for_program( - op) + op_dist_attr = dist_context.get_op_dist_attr_for_program(op) if name == var_name and op_dist_attr is not None: - op_process_mesh = op_dist_attr.get_process_mesh() + op_process_mesh = op_dist_attr.process_mesh op_input_dims_mapping = op_dist_attr.get_input_dims_mapping( var_name) - if op_process_mesh._id == process_mesh._id and op_input_dims_mapping == dims_mapping: + if op_process_mesh == process_mesh and op_input_dims_mapping == dims_mapping: op.desc._rename_input(name, target_tensor.name) op_dist_attr.set_input_dims_mapping( target_tensor.name, dims_mapping) - op_dist_attr._dims_mapping.pop(name, None) + op_dist_attr.set_input_dist_attr(name, None) def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id): @@ -825,9 +827,9 @@ def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id): if op.type == "c_sync_comm_stream": need_save = [] for var_name in op.input_arg_names: - process_mesh = dist_context.get_tensor_distributed_attr_for_program( - vars[var_name]).get_process_mesh() - if rank_id in process_mesh.process_group: + process_mesh = dist_context.get_tensor_dist_attr_for_program( + vars[var_name]).process_mesh + if rank_id in process_mesh.processes: need_save.append(var_name) if not need_save: remove_op_idx.append(idx) @@ -839,10 +841,10 @@ def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id): continue # judge the other op whether should be removed. - op_dist_attr = dist_context.get_op_distributed_attr_for_program(op) + op_dist_attr = dist_context.get_op_dist_attr_for_program(op) if op_dist_attr is not None: - op_process_mesh = op_dist_attr.get_process_mesh() - if rank_id not in op_process_mesh.process_group and op.type not in not_remove_op_ref: + op_process_mesh = op_dist_attr.process_mesh + if rank_id not in op_process_mesh.processes and op.type not in not_remove_op_ref: remove_op_idx.append(idx) for idx in remove_op_idx[::-1]: @@ -974,20 +976,18 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, while idx < len(block.ops): pre_op_count = len(block.ops) op = block.ops[idx] - op_dist_attr = dist_context.get_op_distributed_attr_for_program(op) - if op_dist_attr is not None: + dist_op = dist_context.get_dist_op_for_program(op) + if dist_op is not None: idx_offset = 0 for var_name in op.input_arg_names: # skip lod_tensor_blocking_queue_0 if var_name == "lod_tensor_blocking_queue_0": continue var = block.vars[var_name] - tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program( - var) - if tensor_dist_attr is not None and _need_reshard( - tensor_dist_attr, op_dist_attr): - reshard_op_desc = find_op_desc_seq(var, tensor_dist_attr, - op_dist_attr) + dist_tensor = dist_context.get_dist_tensor_for_program(var) + if dist_tensor is not None and _need_reshard(dist_tensor, + dist_op): + reshard_op_desc = find_op_desc_seq(dist_tensor, dist_op) parse_op_desc(auto_parallel_main_prog, rank_id, reshard_op_desc, var_name, op, dist_context) cur_op_count = len(block.ops) diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 813bd481d92..dc3780f2e16 100755 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -15,7 +15,6 @@ import threading import paddle.fluid.core as core import numpy as np -from .interface import _g_process_mesh_map def is_valid_list_index(list, index): @@ -119,34 +118,35 @@ def remove_distributed_attr_suffix(name): def check_distributed_attr_for_program(program, dist_context=None): - from .context import get_default_distributed_context + from .dist_context import get_default_distributed_context if dist_context is None: dist_context = get_default_distributed_context() assert dist_context.is_initialized_for_program(), \ "Distributed attributes must be initialized before check." for block in program.blocks: for tensor in block.vars.values(): - tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program( + dist_tensor = dist_context.get_dist_tensor_for_graph(tensor) + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( tensor) - if (tensor_dist_attr is not None) and ( - not tensor_dist_attr.is_valid()): + if (tensor_dist_attr is not None) and (not dist_tensor.is_valid()): return False for op in block.ops: - op_dist_attr = dist_context.get_op_distributed_attr_for_program(op) - if (op_dist_attr is not None) and (not op_dist_attr.is_valid()): + dist_op = dist_context.get_dist_op_for_graph(tensor) + op_dist_attr = dist_context.get_op_dist_attr_for_program(op) + if (op_dist_attr is not None) and (not dist_op.is_valid()): return False return True -def print_program_with_distributed_attr(program, dist_context=None): +def print_program_with_dist_attr(program, dist_context=None): """ This function reuses the original program output ability with a distributed context. Using lock can avoid multiple threads change the default distributed context simultaneously. """ lock = threading.Lock() lock.acquire() - from .context import get_default_distributed_context - from .context import set_default_distributed_context + from .dist_context import get_default_distributed_context + from .dist_context import set_default_distributed_context if dist_context is None: dist_context = get_default_distributed_context() print(program) @@ -233,12 +233,12 @@ def _coordinate2linear_idx(mesh_shape, coordinate): """ # NOTE the following function work based on a strong an assumption - # that the processes in mesh are + # that the processes in mesh are # 1. starts from 0 - # 2. continuous - # it will be wrong if ths above condition doesnot meet, + # 2. continuous + # it will be wrong if ths above condition doesnot meet, # e.g. process_mesh = { process_groups = [7, 8, 9,10, 12, 13, 14, 15], mesh = [2, 4]} - # if you want a more general mapping, you should use cartesian product + # if you want a more general mapping, you should use cartesian product assert len(mesh_shape) == len( coordinate @@ -301,31 +301,29 @@ def _linear_idx2coordinate(mesh_shape, linear_idx): return coordinate -def _get_corresponding_rank(target_mesh, rank): +def _get_corresponding_rank(dist_context, target_mesh, rank): # TODO(JZ-LIANG) a hack method to support varying mesh in Pipeline parallelism case. # we assume that all mesh are evenly divide from a parent mesh and should have same size. # to revise this in future. coordinate = None - for key, mesh in _g_process_mesh_map.items(): - if key == 0: - continue - if rank in mesh.process_group and mesh.topology == target_mesh.topology: + for mesh in dist_context.process_meshes: + if rank in mesh.processes and mesh.topology == target_mesh.topology: coordinate = _linear_idx2coordinate(mesh.topology, - mesh.process_group.index(rank)) + mesh.processes.index(rank)) break assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format( rank) - return target_mesh.process_group[_coordinate2linear_idx(mesh.topology, - coordinate)] + return target_mesh.processes[_coordinate2linear_idx(mesh.topology, + coordinate)] def _get_unshard_dist_shape(var, dist_attr): var_shape = var.shape - mapping = dist_attr.get_dims_mapping() - mesh = dist_attr.get_process_mesh().topology + mapping = dist_attr.dims_mapping + mesh = dist_attr.process_mesh.topology assert len(var_shape) == len( mapping ), "variable shape [{}] and dim_mapping [{}] is NOT match !".format( @@ -341,19 +339,16 @@ def _get_unshard_dist_shape(var, dist_attr): def make_data_unshard(dist_main_prog, dist_startup_prog): - from .context import get_default_distributed_context + from .dist_context import get_default_distributed_context dist_context = get_default_distributed_context() for var in dist_main_prog.list_vars(): if var.is_data: - tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program( + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( var) inverse_shape = _get_unshard_dist_shape(var, tensor_dist_attr) var.desc.set_shape(inverse_shape) - dim_mapping = tensor_dist_attr.get_dims_mapping() + dim_mapping = tensor_dist_attr.dims_mapping dim_mapping = [-1] * len(dim_mapping) - tensor_dist_attr.set_dims_mapping(dim_mapping) - dist_context.set_tensor_distributed_attr_for_program( - var, tensor_dist_attr) - var._set_attr('dim_mapping' + core.kAutoParallelSuffix(), - dim_mapping) + tensor_dist_attr.dims_mapping = dim_mapping + dist_context.set_tensor_dist_attr_for_program(var, tensor_dist_attr) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index c8e7de43361..6b868903c8c 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1308,13 +1308,12 @@ class Variable(object): if self.persistable: var_str = "persist " + var_str - from paddle.distributed.auto_parallel.context import get_default_distributed_context + from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context dist_context = get_default_distributed_context() - var_dist_attr = dist_context.get_tensor_distributed_attr_for_program( - self) - if var_dist_attr is not None: + dist_tensor = dist_context.get_dist_tensor_for_program(self) + if dist_tensor is not None: var_str += ", {name} = {value}".format( - name="dist_attr", value=var_dist_attr) + name="dist_attr", value=dist_tensor) return var_str @@ -2529,12 +2528,12 @@ class Operator(object): if i != len(attr_names) - 1: attrs_str += ", " - from paddle.distributed.auto_parallel.context import get_default_distributed_context + from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context dist_context = get_default_distributed_context() - op_dist_attr = dist_context.get_op_distributed_attr_for_program(self) - if op_dist_attr is not None: + dist_op = dist_context.get_dist_op_for_program(self) + if dist_op is not None: attrs_str += ", {name} = {value}".format( - name="dist_attr", value=op_dist_attr) + name="dist_attr", value=dist_op) if outputs_str != "{}": op_str = "{outputs} = {op_type}(inputs={inputs}, {attrs})".\ diff --git a/python/paddle/fluid/tests/unittests/auto_parallel_data_unshard.py b/python/paddle/fluid/tests/unittests/auto_parallel_data_unshard.py index 367d9858626..ed8cb8a23c3 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel_data_unshard.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel_data_unshard.py @@ -36,8 +36,7 @@ class TestDataUnshard(unittest.TestCase): def create_model(train_program, start_program): with paddle.static.program_guard(train_program, start_program): - ROOT_MESH = auto.ProcessMesh([0, 1]) - MESH_0 = auto.ProcessMesh([0, 1], ROOT_MESH) + MESH_0 = auto.ProcessMesh([0, 1]) input = paddle.static.data(name='input', shape=[2, 8]) label = paddle.static.data(name='label', shape=[2, 8]) @@ -47,10 +46,30 @@ class TestDataUnshard(unittest.TestCase): linear0 = nn.Linear(8, 8, weight_attr) linear1 = nn.Linear(8, 8, weight_attr) - auto.shard_tensor(input, MESH_0, dim_mapping=[0, -1]) - auto.shard_tensor(label, MESH_0, dim_mapping=[0, -1]) - auto.shard_tensor(linear0.weight, MESH_0, dim_mapping=[-1, -1]) - auto.shard_tensor(linear1.weight, MESH_0, dim_mapping=[-1, -1]) + auto.shard_tensor( + input, + dist_attr={ + "process_mesh": MESH_0, + "dims_mapping": [0, -1] + }) + auto.shard_tensor( + label, + dist_attr={ + "process_mesh": MESH_0, + "dims_mapping": [0, -1] + }) + auto.shard_tensor( + linear0.weight, + dist_attr={ + "process_mesh": MESH_0, + "dims_mapping": [-1, -1] + }) + auto.shard_tensor( + linear1.weight, + dist_attr={ + "process_mesh": MESH_0, + "dims_mapping": [-1, -1] + }) linear0_out = linear0(input) gelu_out = F.gelu(linear0_out) @@ -105,8 +124,7 @@ class TestDataUnshard(unittest.TestCase): def create_model(train_program, start_program): with paddle.static.program_guard(train_program, start_program): - ROOT_MESH = auto.ProcessMesh([0, 1]) - MESH_0 = auto.ProcessMesh([0, 1], ROOT_MESH) + MESH_0 = auto.ProcessMesh([0, 1]) input = paddle.static.data(name='input', shape=[8, 8]) label = paddle.static.data(name='label', shape=[8, 8]) @@ -116,11 +134,31 @@ class TestDataUnshard(unittest.TestCase): linear0 = nn.Linear(8, 8, weight_attr) linear1 = nn.Linear(8, 8, weight_attr) - auto.shard_tensor(input, MESH_0, dim_mapping=[-1, -1]) - auto.shard_tensor(label, MESH_0, dim_mapping=[-1, -1]) - - auto.shard_tensor(linear0.weight, MESH_0, dim_mapping=[-1, 0]) - auto.shard_tensor(linear1.weight, MESH_0, dim_mapping=[0, -1]) + auto.shard_tensor( + input, + dist_attr={ + "process_mesh": MESH_0, + "dims_mapping": [-1, -1] + }) + auto.shard_tensor( + label, + dist_attr={ + "process_mesh": MESH_0, + "dims_mapping": [-1, -1] + }) + + auto.shard_tensor( + linear0.weight, + dist_attr={ + "process_mesh": MESH_0, + "dims_mapping": [-1, 0] + }) + auto.shard_tensor( + linear1.weight, + dist_attr={ + "process_mesh": MESH_0, + "dims_mapping": [0, -1] + }) linear0_out = linear0(input) gelu_out = F.gelu(linear0_out) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel_parallelizer.py b/python/paddle/fluid/tests/unittests/auto_parallel_parallelizer.py index 89880f8c2f4..036b46470a7 100755 --- a/python/paddle/fluid/tests/unittests/auto_parallel_parallelizer.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel_parallelizer.py @@ -24,13 +24,12 @@ import paddle.utils as utils from paddle.fluid import layers from paddle.distributed import fleet import paddle.distributed.auto_parallel as auto -from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr +from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr import paddle.fluid.core as core paddle.enable_static() _global_parallel_strategy = None _global_process_mesh = None -ROOT_MESH = auto.ProcessMesh([0, 1]) class MLPLayer(nn.Layer): @@ -78,8 +77,12 @@ def mlp_pretrain_forward(train_program, start_program): label = static.data( name="label", shape=[batch_size, sequence_len, 1], dtype='float32') - auto.shard_tensor(input, _global_process_mesh, dim_mapping=[-1, -1, -1]) - auto.set_pipeline_stage(1) + auto.shard_tensor( + input, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mappig": [-1, -1, -1] + }) mlp = MLPLayer( hidden_size=hidden_size, @@ -99,7 +102,7 @@ class TestMLPAutoParallelizer(unittest.TestCase): def test_mlp_serial(self): global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) dist_strategy = fleet.DistributedStrategy() dist_strategy.amp = False @@ -131,7 +134,7 @@ class TestMLPAutoParallelizer(unittest.TestCase): for op in block.ops: for attr_name in op.attr_names: self.assertTrue(suffix not in attr_name) - # print_program_with_distributed_attr(distributed_main_program) + # print_program_with_dist_attr(distributed_main_program) self.assertIsNotNone(distributed_startup_program) self.assertIsNotNone(distributed_main_program) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_api.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_api.py index 3f1d692b72e..8593e44b3d8 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_api.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_api.py @@ -15,128 +15,153 @@ from __future__ import print_function import unittest -import functools -import operator -import numpy as np import paddle import paddle.fluid as fluid -import paddle.fluid.core as core import paddle.nn as nn import paddle.distributed as dist +from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context +from paddle.distributed.auto_parallel.process_mesh import ProcessMesh paddle.enable_static() - -def _flatten_nested_list(nested_list): - result = functools.reduce(operator.iconcat, nested_list, []) - return result - - -def _append_attr_suffix(name): - return name + core.kAutoParallelSuffix() - - -LAST_PP_STAGE = 3 -MASK = [[0, 1, 1], [0, 1, 1]] -MESH = dist.ProcessMesh([[0, 1, 2], [3, 4, 5]]) +process_mesh1 = [0, 1, 2, 3] +process_mesh2 = [[0, 1, 2], [3, 4, 5]] class SimpleNet(nn.Layer): def __init__(self, vocab_size=128, hidden_size=4): super(SimpleNet, self).__init__() - self.mesh = MESH - self.mesh.set_placement([5, 4, 3, 2, 1, 0]) self.word_embeddings = nn.Embedding(vocab_size, hidden_size) self.dense1 = nn.Linear(hidden_size, hidden_size) self.dense2 = nn.Linear(hidden_size, hidden_size // 2) def forward(self, x, y): - x = dist.shard_tensor(x, self.mesh, dim_mapping=[0, -1]) - x = dist.set_shard_mask(x, MASK) + # Test shard_tensor interface with dist_attr arg + x = dist.shard_tensor( + x, + dist_attr={"process_mesh": process_mesh1, + "dims_mapping": [0, -1]}) emb_out = self.word_embeddings(x) - - dist.set_pipeline_stage(LAST_PP_STAGE) - - y = dist.shard_tensor(y, self.mesh, dim_mapping=[0, -1]) - dist.set_offload_device(y, "cpu") + # Test shard_tensor interface with no dist_attr arg + y = dist.shard_tensor(y) linear1 = self.dense1(y) out = self.dense2(linear1) - return x, y, self.mesh + return x, y class TestAutoParallelAPI(unittest.TestCase): def test_api(self): + dist_context = get_default_distributed_context() + net = SimpleNet() data1 = fluid.layers.fill_constant(shape=[2, 4], value=1, dtype="int64") data2 = fluid.layers.fill_constant( shape=[2, 4], value=2, dtype="float32") data3 = fluid.layers.fill_constant( shape=[2, 4], value=4, dtype="float32") - x, y, mesh = net.forward(data1, data2) - mesh_attr = _append_attr_suffix('mesh_id') - x_mesh_id = x._get_attr(mesh_attr) - self.assertEqual(x_mesh_id, mesh._id) - x_mesh = x.process_mesh - - allatts = x.attr_names - self.assertEqual(x_mesh, mesh) - shard_mask_attr = _append_attr_suffix('mask') - self.assertEqual( - x._get_attr(shard_mask_attr), _flatten_nested_list(MASK)) - self.assertEqual(x.shard_mask, _flatten_nested_list(MASK)) - offload_attr = _append_attr_suffix('offload_device') - self.assertEqual(y._get_attr(offload_attr), "cpu") - self.assertEqual(y.desc.has_attr(offload_attr), True) - self.assertEqual(y.offload_device, "cpu") - y._remove_attr(offload_attr) - self.assertEqual(y._has_attr(offload_attr), False) - ops = paddle.static.default_main_program().block(0).ops - first_op = ops[0] - last_op = ops[-1] - self.assertEqual(last_op.pipeline_stage, LAST_PP_STAGE) - - DIMS_MAPPING1 = [0, 1] - DIMS_MAPPING2 = [-1, 0] - kwargs = {'x': data2, 'y': data3} - dist.shard_op( + x, y = net.forward(data1, data2) + + dist_x = dist_context.get_dist_tensor_for_program(x) + self.assertEqual(dist_x.dist_attr.process_mesh.processes, process_mesh1) + self.assertEqual(dist_x.dist_attr.dims_mapping, [0, -1]) + self.assertEqual(dist_x.dist_attr.shard_sizes, None) + self.assertEqual(dist_x.dist_attr.device_placement, None) + self.assertTrue(dist_x.dist_attr.is_annotated("process_mesh")) + self.assertTrue(dist_x.dist_attr.is_annotated("dims_mapping")) + self.assertFalse(dist_x.dist_attr.is_annotated("shard_sizes")) + self.assertFalse(dist_x.dist_attr.is_annotated("device_placement")) + + dist_y = dist_context.get_dist_tensor_for_program(y) + self.assertEqual(dist_y.dist_attr.process_mesh, None) + self.assertEqual(dist_y.dist_attr.dims_mapping, [-1, -1]) + self.assertEqual(dist_y.dist_attr.shard_sizes, None) + self.assertEqual(dist_y.dist_attr.device_placement, None) + self.assertFalse(dist_y.dist_attr.is_annotated("process_mesh")) + self.assertFalse(dist_y.dist_attr.is_annotated("dims_mapping")) + self.assertFalse(dist_y.dist_attr.is_annotated("shard_sizes")) + self.assertFalse(dist_y.dist_attr.is_annotated("device_placement")) + + # Test shard_op interface with dist_attr + dims_mapping1 = [0, 1] + dims_mapping2 = [-1, 0] + dist_add = dist.shard_op( paddle.add, - mesh=mesh, - dim_mapping_dict={ - data2.name: DIMS_MAPPING1, - data3.name: DIMS_MAPPING2 - }, - **kwargs) + dist_attr={ + data2: { + "process_mesh": process_mesh2, + "dims_mapping": dims_mapping1 + }, + data3: { + "dims_mapping": dims_mapping2 + } + }) + results = dist_add(data2, data3) ops = paddle.static.default_main_program().block(0).ops last_op = ops[-1] - self.assertEqual(last_op.process_mesh, mesh) - attr_name = "IN_" + data2.name - attr_name = _append_attr_suffix(attr_name) - self.assertEqual(last_op.attr(attr_name), DIMS_MAPPING1) - attr_name = "IN_" + data3.name - attr_name = _append_attr_suffix(attr_name) - self.assertEqual(last_op.attr(attr_name), DIMS_MAPPING2) - - def test_process_mesh(self): - mesh1 = dist.ProcessMesh([[0, 1, 2], [3, 4, 5]], parent=MESH) - mesh2 = dist.ProcessMesh([[0, 1, 2], [3, 4, 5]], parent=mesh1) - mesh3 = dist.ProcessMesh([[0, 1], [2, 3]], parent=mesh1) - mesh4 = dist.ProcessMesh([[2, 3], [4, 5]], parent=mesh1) - - self.assertEqual(MESH.parent, None) - self.assertEqual(mesh1.parent, MESH) - self.assertEqual(mesh1._desc.parent, MESH._id) - self.assertEqual(mesh3.parent, mesh1) - self.assertEqual(mesh4.parent, mesh1) - self.assertEqual(mesh1, mesh2) - self.assertNotEqual(mesh3, mesh4) - self.assertEqual(mesh2._id, mesh2._desc.id) - self.assertEqual(mesh3.topology, mesh3._desc.topology) - self.assertEqual(mesh3.topology, [2, 2]) - self.assertEqual(mesh3.process_group, [0, 1, 2, 3]) - self.assertEqual(mesh4.process_group, mesh4._desc.process_group) + dist_op = dist_context.get_dist_op_for_program(last_op) + self.assertEqual(dist_op.dist_attr.process_mesh, + ProcessMesh(process_mesh2)) + self.assertEqual(dist_op.dist_attr.impl_type, "default") + self.assertEqual(dist_op.dist_attr.impl_idx, -2) + self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh")) + + data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name) + self.assertEqual(data2_dist_attr.process_mesh, + dist_op.dist_attr.process_mesh) + self.assertEqual(data2_dist_attr.dims_mapping, dims_mapping1) + self.assertEqual(data2_dist_attr.shard_sizes, None) + self.assertEqual(data2_dist_attr.device_placement, None) + self.assertTrue(data2_dist_attr.is_annotated("process_mesh")) + self.assertTrue(data2_dist_attr.is_annotated("dims_mapping")) + self.assertFalse(data2_dist_attr.is_annotated("shard_sizes")) + self.assertFalse(data2_dist_attr.is_annotated("device_placement")) + + data3_dist_attr = dist_op.dist_attr.get_input_dist_attr(data3.name) + self.assertEqual(data3_dist_attr.process_mesh, + dist_op.dist_attr.process_mesh) + self.assertEqual(data3_dist_attr.dims_mapping, dims_mapping2) + self.assertEqual(data3_dist_attr.shard_sizes, None) + self.assertEqual(data3_dist_attr.device_placement, None) + self.assertTrue(data3_dist_attr.is_annotated("process_mesh")) + self.assertTrue(data3_dist_attr.is_annotated("dims_mapping")) + self.assertFalse(data3_dist_attr.is_annotated("shard_sizes")) + self.assertFalse(data3_dist_attr.is_annotated("device_placement")) + + # Test shard_op interface with dist_attr + dist_add = dist.shard_op(paddle.add) + results = dist_add(data2, data3) + ops = paddle.static.default_main_program().block(0).ops + last_op = ops[-1] + dist_op = dist_context.get_dist_op_for_program(last_op) + self.assertEqual(dist_op.dist_attr.process_mesh, None) + self.assertEqual(dist_op.dist_attr.impl_type, "default") + self.assertEqual(dist_op.dist_attr.impl_idx, -2) + self.assertFalse(dist_op.dist_attr.is_annotated("process_mesh")) + + data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name) + self.assertEqual(data2_dist_attr.process_mesh, + dist_op.dist_attr.process_mesh) + self.assertEqual(data2_dist_attr.dims_mapping, [-1, -1]) + self.assertEqual(data2_dist_attr.shard_sizes, None) + self.assertEqual(data2_dist_attr.device_placement, None) + self.assertFalse(data2_dist_attr.is_annotated("process_mesh")) + self.assertFalse(data2_dist_attr.is_annotated("dims_mapping")) + self.assertFalse(data2_dist_attr.is_annotated("shard_sizes")) + self.assertFalse(data2_dist_attr.is_annotated("device_placement")) + + data3_dist_attr = dist_op.dist_attr.get_input_dist_attr(data3.name) + self.assertEqual(data3_dist_attr.process_mesh, + dist_op.dist_attr.process_mesh) + self.assertEqual(data3_dist_attr.dims_mapping, [-1, -1]) + self.assertEqual(data3_dist_attr.shard_sizes, None) + self.assertEqual(data3_dist_attr.device_placement, None) + self.assertFalse(data3_dist_attr.is_annotated("process_mesh")) + self.assertFalse(data3_dist_attr.is_annotated("dims_mapping")) + self.assertFalse(data3_dist_attr.is_annotated("shard_sizes")) + self.assertFalse(data3_dist_attr.is_annotated("device_placement")) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py index 21726596ca7..05d71aca5db 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py @@ -28,15 +28,14 @@ from paddle.fluid import layers from paddle.nn.layer.transformer import _convert_param_attr_to_list import paddle.distributed.auto_parallel as auto from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program -from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr +from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix -from paddle.distributed.auto_parallel.context import DistributedContext -from paddle.distributed.auto_parallel.context import set_default_distributed_context +from paddle.distributed.auto_parallel.dist_context import DistributedContext +from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context paddle.enable_static() _global_parallel_strategy = None _global_process_mesh = None _global_process_mesh2 = None -ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]]) class MLPLayer(nn.Layer): @@ -62,20 +61,43 @@ class MLPLayer(nn.Layer): def forward(self, input): if _global_parallel_strategy == "mp": auto.shard_tensor( - self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) - auto.shard_tensor( - self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) + self.linear0.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) + auto.shard_tensor( + self.linear1.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) - auto.shard_tensor( - self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1]) + self.linear0.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) + auto.shard_tensor( + self.linear1.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [1, -1] + }) elif _global_parallel_strategy == "pp": auto.shard_tensor( - self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.linear0.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) auto.shard_tensor( - self.linear1.weight, _global_process_mesh2, - dim_mapping=[1, -1]) + self.linear1.weight, + dist_attr={ + "process_mesh": _global_process_mesh2, + "dims_mapping": [1, -1] + }) out = self.norm(input) out = self.linear0(out) @@ -99,10 +121,18 @@ def mlp_pretrain_forward(train_program, start_program): if _global_parallel_strategy == "dp": auto.shard_tensor( - input, _global_process_mesh, dim_mapping=[0, -1, -1]) + input, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - input, _global_process_mesh, dim_mapping=[0, -1, -1]) + input, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1, -1] + }) mlp = MLPLayer( hidden_size=hidden_size, @@ -118,8 +148,7 @@ class TestMLPAutoCompletion(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "dp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh( - mesh=[0, 1, 2, 3], parent=ROOT_MESH) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) train_program = static.Program() start_program = static.Program() dist_context = DistributedContext() @@ -127,18 +156,15 @@ class TestMLPAutoCompletion(unittest.TestCase): start_program) complete_train_program = auto.complete_annotation(train_program, dist_context) - # print_program_with_distributed_attr(complete_train_program, + # print_program_with_dist_attr(complete_train_program, # dist_context) - self.assertTrue( - check_distributed_attr_for_program(complete_train_program, - dist_context)) + self.assertTrue(dist_context.validate_dist_attr_for_program()) def test_mlp_mp(self): global _global_parallel_strategy _global_parallel_strategy = "mp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh( - mesh=[0, 1, 2, 3], parent=ROOT_MESH) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) train_program = static.Program() start_program = static.Program() @@ -147,81 +173,77 @@ class TestMLPAutoCompletion(unittest.TestCase): start_program) complete_train_program = auto.complete_annotation(train_program, dist_context) - # print_program_with_distributed_attr(complete_train_program, + # print_program_with_dist_attr(complete_train_program, # dist_context) - self.assertTrue( - check_distributed_attr_for_program(complete_train_program, - dist_context)) + self.assertTrue(dist_context.validate_dist_attr_for_program()) def test_mlp_dp_mp(self): global _global_parallel_strategy _global_parallel_strategy = "dp_mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( - mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) - - train_program = static.Program() - start_program = static.Program() - dist_context = DistributedContext() - train_program, start_program = mlp_pretrain_forward(train_program, - start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) - # print_program_with_distributed_attr(complete_train_program, - # dist_context) - self.assertTrue( - check_distributed_attr_for_program(complete_train_program, - dist_context)) - - def test_mlp_misc(self): - # import pdb - global _global_parallel_strategy - _global_parallel_strategy = "pp" - global _global_process_mesh - _global_process_mesh = auto.ProcessMesh( - mesh=[[0, 1], [2, 3]], parent=ROOT_MESH) - global _global_process_mesh2 - _global_process_mesh2 = auto.ProcessMesh( - mesh=[[4, 5], [6, 7]], parent=ROOT_MESH) + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) train_program = static.Program() start_program = static.Program() dist_context = DistributedContext() train_program, start_program = mlp_pretrain_forward(train_program, start_program) - # pdb.set_trace() complete_train_program = auto.complete_annotation(train_program, dist_context) - # print_program_with_distributed_attr(complete_train_program, + # print_program_with_dist_attr(complete_train_program, # dist_context) - dist_context.finalize_distributed_attr_for_program( - complete_train_program) - from paddle.distributed.auto_parallel.interface import _g_process_mesh_map - for block in complete_train_program.blocks: - for tensor in block.vars.values(): - desc = tensor.desc - attr_name = append_distributed_attr_suffix("mesh_id") - self.assertIsNotNone(desc.has_attr(attr_name)) - attr_name = append_distributed_attr_suffix("dim_mapping") - self.assertIsNotNone(desc.has_attr(attr_name)) - for op in block.ops: - desc = op.desc - attr_name = append_distributed_attr_suffix("mesh_id") - self.assertIsNotNone(desc.has_attr(attr_name)) - for tensor_name in desc.input_arg_names(): - attr_name = append_distributed_attr_suffix("IN_" + - tensor_name) - self.assertIsNotNone(desc.has_attr(attr_name)) - for tensor_name in desc.output_arg_names(): - attr_name = append_distributed_attr_suffix("OUT_" + - tensor_name) - self.assertIsNotNone(desc.has_attr(attr_name)) - set_default_distributed_context(dist_context) - self.assertTrue("dist_attr" in str(complete_train_program)) - with unittest.mock.patch( - "sys.stdout", new_callable=StringIO) as mock_stdout: - print_program_with_distributed_attr(complete_train_program) - self.assertIsNotNone(mock_stdout.getvalue()) + self.assertTrue(dist_context.validate_dist_attr_for_program()) + + # def test_mlp_misc(self): + # # import pdb + # global _global_parallel_strategy + # _global_parallel_strategy = "pp" + # global _global_process_mesh + # _global_process_mesh = auto.ProcessMesh( + # mesh=[[0, 1], [2, 3]]) + # global _global_process_mesh2 + # _global_process_mesh2 = auto.ProcessMesh( + # mesh=[[4, 5], [6, 7]]) + + # train_program = static.Program() + # start_program = static.Program() + # dist_context = DistributedContext() + # train_program, start_program = mlp_pretrain_forward(train_program, + # start_program) + # # pdb.set_trace() + # complete_train_program = auto.complete_annotation(train_program, + # dist_context) + # # print_program_with_dist_attr(complete_train_program, + # # dist_context) + # dist_context.finalize_distributed_attr_for_program( + # complete_train_program) + # from paddle.distributed.auto_parallel.interface import _g_process_mesh_map + # for block in complete_train_program.blocks: + # for tensor in block.vars.values(): + # desc = tensor.desc + # attr_name = append_distributed_attr_suffix("mesh_id") + # self.assertIsNotNone(desc.has_attr(attr_name)) + # attr_name = append_distributed_attr_suffix("dims_mapping") + # self.assertIsNotNone(desc.has_attr(attr_name)) + # for op in block.ops: + # desc = op.desc + # attr_name = append_distributed_attr_suffix("mesh_id") + # self.assertIsNotNone(desc.has_attr(attr_name)) + # for tensor_name in desc.input_arg_names(): + # attr_name = append_distributed_attr_suffix("IN_" + + # tensor_name) + # self.assertIsNotNone(desc.has_attr(attr_name)) + # for tensor_name in desc.output_arg_names(): + # attr_name = append_distributed_attr_suffix("OUT_" + + # tensor_name) + # self.assertIsNotNone(desc.has_attr(attr_name)) + # set_default_distributed_context(dist_context) + # self.assertTrue("dist_attr" in str(complete_train_program)) + # with unittest.mock.patch( + # "sys.stdout", new_callable=StringIO) as mock_stdout: + # print_program_with_dist_attr(complete_train_program) + # self.assertIsNotNone(mock_stdout.getvalue()) class AttentionLayer(nn.Layer): @@ -262,10 +284,18 @@ class AttentionLayer(nn.Layer): def forward(self, input): if _global_parallel_strategy == "dp": auto.shard_tensor( - input, _global_process_mesh, dim_mapping=[0, -1, -1]) + input, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - input, _global_process_mesh, dim_mapping=[0, -1, -1]) + input, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1, -1] + }) q = self.q_proj(input) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) @@ -276,18 +306,42 @@ class AttentionLayer(nn.Layer): if _global_parallel_strategy == "mp": auto.shard_tensor( - self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) - auto.shard_tensor( - self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) - auto.shard_tensor( - self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + self.q_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) + auto.shard_tensor( + self.k_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) + auto.shard_tensor( + self.v_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.q_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) auto.shard_tensor( - self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.k_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) auto.shard_tensor( - self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.v_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) @@ -320,12 +374,18 @@ class AttentionLayer(nn.Layer): out = self.out_proj(out) if _global_parallel_strategy == "mp": auto.shard_tensor( - self.out_proj.weight, _global_process_mesh, - dim_mapping=[0, -1]) + self.out_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.out_proj.weight, _global_process_mesh, - dim_mapping=[1, -1]) + self.out_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [1, -1] + }) return out @@ -357,8 +417,7 @@ class TestAttentionAutoCompletion(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "dp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh( - mesh=[0, 1, 2, 3], parent=ROOT_MESH) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) train_program = static.Program() start_program = static.Program() dist_context = DistributedContext() @@ -366,18 +425,15 @@ class TestAttentionAutoCompletion(unittest.TestCase): start_program) complete_train_program = auto.complete_annotation(train_program, dist_context) - # print_program_with_distributed_attr(complete_train_program, + # print_program_with_dist_attr(complete_train_program, # dist_context) - self.assertTrue( - check_distributed_attr_for_program(complete_train_program, - dist_context)) + self.assertTrue(dist_context.validate_dist_attr_for_program()) def test_attn_mp(self): global _global_parallel_strategy _global_parallel_strategy = "mp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh( - mesh=[0, 1, 2, 3], parent=ROOT_MESH) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) train_program = static.Program() start_program = static.Program() @@ -386,18 +442,16 @@ class TestAttentionAutoCompletion(unittest.TestCase): start_program) complete_train_program = auto.complete_annotation(train_program, dist_context) - # print_program_with_distributed_attr(complete_train_program, + # print_program_with_dist_attr(complete_train_program, # dist_context) - self.assertTrue( - check_distributed_attr_for_program(complete_train_program, - dist_context)) + self.assertTrue(dist_context.validate_dist_attr_for_program()) def test_attn_dp_mp(self): global _global_parallel_strategy _global_parallel_strategy = "dp_mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( - mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) train_program = static.Program() start_program = static.Program() @@ -406,11 +460,9 @@ class TestAttentionAutoCompletion(unittest.TestCase): start_program) complete_train_program = auto.complete_annotation(train_program, dist_context) - # print_program_with_distributed_attr(complete_train_program, + # print_program_with_dist_attr(complete_train_program, # dist_context) - self.assertTrue( - check_distributed_attr_for_program(complete_train_program, - dist_context)) + self.assertTrue(dist_context.validate_dist_attr_for_program()) class DecoderLayer(nn.Layer): @@ -486,10 +538,18 @@ class DecoderLayer(nn.Layer): def forward(self, input_ids, position_ids): if _global_parallel_strategy == "dp": auto.shard_tensor( - input_ids, _global_process_mesh, dim_mapping=[0, -1]) + input_ids, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - input_ids, _global_process_mesh, dim_mapping=[0, -1]) + input_ids, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) input_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) @@ -497,13 +557,17 @@ class DecoderLayer(nn.Layer): if _global_parallel_strategy == "mp": auto.shard_tensor( self.word_embeddings.weight, - _global_process_mesh, - dim_mapping=[0, -1]) + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.word_embeddings.weight, - _global_process_mesh, - dim_mapping=[1, -1]) + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [1, -1] + }) embeddings = input_embeddings + position_embeddings embeddings = self.dropout1(embeddings) @@ -521,18 +585,42 @@ class DecoderLayer(nn.Layer): if _global_parallel_strategy == "mp": auto.shard_tensor( - self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) - auto.shard_tensor( - self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) - auto.shard_tensor( - self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + self.q_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) + auto.shard_tensor( + self.k_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) + auto.shard_tensor( + self.v_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.q_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) auto.shard_tensor( - self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.k_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) auto.shard_tensor( - self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.v_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) @@ -566,12 +654,18 @@ class DecoderLayer(nn.Layer): if _global_parallel_strategy == "mp": auto.shard_tensor( - self.out_proj.weight, _global_process_mesh, - dim_mapping=[0, -1]) + self.out_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.out_proj.weight, _global_process_mesh, - dim_mapping=[1, -1]) + self.out_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [1, -1] + }) # Add residual residual = embeddings + self.dropout2(out) @@ -586,14 +680,30 @@ class DecoderLayer(nn.Layer): if _global_parallel_strategy == "mp": auto.shard_tensor( - self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) - auto.shard_tensor( - self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) + self.linear0.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) + auto.shard_tensor( + self.linear1.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.linear0.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) auto.shard_tensor( - self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1]) + self.linear1.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [1, -1] + }) # Add residual final = residual + self.dropout3(out3) @@ -631,8 +741,7 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "dp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh( - mesh=[0, 1, 2, 3], parent=ROOT_MESH) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) train_program = static.Program() start_program = static.Program() dist_context = DistributedContext() @@ -640,18 +749,15 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): start_program) complete_train_program = auto.complete_annotation(train_program, dist_context) - # print_program_with_distributed_attr(complete_train_program, + # print_program_with_dist_attr(complete_train_program, # dist_context) - self.assertTrue( - check_distributed_attr_for_program(complete_train_program, - dist_context)) + self.assertTrue(dist_context.validate_dist_attr_for_program()) def test_decoder_mp(self): global _global_parallel_strategy _global_parallel_strategy = "mp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh( - mesh=[0, 1, 2, 3], parent=ROOT_MESH) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) train_program = static.Program() start_program = static.Program() @@ -660,18 +766,16 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): start_program) complete_train_program = auto.complete_annotation(train_program, dist_context) - # print_program_with_distributed_attr(complete_train_program, + # print_program_with_dist_attr(complete_train_program, # dist_context) - self.assertTrue( - check_distributed_attr_for_program(complete_train_program, - dist_context)) + self.assertTrue(dist_context.validate_dist_attr_for_program()) def test_decoder_dp_mp(self): global _global_parallel_strategy _global_parallel_strategy = "dp_mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( - mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) train_program = static.Program() start_program = static.Program() @@ -680,11 +784,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): start_program) complete_train_program = auto.complete_annotation(train_program, dist_context) - # print_program_with_distributed_attr(complete_train_program, + # print_program_with_dist_attr(complete_train_program, # dist_context) - self.assertTrue( - check_distributed_attr_for_program(complete_train_program, - dist_context)) + self.assertTrue(dist_context.validate_dist_attr_for_program()) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion_gpt.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion_gpt.py index cd87a72a7e6..c2c1e63155c 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion_gpt.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion_gpt.py @@ -32,13 +32,12 @@ from paddle.distributed.fleet import fleet import paddle.static as static import paddle.distributed.auto_parallel as auto from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program -from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr -from paddle.distributed.auto_parallel.context import DistributedContext +from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr +from paddle.distributed.auto_parallel.dist_context import DistributedContext paddle.enable_static() _global_parallel_strategy = None _global_process_mesh = None -ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]]) class MultiHeadAttention(nn.Layer): @@ -108,10 +107,18 @@ class MultiHeadAttention(nn.Layer): if _global_parallel_strategy == "mp": auto.shard_tensor( - self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + self.q_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.q_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) @@ -145,19 +152,35 @@ class MultiHeadAttention(nn.Layer): if _global_parallel_strategy == "mp": auto.shard_tensor( - self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + self.k_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.k_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) v = self.v_proj(value) if _global_parallel_strategy == "mp": auto.shard_tensor( - self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + self.v_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.v_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) @@ -238,12 +261,18 @@ class MultiHeadAttention(nn.Layer): if _global_parallel_strategy == "mp": auto.shard_tensor( - self.out_proj.weight, _global_process_mesh, - dim_mapping=[0, -1]) + self.out_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.out_proj.weight, _global_process_mesh, - dim_mapping=[1, -1]) + self.out_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [1, -1] + }) outs = [out] if self.need_weights: @@ -411,17 +440,33 @@ class TransformerDecoderLayer(nn.Layer): if _global_parallel_strategy == "mp": auto.shard_tensor( - self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 0]) + self.linear1.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.linear1.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) if _global_parallel_strategy == "mp": auto.shard_tensor( - self.linear2.weight, _global_process_mesh, dim_mapping=[0, -1]) + self.linear2.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.linear2.weight, _global_process_mesh, dim_mapping=[1, -1]) + self.linear2.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [1, -1] + }) # tgt = self.dropout2( # self.linear2(F.gelu( @@ -485,13 +530,17 @@ class GPTEmbeddings(nn.Layer): if _global_parallel_strategy == "mp": auto.shard_tensor( self.word_embeddings.weight, - _global_process_mesh, - dim_mapping=[0, -1]) + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.word_embeddings.weight, - _global_process_mesh, - dim_mapping=[1, -1]) + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [1, -1] + }) position_embeddings = self.position_embeddings(position_ids) embeddings = input_embedings + position_embeddings @@ -717,10 +766,18 @@ def gpt_pretrain_forward(train_program, start_program): if _global_parallel_strategy == "dp": auto.shard_tensor( - input_ids, _global_process_mesh, dim_mapping=[0, -1]) + input_ids, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - input_ids, _global_process_mesh, dim_mapping=[0, -1]) + input_ids, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) gpt = GPTModel( vocab_size=32768, @@ -753,8 +810,7 @@ class TestGPTAutoCompletion(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "dp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh( - mesh=[0, 1, 2, 3], parent=ROOT_MESH) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) train_program = static.Program() start_program = static.Program() @@ -763,18 +819,15 @@ class TestGPTAutoCompletion(unittest.TestCase): start_program) complete_train_program = auto.complete_annotation(train_program, dist_context) - # print_program_with_distributed_attr(complete_train_program, + # print_program_with_dist_attr(complete_train_program, # dist_context) - self.assertTrue( - check_distributed_attr_for_program(complete_train_program, - dist_context)) + self.assertTrue(dist_context.validate_dist_attr_for_program()) def test_gpt_mp(self): global _global_parallel_strategy _global_parallel_strategy = "mp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh( - mesh=[0, 1, 2, 3], parent=ROOT_MESH) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) train_program = static.Program() start_program = static.Program() @@ -783,18 +836,16 @@ class TestGPTAutoCompletion(unittest.TestCase): start_program) complete_train_program = auto.complete_annotation(train_program, dist_context) - # print_program_with_distributed_attr(complete_train_program, + # print_program_with_dist_attr(complete_train_program, # dist_context) - self.assertTrue( - check_distributed_attr_for_program(complete_train_program, - dist_context)) + self.assertTrue(dist_context.validate_dist_attr_for_program()) def test_gpt_dp_mp(self): global _global_parallel_strategy _global_parallel_strategy = "dp_mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( - mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) train_program = static.Program() start_program = static.Program() @@ -803,11 +854,9 @@ class TestGPTAutoCompletion(unittest.TestCase): start_program) complete_train_program = auto.complete_annotation(train_program, dist_context) - # print_program_with_distributed_attr(complete_train_program, + # print_program_with_dist_attr(complete_train_program, # dist_context) - self.assertTrue( - check_distributed_attr_for_program(complete_train_program, - dist_context)) + self.assertTrue(dist_context.validate_dist_attr_for_program()) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py index 000b1db6138..4c9c01b99e0 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py @@ -23,21 +23,19 @@ import paddle.static as static import paddle.nn.functional as F import paddle.utils as utils import paddle.distributed.auto_parallel as auto -from paddle.distributed.auto_parallel.context import DistributedContext +from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed import fleet from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.completion import complete_backward_annotation from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.cost_model import estimate_cost import paddle.fluid.core as core +from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr paddle.enable_static() _global_parallel_strategy = "dp_mp_pp" -ROOT_MESH = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]]) -_global_process_mesh = auto.ProcessMesh( - [[[0, 1], [4, 5]], [[2, 3], [6, 7]]], parent=ROOT_MESH) -PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], parent=ROOT_MESH) -PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], parent=ROOT_MESH) +PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]]) +PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]]) NUM_RANKS = 8 STAGE_0_CNT = 5 STAGE_1_CNT = 10 @@ -70,9 +68,13 @@ class MLPLayer(nn.Layer): def forward(self, input): if self.is_distributed: auto.shard_tensor( - self.linear0.weight, PP_MESH_0, dim_mapping=[-1, 1]) + self.linear0.weight, + dist_attr={"process_mesh": PP_MESH_0, + "dims_mapping": [-1, 1]}) auto.shard_tensor( - self.linear1.weight, PP_MESH_1, dim_mapping=[1, -1]) + self.linear1.weight, + dist_attr={"process_mesh": PP_MESH_1, + "dims_mapping": [1, -1]}) out = self.norm(input) out = self.linear0(out) @@ -120,8 +122,14 @@ def mlp_forward(train_program, start_program, is_distributed=True): name="label", shape=[batch_size, 1], dtype='float32') if is_distributed: - auto.shard_tensor(input, PP_MESH_0, dim_mapping=[0, -1]) - auto.shard_tensor(label, PP_MESH_1, dim_mapping=[0, -1]) + auto.shard_tensor( + input, + dist_attr={"process_mesh": PP_MESH_0, + "dims_mapping": [0, -1]}) + auto.shard_tensor( + label, + dist_attr={"process_mesh": PP_MESH_1, + "dims_mapping": [0, -1]}) mlp = MLPLayer( hidden_size=hidden_size, @@ -137,8 +145,6 @@ def mlp_forward(train_program, start_program, is_distributed=True): def get_dist_prog(train_program, startup_program, dist_context, rank_id): - global _global_process_mesh - dist_context.set_process_mesh(_global_process_mesh) loss, train_program, startup_program = mlp_forward(train_program, startup_program) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py index 44a52524401..3a23f9b2611 100755 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py @@ -29,19 +29,17 @@ from paddle.fluid import layers from paddle.nn.layer.transformer import _convert_param_attr_to_list import paddle.distributed.auto_parallel as auto from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program -from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr +from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix -from paddle.distributed.auto_parallel.context import DistributedContext -from paddle.distributed.auto_parallel.context import set_default_distributed_context +from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed import fleet from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.utils import _get_comm_group -from paddle.distributed.auto_parallel.process import new_process_group +from paddle.distributed.auto_parallel.process_group import new_process_group paddle.enable_static() _global_parallel_strategy = None _global_process_mesh = None -ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]]) def get_programs(annotated_func): @@ -49,7 +47,7 @@ def get_programs(annotated_func): start_program = static.Program() dist_context = DistributedContext() global _global_process_mesh - dist_context.set_process_mesh(_global_process_mesh) + dist_context.process_mesh = _global_process_mesh train_program, start_program = annotated_func(train_program, start_program) complete_train_program = auto.complete_annotation(train_program, dist_context) @@ -95,9 +93,8 @@ def initialization_check(mode, dist_context, dist_startup_prog, serial_startup_prog, var_need_broadcast, process_mesh, mp_parallel_axis, dp_parallel_axis): if 'mp' in mode: - group_ranks = _get_comm_group(process_mesh.process_group, - process_mesh.topology, mp_parallel_axis, - 3) + group_ranks = _get_comm_group( + process_mesh.processes, process_mesh.topology, mp_parallel_axis, 3) mp_ring_id = new_process_group(group_ranks).id broadcast_ops = [ op for op in dist_startup_prog.global_block().ops @@ -110,9 +107,8 @@ def initialization_check(mode, dist_context, dist_startup_prog, return False if 'dp' in mode: - group_ranks = _get_comm_group(process_mesh.process_group, - process_mesh.topology, dp_parallel_axis, - 3) + group_ranks = _get_comm_group( + process_mesh.processes, process_mesh.topology, dp_parallel_axis, 3) dp_ring_id = new_process_group(group_ranks).id nparam = len(serial_startup_prog.all_parameters()) nbroadcast_dp = len([ @@ -137,22 +133,21 @@ def initialization_check(mode, dist_context, dist_startup_prog, def get_input_var_dist_attr(op, main_program, dist_context): varname = op.desc.input_arg_names() var = main_program.global_block().var(varname[0]) - dist_attr = dist_context.get_tensor_distributed_attr_for_program(var) + dist_attr = dist_context.get_tensor_dist_attr_for_program(var) return dist_attr def get_output_var_dist_attr(op, main_program, dist_context): varname = op.desc.output_arg_names() var = main_program.global_block().var(varname[0]) - dist_attr = dist_context.get_tensor_distributed_attr_for_program(var) + dist_attr = dist_context.get_tensor_dist_attr_for_program(var) return dist_attr def check_equal_var_dist_attr(serial_dist_attr, dist_attr): equal = True - if serial_dist_attr.get_process_mesh() != dist_attr.get_process_mesh() or \ - serial_dist_attr.is_parameter() != dist_attr.is_parameter() or \ - serial_dist_attr.get_dims_mapping() != dist_attr.get_dims_mapping(): + if serial_dist_attr.process_mesh != dist_attr.process_mesh or \ + serial_dist_attr.dims_mapping != dist_attr.dims_mapping: equal = False return equal @@ -161,36 +156,33 @@ def check_equal_dist_op_attr(dist_context, dist_main_prog, serial_op, dist_ops, dist_op_idx): equal = True # get serial op's process_mesh and impl_idx - serial_op_dist_attr = dist_context.get_op_distributed_attr_for_program( - serial_op) - serial_process_mesh = serial_op_dist_attr.get_process_mesh() - serial_impl_idx = serial_op_dist_attr.get_impl_idx() + serial_op_dist_attr = dist_context.get_op_dist_attr_for_program(serial_op) + serial_process_mesh = serial_op_dist_attr.process_mesh + serial_impl_idx = serial_op_dist_attr.impl_idx # check dist_attr between serial op and dist op for i in dist_op_idx: - op_dist_attr = dist_context.get_op_distributed_attr_for_program( - dist_ops[i]) + op_dist_attr = dist_context.get_op_dist_attr_for_program(dist_ops[i]) for in_varname in dist_ops[i].desc.input_arg_names(): in_var = dist_main_prog.global_block().var(in_varname) - tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program( + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( in_var) - tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() + tensor_dims_mapping = tensor_dist_attr.dims_mapping in_var_dims_mapping = op_dist_attr.get_input_dims_mapping( in_varname) if tensor_dims_mapping != in_var_dims_mapping: equal = False for out_varname in dist_ops[i].desc.output_arg_names(): out_var = dist_main_prog.global_block().var(out_varname) - tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program( + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( out_var) - tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() + tensor_dims_mapping = tensor_dist_attr.dims_mapping out_var_dims_mapping = op_dist_attr.get_output_dims_mapping( out_varname) if tensor_dims_mapping != out_var_dims_mapping: equal = False - - dist_op_process_mesh = op_dist_attr.get_process_mesh() - dist_op_impl_idx = op_dist_attr.get_impl_idx() + dist_op_process_mesh = op_dist_attr.process_mesh + dist_op_impl_idx = op_dist_attr.impl_idx if serial_op.desc.id() == dist_ops[i].desc.id() or \ serial_process_mesh != dist_op_process_mesh or \ serial_impl_idx != dist_op_impl_idx: @@ -242,13 +234,13 @@ def distributed_attr_check_for_program(dist_main_prog, dist_context): have_dist_attr = True for block in dist_main_prog.blocks: for tensor in block.vars.values(): - var_dist_attr = dist_context.get_tensor_distributed_attr_for_program( + var_dist_attr = dist_context.get_tensor_dist_attr_for_program( tensor) if var_dist_attr is None: have_dist_attr = False for op in block.ops: - op_dist_attr = dist_context.get_op_distributed_attr_for_program(op) + op_dist_attr = dist_context.get_op_dist_attr_for_program(op) if op_dist_attr is None: have_dist_attr = False @@ -278,21 +270,43 @@ class MLPLayer(nn.Layer): def forward(self, input): if _global_parallel_strategy == "mp": auto.shard_tensor( - self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) + self.linear0.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) auto.shard_tensor( - self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) + self.linear1.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.linear0.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) auto.shard_tensor( - self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1]) + self.linear1.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [1, -1] + }) else: auto.shard_tensor( - self.linear0.weight, _global_process_mesh, - dim_mapping=[-1, -1]) + self.linear0.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, -1] + }) auto.shard_tensor( - self.linear1.weight, _global_process_mesh, - dim_mapping=[-1, -1]) + self.linear1.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, -1] + }) out = self.norm(input) out = self.linear0(out) @@ -316,10 +330,18 @@ def mlp_pretrain_forward(train_program, start_program): if _global_parallel_strategy == "dp": auto.shard_tensor( - input, _global_process_mesh, dim_mapping=[0, -1, -1]) + input, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - input, _global_process_mesh, dim_mapping=[0, -1, -1]) + input, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1, -1] + }) mlp = MLPLayer( hidden_size=hidden_size, @@ -335,8 +357,7 @@ class TestMLPAutoPartitioner(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "dp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh( - mesh=[0, 1, 2, 3], parent=ROOT_MESH) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( mlp_pretrain_forward) @@ -372,8 +393,7 @@ class TestMLPAutoPartitioner(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "mp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh( - mesh=[0, 1, 2, 3], parent=ROOT_MESH) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( mlp_pretrain_forward) @@ -437,7 +457,7 @@ class TestMLPAutoPartitioner(unittest.TestCase): _global_parallel_strategy = "dp_mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( - mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( mlp_pretrain_forward) @@ -535,10 +555,18 @@ class AttentionLayer(nn.Layer): def forward(self, input): if _global_parallel_strategy == "dp": auto.shard_tensor( - input, _global_process_mesh, dim_mapping=[0, -1, -1]) + input, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - input, _global_process_mesh, dim_mapping=[0, -1, -1]) + input, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1, -1] + }) q = self.q_proj(input) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) @@ -549,18 +577,42 @@ class AttentionLayer(nn.Layer): if _global_parallel_strategy == "mp": auto.shard_tensor( - self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + self.q_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) auto.shard_tensor( - self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + self.k_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) auto.shard_tensor( - self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + self.v_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.q_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) auto.shard_tensor( - self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.k_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) auto.shard_tensor( - self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.v_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) @@ -593,12 +645,18 @@ class AttentionLayer(nn.Layer): out = self.out_proj(out) if _global_parallel_strategy == "mp": auto.shard_tensor( - self.out_proj.weight, _global_process_mesh, - dim_mapping=[0, -1]) + self.out_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.out_proj.weight, _global_process_mesh, - dim_mapping=[1, -1]) + self.out_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [1, -1] + }) return out @@ -630,8 +688,7 @@ class TestAttentionAutoPartitioner(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "dp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh( - mesh=[0, 1, 2, 3], parent=ROOT_MESH) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( attn_pretrain_forward) @@ -666,8 +723,7 @@ class TestAttentionAutoPartitioner(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "mp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh( - mesh=[0, 1, 2, 3], parent=ROOT_MESH) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( attn_pretrain_forward) @@ -735,7 +791,7 @@ class TestAttentionAutoPartitioner(unittest.TestCase): _global_parallel_strategy = "dp_mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( - mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( attn_pretrain_forward) @@ -871,10 +927,18 @@ class DecoderLayer(nn.Layer): def forward(self, input_ids, position_ids): if _global_parallel_strategy == "dp": auto.shard_tensor( - input_ids, _global_process_mesh, dim_mapping=[0, -1]) + input_ids, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - input_ids, _global_process_mesh, dim_mapping=[0, -1]) + input_ids, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) input_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) @@ -882,13 +946,17 @@ class DecoderLayer(nn.Layer): if _global_parallel_strategy == "mp": auto.shard_tensor( self.word_embeddings.weight, - _global_process_mesh, - dim_mapping=[0, -1]) + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.word_embeddings.weight, - _global_process_mesh, - dim_mapping=[1, -1]) + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [1, -1] + }) embeddings = input_embeddings + position_embeddings embeddings = self.dropout1(embeddings) @@ -906,18 +974,42 @@ class DecoderLayer(nn.Layer): if _global_parallel_strategy == "mp": auto.shard_tensor( - self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + self.q_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) auto.shard_tensor( - self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + self.k_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) auto.shard_tensor( - self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + self.v_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.q_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) auto.shard_tensor( - self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.k_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) auto.shard_tensor( - self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.v_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) @@ -951,17 +1043,25 @@ class DecoderLayer(nn.Layer): if _global_parallel_strategy == "mp": auto.shard_tensor( - self.out_proj.weight, _global_process_mesh, - dim_mapping=[0, -1]) + self.out_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.out_proj.weight, _global_process_mesh, - dim_mapping=[1, -1]) + self.out_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [1, -1] + }) else: auto.shard_tensor( self.out_proj.weight, - _global_process_mesh, - dim_mapping=[-1, -1]) + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, -1] + }) # Add residual residual = embeddings + self.dropout2(out) @@ -976,14 +1076,30 @@ class DecoderLayer(nn.Layer): if _global_parallel_strategy == "mp": auto.shard_tensor( - self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) + self.linear0.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) auto.shard_tensor( - self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) + self.linear1.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.linear0.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) auto.shard_tensor( - self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1]) + self.linear1.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [1, -1] + }) # Add residual final = residual + self.dropout3(out3) @@ -1022,7 +1138,7 @@ class TestDecoderLayerPartitioner(unittest.TestCase): _global_parallel_strategy = "dp_mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( - mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( decoder_pretrain_forward) @@ -1105,7 +1221,7 @@ class TestDecoderLayerPartitioner(unittest.TestCase): _global_parallel_strategy = "None" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( - mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( decoder_pretrain_forward) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py index 3c395fbdf7d..7fcb18db128 100755 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py @@ -32,14 +32,13 @@ from paddle.distributed import fleet import paddle.static as static import paddle.distributed.auto_parallel as auto from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program -from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr -from paddle.distributed.auto_parallel.context import DistributedContext +from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr +from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.utils import _get_comm_group -from paddle.distributed.auto_parallel.process import new_process_group +from paddle.distributed.auto_parallel.process_group import new_process_group paddle.enable_static() -ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]]) _global_parallel_strategy = None _global_process_mesh = None @@ -61,24 +60,27 @@ def is_valid_completed_program(dist_context, program): ops = program.global_block().ops vars_ = program.list_vars() for op in ops: - op_dist_attrs = dist_context.get_op_distributed_attr_for_program(op) + op_dist_attrs = dist_context.get_op_dist_attr_for_program(op) if op_dist_attrs == None: return False - if op_dist_attrs.get_process_mesh == None: + if op_dist_attrs.process_mesh == None: return False - if None in op_dist_attrs._dims_mapping.values(): - return False + for tensor_dist_attr in op_dist_attrs.inputs_dist_attrs.values(): + if None == tensor_dist_attr.dims_mapping: + return False + for tensor_dist_attr in op_dist_attrs.outputs_dist_attrs.values(): + if None == tensor_dist_attr.dims_mapping: + return False for var in vars_: - var_dist_attrs = dist_context.get_tensor_distributed_attr_for_program( - var) + var_dist_attrs = dist_context.get_tensor_dist_attr_for_program(var) if var_dist_attrs == None: return False - elif var_dist_attrs.get_process_mesh == None: + elif var_dist_attrs.process_mesh == None: return False - elif var_dist_attrs.get_dims_mapping == None: + elif var_dist_attrs.dims_mapping == None: return False return True @@ -151,10 +153,18 @@ class MultiHeadAttention(nn.Layer): if _global_parallel_strategy == "mp": auto.shard_tensor( - self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + self.q_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.q_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) @@ -188,19 +198,35 @@ class MultiHeadAttention(nn.Layer): if _global_parallel_strategy == "mp": auto.shard_tensor( - self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + self.k_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.k_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) v = self.v_proj(value) if _global_parallel_strategy == "mp": auto.shard_tensor( - self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + self.v_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.v_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) @@ -281,12 +307,18 @@ class MultiHeadAttention(nn.Layer): if _global_parallel_strategy == "mp": auto.shard_tensor( - self.out_proj.weight, _global_process_mesh, - dim_mapping=[0, -1]) + self.out_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.out_proj.weight, _global_process_mesh, - dim_mapping=[1, -1]) + self.out_proj.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [1, -1] + }) outs = [out] if self.need_weights: @@ -454,17 +486,33 @@ class TransformerDecoderLayer(nn.Layer): if _global_parallel_strategy == "mp": auto.shard_tensor( - self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 0]) + self.linear1.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 1]) + self.linear1.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 1] + }) if _global_parallel_strategy == "mp": auto.shard_tensor( - self.linear2.weight, _global_process_mesh, dim_mapping=[0, -1]) + self.linear2.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - self.linear2.weight, _global_process_mesh, dim_mapping=[1, -1]) + self.linear2.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [1, -1] + }) # tgt = self.dropout2( # self.linear2(F.gelu( @@ -528,13 +576,17 @@ class GPTEmbeddings(nn.Layer): if _global_parallel_strategy == "mp": auto.shard_tensor( self.word_embeddings.weight, - _global_process_mesh, - dim_mapping=[0, -1]) + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.word_embeddings.weight, - _global_process_mesh, - dim_mapping=[1, -1]) + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [1, -1] + }) position_embeddings = self.position_embeddings(position_ids) embeddings = input_embedings + position_embeddings @@ -760,10 +812,18 @@ def gpt_pretrain_forward(train_program, start_program): if _global_parallel_strategy == "dp": auto.shard_tensor( - input_ids, _global_process_mesh, dim_mapping=[0, -1]) + input_ids, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( - input_ids, _global_process_mesh, dim_mapping=[0, -1]) + input_ids, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) gpt = GPTModel( vocab_size=32768, @@ -798,12 +858,12 @@ class TestGPTPartitioner(unittest.TestCase): global _global_process_mesh _global_process_mesh = auto.ProcessMesh( - mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) train_program = static.Program() start_program = static.Program() dist_context = DistributedContext() - dist_context.set_process_mesh(_global_process_mesh) + dist_context.process_mesh = _global_process_mesh train_program, start_program, loss = gpt_pretrain_forward(train_program, start_program) complete_train_program = auto.complete_annotation(train_program, @@ -833,7 +893,7 @@ class TestGPTPartitioner(unittest.TestCase): opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads, auto_parallel_main_prog, auto_parallel_startup_prog) - from paddle.distributed.auto_parallel.context import set_default_distributed_context + from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context set_default_distributed_context(dist_context) with open("./test_auto_parallel_partitioner_main_new.txt1", "w") as fw: fw.write(str(auto_parallel_main_prog)) @@ -877,14 +937,12 @@ class TestGPTPartitioner(unittest.TestCase): mp_parallel_axis = 1 dp_parallel_axis = 0 - group_ranks = _get_comm_group(process_mesh.process_group, - process_mesh.topology, mp_parallel_axis, - 3) + group_ranks = _get_comm_group( + process_mesh.processes, process_mesh.topology, mp_parallel_axis, 3) mp_ring_id = new_process_group(group_ranks).id - group_ranks = _get_comm_group(process_mesh.process_group, - process_mesh.topology, dp_parallel_axis, - 3) + group_ranks = _get_comm_group( + process_mesh.processes, process_mesh.topology, dp_parallel_axis, 3) dp_ring_id = new_process_group(group_ranks).id tensor_parallel_allreduce_vars = sorted([ diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py index fe9b965ed87..0439b9a287c 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py @@ -22,16 +22,16 @@ import paddle.static as static import paddle.nn.functional as F import paddle.utils as utils import paddle.distributed.auto_parallel as auto -from paddle.distributed.auto_parallel.context import DistributedContext +from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed import fleet from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.reshard import reshard -from paddle.distributed.auto_parallel.process import PROCESS_GROUP_MAP +from paddle.distributed.auto_parallel.process_group import _g_process_group_map +from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr paddle.enable_static() _global_parallel_strategy = None _global_process_mesh = None -ROOT_MESH = auto.ProcessMesh([0, 1]) PP_MESH_0 = None PP_MESH_1 = None @@ -57,16 +57,30 @@ class MLPLayer(nn.Layer): def forward(self, input): if _global_parallel_strategy == "pp": auto.shard_tensor( - self.linear0.weight, PP_MESH_0, dim_mapping=[-1, -1]) + self.linear0.weight, + dist_attr={ + "process_mesh": PP_MESH_0, + "dims_mapping": [-1, -1] + }) auto.shard_tensor( - self.linear1.weight, PP_MESH_1, dim_mapping=[-1, -1]) + self.linear1.weight, + dist_attr={ + "process_mesh": PP_MESH_1, + "dims_mapping": [-1, -1] + }) else: auto.shard_tensor( - self.linear0.weight, _global_process_mesh, - dim_mapping=[-1, -1]) + self.linear0.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, -1] + }) auto.shard_tensor( - self.linear1.weight, _global_process_mesh, - dim_mapping=[-1, -1]) + self.linear1.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, -1] + }) out = self.norm(input) out = self.linear0(out) @@ -88,12 +102,32 @@ def mlp_forward(train_program, start_program): name="label", shape=[batch_size, 1], dtype='float32') if _global_parallel_strategy == "pp": - auto.shard_tensor(input, PP_MESH_0, dim_mapping=[-1, -1]) - auto.shard_tensor(label, PP_MESH_1, dim_mapping=[-1, -1]) + auto.shard_tensor( + input, + dist_attr={ + "process_mesh": PP_MESH_0, + "dims_mapping": [-1, -1] + }) + auto.shard_tensor( + label, + dist_attr={ + "process_mesh": PP_MESH_1, + "dims_mapping": [-1, -1] + }) elif _global_parallel_strategy == "dp": - auto.shard_tensor(input, _global_process_mesh, dim_mapping=[0, -1]) + auto.shard_tensor( + input, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) else: - auto.shard_tensor(input, _global_process_mesh, dim_mapping=[-1, -1]) + auto.shard_tensor( + input, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, -1] + }) mlp = MLPLayer( hidden_size=hidden_size, @@ -108,8 +142,6 @@ def mlp_forward(train_program, start_program): def get_dist_prog(train_program, startup_program, dist_context, rank_id): - global _global_process_mesh - dist_context.set_process_mesh(_global_process_mesh) loss, train_program, startup_program = mlp_forward(train_program, startup_program) @@ -136,22 +168,21 @@ def check_backward_dist_attr(dist_context, dist_main_prog, op_need_check): has_dist_attr = True vars = dist_main_prog.global_block().vars - op_dist_attr = dist_context.get_op_distributed_attr_for_program( - op_need_check) - if not op_dist_attr or not op_dist_attr.get_process_mesh(): + op_dist_attr = dist_context.get_op_dist_attr_for_program(op_need_check) + if not op_dist_attr or not op_dist_attr.process_mesh: has_dist_attr = False for var_name in op_need_check.input_arg_names: if not op_dist_attr.get_input_dims_mapping(var_name) or \ - not dist_context.get_tensor_distributed_attr_for_program(vars[var_name]).get_dims_mapping() or \ - not dist_context.get_tensor_distributed_attr_for_program(vars[var_name]).get_process_mesh(): + not dist_context.get_tensor_dist_attr_for_program(vars[var_name]).dims_mapping or \ + not dist_context.get_tensor_dist_attr_for_program(vars[var_name]).process_mesh: has_dist_attr = False break if has_dist_attr: for var_name in op_need_check.output_arg_names: - if not dist_context.get_tensor_distributed_attr_for_program(vars[var_name]).get_dims_mapping() or \ - not dist_context.get_tensor_distributed_attr_for_program(vars[var_name]).get_process_mesh(): + if not dist_context.get_tensor_dist_attr_for_program(vars[var_name]).dims_mapping or \ + not dist_context.get_tensor_dist_attr_for_program(vars[var_name]).process_mesh: has_dist_attr = False break @@ -162,6 +193,7 @@ def check_send_recv_result(dist_main_prog, rank_id): send_result = False recv_result = False ops = dist_main_prog.global_block().ops + if rank_id == 0: for idx, op in enumerate(ops): if op.type == "send_v2" and "gelu_0.tmp_0" in op.input_arg_names: @@ -217,7 +249,7 @@ def check_initialization_for_dp(dist_startup_prog): class TestMLPReshard(unittest.TestCase): def test_complete_backward_annotation(self): global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) train_program = paddle.static.Program() startup_program = paddle.static.Program() @@ -231,6 +263,7 @@ class TestMLPReshard(unittest.TestCase): if op.type == "gelu_grad": op_need_check = op break + # print_program_with_dist_attr(dist_main_prog, dist_context) # grad op should have dist attr self.assertTrue( @@ -241,11 +274,11 @@ class TestMLPReshard(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "pp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) global PP_MESH_0 - PP_MESH_0 = auto.ProcessMesh(mesh=[0], parent=ROOT_MESH) + PP_MESH_0 = auto.ProcessMesh(mesh=[0]) global PP_MESH_1 - PP_MESH_1 = auto.ProcessMesh(mesh=[1], parent=ROOT_MESH) + PP_MESH_1 = auto.ProcessMesh(mesh=[1]) train_program = paddle.static.Program() startup_program = paddle.static.Program() @@ -253,9 +286,10 @@ class TestMLPReshard(unittest.TestCase): rank_id = 1 dist_main_prog, dist_startup_prog = get_dist_prog( train_program, startup_program, dist_context, rank_id) - for key in list(PROCESS_GROUP_MAP.keys()): - del PROCESS_GROUP_MAP[key] + for key in list(_g_process_group_map.keys()): + del _g_process_group_map[key] reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) + # print_program_with_dist_attr(dist_main_prog, dist_context) # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) @@ -267,7 +301,7 @@ class TestMLPReshard(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "dp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) train_program = paddle.static.Program() startup_program = paddle.static.Program() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py index babc622393c..4bd03a3e1bd 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py @@ -22,18 +22,17 @@ import paddle.static as static import paddle.nn.functional as F import paddle.utils as utils import paddle.distributed.auto_parallel as auto -from paddle.distributed.auto_parallel.context import DistributedContext +from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed import fleet from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.reshard import reshard +from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr paddle.enable_static() _global_parallel_strategy = "dp_mp_pp" -ROOT_MESH = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]]) -_global_process_mesh = auto.ProcessMesh( - [[[0, 1], [4, 5]], [[2, 3], [6, 7]]], parent=ROOT_MESH) -PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], parent=ROOT_MESH) -PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], parent=ROOT_MESH) +_global_process_mesh = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]]) +PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]]) +PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]]) class MLPLayer(nn.Layer): @@ -55,8 +54,14 @@ class MLPLayer(nn.Layer): self.norm = nn.LayerNorm(d_model, epsilon=1e-5) def forward(self, input): - auto.shard_tensor(self.linear0.weight, PP_MESH_0, dim_mapping=[-1, 1]) - auto.shard_tensor(self.linear1.weight, PP_MESH_1, dim_mapping=[1, -1]) + auto.shard_tensor( + self.linear0.weight, + dist_attr={"process_mesh": PP_MESH_0, + "dims_mapping": [-1, 1]}) + auto.shard_tensor( + self.linear1.weight, + dist_attr={"process_mesh": PP_MESH_1, + "dims_mapping": [1, -1]}) out = self.norm(input) out = self.linear0(out) @@ -77,8 +82,14 @@ def mlp_forward(train_program, start_program): label = static.data( name="label", shape=[batch_size, 1], dtype='float32') - auto.shard_tensor(input, PP_MESH_0, dim_mapping=[0, -1]) - auto.shard_tensor(label, PP_MESH_1, dim_mapping=[0, -1]) + auto.shard_tensor( + input, + dist_attr={"process_mesh": PP_MESH_0, + "dims_mapping": [0, -1]}) + auto.shard_tensor( + label, + dist_attr={"process_mesh": PP_MESH_1, + "dims_mapping": [0, -1]}) mlp = MLPLayer( hidden_size=hidden_size, @@ -94,7 +105,7 @@ def mlp_forward(train_program, start_program): def get_dist_prog(train_program, startup_program, dist_context, rank_id): global _global_process_mesh - dist_context.set_process_mesh(_global_process_mesh) + dist_context.process_mesh = _global_process_mesh loss, train_program, startup_program = mlp_forward(train_program, startup_program) @@ -156,10 +167,8 @@ class TestMLPReshard(unittest.TestCase): rank_id = 2 dist_main_prog, dist_startup_prog = get_dist_prog( train_program, startup_program, dist_context, rank_id) - print(dist_main_prog) reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) - print(dist_main_prog) - print(dist_startup_prog) + # print_program_with_dist_attr(dist_main_prog, dist_context) # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py index 96a8b2a8d7c..ae79712dc79 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py @@ -22,17 +22,17 @@ import paddle.static as static import paddle.nn.functional as F import paddle.utils as utils import paddle.distributed.auto_parallel as auto -from paddle.distributed.auto_parallel.context import DistributedContext +from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed import fleet from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.reshard import reshard +from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr paddle.enable_static() _global_parallel_strategy = "mp_pp" -ROOT_MESH = auto.ProcessMesh([[0, 1], [2, 3]]) -_global_process_mesh = auto.ProcessMesh([[0, 1], [2, 3]], parent=ROOT_MESH) -PP_MESH_0 = auto.ProcessMesh([0, 1], parent=ROOT_MESH) -PP_MESH_1 = auto.ProcessMesh([2, 3], parent=ROOT_MESH) +_global_process_mesh = auto.ProcessMesh([[0, 1], [2, 3]]) +PP_MESH_0 = auto.ProcessMesh([0, 1]) +PP_MESH_1 = auto.ProcessMesh([2, 3]) class MLPLayer(nn.Layer): @@ -64,10 +64,21 @@ class MLPLayer(nn.Layer): def forward(self, input): auto.shard_tensor( - self.word_embeddings.weight, PP_MESH_0, dim_mapping=[0, -1]) - auto.shard_tensor(self.linear0.weight, PP_MESH_0, dim_mapping=[-1, 0]) - auto.shard_tensor(self.linear1.weight, PP_MESH_1, dim_mapping=[0, -1]) - auto.shard_tensor(self.linear2.weight, PP_MESH_1, dim_mapping=[0, -1]) + self.word_embeddings.weight, + dist_attr={"process_mesh": PP_MESH_0, + "dims_mapping": [0, -1]}) + auto.shard_tensor( + self.linear0.weight, + dist_attr={"process_mesh": PP_MESH_0, + "dims_mapping": [-1, 0]}) + auto.shard_tensor( + self.linear1.weight, + dist_attr={"process_mesh": PP_MESH_1, + "dims_mapping": [0, -1]}) + auto.shard_tensor( + self.linear2.weight, + dist_attr={"process_mesh": PP_MESH_1, + "dims_mapping": [0, -1]}) w_out = self.word_embeddings(input) out = self.linear0(w_out) gelu_out = F.gelu(out, approximate=True) @@ -88,8 +99,13 @@ def mlp_forward(train_program, start_program): label = static.data( name="label", shape=[batch_size, 1], dtype='float32') - auto.shard_tensor(input, PP_MESH_0, dim_mapping=[-1]) - auto.shard_tensor(label, PP_MESH_1, dim_mapping=[-1, -1]) + auto.shard_tensor( + input, dist_attr={"process_mesh": PP_MESH_0, + "dims_mapping": [-1]}) + auto.shard_tensor( + label, + dist_attr={"process_mesh": PP_MESH_1, + "dims_mapping": [-1, -1]}) mlp = MLPLayer( hidden_size=hidden_size, @@ -105,7 +121,7 @@ def mlp_forward(train_program, start_program): def get_dist_prog(train_program, startup_program, dist_context, rank_id): global _global_process_mesh - dist_context.set_process_mesh(_global_process_mesh) + dist_context.process_mesh = _global_process_mesh loss, train_program, startup_program = mlp_forward(train_program, startup_program) @@ -198,19 +214,41 @@ class TestMLPReshard(unittest.TestCase): def test_allgather(self): train_program = paddle.static.Program() startup_program = paddle.static.Program() - process_mesh = auto.ProcessMesh(mesh=[0, 3], parent=ROOT_MESH) + process_mesh = auto.ProcessMesh(mesh=[0, 3]) with static.program_guard(train_program, startup_program): x = paddle.static.data(name="x", shape=[4, 4], dtype='float32') - x = auto.shard_tensor(x, process_mesh, dim_mapping=[0, -1]) + x = auto.shard_tensor( + x, + dist_attr={ + "process_mesh": process_mesh, + "dims_mapping": [0, -1] + }) w = paddle.static.data(name="w", shape=[4, 4], dtype='float32') - w = auto.shard_tensor(w, process_mesh, dim_mapping=[-1, -1]) - - y = paddle.distributed.shard_op(paddle.matmul, process_mesh, { - x.name: [-1, -1], - w.name: [-1, -1] - }, **{"x": x, - "y": w})[0] + w = auto.shard_tensor( + w, + dist_attr={ + "process_mesh": process_mesh, + "dims_mapping": [-1, -1] + }) + + # y = paddle.distributed.shard_op(paddle.matmul, process_mesh, { + # x.name: [-1, -1], + # w.name: [-1, -1] + # }, **{"x": x, + # "y": w})[0] + + y = paddle.distributed.shard_op( + paddle.matmul, + dist_attr={ + "process_mesh": process_mesh, + x: { + "dims_mapping": [-1, -1] + }, + w: { + "dims_mapping": [-1, -1] + } + })(x, w)[0] rank_id = 0 dist_context = DistributedContext() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_serial.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_serial.py index bf2ba9f061f..90dd0111dff 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_serial.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_serial.py @@ -26,16 +26,15 @@ import paddle.static as static import paddle.nn.functional as F import paddle.utils as utils import paddle.distributed.auto_parallel as auto -from paddle.distributed.auto_parallel.context import get_default_distributed_context +from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context from paddle.distributed import fleet from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.reshard import reshard -from paddle.distributed.auto_parallel.process import new_process_group +from paddle.distributed.auto_parallel.process_group import new_process_group paddle.enable_static() _global_parallel_strategy = None _global_process_mesh = None -ROOT_MESH = auto.ProcessMesh([0]) class MLPLayer(nn.Layer): @@ -59,16 +58,30 @@ class MLPLayer(nn.Layer): def forward(self, input): if _global_parallel_strategy == "pp": auto.shard_tensor( - self.linear0.weight, PP_MESH_0, dim_mapping=[-1, -1]) + self.linear0.weight, + dist_attr={ + "process_mesh": PP_MESH_0, + "dims_mapping": [-1, -1] + }) auto.shard_tensor( - self.linear1.weight, PP_MESH_1, dim_mapping=[-1, -1]) + self.linear1.weight, + dist_attr={ + "process_mesh": PP_MESH_1, + "dims_mapping": [-1, -1] + }) else: auto.shard_tensor( - self.linear0.weight, _global_process_mesh, - dim_mapping=[-1, -1]) + self.linear0.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, -1] + }) auto.shard_tensor( - self.linear1.weight, _global_process_mesh, - dim_mapping=[-1, -1]) + self.linear1.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, -1] + }) out = self.norm(input) out = self.linear0(out) @@ -90,12 +103,32 @@ def mlp_forward(train_program, start_program): name="label", shape=[batch_size, 1], dtype='float32') if _global_parallel_strategy == "pp": - auto.shard_tensor(input, PP_MESH_0, dim_mapping=[-1, -1]) - auto.shard_tensor(label, PP_MESH_1, dim_mapping=[-1, -1]) + auto.shard_tensor( + input, + dist_attr={ + "process_mesh": PP_MESH_0, + "dims_mapping": [-1, -1] + }) + auto.shard_tensor( + label, + dist_attr={ + "process_mesh": PP_MESH_1, + "dims_mapping": [-1, -1] + }) elif _global_parallel_strategy == "dp": - auto.shard_tensor(input, _global_process_mesh, dim_mapping=[0, -1]) + auto.shard_tensor( + input, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) else: - auto.shard_tensor(input, _global_process_mesh, dim_mapping=[-1, -1]) + auto.shard_tensor( + input, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, -1] + }) mlp = MLPLayer( hidden_size=hidden_size, @@ -168,7 +201,7 @@ class TestMLPReshard(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = None global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0], parent=ROOT_MESH) + _global_process_mesh = auto.ProcessMesh(mesh=[0]) train_program = paddle.static.Program() startup_program = paddle.static.Program() -- GitLab