未验证 提交 1c0afa79 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Merge the python and c++ impls of ProcessMesh (#47503)

* [Auto Parallel] Rename methods of ProcessMesh

* [Auto Parallel] Impl the python process_mesh by the c++ one

* [Auto Parallel] Add some minor modifications

* [Auto Parallel] Rename some methods

* [Auto Parallel] Remove unnecessary codes

* [Auto Parallel] Add back some removed files

* [Auto Parallel] Fix bugs

* [Auto Parallel] Fix a bug

* Update process_mesh.cc

* [Auto Parallel] Fix a bug
上级 3f614f48
...@@ -69,6 +69,12 @@ void BindAutoParallel(py::module *m) { ...@@ -69,6 +69,12 @@ void BindAutoParallel(py::module *m) {
.def("contains", &ProcessMesh::contains) .def("contains", &ProcessMesh::contains)
.def(py::self == py::self) .def(py::self == py::self)
.def(py::self != py::self) .def(py::self != py::self)
.def("__copy__",
[](const ProcessMesh &self) { return ProcessMesh(self); })
.def(
"__deepcopy__",
[](const ProcessMesh &self, py::dict) { return ProcessMesh(self); },
py::arg("memo"))
.def("__str__", &ProcessMesh::to_string); .def("__str__", &ProcessMesh::to_string);
py::class_<DeviceCapability>(*m, "DeviceCapability") py::class_<DeviceCapability>(*m, "DeviceCapability")
...@@ -131,7 +137,7 @@ void BindAutoParallel(py::module *m) { ...@@ -131,7 +137,7 @@ void BindAutoParallel(py::module *m) {
const std::vector<std::string> &>(), const std::vector<std::string> &>(),
py::arg("name"), py::arg("name"),
py::arg("shape"), py::arg("shape"),
py::arg("process_ids"), py::arg("device_ids"),
py::arg("dim_names")) py::arg("dim_names"))
.def_property_readonly("name", &DeviceMesh::name) .def_property_readonly("name", &DeviceMesh::name)
.def_property_readonly("shape", &DeviceMesh::shape) .def_property_readonly("shape", &DeviceMesh::shape)
...@@ -165,6 +171,8 @@ void BindAutoParallel(py::module *m) { ...@@ -165,6 +171,8 @@ void BindAutoParallel(py::module *m) {
&DeviceMesh::dim_size)) &DeviceMesh::dim_size))
.def(py::self == py::self) .def(py::self == py::self)
.def(py::self != py::self) .def(py::self != py::self)
.def("__copy__",
[](const TensorDistAttr &self) { return TensorDistAttr(self); })
.def( .def(
"__deepcopy__", "__deepcopy__",
[](const TensorDistAttr &self, py::dict) { [](const TensorDistAttr &self, py::dict) {
...@@ -256,6 +264,8 @@ void BindAutoParallel(py::module *m) { ...@@ -256,6 +264,8 @@ void BindAutoParallel(py::module *m) {
.def("parse_from_string", &OperatorDistAttr::parse_from_string) .def("parse_from_string", &OperatorDistAttr::parse_from_string)
.def(py::self == py::self) .def(py::self == py::self)
.def(py::self != py::self) .def(py::self != py::self)
.def("__copy__",
[](const OperatorDistAttr &self) { return OperatorDistAttr(self); })
.def( .def(
"__deepcopy__", "__deepcopy__",
[](const OperatorDistAttr &self, py::dict) { [](const OperatorDistAttr &self, py::dict) {
......
...@@ -25,7 +25,7 @@ from .dist_attribute import ( ...@@ -25,7 +25,7 @@ from .dist_attribute import (
from .dist_context import _node_id from .dist_context import _node_id
from .operators import find_compatible_distributed_operator_impls from .operators import find_compatible_distributed_operator_impls
from .process_group import get_world_process_group from .process_group import get_world_process_group
from .process_mesh import ProcessMesh from .process_mesh import ProcessMesh, compute_compatible_process_mesh
from .utils import ( from .utils import (
__no_shape_var_type__, __no_shape_var_type__,
get_logger, get_logger,
...@@ -34,47 +34,12 @@ from .utils import ( ...@@ -34,47 +34,12 @@ from .utils import (
) )
def compute_compatible_process_mesh(process_mesh_list):
"""Compute the compatible process mesh given a list of process meshes."""
if not process_mesh_list:
return None
def _compute_compatible_process_mesh_two(pm1, pm2):
if pm1 is None:
return True, pm2
if pm2 is None:
return True, pm1
if pm1 == pm2:
return True, pm1
if pm1.processes == pm2.processes:
if len(pm1.topology) >= len(pm2.topology):
return True, pm1
else:
return True, pm2
process_set1 = set(pm1.processes)
process_set2 = set(pm2.processes)
if process_set1.issubset(process_set2):
return True, pm2
if process_set2.issubset(process_set1):
return True, pm1
return False, None
compatible_result = None
for process_mesh in process_mesh_list:
compatible, compatible_result = _compute_compatible_process_mesh_two(
compatible_result, process_mesh
)
if not compatible:
return None
return copy.deepcopy(compatible_result)
def compute_compatible_dim_mapping(dim_mapping_list): def compute_compatible_dim_mapping(dim_mapping_list):
"""Compute the compatible dim mapping given a list of dim mapping.""" """Compute the compatible dim mapping given a list of dim mapping."""
if not dim_mapping_list: if not dim_mapping_list:
return None return None
def _compute_compatible_dim_mapping_two(dm1, dm2): def _compute_compatible_dim_mapping_of_two(dm1, dm2):
if dm1 == -1: if dm1 == -1:
return True, dm2 return True, dm2
if dm2 == -1: if dm2 == -1:
...@@ -85,7 +50,7 @@ def compute_compatible_dim_mapping(dim_mapping_list): ...@@ -85,7 +50,7 @@ def compute_compatible_dim_mapping(dim_mapping_list):
compatible_result = -1 compatible_result = -1
for mapping in dim_mapping_list: for mapping in dim_mapping_list:
compatible, compatible_result = _compute_compatible_dim_mapping_two( compatible, compatible_result = _compute_compatible_dim_mapping_of_two(
compatible_result, mapping compatible_result, mapping
) )
if not compatible: if not compatible:
...@@ -122,9 +87,9 @@ def merge_process_mesh_two(pm1, pm2): ...@@ -122,9 +87,9 @@ def merge_process_mesh_two(pm1, pm2):
if pm1 is None and pm2 is None: if pm1 is None and pm2 is None:
return None return None
if pm1 is not None: if pm1 is not None:
process_set1 = set(pm1.processes) process_set1 = set(pm1.process_ids)
if pm2 is not None: if pm2 is not None:
process_set2 = set(pm2.processes) process_set2 = set(pm2.process_ids)
merged_process_set = process_set1.union(process_set2) merged_process_set = process_set1.union(process_set2)
merged_process_mesh = ProcessMesh(list(merged_process_set)) merged_process_mesh = ProcessMesh(list(merged_process_set))
return merged_process_mesh return merged_process_mesh
...@@ -134,11 +99,9 @@ def _validate_dims_mapping(dims_mapping, process_mesh): ...@@ -134,11 +99,9 @@ def _validate_dims_mapping(dims_mapping, process_mesh):
if dims_mapping is None: if dims_mapping is None:
return False return False
for i in range(len(dims_mapping)): for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len( if dims_mapping[i] < -1 or dims_mapping[i] >= len(process_mesh.shape):
process_mesh.topology
):
return False return False
for i in range(len(process_mesh.topology)): for i in range(len(process_mesh.shape)):
if dims_mapping.count(i) > 1: if dims_mapping.count(i) > 1:
return False return False
return True return True
......
...@@ -80,7 +80,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context): ...@@ -80,7 +80,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
assert process_mesh, "Process mesh must not be None." assert process_mesh, "Process mesh must not be None."
processes = process_mesh.processes processes = process_mesh.process_ids
for process in processes: for process in processes:
desc = {} desc = {}
desc["op"] = op.type desc["op"] = op.type
...@@ -103,7 +103,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context): ...@@ -103,7 +103,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
global_sizes = var.shape global_sizes = var.shape
# NOTE: When support uneven partition, the shard_sizes will be got from dist_attr. # NOTE: When support uneven partition, the shard_sizes will be got from dist_attr.
shard_sizes = None shard_sizes = None
topology = process_mesh.topology topology = process_mesh.shape
shape = DistributedTensor.get_local_sizes( shape = DistributedTensor.get_local_sizes(
global_sizes, global_sizes,
dims_mapping, dims_mapping,
...@@ -129,7 +129,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context): ...@@ -129,7 +129,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
) )
relative_idx = _get_idx_in_axis( relative_idx = _get_idx_in_axis(
processes, processes,
dist_attr.process_mesh.topology, dist_attr.process_mesh.shape,
embedding_row_dim_mapping, embedding_row_dim_mapping,
process, process,
) )
...@@ -153,8 +153,8 @@ def build_comp_desc_from_dist_op(dist_op, dist_context): ...@@ -153,8 +153,8 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
global_sizes = var.shape global_sizes = var.shape
shard_sizes = None shard_sizes = None
processes = process_mesh.processes processes = process_mesh.process_ids
topology = process_mesh.topology topology = process_mesh.shape
shape = DistributedTensor.get_local_sizes( shape = DistributedTensor.get_local_sizes(
global_sizes, global_sizes,
dims_mapping, dims_mapping,
...@@ -170,7 +170,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context): ...@@ -170,7 +170,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
# Modify shape attr according to how output are partitioned # Modify shape attr according to how output are partitioned
out_name = var_name_list[0] out_name = var_name_list[0]
dims_mapping = dist_attr.get_output_dims_mapping(out_name) dims_mapping = dist_attr.get_output_dims_mapping(out_name)
process_mesh_shape = dist_attr.process_mesh.topology process_mesh_shape = dist_attr.process_mesh.shape
shape_list = op.attr("shape") shape_list = op.attr("shape")
# Modify target shape # Modify target shape
for idx, axis in enumerate(dims_mapping): for idx, axis in enumerate(dims_mapping):
...@@ -253,7 +253,7 @@ def build_comm_desc_from_dist_op( ...@@ -253,7 +253,7 @@ def build_comm_desc_from_dist_op(
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
assert process_mesh, "Process mesh must not be None." assert process_mesh, "Process mesh must not be None."
processes = process_mesh.processes processes = process_mesh.process_ids
op_descs = {} op_descs = {}
for process in processes: for process in processes:
rank_id = process rank_id = process
...@@ -295,7 +295,7 @@ def build_comm_desc_from_dist_op( ...@@ -295,7 +295,7 @@ def build_comm_desc_from_dist_op(
) )
global_sizes = var.shape global_sizes = var.shape
shard_sizes = None shard_sizes = None
topology = process_mesh.topology topology = process_mesh.shape
shape = DistributedTensor.get_local_sizes( shape = DistributedTensor.get_local_sizes(
global_sizes, global_sizes,
dims_mapping, dims_mapping,
...@@ -311,8 +311,8 @@ def build_comm_desc_from_dist_op( ...@@ -311,8 +311,8 @@ def build_comm_desc_from_dist_op(
# Get comm group by parallel_axis or the given group_ranks. # Get comm group by parallel_axis or the given group_ranks.
if parallel_axis is not None: if parallel_axis is not None:
process_mesh_shape = process_mesh.topology process_mesh_shape = process_mesh.shape
process_mesh_group = process_mesh.processes process_mesh_group = process_mesh.process_ids
comm_group_ranks = _get_comm_group( comm_group_ranks = _get_comm_group(
process_mesh_group, process_mesh_group,
process_mesh_shape, process_mesh_shape,
...@@ -384,7 +384,7 @@ def build_dp_costs( ...@@ -384,7 +384,7 @@ def build_dp_costs(
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.process_ids
assert len(var_names) == 1 assert len(var_names) == 1
vars = dist_op.serial_op.block.vars vars = dist_op.serial_op.block.vars
var_name = var_names[0] var_name = var_names[0]
...@@ -443,7 +443,7 @@ def build_dp_costs( ...@@ -443,7 +443,7 @@ def build_dp_costs(
) )
global_sizes = var.shape global_sizes = var.shape
shard_sizes = None shard_sizes = None
topology = process_mesh.topology topology = process_mesh.shape
shape = DistributedTensor.get_local_sizes( shape = DistributedTensor.get_local_sizes(
global_sizes, global_sizes,
dims_mapping, dims_mapping,
......
...@@ -190,7 +190,7 @@ class CostEstimator: ...@@ -190,7 +190,7 @@ class CostEstimator:
# Calc dist op cost # Calc dist op cost
dist_op = dist_context.get_dist_op_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.processes processes = op_dist_attr.process_mesh.process_ids
container = get_distributed_operator_impl_container( container = get_distributed_operator_impl_container(
op_dist_attr.impl_type op_dist_attr.impl_type
...@@ -273,8 +273,8 @@ class CostEstimator: ...@@ -273,8 +273,8 @@ class CostEstimator:
# This estimation will be improved, now reshard and inplace are not considered. # This estimation will be improved, now reshard and inplace are not considered.
# Persist var is not free. # Persist var is not free.
def _convert_pm_and_dm_to_str(process_mesh, dims_mapping): def _convert_pm_and_dm_to_str(process_mesh, dims_mapping):
processes = ",".join([str(x) for x in process_mesh.processes]) processes = ",".join([str(x) for x in process_mesh.process_ids])
topology = ",".join([str(x) for x in process_mesh.topology]) topology = ",".join([str(x) for x in process_mesh.shape])
dims_mapping = ",".join([str(x) for x in dims_mapping]) dims_mapping = ",".join([str(x) for x in dims_mapping])
result = processes + topology + dims_mapping result = processes + topology + dims_mapping
return result return result
...@@ -318,8 +318,8 @@ class CostEstimator: ...@@ -318,8 +318,8 @@ class CostEstimator:
sizes = DistributedTensor.get_local_sizes( sizes = DistributedTensor.get_local_sizes(
global_sizes, global_sizes,
input_dims_mapping, input_dims_mapping,
process_mesh.topology, process_mesh.shape,
process_mesh.processes, process_mesh.process_ids,
) )
var_info[var_name][key]["memory"] = self._calculate_bytes( var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype sizes, dtype
...@@ -346,8 +346,8 @@ class CostEstimator: ...@@ -346,8 +346,8 @@ class CostEstimator:
sizes = DistributedTensor.get_local_sizes( sizes = DistributedTensor.get_local_sizes(
global_sizes, global_sizes,
output_dims_mapping, output_dims_mapping,
process_mesh.topology, process_mesh.shape,
process_mesh.processes, process_mesh.process_ids,
) )
var_info[var_name][key]["memory"] = self._calculate_bytes( var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype sizes, dtype
...@@ -380,7 +380,7 @@ class CostEstimator: ...@@ -380,7 +380,7 @@ class CostEstimator:
# Not used # Not used
if var_name + key not in has_used_vars: if var_name + key not in has_used_vars:
has_used_vars.add(has_used_var) has_used_vars.add(has_used_var)
for process in process_mesh.processes: for process in process_mesh.process_ids:
if process not in memories: if process not in memories:
memories[process] = 0 memories[process] = 0
memories[process] += var_info[var_name][key]["memory"] memories[process] += var_info[var_name][key]["memory"]
...@@ -390,7 +390,7 @@ class CostEstimator: ...@@ -390,7 +390,7 @@ class CostEstimator:
if has_used_var not in can_free_vars: if has_used_var not in can_free_vars:
can_free_vars.add(has_used_var) can_free_vars.add(has_used_var)
if not var.persistable: if not var.persistable:
for process in process_mesh.processes: for process in process_mesh.process_ids:
if process not in can_free_memories: if process not in can_free_memories:
can_free_memories[process] = 0 can_free_memories[process] = 0
can_free_memories[process] += var_info[ can_free_memories[process] += var_info[
...@@ -409,7 +409,7 @@ class CostEstimator: ...@@ -409,7 +409,7 @@ class CostEstimator:
# Not used # Not used
if var_name + key not in has_used_vars: if var_name + key not in has_used_vars:
has_used_vars.add(has_used_var) has_used_vars.add(has_used_var)
for process in process_mesh.processes: for process in process_mesh.process_ids:
if process not in memories: if process not in memories:
memories[process] = 0 memories[process] = 0
memories[process] += var_info[var_name][key]["memory"] memories[process] += var_info[var_name][key]["memory"]
...@@ -419,7 +419,7 @@ class CostEstimator: ...@@ -419,7 +419,7 @@ class CostEstimator:
if has_used_var not in can_free_vars: if has_used_var not in can_free_vars:
can_free_vars.add(has_used_var) can_free_vars.add(has_used_var)
if not var.persistable: if not var.persistable:
for process in process_mesh.processes: for process in process_mesh.process_ids:
if process not in can_free_memories: if process not in can_free_memories:
can_free_memories[process] = 0 can_free_memories[process] = 0
can_free_memories[process] += var_info[ can_free_memories[process] += var_info[
......
...@@ -121,32 +121,17 @@ class TensorDistributedAttribute: ...@@ -121,32 +121,17 @@ class TensorDistributedAttribute:
if dist_attr is None: if dist_attr is None:
return return
assert isinstance( assert isinstance(
dist_attr, (dict, TensorDistributedAttribute) dist_attr, TensorDistributedAttribute
), "The type of dist_attr must be dict or TensorDistributedAttribute." ), "The type of dist_attr must be dict or TensorDistributedAttribute."
if isinstance(dist_attr, dict): for key in get_tensor_dist_attr_field_keys():
for key, value in dist_attr.items(): field_property = TensorDistributedAttribute.__dict__.get(key, None)
if key in get_tensor_dist_attr_field_keys(): if field_property:
field_property = TensorDistributedAttribute.__dict__.get( field_property.fset(self, field_property.fget(dist_attr))
key, None else:
) assert False, "No setter for {} in args {}.".format(
if field_property: key, dist_attr
field_property.fset(self, value)
else:
assert False, "No setter for {} in args {}.".format(
key, dist_attr
)
elif isinstance(dist_attr, TensorDistributedAttribute):
for key in get_tensor_dist_attr_field_keys():
field_property = TensorDistributedAttribute.__dict__.get(
key, None
) )
if field_property: self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
field_property.fset(self, field_property.fget(dist_attr))
else:
assert False, "No setter for {} in args {}.".format(
key, dist_attr
)
self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
def reset(self, skip_dist_attr_field_names=None): def reset(self, skip_dist_attr_field_names=None):
if skip_dist_attr_field_names is None or ( if skip_dist_attr_field_names is None or (
...@@ -243,7 +228,9 @@ class OperatorDistributedAttribute: ...@@ -243,7 +228,9 @@ class OperatorDistributedAttribute:
if process_mesh is not None: if process_mesh is not None:
assert isinstance( assert isinstance(
process_mesh, (list, ProcessMesh) process_mesh, (list, ProcessMesh)
), "The type of process_mesh must be list or ProcessMesh." ), "The type of process_mesh must be list or ProcessMesh, but receive {}".format(
type(process_mesh)
)
if isinstance(process_mesh, list): if isinstance(process_mesh, list):
process_mesh = ProcessMesh(process_mesh) process_mesh = ProcessMesh(process_mesh)
self._process_mesh = copy.deepcopy(process_mesh) self._process_mesh = copy.deepcopy(process_mesh)
......
...@@ -870,8 +870,8 @@ class DistributedContext: ...@@ -870,8 +870,8 @@ class DistributedContext:
else: else:
tensor_shape = serial_tensor.shape tensor_shape = serial_tensor.shape
dims_mapping = dist_attr.dims_mapping dims_mapping = dist_attr.dims_mapping
process_mesh_shape = dist_attr.process_mesh.topology process_mesh_shape = dist_attr.process_mesh.shape
process_mesh_processes = dist_attr.process_mesh.processes process_mesh_processes = dist_attr.process_mesh.process_ids
# If the dimension of tensor is less than the sharding dimension of process mesh, # If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?) # we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)): for i in range(len(tensor_shape)):
...@@ -887,8 +887,8 @@ class DistributedContext: ...@@ -887,8 +887,8 @@ class DistributedContext:
for dist_op in self._dist_ops_for_program.values(): for dist_op in self._dist_ops_for_program.values():
serial_op = dist_op.serial_op serial_op = dist_op.serial_op
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
process_mesh_shape = dist_attr.process_mesh.topology process_mesh_shape = dist_attr.process_mesh.shape
process_mesh_processes = dist_attr.process_mesh.processes process_mesh_processes = dist_attr.process_mesh.process_ids
for arg_name in serial_op.input_arg_names: for arg_name in serial_op.input_arg_names:
if dist_op.get_serial_input(arg_name) is None: if dist_op.get_serial_input(arg_name) is None:
tensor_shape = [] tensor_shape = []
......
...@@ -159,10 +159,10 @@ class DistributedOperator: ...@@ -159,10 +159,10 @@ class DistributedOperator:
return False return False
for i in range(len(dims_mapping)): for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len( if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology self.dist_attr.process_mesh.shape
): ):
return False return False
for i in range(len(self.dist_attr.process_mesh.topology)): for i in range(len(self.dist_attr.process_mesh.shape)):
if dims_mapping.count(i) > 1: if dims_mapping.count(i) > 1:
return False return False
if self.dist_attr.process_mesh != input_dist_attr.process_mesh: if self.dist_attr.process_mesh != input_dist_attr.process_mesh:
...@@ -179,10 +179,10 @@ class DistributedOperator: ...@@ -179,10 +179,10 @@ class DistributedOperator:
return False return False
for i in range(len(dims_mapping)): for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len( if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology self.dist_attr.process_mesh.shape
): ):
return False return False
for i in range(len(self.dist_attr.process_mesh.topology)): for i in range(len(self.dist_attr.process_mesh.shape)):
if dims_mapping.count(i) > 1: if dims_mapping.count(i) > 1:
return False return False
if self.dist_attr.process_mesh != output_dist_attr.process_mesh: if self.dist_attr.process_mesh != output_dist_attr.process_mesh:
......
...@@ -225,10 +225,10 @@ class DistributedTensor: ...@@ -225,10 +225,10 @@ class DistributedTensor:
if self.dist_attr.dims_mapping[ if self.dist_attr.dims_mapping[
i i
] < -1 or self.dist_attr.dims_mapping[i] >= len( ] < -1 or self.dist_attr.dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology self.dist_attr.process_mesh.shape
): ):
return False return False
for i in range(len(self.dist_attr.process_mesh.topology)): for i in range(len(self.dist_attr.process_mesh.shape)):
if self.dist_attr.dims_mapping.count(i) > 1: if self.dist_attr.dims_mapping.count(i) > 1:
return False return False
return True return True
...@@ -239,8 +239,8 @@ class DistributedTensor: ...@@ -239,8 +239,8 @@ class DistributedTensor:
global_sizes = self.serial_tensor.shape global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.processes processes = self.dist_attr.process_mesh.process_ids
topology = self.dist_attr.process_mesh.topology topology = self.dist_attr.process_mesh.shape
local_sizes = DistributedTensor.get_local_sizes( local_sizes = DistributedTensor.get_local_sizes(
global_sizes, dims_mapping, topology, processes, rank, shard_sizes global_sizes, dims_mapping, topology, processes, rank, shard_sizes
) )
...@@ -256,8 +256,8 @@ class DistributedTensor: ...@@ -256,8 +256,8 @@ class DistributedTensor:
global_sizes = self.serial_tensor.shape global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.processes processes = self.dist_attr.process_mesh.process_ids
topology = self.dist_attr.process_mesh.topology topology = self.dist_attr.process_mesh.shape
local_offsets = DistributedTensor.get_local_offsets( local_offsets = DistributedTensor.get_local_offsets(
global_sizes, global_sizes,
dims_mapping, dims_mapping,
...@@ -282,8 +282,8 @@ class DistributedTensor: ...@@ -282,8 +282,8 @@ class DistributedTensor:
global_sizes = self.serial_tensor.shape global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.processes processes = self.dist_attr.process_mesh.process_ids
topology = self.dist_attr.process_mesh.topology topology = self.dist_attr.process_mesh.shape
local_shard = DistributedTensor.get_local_shard( local_shard = DistributedTensor.get_local_shard(
global_sizes, global_sizes,
dims_mapping, dims_mapping,
......
...@@ -324,8 +324,8 @@ class ProgramHelper: ...@@ -324,8 +324,8 @@ class ProgramHelper:
var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var) var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
dist_attr = { dist_attr = {
"dims_mapping": var_dist_attr.dims_mapping, "dims_mapping": var_dist_attr.dims_mapping,
"process_shape": var_dist_attr.process_mesh.topology, "process_shape": var_dist_attr.process_mesh.shape,
"process_group": var_dist_attr.process_mesh.processes, "process_group": var_dist_attr.process_mesh.process_ids,
} }
# slice param_value with dist_attr # slice param_value with dist_attr
# share sliced_param_value with param_tensor in global_scope # share sliced_param_value with param_tensor in global_scope
......
...@@ -275,7 +275,7 @@ def is_parameter_related(varname, block): ...@@ -275,7 +275,7 @@ def is_parameter_related(varname, block):
def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr): def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr):
var_shape = block._var_recursive(src_var.name).shape var_shape = block._var_recursive(src_var.name).shape
var_topoloy = src_var_dist_attr.process_mesh.topology var_topoloy = src_var_dist_attr.process_mesh.shape
var_dims_mapping = src_var_dist_attr.dims_mapping var_dims_mapping = src_var_dist_attr.dims_mapping
complete_shape = [] complete_shape = []
...@@ -287,7 +287,7 @@ def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr): ...@@ -287,7 +287,7 @@ def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr):
complete_shape.append(new_shape) complete_shape.append(new_shape)
exact_shape = [] exact_shape = []
input_topology = op_input_dist_attr.process_mesh.topology input_topology = op_input_dist_attr.process_mesh.shape
input_dims_mapping = op_input_dist_attr.dims_mapping input_dims_mapping = op_input_dist_attr.dims_mapping
for idx, shape in enumerate(complete_shape): for idx, shape in enumerate(complete_shape):
if input_dims_mapping[idx] == -1: if input_dims_mapping[idx] == -1:
...@@ -362,10 +362,10 @@ def get_data_parallel_group(dist_ctx, op, act_grad_names, rank): ...@@ -362,10 +362,10 @@ def get_data_parallel_group(dist_ctx, op, act_grad_names, rank):
op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op) op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
process_mesh = op_dist_attr.process_mesh process_mesh = op_dist_attr.process_mesh
mesh_shape = process_mesh.topology mesh_shape = process_mesh.shape
# FIXME Hack for Pipeline Parallelism where the current operator # FIXME Hack for Pipeline Parallelism where the current operator
# not belong to the mesh the current rank belong to. # not belong to the mesh the current rank belong to.
if rank not in process_mesh.processes: if rank not in process_mesh.process_ids:
rank = _get_corresponding_rank(dist_ctx, process_mesh, rank) rank = _get_corresponding_rank(dist_ctx, process_mesh, rank)
for var_name in act_grad_names: for var_name in act_grad_names:
...@@ -376,8 +376,8 @@ def get_data_parallel_group(dist_ctx, op, act_grad_names, rank): ...@@ -376,8 +376,8 @@ def get_data_parallel_group(dist_ctx, op, act_grad_names, rank):
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
group_ranks = _get_comm_group( group_ranks = _get_comm_group(
process_mesh.processes, process_mesh.process_ids,
process_mesh.topology, process_mesh.shape,
batch_size_axis, batch_size_axis,
rank, rank,
) )
......
...@@ -89,7 +89,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): ...@@ -89,7 +89,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
str(backward_op) str(backward_op)
) )
assert rank_id in dist_attr.process_mesh.processes assert rank_id in dist_attr.process_mesh.process_ids
assert 'X' in kwargs, "input [{}] is not given".format('X') assert 'X' in kwargs, "input [{}] is not given".format('X')
assert 'Scale' in kwargs, "input [{}] is not given".format('Scale') assert 'Scale' in kwargs, "input [{}] is not given".format('Scale')
...@@ -120,7 +120,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): ...@@ -120,7 +120,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
rank_id rank_id
in ctx.get_tensor_dist_attr_for_program( in ctx.get_tensor_dist_attr_for_program(
main_block._var_recursive(varname) main_block._var_recursive(varname)
).process_mesh.processes ).process_mesh.process_ids
): ):
filter_vars.append(varname) filter_vars.append(varname)
......
...@@ -124,7 +124,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -124,7 +124,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op( desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx dist_op=dist_op, dist_context=ctx
) )
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.process_ids
op_type = dist_op.serial_op.type op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
_g_op_cost_factory[op_type], ctx, processes, desc_mapping, cluster _g_op_cost_factory[op_type], ctx, processes, desc_mapping, cluster
...@@ -141,7 +141,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -141,7 +141,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
) )
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.process_ids
backward_op = dist_op.serial_op backward_op = dist_op.serial_op
op_type = backward_op.type op_type = backward_op.type
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
...@@ -157,7 +157,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -157,7 +157,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
varname, main_block varname, main_block
): ):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname) var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True need_gradient_allreduce = True
...@@ -172,7 +172,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -172,7 +172,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping( var_dim_mapping = dist_attr.get_input_dims_mapping(
varname varname
) )
mesh_shape = process_mesh.topology mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True} attrs = {"use_calc_stream": True}
...@@ -501,19 +501,19 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -501,19 +501,19 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
dims_mapping = param_dist_attr.dims_mapping dims_mapping = param_dist_attr.dims_mapping
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in process_mesh.processes: if rank_id not in process_mesh.process_ids:
rank_id = _get_corresponding_rank( rank_id = _get_corresponding_rank(
ctx, process_mesh, rank_id ctx, process_mesh, rank_id
) )
# NOTE all not splited axis should be presented in mesh # NOTE all not splited axis should be presented in mesh
for axis, size in enumerate(process_mesh.topology): for axis, size in enumerate(process_mesh.shape):
if size <= 1 or axis in dims_mapping: if size <= 1 or axis in dims_mapping:
pass pass
else: else:
group_ranks = _get_comm_group( group_ranks = _get_comm_group(
process_mesh.processes, process_mesh.process_ids,
process_mesh.topology, process_mesh.shape,
axis, axis,
rank_id, rank_id,
) )
......
...@@ -67,7 +67,7 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): ...@@ -67,7 +67,7 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op( desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx dist_op=dist_op, dist_context=ctx
) )
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.process_ids
op_type = dist_op.serial_op.type op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
_g_op_cost_factory[op_type], ctx, processes, desc_mapping, cluster _g_op_cost_factory[op_type], ctx, processes, desc_mapping, cluster
...@@ -84,7 +84,7 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): ...@@ -84,7 +84,7 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
) )
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.process_ids
backward_op = dist_op.serial_op backward_op = dist_op.serial_op
op_type = backward_op.type op_type = backward_op.type
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
...@@ -100,7 +100,7 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): ...@@ -100,7 +100,7 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
varname, main_block varname, main_block
): ):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname) var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True need_gradient_allreduce = True
...@@ -115,7 +115,7 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): ...@@ -115,7 +115,7 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping( var_dim_mapping = dist_attr.get_input_dims_mapping(
varname varname
) )
mesh_shape = process_mesh.topology mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True} attrs = {"use_calc_stream": True}
......
...@@ -175,7 +175,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -175,7 +175,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op( desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx dist_op=dist_op, dist_context=ctx
) )
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.process_ids
# embedding need start_index # embedding need start_index
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
EmbeddingOpCost, ctx, processes, desc_mapping, cluster EmbeddingOpCost, ctx, processes, desc_mapping, cluster
...@@ -231,7 +231,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -231,7 +231,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
) )
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.process_ids
comm_op_cost_list = build_comm_costs_from_descs( comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
) )
...@@ -250,7 +250,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -250,7 +250,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping( var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Ids")[0] backward_op.input("Ids")[0]
) )
mesh_shape = process_mesh.topology mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
...@@ -390,8 +390,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -390,8 +390,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
), "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format( ), "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
embedding_row_dim_mapping embedding_row_dim_mapping
) )
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.processes process_mesh_group = op_dist_attr.process_mesh.process_ids
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in process_mesh_group: if rank_id not in process_mesh_group:
...@@ -539,13 +539,13 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -539,13 +539,13 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
dim_mapping = param_dist_attr.dims_mapping dim_mapping = param_dist_attr.dims_mapping
# NOTE all not splited axis should be presented in mesh # NOTE all not splited axis should be presented in mesh
for axis, size in enumerate(process_mesh.topology): for axis, size in enumerate(process_mesh.shape):
if size <= 1 or axis in dim_mapping: if size <= 1 or axis in dim_mapping:
pass pass
else: else:
group_ranks = _get_comm_group( group_ranks = _get_comm_group(
process_mesh.processes, process_mesh.process_ids,
process_mesh.topology, process_mesh.shape,
axis, axis,
rank_id, rank_id,
) )
...@@ -579,7 +579,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -579,7 +579,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
) )
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in dist_attr.process_mesh.processes: if rank_id not in dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank( rank_id = _get_corresponding_rank(
ctx, dist_attr.process_mesh, rank_id ctx, dist_attr.process_mesh, rank_id
) )
...@@ -623,8 +623,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -623,8 +623,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
), "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format( ), "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
embedding_row_dim_mapping embedding_row_dim_mapping
) )
process_mesh_shape = dist_attr.process_mesh.topology process_mesh_shape = dist_attr.process_mesh.shape
process_mesh_group = dist_attr.process_mesh.processes process_mesh_group = dist_attr.process_mesh.process_ids
# A generalized method to caculate embedding offset using cartisian product # A generalized method to caculate embedding offset using cartisian product
relative_idx = _get_idx_in_axis( relative_idx = _get_idx_in_axis(
......
...@@ -61,7 +61,7 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): ...@@ -61,7 +61,7 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op( desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx dist_op=dist_op, dist_context=ctx
) )
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.process_ids
op_type = dist_op.serial_op.type op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
FillConstantBatchSizeLikeOpCost, FillConstantBatchSizeLikeOpCost,
...@@ -139,7 +139,7 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): ...@@ -139,7 +139,7 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl):
# modify shape attr according to how output are partitioned # modify shape attr according to how output are partitioned
out_name = op.output('Out')[0] out_name = op.output('Out')[0]
dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.shape
shape_list = op.attr("shape") shape_list = op.attr("shape")
# modify target shape # modify target shape
for idx, axis in enumerate(dims_mapping): for idx, axis in enumerate(dims_mapping):
......
...@@ -156,7 +156,7 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl): ...@@ -156,7 +156,7 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl):
rank_id = dist_op_context.rank_id rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
if rank_id not in op_dist_attr.process_mesh.processes: if rank_id not in op_dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank( rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id ctx, op_dist_attr.process_mesh, rank_id
) )
...@@ -172,8 +172,8 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl): ...@@ -172,8 +172,8 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl):
), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
qkv_w_col_dim_mapping qkv_w_col_dim_mapping
) )
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.processes process_mesh_group = op_dist_attr.process_mesh.process_ids
parallel_axis = qkv_w_col_dim_mapping parallel_axis = qkv_w_col_dim_mapping
group_ranks = _get_comm_group( group_ranks = _get_comm_group(
...@@ -198,7 +198,7 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl): ...@@ -198,7 +198,7 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl):
rank_id = dist_op_context.rank_id rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
if rank_id not in op_dist_attr.process_mesh.processes: if rank_id not in op_dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank( rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id ctx, op_dist_attr.process_mesh, rank_id
) )
...@@ -211,8 +211,8 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl): ...@@ -211,8 +211,8 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl):
), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
out_w_col_dim_mapping out_w_col_dim_mapping
) )
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.processes process_mesh_group = op_dist_attr.process_mesh.process_ids
parallel_axis = out_w_col_dim_mapping parallel_axis = out_w_col_dim_mapping
group_ranks = _get_comm_group( group_ranks = _get_comm_group(
......
...@@ -148,7 +148,7 @@ class DistributedFusedFeedForwardImpl(DistributedOperatorImpl): ...@@ -148,7 +148,7 @@ class DistributedFusedFeedForwardImpl(DistributedOperatorImpl):
rank_id = dist_op_context.rank_id rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
if rank_id not in op_dist_attr.process_mesh.processes: if rank_id not in op_dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank( rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id ctx, op_dist_attr.process_mesh, rank_id
) )
...@@ -163,8 +163,8 @@ class DistributedFusedFeedForwardImpl(DistributedOperatorImpl): ...@@ -163,8 +163,8 @@ class DistributedFusedFeedForwardImpl(DistributedOperatorImpl):
), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
linear1_weight_col_dim_mapping linear1_weight_col_dim_mapping
) )
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.processes process_mesh_group = op_dist_attr.process_mesh.process_ids
parallel_axis = linear1_weight_col_dim_mapping parallel_axis = linear1_weight_col_dim_mapping
group_ranks = _get_comm_group( group_ranks = _get_comm_group(
...@@ -190,7 +190,7 @@ class DistributedFusedFeedForwardImpl(DistributedOperatorImpl): ...@@ -190,7 +190,7 @@ class DistributedFusedFeedForwardImpl(DistributedOperatorImpl):
rank_id = dist_op_context.rank_id rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
if rank_id not in op_dist_attr.process_mesh.processes: if rank_id not in op_dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank( rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id ctx, op_dist_attr.process_mesh, rank_id
) )
...@@ -205,8 +205,8 @@ class DistributedFusedFeedForwardImpl(DistributedOperatorImpl): ...@@ -205,8 +205,8 @@ class DistributedFusedFeedForwardImpl(DistributedOperatorImpl):
), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
linear2_weight_col_dim_mapping linear2_weight_col_dim_mapping
) )
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.processes process_mesh_group = op_dist_attr.process_mesh.process_ids
parallel_axis = linear2_weight_col_dim_mapping parallel_axis = linear2_weight_col_dim_mapping
group_ranks = _get_comm_group( group_ranks = _get_comm_group(
......
...@@ -299,7 +299,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -299,7 +299,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
), "backward op [{}] don't have dist attribute !".format(str(backward_op)) ), "backward op [{}] don't have dist attribute !".format(str(backward_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in dist_attr.process_mesh.processes: if rank_id not in dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh, rank_id) rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh, rank_id)
assert 'Y' in kwargs, "input [{}] is not given".format('Y') assert 'Y' in kwargs, "input [{}] is not given".format('Y')
...@@ -341,8 +341,8 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -341,8 +341,8 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
X_var_dims_mapping = dist_attr.get_input_dims_mapping(X_var.name) X_var_dims_mapping = dist_attr.get_input_dims_mapping(X_var.name)
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name) Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name)
process_mesh_shape = dist_attr.process_mesh.topology process_mesh_shape = dist_attr.process_mesh.shape
process_mesh_group = dist_attr.process_mesh.processes process_mesh_group = dist_attr.process_mesh.process_ids
trans_x = None trans_x = None
trans_y = None trans_y = None
...@@ -532,12 +532,12 @@ def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id): ...@@ -532,12 +532,12 @@ def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id):
process_mesh = param_dist_attr.process_mesh process_mesh = param_dist_attr.process_mesh
dim_mapping = param_dist_attr.dims_mapping dim_mapping = param_dist_attr.dims_mapping
for axis, size in enumerate(process_mesh.topology): for axis, size in enumerate(process_mesh.shape):
if size <= 1 or axis in dim_mapping: if size <= 1 or axis in dim_mapping:
pass pass
else: else:
group_ranks = _get_comm_group( group_ranks = _get_comm_group(
process_mesh.processes, process_mesh.topology, axis, rank_id process_mesh.process_ids, process_mesh.shape, axis, rank_id
) )
sync_group = new_process_group(group_ranks) sync_group = new_process_group(group_ranks)
...@@ -600,7 +600,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -600,7 +600,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
dist_op=dist_op, dist_context=ctx dist_op=dist_op, dist_context=ctx
) )
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
MatmulGradOpCost, ctx, processes, desc_mapping, cluster MatmulGradOpCost, ctx, processes, desc_mapping, cluster
) )
...@@ -631,7 +631,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -631,7 +631,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping( var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0] backward_op.input("X")[0]
) )
mesh_shape = process_mesh.topology mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if ( if (
batch_size_axis > -1 batch_size_axis > -1
...@@ -651,7 +651,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -651,7 +651,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op( desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx dist_op=dist_op, dist_context=ctx
) )
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
MatmulOpCost, ctx, processes, desc_mapping, cluster MatmulOpCost, ctx, processes, desc_mapping, cluster
) )
...@@ -749,7 +749,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -749,7 +749,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
), "backward op [{}] don't have dist attribute !".format(str(src_op)) ), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes: if rank_id not in op_dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank( rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id ctx, op_dist_attr.process_mesh, rank_id
) )
...@@ -791,8 +791,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -791,8 +791,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping matmul_col_dim_mapping
) )
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.processes process_mesh_group = op_dist_attr.process_mesh.process_ids
parallel_axis = matmul_col_dim_mapping parallel_axis = matmul_col_dim_mapping
group_ranks = _get_comm_group( group_ranks = _get_comm_group(
...@@ -986,7 +986,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -986,7 +986,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
parallel_axis=parallel_axis, parallel_axis=parallel_axis,
) )
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.process_ids
comm_op_cost_list = build_comm_costs_from_descs( comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
) )
...@@ -1005,7 +1005,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -1005,7 +1005,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping( var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0] backward_op.input("X")[0]
) )
mesh_shape = process_mesh.topology mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if ( if (
batch_size_axis > -1 batch_size_axis > -1
...@@ -1025,7 +1025,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -1025,7 +1025,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op( desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx dist_op=dist_op, dist_context=ctx
) )
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
MatmulOpCost, ctx, processes, desc_mapping, cluster MatmulOpCost, ctx, processes, desc_mapping, cluster
) )
...@@ -1131,7 +1131,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -1131,7 +1131,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
), "backward op [{}] don't have dist attribute !".format(str(src_op)) ), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes: if rank_id not in op_dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank( rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id ctx, op_dist_attr.process_mesh, rank_id
) )
...@@ -1173,8 +1173,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -1173,8 +1173,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping matmul_row_dim_mapping
) )
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.processes process_mesh_group = op_dist_attr.process_mesh.process_ids
parallel_axis = matmul_row_dim_mapping parallel_axis = matmul_row_dim_mapping
group_ranks = _get_comm_group( group_ranks = _get_comm_group(
...@@ -1329,7 +1329,7 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): ...@@ -1329,7 +1329,7 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
dist_op=dist_op, dist_context=ctx dist_op=dist_op, dist_context=ctx
) )
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
MatmulGradOpCost, ctx, processes, desc_mapping, cluster MatmulGradOpCost, ctx, processes, desc_mapping, cluster
) )
...@@ -1339,7 +1339,7 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): ...@@ -1339,7 +1339,7 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping( var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0] backward_op.input("X")[0]
) )
mesh_shape = process_mesh.topology mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if ( if (
batch_size_axis > -1 batch_size_axis > -1
...@@ -1360,7 +1360,7 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): ...@@ -1360,7 +1360,7 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op( desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx dist_op=dist_op, dist_context=ctx
) )
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
MatmulOpCost, ctx, processes, desc_mapping, cluster MatmulOpCost, ctx, processes, desc_mapping, cluster
) )
...@@ -1479,7 +1479,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1479,7 +1479,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
backward_op.input("Y")[0] backward_op.input("Y")[0]
) )
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.process_ids
# col parallel: matmul + allreduce # col parallel: matmul + allreduce
if backward_op.attr("trans_y"): if backward_op.attr("trans_y"):
Y_var_dim_mapping.reverse() Y_var_dim_mapping.reverse()
...@@ -1526,7 +1526,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1526,7 +1526,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping( var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0] backward_op.input("X")[0]
) )
mesh_shape = process_mesh.topology mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if ( if (
batch_size_axis > -1 batch_size_axis > -1
...@@ -1547,7 +1547,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1547,7 +1547,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
comp_desc_mapping = build_comp_desc_from_dist_op( comp_desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx dist_op=dist_op, dist_context=ctx
) )
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.process_ids
comp_cost_mapping = build_comp_costs_from_descs( comp_cost_mapping = build_comp_costs_from_descs(
MatmulV2OpCost, ctx, processes, comp_desc_mapping, cluster MatmulV2OpCost, ctx, processes, comp_desc_mapping, cluster
) )
...@@ -1645,7 +1645,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1645,7 +1645,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
), "backward op [{}] don't have dist attribute !".format(str(src_op)) ), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes: if rank_id not in op_dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank( rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id ctx, op_dist_attr.process_mesh, rank_id
) )
...@@ -1687,8 +1687,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1687,8 +1687,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping matmul_col_dim_mapping
) )
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.processes process_mesh_group = op_dist_attr.process_mesh.process_ids
parallel_axis = matmul_col_dim_mapping parallel_axis = matmul_col_dim_mapping
group_ranks = _get_comm_group( group_ranks = _get_comm_group(
...@@ -1869,7 +1869,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1869,7 +1869,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
parallel_axis = Y_var_dim_mapping[0] parallel_axis = Y_var_dim_mapping[0]
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.process_ids
# calc comm op cost # calc comm op cost
var_names = [backward_op.input("Out@GRAD")[0]] var_names = [backward_op.input("Out@GRAD")[0]]
attrs = {"use_calc_stream": True, "use_model_parallel": True} attrs = {"use_calc_stream": True, "use_model_parallel": True}
...@@ -1900,7 +1900,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1900,7 +1900,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping( var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0] backward_op.input("X")[0]
) )
mesh_shape = process_mesh.topology mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if ( if (
batch_size_axis > -1 batch_size_axis > -1
...@@ -1920,7 +1920,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1920,7 +1920,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op( desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx dist_op=dist_op, dist_context=ctx
) )
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
MatmulV2OpCost, ctx, processes, desc_mapping, cluster MatmulV2OpCost, ctx, processes, desc_mapping, cluster
) )
...@@ -2025,7 +2025,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -2025,7 +2025,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
), "backward op [{}] don't have dist attribute !".format(str(src_op)) ), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes: if rank_id not in op_dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank( rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id ctx, op_dist_attr.process_mesh, rank_id
) )
...@@ -2067,8 +2067,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -2067,8 +2067,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping matmul_row_dim_mapping
) )
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.processes process_mesh_group = op_dist_attr.process_mesh.process_ids
parallel_axis = matmul_row_dim_mapping parallel_axis = matmul_row_dim_mapping
group_ranks = _get_comm_group( group_ranks = _get_comm_group(
...@@ -2222,7 +2222,7 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): ...@@ -2222,7 +2222,7 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op( desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx dist_op=dist_op, dist_context=ctx
) )
processes = process_mesh.processes processes = process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster
) )
...@@ -2232,7 +2232,7 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): ...@@ -2232,7 +2232,7 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping( var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0] backward_op.input("X")[0]
) )
mesh_shape = process_mesh.topology mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if ( if (
batch_size_axis > -1 batch_size_axis > -1
...@@ -2253,7 +2253,7 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): ...@@ -2253,7 +2253,7 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op( desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx dist_op=dist_op, dist_context=ctx
) )
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
MatmulV2OpCost, ctx, processes, desc_mapping, cluster MatmulV2OpCost, ctx, processes, desc_mapping, cluster
) )
...@@ -2386,7 +2386,7 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2386,7 +2386,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
dist_op=dist_op, dist_context=ctx dist_op=dist_op, dist_context=ctx
) )
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
MulGradOpCost, ctx, processes, desc_mapping, cluster MulGradOpCost, ctx, processes, desc_mapping, cluster
) )
...@@ -2417,7 +2417,7 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2417,7 +2417,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping( var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0] backward_op.input("X")[0]
) )
mesh_shape = process_mesh.topology mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if ( if (
batch_size_axis > -1 batch_size_axis > -1
...@@ -2437,7 +2437,7 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2437,7 +2437,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op( desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx dist_op=dist_op, dist_context=ctx
) )
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
MulOpCost, ctx, processes, desc_mapping, cluster MulOpCost, ctx, processes, desc_mapping, cluster
) )
...@@ -2530,7 +2530,7 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2530,7 +2530,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
), "backward op [{}] don't have dist attribute !".format(str(src_op)) ), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes: if rank_id not in op_dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank( rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id ctx, op_dist_attr.process_mesh, rank_id
) )
...@@ -2566,8 +2566,8 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2566,8 +2566,8 @@ class DistributedMulImpl0(DistributedOperatorImpl):
), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping matmul_col_dim_mapping
) )
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.processes process_mesh_group = op_dist_attr.process_mesh.process_ids
parallel_axis = matmul_col_dim_mapping parallel_axis = matmul_col_dim_mapping
group_ranks = _get_comm_group( group_ranks = _get_comm_group(
...@@ -2778,7 +2778,7 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2778,7 +2778,7 @@ class DistributedMulImpl1(DistributedOperatorImpl):
attrs=attrs, attrs=attrs,
parallel_axis=parallel_axis, parallel_axis=parallel_axis,
) )
processes = process_mesh.processes processes = process_mesh.process_ids
comm_op_cost_list = build_comm_costs_from_descs( comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
) )
...@@ -2797,7 +2797,7 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2797,7 +2797,7 @@ class DistributedMulImpl1(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping( var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0] backward_op.input("X")[0]
) )
mesh_shape = process_mesh.topology mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if ( if (
batch_size_axis > -1 batch_size_axis > -1
...@@ -2817,7 +2817,7 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2817,7 +2817,7 @@ class DistributedMulImpl1(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op( desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx dist_op=dist_op, dist_context=ctx
) )
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
MulOpCost, ctx, processes, desc_mapping, cluster MulOpCost, ctx, processes, desc_mapping, cluster
) )
...@@ -2919,7 +2919,7 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2919,7 +2919,7 @@ class DistributedMulImpl1(DistributedOperatorImpl):
), "backward op [{}] don't have dist attribute !".format(str(src_op)) ), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes: if rank_id not in op_dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank( rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id ctx, op_dist_attr.process_mesh, rank_id
) )
...@@ -2955,8 +2955,8 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2955,8 +2955,8 @@ class DistributedMulImpl1(DistributedOperatorImpl):
), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( ), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping matmul_row_dim_mapping
) )
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.processes process_mesh_group = op_dist_attr.process_mesh.process_ids
parallel_axis = matmul_row_dim_mapping parallel_axis = matmul_row_dim_mapping
group_ranks = _get_comm_group( group_ranks = _get_comm_group(
...@@ -3131,7 +3131,7 @@ class DistributedMulImpl2(DistributedOperatorImpl): ...@@ -3131,7 +3131,7 @@ class DistributedMulImpl2(DistributedOperatorImpl):
dist_op=dist_op, dist_context=ctx dist_op=dist_op, dist_context=ctx
) )
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
MulGradOpCost, ctx, processes, desc_mapping, cluster MulGradOpCost, ctx, processes, desc_mapping, cluster
) )
...@@ -3141,7 +3141,7 @@ class DistributedMulImpl2(DistributedOperatorImpl): ...@@ -3141,7 +3141,7 @@ class DistributedMulImpl2(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping( var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0] backward_op.input("X")[0]
) )
mesh_shape = process_mesh.topology mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if ( if (
batch_size_axis > -1 batch_size_axis > -1
...@@ -3162,7 +3162,7 @@ class DistributedMulImpl2(DistributedOperatorImpl): ...@@ -3162,7 +3162,7 @@ class DistributedMulImpl2(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op( desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx dist_op=dist_op, dist_context=ctx
) )
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
MulOpCost, ctx, processes, desc_mapping, cluster MulOpCost, ctx, processes, desc_mapping, cluster
) )
......
...@@ -154,7 +154,7 @@ class DistributedPNormImpl(DistributedOperatorImpl): ...@@ -154,7 +154,7 @@ class DistributedPNormImpl(DistributedOperatorImpl):
output_name output_name
) )
if rank_id not in op_dist_attr.process_mesh.processes: if rank_id not in op_dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank( rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id ctx, op_dist_attr.process_mesh, rank_id
) )
...@@ -164,8 +164,8 @@ class DistributedPNormImpl(DistributedOperatorImpl): ...@@ -164,8 +164,8 @@ class DistributedPNormImpl(DistributedOperatorImpl):
for axis in range(len(in_dims_mapping)): for axis in range(len(in_dims_mapping)):
if in_dims_mapping[axis] != -1: if in_dims_mapping[axis] != -1:
break break
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.processes process_mesh_group = op_dist_attr.process_mesh.process_ids
group_ranks = _get_comm_group( group_ranks = _get_comm_group(
process_mesh_group, process_mesh_shape, axis, rank_id process_mesh_group, process_mesh_shape, axis, rank_id
) )
...@@ -301,8 +301,8 @@ class DistributedPNormImpl(DistributedOperatorImpl): ...@@ -301,8 +301,8 @@ class DistributedPNormImpl(DistributedOperatorImpl):
ctx.set_op_dist_attr_for_program(p_norm_grad_op, op_dist_attr) ctx.set_op_dist_attr_for_program(p_norm_grad_op, op_dist_attr)
# 2. insert slice op # 2. insert slice op
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.processes process_mesh_group = op_dist_attr.process_mesh.process_ids
dims_mapping = [0] + [-1 for _ in range(len(new_X_grad.shape) - 1)] dims_mapping = [0] + [-1 for _ in range(len(new_X_grad.shape) - 1)]
from ..reshard import Resharder from ..reshard import Resharder
......
...@@ -67,7 +67,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -67,7 +67,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
shape_list = op.desc.attr("shape") shape_list = op.desc.attr("shape")
# got dist attribute info # got dist attribute info
dim_mapping = dist_attr.get_output_dims_mapping(op.output("Out")[0]) dim_mapping = dist_attr.get_output_dims_mapping(op.output("Out")[0])
process_mesh_shape = dist_attr.process_mesh.topology process_mesh_shape = dist_attr.process_mesh.shape
# modify target shape # modify target shape
for idx, axis in enumerate(dim_mapping): for idx, axis in enumerate(dim_mapping):
...@@ -81,7 +81,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -81,7 +81,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op( desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx dist_op=dist_op, dist_context=ctx
) )
processes = dist_attr.process_mesh.processes processes = dist_attr.process_mesh.process_ids
for key in desc_mapping: for key in desc_mapping:
desc_mapping[key]["shape"] = shape_list desc_mapping[key]["shape"] = shape_list
...@@ -100,7 +100,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -100,7 +100,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
) )
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.process_ids
op_type = dist_op.serial_op.type op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
...@@ -119,7 +119,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -119,7 +119,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping = dist_attr.get_input_dims_mapping(varname) var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
...@@ -266,7 +266,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -266,7 +266,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
# got dist attribute info # got dist attribute info
dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.shape
# modify target shape # modify target shape
for idx, axis in enumerate(dim_mapping): for idx, axis in enumerate(dim_mapping):
...@@ -315,7 +315,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -315,7 +315,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
shape_list = op.desc.attr("shape") shape_list = op.desc.attr("shape")
# got dist attribute info # got dist attribute info
dim_mapping = dist_attr.get_output_dims_mapping(op.output("Out")[0]) dim_mapping = dist_attr.get_output_dims_mapping(op.output("Out")[0])
process_mesh_shape = dist_attr.process_mesh.topology process_mesh_shape = dist_attr.process_mesh.shape
# modify target shape # modify target shape
for idx, axis in enumerate(dim_mapping): for idx, axis in enumerate(dim_mapping):
...@@ -329,7 +329,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -329,7 +329,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op( desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx dist_op=dist_op, dist_context=ctx
) )
processes = dist_attr.process_mesh.processes processes = dist_attr.process_mesh.process_ids
for key in desc_mapping: for key in desc_mapping:
desc_mapping[key]["shape"] = shape_list desc_mapping[key]["shape"] = shape_list
...@@ -348,7 +348,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -348,7 +348,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
) )
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.process_ids
op_type = dist_op.serial_op.type op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
...@@ -367,7 +367,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -367,7 +367,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping = dist_attr.get_input_dims_mapping(varname) var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
...@@ -517,7 +517,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -517,7 +517,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
# got dist attribute info # got dist attribute info
dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.shape
# modify target shape # modify target shape
for idx, axis in enumerate(dim_mapping): for idx, axis in enumerate(dim_mapping):
...@@ -566,7 +566,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): ...@@ -566,7 +566,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
shape_list = op.desc.attr("shape") shape_list = op.desc.attr("shape")
# got dist attribute info # got dist attribute info
dim_mapping = dist_attr.get_output_dims_mapping(op.output("Out")[0]) dim_mapping = dist_attr.get_output_dims_mapping(op.output("Out")[0])
process_mesh_shape = dist_attr.process_mesh.topology process_mesh_shape = dist_attr.process_mesh.shape
# modify target shape # modify target shape
for idx, axis in enumerate(dim_mapping): for idx, axis in enumerate(dim_mapping):
...@@ -580,7 +580,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): ...@@ -580,7 +580,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op( desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx dist_op=dist_op, dist_context=ctx
) )
processes = dist_attr.process_mesh.processes processes = dist_attr.process_mesh.process_ids
for key in desc_mapping: for key in desc_mapping:
desc_mapping[key]["shape"] = shape_list desc_mapping[key]["shape"] = shape_list
...@@ -599,7 +599,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): ...@@ -599,7 +599,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
) )
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.process_ids
op_type = dist_op.serial_op.type op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
...@@ -618,7 +618,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): ...@@ -618,7 +618,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping = dist_attr.get_input_dims_mapping(varname) var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
...@@ -761,7 +761,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): ...@@ -761,7 +761,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
# got dist attribute info # got dist attribute info
out_dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) out_dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.shape
# modify target shape # modify target shape
for idx, axis in enumerate(out_dim_mapping): for idx, axis in enumerate(out_dim_mapping):
......
...@@ -60,7 +60,7 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -60,7 +60,7 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op( desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx dist_op=dist_op, dist_context=ctx
) )
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
SoftmaxOpCost, ctx, processes, desc_mapping, cluster SoftmaxOpCost, ctx, processes, desc_mapping, cluster
) )
...@@ -76,7 +76,7 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -76,7 +76,7 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
) )
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
SoftmaxGradOpCost, ctx, processes, desc_mapping, cluster SoftmaxGradOpCost, ctx, processes, desc_mapping, cluster
) )
...@@ -93,7 +93,7 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -93,7 +93,7 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping = dist_attr.get_input_dims_mapping(varname) var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
......
...@@ -140,7 +140,7 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): ...@@ -140,7 +140,7 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op( desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx dist_op=dist_op, dist_context=ctx
) )
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.process_ids
op_type = dist_op.serial_op.type op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
Transpose2OpCost, ctx, processes, desc_mapping, cluster Transpose2OpCost, ctx, processes, desc_mapping, cluster
...@@ -157,7 +157,7 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): ...@@ -157,7 +157,7 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
) )
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.process_ids
op_type = dist_op.serial_op.type op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs( cost_mapping = build_comp_costs_from_descs(
Transpose2GradOpCost, ctx, processes, desc_mapping, cluster Transpose2GradOpCost, ctx, processes, desc_mapping, cluster
...@@ -175,7 +175,7 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): ...@@ -175,7 +175,7 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping = dist_attr.get_input_dims_mapping(varname) var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
......
...@@ -79,7 +79,7 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl): ...@@ -79,7 +79,7 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl):
str(backward_op) str(backward_op)
) )
assert rank_id in dist_attr.process_mesh.processes assert rank_id in dist_attr.process_mesh.process_ids
assert 'X' in kwargs, "input [{}] is not given".format('X') assert 'X' in kwargs, "input [{}] is not given".format('X')
assert 'FoundInfinite' in kwargs, "input [{}] is not given".format( assert 'FoundInfinite' in kwargs, "input [{}] is not given".format(
...@@ -154,7 +154,7 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl): ...@@ -154,7 +154,7 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl):
rank_id rank_id
in ctx.get_tensor_dist_attr_for_program( in ctx.get_tensor_dist_attr_for_program(
main_block._var_recursive(varname) main_block._var_recursive(varname)
).process_mesh.processes ).process_mesh.process_ids
): ):
filter_vars.append(varname) filter_vars.append(varname)
......
...@@ -286,7 +286,7 @@ class AutoParallelizer: ...@@ -286,7 +286,7 @@ class AutoParallelizer:
_g_process_group_map.clear() _g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, []) _g_process_group_map[0] = ProcessGroup(0, [])
for process_mesh in self._dist_context._process_meshes: for process_mesh in self._dist_context._process_meshes:
_g_process_group_map[0].add_ranks(process_mesh.processes) _g_process_group_map[0].add_ranks(process_mesh.process_ids)
return ( return (
dist_optimize_ops, dist_optimize_ops,
dist_params_grads, dist_params_grads,
...@@ -482,7 +482,7 @@ class AutoParallelizer: ...@@ -482,7 +482,7 @@ class AutoParallelizer:
if dist_context is not None: if dist_context is not None:
pg0 = get_process_group(0) pg0 = get_process_group(0)
for process_mesh in dist_context._process_meshes: for process_mesh in dist_context._process_meshes:
pg0.add_ranks(process_mesh.processes) pg0.add_ranks(process_mesh.process_ids)
( (
dist_optimize_ops, dist_optimize_ops,
dist_params_grads, dist_params_grads,
......
...@@ -391,7 +391,7 @@ def _get_dist_shape(var, dist_attr): ...@@ -391,7 +391,7 @@ def _get_dist_shape(var, dist_attr):
var_shape = var.shape var_shape = var.shape
mapping = dist_attr.dims_mapping mapping = dist_attr.dims_mapping
mesh = dist_attr.process_mesh.topology mesh = dist_attr.process_mesh.shape
if mapping == []: if mapping == []:
return var_shape return var_shape
......
...@@ -74,7 +74,7 @@ class PlanFilter: ...@@ -74,7 +74,7 @@ class PlanFilter:
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
dims_mapping = op_dist_attr.get_input_dims_mapping(var_name) dims_mapping = op_dist_attr.get_input_dims_mapping(var_name)
if not PlanFilter.check_dims_mapping_for_tensor( if not PlanFilter.check_dims_mapping_for_tensor(
process_mesh.topology, vars[var_name].shape, dims_mapping process_mesh.shape, vars[var_name].shape, dims_mapping
): ):
return False return False
if vars[var_name].is_data and len(dims_mapping) > 1: if vars[var_name].is_data and len(dims_mapping) > 1:
...@@ -85,7 +85,7 @@ class PlanFilter: ...@@ -85,7 +85,7 @@ class PlanFilter:
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(var_name) dims_mapping = op_dist_attr.get_output_dims_mapping(var_name)
if not PlanFilter.check_dims_mapping_for_tensor( if not PlanFilter.check_dims_mapping_for_tensor(
process_mesh.topology, vars[var_name].shape, dims_mapping process_mesh.shape, vars[var_name].shape, dims_mapping
): ):
return False return False
...@@ -217,13 +217,13 @@ class PlanSpace: ...@@ -217,13 +217,13 @@ class PlanSpace:
for var_name in chain(op.input_arg_names, op.output_arg_names): for var_name in chain(op.input_arg_names, op.output_arg_names):
visited = [ visited = [
False False
for _ in range(len(list(range(-1, len(process_mesh.topology))))) for _ in range(len(list(range(-1, len(process_mesh.shape)))))
] ]
depth = 0 depth = 0
path = [] path = []
dims_mapping_list = [] dims_mapping_list = []
PlanSpace._enum_dims_mapping( PlanSpace._enum_dims_mapping(
process_mesh.topology, process_mesh.shape,
visited, visited,
path, path,
depth, depth,
...@@ -590,7 +590,10 @@ class MCMC(SearchAlgorithm): ...@@ -590,7 +590,10 @@ class MCMC(SearchAlgorithm):
self.serial_program_info, dist_context, self.parallelizer self.serial_program_info, dist_context, self.parallelizer
) )
pipeline_config = ( pipeline_config = (
[process_mesh.processes for process_mesh in pipeline_process_meshes] [
process_mesh.process_ids
for process_mesh in pipeline_process_meshes
]
if pipeline_process_meshes is not None if pipeline_process_meshes is not None
else None else None
) )
...@@ -1060,7 +1063,7 @@ class MCMC(SearchAlgorithm): ...@@ -1060,7 +1063,7 @@ class MCMC(SearchAlgorithm):
# rebuild g_process_group # rebuild g_process_group
pg0 = get_process_group(0) pg0 = get_process_group(0)
for process_mesh in searched_dist_context._process_meshes: for process_mesh in searched_dist_context._process_meshes:
pg0.add_ranks(process_mesh.processes) pg0.add_ranks(process_mesh.process_ids)
end_time = time.time() end_time = time.time()
print( print(
"End MCMC searching: the min cost is {} and the search time is {}s.".format( "End MCMC searching: the min cost is {} and the search time is {}s.".format(
......
...@@ -17,6 +17,7 @@ import copy ...@@ -17,6 +17,7 @@ import copy
import numpy as np import numpy as np
import paddle import paddle
from paddle.fluid import core
# Use to store the previous and current process mesh # Use to store the previous and current process mesh
_g_previous_process_mesh = None _g_previous_process_mesh = None
...@@ -41,12 +42,12 @@ def reset_current_process_mesh(): ...@@ -41,12 +42,12 @@ def reset_current_process_mesh():
_g_current_process_mesh = _g_previous_process_mesh _g_current_process_mesh = _g_previous_process_mesh
class ProcessMesh: class ProcessMesh(core.ProcessMesh):
""" """
The `Processmesh` object describes the topology of the used processes. The `ProcessMesh` object describes the Cartesian topology of the used processes.
Args: Args:
mesh (list|numpy.array): an n-dimensional array describes the toplogy mesh (list|numpy.array): an n-dimensional array describes the topology
of the processes. of the processes.
dim_names (list, optional): the i-th element of this list gives the name of the dim_names (list, optional): the i-th element of this list gives the name of the
i-th dimension of the mesh. i-th dimension of the mesh.
...@@ -58,7 +59,7 @@ class ProcessMesh: ...@@ -58,7 +59,7 @@ class ProcessMesh:
mesh = auto.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"]) mesh = auto.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"])
assert mesh.shape == [2, 3] assert mesh.shape == [2, 3]
assert mesh.processe_ids == [2, 4, 5, 0, 1, 3] assert mesh.process_ids == [2, 4, 5, 0, 1, 3]
""" """
...@@ -77,6 +78,9 @@ class ProcessMesh: ...@@ -77,6 +78,9 @@ class ProcessMesh:
if isinstance(mesh, list): if isinstance(mesh, list):
mesh = np.array(mesh) mesh = np.array(mesh)
if dim_names is not None and not isinstance(dim_names, list):
raise ValueError('The dim_names must be an instance of list.')
self._mesh = mesh self._mesh = mesh
self._shape = list(self._mesh.shape) self._shape = list(self._mesh.shape)
self._process_ids = self._mesh.flatten().tolist() self._process_ids = self._mesh.flatten().tolist()
...@@ -104,6 +108,11 @@ class ProcessMesh: ...@@ -104,6 +108,11 @@ class ProcessMesh:
self._dim_names self._dim_names
), 'All dim_names {} must be unique.'.format(dim_names) ), 'All dim_names {} must be unique.'.format(dim_names)
# Follow the requirement for using pybind11
core.ProcessMesh.__init__(
self, self._shape, self._process_ids, self._dim_names
)
# Store all process meshes # Store all process meshes
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
...@@ -113,35 +122,7 @@ class ProcessMesh: ...@@ -113,35 +122,7 @@ class ProcessMesh:
from .process_group import get_process_group from .process_group import get_process_group
pg0 = get_process_group(0) pg0 = get_process_group(0)
pg0.add_ranks(self.processes) pg0.add_ranks(self.process_ids)
@property
def shape(self):
"""
Get the shape of this ProcessMesh.
"""
return self._shape
@property
def process_ids(self):
"""
Get the process ids belonging to this ProcessMesh.
"""
return self._process_ids
@property
def dim_names(self):
"""
Get the dimension names of this ProcessMesh.
"""
return self._dim_names
@property
def ndim(self):
"""
Get the number of dimension of this ProcessMesh.
"""
return len(self._shape)
@property @property
def mesh(self): def mesh(self):
...@@ -150,14 +131,6 @@ class ProcessMesh: ...@@ -150,14 +131,6 @@ class ProcessMesh:
""" """
return self._mesh return self._mesh
@property
def topology(self):
return self._shape
@property
def processes(self):
return self._process_ids
def __getitem__(self, index): def __getitem__(self, index):
if isinstance(index, tuple): if isinstance(index, tuple):
new_dim_names = [] new_dim_names = []
...@@ -207,9 +180,8 @@ class ProcessMesh: ...@@ -207,9 +180,8 @@ class ProcessMesh:
tensor tensor
) )
if dist_tensor is None: if dist_tensor is None:
dist_tensor = DistributedTensor( dist_tensor = DistributedTensor(cur_block.vars[name])
cur_block.vars[name], {"process_mesh": self} dist_tensor.dist_attr.process_mesh = self
)
dist_tensor.dist_attr.mark_annotated("process_mesh") dist_tensor.dist_attr.mark_annotated("process_mesh")
default_dist_ctx.add_dist_tensor_for_program(dist_tensor) default_dist_ctx.add_dist_tensor_for_program(dist_tensor)
else: else:
...@@ -221,7 +193,8 @@ class ProcessMesh: ...@@ -221,7 +193,8 @@ class ProcessMesh:
op = cur_block.ops[idx] op = cur_block.ops[idx]
dist_op = default_dist_ctx.get_dist_op_for_program(op) dist_op = default_dist_ctx.get_dist_op_for_program(op)
if dist_op is None: if dist_op is None:
dist_op = DistributedOperator(op, {"process_mesh": self}) dist_op = DistributedOperator(op)
dist_op.dist_attr.process_mesh = self
dist_op.dist_attr.mark_annotated("process_mesh") dist_op.dist_attr.mark_annotated("process_mesh")
default_dist_ctx.add_dist_op_for_program(dist_op) default_dist_ctx.add_dist_op_for_program(dist_op)
else: else:
...@@ -230,6 +203,13 @@ class ProcessMesh: ...@@ -230,6 +203,13 @@ class ProcessMesh:
dist_op.dist_attr.mark_annotated("process_mesh") dist_op.dist_attr.mark_annotated("process_mesh")
reset_current_process_mesh() reset_current_process_mesh()
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
new_process_mesh = ProcessMesh(np.array(self.mesh), self.dim_names)
memo[id(self)] = new_process_mesh
return new_process_mesh
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, ProcessMesh): if not isinstance(other, ProcessMesh):
return False return False
...@@ -245,3 +225,51 @@ class ProcessMesh: ...@@ -245,3 +225,51 @@ class ProcessMesh:
self.shape, self.process_ids, self.dim_names self.shape, self.process_ids, self.dim_names
) )
return str return str
def compute_compatible_process_mesh(process_mesh_list):
"""Compute the compatible process mesh given a list of process meshes."""
if not process_mesh_list:
return None
def _compute_compatible_process_mesh_of_two(pm1, pm2):
if pm1 is None:
return True, pm2
if pm2 is None:
return True, pm1
if pm1 == pm2:
return True, pm1
if pm1.process_ids == pm2.process_ids:
if len(pm1.shape) >= len(pm2.shape):
return True, pm1
else:
return True, pm2
process_set1 = set(pm1.process_ids)
process_set2 = set(pm2.process_ids)
if process_set1.issubset(process_set2):
return True, pm2
if process_set2.issubset(process_set1):
return True, pm1
return False, None
compatible_result = None
for process_mesh in process_mesh_list:
compatible, compatible_result = _compute_compatible_process_mesh_of_two(
compatible_result, process_mesh
)
if not compatible:
return None
return copy.deepcopy(compatible_result)
def merge_process_meshes(process_meshes):
"""Merge a list of process meshes."""
merged_process_mesh = None
merged_process_ids = set()
for process_mesh in process_meshes:
if process_mesh is not None:
process_ids = set(process_mesh.process_ids)
merged_process_ids = merged_process_ids.union(process_ids)
if len(merged_process_ids) != 0:
merged_process_mesh = ProcessMesh(list(merged_process_ids))
return merged_process_mesh
...@@ -829,7 +829,7 @@ class Remover: ...@@ -829,7 +829,7 @@ class Remover:
) )
).process_mesh ).process_mesh
) )
if rank_id in process_mesh.processes: if rank_id in process_mesh.process_ids:
need_save.append(var_name) need_save.append(var_name)
if not need_save: if not need_save:
remove_op_idx.append(idx) remove_op_idx.append(idx)
...@@ -845,7 +845,7 @@ class Remover: ...@@ -845,7 +845,7 @@ class Remover:
if op_dist_attr is not None: if op_dist_attr is not None:
op_process_mesh = op_dist_attr.process_mesh op_process_mesh = op_dist_attr.process_mesh
if ( if (
rank_id not in op_process_mesh.processes rank_id not in op_process_mesh.process_ids
and op.type not in not_remove_op_ref and op.type not in not_remove_op_ref
): ):
remove_op_idx.append(idx) remove_op_idx.append(idx)
...@@ -1408,9 +1408,11 @@ class Resharder: ...@@ -1408,9 +1408,11 @@ class Resharder:
op_process_mesh = dist_op.dist_attr.process_mesh op_process_mesh = dist_op.dist_attr.process_mesh
for process_mesh in self.dist_context.process_meshes: for process_mesh in self.dist_context.process_meshes:
if set(process_mesh.processes) & ( if set(process_mesh.process_ids) & (
set(op_process_mesh.processes) set(op_process_mesh.process_ids)
) and len(process_mesh.processes) < len(op_process_mesh.processes): ) and len(process_mesh.process_ids) < len(
op_process_mesh.process_ids
):
process_meshes.append(process_mesh) process_meshes.append(process_mesh)
# it means the process mesh is not a union when process meshes is null # it means the process mesh is not a union when process meshes is null
...@@ -1438,13 +1440,13 @@ class Resharder: ...@@ -1438,13 +1440,13 @@ class Resharder:
source_dims_mapping = tensor_dist_attr.dims_mapping source_dims_mapping = tensor_dist_attr.dims_mapping
source_process_mesh = tensor_dist_attr.process_mesh source_process_mesh = tensor_dist_attr.process_mesh
source_process_group = source_process_mesh.processes source_process_group = source_process_mesh.process_ids
source_process_shape = source_process_mesh.topology source_process_shape = source_process_mesh.shape
target_process_mesh = dist_attr[0] target_process_mesh = dist_attr[0]
target_dims_mapping = dist_attr[1] target_dims_mapping = dist_attr[1]
target_process_group = target_process_mesh.processes target_process_group = target_process_mesh.process_ids
target_process_shape = target_process_mesh.topology target_process_shape = target_process_mesh.shape
if source_tensor.shape[0] < 0: if source_tensor.shape[0] < 0:
assert source_tensor.shape[0] == -1 assert source_tensor.shape[0] == -1
...@@ -2141,9 +2143,11 @@ class Resharder: ...@@ -2141,9 +2143,11 @@ class Resharder:
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
op_process_mesh = dist_attr.process_mesh op_process_mesh = dist_attr.process_mesh
for process_mesh in self.dist_context.process_meshes: for process_mesh in self.dist_context.process_meshes:
if set(process_mesh.processes) & ( if set(process_mesh.process_ids) & (
set(op_process_mesh.processes) set(op_process_mesh.process_ids)
) and len(process_mesh.processes) < len(op_process_mesh.processes): ) and len(process_mesh.process_ids) < len(
op_process_mesh.process_ids
):
process_meshes.append(process_mesh) process_meshes.append(process_mesh)
# it means that the process mesh is not a union when process meshes is none # it means that the process mesh is not a union when process meshes is none
...@@ -2176,12 +2180,12 @@ class Resharder: ...@@ -2176,12 +2180,12 @@ class Resharder:
if process_mesh_count > 1: if process_mesh_count > 1:
global_process_mesh_idx = None global_process_mesh_idx = None
for process_mesh in self.dist_context.process_meshes: for process_mesh in self.dist_context.process_meshes:
for process in process_mesh.processes: for process in process_mesh.process_ids:
processes.add(process) processes.add(process)
for idx, process_mesh in enumerate( for idx, process_mesh in enumerate(
self.dist_context.process_meshes self.dist_context.process_meshes
): ):
if len(set(process_mesh.processes)) == len(processes): if len(set(process_mesh.process_ids)) == len(processes):
global_process_mesh_idx = idx global_process_mesh_idx = idx
break break
...@@ -2191,7 +2195,7 @@ class Resharder: ...@@ -2191,7 +2195,7 @@ class Resharder:
for i, mesh in enumerate(self.dist_context.process_meshes): for i, mesh in enumerate(self.dist_context.process_meshes):
if i == idx: if i == idx:
continue continue
if set(mesh.processes) < set(global_mesh.processes): if set(mesh.process_ids) < set(global_mesh.process_ids):
is_removed = True is_removed = True
if is_removed: if is_removed:
...@@ -2330,8 +2334,8 @@ class Resharder: ...@@ -2330,8 +2334,8 @@ class Resharder:
# deal with union tensor # deal with union tensor
if is_union_process_mesh_tensor: if is_union_process_mesh_tensor:
# if op process mesh is subset of union tensor process mesh, need no reshard # if op process mesh is subset of union tensor process mesh, need no reshard
if set(input_attr[0].processes) <= set( if set(input_attr[0].process_ids) <= set(
dist_tensor.dist_attr.process_mesh.processes dist_tensor.dist_attr.process_mesh.process_ids
): ):
continue continue
...@@ -2525,14 +2529,14 @@ class Resharder: ...@@ -2525,14 +2529,14 @@ class Resharder:
dist_tensor, output_attr, False dist_tensor, output_attr, False
): ):
tensor_processes = set( tensor_processes = set(
tensor_process_mesh.processes tensor_process_mesh.process_ids
) - ( ) - (
set(tensor_process_mesh.processes) set(tensor_process_mesh.process_ids)
& set(output_attr[0].processes) & set(output_attr[0].process_ids)
) )
if tensor_processes: if tensor_processes:
if len(tensor_processes) != len( if len(tensor_processes) != len(
output_attr[0].processes output_attr[0].process_ids
): ):
if dist_tensor.dist_attr.dims_mapping.count( if dist_tensor.dist_attr.dims_mapping.count(
-1 -1
...@@ -2555,13 +2559,15 @@ class Resharder: ...@@ -2555,13 +2559,15 @@ class Resharder:
recv_rank = tensor_process recv_rank = tensor_process
actual_index = index actual_index = index
if index >= len( if index >= len(
output_attr[0].processes output_attr[0].process_ids
): ):
actual_index = ( actual_index = (
index index
- len(output_attr[0].processes) - len(
) % len(output_attr[0].processes) output_attr[0].process_ids
item = output_attr[0].processes[ )
) % len(output_attr[0].process_ids)
item = output_attr[0].process_ids[
actual_index actual_index
] ]
if recv_rank == item: if recv_rank == item:
...@@ -2591,7 +2597,7 @@ class Resharder: ...@@ -2591,7 +2597,7 @@ class Resharder:
tensor_processes tensor_processes
): ):
recv_rank = tensor_process recv_rank = tensor_process
item = output_attr[0].processes[index] item = output_attr[0].process_ids[index]
if recv_rank == item: if recv_rank == item:
continue continue
if self.rank_id == item: if self.rank_id == item:
......
...@@ -114,7 +114,7 @@ def _copy_context(ref_dist_context): ...@@ -114,7 +114,7 @@ def _copy_context(ref_dist_context):
clear_all_process_groups() clear_all_process_groups()
ranks = [] ranks = []
for process_mesh in ref_dist_context._process_meshes: for process_mesh in ref_dist_context._process_meshes:
ranks.extend(process_mesh.processes) ranks.extend(process_mesh.process_ids)
new_process_group(list(set(ranks))) new_process_group(list(set(ranks)))
new_dist_context = DistributedContext() new_dist_context = DistributedContext()
......
...@@ -689,15 +689,15 @@ class ParallelTuner: ...@@ -689,15 +689,15 @@ class ParallelTuner:
assert process_mesh.ndim == 2 assert process_mesh.ndim == 2
dim_of_one = None dim_of_one = None
dim_of_other = None dim_of_other = None
if process_mesh.topology[0] == 1: if process_mesh.shape[0] == 1:
dim_of_one = 0 dim_of_one = 0
dim_of_other = 1 dim_of_other = 1
elif process_mesh.topology[1] == 1: elif process_mesh.shape[1] == 1:
dim_of_one = 1 dim_of_one = 1
dim_of_other = 0 dim_of_other = 0
if dim_of_one is not None: if dim_of_one is not None:
dist_attr.process_mesh = ProcessMesh(process_mesh.processes) dist_attr.process_mesh = ProcessMesh(process_mesh.process_ids)
self._dist_context.add_process_mesh(dist_attr.process_mesh) self._dist_context.add_process_mesh(dist_attr.process_mesh)
for arg_name in dist_attr.inputs_dist_attrs.keys(): for arg_name in dist_attr.inputs_dist_attrs.keys():
...@@ -715,7 +715,7 @@ class ParallelTuner: ...@@ -715,7 +715,7 @@ class ParallelTuner:
dims_mapping = dist_attr.get_input_dims_mapping(arg_name) dims_mapping = dist_attr.get_input_dims_mapping(arg_name)
# dynamic_dims = dist_attr.get_input_dynamic_dims(arg_name) # dynamic_dims = dist_attr.get_input_dynamic_dims(arg_name)
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
process_shape = process_mesh.topology process_shape = process_mesh.shape
tensor = dist_op.get_serial_input(arg_name) tensor = dist_op.get_serial_input(arg_name)
if dims_mapping: if dims_mapping:
tensor_shape = tensor.shape tensor_shape = tensor.shape
...@@ -748,7 +748,7 @@ class ParallelTuner: ...@@ -748,7 +748,7 @@ class ParallelTuner:
dims_mapping = dist_attr.get_output_dims_mapping(arg_name) dims_mapping = dist_attr.get_output_dims_mapping(arg_name)
# dynamic_dims = dist_attr.get_output_dynamic_dims(arg_name) # dynamic_dims = dist_attr.get_output_dynamic_dims(arg_name)
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
process_shape = process_mesh.topology process_shape = process_mesh.shape
tensor = dist_op.get_serial_output(arg_name) tensor = dist_op.get_serial_output(arg_name)
if dims_mapping: if dims_mapping:
...@@ -793,7 +793,7 @@ class ParallelTuner: ...@@ -793,7 +793,7 @@ class ParallelTuner:
input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping( input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
input_name input_name
) )
topology = dist_op.dist_attr.process_mesh.topology topology = dist_op.dist_attr.process_mesh.shape
input_tensor = dist_op.get_serial_input(input_name) input_tensor = dist_op.get_serial_input(input_name)
last_but_one_dim = ( last_but_one_dim = (
input_tensor.shape[-2] // topology[input_dims_mapping[-2]] input_tensor.shape[-2] // topology[input_dims_mapping[-2]]
......
...@@ -100,7 +100,7 @@ def convert_to_dims_mapping(shard_spec, process_mesh): ...@@ -100,7 +100,7 @@ def convert_to_dims_mapping(shard_spec, process_mesh):
for shard in shard_spec: for shard in shard_spec:
if shard is None: if shard is None:
dims_mapping.append(-1) dims_mapping.append(-1)
elif process_mesh.topology[process_mesh.dim_names.index(shard)] == 1: elif process_mesh.shape[process_mesh.dim_names.index(shard)] == 1:
dims_mapping.append(-1) dims_mapping.append(-1)
else: else:
dims_mapping.append(process_mesh.dim_names.index(shard)) dims_mapping.append(process_mesh.dim_names.index(shard))
...@@ -429,26 +429,26 @@ def _get_corresponding_rank(dist_context, target_mesh, rank): ...@@ -429,26 +429,26 @@ def _get_corresponding_rank(dist_context, target_mesh, rank):
coordinate = None coordinate = None
for mesh in dist_context.process_meshes: for mesh in dist_context.process_meshes:
if rank in mesh.processes and mesh.topology == target_mesh.topology: if rank in mesh.process_ids and mesh.shape == target_mesh.shape:
coordinate = _linear_idx2coordinate( coordinate = _linear_idx2coordinate(
mesh.topology, mesh.processes.index(rank) mesh.shape, mesh.process_ids.index(rank)
) )
break break
# assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format( # assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format(
# rank) # rank)
if coordinate is not None: if coordinate is not None:
return target_mesh.processes[ return target_mesh.process_ids[
_coordinate2linear_idx(mesh.topology, coordinate) _coordinate2linear_idx(mesh.shape, coordinate)
] ]
else: else:
return target_mesh.processes[0] return target_mesh.process_ids[0]
def _get_unshard_dist_shape(var, dist_attr): def _get_unshard_dist_shape(var, dist_attr):
var_shape = var.shape var_shape = var.shape
mapping = dist_attr.dims_mapping mapping = dist_attr.dims_mapping
mesh = dist_attr.process_mesh.topology mesh = dist_attr.process_mesh.shape
assert len(var_shape) == len( assert len(var_shape) == len(
mapping mapping
), "variable shape [{}] and dim_mapping [{}] is NOT match !".format( ), "variable shape [{}] and dim_mapping [{}] is NOT match !".format(
...@@ -832,8 +832,8 @@ def get_dist_attr(program, dist_context=None): ...@@ -832,8 +832,8 @@ def get_dist_attr(program, dist_context=None):
process_mesh = tensor_dist_attr.process_mesh process_mesh = tensor_dist_attr.process_mesh
dims_mapping = tensor_dist_attr.dims_mapping dims_mapping = tensor_dist_attr.dims_mapping
dist_attr[var.name] = { dist_attr[var.name] = {
"process_shape": process_mesh.topology, "process_shape": process_mesh.shape,
"process_group": process_mesh.processes, "process_group": process_mesh.process_ids,
"dims_mapping": dims_mapping, "dims_mapping": dims_mapping,
} }
return dist_attr return dist_attr
...@@ -2006,16 +2006,16 @@ def get_input_split_info(cur_rank, var, dist_context): ...@@ -2006,16 +2006,16 @@ def get_input_split_info(cur_rank, var, dist_context):
process_mesh = tensor_dist_attr.process_mesh process_mesh = tensor_dist_attr.process_mesh
dims_mapping = tensor_dist_attr.dims_mapping dims_mapping = tensor_dist_attr.dims_mapping
if cur_rank not in process_mesh.processes: if cur_rank not in process_mesh.process_ids:
rank_id = _get_corresponding_rank(dist_context, process_mesh, cur_rank) rank_id = _get_corresponding_rank(dist_context, process_mesh, cur_rank)
else: else:
rank_id = cur_rank rank_id = cur_rank
batch_size_axis = dims_mapping[0] batch_size_axis = dims_mapping[0]
if batch_size_axis > -1 and process_mesh.topology[batch_size_axis] > 1: if batch_size_axis > -1 and process_mesh.shape[batch_size_axis] > 1:
group_ranks = _get_comm_group( group_ranks = _get_comm_group(
process_mesh.processes, process_mesh.process_ids,
process_mesh.topology, process_mesh.shape,
batch_size_axis, batch_size_axis,
rank_id, rank_id,
) )
......
...@@ -164,8 +164,8 @@ class ClipHelper: ...@@ -164,8 +164,8 @@ class ClipHelper:
param = self.params[self.params_name.index(name)] param = self.params[self.params_name.index(name)]
dist_attr = self._get_dist_attr(name) dist_attr = self._get_dist_attr(name)
topology = dist_attr.process_mesh.topology topology = dist_attr.process_mesh.shape
processes = dist_attr.process_mesh.processes processes = dist_attr.process_mesh.process_ids
dims_mapping = dist_attr.dims_mapping dims_mapping = dist_attr.dims_mapping
return _is_about_global_norm( return _is_about_global_norm(
self.rank_id, self.rank_id,
...@@ -188,7 +188,7 @@ class ClipHelper: ...@@ -188,7 +188,7 @@ class ClipHelper:
def _is_local_var(self, name): def _is_local_var(self, name):
dist_attr = self._get_dist_attr(name) dist_attr = self._get_dist_attr(name)
assert dist_attr is not None assert dist_attr is not None
return self.rank_id in dist_attr.process_mesh.processes return self.rank_id in dist_attr.process_mesh.process_ids
def _init_dist_attr(self, op): def _init_dist_attr(self, op):
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistributedAttribute()
......
...@@ -870,13 +870,13 @@ def _inference_data_parallel_group_for_operator(rank_id, op, dist_context): ...@@ -870,13 +870,13 @@ def _inference_data_parallel_group_for_operator(rank_id, op, dist_context):
dist_attr = dist_context.get_op_dist_attr_for_program(op) dist_attr = dist_context.get_op_dist_attr_for_program(op)
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
input_dim_mapping = dist_attr.get_input_dims_mapping(input_name) input_dim_mapping = dist_attr.get_input_dims_mapping(input_name)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.shape
# TODO(JZ-LIANG) replace with specific batch size dimension # TODO(JZ-LIANG) replace with specific batch size dimension
batch_size_axis = input_dim_mapping[0] batch_size_axis = input_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
group_ranks = _get_comm_group( group_ranks = _get_comm_group(
process_mesh.processes, process_mesh.process_ids,
process_mesh.topology, process_mesh.shape,
batch_size_axis, batch_size_axis,
rank_id, rank_id,
) )
......
...@@ -180,7 +180,7 @@ class TestBaseCost(unittest.TestCase): ...@@ -180,7 +180,7 @@ class TestBaseCost(unittest.TestCase):
for op in train_program.global_block().ops: for op in train_program.global_block().ops:
dist_op = dist_context.get_dist_op_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
if dist_op: if dist_op:
processes = dist_op.dist_attr.process_mesh.processes processes = dist_op.dist_attr.process_mesh.process_ids
comp_descs = build_comp_desc_from_dist_op(dist_op, dist_context) comp_descs = build_comp_desc_from_dist_op(dist_op, dist_context)
self.assertTrue(isinstance(comp_descs, dict) and comp_descs) self.assertTrue(isinstance(comp_descs, dict) and comp_descs)
var_names = None var_names = None
......
...@@ -52,13 +52,13 @@ class MLPLayer(nn.Layer): ...@@ -52,13 +52,13 @@ class MLPLayer(nn.Layer):
out = self.norm(input) out = self.norm(input)
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, auto.ProcessMesh([0, 1], "x"), [None, "x"] self.linear0.weight, auto.ProcessMesh([0, 1], ["x"]), [None, "x"]
) )
out = self.linear0(out) out = self.linear0(out)
out = F.gelu(out, approximate=True) out = F.gelu(out, approximate=True)
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, auto.ProcessMesh([0, 1], "x"), ["x", None] self.linear1.weight, auto.ProcessMesh([0, 1], ["x"]), ["x", None]
) )
out = self.linear1(out) out = self.linear1(out)
......
...@@ -98,7 +98,7 @@ class TestDistOpCost(unittest.TestCase): ...@@ -98,7 +98,7 @@ class TestDistOpCost(unittest.TestCase):
): ):
dist_op = dist_context.get_dist_op_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.processes processes = op_dist_attr.process_mesh.process_ids
if is_elementwise_op(op.type): if is_elementwise_op(op.type):
container = get_distributed_operator_impl_container( container = get_distributed_operator_impl_container(
"elementwise" "elementwise"
...@@ -205,7 +205,7 @@ class TestDistOpCost(unittest.TestCase): ...@@ -205,7 +205,7 @@ class TestDistOpCost(unittest.TestCase):
for idx, op in enumerate(ops): for idx, op in enumerate(ops):
dist_op = dist_context.get_dist_op_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.processes processes = op_dist_attr.process_mesh.process_ids
if is_elementwise_op(op.type): if is_elementwise_op(op.type):
container = get_distributed_operator_impl_container( container = get_distributed_operator_impl_container(
"elementwise" "elementwise"
...@@ -313,7 +313,7 @@ class TestDistOpCost(unittest.TestCase): ...@@ -313,7 +313,7 @@ class TestDistOpCost(unittest.TestCase):
for idx, op in enumerate(ops): for idx, op in enumerate(ops):
dist_op = dist_context.get_dist_op_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.processes processes = op_dist_attr.process_mesh.process_ids
if is_elementwise_op(op.type): if is_elementwise_op(op.type):
container = get_distributed_operator_impl_container( container = get_distributed_operator_impl_container(
"elementwise" "elementwise"
...@@ -421,7 +421,7 @@ class TestDistOpCost(unittest.TestCase): ...@@ -421,7 +421,7 @@ class TestDistOpCost(unittest.TestCase):
for idx, op in enumerate(ops): for idx, op in enumerate(ops):
dist_op = dist_context.get_dist_op_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.processes processes = op_dist_attr.process_mesh.process_ids
if is_elementwise_op(op.type): if is_elementwise_op(op.type):
container = get_distributed_operator_impl_container( container = get_distributed_operator_impl_container(
"elementwise" "elementwise"
......
...@@ -23,7 +23,11 @@ import paddle.static as static ...@@ -23,7 +23,11 @@ import paddle.static as static
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.dist_context import (
get_default_distributed_context, get_default_distributed_context,
) )
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh from paddle.distributed.auto_parallel.process_mesh import (
ProcessMesh,
compute_compatible_process_mesh,
merge_process_meshes,
)
paddle.enable_static() paddle.enable_static()
...@@ -129,7 +133,7 @@ class TestProcessMesh(unittest.TestCase): ...@@ -129,7 +133,7 @@ class TestProcessMesh(unittest.TestCase):
initializer_range=0.02, initializer_range=0.02,
) )
with ProcessMesh(mesh, "d"): with ProcessMesh(mesh, ["d"]):
out = mlp(input) out = mlp(input)
default_program = paddle.fluid.default_main_program() default_program = paddle.fluid.default_main_program()
...@@ -151,6 +155,67 @@ class TestProcessMesh(unittest.TestCase): ...@@ -151,6 +155,67 @@ class TestProcessMesh(unittest.TestCase):
dist_op.dist_attr.process_mesh, ProcessMesh(mesh) dist_op.dist_attr.process_mesh, ProcessMesh(mesh)
) )
def test_compute_compatible_process_mesh(self):
process_mesh1 = ProcessMesh(
[[0, 1, 2], [3, 4, 5]], dim_names=["x", "y"]
)
compatible_process_mesh = compute_compatible_process_mesh(
[process_mesh1, None]
)
self.assertEqual(compatible_process_mesh, process_mesh1)
compatible_process_mesh = compute_compatible_process_mesh(
[None, process_mesh1]
)
self.assertEqual(compatible_process_mesh, process_mesh1)
process_mesh2 = ProcessMesh([[0, 1, 2], [3, 4, 5]])
compatible_process_mesh = compute_compatible_process_mesh(
[process_mesh1, process_mesh2]
)
self.assertEqual(compatible_process_mesh, process_mesh1)
self.assertEqual(compatible_process_mesh, process_mesh2)
process_mesh2 = ProcessMesh([[0, 1, 2, 3, 4, 5]])
compatible_process_mesh = compute_compatible_process_mesh(
[process_mesh1, process_mesh2]
)
self.assertEqual(compatible_process_mesh, process_mesh1)
process_mesh2 = ProcessMesh([[0, 1, 2]])
compatible_process_mesh = compute_compatible_process_mesh(
[process_mesh1, process_mesh2]
)
self.assertEqual(compatible_process_mesh, process_mesh1)
def test_merge_process_meshes(self):
process_mesh1 = ProcessMesh(
[[0, 1, 2], [3, 4, 5]], dim_names=["x", "y"]
)
merged_process_mesh = merge_process_meshes([process_mesh1, None])
self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
merged_process_mesh = merge_process_meshes([None, process_mesh1])
self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
process_mesh2 = ProcessMesh([[0, 1, 2], [3, 4, 5]])
merged_process_mesh = merge_process_meshes(
[process_mesh1, process_mesh2]
)
self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
process_mesh2 = ProcessMesh([[0, 1, 2]])
merged_process_mesh = merge_process_meshes(
[process_mesh1, process_mesh2]
)
self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
process_mesh2 = ProcessMesh([[6, 7]])
merged_process_mesh = merge_process_meshes(
[process_mesh1, process_mesh2]
)
self.assertEqual(
merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5, 6, 7])
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -80,7 +80,6 @@ class TestProcessMesh(unittest.TestCase): ...@@ -80,7 +80,6 @@ class TestProcessMesh(unittest.TestCase):
[[0, 1, 2], [3, 4, 5]], dim_names=["x", "y"] [[0, 1, 2], [3, 4, 5]], dim_names=["x", "y"]
) )
merged_process_mesh = merge_process_mesh([process_mesh1, None]) merged_process_mesh = merge_process_mesh([process_mesh1, None])
print(merged_process_mesh)
self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5])) self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
merged_process_mesh = merge_process_mesh([None, process_mesh1]) merged_process_mesh = merge_process_mesh([None, process_mesh1])
self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5])) self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
......
...@@ -105,7 +105,7 @@ def initialization_check( ...@@ -105,7 +105,7 @@ def initialization_check(
): ):
if 'mp' in mode: if 'mp' in mode:
group_ranks = _get_comm_group( group_ranks = _get_comm_group(
process_mesh.processes, process_mesh.topology, mp_parallel_axis, 3 process_mesh.process_ids, process_mesh.shape, mp_parallel_axis, 3
) )
mp_ring_id = new_process_group(group_ranks).id mp_ring_id = new_process_group(group_ranks).id
broadcast_ops = [ broadcast_ops = [
...@@ -124,7 +124,7 @@ def initialization_check( ...@@ -124,7 +124,7 @@ def initialization_check(
if 'dp' in mode: if 'dp' in mode:
group_ranks = _get_comm_group( group_ranks = _get_comm_group(
process_mesh.processes, process_mesh.topology, dp_parallel_axis, 3 process_mesh.process_ids, process_mesh.shape, dp_parallel_axis, 3
) )
dp_ring_id = new_process_group(group_ranks).id dp_ring_id = new_process_group(group_ranks).id
nparam = len(serial_startup_prog.all_parameters()) nparam = len(serial_startup_prog.all_parameters())
......
...@@ -958,12 +958,12 @@ class TestGPTPartitioner(unittest.TestCase): ...@@ -958,12 +958,12 @@ class TestGPTPartitioner(unittest.TestCase):
dp_parallel_axis = 0 dp_parallel_axis = 0
group_ranks = _get_comm_group( group_ranks = _get_comm_group(
process_mesh.processes, process_mesh.topology, mp_parallel_axis, 3 process_mesh.process_ids, process_mesh.shape, mp_parallel_axis, 3
) )
mp_ring_id = new_process_group(group_ranks).id mp_ring_id = new_process_group(group_ranks).id
group_ranks = _get_comm_group( group_ranks = _get_comm_group(
process_mesh.processes, process_mesh.topology, dp_parallel_axis, 3 process_mesh.process_ids, process_mesh.shape, dp_parallel_axis, 3
) )
dp_ring_id = new_process_group(group_ranks).id dp_ring_id = new_process_group(group_ranks).id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册