From 1c0afa7922048f664f710138410637293a6f3d3d Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Mon, 26 Dec 2022 11:27:35 +0800 Subject: [PATCH] [Auto Parallel] Merge the python and c++ impls of ProcessMesh (#47503) * [Auto Parallel] Rename methods of ProcessMesh * [Auto Parallel] Impl the python process_mesh by the c++ one * [Auto Parallel] Add some minor modifications * [Auto Parallel] Rename some methods * [Auto Parallel] Remove unnecessary codes * [Auto Parallel] Add back some removed files * [Auto Parallel] Fix bugs * [Auto Parallel] Fix a bug * Update process_mesh.cc * [Auto Parallel] Fix a bug --- paddle/fluid/pybind/auto_parallel_py.cc | 12 +- .../distributed/auto_parallel/completion.py | 51 ++------ .../auto_parallel/cost/base_cost.py | 24 ++-- .../auto_parallel/cost/estimate_cost.py | 22 ++-- .../auto_parallel/dist_attribute.py | 37 ++---- .../distributed/auto_parallel/dist_context.py | 8 +- .../distributed/auto_parallel/dist_op.py | 8 +- .../distributed/auto_parallel/dist_tensor.py | 16 +-- .../distributed/auto_parallel/helper.py | 4 +- .../auto_parallel/operators/common.py | 12 +- .../dist_check_finite_and_unscale.py | 4 +- .../auto_parallel/operators/dist_default.py | 16 +-- .../auto_parallel/operators/dist_eltwise.py | 8 +- .../auto_parallel/operators/dist_embedding.py | 22 ++-- .../dist_fill_constant_batch_size_like.py | 4 +- .../operators/dist_fused_attention.py | 12 +- .../operators/dist_fused_feedforward.py | 12 +- .../auto_parallel/operators/dist_matmul.py | 100 +++++++-------- .../auto_parallel/operators/dist_pnorm.py | 10 +- .../auto_parallel/operators/dist_reshape.py | 30 ++--- .../auto_parallel/operators/dist_softmax.py | 6 +- .../auto_parallel/operators/dist_transpose.py | 6 +- .../operators/dist_update_loss_scaling.py | 4 +- .../distributed/auto_parallel/parallelizer.py | 4 +- .../distributed/auto_parallel/partitioner.py | 2 +- .../distributed/auto_parallel/planner.py | 15 ++- .../distributed/auto_parallel/process_mesh.py | 118 +++++++++++------- .../distributed/auto_parallel/reshard.py | 58 +++++---- .../auto_parallel/tuner/optimization_tuner.py | 2 +- .../auto_parallel/tuner/parallel_tuner.py | 12 +- .../paddle/distributed/auto_parallel/utils.py | 26 ++-- .../passes/auto_parallel_grad_clip.py | 6 +- .../passes/auto_parallel_sharding.py | 6 +- .../unittests/auto_parallel/test_base_cost.py | 2 +- .../test_conditional_block_reshard.py | 4 +- .../auto_parallel/test_dist_op_cost.py | 8 +- .../auto_parallel/test_process_mesh.py | 69 +++++++++- .../auto_parallel/test_process_mesh_v2.py | 1 - .../test_auto_parallel_partitioner.py | 4 +- .../test_auto_parallel_partitioner_gpt.py | 4 +- 40 files changed, 415 insertions(+), 354 deletions(-) diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index 089f5da5abc..a24001819fe 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -69,6 +69,12 @@ void BindAutoParallel(py::module *m) { .def("contains", &ProcessMesh::contains) .def(py::self == py::self) .def(py::self != py::self) + .def("__copy__", + [](const ProcessMesh &self) { return ProcessMesh(self); }) + .def( + "__deepcopy__", + [](const ProcessMesh &self, py::dict) { return ProcessMesh(self); }, + py::arg("memo")) .def("__str__", &ProcessMesh::to_string); py::class_(*m, "DeviceCapability") @@ -131,7 +137,7 @@ void BindAutoParallel(py::module *m) { const std::vector &>(), py::arg("name"), py::arg("shape"), - py::arg("process_ids"), + py::arg("device_ids"), py::arg("dim_names")) .def_property_readonly("name", &DeviceMesh::name) .def_property_readonly("shape", &DeviceMesh::shape) @@ -165,6 +171,8 @@ void BindAutoParallel(py::module *m) { &DeviceMesh::dim_size)) .def(py::self == py::self) .def(py::self != py::self) + .def("__copy__", + [](const TensorDistAttr &self) { return TensorDistAttr(self); }) .def( "__deepcopy__", [](const TensorDistAttr &self, py::dict) { @@ -256,6 +264,8 @@ void BindAutoParallel(py::module *m) { .def("parse_from_string", &OperatorDistAttr::parse_from_string) .def(py::self == py::self) .def(py::self != py::self) + .def("__copy__", + [](const OperatorDistAttr &self) { return OperatorDistAttr(self); }) .def( "__deepcopy__", [](const OperatorDistAttr &self, py::dict) { diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 4f2b2e79874..75533720c95 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -25,7 +25,7 @@ from .dist_attribute import ( from .dist_context import _node_id from .operators import find_compatible_distributed_operator_impls from .process_group import get_world_process_group -from .process_mesh import ProcessMesh +from .process_mesh import ProcessMesh, compute_compatible_process_mesh from .utils import ( __no_shape_var_type__, get_logger, @@ -34,47 +34,12 @@ from .utils import ( ) -def compute_compatible_process_mesh(process_mesh_list): - """Compute the compatible process mesh given a list of process meshes.""" - if not process_mesh_list: - return None - - def _compute_compatible_process_mesh_two(pm1, pm2): - if pm1 is None: - return True, pm2 - if pm2 is None: - return True, pm1 - if pm1 == pm2: - return True, pm1 - if pm1.processes == pm2.processes: - if len(pm1.topology) >= len(pm2.topology): - return True, pm1 - else: - return True, pm2 - process_set1 = set(pm1.processes) - process_set2 = set(pm2.processes) - if process_set1.issubset(process_set2): - return True, pm2 - if process_set2.issubset(process_set1): - return True, pm1 - return False, None - - compatible_result = None - for process_mesh in process_mesh_list: - compatible, compatible_result = _compute_compatible_process_mesh_two( - compatible_result, process_mesh - ) - if not compatible: - return None - return copy.deepcopy(compatible_result) - - def compute_compatible_dim_mapping(dim_mapping_list): """Compute the compatible dim mapping given a list of dim mapping.""" if not dim_mapping_list: return None - def _compute_compatible_dim_mapping_two(dm1, dm2): + def _compute_compatible_dim_mapping_of_two(dm1, dm2): if dm1 == -1: return True, dm2 if dm2 == -1: @@ -85,7 +50,7 @@ def compute_compatible_dim_mapping(dim_mapping_list): compatible_result = -1 for mapping in dim_mapping_list: - compatible, compatible_result = _compute_compatible_dim_mapping_two( + compatible, compatible_result = _compute_compatible_dim_mapping_of_two( compatible_result, mapping ) if not compatible: @@ -122,9 +87,9 @@ def merge_process_mesh_two(pm1, pm2): if pm1 is None and pm2 is None: return None if pm1 is not None: - process_set1 = set(pm1.processes) + process_set1 = set(pm1.process_ids) if pm2 is not None: - process_set2 = set(pm2.processes) + process_set2 = set(pm2.process_ids) merged_process_set = process_set1.union(process_set2) merged_process_mesh = ProcessMesh(list(merged_process_set)) return merged_process_mesh @@ -134,11 +99,9 @@ def _validate_dims_mapping(dims_mapping, process_mesh): if dims_mapping is None: return False for i in range(len(dims_mapping)): - if dims_mapping[i] < -1 or dims_mapping[i] >= len( - process_mesh.topology - ): + if dims_mapping[i] < -1 or dims_mapping[i] >= len(process_mesh.shape): return False - for i in range(len(process_mesh.topology)): + for i in range(len(process_mesh.shape)): if dims_mapping.count(i) > 1: return False return True diff --git a/python/paddle/distributed/auto_parallel/cost/base_cost.py b/python/paddle/distributed/auto_parallel/cost/base_cost.py index 3de1b46453d..cb1f2654a2b 100644 --- a/python/paddle/distributed/auto_parallel/cost/base_cost.py +++ b/python/paddle/distributed/auto_parallel/cost/base_cost.py @@ -80,7 +80,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context): dist_attr = dist_op.dist_attr process_mesh = dist_attr.process_mesh assert process_mesh, "Process mesh must not be None." - processes = process_mesh.processes + processes = process_mesh.process_ids for process in processes: desc = {} desc["op"] = op.type @@ -103,7 +103,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context): global_sizes = var.shape # NOTE: When support uneven partition, the shard_sizes will be got from dist_attr. shard_sizes = None - topology = process_mesh.topology + topology = process_mesh.shape shape = DistributedTensor.get_local_sizes( global_sizes, dims_mapping, @@ -129,7 +129,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context): ) relative_idx = _get_idx_in_axis( processes, - dist_attr.process_mesh.topology, + dist_attr.process_mesh.shape, embedding_row_dim_mapping, process, ) @@ -153,8 +153,8 @@ def build_comp_desc_from_dist_op(dist_op, dist_context): process_mesh = dist_attr.process_mesh global_sizes = var.shape shard_sizes = None - processes = process_mesh.processes - topology = process_mesh.topology + processes = process_mesh.process_ids + topology = process_mesh.shape shape = DistributedTensor.get_local_sizes( global_sizes, dims_mapping, @@ -170,7 +170,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context): # Modify shape attr according to how output are partitioned out_name = var_name_list[0] dims_mapping = dist_attr.get_output_dims_mapping(out_name) - process_mesh_shape = dist_attr.process_mesh.topology + process_mesh_shape = dist_attr.process_mesh.shape shape_list = op.attr("shape") # Modify target shape for idx, axis in enumerate(dims_mapping): @@ -253,7 +253,7 @@ def build_comm_desc_from_dist_op( process_mesh = dist_attr.process_mesh assert process_mesh, "Process mesh must not be None." - processes = process_mesh.processes + processes = process_mesh.process_ids op_descs = {} for process in processes: rank_id = process @@ -295,7 +295,7 @@ def build_comm_desc_from_dist_op( ) global_sizes = var.shape shard_sizes = None - topology = process_mesh.topology + topology = process_mesh.shape shape = DistributedTensor.get_local_sizes( global_sizes, dims_mapping, @@ -311,8 +311,8 @@ def build_comm_desc_from_dist_op( # Get comm group by parallel_axis or the given group_ranks. if parallel_axis is not None: - process_mesh_shape = process_mesh.topology - process_mesh_group = process_mesh.processes + process_mesh_shape = process_mesh.shape + process_mesh_group = process_mesh.process_ids comm_group_ranks = _get_comm_group( process_mesh_group, process_mesh_shape, @@ -384,7 +384,7 @@ def build_dp_costs( dist_attr = dist_op.dist_attr process_mesh = dist_attr.process_mesh - processes = process_mesh.processes + processes = process_mesh.process_ids assert len(var_names) == 1 vars = dist_op.serial_op.block.vars var_name = var_names[0] @@ -443,7 +443,7 @@ def build_dp_costs( ) global_sizes = var.shape shard_sizes = None - topology = process_mesh.topology + topology = process_mesh.shape shape = DistributedTensor.get_local_sizes( global_sizes, dims_mapping, diff --git a/python/paddle/distributed/auto_parallel/cost/estimate_cost.py b/python/paddle/distributed/auto_parallel/cost/estimate_cost.py index 7d6f8d8474c..f5c1172cef5 100644 --- a/python/paddle/distributed/auto_parallel/cost/estimate_cost.py +++ b/python/paddle/distributed/auto_parallel/cost/estimate_cost.py @@ -190,7 +190,7 @@ class CostEstimator: # Calc dist op cost dist_op = dist_context.get_dist_op_for_program(op) op_dist_attr = dist_op.dist_attr - processes = op_dist_attr.process_mesh.processes + processes = op_dist_attr.process_mesh.process_ids container = get_distributed_operator_impl_container( op_dist_attr.impl_type @@ -273,8 +273,8 @@ class CostEstimator: # This estimation will be improved, now reshard and inplace are not considered. # Persist var is not free. def _convert_pm_and_dm_to_str(process_mesh, dims_mapping): - processes = ",".join([str(x) for x in process_mesh.processes]) - topology = ",".join([str(x) for x in process_mesh.topology]) + processes = ",".join([str(x) for x in process_mesh.process_ids]) + topology = ",".join([str(x) for x in process_mesh.shape]) dims_mapping = ",".join([str(x) for x in dims_mapping]) result = processes + topology + dims_mapping return result @@ -318,8 +318,8 @@ class CostEstimator: sizes = DistributedTensor.get_local_sizes( global_sizes, input_dims_mapping, - process_mesh.topology, - process_mesh.processes, + process_mesh.shape, + process_mesh.process_ids, ) var_info[var_name][key]["memory"] = self._calculate_bytes( sizes, dtype @@ -346,8 +346,8 @@ class CostEstimator: sizes = DistributedTensor.get_local_sizes( global_sizes, output_dims_mapping, - process_mesh.topology, - process_mesh.processes, + process_mesh.shape, + process_mesh.process_ids, ) var_info[var_name][key]["memory"] = self._calculate_bytes( sizes, dtype @@ -380,7 +380,7 @@ class CostEstimator: # Not used if var_name + key not in has_used_vars: has_used_vars.add(has_used_var) - for process in process_mesh.processes: + for process in process_mesh.process_ids: if process not in memories: memories[process] = 0 memories[process] += var_info[var_name][key]["memory"] @@ -390,7 +390,7 @@ class CostEstimator: if has_used_var not in can_free_vars: can_free_vars.add(has_used_var) if not var.persistable: - for process in process_mesh.processes: + for process in process_mesh.process_ids: if process not in can_free_memories: can_free_memories[process] = 0 can_free_memories[process] += var_info[ @@ -409,7 +409,7 @@ class CostEstimator: # Not used if var_name + key not in has_used_vars: has_used_vars.add(has_used_var) - for process in process_mesh.processes: + for process in process_mesh.process_ids: if process not in memories: memories[process] = 0 memories[process] += var_info[var_name][key]["memory"] @@ -419,7 +419,7 @@ class CostEstimator: if has_used_var not in can_free_vars: can_free_vars.add(has_used_var) if not var.persistable: - for process in process_mesh.processes: + for process in process_mesh.process_ids: if process not in can_free_memories: can_free_memories[process] = 0 can_free_memories[process] += var_info[ diff --git a/python/paddle/distributed/auto_parallel/dist_attribute.py b/python/paddle/distributed/auto_parallel/dist_attribute.py index 8635818bb34..c464c89c5f7 100644 --- a/python/paddle/distributed/auto_parallel/dist_attribute.py +++ b/python/paddle/distributed/auto_parallel/dist_attribute.py @@ -121,32 +121,17 @@ class TensorDistributedAttribute: if dist_attr is None: return assert isinstance( - dist_attr, (dict, TensorDistributedAttribute) + dist_attr, TensorDistributedAttribute ), "The type of dist_attr must be dict or TensorDistributedAttribute." - if isinstance(dist_attr, dict): - for key, value in dist_attr.items(): - if key in get_tensor_dist_attr_field_keys(): - field_property = TensorDistributedAttribute.__dict__.get( - key, None - ) - if field_property: - field_property.fset(self, value) - else: - assert False, "No setter for {} in args {}.".format( - key, dist_attr - ) - elif isinstance(dist_attr, TensorDistributedAttribute): - for key in get_tensor_dist_attr_field_keys(): - field_property = TensorDistributedAttribute.__dict__.get( - key, None + for key in get_tensor_dist_attr_field_keys(): + field_property = TensorDistributedAttribute.__dict__.get(key, None) + if field_property: + field_property.fset(self, field_property.fget(dist_attr)) + else: + assert False, "No setter for {} in args {}.".format( + key, dist_attr ) - if field_property: - field_property.fset(self, field_property.fget(dist_attr)) - else: - assert False, "No setter for {} in args {}.".format( - key, dist_attr - ) - self._is_annotated = copy.deepcopy(dist_attr._is_annotated) + self._is_annotated = copy.deepcopy(dist_attr._is_annotated) def reset(self, skip_dist_attr_field_names=None): if skip_dist_attr_field_names is None or ( @@ -243,7 +228,9 @@ class OperatorDistributedAttribute: if process_mesh is not None: assert isinstance( process_mesh, (list, ProcessMesh) - ), "The type of process_mesh must be list or ProcessMesh." + ), "The type of process_mesh must be list or ProcessMesh, but receive {}".format( + type(process_mesh) + ) if isinstance(process_mesh, list): process_mesh = ProcessMesh(process_mesh) self._process_mesh = copy.deepcopy(process_mesh) diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index 7ac260af150..934696e00ee 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -870,8 +870,8 @@ class DistributedContext: else: tensor_shape = serial_tensor.shape dims_mapping = dist_attr.dims_mapping - process_mesh_shape = dist_attr.process_mesh.topology - process_mesh_processes = dist_attr.process_mesh.processes + process_mesh_shape = dist_attr.process_mesh.shape + process_mesh_processes = dist_attr.process_mesh.process_ids # If the dimension of tensor is less than the sharding dimension of process mesh, # we just amend the dimension mapping to -1. (Is this really OK?) for i in range(len(tensor_shape)): @@ -887,8 +887,8 @@ class DistributedContext: for dist_op in self._dist_ops_for_program.values(): serial_op = dist_op.serial_op dist_attr = dist_op.dist_attr - process_mesh_shape = dist_attr.process_mesh.topology - process_mesh_processes = dist_attr.process_mesh.processes + process_mesh_shape = dist_attr.process_mesh.shape + process_mesh_processes = dist_attr.process_mesh.process_ids for arg_name in serial_op.input_arg_names: if dist_op.get_serial_input(arg_name) is None: tensor_shape = [] diff --git a/python/paddle/distributed/auto_parallel/dist_op.py b/python/paddle/distributed/auto_parallel/dist_op.py index 484bf45111d..4ad796e0b27 100644 --- a/python/paddle/distributed/auto_parallel/dist_op.py +++ b/python/paddle/distributed/auto_parallel/dist_op.py @@ -159,10 +159,10 @@ class DistributedOperator: return False for i in range(len(dims_mapping)): if dims_mapping[i] < -1 or dims_mapping[i] >= len( - self.dist_attr.process_mesh.topology + self.dist_attr.process_mesh.shape ): return False - for i in range(len(self.dist_attr.process_mesh.topology)): + for i in range(len(self.dist_attr.process_mesh.shape)): if dims_mapping.count(i) > 1: return False if self.dist_attr.process_mesh != input_dist_attr.process_mesh: @@ -179,10 +179,10 @@ class DistributedOperator: return False for i in range(len(dims_mapping)): if dims_mapping[i] < -1 or dims_mapping[i] >= len( - self.dist_attr.process_mesh.topology + self.dist_attr.process_mesh.shape ): return False - for i in range(len(self.dist_attr.process_mesh.topology)): + for i in range(len(self.dist_attr.process_mesh.shape)): if dims_mapping.count(i) > 1: return False if self.dist_attr.process_mesh != output_dist_attr.process_mesh: diff --git a/python/paddle/distributed/auto_parallel/dist_tensor.py b/python/paddle/distributed/auto_parallel/dist_tensor.py index 9c55998857d..4c789ca50e3 100644 --- a/python/paddle/distributed/auto_parallel/dist_tensor.py +++ b/python/paddle/distributed/auto_parallel/dist_tensor.py @@ -225,10 +225,10 @@ class DistributedTensor: if self.dist_attr.dims_mapping[ i ] < -1 or self.dist_attr.dims_mapping[i] >= len( - self.dist_attr.process_mesh.topology + self.dist_attr.process_mesh.shape ): return False - for i in range(len(self.dist_attr.process_mesh.topology)): + for i in range(len(self.dist_attr.process_mesh.shape)): if self.dist_attr.dims_mapping.count(i) > 1: return False return True @@ -239,8 +239,8 @@ class DistributedTensor: global_sizes = self.serial_tensor.shape dims_mapping = self.dist_attr.dims_mapping shard_sizes = self.dist_attr.shard_sizes - processes = self.dist_attr.process_mesh.processes - topology = self.dist_attr.process_mesh.topology + processes = self.dist_attr.process_mesh.process_ids + topology = self.dist_attr.process_mesh.shape local_sizes = DistributedTensor.get_local_sizes( global_sizes, dims_mapping, topology, processes, rank, shard_sizes ) @@ -256,8 +256,8 @@ class DistributedTensor: global_sizes = self.serial_tensor.shape dims_mapping = self.dist_attr.dims_mapping shard_sizes = self.dist_attr.shard_sizes - processes = self.dist_attr.process_mesh.processes - topology = self.dist_attr.process_mesh.topology + processes = self.dist_attr.process_mesh.process_ids + topology = self.dist_attr.process_mesh.shape local_offsets = DistributedTensor.get_local_offsets( global_sizes, dims_mapping, @@ -282,8 +282,8 @@ class DistributedTensor: global_sizes = self.serial_tensor.shape dims_mapping = self.dist_attr.dims_mapping shard_sizes = self.dist_attr.shard_sizes - processes = self.dist_attr.process_mesh.processes - topology = self.dist_attr.process_mesh.topology + processes = self.dist_attr.process_mesh.process_ids + topology = self.dist_attr.process_mesh.shape local_shard = DistributedTensor.get_local_shard( global_sizes, dims_mapping, diff --git a/python/paddle/distributed/auto_parallel/helper.py b/python/paddle/distributed/auto_parallel/helper.py index f156eba1b0e..68741eb1211 100644 --- a/python/paddle/distributed/auto_parallel/helper.py +++ b/python/paddle/distributed/auto_parallel/helper.py @@ -324,8 +324,8 @@ class ProgramHelper: var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var) dist_attr = { "dims_mapping": var_dist_attr.dims_mapping, - "process_shape": var_dist_attr.process_mesh.topology, - "process_group": var_dist_attr.process_mesh.processes, + "process_shape": var_dist_attr.process_mesh.shape, + "process_group": var_dist_attr.process_mesh.process_ids, } # slice param_value with dist_attr # share sliced_param_value with param_tensor in global_scope diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index 7eace81155c..ff1518f9b8f 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -275,7 +275,7 @@ def is_parameter_related(varname, block): def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr): var_shape = block._var_recursive(src_var.name).shape - var_topoloy = src_var_dist_attr.process_mesh.topology + var_topoloy = src_var_dist_attr.process_mesh.shape var_dims_mapping = src_var_dist_attr.dims_mapping complete_shape = [] @@ -287,7 +287,7 @@ def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr): complete_shape.append(new_shape) exact_shape = [] - input_topology = op_input_dist_attr.process_mesh.topology + input_topology = op_input_dist_attr.process_mesh.shape input_dims_mapping = op_input_dist_attr.dims_mapping for idx, shape in enumerate(complete_shape): if input_dims_mapping[idx] == -1: @@ -362,10 +362,10 @@ def get_data_parallel_group(dist_ctx, op, act_grad_names, rank): op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op) process_mesh = op_dist_attr.process_mesh - mesh_shape = process_mesh.topology + mesh_shape = process_mesh.shape # FIXME Hack for Pipeline Parallelism where the current operator # not belong to the mesh the current rank belong to. - if rank not in process_mesh.processes: + if rank not in process_mesh.process_ids: rank = _get_corresponding_rank(dist_ctx, process_mesh, rank) for var_name in act_grad_names: @@ -376,8 +376,8 @@ def get_data_parallel_group(dist_ctx, op, act_grad_names, rank): if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: group_ranks = _get_comm_group( - process_mesh.processes, - process_mesh.topology, + process_mesh.process_ids, + process_mesh.shape, batch_size_axis, rank, ) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py b/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py index 1c2e4890736..6a681be1a37 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py @@ -89,7 +89,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): str(backward_op) ) - assert rank_id in dist_attr.process_mesh.processes + assert rank_id in dist_attr.process_mesh.process_ids assert 'X' in kwargs, "input [{}] is not given".format('X') assert 'Scale' in kwargs, "input [{}] is not given".format('Scale') @@ -120,7 +120,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): rank_id in ctx.get_tensor_dist_attr_for_program( main_block._var_recursive(varname) - ).process_mesh.processes + ).process_mesh.process_ids ): filter_vars.append(varname) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index 402c4cd74b4..fe607356339 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -124,7 +124,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): desc_mapping = build_comp_desc_from_dist_op( dist_op=dist_op, dist_context=ctx ) - processes = dist_op.dist_attr.process_mesh.processes + processes = dist_op.dist_attr.process_mesh.process_ids op_type = dist_op.serial_op.type cost_mapping = build_comp_costs_from_descs( _g_op_cost_factory[op_type], ctx, processes, desc_mapping, cluster @@ -141,7 +141,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ) dist_attr = dist_op.dist_attr process_mesh = dist_attr.process_mesh - processes = process_mesh.processes + processes = process_mesh.process_ids backward_op = dist_op.serial_op op_type = backward_op.type cost_mapping = build_comp_costs_from_descs( @@ -157,7 +157,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): varname, main_block ): var_dim_mapping = dist_attr.get_input_dims_mapping(varname) - mesh_shape = process_mesh.topology + mesh_shape = process_mesh.shape batch_size_axis = var_dim_mapping[0] if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: need_gradient_allreduce = True @@ -172,7 +172,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): var_dim_mapping = dist_attr.get_input_dims_mapping( varname ) - mesh_shape = process_mesh.topology + mesh_shape = process_mesh.shape batch_size_axis = var_dim_mapping[0] parallel_axis = batch_size_axis attrs = {"use_calc_stream": True} @@ -501,19 +501,19 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): dims_mapping = param_dist_attr.dims_mapping # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism - if rank_id not in process_mesh.processes: + if rank_id not in process_mesh.process_ids: rank_id = _get_corresponding_rank( ctx, process_mesh, rank_id ) # NOTE all not splited axis should be presented in mesh - for axis, size in enumerate(process_mesh.topology): + for axis, size in enumerate(process_mesh.shape): if size <= 1 or axis in dims_mapping: pass else: group_ranks = _get_comm_group( - process_mesh.processes, - process_mesh.topology, + process_mesh.process_ids, + process_mesh.shape, axis, rank_id, ) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py b/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py index e0e1b1213f6..3e0924d143f 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py @@ -67,7 +67,7 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): desc_mapping = build_comp_desc_from_dist_op( dist_op=dist_op, dist_context=ctx ) - processes = dist_op.dist_attr.process_mesh.processes + processes = dist_op.dist_attr.process_mesh.process_ids op_type = dist_op.serial_op.type cost_mapping = build_comp_costs_from_descs( _g_op_cost_factory[op_type], ctx, processes, desc_mapping, cluster @@ -84,7 +84,7 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): ) dist_attr = dist_op.dist_attr process_mesh = dist_attr.process_mesh - processes = process_mesh.processes + processes = process_mesh.process_ids backward_op = dist_op.serial_op op_type = backward_op.type cost_mapping = build_comp_costs_from_descs( @@ -100,7 +100,7 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): varname, main_block ): var_dim_mapping = dist_attr.get_input_dims_mapping(varname) - mesh_shape = process_mesh.topology + mesh_shape = process_mesh.shape batch_size_axis = var_dim_mapping[0] if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: need_gradient_allreduce = True @@ -115,7 +115,7 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): var_dim_mapping = dist_attr.get_input_dims_mapping( varname ) - mesh_shape = process_mesh.topology + mesh_shape = process_mesh.shape batch_size_axis = var_dim_mapping[0] parallel_axis = batch_size_axis attrs = {"use_calc_stream": True} diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index 9619d12681a..27751a75d87 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -175,7 +175,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): desc_mapping = build_comp_desc_from_dist_op( dist_op=dist_op, dist_context=ctx ) - processes = dist_op.dist_attr.process_mesh.processes + processes = dist_op.dist_attr.process_mesh.process_ids # embedding need start_index cost_mapping = build_comp_costs_from_descs( EmbeddingOpCost, ctx, processes, desc_mapping, cluster @@ -231,7 +231,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ) process_mesh = dist_attr.process_mesh - processes = process_mesh.processes + processes = process_mesh.process_ids comm_op_cost_list = build_comm_costs_from_descs( IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster ) @@ -250,7 +250,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): var_dim_mapping = dist_attr.get_input_dims_mapping( backward_op.input("Ids")[0] ) - mesh_shape = process_mesh.topology + mesh_shape = process_mesh.shape batch_size_axis = var_dim_mapping[0] if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: parallel_axis = batch_size_axis @@ -390,8 +390,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ), "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format( embedding_row_dim_mapping ) - process_mesh_shape = op_dist_attr.process_mesh.topology - process_mesh_group = op_dist_attr.process_mesh.processes + process_mesh_shape = op_dist_attr.process_mesh.shape + process_mesh_group = op_dist_attr.process_mesh.process_ids # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism if rank_id not in process_mesh_group: @@ -539,13 +539,13 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): dim_mapping = param_dist_attr.dims_mapping # NOTE all not splited axis should be presented in mesh - for axis, size in enumerate(process_mesh.topology): + for axis, size in enumerate(process_mesh.shape): if size <= 1 or axis in dim_mapping: pass else: group_ranks = _get_comm_group( - process_mesh.processes, - process_mesh.topology, + process_mesh.process_ids, + process_mesh.shape, axis, rank_id, ) @@ -579,7 +579,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism - if rank_id not in dist_attr.process_mesh.processes: + if rank_id not in dist_attr.process_mesh.process_ids: rank_id = _get_corresponding_rank( ctx, dist_attr.process_mesh, rank_id ) @@ -623,8 +623,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ), "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format( embedding_row_dim_mapping ) - process_mesh_shape = dist_attr.process_mesh.topology - process_mesh_group = dist_attr.process_mesh.processes + process_mesh_shape = dist_attr.process_mesh.shape + process_mesh_group = dist_attr.process_mesh.process_ids # A generalized method to caculate embedding offset using cartisian product relative_idx = _get_idx_in_axis( diff --git a/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py b/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py index 014aa0e98d8..7aa3fecfc05 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py @@ -61,7 +61,7 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): desc_mapping = build_comp_desc_from_dist_op( dist_op=dist_op, dist_context=ctx ) - processes = dist_op.dist_attr.process_mesh.processes + processes = dist_op.dist_attr.process_mesh.process_ids op_type = dist_op.serial_op.type cost_mapping = build_comp_costs_from_descs( FillConstantBatchSizeLikeOpCost, @@ -139,7 +139,7 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): # modify shape attr according to how output are partitioned out_name = op.output('Out')[0] dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) - process_mesh_shape = op_dist_attr.process_mesh.topology + process_mesh_shape = op_dist_attr.process_mesh.shape shape_list = op.attr("shape") # modify target shape for idx, axis in enumerate(dims_mapping): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_fused_attention.py b/python/paddle/distributed/auto_parallel/operators/dist_fused_attention.py index b4cf4da452e..14e5e9956ee 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_fused_attention.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_fused_attention.py @@ -156,7 +156,7 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl): rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - if rank_id not in op_dist_attr.process_mesh.processes: + if rank_id not in op_dist_attr.process_mesh.process_ids: rank_id = _get_corresponding_rank( ctx, op_dist_attr.process_mesh, rank_id ) @@ -172,8 +172,8 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl): ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( qkv_w_col_dim_mapping ) - process_mesh_shape = op_dist_attr.process_mesh.topology - process_mesh_group = op_dist_attr.process_mesh.processes + process_mesh_shape = op_dist_attr.process_mesh.shape + process_mesh_group = op_dist_attr.process_mesh.process_ids parallel_axis = qkv_w_col_dim_mapping group_ranks = _get_comm_group( @@ -198,7 +198,7 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl): rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - if rank_id not in op_dist_attr.process_mesh.processes: + if rank_id not in op_dist_attr.process_mesh.process_ids: rank_id = _get_corresponding_rank( ctx, op_dist_attr.process_mesh, rank_id ) @@ -211,8 +211,8 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl): ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( out_w_col_dim_mapping ) - process_mesh_shape = op_dist_attr.process_mesh.topology - process_mesh_group = op_dist_attr.process_mesh.processes + process_mesh_shape = op_dist_attr.process_mesh.shape + process_mesh_group = op_dist_attr.process_mesh.process_ids parallel_axis = out_w_col_dim_mapping group_ranks = _get_comm_group( diff --git a/python/paddle/distributed/auto_parallel/operators/dist_fused_feedforward.py b/python/paddle/distributed/auto_parallel/operators/dist_fused_feedforward.py index 2b7f38d4754..84da09acfd0 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_fused_feedforward.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_fused_feedforward.py @@ -148,7 +148,7 @@ class DistributedFusedFeedForwardImpl(DistributedOperatorImpl): rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - if rank_id not in op_dist_attr.process_mesh.processes: + if rank_id not in op_dist_attr.process_mesh.process_ids: rank_id = _get_corresponding_rank( ctx, op_dist_attr.process_mesh, rank_id ) @@ -163,8 +163,8 @@ class DistributedFusedFeedForwardImpl(DistributedOperatorImpl): ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( linear1_weight_col_dim_mapping ) - process_mesh_shape = op_dist_attr.process_mesh.topology - process_mesh_group = op_dist_attr.process_mesh.processes + process_mesh_shape = op_dist_attr.process_mesh.shape + process_mesh_group = op_dist_attr.process_mesh.process_ids parallel_axis = linear1_weight_col_dim_mapping group_ranks = _get_comm_group( @@ -190,7 +190,7 @@ class DistributedFusedFeedForwardImpl(DistributedOperatorImpl): rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - if rank_id not in op_dist_attr.process_mesh.processes: + if rank_id not in op_dist_attr.process_mesh.process_ids: rank_id = _get_corresponding_rank( ctx, op_dist_attr.process_mesh, rank_id ) @@ -205,8 +205,8 @@ class DistributedFusedFeedForwardImpl(DistributedOperatorImpl): ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( linear2_weight_col_dim_mapping ) - process_mesh_shape = op_dist_attr.process_mesh.topology - process_mesh_group = op_dist_attr.process_mesh.processes + process_mesh_shape = op_dist_attr.process_mesh.shape + process_mesh_group = op_dist_attr.process_mesh.process_ids parallel_axis = linear2_weight_col_dim_mapping group_ranks = _get_comm_group( diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index 9249837d513..e3d17b26b68 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -299,7 +299,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ), "backward op [{}] don't have dist attribute !".format(str(backward_op)) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism - if rank_id not in dist_attr.process_mesh.processes: + if rank_id not in dist_attr.process_mesh.process_ids: rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh, rank_id) assert 'Y' in kwargs, "input [{}] is not given".format('Y') @@ -341,8 +341,8 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): X_var_dims_mapping = dist_attr.get_input_dims_mapping(X_var.name) Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name) - process_mesh_shape = dist_attr.process_mesh.topology - process_mesh_group = dist_attr.process_mesh.processes + process_mesh_shape = dist_attr.process_mesh.shape + process_mesh_group = dist_attr.process_mesh.process_ids trans_x = None trans_y = None @@ -532,12 +532,12 @@ def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id): process_mesh = param_dist_attr.process_mesh dim_mapping = param_dist_attr.dims_mapping - for axis, size in enumerate(process_mesh.topology): + for axis, size in enumerate(process_mesh.shape): if size <= 1 or axis in dim_mapping: pass else: group_ranks = _get_comm_group( - process_mesh.processes, process_mesh.topology, axis, rank_id + process_mesh.process_ids, process_mesh.shape, axis, rank_id ) sync_group = new_process_group(group_ranks) @@ -600,7 +600,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): dist_op=dist_op, dist_context=ctx ) process_mesh = dist_attr.process_mesh - processes = process_mesh.processes + processes = process_mesh.process_ids cost_mapping = build_comp_costs_from_descs( MatmulGradOpCost, ctx, processes, desc_mapping, cluster ) @@ -631,7 +631,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): var_dim_mapping = dist_attr.get_input_dims_mapping( backward_op.input("X")[0] ) - mesh_shape = process_mesh.topology + mesh_shape = process_mesh.shape batch_size_axis = var_dim_mapping[0] if ( batch_size_axis > -1 @@ -651,7 +651,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): desc_mapping = build_comp_desc_from_dist_op( dist_op=dist_op, dist_context=ctx ) - processes = dist_op.dist_attr.process_mesh.processes + processes = dist_op.dist_attr.process_mesh.process_ids cost_mapping = build_comp_costs_from_descs( MatmulOpCost, ctx, processes, desc_mapping, cluster ) @@ -749,7 +749,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ), "backward op [{}] don't have dist attribute !".format(str(src_op)) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism - if rank_id not in op_dist_attr.process_mesh.processes: + if rank_id not in op_dist_attr.process_mesh.process_ids: rank_id = _get_corresponding_rank( ctx, op_dist_attr.process_mesh, rank_id ) @@ -791,8 +791,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( matmul_col_dim_mapping ) - process_mesh_shape = op_dist_attr.process_mesh.topology - process_mesh_group = op_dist_attr.process_mesh.processes + process_mesh_shape = op_dist_attr.process_mesh.shape + process_mesh_group = op_dist_attr.process_mesh.process_ids parallel_axis = matmul_col_dim_mapping group_ranks = _get_comm_group( @@ -986,7 +986,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): parallel_axis=parallel_axis, ) process_mesh = dist_attr.process_mesh - processes = process_mesh.processes + processes = process_mesh.process_ids comm_op_cost_list = build_comm_costs_from_descs( IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster ) @@ -1005,7 +1005,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): var_dim_mapping = dist_attr.get_input_dims_mapping( backward_op.input("X")[0] ) - mesh_shape = process_mesh.topology + mesh_shape = process_mesh.shape batch_size_axis = var_dim_mapping[0] if ( batch_size_axis > -1 @@ -1025,7 +1025,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): desc_mapping = build_comp_desc_from_dist_op( dist_op=dist_op, dist_context=ctx ) - processes = dist_op.dist_attr.process_mesh.processes + processes = dist_op.dist_attr.process_mesh.process_ids cost_mapping = build_comp_costs_from_descs( MatmulOpCost, ctx, processes, desc_mapping, cluster ) @@ -1131,7 +1131,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ), "backward op [{}] don't have dist attribute !".format(str(src_op)) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism - if rank_id not in op_dist_attr.process_mesh.processes: + if rank_id not in op_dist_attr.process_mesh.process_ids: rank_id = _get_corresponding_rank( ctx, op_dist_attr.process_mesh, rank_id ) @@ -1173,8 +1173,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( matmul_row_dim_mapping ) - process_mesh_shape = op_dist_attr.process_mesh.topology - process_mesh_group = op_dist_attr.process_mesh.processes + process_mesh_shape = op_dist_attr.process_mesh.shape + process_mesh_group = op_dist_attr.process_mesh.process_ids parallel_axis = matmul_row_dim_mapping group_ranks = _get_comm_group( @@ -1329,7 +1329,7 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): dist_op=dist_op, dist_context=ctx ) process_mesh = dist_attr.process_mesh - processes = process_mesh.processes + processes = process_mesh.process_ids cost_mapping = build_comp_costs_from_descs( MatmulGradOpCost, ctx, processes, desc_mapping, cluster ) @@ -1339,7 +1339,7 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): var_dim_mapping = dist_attr.get_input_dims_mapping( backward_op.input("X")[0] ) - mesh_shape = process_mesh.topology + mesh_shape = process_mesh.shape batch_size_axis = var_dim_mapping[0] if ( batch_size_axis > -1 @@ -1360,7 +1360,7 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): desc_mapping = build_comp_desc_from_dist_op( dist_op=dist_op, dist_context=ctx ) - processes = dist_op.dist_attr.process_mesh.processes + processes = dist_op.dist_attr.process_mesh.process_ids cost_mapping = build_comp_costs_from_descs( MatmulOpCost, ctx, processes, desc_mapping, cluster ) @@ -1479,7 +1479,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): backward_op.input("Y")[0] ) process_mesh = dist_attr.process_mesh - processes = process_mesh.processes + processes = process_mesh.process_ids # col parallel: matmul + allreduce if backward_op.attr("trans_y"): Y_var_dim_mapping.reverse() @@ -1526,7 +1526,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): var_dim_mapping = dist_attr.get_input_dims_mapping( backward_op.input("X")[0] ) - mesh_shape = process_mesh.topology + mesh_shape = process_mesh.shape batch_size_axis = var_dim_mapping[0] if ( batch_size_axis > -1 @@ -1547,7 +1547,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): comp_desc_mapping = build_comp_desc_from_dist_op( dist_op=dist_op, dist_context=ctx ) - processes = dist_op.dist_attr.process_mesh.processes + processes = dist_op.dist_attr.process_mesh.process_ids comp_cost_mapping = build_comp_costs_from_descs( MatmulV2OpCost, ctx, processes, comp_desc_mapping, cluster ) @@ -1645,7 +1645,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ), "backward op [{}] don't have dist attribute !".format(str(src_op)) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism - if rank_id not in op_dist_attr.process_mesh.processes: + if rank_id not in op_dist_attr.process_mesh.process_ids: rank_id = _get_corresponding_rank( ctx, op_dist_attr.process_mesh, rank_id ) @@ -1687,8 +1687,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( matmul_col_dim_mapping ) - process_mesh_shape = op_dist_attr.process_mesh.topology - process_mesh_group = op_dist_attr.process_mesh.processes + process_mesh_shape = op_dist_attr.process_mesh.shape + process_mesh_group = op_dist_attr.process_mesh.process_ids parallel_axis = matmul_col_dim_mapping group_ranks = _get_comm_group( @@ -1869,7 +1869,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): parallel_axis = Y_var_dim_mapping[0] process_mesh = dist_attr.process_mesh - processes = process_mesh.processes + processes = process_mesh.process_ids # calc comm op cost var_names = [backward_op.input("Out@GRAD")[0]] attrs = {"use_calc_stream": True, "use_model_parallel": True} @@ -1900,7 +1900,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): var_dim_mapping = dist_attr.get_input_dims_mapping( backward_op.input("X")[0] ) - mesh_shape = process_mesh.topology + mesh_shape = process_mesh.shape batch_size_axis = var_dim_mapping[0] if ( batch_size_axis > -1 @@ -1920,7 +1920,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): desc_mapping = build_comp_desc_from_dist_op( dist_op=dist_op, dist_context=ctx ) - processes = dist_op.dist_attr.process_mesh.processes + processes = dist_op.dist_attr.process_mesh.process_ids cost_mapping = build_comp_costs_from_descs( MatmulV2OpCost, ctx, processes, desc_mapping, cluster ) @@ -2025,7 +2025,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ), "backward op [{}] don't have dist attribute !".format(str(src_op)) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism - if rank_id not in op_dist_attr.process_mesh.processes: + if rank_id not in op_dist_attr.process_mesh.process_ids: rank_id = _get_corresponding_rank( ctx, op_dist_attr.process_mesh, rank_id ) @@ -2067,8 +2067,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( matmul_row_dim_mapping ) - process_mesh_shape = op_dist_attr.process_mesh.topology - process_mesh_group = op_dist_attr.process_mesh.processes + process_mesh_shape = op_dist_attr.process_mesh.shape + process_mesh_group = op_dist_attr.process_mesh.process_ids parallel_axis = matmul_row_dim_mapping group_ranks = _get_comm_group( @@ -2222,7 +2222,7 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): desc_mapping = build_comp_desc_from_dist_op( dist_op=dist_op, dist_context=ctx ) - processes = process_mesh.processes + processes = process_mesh.process_ids cost_mapping = build_comp_costs_from_descs( MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster ) @@ -2232,7 +2232,7 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): var_dim_mapping = dist_attr.get_input_dims_mapping( backward_op.input("X")[0] ) - mesh_shape = process_mesh.topology + mesh_shape = process_mesh.shape batch_size_axis = var_dim_mapping[0] if ( batch_size_axis > -1 @@ -2253,7 +2253,7 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): desc_mapping = build_comp_desc_from_dist_op( dist_op=dist_op, dist_context=ctx ) - processes = dist_op.dist_attr.process_mesh.processes + processes = dist_op.dist_attr.process_mesh.process_ids cost_mapping = build_comp_costs_from_descs( MatmulV2OpCost, ctx, processes, desc_mapping, cluster ) @@ -2386,7 +2386,7 @@ class DistributedMulImpl0(DistributedOperatorImpl): dist_op=dist_op, dist_context=ctx ) process_mesh = dist_attr.process_mesh - processes = process_mesh.processes + processes = process_mesh.process_ids cost_mapping = build_comp_costs_from_descs( MulGradOpCost, ctx, processes, desc_mapping, cluster ) @@ -2417,7 +2417,7 @@ class DistributedMulImpl0(DistributedOperatorImpl): var_dim_mapping = dist_attr.get_input_dims_mapping( backward_op.input("X")[0] ) - mesh_shape = process_mesh.topology + mesh_shape = process_mesh.shape batch_size_axis = var_dim_mapping[0] if ( batch_size_axis > -1 @@ -2437,7 +2437,7 @@ class DistributedMulImpl0(DistributedOperatorImpl): desc_mapping = build_comp_desc_from_dist_op( dist_op=dist_op, dist_context=ctx ) - processes = dist_op.dist_attr.process_mesh.processes + processes = dist_op.dist_attr.process_mesh.process_ids cost_mapping = build_comp_costs_from_descs( MulOpCost, ctx, processes, desc_mapping, cluster ) @@ -2530,7 +2530,7 @@ class DistributedMulImpl0(DistributedOperatorImpl): ), "backward op [{}] don't have dist attribute !".format(str(src_op)) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism - if rank_id not in op_dist_attr.process_mesh.processes: + if rank_id not in op_dist_attr.process_mesh.process_ids: rank_id = _get_corresponding_rank( ctx, op_dist_attr.process_mesh, rank_id ) @@ -2566,8 +2566,8 @@ class DistributedMulImpl0(DistributedOperatorImpl): ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( matmul_col_dim_mapping ) - process_mesh_shape = op_dist_attr.process_mesh.topology - process_mesh_group = op_dist_attr.process_mesh.processes + process_mesh_shape = op_dist_attr.process_mesh.shape + process_mesh_group = op_dist_attr.process_mesh.process_ids parallel_axis = matmul_col_dim_mapping group_ranks = _get_comm_group( @@ -2778,7 +2778,7 @@ class DistributedMulImpl1(DistributedOperatorImpl): attrs=attrs, parallel_axis=parallel_axis, ) - processes = process_mesh.processes + processes = process_mesh.process_ids comm_op_cost_list = build_comm_costs_from_descs( IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster ) @@ -2797,7 +2797,7 @@ class DistributedMulImpl1(DistributedOperatorImpl): var_dim_mapping = dist_attr.get_input_dims_mapping( backward_op.input("X")[0] ) - mesh_shape = process_mesh.topology + mesh_shape = process_mesh.shape batch_size_axis = var_dim_mapping[0] if ( batch_size_axis > -1 @@ -2817,7 +2817,7 @@ class DistributedMulImpl1(DistributedOperatorImpl): desc_mapping = build_comp_desc_from_dist_op( dist_op=dist_op, dist_context=ctx ) - processes = dist_op.dist_attr.process_mesh.processes + processes = dist_op.dist_attr.process_mesh.process_ids cost_mapping = build_comp_costs_from_descs( MulOpCost, ctx, processes, desc_mapping, cluster ) @@ -2919,7 +2919,7 @@ class DistributedMulImpl1(DistributedOperatorImpl): ), "backward op [{}] don't have dist attribute !".format(str(src_op)) # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism - if rank_id not in op_dist_attr.process_mesh.processes: + if rank_id not in op_dist_attr.process_mesh.process_ids: rank_id = _get_corresponding_rank( ctx, op_dist_attr.process_mesh, rank_id ) @@ -2955,8 +2955,8 @@ class DistributedMulImpl1(DistributedOperatorImpl): ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( matmul_row_dim_mapping ) - process_mesh_shape = op_dist_attr.process_mesh.topology - process_mesh_group = op_dist_attr.process_mesh.processes + process_mesh_shape = op_dist_attr.process_mesh.shape + process_mesh_group = op_dist_attr.process_mesh.process_ids parallel_axis = matmul_row_dim_mapping group_ranks = _get_comm_group( @@ -3131,7 +3131,7 @@ class DistributedMulImpl2(DistributedOperatorImpl): dist_op=dist_op, dist_context=ctx ) process_mesh = dist_attr.process_mesh - processes = process_mesh.processes + processes = process_mesh.process_ids cost_mapping = build_comp_costs_from_descs( MulGradOpCost, ctx, processes, desc_mapping, cluster ) @@ -3141,7 +3141,7 @@ class DistributedMulImpl2(DistributedOperatorImpl): var_dim_mapping = dist_attr.get_input_dims_mapping( backward_op.input("X")[0] ) - mesh_shape = process_mesh.topology + mesh_shape = process_mesh.shape batch_size_axis = var_dim_mapping[0] if ( batch_size_axis > -1 @@ -3162,7 +3162,7 @@ class DistributedMulImpl2(DistributedOperatorImpl): desc_mapping = build_comp_desc_from_dist_op( dist_op=dist_op, dist_context=ctx ) - processes = dist_op.dist_attr.process_mesh.processes + processes = dist_op.dist_attr.process_mesh.process_ids cost_mapping = build_comp_costs_from_descs( MulOpCost, ctx, processes, desc_mapping, cluster ) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py b/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py index 766256aaa50..53e2278bd6c 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py @@ -154,7 +154,7 @@ class DistributedPNormImpl(DistributedOperatorImpl): output_name ) - if rank_id not in op_dist_attr.process_mesh.processes: + if rank_id not in op_dist_attr.process_mesh.process_ids: rank_id = _get_corresponding_rank( ctx, op_dist_attr.process_mesh, rank_id ) @@ -164,8 +164,8 @@ class DistributedPNormImpl(DistributedOperatorImpl): for axis in range(len(in_dims_mapping)): if in_dims_mapping[axis] != -1: break - process_mesh_shape = op_dist_attr.process_mesh.topology - process_mesh_group = op_dist_attr.process_mesh.processes + process_mesh_shape = op_dist_attr.process_mesh.shape + process_mesh_group = op_dist_attr.process_mesh.process_ids group_ranks = _get_comm_group( process_mesh_group, process_mesh_shape, axis, rank_id ) @@ -301,8 +301,8 @@ class DistributedPNormImpl(DistributedOperatorImpl): ctx.set_op_dist_attr_for_program(p_norm_grad_op, op_dist_attr) # 2. insert slice op - process_mesh_shape = op_dist_attr.process_mesh.topology - process_mesh_group = op_dist_attr.process_mesh.processes + process_mesh_shape = op_dist_attr.process_mesh.shape + process_mesh_group = op_dist_attr.process_mesh.process_ids dims_mapping = [0] + [-1 for _ in range(len(new_X_grad.shape) - 1)] from ..reshard import Resharder diff --git a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py index 43b6f984fa6..d3a344965c4 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py @@ -67,7 +67,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): shape_list = op.desc.attr("shape") # got dist attribute info dim_mapping = dist_attr.get_output_dims_mapping(op.output("Out")[0]) - process_mesh_shape = dist_attr.process_mesh.topology + process_mesh_shape = dist_attr.process_mesh.shape # modify target shape for idx, axis in enumerate(dim_mapping): @@ -81,7 +81,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): desc_mapping = build_comp_desc_from_dist_op( dist_op=dist_op, dist_context=ctx ) - processes = dist_attr.process_mesh.processes + processes = dist_attr.process_mesh.process_ids for key in desc_mapping: desc_mapping[key]["shape"] = shape_list @@ -100,7 +100,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ) dist_attr = dist_op.dist_attr process_mesh = dist_attr.process_mesh - processes = process_mesh.processes + processes = process_mesh.process_ids op_type = dist_op.serial_op.type cost_mapping = build_comp_costs_from_descs( @@ -119,7 +119,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op var_dim_mapping = dist_attr.get_input_dims_mapping(varname) - mesh_shape = process_mesh.topology + mesh_shape = process_mesh.shape batch_size_axis = var_dim_mapping[0] if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: parallel_axis = batch_size_axis @@ -266,7 +266,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): # got dist attribute info dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) - process_mesh_shape = op_dist_attr.process_mesh.topology + process_mesh_shape = op_dist_attr.process_mesh.shape # modify target shape for idx, axis in enumerate(dim_mapping): @@ -315,7 +315,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): shape_list = op.desc.attr("shape") # got dist attribute info dim_mapping = dist_attr.get_output_dims_mapping(op.output("Out")[0]) - process_mesh_shape = dist_attr.process_mesh.topology + process_mesh_shape = dist_attr.process_mesh.shape # modify target shape for idx, axis in enumerate(dim_mapping): @@ -329,7 +329,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): desc_mapping = build_comp_desc_from_dist_op( dist_op=dist_op, dist_context=ctx ) - processes = dist_attr.process_mesh.processes + processes = dist_attr.process_mesh.process_ids for key in desc_mapping: desc_mapping[key]["shape"] = shape_list @@ -348,7 +348,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ) dist_attr = dist_op.dist_attr process_mesh = dist_attr.process_mesh - processes = process_mesh.processes + processes = process_mesh.process_ids op_type = dist_op.serial_op.type cost_mapping = build_comp_costs_from_descs( @@ -367,7 +367,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op var_dim_mapping = dist_attr.get_input_dims_mapping(varname) - mesh_shape = process_mesh.topology + mesh_shape = process_mesh.shape batch_size_axis = var_dim_mapping[0] if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: parallel_axis = batch_size_axis @@ -517,7 +517,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): # got dist attribute info dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) - process_mesh_shape = op_dist_attr.process_mesh.topology + process_mesh_shape = op_dist_attr.process_mesh.shape # modify target shape for idx, axis in enumerate(dim_mapping): @@ -566,7 +566,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): shape_list = op.desc.attr("shape") # got dist attribute info dim_mapping = dist_attr.get_output_dims_mapping(op.output("Out")[0]) - process_mesh_shape = dist_attr.process_mesh.topology + process_mesh_shape = dist_attr.process_mesh.shape # modify target shape for idx, axis in enumerate(dim_mapping): @@ -580,7 +580,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): desc_mapping = build_comp_desc_from_dist_op( dist_op=dist_op, dist_context=ctx ) - processes = dist_attr.process_mesh.processes + processes = dist_attr.process_mesh.process_ids for key in desc_mapping: desc_mapping[key]["shape"] = shape_list @@ -599,7 +599,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): ) dist_attr = dist_op.dist_attr process_mesh = dist_attr.process_mesh - processes = process_mesh.processes + processes = process_mesh.process_ids op_type = dist_op.serial_op.type cost_mapping = build_comp_costs_from_descs( @@ -618,7 +618,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op var_dim_mapping = dist_attr.get_input_dims_mapping(varname) - mesh_shape = process_mesh.topology + mesh_shape = process_mesh.shape batch_size_axis = var_dim_mapping[0] if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: parallel_axis = batch_size_axis @@ -761,7 +761,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): # got dist attribute info out_dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) - process_mesh_shape = op_dist_attr.process_mesh.topology + process_mesh_shape = op_dist_attr.process_mesh.shape # modify target shape for idx, axis in enumerate(out_dim_mapping): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py index 97b21ba9203..4592a05045d 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py @@ -60,7 +60,7 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): desc_mapping = build_comp_desc_from_dist_op( dist_op=dist_op, dist_context=ctx ) - processes = dist_op.dist_attr.process_mesh.processes + processes = dist_op.dist_attr.process_mesh.process_ids cost_mapping = build_comp_costs_from_descs( SoftmaxOpCost, ctx, processes, desc_mapping, cluster ) @@ -76,7 +76,7 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ) dist_attr = dist_op.dist_attr process_mesh = dist_attr.process_mesh - processes = process_mesh.processes + processes = process_mesh.process_ids cost_mapping = build_comp_costs_from_descs( SoftmaxGradOpCost, ctx, processes, desc_mapping, cluster ) @@ -93,7 +93,7 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op var_dim_mapping = dist_attr.get_input_dims_mapping(varname) - mesh_shape = process_mesh.topology + mesh_shape = process_mesh.shape batch_size_axis = var_dim_mapping[0] if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: parallel_axis = batch_size_axis diff --git a/python/paddle/distributed/auto_parallel/operators/dist_transpose.py b/python/paddle/distributed/auto_parallel/operators/dist_transpose.py index a4ab19c36cc..b49debc6ad7 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_transpose.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_transpose.py @@ -140,7 +140,7 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): desc_mapping = build_comp_desc_from_dist_op( dist_op=dist_op, dist_context=ctx ) - processes = dist_op.dist_attr.process_mesh.processes + processes = dist_op.dist_attr.process_mesh.process_ids op_type = dist_op.serial_op.type cost_mapping = build_comp_costs_from_descs( Transpose2OpCost, ctx, processes, desc_mapping, cluster @@ -157,7 +157,7 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): ) dist_attr = dist_op.dist_attr process_mesh = dist_attr.process_mesh - processes = process_mesh.processes + processes = process_mesh.process_ids op_type = dist_op.serial_op.type cost_mapping = build_comp_costs_from_descs( Transpose2GradOpCost, ctx, processes, desc_mapping, cluster @@ -175,7 +175,7 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op var_dim_mapping = dist_attr.get_input_dims_mapping(varname) - mesh_shape = process_mesh.topology + mesh_shape = process_mesh.shape batch_size_axis = var_dim_mapping[0] if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: parallel_axis = batch_size_axis diff --git a/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py b/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py index c4f1794b46f..5877cf37b8d 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py @@ -79,7 +79,7 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl): str(backward_op) ) - assert rank_id in dist_attr.process_mesh.processes + assert rank_id in dist_attr.process_mesh.process_ids assert 'X' in kwargs, "input [{}] is not given".format('X') assert 'FoundInfinite' in kwargs, "input [{}] is not given".format( @@ -154,7 +154,7 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl): rank_id in ctx.get_tensor_dist_attr_for_program( main_block._var_recursive(varname) - ).process_mesh.processes + ).process_mesh.process_ids ): filter_vars.append(varname) diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index 25fc93995f9..0815ed1cd53 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -286,7 +286,7 @@ class AutoParallelizer: _g_process_group_map.clear() _g_process_group_map[0] = ProcessGroup(0, []) for process_mesh in self._dist_context._process_meshes: - _g_process_group_map[0].add_ranks(process_mesh.processes) + _g_process_group_map[0].add_ranks(process_mesh.process_ids) return ( dist_optimize_ops, dist_params_grads, @@ -482,7 +482,7 @@ class AutoParallelizer: if dist_context is not None: pg0 = get_process_group(0) for process_mesh in dist_context._process_meshes: - pg0.add_ranks(process_mesh.processes) + pg0.add_ranks(process_mesh.process_ids) ( dist_optimize_ops, dist_params_grads, diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index 1a2571b2ceb..696d9382681 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -391,7 +391,7 @@ def _get_dist_shape(var, dist_attr): var_shape = var.shape mapping = dist_attr.dims_mapping - mesh = dist_attr.process_mesh.topology + mesh = dist_attr.process_mesh.shape if mapping == []: return var_shape diff --git a/python/paddle/distributed/auto_parallel/planner.py b/python/paddle/distributed/auto_parallel/planner.py index 7ac776bbc52..a264e0294d2 100755 --- a/python/paddle/distributed/auto_parallel/planner.py +++ b/python/paddle/distributed/auto_parallel/planner.py @@ -74,7 +74,7 @@ class PlanFilter: for var_name in op.input_arg_names: dims_mapping = op_dist_attr.get_input_dims_mapping(var_name) if not PlanFilter.check_dims_mapping_for_tensor( - process_mesh.topology, vars[var_name].shape, dims_mapping + process_mesh.shape, vars[var_name].shape, dims_mapping ): return False if vars[var_name].is_data and len(dims_mapping) > 1: @@ -85,7 +85,7 @@ class PlanFilter: for var_name in op.output_arg_names: dims_mapping = op_dist_attr.get_output_dims_mapping(var_name) if not PlanFilter.check_dims_mapping_for_tensor( - process_mesh.topology, vars[var_name].shape, dims_mapping + process_mesh.shape, vars[var_name].shape, dims_mapping ): return False @@ -217,13 +217,13 @@ class PlanSpace: for var_name in chain(op.input_arg_names, op.output_arg_names): visited = [ False - for _ in range(len(list(range(-1, len(process_mesh.topology))))) + for _ in range(len(list(range(-1, len(process_mesh.shape))))) ] depth = 0 path = [] dims_mapping_list = [] PlanSpace._enum_dims_mapping( - process_mesh.topology, + process_mesh.shape, visited, path, depth, @@ -590,7 +590,10 @@ class MCMC(SearchAlgorithm): self.serial_program_info, dist_context, self.parallelizer ) pipeline_config = ( - [process_mesh.processes for process_mesh in pipeline_process_meshes] + [ + process_mesh.process_ids + for process_mesh in pipeline_process_meshes + ] if pipeline_process_meshes is not None else None ) @@ -1060,7 +1063,7 @@ class MCMC(SearchAlgorithm): # rebuild g_process_group pg0 = get_process_group(0) for process_mesh in searched_dist_context._process_meshes: - pg0.add_ranks(process_mesh.processes) + pg0.add_ranks(process_mesh.process_ids) end_time = time.time() print( "End MCMC searching: the min cost is {} and the search time is {}s.".format( diff --git a/python/paddle/distributed/auto_parallel/process_mesh.py b/python/paddle/distributed/auto_parallel/process_mesh.py index 2ccd188dde4..27bb0a79ac2 100644 --- a/python/paddle/distributed/auto_parallel/process_mesh.py +++ b/python/paddle/distributed/auto_parallel/process_mesh.py @@ -17,6 +17,7 @@ import copy import numpy as np import paddle +from paddle.fluid import core # Use to store the previous and current process mesh _g_previous_process_mesh = None @@ -41,12 +42,12 @@ def reset_current_process_mesh(): _g_current_process_mesh = _g_previous_process_mesh -class ProcessMesh: +class ProcessMesh(core.ProcessMesh): """ - The `Processmesh` object describes the topology of the used processes. + The `ProcessMesh` object describes the Cartesian topology of the used processes. Args: - mesh (list|numpy.array): an n-dimensional array describes the toplogy + mesh (list|numpy.array): an n-dimensional array describes the topology of the processes. dim_names (list, optional): the i-th element of this list gives the name of the i-th dimension of the mesh. @@ -58,7 +59,7 @@ class ProcessMesh: mesh = auto.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"]) assert mesh.shape == [2, 3] - assert mesh.processe_ids == [2, 4, 5, 0, 1, 3] + assert mesh.process_ids == [2, 4, 5, 0, 1, 3] """ @@ -77,6 +78,9 @@ class ProcessMesh: if isinstance(mesh, list): mesh = np.array(mesh) + if dim_names is not None and not isinstance(dim_names, list): + raise ValueError('The dim_names must be an instance of list.') + self._mesh = mesh self._shape = list(self._mesh.shape) self._process_ids = self._mesh.flatten().tolist() @@ -104,6 +108,11 @@ class ProcessMesh: self._dim_names ), 'All dim_names {} must be unique.'.format(dim_names) + # Follow the requirement for using pybind11 + core.ProcessMesh.__init__( + self, self._shape, self._process_ids, self._dim_names + ) + # Store all process meshes from .dist_context import get_default_distributed_context @@ -113,35 +122,7 @@ class ProcessMesh: from .process_group import get_process_group pg0 = get_process_group(0) - pg0.add_ranks(self.processes) - - @property - def shape(self): - """ - Get the shape of this ProcessMesh. - """ - return self._shape - - @property - def process_ids(self): - """ - Get the process ids belonging to this ProcessMesh. - """ - return self._process_ids - - @property - def dim_names(self): - """ - Get the dimension names of this ProcessMesh. - """ - return self._dim_names - - @property - def ndim(self): - """ - Get the number of dimension of this ProcessMesh. - """ - return len(self._shape) + pg0.add_ranks(self.process_ids) @property def mesh(self): @@ -150,14 +131,6 @@ class ProcessMesh: """ return self._mesh - @property - def topology(self): - return self._shape - - @property - def processes(self): - return self._process_ids - def __getitem__(self, index): if isinstance(index, tuple): new_dim_names = [] @@ -207,9 +180,8 @@ class ProcessMesh: tensor ) if dist_tensor is None: - dist_tensor = DistributedTensor( - cur_block.vars[name], {"process_mesh": self} - ) + dist_tensor = DistributedTensor(cur_block.vars[name]) + dist_tensor.dist_attr.process_mesh = self dist_tensor.dist_attr.mark_annotated("process_mesh") default_dist_ctx.add_dist_tensor_for_program(dist_tensor) else: @@ -221,7 +193,8 @@ class ProcessMesh: op = cur_block.ops[idx] dist_op = default_dist_ctx.get_dist_op_for_program(op) if dist_op is None: - dist_op = DistributedOperator(op, {"process_mesh": self}) + dist_op = DistributedOperator(op) + dist_op.dist_attr.process_mesh = self dist_op.dist_attr.mark_annotated("process_mesh") default_dist_ctx.add_dist_op_for_program(dist_op) else: @@ -230,6 +203,13 @@ class ProcessMesh: dist_op.dist_attr.mark_annotated("process_mesh") reset_current_process_mesh() + def __deepcopy__(self, memo): + if id(self) in memo: + return memo[id(self)] + new_process_mesh = ProcessMesh(np.array(self.mesh), self.dim_names) + memo[id(self)] = new_process_mesh + return new_process_mesh + def __eq__(self, other): if not isinstance(other, ProcessMesh): return False @@ -245,3 +225,51 @@ class ProcessMesh: self.shape, self.process_ids, self.dim_names ) return str + + +def compute_compatible_process_mesh(process_mesh_list): + """Compute the compatible process mesh given a list of process meshes.""" + if not process_mesh_list: + return None + + def _compute_compatible_process_mesh_of_two(pm1, pm2): + if pm1 is None: + return True, pm2 + if pm2 is None: + return True, pm1 + if pm1 == pm2: + return True, pm1 + if pm1.process_ids == pm2.process_ids: + if len(pm1.shape) >= len(pm2.shape): + return True, pm1 + else: + return True, pm2 + process_set1 = set(pm1.process_ids) + process_set2 = set(pm2.process_ids) + if process_set1.issubset(process_set2): + return True, pm2 + if process_set2.issubset(process_set1): + return True, pm1 + return False, None + + compatible_result = None + for process_mesh in process_mesh_list: + compatible, compatible_result = _compute_compatible_process_mesh_of_two( + compatible_result, process_mesh + ) + if not compatible: + return None + return copy.deepcopy(compatible_result) + + +def merge_process_meshes(process_meshes): + """Merge a list of process meshes.""" + merged_process_mesh = None + merged_process_ids = set() + for process_mesh in process_meshes: + if process_mesh is not None: + process_ids = set(process_mesh.process_ids) + merged_process_ids = merged_process_ids.union(process_ids) + if len(merged_process_ids) != 0: + merged_process_mesh = ProcessMesh(list(merged_process_ids)) + return merged_process_mesh diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index 62e383c72e9..3dda3fb3899 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -829,7 +829,7 @@ class Remover: ) ).process_mesh ) - if rank_id in process_mesh.processes: + if rank_id in process_mesh.process_ids: need_save.append(var_name) if not need_save: remove_op_idx.append(idx) @@ -845,7 +845,7 @@ class Remover: if op_dist_attr is not None: op_process_mesh = op_dist_attr.process_mesh if ( - rank_id not in op_process_mesh.processes + rank_id not in op_process_mesh.process_ids and op.type not in not_remove_op_ref ): remove_op_idx.append(idx) @@ -1408,9 +1408,11 @@ class Resharder: op_process_mesh = dist_op.dist_attr.process_mesh for process_mesh in self.dist_context.process_meshes: - if set(process_mesh.processes) & ( - set(op_process_mesh.processes) - ) and len(process_mesh.processes) < len(op_process_mesh.processes): + if set(process_mesh.process_ids) & ( + set(op_process_mesh.process_ids) + ) and len(process_mesh.process_ids) < len( + op_process_mesh.process_ids + ): process_meshes.append(process_mesh) # it means the process mesh is not a union when process meshes is null @@ -1438,13 +1440,13 @@ class Resharder: source_dims_mapping = tensor_dist_attr.dims_mapping source_process_mesh = tensor_dist_attr.process_mesh - source_process_group = source_process_mesh.processes - source_process_shape = source_process_mesh.topology + source_process_group = source_process_mesh.process_ids + source_process_shape = source_process_mesh.shape target_process_mesh = dist_attr[0] target_dims_mapping = dist_attr[1] - target_process_group = target_process_mesh.processes - target_process_shape = target_process_mesh.topology + target_process_group = target_process_mesh.process_ids + target_process_shape = target_process_mesh.shape if source_tensor.shape[0] < 0: assert source_tensor.shape[0] == -1 @@ -2141,9 +2143,11 @@ class Resharder: dist_attr = dist_op.dist_attr op_process_mesh = dist_attr.process_mesh for process_mesh in self.dist_context.process_meshes: - if set(process_mesh.processes) & ( - set(op_process_mesh.processes) - ) and len(process_mesh.processes) < len(op_process_mesh.processes): + if set(process_mesh.process_ids) & ( + set(op_process_mesh.process_ids) + ) and len(process_mesh.process_ids) < len( + op_process_mesh.process_ids + ): process_meshes.append(process_mesh) # it means that the process mesh is not a union when process meshes is none @@ -2176,12 +2180,12 @@ class Resharder: if process_mesh_count > 1: global_process_mesh_idx = None for process_mesh in self.dist_context.process_meshes: - for process in process_mesh.processes: + for process in process_mesh.process_ids: processes.add(process) for idx, process_mesh in enumerate( self.dist_context.process_meshes ): - if len(set(process_mesh.processes)) == len(processes): + if len(set(process_mesh.process_ids)) == len(processes): global_process_mesh_idx = idx break @@ -2191,7 +2195,7 @@ class Resharder: for i, mesh in enumerate(self.dist_context.process_meshes): if i == idx: continue - if set(mesh.processes) < set(global_mesh.processes): + if set(mesh.process_ids) < set(global_mesh.process_ids): is_removed = True if is_removed: @@ -2330,8 +2334,8 @@ class Resharder: # deal with union tensor if is_union_process_mesh_tensor: # if op process mesh is subset of union tensor process mesh, need no reshard - if set(input_attr[0].processes) <= set( - dist_tensor.dist_attr.process_mesh.processes + if set(input_attr[0].process_ids) <= set( + dist_tensor.dist_attr.process_mesh.process_ids ): continue @@ -2525,14 +2529,14 @@ class Resharder: dist_tensor, output_attr, False ): tensor_processes = set( - tensor_process_mesh.processes + tensor_process_mesh.process_ids ) - ( - set(tensor_process_mesh.processes) - & set(output_attr[0].processes) + set(tensor_process_mesh.process_ids) + & set(output_attr[0].process_ids) ) if tensor_processes: if len(tensor_processes) != len( - output_attr[0].processes + output_attr[0].process_ids ): if dist_tensor.dist_attr.dims_mapping.count( -1 @@ -2555,13 +2559,15 @@ class Resharder: recv_rank = tensor_process actual_index = index if index >= len( - output_attr[0].processes + output_attr[0].process_ids ): actual_index = ( index - - len(output_attr[0].processes) - ) % len(output_attr[0].processes) - item = output_attr[0].processes[ + - len( + output_attr[0].process_ids + ) + ) % len(output_attr[0].process_ids) + item = output_attr[0].process_ids[ actual_index ] if recv_rank == item: @@ -2591,7 +2597,7 @@ class Resharder: tensor_processes ): recv_rank = tensor_process - item = output_attr[0].processes[index] + item = output_attr[0].process_ids[index] if recv_rank == item: continue if self.rank_id == item: diff --git a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py index c3de081c752..e09f55d91ec 100644 --- a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py +++ b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py @@ -114,7 +114,7 @@ def _copy_context(ref_dist_context): clear_all_process_groups() ranks = [] for process_mesh in ref_dist_context._process_meshes: - ranks.extend(process_mesh.processes) + ranks.extend(process_mesh.process_ids) new_process_group(list(set(ranks))) new_dist_context = DistributedContext() diff --git a/python/paddle/distributed/auto_parallel/tuner/parallel_tuner.py b/python/paddle/distributed/auto_parallel/tuner/parallel_tuner.py index e1d8217a99a..f856d7590f7 100644 --- a/python/paddle/distributed/auto_parallel/tuner/parallel_tuner.py +++ b/python/paddle/distributed/auto_parallel/tuner/parallel_tuner.py @@ -689,15 +689,15 @@ class ParallelTuner: assert process_mesh.ndim == 2 dim_of_one = None dim_of_other = None - if process_mesh.topology[0] == 1: + if process_mesh.shape[0] == 1: dim_of_one = 0 dim_of_other = 1 - elif process_mesh.topology[1] == 1: + elif process_mesh.shape[1] == 1: dim_of_one = 1 dim_of_other = 0 if dim_of_one is not None: - dist_attr.process_mesh = ProcessMesh(process_mesh.processes) + dist_attr.process_mesh = ProcessMesh(process_mesh.process_ids) self._dist_context.add_process_mesh(dist_attr.process_mesh) for arg_name in dist_attr.inputs_dist_attrs.keys(): @@ -715,7 +715,7 @@ class ParallelTuner: dims_mapping = dist_attr.get_input_dims_mapping(arg_name) # dynamic_dims = dist_attr.get_input_dynamic_dims(arg_name) process_mesh = dist_attr.process_mesh - process_shape = process_mesh.topology + process_shape = process_mesh.shape tensor = dist_op.get_serial_input(arg_name) if dims_mapping: tensor_shape = tensor.shape @@ -748,7 +748,7 @@ class ParallelTuner: dims_mapping = dist_attr.get_output_dims_mapping(arg_name) # dynamic_dims = dist_attr.get_output_dynamic_dims(arg_name) process_mesh = dist_attr.process_mesh - process_shape = process_mesh.topology + process_shape = process_mesh.shape tensor = dist_op.get_serial_output(arg_name) if dims_mapping: @@ -793,7 +793,7 @@ class ParallelTuner: input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping( input_name ) - topology = dist_op.dist_attr.process_mesh.topology + topology = dist_op.dist_attr.process_mesh.shape input_tensor = dist_op.get_serial_input(input_name) last_but_one_dim = ( input_tensor.shape[-2] // topology[input_dims_mapping[-2]] diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 4d474569fb3..80721b0c7fb 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -100,7 +100,7 @@ def convert_to_dims_mapping(shard_spec, process_mesh): for shard in shard_spec: if shard is None: dims_mapping.append(-1) - elif process_mesh.topology[process_mesh.dim_names.index(shard)] == 1: + elif process_mesh.shape[process_mesh.dim_names.index(shard)] == 1: dims_mapping.append(-1) else: dims_mapping.append(process_mesh.dim_names.index(shard)) @@ -429,26 +429,26 @@ def _get_corresponding_rank(dist_context, target_mesh, rank): coordinate = None for mesh in dist_context.process_meshes: - if rank in mesh.processes and mesh.topology == target_mesh.topology: + if rank in mesh.process_ids and mesh.shape == target_mesh.shape: coordinate = _linear_idx2coordinate( - mesh.topology, mesh.processes.index(rank) + mesh.shape, mesh.process_ids.index(rank) ) break # assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format( # rank) if coordinate is not None: - return target_mesh.processes[ - _coordinate2linear_idx(mesh.topology, coordinate) + return target_mesh.process_ids[ + _coordinate2linear_idx(mesh.shape, coordinate) ] else: - return target_mesh.processes[0] + return target_mesh.process_ids[0] def _get_unshard_dist_shape(var, dist_attr): var_shape = var.shape mapping = dist_attr.dims_mapping - mesh = dist_attr.process_mesh.topology + mesh = dist_attr.process_mesh.shape assert len(var_shape) == len( mapping ), "variable shape [{}] and dim_mapping [{}] is NOT match !".format( @@ -832,8 +832,8 @@ def get_dist_attr(program, dist_context=None): process_mesh = tensor_dist_attr.process_mesh dims_mapping = tensor_dist_attr.dims_mapping dist_attr[var.name] = { - "process_shape": process_mesh.topology, - "process_group": process_mesh.processes, + "process_shape": process_mesh.shape, + "process_group": process_mesh.process_ids, "dims_mapping": dims_mapping, } return dist_attr @@ -2006,16 +2006,16 @@ def get_input_split_info(cur_rank, var, dist_context): process_mesh = tensor_dist_attr.process_mesh dims_mapping = tensor_dist_attr.dims_mapping - if cur_rank not in process_mesh.processes: + if cur_rank not in process_mesh.process_ids: rank_id = _get_corresponding_rank(dist_context, process_mesh, cur_rank) else: rank_id = cur_rank batch_size_axis = dims_mapping[0] - if batch_size_axis > -1 and process_mesh.topology[batch_size_axis] > 1: + if batch_size_axis > -1 and process_mesh.shape[batch_size_axis] > 1: group_ranks = _get_comm_group( - process_mesh.processes, - process_mesh.topology, + process_mesh.process_ids, + process_mesh.shape, batch_size_axis, rank_id, ) diff --git a/python/paddle/distributed/passes/auto_parallel_grad_clip.py b/python/paddle/distributed/passes/auto_parallel_grad_clip.py index 7258eca661d..d209d13eefd 100644 --- a/python/paddle/distributed/passes/auto_parallel_grad_clip.py +++ b/python/paddle/distributed/passes/auto_parallel_grad_clip.py @@ -164,8 +164,8 @@ class ClipHelper: param = self.params[self.params_name.index(name)] dist_attr = self._get_dist_attr(name) - topology = dist_attr.process_mesh.topology - processes = dist_attr.process_mesh.processes + topology = dist_attr.process_mesh.shape + processes = dist_attr.process_mesh.process_ids dims_mapping = dist_attr.dims_mapping return _is_about_global_norm( self.rank_id, @@ -188,7 +188,7 @@ class ClipHelper: def _is_local_var(self, name): dist_attr = self._get_dist_attr(name) assert dist_attr is not None - return self.rank_id in dist_attr.process_mesh.processes + return self.rank_id in dist_attr.process_mesh.process_ids def _init_dist_attr(self, op): op_dist_attr = OperatorDistributedAttribute() diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index c001a93d789..84992aa903b 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -870,13 +870,13 @@ def _inference_data_parallel_group_for_operator(rank_id, op, dist_context): dist_attr = dist_context.get_op_dist_attr_for_program(op) process_mesh = dist_attr.process_mesh input_dim_mapping = dist_attr.get_input_dims_mapping(input_name) - mesh_shape = process_mesh.topology + mesh_shape = process_mesh.shape # TODO(JZ-LIANG) replace with specific batch size dimension batch_size_axis = input_dim_mapping[0] if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: group_ranks = _get_comm_group( - process_mesh.processes, - process_mesh.topology, + process_mesh.process_ids, + process_mesh.shape, batch_size_axis, rank_id, ) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_base_cost.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_base_cost.py index 6f1b53c3317..b8a0f4467ac 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_base_cost.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_base_cost.py @@ -180,7 +180,7 @@ class TestBaseCost(unittest.TestCase): for op in train_program.global_block().ops: dist_op = dist_context.get_dist_op_for_program(op) if dist_op: - processes = dist_op.dist_attr.process_mesh.processes + processes = dist_op.dist_attr.process_mesh.process_ids comp_descs = build_comp_desc_from_dist_op(dist_op, dist_context) self.assertTrue(isinstance(comp_descs, dict) and comp_descs) var_names = None diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_conditional_block_reshard.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_conditional_block_reshard.py index 85f475848e5..de71bae51d4 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_conditional_block_reshard.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_conditional_block_reshard.py @@ -52,13 +52,13 @@ class MLPLayer(nn.Layer): out = self.norm(input) auto.shard_tensor( - self.linear0.weight, auto.ProcessMesh([0, 1], "x"), [None, "x"] + self.linear0.weight, auto.ProcessMesh([0, 1], ["x"]), [None, "x"] ) out = self.linear0(out) out = F.gelu(out, approximate=True) auto.shard_tensor( - self.linear1.weight, auto.ProcessMesh([0, 1], "x"), ["x", None] + self.linear1.weight, auto.ProcessMesh([0, 1], ["x"]), ["x", None] ) out = self.linear1(out) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py index d01f2792d2a..4eb0408976a 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py @@ -98,7 +98,7 @@ class TestDistOpCost(unittest.TestCase): ): dist_op = dist_context.get_dist_op_for_program(op) op_dist_attr = dist_op.dist_attr - processes = op_dist_attr.process_mesh.processes + processes = op_dist_attr.process_mesh.process_ids if is_elementwise_op(op.type): container = get_distributed_operator_impl_container( "elementwise" @@ -205,7 +205,7 @@ class TestDistOpCost(unittest.TestCase): for idx, op in enumerate(ops): dist_op = dist_context.get_dist_op_for_program(op) op_dist_attr = dist_op.dist_attr - processes = op_dist_attr.process_mesh.processes + processes = op_dist_attr.process_mesh.process_ids if is_elementwise_op(op.type): container = get_distributed_operator_impl_container( "elementwise" @@ -313,7 +313,7 @@ class TestDistOpCost(unittest.TestCase): for idx, op in enumerate(ops): dist_op = dist_context.get_dist_op_for_program(op) op_dist_attr = dist_op.dist_attr - processes = op_dist_attr.process_mesh.processes + processes = op_dist_attr.process_mesh.process_ids if is_elementwise_op(op.type): container = get_distributed_operator_impl_container( "elementwise" @@ -421,7 +421,7 @@ class TestDistOpCost(unittest.TestCase): for idx, op in enumerate(ops): dist_op = dist_context.get_dist_op_for_program(op) op_dist_attr = dist_op.dist_attr - processes = op_dist_attr.process_mesh.processes + processes = op_dist_attr.process_mesh.process_ids if is_elementwise_op(op.type): container = get_distributed_operator_impl_container( "elementwise" diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_process_mesh.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_process_mesh.py index 6a154000dd1..30527d247a7 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_process_mesh.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_process_mesh.py @@ -23,7 +23,11 @@ import paddle.static as static from paddle.distributed.auto_parallel.dist_context import ( get_default_distributed_context, ) -from paddle.distributed.auto_parallel.process_mesh import ProcessMesh +from paddle.distributed.auto_parallel.process_mesh import ( + ProcessMesh, + compute_compatible_process_mesh, + merge_process_meshes, +) paddle.enable_static() @@ -129,7 +133,7 @@ class TestProcessMesh(unittest.TestCase): initializer_range=0.02, ) - with ProcessMesh(mesh, "d"): + with ProcessMesh(mesh, ["d"]): out = mlp(input) default_program = paddle.fluid.default_main_program() @@ -151,6 +155,67 @@ class TestProcessMesh(unittest.TestCase): dist_op.dist_attr.process_mesh, ProcessMesh(mesh) ) + def test_compute_compatible_process_mesh(self): + process_mesh1 = ProcessMesh( + [[0, 1, 2], [3, 4, 5]], dim_names=["x", "y"] + ) + compatible_process_mesh = compute_compatible_process_mesh( + [process_mesh1, None] + ) + self.assertEqual(compatible_process_mesh, process_mesh1) + compatible_process_mesh = compute_compatible_process_mesh( + [None, process_mesh1] + ) + self.assertEqual(compatible_process_mesh, process_mesh1) + + process_mesh2 = ProcessMesh([[0, 1, 2], [3, 4, 5]]) + compatible_process_mesh = compute_compatible_process_mesh( + [process_mesh1, process_mesh2] + ) + self.assertEqual(compatible_process_mesh, process_mesh1) + self.assertEqual(compatible_process_mesh, process_mesh2) + + process_mesh2 = ProcessMesh([[0, 1, 2, 3, 4, 5]]) + compatible_process_mesh = compute_compatible_process_mesh( + [process_mesh1, process_mesh2] + ) + self.assertEqual(compatible_process_mesh, process_mesh1) + + process_mesh2 = ProcessMesh([[0, 1, 2]]) + compatible_process_mesh = compute_compatible_process_mesh( + [process_mesh1, process_mesh2] + ) + self.assertEqual(compatible_process_mesh, process_mesh1) + + def test_merge_process_meshes(self): + process_mesh1 = ProcessMesh( + [[0, 1, 2], [3, 4, 5]], dim_names=["x", "y"] + ) + merged_process_mesh = merge_process_meshes([process_mesh1, None]) + self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5])) + merged_process_mesh = merge_process_meshes([None, process_mesh1]) + self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5])) + + process_mesh2 = ProcessMesh([[0, 1, 2], [3, 4, 5]]) + merged_process_mesh = merge_process_meshes( + [process_mesh1, process_mesh2] + ) + self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5])) + + process_mesh2 = ProcessMesh([[0, 1, 2]]) + merged_process_mesh = merge_process_meshes( + [process_mesh1, process_mesh2] + ) + self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5])) + + process_mesh2 = ProcessMesh([[6, 7]]) + merged_process_mesh = merge_process_meshes( + [process_mesh1, process_mesh2] + ) + self.assertEqual( + merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5, 6, 7]) + ) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_process_mesh_v2.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_process_mesh_v2.py index 649e2044e56..03ec95c7187 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_process_mesh_v2.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_process_mesh_v2.py @@ -80,7 +80,6 @@ class TestProcessMesh(unittest.TestCase): [[0, 1, 2], [3, 4, 5]], dim_names=["x", "y"] ) merged_process_mesh = merge_process_mesh([process_mesh1, None]) - print(merged_process_mesh) self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5])) merged_process_mesh = merge_process_mesh([None, process_mesh1]) self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5])) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py index 8300aaa418c..2de97b63a51 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py @@ -105,7 +105,7 @@ def initialization_check( ): if 'mp' in mode: group_ranks = _get_comm_group( - process_mesh.processes, process_mesh.topology, mp_parallel_axis, 3 + process_mesh.process_ids, process_mesh.shape, mp_parallel_axis, 3 ) mp_ring_id = new_process_group(group_ranks).id broadcast_ops = [ @@ -124,7 +124,7 @@ def initialization_check( if 'dp' in mode: group_ranks = _get_comm_group( - process_mesh.processes, process_mesh.topology, dp_parallel_axis, 3 + process_mesh.process_ids, process_mesh.shape, dp_parallel_axis, 3 ) dp_ring_id = new_process_group(group_ranks).id nparam = len(serial_startup_prog.all_parameters()) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py index a41b79d4eff..f4fa7874882 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py @@ -958,12 +958,12 @@ class TestGPTPartitioner(unittest.TestCase): dp_parallel_axis = 0 group_ranks = _get_comm_group( - process_mesh.processes, process_mesh.topology, mp_parallel_axis, 3 + process_mesh.process_ids, process_mesh.shape, mp_parallel_axis, 3 ) mp_ring_id = new_process_group(group_ranks).id group_ranks = _get_comm_group( - process_mesh.processes, process_mesh.topology, dp_parallel_axis, 3 + process_mesh.process_ids, process_mesh.shape, dp_parallel_axis, 3 ) dp_ring_id = new_process_group(group_ranks).id -- GitLab