未验证 提交 1c0afa79 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Merge the python and c++ impls of ProcessMesh (#47503)

* [Auto Parallel] Rename methods of ProcessMesh

* [Auto Parallel] Impl the python process_mesh by the c++ one

* [Auto Parallel] Add some minor modifications

* [Auto Parallel] Rename some methods

* [Auto Parallel] Remove unnecessary codes

* [Auto Parallel] Add back some removed files

* [Auto Parallel] Fix bugs

* [Auto Parallel] Fix a bug

* Update process_mesh.cc

* [Auto Parallel] Fix a bug
上级 3f614f48
......@@ -69,6 +69,12 @@ void BindAutoParallel(py::module *m) {
.def("contains", &ProcessMesh::contains)
.def(py::self == py::self)
.def(py::self != py::self)
.def("__copy__",
[](const ProcessMesh &self) { return ProcessMesh(self); })
.def(
"__deepcopy__",
[](const ProcessMesh &self, py::dict) { return ProcessMesh(self); },
py::arg("memo"))
.def("__str__", &ProcessMesh::to_string);
py::class_<DeviceCapability>(*m, "DeviceCapability")
......@@ -131,7 +137,7 @@ void BindAutoParallel(py::module *m) {
const std::vector<std::string> &>(),
py::arg("name"),
py::arg("shape"),
py::arg("process_ids"),
py::arg("device_ids"),
py::arg("dim_names"))
.def_property_readonly("name", &DeviceMesh::name)
.def_property_readonly("shape", &DeviceMesh::shape)
......@@ -165,6 +171,8 @@ void BindAutoParallel(py::module *m) {
&DeviceMesh::dim_size))
.def(py::self == py::self)
.def(py::self != py::self)
.def("__copy__",
[](const TensorDistAttr &self) { return TensorDistAttr(self); })
.def(
"__deepcopy__",
[](const TensorDistAttr &self, py::dict) {
......@@ -256,6 +264,8 @@ void BindAutoParallel(py::module *m) {
.def("parse_from_string", &OperatorDistAttr::parse_from_string)
.def(py::self == py::self)
.def(py::self != py::self)
.def("__copy__",
[](const OperatorDistAttr &self) { return OperatorDistAttr(self); })
.def(
"__deepcopy__",
[](const OperatorDistAttr &self, py::dict) {
......
......@@ -25,7 +25,7 @@ from .dist_attribute import (
from .dist_context import _node_id
from .operators import find_compatible_distributed_operator_impls
from .process_group import get_world_process_group
from .process_mesh import ProcessMesh
from .process_mesh import ProcessMesh, compute_compatible_process_mesh
from .utils import (
__no_shape_var_type__,
get_logger,
......@@ -34,47 +34,12 @@ from .utils import (
)
def compute_compatible_process_mesh(process_mesh_list):
"""Compute the compatible process mesh given a list of process meshes."""
if not process_mesh_list:
return None
def _compute_compatible_process_mesh_two(pm1, pm2):
if pm1 is None:
return True, pm2
if pm2 is None:
return True, pm1
if pm1 == pm2:
return True, pm1
if pm1.processes == pm2.processes:
if len(pm1.topology) >= len(pm2.topology):
return True, pm1
else:
return True, pm2
process_set1 = set(pm1.processes)
process_set2 = set(pm2.processes)
if process_set1.issubset(process_set2):
return True, pm2
if process_set2.issubset(process_set1):
return True, pm1
return False, None
compatible_result = None
for process_mesh in process_mesh_list:
compatible, compatible_result = _compute_compatible_process_mesh_two(
compatible_result, process_mesh
)
if not compatible:
return None
return copy.deepcopy(compatible_result)
def compute_compatible_dim_mapping(dim_mapping_list):
"""Compute the compatible dim mapping given a list of dim mapping."""
if not dim_mapping_list:
return None
def _compute_compatible_dim_mapping_two(dm1, dm2):
def _compute_compatible_dim_mapping_of_two(dm1, dm2):
if dm1 == -1:
return True, dm2
if dm2 == -1:
......@@ -85,7 +50,7 @@ def compute_compatible_dim_mapping(dim_mapping_list):
compatible_result = -1
for mapping in dim_mapping_list:
compatible, compatible_result = _compute_compatible_dim_mapping_two(
compatible, compatible_result = _compute_compatible_dim_mapping_of_two(
compatible_result, mapping
)
if not compatible:
......@@ -122,9 +87,9 @@ def merge_process_mesh_two(pm1, pm2):
if pm1 is None and pm2 is None:
return None
if pm1 is not None:
process_set1 = set(pm1.processes)
process_set1 = set(pm1.process_ids)
if pm2 is not None:
process_set2 = set(pm2.processes)
process_set2 = set(pm2.process_ids)
merged_process_set = process_set1.union(process_set2)
merged_process_mesh = ProcessMesh(list(merged_process_set))
return merged_process_mesh
......@@ -134,11 +99,9 @@ def _validate_dims_mapping(dims_mapping, process_mesh):
if dims_mapping is None:
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
process_mesh.topology
):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(process_mesh.shape):
return False
for i in range(len(process_mesh.topology)):
for i in range(len(process_mesh.shape)):
if dims_mapping.count(i) > 1:
return False
return True
......
......@@ -80,7 +80,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
assert process_mesh, "Process mesh must not be None."
processes = process_mesh.processes
processes = process_mesh.process_ids
for process in processes:
desc = {}
desc["op"] = op.type
......@@ -103,7 +103,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
global_sizes = var.shape
# NOTE: When support uneven partition, the shard_sizes will be got from dist_attr.
shard_sizes = None
topology = process_mesh.topology
topology = process_mesh.shape
shape = DistributedTensor.get_local_sizes(
global_sizes,
dims_mapping,
......@@ -129,7 +129,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
)
relative_idx = _get_idx_in_axis(
processes,
dist_attr.process_mesh.topology,
dist_attr.process_mesh.shape,
embedding_row_dim_mapping,
process,
)
......@@ -153,8 +153,8 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
process_mesh = dist_attr.process_mesh
global_sizes = var.shape
shard_sizes = None
processes = process_mesh.processes
topology = process_mesh.topology
processes = process_mesh.process_ids
topology = process_mesh.shape
shape = DistributedTensor.get_local_sizes(
global_sizes,
dims_mapping,
......@@ -170,7 +170,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
# Modify shape attr according to how output are partitioned
out_name = var_name_list[0]
dims_mapping = dist_attr.get_output_dims_mapping(out_name)
process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_shape = dist_attr.process_mesh.shape
shape_list = op.attr("shape")
# Modify target shape
for idx, axis in enumerate(dims_mapping):
......@@ -253,7 +253,7 @@ def build_comm_desc_from_dist_op(
process_mesh = dist_attr.process_mesh
assert process_mesh, "Process mesh must not be None."
processes = process_mesh.processes
processes = process_mesh.process_ids
op_descs = {}
for process in processes:
rank_id = process
......@@ -295,7 +295,7 @@ def build_comm_desc_from_dist_op(
)
global_sizes = var.shape
shard_sizes = None
topology = process_mesh.topology
topology = process_mesh.shape
shape = DistributedTensor.get_local_sizes(
global_sizes,
dims_mapping,
......@@ -311,8 +311,8 @@ def build_comm_desc_from_dist_op(
# Get comm group by parallel_axis or the given group_ranks.
if parallel_axis is not None:
process_mesh_shape = process_mesh.topology
process_mesh_group = process_mesh.processes
process_mesh_shape = process_mesh.shape
process_mesh_group = process_mesh.process_ids
comm_group_ranks = _get_comm_group(
process_mesh_group,
process_mesh_shape,
......@@ -384,7 +384,7 @@ def build_dp_costs(
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
processes = process_mesh.process_ids
assert len(var_names) == 1
vars = dist_op.serial_op.block.vars
var_name = var_names[0]
......@@ -443,7 +443,7 @@ def build_dp_costs(
)
global_sizes = var.shape
shard_sizes = None
topology = process_mesh.topology
topology = process_mesh.shape
shape = DistributedTensor.get_local_sizes(
global_sizes,
dims_mapping,
......
......@@ -190,7 +190,7 @@ class CostEstimator:
# Calc dist op cost
dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.processes
processes = op_dist_attr.process_mesh.process_ids
container = get_distributed_operator_impl_container(
op_dist_attr.impl_type
......@@ -273,8 +273,8 @@ class CostEstimator:
# This estimation will be improved, now reshard and inplace are not considered.
# Persist var is not free.
def _convert_pm_and_dm_to_str(process_mesh, dims_mapping):
processes = ",".join([str(x) for x in process_mesh.processes])
topology = ",".join([str(x) for x in process_mesh.topology])
processes = ",".join([str(x) for x in process_mesh.process_ids])
topology = ",".join([str(x) for x in process_mesh.shape])
dims_mapping = ",".join([str(x) for x in dims_mapping])
result = processes + topology + dims_mapping
return result
......@@ -318,8 +318,8 @@ class CostEstimator:
sizes = DistributedTensor.get_local_sizes(
global_sizes,
input_dims_mapping,
process_mesh.topology,
process_mesh.processes,
process_mesh.shape,
process_mesh.process_ids,
)
var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype
......@@ -346,8 +346,8 @@ class CostEstimator:
sizes = DistributedTensor.get_local_sizes(
global_sizes,
output_dims_mapping,
process_mesh.topology,
process_mesh.processes,
process_mesh.shape,
process_mesh.process_ids,
)
var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype
......@@ -380,7 +380,7 @@ class CostEstimator:
# Not used
if var_name + key not in has_used_vars:
has_used_vars.add(has_used_var)
for process in process_mesh.processes:
for process in process_mesh.process_ids:
if process not in memories:
memories[process] = 0
memories[process] += var_info[var_name][key]["memory"]
......@@ -390,7 +390,7 @@ class CostEstimator:
if has_used_var not in can_free_vars:
can_free_vars.add(has_used_var)
if not var.persistable:
for process in process_mesh.processes:
for process in process_mesh.process_ids:
if process not in can_free_memories:
can_free_memories[process] = 0
can_free_memories[process] += var_info[
......@@ -409,7 +409,7 @@ class CostEstimator:
# Not used
if var_name + key not in has_used_vars:
has_used_vars.add(has_used_var)
for process in process_mesh.processes:
for process in process_mesh.process_ids:
if process not in memories:
memories[process] = 0
memories[process] += var_info[var_name][key]["memory"]
......@@ -419,7 +419,7 @@ class CostEstimator:
if has_used_var not in can_free_vars:
can_free_vars.add(has_used_var)
if not var.persistable:
for process in process_mesh.processes:
for process in process_mesh.process_ids:
if process not in can_free_memories:
can_free_memories[process] = 0
can_free_memories[process] += var_info[
......
......@@ -121,32 +121,17 @@ class TensorDistributedAttribute:
if dist_attr is None:
return
assert isinstance(
dist_attr, (dict, TensorDistributedAttribute)
dist_attr, TensorDistributedAttribute
), "The type of dist_attr must be dict or TensorDistributedAttribute."
if isinstance(dist_attr, dict):
for key, value in dist_attr.items():
if key in get_tensor_dist_attr_field_keys():
field_property = TensorDistributedAttribute.__dict__.get(
key, None
)
if field_property:
field_property.fset(self, value)
else:
assert False, "No setter for {} in args {}.".format(
key, dist_attr
)
elif isinstance(dist_attr, TensorDistributedAttribute):
for key in get_tensor_dist_attr_field_keys():
field_property = TensorDistributedAttribute.__dict__.get(
key, None
for key in get_tensor_dist_attr_field_keys():
field_property = TensorDistributedAttribute.__dict__.get(key, None)
if field_property:
field_property.fset(self, field_property.fget(dist_attr))
else:
assert False, "No setter for {} in args {}.".format(
key, dist_attr
)
if field_property:
field_property.fset(self, field_property.fget(dist_attr))
else:
assert False, "No setter for {} in args {}.".format(
key, dist_attr
)
self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
def reset(self, skip_dist_attr_field_names=None):
if skip_dist_attr_field_names is None or (
......@@ -243,7 +228,9 @@ class OperatorDistributedAttribute:
if process_mesh is not None:
assert isinstance(
process_mesh, (list, ProcessMesh)
), "The type of process_mesh must be list or ProcessMesh."
), "The type of process_mesh must be list or ProcessMesh, but receive {}".format(
type(process_mesh)
)
if isinstance(process_mesh, list):
process_mesh = ProcessMesh(process_mesh)
self._process_mesh = copy.deepcopy(process_mesh)
......
......@@ -870,8 +870,8 @@ class DistributedContext:
else:
tensor_shape = serial_tensor.shape
dims_mapping = dist_attr.dims_mapping
process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_processes = dist_attr.process_mesh.processes
process_mesh_shape = dist_attr.process_mesh.shape
process_mesh_processes = dist_attr.process_mesh.process_ids
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
......@@ -887,8 +887,8 @@ class DistributedContext:
for dist_op in self._dist_ops_for_program.values():
serial_op = dist_op.serial_op
dist_attr = dist_op.dist_attr
process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_processes = dist_attr.process_mesh.processes
process_mesh_shape = dist_attr.process_mesh.shape
process_mesh_processes = dist_attr.process_mesh.process_ids
for arg_name in serial_op.input_arg_names:
if dist_op.get_serial_input(arg_name) is None:
tensor_shape = []
......
......@@ -159,10 +159,10 @@ class DistributedOperator:
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology
self.dist_attr.process_mesh.shape
):
return False
for i in range(len(self.dist_attr.process_mesh.topology)):
for i in range(len(self.dist_attr.process_mesh.shape)):
if dims_mapping.count(i) > 1:
return False
if self.dist_attr.process_mesh != input_dist_attr.process_mesh:
......@@ -179,10 +179,10 @@ class DistributedOperator:
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology
self.dist_attr.process_mesh.shape
):
return False
for i in range(len(self.dist_attr.process_mesh.topology)):
for i in range(len(self.dist_attr.process_mesh.shape)):
if dims_mapping.count(i) > 1:
return False
if self.dist_attr.process_mesh != output_dist_attr.process_mesh:
......
......@@ -225,10 +225,10 @@ class DistributedTensor:
if self.dist_attr.dims_mapping[
i
] < -1 or self.dist_attr.dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology
self.dist_attr.process_mesh.shape
):
return False
for i in range(len(self.dist_attr.process_mesh.topology)):
for i in range(len(self.dist_attr.process_mesh.shape)):
if self.dist_attr.dims_mapping.count(i) > 1:
return False
return True
......@@ -239,8 +239,8 @@ class DistributedTensor:
global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.processes
topology = self.dist_attr.process_mesh.topology
processes = self.dist_attr.process_mesh.process_ids
topology = self.dist_attr.process_mesh.shape
local_sizes = DistributedTensor.get_local_sizes(
global_sizes, dims_mapping, topology, processes, rank, shard_sizes
)
......@@ -256,8 +256,8 @@ class DistributedTensor:
global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.processes
topology = self.dist_attr.process_mesh.topology
processes = self.dist_attr.process_mesh.process_ids
topology = self.dist_attr.process_mesh.shape
local_offsets = DistributedTensor.get_local_offsets(
global_sizes,
dims_mapping,
......@@ -282,8 +282,8 @@ class DistributedTensor:
global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.processes
topology = self.dist_attr.process_mesh.topology
processes = self.dist_attr.process_mesh.process_ids
topology = self.dist_attr.process_mesh.shape
local_shard = DistributedTensor.get_local_shard(
global_sizes,
dims_mapping,
......
......@@ -324,8 +324,8 @@ class ProgramHelper:
var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
dist_attr = {
"dims_mapping": var_dist_attr.dims_mapping,
"process_shape": var_dist_attr.process_mesh.topology,
"process_group": var_dist_attr.process_mesh.processes,
"process_shape": var_dist_attr.process_mesh.shape,
"process_group": var_dist_attr.process_mesh.process_ids,
}
# slice param_value with dist_attr
# share sliced_param_value with param_tensor in global_scope
......
......@@ -275,7 +275,7 @@ def is_parameter_related(varname, block):
def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr):
var_shape = block._var_recursive(src_var.name).shape
var_topoloy = src_var_dist_attr.process_mesh.topology
var_topoloy = src_var_dist_attr.process_mesh.shape
var_dims_mapping = src_var_dist_attr.dims_mapping
complete_shape = []
......@@ -287,7 +287,7 @@ def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr):
complete_shape.append(new_shape)
exact_shape = []
input_topology = op_input_dist_attr.process_mesh.topology
input_topology = op_input_dist_attr.process_mesh.shape
input_dims_mapping = op_input_dist_attr.dims_mapping
for idx, shape in enumerate(complete_shape):
if input_dims_mapping[idx] == -1:
......@@ -362,10 +362,10 @@ def get_data_parallel_group(dist_ctx, op, act_grad_names, rank):
op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
process_mesh = op_dist_attr.process_mesh
mesh_shape = process_mesh.topology
mesh_shape = process_mesh.shape
# FIXME Hack for Pipeline Parallelism where the current operator
# not belong to the mesh the current rank belong to.
if rank not in process_mesh.processes:
if rank not in process_mesh.process_ids:
rank = _get_corresponding_rank(dist_ctx, process_mesh, rank)
for var_name in act_grad_names:
......@@ -376,8 +376,8 @@ def get_data_parallel_group(dist_ctx, op, act_grad_names, rank):
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
group_ranks = _get_comm_group(
process_mesh.processes,
process_mesh.topology,
process_mesh.process_ids,
process_mesh.shape,
batch_size_axis,
rank,
)
......
......@@ -89,7 +89,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
str(backward_op)
)
assert rank_id in dist_attr.process_mesh.processes
assert rank_id in dist_attr.process_mesh.process_ids
assert 'X' in kwargs, "input [{}] is not given".format('X')
assert 'Scale' in kwargs, "input [{}] is not given".format('Scale')
......@@ -120,7 +120,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
rank_id
in ctx.get_tensor_dist_attr_for_program(
main_block._var_recursive(varname)
).process_mesh.processes
).process_mesh.process_ids
):
filter_vars.append(varname)
......
......@@ -124,7 +124,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
processes = dist_op.dist_attr.process_mesh.process_ids
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(
_g_op_cost_factory[op_type], ctx, processes, desc_mapping, cluster
......@@ -141,7 +141,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
processes = process_mesh.process_ids
backward_op = dist_op.serial_op
op_type = backward_op.type
cost_mapping = build_comp_costs_from_descs(
......@@ -157,7 +157,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
varname, main_block
):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
......@@ -172,7 +172,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping(
varname
)
mesh_shape = process_mesh.topology
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
......@@ -501,19 +501,19 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
dims_mapping = param_dist_attr.dims_mapping
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in process_mesh.processes:
if rank_id not in process_mesh.process_ids:
rank_id = _get_corresponding_rank(
ctx, process_mesh, rank_id
)
# NOTE all not splited axis should be presented in mesh
for axis, size in enumerate(process_mesh.topology):
for axis, size in enumerate(process_mesh.shape):
if size <= 1 or axis in dims_mapping:
pass
else:
group_ranks = _get_comm_group(
process_mesh.processes,
process_mesh.topology,
process_mesh.process_ids,
process_mesh.shape,
axis,
rank_id,
)
......
......@@ -67,7 +67,7 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
processes = dist_op.dist_attr.process_mesh.process_ids
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(
_g_op_cost_factory[op_type], ctx, processes, desc_mapping, cluster
......@@ -84,7 +84,7 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
processes = process_mesh.process_ids
backward_op = dist_op.serial_op
op_type = backward_op.type
cost_mapping = build_comp_costs_from_descs(
......@@ -100,7 +100,7 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
varname, main_block
):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
......@@ -115,7 +115,7 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping(
varname
)
mesh_shape = process_mesh.topology
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
......
......@@ -175,7 +175,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
processes = dist_op.dist_attr.process_mesh.process_ids
# embedding need start_index
cost_mapping = build_comp_costs_from_descs(
EmbeddingOpCost, ctx, processes, desc_mapping, cluster
......@@ -231,7 +231,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
processes = process_mesh.process_ids
comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
)
......@@ -250,7 +250,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Ids")[0]
)
mesh_shape = process_mesh.topology
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
......@@ -390,8 +390,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
), "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
embedding_row_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.process_ids
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in process_mesh_group:
......@@ -539,13 +539,13 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
dim_mapping = param_dist_attr.dims_mapping
# NOTE all not splited axis should be presented in mesh
for axis, size in enumerate(process_mesh.topology):
for axis, size in enumerate(process_mesh.shape):
if size <= 1 or axis in dim_mapping:
pass
else:
group_ranks = _get_comm_group(
process_mesh.processes,
process_mesh.topology,
process_mesh.process_ids,
process_mesh.shape,
axis,
rank_id,
)
......@@ -579,7 +579,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
)
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in dist_attr.process_mesh.processes:
if rank_id not in dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank(
ctx, dist_attr.process_mesh, rank_id
)
......@@ -623,8 +623,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
), "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
embedding_row_dim_mapping
)
process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_group = dist_attr.process_mesh.processes
process_mesh_shape = dist_attr.process_mesh.shape
process_mesh_group = dist_attr.process_mesh.process_ids
# A generalized method to caculate embedding offset using cartisian product
relative_idx = _get_idx_in_axis(
......
......@@ -61,7 +61,7 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
processes = dist_op.dist_attr.process_mesh.process_ids
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(
FillConstantBatchSizeLikeOpCost,
......@@ -139,7 +139,7 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl):
# modify shape attr according to how output are partitioned
out_name = op.output('Out')[0]
dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_shape = op_dist_attr.process_mesh.shape
shape_list = op.attr("shape")
# modify target shape
for idx, axis in enumerate(dims_mapping):
......
......@@ -156,7 +156,7 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl):
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
if rank_id not in op_dist_attr.process_mesh.processes:
if rank_id not in op_dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id
)
......@@ -172,8 +172,8 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl):
), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
qkv_w_col_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.process_ids
parallel_axis = qkv_w_col_dim_mapping
group_ranks = _get_comm_group(
......@@ -198,7 +198,7 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl):
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
if rank_id not in op_dist_attr.process_mesh.processes:
if rank_id not in op_dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id
)
......@@ -211,8 +211,8 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl):
), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
out_w_col_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.process_ids
parallel_axis = out_w_col_dim_mapping
group_ranks = _get_comm_group(
......
......@@ -148,7 +148,7 @@ class DistributedFusedFeedForwardImpl(DistributedOperatorImpl):
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
if rank_id not in op_dist_attr.process_mesh.processes:
if rank_id not in op_dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id
)
......@@ -163,8 +163,8 @@ class DistributedFusedFeedForwardImpl(DistributedOperatorImpl):
), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
linear1_weight_col_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.process_ids
parallel_axis = linear1_weight_col_dim_mapping
group_ranks = _get_comm_group(
......@@ -190,7 +190,7 @@ class DistributedFusedFeedForwardImpl(DistributedOperatorImpl):
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
if rank_id not in op_dist_attr.process_mesh.processes:
if rank_id not in op_dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id
)
......@@ -205,8 +205,8 @@ class DistributedFusedFeedForwardImpl(DistributedOperatorImpl):
), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
linear2_weight_col_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.process_ids
parallel_axis = linear2_weight_col_dim_mapping
group_ranks = _get_comm_group(
......
......@@ -299,7 +299,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
), "backward op [{}] don't have dist attribute !".format(str(backward_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in dist_attr.process_mesh.processes:
if rank_id not in dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh, rank_id)
assert 'Y' in kwargs, "input [{}] is not given".format('Y')
......@@ -341,8 +341,8 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
X_var_dims_mapping = dist_attr.get_input_dims_mapping(X_var.name)
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name)
process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_group = dist_attr.process_mesh.processes
process_mesh_shape = dist_attr.process_mesh.shape
process_mesh_group = dist_attr.process_mesh.process_ids
trans_x = None
trans_y = None
......@@ -532,12 +532,12 @@ def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id):
process_mesh = param_dist_attr.process_mesh
dim_mapping = param_dist_attr.dims_mapping
for axis, size in enumerate(process_mesh.topology):
for axis, size in enumerate(process_mesh.shape):
if size <= 1 or axis in dim_mapping:
pass
else:
group_ranks = _get_comm_group(
process_mesh.processes, process_mesh.topology, axis, rank_id
process_mesh.process_ids, process_mesh.shape, axis, rank_id
)
sync_group = new_process_group(group_ranks)
......@@ -600,7 +600,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
dist_op=dist_op, dist_context=ctx
)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
processes = process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs(
MatmulGradOpCost, ctx, processes, desc_mapping, cluster
)
......@@ -631,7 +631,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
if (
batch_size_axis > -1
......@@ -651,7 +651,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
processes = dist_op.dist_attr.process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs(
MatmulOpCost, ctx, processes, desc_mapping, cluster
)
......@@ -749,7 +749,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes:
if rank_id not in op_dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id
)
......@@ -791,8 +791,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.process_ids
parallel_axis = matmul_col_dim_mapping
group_ranks = _get_comm_group(
......@@ -986,7 +986,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
parallel_axis=parallel_axis,
)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
processes = process_mesh.process_ids
comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
)
......@@ -1005,7 +1005,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
if (
batch_size_axis > -1
......@@ -1025,7 +1025,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
processes = dist_op.dist_attr.process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs(
MatmulOpCost, ctx, processes, desc_mapping, cluster
)
......@@ -1131,7 +1131,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes:
if rank_id not in op_dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id
)
......@@ -1173,8 +1173,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.process_ids
parallel_axis = matmul_row_dim_mapping
group_ranks = _get_comm_group(
......@@ -1329,7 +1329,7 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
dist_op=dist_op, dist_context=ctx
)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
processes = process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs(
MatmulGradOpCost, ctx, processes, desc_mapping, cluster
)
......@@ -1339,7 +1339,7 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
if (
batch_size_axis > -1
......@@ -1360,7 +1360,7 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
processes = dist_op.dist_attr.process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs(
MatmulOpCost, ctx, processes, desc_mapping, cluster
)
......@@ -1479,7 +1479,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
backward_op.input("Y")[0]
)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
processes = process_mesh.process_ids
# col parallel: matmul + allreduce
if backward_op.attr("trans_y"):
Y_var_dim_mapping.reverse()
......@@ -1526,7 +1526,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
if (
batch_size_axis > -1
......@@ -1547,7 +1547,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
comp_desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
processes = dist_op.dist_attr.process_mesh.process_ids
comp_cost_mapping = build_comp_costs_from_descs(
MatmulV2OpCost, ctx, processes, comp_desc_mapping, cluster
)
......@@ -1645,7 +1645,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes:
if rank_id not in op_dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id
)
......@@ -1687,8 +1687,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.process_ids
parallel_axis = matmul_col_dim_mapping
group_ranks = _get_comm_group(
......@@ -1869,7 +1869,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
parallel_axis = Y_var_dim_mapping[0]
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
processes = process_mesh.process_ids
# calc comm op cost
var_names = [backward_op.input("Out@GRAD")[0]]
attrs = {"use_calc_stream": True, "use_model_parallel": True}
......@@ -1900,7 +1900,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
if (
batch_size_axis > -1
......@@ -1920,7 +1920,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
processes = dist_op.dist_attr.process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs(
MatmulV2OpCost, ctx, processes, desc_mapping, cluster
)
......@@ -2025,7 +2025,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes:
if rank_id not in op_dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id
)
......@@ -2067,8 +2067,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.process_ids
parallel_axis = matmul_row_dim_mapping
group_ranks = _get_comm_group(
......@@ -2222,7 +2222,7 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = process_mesh.processes
processes = process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs(
MatmulV2GradOpCost, ctx, processes, desc_mapping, cluster
)
......@@ -2232,7 +2232,7 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
if (
batch_size_axis > -1
......@@ -2253,7 +2253,7 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
processes = dist_op.dist_attr.process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs(
MatmulV2OpCost, ctx, processes, desc_mapping, cluster
)
......@@ -2386,7 +2386,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
dist_op=dist_op, dist_context=ctx
)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
processes = process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs(
MulGradOpCost, ctx, processes, desc_mapping, cluster
)
......@@ -2417,7 +2417,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
if (
batch_size_axis > -1
......@@ -2437,7 +2437,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
processes = dist_op.dist_attr.process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs(
MulOpCost, ctx, processes, desc_mapping, cluster
)
......@@ -2530,7 +2530,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes:
if rank_id not in op_dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id
)
......@@ -2566,8 +2566,8 @@ class DistributedMulImpl0(DistributedOperatorImpl):
), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.process_ids
parallel_axis = matmul_col_dim_mapping
group_ranks = _get_comm_group(
......@@ -2778,7 +2778,7 @@ class DistributedMulImpl1(DistributedOperatorImpl):
attrs=attrs,
parallel_axis=parallel_axis,
)
processes = process_mesh.processes
processes = process_mesh.process_ids
comm_op_cost_list = build_comm_costs_from_descs(
IdentityOpCost, ctx, processes, c_identity_desc_mapping, cluster
)
......@@ -2797,7 +2797,7 @@ class DistributedMulImpl1(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
if (
batch_size_axis > -1
......@@ -2817,7 +2817,7 @@ class DistributedMulImpl1(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
processes = dist_op.dist_attr.process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs(
MulOpCost, ctx, processes, desc_mapping, cluster
)
......@@ -2919,7 +2919,7 @@ class DistributedMulImpl1(DistributedOperatorImpl):
), "backward op [{}] don't have dist attribute !".format(str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.process_mesh.processes:
if rank_id not in op_dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id
)
......@@ -2955,8 +2955,8 @@ class DistributedMulImpl1(DistributedOperatorImpl):
), "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping
)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.process_ids
parallel_axis = matmul_row_dim_mapping
group_ranks = _get_comm_group(
......@@ -3131,7 +3131,7 @@ class DistributedMulImpl2(DistributedOperatorImpl):
dist_op=dist_op, dist_context=ctx
)
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
processes = process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs(
MulGradOpCost, ctx, processes, desc_mapping, cluster
)
......@@ -3141,7 +3141,7 @@ class DistributedMulImpl2(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("X")[0]
)
mesh_shape = process_mesh.topology
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
if (
batch_size_axis > -1
......@@ -3162,7 +3162,7 @@ class DistributedMulImpl2(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
processes = dist_op.dist_attr.process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs(
MulOpCost, ctx, processes, desc_mapping, cluster
)
......
......@@ -154,7 +154,7 @@ class DistributedPNormImpl(DistributedOperatorImpl):
output_name
)
if rank_id not in op_dist_attr.process_mesh.processes:
if rank_id not in op_dist_attr.process_mesh.process_ids:
rank_id = _get_corresponding_rank(
ctx, op_dist_attr.process_mesh, rank_id
)
......@@ -164,8 +164,8 @@ class DistributedPNormImpl(DistributedOperatorImpl):
for axis in range(len(in_dims_mapping)):
if in_dims_mapping[axis] != -1:
break
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.process_ids
group_ranks = _get_comm_group(
process_mesh_group, process_mesh_shape, axis, rank_id
)
......@@ -301,8 +301,8 @@ class DistributedPNormImpl(DistributedOperatorImpl):
ctx.set_op_dist_attr_for_program(p_norm_grad_op, op_dist_attr)
# 2. insert slice op
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
process_mesh_shape = op_dist_attr.process_mesh.shape
process_mesh_group = op_dist_attr.process_mesh.process_ids
dims_mapping = [0] + [-1 for _ in range(len(new_X_grad.shape) - 1)]
from ..reshard import Resharder
......
......@@ -67,7 +67,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
shape_list = op.desc.attr("shape")
# got dist attribute info
dim_mapping = dist_attr.get_output_dims_mapping(op.output("Out")[0])
process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_shape = dist_attr.process_mesh.shape
# modify target shape
for idx, axis in enumerate(dim_mapping):
......@@ -81,7 +81,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_attr.process_mesh.processes
processes = dist_attr.process_mesh.process_ids
for key in desc_mapping:
desc_mapping[key]["shape"] = shape_list
......@@ -100,7 +100,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
processes = process_mesh.process_ids
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(
......@@ -119,7 +119,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
......@@ -266,7 +266,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
# got dist attribute info
dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_shape = op_dist_attr.process_mesh.shape
# modify target shape
for idx, axis in enumerate(dim_mapping):
......@@ -315,7 +315,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
shape_list = op.desc.attr("shape")
# got dist attribute info
dim_mapping = dist_attr.get_output_dims_mapping(op.output("Out")[0])
process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_shape = dist_attr.process_mesh.shape
# modify target shape
for idx, axis in enumerate(dim_mapping):
......@@ -329,7 +329,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_attr.process_mesh.processes
processes = dist_attr.process_mesh.process_ids
for key in desc_mapping:
desc_mapping[key]["shape"] = shape_list
......@@ -348,7 +348,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
processes = process_mesh.process_ids
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(
......@@ -367,7 +367,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
......@@ -517,7 +517,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
# got dist attribute info
dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_shape = op_dist_attr.process_mesh.shape
# modify target shape
for idx, axis in enumerate(dim_mapping):
......@@ -566,7 +566,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
shape_list = op.desc.attr("shape")
# got dist attribute info
dim_mapping = dist_attr.get_output_dims_mapping(op.output("Out")[0])
process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_shape = dist_attr.process_mesh.shape
# modify target shape
for idx, axis in enumerate(dim_mapping):
......@@ -580,7 +580,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_attr.process_mesh.processes
processes = dist_attr.process_mesh.process_ids
for key in desc_mapping:
desc_mapping[key]["shape"] = shape_list
......@@ -599,7 +599,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
processes = process_mesh.process_ids
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(
......@@ -618,7 +618,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
......@@ -761,7 +761,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
# got dist attribute info
out_dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_shape = op_dist_attr.process_mesh.shape
# modify target shape
for idx, axis in enumerate(out_dim_mapping):
......
......@@ -60,7 +60,7 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
processes = dist_op.dist_attr.process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs(
SoftmaxOpCost, ctx, processes, desc_mapping, cluster
)
......@@ -76,7 +76,7 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
processes = process_mesh.process_ids
cost_mapping = build_comp_costs_from_descs(
SoftmaxGradOpCost, ctx, processes, desc_mapping, cluster
)
......@@ -93,7 +93,7 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
......
......@@ -140,7 +140,7 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.processes
processes = dist_op.dist_attr.process_mesh.process_ids
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(
Transpose2OpCost, ctx, processes, desc_mapping, cluster
......@@ -157,7 +157,7 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
processes = process_mesh.process_ids
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(
Transpose2GradOpCost, ctx, processes, desc_mapping, cluster
......@@ -175,7 +175,7 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
......
......@@ -79,7 +79,7 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl):
str(backward_op)
)
assert rank_id in dist_attr.process_mesh.processes
assert rank_id in dist_attr.process_mesh.process_ids
assert 'X' in kwargs, "input [{}] is not given".format('X')
assert 'FoundInfinite' in kwargs, "input [{}] is not given".format(
......@@ -154,7 +154,7 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl):
rank_id
in ctx.get_tensor_dist_attr_for_program(
main_block._var_recursive(varname)
).process_mesh.processes
).process_mesh.process_ids
):
filter_vars.append(varname)
......
......@@ -286,7 +286,7 @@ class AutoParallelizer:
_g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, [])
for process_mesh in self._dist_context._process_meshes:
_g_process_group_map[0].add_ranks(process_mesh.processes)
_g_process_group_map[0].add_ranks(process_mesh.process_ids)
return (
dist_optimize_ops,
dist_params_grads,
......@@ -482,7 +482,7 @@ class AutoParallelizer:
if dist_context is not None:
pg0 = get_process_group(0)
for process_mesh in dist_context._process_meshes:
pg0.add_ranks(process_mesh.processes)
pg0.add_ranks(process_mesh.process_ids)
(
dist_optimize_ops,
dist_params_grads,
......
......@@ -391,7 +391,7 @@ def _get_dist_shape(var, dist_attr):
var_shape = var.shape
mapping = dist_attr.dims_mapping
mesh = dist_attr.process_mesh.topology
mesh = dist_attr.process_mesh.shape
if mapping == []:
return var_shape
......
......@@ -74,7 +74,7 @@ class PlanFilter:
for var_name in op.input_arg_names:
dims_mapping = op_dist_attr.get_input_dims_mapping(var_name)
if not PlanFilter.check_dims_mapping_for_tensor(
process_mesh.topology, vars[var_name].shape, dims_mapping
process_mesh.shape, vars[var_name].shape, dims_mapping
):
return False
if vars[var_name].is_data and len(dims_mapping) > 1:
......@@ -85,7 +85,7 @@ class PlanFilter:
for var_name in op.output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(var_name)
if not PlanFilter.check_dims_mapping_for_tensor(
process_mesh.topology, vars[var_name].shape, dims_mapping
process_mesh.shape, vars[var_name].shape, dims_mapping
):
return False
......@@ -217,13 +217,13 @@ class PlanSpace:
for var_name in chain(op.input_arg_names, op.output_arg_names):
visited = [
False
for _ in range(len(list(range(-1, len(process_mesh.topology)))))
for _ in range(len(list(range(-1, len(process_mesh.shape)))))
]
depth = 0
path = []
dims_mapping_list = []
PlanSpace._enum_dims_mapping(
process_mesh.topology,
process_mesh.shape,
visited,
path,
depth,
......@@ -590,7 +590,10 @@ class MCMC(SearchAlgorithm):
self.serial_program_info, dist_context, self.parallelizer
)
pipeline_config = (
[process_mesh.processes for process_mesh in pipeline_process_meshes]
[
process_mesh.process_ids
for process_mesh in pipeline_process_meshes
]
if pipeline_process_meshes is not None
else None
)
......@@ -1060,7 +1063,7 @@ class MCMC(SearchAlgorithm):
# rebuild g_process_group
pg0 = get_process_group(0)
for process_mesh in searched_dist_context._process_meshes:
pg0.add_ranks(process_mesh.processes)
pg0.add_ranks(process_mesh.process_ids)
end_time = time.time()
print(
"End MCMC searching: the min cost is {} and the search time is {}s.".format(
......
......@@ -17,6 +17,7 @@ import copy
import numpy as np
import paddle
from paddle.fluid import core
# Use to store the previous and current process mesh
_g_previous_process_mesh = None
......@@ -41,12 +42,12 @@ def reset_current_process_mesh():
_g_current_process_mesh = _g_previous_process_mesh
class ProcessMesh:
class ProcessMesh(core.ProcessMesh):
"""
The `Processmesh` object describes the topology of the used processes.
The `ProcessMesh` object describes the Cartesian topology of the used processes.
Args:
mesh (list|numpy.array): an n-dimensional array describes the toplogy
mesh (list|numpy.array): an n-dimensional array describes the topology
of the processes.
dim_names (list, optional): the i-th element of this list gives the name of the
i-th dimension of the mesh.
......@@ -58,7 +59,7 @@ class ProcessMesh:
mesh = auto.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"])
assert mesh.shape == [2, 3]
assert mesh.processe_ids == [2, 4, 5, 0, 1, 3]
assert mesh.process_ids == [2, 4, 5, 0, 1, 3]
"""
......@@ -77,6 +78,9 @@ class ProcessMesh:
if isinstance(mesh, list):
mesh = np.array(mesh)
if dim_names is not None and not isinstance(dim_names, list):
raise ValueError('The dim_names must be an instance of list.')
self._mesh = mesh
self._shape = list(self._mesh.shape)
self._process_ids = self._mesh.flatten().tolist()
......@@ -104,6 +108,11 @@ class ProcessMesh:
self._dim_names
), 'All dim_names {} must be unique.'.format(dim_names)
# Follow the requirement for using pybind11
core.ProcessMesh.__init__(
self, self._shape, self._process_ids, self._dim_names
)
# Store all process meshes
from .dist_context import get_default_distributed_context
......@@ -113,35 +122,7 @@ class ProcessMesh:
from .process_group import get_process_group
pg0 = get_process_group(0)
pg0.add_ranks(self.processes)
@property
def shape(self):
"""
Get the shape of this ProcessMesh.
"""
return self._shape
@property
def process_ids(self):
"""
Get the process ids belonging to this ProcessMesh.
"""
return self._process_ids
@property
def dim_names(self):
"""
Get the dimension names of this ProcessMesh.
"""
return self._dim_names
@property
def ndim(self):
"""
Get the number of dimension of this ProcessMesh.
"""
return len(self._shape)
pg0.add_ranks(self.process_ids)
@property
def mesh(self):
......@@ -150,14 +131,6 @@ class ProcessMesh:
"""
return self._mesh
@property
def topology(self):
return self._shape
@property
def processes(self):
return self._process_ids
def __getitem__(self, index):
if isinstance(index, tuple):
new_dim_names = []
......@@ -207,9 +180,8 @@ class ProcessMesh:
tensor
)
if dist_tensor is None:
dist_tensor = DistributedTensor(
cur_block.vars[name], {"process_mesh": self}
)
dist_tensor = DistributedTensor(cur_block.vars[name])
dist_tensor.dist_attr.process_mesh = self
dist_tensor.dist_attr.mark_annotated("process_mesh")
default_dist_ctx.add_dist_tensor_for_program(dist_tensor)
else:
......@@ -221,7 +193,8 @@ class ProcessMesh:
op = cur_block.ops[idx]
dist_op = default_dist_ctx.get_dist_op_for_program(op)
if dist_op is None:
dist_op = DistributedOperator(op, {"process_mesh": self})
dist_op = DistributedOperator(op)
dist_op.dist_attr.process_mesh = self
dist_op.dist_attr.mark_annotated("process_mesh")
default_dist_ctx.add_dist_op_for_program(dist_op)
else:
......@@ -230,6 +203,13 @@ class ProcessMesh:
dist_op.dist_attr.mark_annotated("process_mesh")
reset_current_process_mesh()
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
new_process_mesh = ProcessMesh(np.array(self.mesh), self.dim_names)
memo[id(self)] = new_process_mesh
return new_process_mesh
def __eq__(self, other):
if not isinstance(other, ProcessMesh):
return False
......@@ -245,3 +225,51 @@ class ProcessMesh:
self.shape, self.process_ids, self.dim_names
)
return str
def compute_compatible_process_mesh(process_mesh_list):
"""Compute the compatible process mesh given a list of process meshes."""
if not process_mesh_list:
return None
def _compute_compatible_process_mesh_of_two(pm1, pm2):
if pm1 is None:
return True, pm2
if pm2 is None:
return True, pm1
if pm1 == pm2:
return True, pm1
if pm1.process_ids == pm2.process_ids:
if len(pm1.shape) >= len(pm2.shape):
return True, pm1
else:
return True, pm2
process_set1 = set(pm1.process_ids)
process_set2 = set(pm2.process_ids)
if process_set1.issubset(process_set2):
return True, pm2
if process_set2.issubset(process_set1):
return True, pm1
return False, None
compatible_result = None
for process_mesh in process_mesh_list:
compatible, compatible_result = _compute_compatible_process_mesh_of_two(
compatible_result, process_mesh
)
if not compatible:
return None
return copy.deepcopy(compatible_result)
def merge_process_meshes(process_meshes):
"""Merge a list of process meshes."""
merged_process_mesh = None
merged_process_ids = set()
for process_mesh in process_meshes:
if process_mesh is not None:
process_ids = set(process_mesh.process_ids)
merged_process_ids = merged_process_ids.union(process_ids)
if len(merged_process_ids) != 0:
merged_process_mesh = ProcessMesh(list(merged_process_ids))
return merged_process_mesh
......@@ -829,7 +829,7 @@ class Remover:
)
).process_mesh
)
if rank_id in process_mesh.processes:
if rank_id in process_mesh.process_ids:
need_save.append(var_name)
if not need_save:
remove_op_idx.append(idx)
......@@ -845,7 +845,7 @@ class Remover:
if op_dist_attr is not None:
op_process_mesh = op_dist_attr.process_mesh
if (
rank_id not in op_process_mesh.processes
rank_id not in op_process_mesh.process_ids
and op.type not in not_remove_op_ref
):
remove_op_idx.append(idx)
......@@ -1408,9 +1408,11 @@ class Resharder:
op_process_mesh = dist_op.dist_attr.process_mesh
for process_mesh in self.dist_context.process_meshes:
if set(process_mesh.processes) & (
set(op_process_mesh.processes)
) and len(process_mesh.processes) < len(op_process_mesh.processes):
if set(process_mesh.process_ids) & (
set(op_process_mesh.process_ids)
) and len(process_mesh.process_ids) < len(
op_process_mesh.process_ids
):
process_meshes.append(process_mesh)
# it means the process mesh is not a union when process meshes is null
......@@ -1438,13 +1440,13 @@ class Resharder:
source_dims_mapping = tensor_dist_attr.dims_mapping
source_process_mesh = tensor_dist_attr.process_mesh
source_process_group = source_process_mesh.processes
source_process_shape = source_process_mesh.topology
source_process_group = source_process_mesh.process_ids
source_process_shape = source_process_mesh.shape
target_process_mesh = dist_attr[0]
target_dims_mapping = dist_attr[1]
target_process_group = target_process_mesh.processes
target_process_shape = target_process_mesh.topology
target_process_group = target_process_mesh.process_ids
target_process_shape = target_process_mesh.shape
if source_tensor.shape[0] < 0:
assert source_tensor.shape[0] == -1
......@@ -2141,9 +2143,11 @@ class Resharder:
dist_attr = dist_op.dist_attr
op_process_mesh = dist_attr.process_mesh
for process_mesh in self.dist_context.process_meshes:
if set(process_mesh.processes) & (
set(op_process_mesh.processes)
) and len(process_mesh.processes) < len(op_process_mesh.processes):
if set(process_mesh.process_ids) & (
set(op_process_mesh.process_ids)
) and len(process_mesh.process_ids) < len(
op_process_mesh.process_ids
):
process_meshes.append(process_mesh)
# it means that the process mesh is not a union when process meshes is none
......@@ -2176,12 +2180,12 @@ class Resharder:
if process_mesh_count > 1:
global_process_mesh_idx = None
for process_mesh in self.dist_context.process_meshes:
for process in process_mesh.processes:
for process in process_mesh.process_ids:
processes.add(process)
for idx, process_mesh in enumerate(
self.dist_context.process_meshes
):
if len(set(process_mesh.processes)) == len(processes):
if len(set(process_mesh.process_ids)) == len(processes):
global_process_mesh_idx = idx
break
......@@ -2191,7 +2195,7 @@ class Resharder:
for i, mesh in enumerate(self.dist_context.process_meshes):
if i == idx:
continue
if set(mesh.processes) < set(global_mesh.processes):
if set(mesh.process_ids) < set(global_mesh.process_ids):
is_removed = True
if is_removed:
......@@ -2330,8 +2334,8 @@ class Resharder:
# deal with union tensor
if is_union_process_mesh_tensor:
# if op process mesh is subset of union tensor process mesh, need no reshard
if set(input_attr[0].processes) <= set(
dist_tensor.dist_attr.process_mesh.processes
if set(input_attr[0].process_ids) <= set(
dist_tensor.dist_attr.process_mesh.process_ids
):
continue
......@@ -2525,14 +2529,14 @@ class Resharder:
dist_tensor, output_attr, False
):
tensor_processes = set(
tensor_process_mesh.processes
tensor_process_mesh.process_ids
) - (
set(tensor_process_mesh.processes)
& set(output_attr[0].processes)
set(tensor_process_mesh.process_ids)
& set(output_attr[0].process_ids)
)
if tensor_processes:
if len(tensor_processes) != len(
output_attr[0].processes
output_attr[0].process_ids
):
if dist_tensor.dist_attr.dims_mapping.count(
-1
......@@ -2555,13 +2559,15 @@ class Resharder:
recv_rank = tensor_process
actual_index = index
if index >= len(
output_attr[0].processes
output_attr[0].process_ids
):
actual_index = (
index
- len(output_attr[0].processes)
) % len(output_attr[0].processes)
item = output_attr[0].processes[
- len(
output_attr[0].process_ids
)
) % len(output_attr[0].process_ids)
item = output_attr[0].process_ids[
actual_index
]
if recv_rank == item:
......@@ -2591,7 +2597,7 @@ class Resharder:
tensor_processes
):
recv_rank = tensor_process
item = output_attr[0].processes[index]
item = output_attr[0].process_ids[index]
if recv_rank == item:
continue
if self.rank_id == item:
......
......@@ -114,7 +114,7 @@ def _copy_context(ref_dist_context):
clear_all_process_groups()
ranks = []
for process_mesh in ref_dist_context._process_meshes:
ranks.extend(process_mesh.processes)
ranks.extend(process_mesh.process_ids)
new_process_group(list(set(ranks)))
new_dist_context = DistributedContext()
......
......@@ -689,15 +689,15 @@ class ParallelTuner:
assert process_mesh.ndim == 2
dim_of_one = None
dim_of_other = None
if process_mesh.topology[0] == 1:
if process_mesh.shape[0] == 1:
dim_of_one = 0
dim_of_other = 1
elif process_mesh.topology[1] == 1:
elif process_mesh.shape[1] == 1:
dim_of_one = 1
dim_of_other = 0
if dim_of_one is not None:
dist_attr.process_mesh = ProcessMesh(process_mesh.processes)
dist_attr.process_mesh = ProcessMesh(process_mesh.process_ids)
self._dist_context.add_process_mesh(dist_attr.process_mesh)
for arg_name in dist_attr.inputs_dist_attrs.keys():
......@@ -715,7 +715,7 @@ class ParallelTuner:
dims_mapping = dist_attr.get_input_dims_mapping(arg_name)
# dynamic_dims = dist_attr.get_input_dynamic_dims(arg_name)
process_mesh = dist_attr.process_mesh
process_shape = process_mesh.topology
process_shape = process_mesh.shape
tensor = dist_op.get_serial_input(arg_name)
if dims_mapping:
tensor_shape = tensor.shape
......@@ -748,7 +748,7 @@ class ParallelTuner:
dims_mapping = dist_attr.get_output_dims_mapping(arg_name)
# dynamic_dims = dist_attr.get_output_dynamic_dims(arg_name)
process_mesh = dist_attr.process_mesh
process_shape = process_mesh.topology
process_shape = process_mesh.shape
tensor = dist_op.get_serial_output(arg_name)
if dims_mapping:
......@@ -793,7 +793,7 @@ class ParallelTuner:
input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
input_name
)
topology = dist_op.dist_attr.process_mesh.topology
topology = dist_op.dist_attr.process_mesh.shape
input_tensor = dist_op.get_serial_input(input_name)
last_but_one_dim = (
input_tensor.shape[-2] // topology[input_dims_mapping[-2]]
......
......@@ -100,7 +100,7 @@ def convert_to_dims_mapping(shard_spec, process_mesh):
for shard in shard_spec:
if shard is None:
dims_mapping.append(-1)
elif process_mesh.topology[process_mesh.dim_names.index(shard)] == 1:
elif process_mesh.shape[process_mesh.dim_names.index(shard)] == 1:
dims_mapping.append(-1)
else:
dims_mapping.append(process_mesh.dim_names.index(shard))
......@@ -429,26 +429,26 @@ def _get_corresponding_rank(dist_context, target_mesh, rank):
coordinate = None
for mesh in dist_context.process_meshes:
if rank in mesh.processes and mesh.topology == target_mesh.topology:
if rank in mesh.process_ids and mesh.shape == target_mesh.shape:
coordinate = _linear_idx2coordinate(
mesh.topology, mesh.processes.index(rank)
mesh.shape, mesh.process_ids.index(rank)
)
break
# assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format(
# rank)
if coordinate is not None:
return target_mesh.processes[
_coordinate2linear_idx(mesh.topology, coordinate)
return target_mesh.process_ids[
_coordinate2linear_idx(mesh.shape, coordinate)
]
else:
return target_mesh.processes[0]
return target_mesh.process_ids[0]
def _get_unshard_dist_shape(var, dist_attr):
var_shape = var.shape
mapping = dist_attr.dims_mapping
mesh = dist_attr.process_mesh.topology
mesh = dist_attr.process_mesh.shape
assert len(var_shape) == len(
mapping
), "variable shape [{}] and dim_mapping [{}] is NOT match !".format(
......@@ -832,8 +832,8 @@ def get_dist_attr(program, dist_context=None):
process_mesh = tensor_dist_attr.process_mesh
dims_mapping = tensor_dist_attr.dims_mapping
dist_attr[var.name] = {
"process_shape": process_mesh.topology,
"process_group": process_mesh.processes,
"process_shape": process_mesh.shape,
"process_group": process_mesh.process_ids,
"dims_mapping": dims_mapping,
}
return dist_attr
......@@ -2006,16 +2006,16 @@ def get_input_split_info(cur_rank, var, dist_context):
process_mesh = tensor_dist_attr.process_mesh
dims_mapping = tensor_dist_attr.dims_mapping
if cur_rank not in process_mesh.processes:
if cur_rank not in process_mesh.process_ids:
rank_id = _get_corresponding_rank(dist_context, process_mesh, cur_rank)
else:
rank_id = cur_rank
batch_size_axis = dims_mapping[0]
if batch_size_axis > -1 and process_mesh.topology[batch_size_axis] > 1:
if batch_size_axis > -1 and process_mesh.shape[batch_size_axis] > 1:
group_ranks = _get_comm_group(
process_mesh.processes,
process_mesh.topology,
process_mesh.process_ids,
process_mesh.shape,
batch_size_axis,
rank_id,
)
......
......@@ -164,8 +164,8 @@ class ClipHelper:
param = self.params[self.params_name.index(name)]
dist_attr = self._get_dist_attr(name)
topology = dist_attr.process_mesh.topology
processes = dist_attr.process_mesh.processes
topology = dist_attr.process_mesh.shape
processes = dist_attr.process_mesh.process_ids
dims_mapping = dist_attr.dims_mapping
return _is_about_global_norm(
self.rank_id,
......@@ -188,7 +188,7 @@ class ClipHelper:
def _is_local_var(self, name):
dist_attr = self._get_dist_attr(name)
assert dist_attr is not None
return self.rank_id in dist_attr.process_mesh.processes
return self.rank_id in dist_attr.process_mesh.process_ids
def _init_dist_attr(self, op):
op_dist_attr = OperatorDistributedAttribute()
......
......@@ -870,13 +870,13 @@ def _inference_data_parallel_group_for_operator(rank_id, op, dist_context):
dist_attr = dist_context.get_op_dist_attr_for_program(op)
process_mesh = dist_attr.process_mesh
input_dim_mapping = dist_attr.get_input_dims_mapping(input_name)
mesh_shape = process_mesh.topology
mesh_shape = process_mesh.shape
# TODO(JZ-LIANG) replace with specific batch size dimension
batch_size_axis = input_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
group_ranks = _get_comm_group(
process_mesh.processes,
process_mesh.topology,
process_mesh.process_ids,
process_mesh.shape,
batch_size_axis,
rank_id,
)
......
......@@ -180,7 +180,7 @@ class TestBaseCost(unittest.TestCase):
for op in train_program.global_block().ops:
dist_op = dist_context.get_dist_op_for_program(op)
if dist_op:
processes = dist_op.dist_attr.process_mesh.processes
processes = dist_op.dist_attr.process_mesh.process_ids
comp_descs = build_comp_desc_from_dist_op(dist_op, dist_context)
self.assertTrue(isinstance(comp_descs, dict) and comp_descs)
var_names = None
......
......@@ -52,13 +52,13 @@ class MLPLayer(nn.Layer):
out = self.norm(input)
auto.shard_tensor(
self.linear0.weight, auto.ProcessMesh([0, 1], "x"), [None, "x"]
self.linear0.weight, auto.ProcessMesh([0, 1], ["x"]), [None, "x"]
)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
auto.shard_tensor(
self.linear1.weight, auto.ProcessMesh([0, 1], "x"), ["x", None]
self.linear1.weight, auto.ProcessMesh([0, 1], ["x"]), ["x", None]
)
out = self.linear1(out)
......
......@@ -98,7 +98,7 @@ class TestDistOpCost(unittest.TestCase):
):
dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.processes
processes = op_dist_attr.process_mesh.process_ids
if is_elementwise_op(op.type):
container = get_distributed_operator_impl_container(
"elementwise"
......@@ -205,7 +205,7 @@ class TestDistOpCost(unittest.TestCase):
for idx, op in enumerate(ops):
dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.processes
processes = op_dist_attr.process_mesh.process_ids
if is_elementwise_op(op.type):
container = get_distributed_operator_impl_container(
"elementwise"
......@@ -313,7 +313,7 @@ class TestDistOpCost(unittest.TestCase):
for idx, op in enumerate(ops):
dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.processes
processes = op_dist_attr.process_mesh.process_ids
if is_elementwise_op(op.type):
container = get_distributed_operator_impl_container(
"elementwise"
......@@ -421,7 +421,7 @@ class TestDistOpCost(unittest.TestCase):
for idx, op in enumerate(ops):
dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.processes
processes = op_dist_attr.process_mesh.process_ids
if is_elementwise_op(op.type):
container = get_distributed_operator_impl_container(
"elementwise"
......
......@@ -23,7 +23,11 @@ import paddle.static as static
from paddle.distributed.auto_parallel.dist_context import (
get_default_distributed_context,
)
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.process_mesh import (
ProcessMesh,
compute_compatible_process_mesh,
merge_process_meshes,
)
paddle.enable_static()
......@@ -129,7 +133,7 @@ class TestProcessMesh(unittest.TestCase):
initializer_range=0.02,
)
with ProcessMesh(mesh, "d"):
with ProcessMesh(mesh, ["d"]):
out = mlp(input)
default_program = paddle.fluid.default_main_program()
......@@ -151,6 +155,67 @@ class TestProcessMesh(unittest.TestCase):
dist_op.dist_attr.process_mesh, ProcessMesh(mesh)
)
def test_compute_compatible_process_mesh(self):
process_mesh1 = ProcessMesh(
[[0, 1, 2], [3, 4, 5]], dim_names=["x", "y"]
)
compatible_process_mesh = compute_compatible_process_mesh(
[process_mesh1, None]
)
self.assertEqual(compatible_process_mesh, process_mesh1)
compatible_process_mesh = compute_compatible_process_mesh(
[None, process_mesh1]
)
self.assertEqual(compatible_process_mesh, process_mesh1)
process_mesh2 = ProcessMesh([[0, 1, 2], [3, 4, 5]])
compatible_process_mesh = compute_compatible_process_mesh(
[process_mesh1, process_mesh2]
)
self.assertEqual(compatible_process_mesh, process_mesh1)
self.assertEqual(compatible_process_mesh, process_mesh2)
process_mesh2 = ProcessMesh([[0, 1, 2, 3, 4, 5]])
compatible_process_mesh = compute_compatible_process_mesh(
[process_mesh1, process_mesh2]
)
self.assertEqual(compatible_process_mesh, process_mesh1)
process_mesh2 = ProcessMesh([[0, 1, 2]])
compatible_process_mesh = compute_compatible_process_mesh(
[process_mesh1, process_mesh2]
)
self.assertEqual(compatible_process_mesh, process_mesh1)
def test_merge_process_meshes(self):
process_mesh1 = ProcessMesh(
[[0, 1, 2], [3, 4, 5]], dim_names=["x", "y"]
)
merged_process_mesh = merge_process_meshes([process_mesh1, None])
self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
merged_process_mesh = merge_process_meshes([None, process_mesh1])
self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
process_mesh2 = ProcessMesh([[0, 1, 2], [3, 4, 5]])
merged_process_mesh = merge_process_meshes(
[process_mesh1, process_mesh2]
)
self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
process_mesh2 = ProcessMesh([[0, 1, 2]])
merged_process_mesh = merge_process_meshes(
[process_mesh1, process_mesh2]
)
self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
process_mesh2 = ProcessMesh([[6, 7]])
merged_process_mesh = merge_process_meshes(
[process_mesh1, process_mesh2]
)
self.assertEqual(
merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5, 6, 7])
)
if __name__ == "__main__":
unittest.main()
......@@ -80,7 +80,6 @@ class TestProcessMesh(unittest.TestCase):
[[0, 1, 2], [3, 4, 5]], dim_names=["x", "y"]
)
merged_process_mesh = merge_process_mesh([process_mesh1, None])
print(merged_process_mesh)
self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
merged_process_mesh = merge_process_mesh([None, process_mesh1])
self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
......
......@@ -105,7 +105,7 @@ def initialization_check(
):
if 'mp' in mode:
group_ranks = _get_comm_group(
process_mesh.processes, process_mesh.topology, mp_parallel_axis, 3
process_mesh.process_ids, process_mesh.shape, mp_parallel_axis, 3
)
mp_ring_id = new_process_group(group_ranks).id
broadcast_ops = [
......@@ -124,7 +124,7 @@ def initialization_check(
if 'dp' in mode:
group_ranks = _get_comm_group(
process_mesh.processes, process_mesh.topology, dp_parallel_axis, 3
process_mesh.process_ids, process_mesh.shape, dp_parallel_axis, 3
)
dp_ring_id = new_process_group(group_ranks).id
nparam = len(serial_startup_prog.all_parameters())
......
......@@ -958,12 +958,12 @@ class TestGPTPartitioner(unittest.TestCase):
dp_parallel_axis = 0
group_ranks = _get_comm_group(
process_mesh.processes, process_mesh.topology, mp_parallel_axis, 3
process_mesh.process_ids, process_mesh.shape, mp_parallel_axis, 3
)
mp_ring_id = new_process_group(group_ranks).id
group_ranks = _get_comm_group(
process_mesh.processes, process_mesh.topology, dp_parallel_axis, 3
process_mesh.process_ids, process_mesh.shape, dp_parallel_axis, 3
)
dp_ring_id = new_process_group(group_ranks).id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册