未验证 提交 a622b701 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] Logical Partition & Dist Op (#35117)

* support shard reader

* support shard reader

* add parallel mode

* update process mesh

* add method to compute comm_group

* implement dist_embedding forward func

* implement dist matmul forward func

* implement dist reshape forward func

* add transpiler framework

* add transpiler forward

* implement transpiler forward

* implement transpiler backward & update

* add process

* add unitest

* chmod

* chmod

* chmod

* update unitest

* add unitest for gpt

* remove unused print

* rename transpiler --> partitioner

* rename transpiler --> partitioner

* chmod

* chmod

* bug fixed

* remove amp function

* update case for dp mode

* update case for dp mode
上级 280d7421
......@@ -14,6 +14,7 @@
import copy
from collections import defaultdict
from paddle.fluid import core
class TensorDistributedAttribute:
......@@ -77,6 +78,8 @@ class TensorDistributedAttribute:
self._is_parameter = True
def is_valid(self):
if self.get_owner_tensor().type == core.VarDesc.VarType.READER:
return True
tensor_shape = self.get_owner_tensor().desc.shape()
if len(tensor_shape) != len(self.get_dims_mapping()):
return False
......@@ -222,6 +225,8 @@ class OperatorDistributedAttribute:
self._is_parameters[name] = True
def is_valid(self):
if "read" in self.get_owner_op().type:
return True
for name in self.get_owner_op().desc.input_arg_names():
dims_mapping = self.get_input_dims_mapping(name)
shape = self.get_input_shape(name)
......
......@@ -15,9 +15,11 @@
import copy
from collections import defaultdict
from paddle.fluid import framework
from paddle.fluid import core
from .attribute import TensorDistributedAttribute
from .attribute import OperatorDistributedAttribute
from .utils import append_distributed_attr_suffix
from .interface import _g_process_mesh_map
# There always exists a default context for user. And user can set it to another one.
DEFAULT_DISTRIBUTED_CONTEXT = None
......@@ -49,6 +51,20 @@ class DistributedContext:
self._op_distributed_attr_map_for_program = {}
self._tensor_distributed_attr_map_for_graph = {}
self._op_distributed_attr_map_for_graph = {}
# The following is a hard code and will be removed in the future
self._data_parallel_axis = None
self._model_parallel_axis = None
self._process_mesh = _g_process_mesh_map.get(0, None)
if self._process_mesh is not None:
if self._process_mesh.ndim == 1:
self._data_parallel_axis = 0
self._model_parallel_axis = 0
else:
self._data_parallel_axis = 0
self._model_parallel_axis = 1
else:
self._data_parallel_axis = -1
self._model_parallel_axis = -1
def is_initialized_for_program(self):
return self._is_initialized_for_program
......@@ -99,6 +115,19 @@ class DistributedContext:
op_node_id = op_node.id()
self._op_distributed_attr_map_for_graph[op_node_id] = op_dist_attr
def set_process_mesh(self, process_mesh):
self._process_mesh = process_mesh
if self._process_mesh is not None:
if self._process_mesh.ndim == 1:
self._data_parallel_axis = 0
self._model_parallel_axis = 0
else:
self._data_parallel_axis = 0
self._model_parallel_axis = 1
else:
self._data_parallel_axis = -1
self._model_parallel_axis = -1
def initialize_distributed_attr_for_program(self, program):
if self._is_initialized_for_program:
return
......@@ -377,3 +406,11 @@ class DistributedContext:
if dims_mapping[i] != -1 and process_mesh_shape[
dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
def _get_data_parallel_info(self):
# This function is a hard code, and will be obsoleted in the future
return self._data_parallel_axis, self._process_mesh
def _get_model_parallel_info(self):
# This function is a hard code, and will be obsoleted in the future
return self._model_parallel_axis, self._process_mesh
......@@ -184,6 +184,13 @@ class ProcessMesh(object):
"parent with id %d does not exist." % self._parent_id)
return _g_process_mesh_map[self._parent_id]
@property
def ndim(self):
r"""
Get the number of dimension of ProcessMesh.
"""
return len(self._topology)
def set_placement(self, order):
"""
Set the map from logical processes to physical ones using the
......@@ -229,6 +236,13 @@ class ProcessMesh(object):
for idx, l_id in enumerate(logical_order):
_user_defined_physical_map[l_id] = order[idx]
def _reset_global_process_mesh_map(self):
"""
Remove all process mesh in _g_process_mesh_map, make it empty.
"""
_g_process_mesh_map = dict()
def __eq__(self, other):
assert other and isinstance(other, ProcessMesh)
if self.topology != other.topology or self.process_group != other.process_group:
......
......@@ -33,6 +33,8 @@ class DistributedOperator:
class DistributedOperatorImpl:
def __init__(self):
self._name = None
self._forward_implemented = False
self._backward_implemented = False
def forward(self, dist_ctx, *args, **kwargs):
raise NotImplementedError("Please Implement this method in Subclass.")
......
......@@ -22,6 +22,12 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from ..process import new_process_group
from ..utils import _get_comm_group
class DistributedEmbedding(DistributedOperator):
......@@ -39,6 +45,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedEmbeddingImpl, self).__init__()
self._name = name
self._forward_implemented = True
self._backward_implemented = False
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
......@@ -92,6 +100,110 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
return changed
def forward(self, serial_op):
def static_handle(dst_block,
src_op,
op_dist_attr,
input_name_mapping,
output_name_mapping,
rank_id=0):
assert len(
input_name_mapping
) == 2, "row_parallel_embedding take 2 inputs variable but got {}".format(
input_name_mapping)
assert len(
output_name_mapping
) == 1, "row_parallel_embedding take 2 inputs variable but got {}".format(
output_name_mapping)
assert len(
input_name_mapping['Ids']
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format(
input_name_mapping['Ids'])
assert len(
input_name_mapping['W']
) == 1, "row_parallel_embedding input W take 1 variable but got {}".format(
input_name_mapping['W'])
assert len(
output_name_mapping['Out']
) == 1, "row_parallel_embedding input Out take 1 variable but got {}".format(
input_name_mapping['Out'])
Ids_var = dst_block.var(input_name_mapping['Ids'][0])
Weight_var = dst_block.var(input_name_mapping['W'][0])
Out_var = dst_block.var(output_name_mapping['Out'][0])
# got dist attribute info
embedding_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[0]
process_mesh_shape = op_dist_attr.get_process_mesh().topology
process_mesh_group = op_dist_attr.get_process_mesh().process_group
# caculate embedding offset
# TODO generalize here, using cartisian product to allow any dimensional mesh shape
mesh_shape = len(process_mesh_shape)
assert mesh_shape <= 2, "row_parallel_embedding only support 1 or 2 dimensional process mesh, but got {}".format(
process_mesh_shape)
num_partition = process_mesh_shape[embedding_row_dim_mapping]
# TODO generalize here, support any mesh group
if mesh_shape == 1:
relative_idx = process_mesh_group.index(rank_id)
else:
relative_idx = rank_id % num_partition
per_part_size = Weight_var.shape[0]
relative_idx = relative_idx * per_part_size
# TODO caculate ring id
model_parallel_axis, process_mesh = op_dist_attr.get_owner_context(
)._get_model_parallel_info()
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology,
model_parallel_axis, rank_id)
group = new_process_group(group_ranks)
# append op
check_variable_and_dtype(Ids_var, 'input', ['int32', 'int64'],
'c_embedding')
intermediate_var_0 = dst_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_embedding", 'tmp'])),
dtype=Weight_var.dtype,
shape=Out_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=Out_var.stop_gradient)
check_variable_and_dtype(
Out_var, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'c_allreduce_sum')
dst_block.append_op(
type='c_embedding',
inputs={'Ids': [Ids_var],
'W': [Weight_var]},
outputs={'Out': [intermediate_var_0]},
attrs={"start_index": relative_idx})
# use_model_parallel
dst_block.append_op(
type='c_allreduce_sum',
inputs={'X': [intermediate_var_0]},
outputs={'Out': [Out_var]},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
})
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
"matmul", 0))
else:
return static_handle
register_distributed_operator_impl("lookup_table_v2",
DistributedEmbeddingImpl("row_parallel"))
......@@ -22,6 +22,12 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from ..process import new_process_group
from ..utils import _get_comm_group
def _update_dims_mapping_for_matmul(op_dist_attr):
......@@ -37,7 +43,6 @@ def _update_dims_mapping_for_matmul(op_dist_attr):
y_dims_mapping_len = len(y_dims_mapping)
out_dims_mapping_len = len(out_dims_mapping)
# print("before", x_dims_mapping, y_dims_mapping, out_dims_mapping)
# Add dim mapping to Make sure the length dims_mapping be at least 2
if x_dims_mapping_len == 1:
x_dims_mapping.insert(0, -1)
......@@ -109,7 +114,6 @@ def _update_dims_mapping_for_matmul(op_dist_attr):
if y_dims_mapping_len == 1:
y_dims_mapping.pop(1)
# print("after", x_dims_mapping, y_dims_mapping, out_dims_mapping)
assert len(x_dims_mapping) == x_dims_mapping_len
assert len(y_dims_mapping) == y_dims_mapping_len
assert len(out_dims_mapping) == out_dims_mapping_len
......@@ -131,6 +135,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulImpl0, self).__init__()
self._name = name
self._forward_implemented = True
self._backward_implemented = False
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
......@@ -170,12 +176,101 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
changed = True
return changed
def forward(self, serial_op):
def static_handle(dst_block,
src_op,
op_dist_attr,
input_name_mapping,
output_name_mapping,
rank_id=0):
assert len(
input_name_mapping
) == 2, "col_parallel_linear take 2 inputs variable but got {}".format(
input_name_mapping)
assert len(
output_name_mapping
) == 1, "col_parallel_linear take 2 inputs variable but got {}".format(
output_name_mapping)
assert len(
input_name_mapping['X']
) == 1, "col_parallel_linear input X take 1 variable but got {}".format(
input_name_mapping['X'])
assert len(
input_name_mapping['Y']
) == 1, "col_parallel_linear input Y take 1 variable but got {}".format(
input_name_mapping['Y'])
assert len(
output_name_mapping['Out']
) == 1, "col_parallel_linear input Out take 1 variable but got {}".format(
input_name_mapping['Out'])
X_var = dst_block.var(input_name_mapping['X'][0])
Weight_var = dst_block.var(input_name_mapping['Y'][0])
Out_var = dst_block.var(output_name_mapping['Out'][0])
# TODO infer logic comm presentation
model_parallel_axis, process_mesh = op_dist_attr.get_owner_context(
)._get_model_parallel_info()
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology,
model_parallel_axis, rank_id)
group = new_process_group(group_ranks)
intermediate_var_0 = dst_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_identity", 'tmp'])),
dtype=X_var.dtype,
shape=X_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=X_var.stop_gradient)
check_variable_and_dtype(
X_var, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'_c_identity')
dst_block.append_op(
type='c_identity',
inputs={'X': [X_var]},
outputs={'Out': intermediate_var_0},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
})
check_variable_and_dtype(intermediate_var_0, 'x',
['float16', 'float32', 'float64'],
'linear')
check_dtype(intermediate_var_0.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear')
attrs = {
'transpose_X': False,
'transpose_Y': False,
'alpha': 1,
}
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
dst_block.append_op(
type='matmul',
inputs=inputs,
outputs={'Out': Out_var},
attrs=attrs)
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
"matmul", 0))
else:
return static_handle
# RowParallel
class DistributedMatmulImpl1(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulImpl1, self).__init__()
self._name = name
self._forward_implemented = True
self._backward_implemented = False
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
......@@ -217,6 +312,86 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
changed = True
return changed
def forward(self, serial_op):
def static_handle(dst_block,
src_op,
op_dist_attr,
input_name_mapping,
output_name_mapping,
rank_id=0):
assert len(
input_name_mapping
) == 2, "col_parallel_linear take 2 inputs variable but got {}".format(
input_name_mapping)
assert len(
output_name_mapping
) == 1, "col_parallel_linear take 2 inputs variable but got {}".format(
output_name_mapping)
assert len(
input_name_mapping['X']
) == 1, "col_parallel_linear input X take 1 variable but got {}".format(
input_name_mapping['X'])
assert len(
input_name_mapping['Y']
) == 1, "col_parallel_linear input Y take 1 variable but got {}".format(
input_name_mapping['Y'])
assert len(
output_name_mapping['Out']
) == 1, "col_parallel_linear input Out take 1 variable but got {}".format(
input_name_mapping['Out'])
X_var = dst_block.var(input_name_mapping['X'][0])
Weight_var = dst_block.var(input_name_mapping['Y'][0])
Out_var = dst_block.var(output_name_mapping['Out'][0])
# TODO infer logic comm presentation
model_parallel_axis, process_mesh = op_dist_attr.get_owner_context(
)._get_model_parallel_info()
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology,
model_parallel_axis, rank_id)
group = new_process_group(group_ranks)
check_variable_and_dtype(
X_var, 'x', ['float16', 'float32', 'float64'], 'linear')
check_dtype(X_var.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear')
attrs = {
'transpose_X': False,
'transpose_Y': False,
'alpha': 1,
}
inputs = {'X': X_var, 'Y': Weight_var}
intermediate_var_0 = dst_block.create_var(
shape=Out_var.shape,
dtype=Out_var.dtype,
type=Out_var.type,
lod_level=Out_var.lod_level,
persistable=False,
is_data=False,
need_check_feed=Out_var.desc.need_check_feed())
dst_block.append_op(
type='matmul',
inputs=inputs,
outputs={'Out': intermediate_var_0},
attrs=attrs)
dst_block.append_op(
type='c_allreduce_sum',
inputs={'X': intermediate_var_0},
outputs={'Out': Out_var},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True
})
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
"matmul", 0))
else:
return static_handle
# ReplicateParallel
class DistributedMatmulImpl2(DistributedOperatorImpl):
......
......@@ -22,6 +22,10 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
class DistributedReshape2(DistributedOperator):
......@@ -37,6 +41,8 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedReshapeImpl0, self).__init__()
self._name = name
self._forward_implemented = True
self._backward_implemented = False
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
......@@ -91,11 +97,90 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
return changed
def forward(self, serial_op):
def static_handle(dst_block,
src_op,
op_dist_attr,
input_name_mapping,
output_name_mapping,
rank_id=0):
assert len(
input_name_mapping
) == 3, "Dist op of Reshape take 3 inputs variable but got {}".format(
input_name_mapping)
assert len(
output_name_mapping
) == 2, "Dist op of Reshape take 2 inputs variable but got {}".format(
output_name_mapping)
assert len(
input_name_mapping['X']
) == 1, "Dist op of Reshape input X take 1 variable but got {}".format(
input_name_mapping['X'])
assert len(
input_name_mapping['ShapeTensor']
) <= 1, "Dist op of Reshape input ShapeTensor take 0 or 1 variable but got {}".format(
input_name_mapping['ShapeTensor'])
assert len(
input_name_mapping['Shape']
) <= 1, "Dist op of Reshape input Shape take 0 or 1 variable but got {}".format(
input_name_mapping['Shape'])
assert len(
output_name_mapping['Out']
) == 1, "Dist op of Reshape input Out take 1 variable but got {}".format(
input_name_mapping['Out'])
assert len(
output_name_mapping['XShape']
) == 1, "Dist op of Reshape input XShape take 1 variable but got {}".format(
input_name_mapping['XShape'])
X_var = dst_block.var(input_name_mapping['X'][0])
Out_var = dst_block.var(output_name_mapping['Out'][0])
XShape_var = dst_block.var(output_name_mapping['XShape'][0])
shape_list = src_op.desc.attr("shape")
ShapeTensor_var_list = []
for name in input_name_mapping['ShapeTensor']:
ShapeTensor_var_list.append(name)
Shape_var_list = []
for name in input_name_mapping['Shape']:
Shape_var_list.append(name)
# got dist attribute info
dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
process_mesh_shape = op_dist_attr.get_process_mesh().topology
# modify target shape
for idx, axis in enumerate(dim_mapping):
if axis >= 0:
if len(shape_list) > idx:
shape_list[idx] = shape_list[idx] // process_mesh_shape[
axis]
# create op
new_op_desc = dst_block.desc.append_op()
new_op_desc.copy_from(src_op.desc)
new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list)
new_op_desc.set_input('Shape', Shape_var_list)
new_op_desc.set_input('X', [X_var.name])
new_op_desc.set_output('XShape', [XShape_var.name])
new_op_desc.set_output('Out', [Out_var.name])
new_op_desc._set_attr('shape', shape_list)
dst_block._sync_with_cpp()
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
"matmul", 0))
else:
return static_handle
class DistributedReshapeImpl1(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedReshapeImpl1, self).__init__()
self._name = name
self._forward_implemented = True
self._backward_implemented = False
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
......@@ -150,6 +235,83 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
return changed
def forward(self, serial_op):
def static_handle(dst_block,
src_op,
op_dist_attr,
input_name_mapping,
output_name_mapping,
rank_id=0):
assert len(
input_name_mapping
) == 3, "Dist op of Reshape take 3 inputs variable but got {}".format(
input_name_mapping)
assert len(
output_name_mapping
) == 2, "Dist op of Reshape take 2 inputs variable but got {}".format(
output_name_mapping)
assert len(
input_name_mapping['X']
) == 1, "Dist op of Reshape input X take 1 variable but got {}".format(
input_name_mapping['X'])
assert len(
input_name_mapping['ShapeTensor']
) <= 1, "Dist op of Reshape input ShapeTensor take 0 or 1 variable but got {}".format(
input_name_mapping['ShapeTensor'])
assert len(
input_name_mapping['Shape']
) <= 1, "Dist op of Reshape input Shape take 0 or 1 variable but got {}".format(
input_name_mapping['Shape'])
assert len(
output_name_mapping['Out']
) == 1, "Dist op of Reshape input Out take 1 variable but got {}".format(
input_name_mapping['Out'])
assert len(
output_name_mapping['XShape']
) == 1, "Dist op of Reshape input XShape take 1 variable but got {}".format(
input_name_mapping['XShape'])
X_var = dst_block.var(input_name_mapping['X'][0])
Out_var = dst_block.var(output_name_mapping['Out'][0])
XShape_var = dst_block.var(output_name_mapping['XShape'][0])
shape_list = src_op.desc.attr("shape")
ShapeTensor_var_list = []
for name in input_name_mapping['ShapeTensor']:
ShapeTensor_var_list.append(name)
Shape_var_list = []
for name in input_name_mapping['Shape']:
Shape_var_list.append(name)
# got dist attribute info
dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
process_mesh_shape = op_dist_attr.get_process_mesh().topology
# modify target shape
for idx, axis in enumerate(dim_mapping):
if axis >= 0:
if len(shape_list) > idx:
shape_list[idx] = shape_list[idx] // process_mesh_shape[
axis]
# create op
new_op_desc = dst_block.desc.append_op()
new_op_desc.copy_from(src_op.desc)
new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list)
new_op_desc.set_input('Shape', Shape_var_list)
new_op_desc.set_input('X', [X_var.name])
new_op_desc.set_output('XShape', [XShape_var.name])
new_op_desc.set_output('Out', [Out_var.name])
new_op_desc._set_attr('shape', shape_list)
dst_block._sync_with_cpp()
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
"matmul", 0))
else:
return static_handle
register_distributed_operator_impl("reshape2",
DistributedReshapeImpl0("add_one_dim_back"))
......
......@@ -47,7 +47,6 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
x_name = op_desc.input('X')[0]
axis = op_desc.attr('axis')
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
# print("softmax axis", axis)
if axis != -1 and axis != len(x_dims_mapping) - 1:
return False
......
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import paddle
import paddle.fluid.core as core
from ..collective import _get_global_env
from ..collective import _new_ring_id
from ...fluid.framework import in_dygraph_mode
from ...fluid.layers.tensor import fill_constant
LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP = None
PROCESSOR_TO_PHYSICAL_PROCESS_MAP = None
def get_all_logical_process_set():
from .interface import _g_process_mesh_map
all_logical_process_set = set(_g_process_mesh_map[0].process_group)
return all_logical_process_set
def get_logical_process_to_physical_process_map():
global LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP
return LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP
def set_logical_process_to_physical_process_map(mapping):
global LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP
LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP = mapping
def get_processor_to_physical_process_map():
global PROCESSOR_TO_PHYSICAL_PROCESS_MAP
return PROCESSOR_TO_PHYSICAL_PROCESS_MAP
def set_processor_to_physical_process_map(mapping):
global PROCESSOR_TO_PHYSICAL_PROCESS_MAP
PROCESSOR_TO_PHYSICAL_PROCESS_MAP = mapping
PROCESS_GROUP_MAP = {}
def get_all_process_groups():
global PROCESS_GROUP_MAP
return PROCESS_GROUP_MAP.values()
def new_process_group(ranks):
global PROCESS_GROUP_MAP
if not PROCESS_GROUP_MAP:
genv = _get_global_env()
PROCESS_GROUP_MAP["global_group"] = ProcessGroup(
0, list(range(genv.world_size)))
# A key constructed from ranks is used in the global process group map
key = ''.join(map(str, sorted(ranks)))
if key not in PROCESS_GROUP_MAP:
num_groups = len(PROCESS_GROUP_MAP)
# Note: our process group may interfere with the original implementation
# so the created group id should start from the original _new_ring_id()
group_id = _new_ring_id() + num_groups + 1
pg = ProcessGroup(group_id, ranks)
PROCESS_GROUP_MAP[key] = pg
return pg
else:
pg = PROCESS_GROUP_MAP[key]
return pg
# This implementation refers to lots of Paddle/python/paddle/distributed/collective.py,
# Fleet also has a collective helper which uses ops to initialize communication in
# Paddle/python/paddle/distributed/fleet/meta_optimizers/common.py. We use the first one
# because it seems simple. This should be enhanced to manage the process membership and
# the instantiation process in a more general way. In the future, the process group may
# handle the communication implementation choice.
class ProcessGroup:
def __init__(self, group_id, ranks):
self._group_id = group_id
self._ranks = sorted(ranks)
self._nranks = len(self._ranks)
self._is_instantiate = False
@property
def id(self):
return self._group_id
# @property
# def key(self):
# return ''.join(map(str, sorted(self._ranks)))
def local_rank(self, global_rank):
if global_rank in self._ranks:
return self._ranks.index(global_rank)
else:
assert False, \
"Rank {} doesn't belong to this group".format(global_rank)
def is_instantiate(self):
return self._is_instantiate
def instantiate(self):
if self._is_instantiate:
return
ring_id = self.id
genv = _get_global_env()
global_rank = genv.rank
if self._nranks >= 2:
strategy = core.ParallelStrategy()
strategy.nranks = self._nranks
strategy.local_rank = self.local_rank(global_rank)
strategy.trainer_endpoints = [
genv.trainer_endpoints[i] for i in self._ranks
]
strategy.current_endpoint = genv.current_endpoint
strategy.nrings = 1
if core.is_compiled_with_cuda():
place = core.CUDAPlace(genv.device_id)
core.NCCLParallelContext(strategy,
place).init_with_ring_id(ring_id)
else:
assert False, ("No CUDA device found")
# TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by cross-creation of new_group
tmp = paddle.to_tensor(
[1], dtype="int32") if in_dygraph_mode() else fill_constant(
[0], dtype="int32", value="1")
paddle.distributed.all_reduce(tmp, use_calc_stream=True)
paddle.distributed.wait(tmp)
self._is_instantiate = True
def __str__(self):
string = "id: {}, nranks: {}, ranks: {}.".format(
self.id, self._nranks, ", ".join(map(str, self._ranks)))
return string
......@@ -14,6 +14,7 @@
import threading
import paddle.fluid.core as core
import numpy as np
def is_valid_list_index(list, index):
......@@ -155,3 +156,125 @@ def print_program_with_distributed_attr(program, dist_context=None):
print(program)
set_default_distributed_context(original_default_context)
lock.release()
def _get_comm_group(processes, shape, axis, rank):
"""
Given a rank and the processes mesh the rank belongs to,
compute the communication peers of the rank based on the give axis in the mesh.
Example: 16 processes managed in a 4-Dimensinal mesh with shape of [2, 2, 2, 2].
the rank communication peers of rank 0 (included) are following:
in axis 0: [0, 1]
in axis 1: [0, 2]
in axis 2: [0, 4]
in axis 3: [0, 8]
"""
# NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous
# tricks to support processes mesh when it is not start with 0 or continuous
rank_relatvie = processes.index(rank)
coordinate = _linear_idx2coordinate(shape, rank_relatvie)
coordinates_in_group = [coordinate[:] for i in range(shape[axis])]
# select comm group
for i in range(shape[axis]):
coordinates_in_group[i][axis] = i
ranks_in_group_relative = [
_coordinate2linear_idx(shape, coordinate)
for coordinate in coordinates_in_group
]
ranks_in_group = [processes[idx] for idx in ranks_in_group_relative]
return sorted(ranks_in_group)
def _coordinate2linear_idx(mesh_shape, coordinate):
"""
convert a coordinate in multidimensional mesh space into a scala idx in linear space.
it use Row-major order for dimension conversion.
so it has: [most_significant_dim, ..., least_significant_dim]
assume:
the size of i-th dimension to be: S[i]
the index of j-th dimension is: I[j]
linear_idx of a n dimensional coordinate is:
I[n-1] * (S[n-2] * S[n-3] * S[n-4] * .... S[0]) +
I[n-2] * ( S[n-3] * S[n-4] * .... S[0]) +
I[n-3] * ( S[n-4] * .... S[0]) +
...
I[1] * ( S[0]) +
I[0]
"""
# NOTE the following function work based on a strong an assumption
# that the processes in mesh are
# 1. starts from 0
# 2. continuous
# it will be wrong if ths above condition doesnot meet,
# e.g. process_mesh = { process_groups = [7, 8, 9,10, 12, 13, 14, 15], mesh = [2, 4]}
# if you want a more general mapping, you should use cartesian product
assert len(mesh_shape) == len(
coordinate
), "coordinate should have the same size as mesh shape, but got shape: {}, coordinate: {}".format(
mesh_shape, coordinate)
for i in range(len(mesh_shape)):
assert coordinate[
i] >= 0, "index in dimension [{}] is least than zero. coordinate: {}".format(
i, coordinate)
assert coordinate[i] < mesh_shape[
i], "index beyond extent in dimension [{}]. shape: {}, coordinate: {}".format(
i, mesh_shape, coordinate)
base = mesh_shape[-1]
linear_idx = coordinate[-1]
# row major order
for i in range(len(mesh_shape) - 2, -1, -1):
linear_idx += base * coordinate[i]
base *= mesh_shape[i]
return linear_idx
def _linear_idx2coordinate(mesh_shape, linear_idx):
"""
mapping a linear scala into multidimensional mesh space, return it coordinate in that space.
it is the inverse function of _coordinate2linear_idx.
assume:
the size of i-th dimension to be: S[i]
the index of j-th dimension is: I[j]
the coordinate given linear_idx is:
I[0] = linear_idx % S[0]
I[0] = (linear_idx / S[0]) % S[1]
I[0] = (linear_idx / (S[0] * S[1])) % S[2]
....
"""
assert linear_idx >= 0, "linear index [{}] is least than zero".format(
linear_idx)
assert linear_idx < np.prod(
mesh_shape
), "linear index beyond the extent of mesh shape. shape: {}, linear index: {}".format(
mesh_shape, linear_idx)
base = 1
coordinate = [-1] * len(mesh_shape)
for i in reversed(range(len(mesh_shape))):
offset = linear_idx / base
coordinate[i] = int(offset % mesh_shape[i])
base *= mesh_shape[i]
# row major order
return coordinate
......@@ -79,6 +79,8 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_meta_optimizer_base)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_distributed_strategy)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_auto)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_static_mp_layers)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner_gpt)
foreach(TEST_OP ${MIXED_DIST_TEST_OPS})
list(REMOVE_ITEM TEST_OPS ${TEST_OP})
endforeach()
......@@ -206,6 +208,8 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
LIST(REMOVE_ITEM TEST_OPS test_dygraph_recompute)
list(REMOVE_ITEM TEST_OPS test_parallel_class_center_sample)
LIST(REMOVE_ITEM TEST_OPS test_parallel_margin_cross_entropy)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner_gpt)
elseif(WITH_GPU)
if (${CUDNN_VERSION} VERSION_LESS 7100)
LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册