未验证 提交 a622b701 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] Logical Partition & Dist Op (#35117)

* support shard reader

* support shard reader

* add parallel mode

* update process mesh

* add method to compute comm_group

* implement dist_embedding forward func

* implement dist matmul forward func

* implement dist reshape forward func

* add transpiler framework

* add transpiler forward

* implement transpiler forward

* implement transpiler backward & update

* add process

* add unitest

* chmod

* chmod

* chmod

* update unitest

* add unitest for gpt

* remove unused print

* rename transpiler --> partitioner

* rename transpiler --> partitioner

* chmod

* chmod

* bug fixed

* remove amp function

* update case for dp mode

* update case for dp mode
上级 280d7421
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import copy import copy
from collections import defaultdict from collections import defaultdict
from paddle.fluid import core
class TensorDistributedAttribute: class TensorDistributedAttribute:
...@@ -77,6 +78,8 @@ class TensorDistributedAttribute: ...@@ -77,6 +78,8 @@ class TensorDistributedAttribute:
self._is_parameter = True self._is_parameter = True
def is_valid(self): def is_valid(self):
if self.get_owner_tensor().type == core.VarDesc.VarType.READER:
return True
tensor_shape = self.get_owner_tensor().desc.shape() tensor_shape = self.get_owner_tensor().desc.shape()
if len(tensor_shape) != len(self.get_dims_mapping()): if len(tensor_shape) != len(self.get_dims_mapping()):
return False return False
...@@ -222,6 +225,8 @@ class OperatorDistributedAttribute: ...@@ -222,6 +225,8 @@ class OperatorDistributedAttribute:
self._is_parameters[name] = True self._is_parameters[name] = True
def is_valid(self): def is_valid(self):
if "read" in self.get_owner_op().type:
return True
for name in self.get_owner_op().desc.input_arg_names(): for name in self.get_owner_op().desc.input_arg_names():
dims_mapping = self.get_input_dims_mapping(name) dims_mapping = self.get_input_dims_mapping(name)
shape = self.get_input_shape(name) shape = self.get_input_shape(name)
......
...@@ -15,9 +15,11 @@ ...@@ -15,9 +15,11 @@
import copy import copy
from collections import defaultdict from collections import defaultdict
from paddle.fluid import framework from paddle.fluid import framework
from paddle.fluid import core
from .attribute import TensorDistributedAttribute from .attribute import TensorDistributedAttribute
from .attribute import OperatorDistributedAttribute from .attribute import OperatorDistributedAttribute
from .utils import append_distributed_attr_suffix from .utils import append_distributed_attr_suffix
from .interface import _g_process_mesh_map
# There always exists a default context for user. And user can set it to another one. # There always exists a default context for user. And user can set it to another one.
DEFAULT_DISTRIBUTED_CONTEXT = None DEFAULT_DISTRIBUTED_CONTEXT = None
...@@ -49,6 +51,20 @@ class DistributedContext: ...@@ -49,6 +51,20 @@ class DistributedContext:
self._op_distributed_attr_map_for_program = {} self._op_distributed_attr_map_for_program = {}
self._tensor_distributed_attr_map_for_graph = {} self._tensor_distributed_attr_map_for_graph = {}
self._op_distributed_attr_map_for_graph = {} self._op_distributed_attr_map_for_graph = {}
# The following is a hard code and will be removed in the future
self._data_parallel_axis = None
self._model_parallel_axis = None
self._process_mesh = _g_process_mesh_map.get(0, None)
if self._process_mesh is not None:
if self._process_mesh.ndim == 1:
self._data_parallel_axis = 0
self._model_parallel_axis = 0
else:
self._data_parallel_axis = 0
self._model_parallel_axis = 1
else:
self._data_parallel_axis = -1
self._model_parallel_axis = -1
def is_initialized_for_program(self): def is_initialized_for_program(self):
return self._is_initialized_for_program return self._is_initialized_for_program
...@@ -99,6 +115,19 @@ class DistributedContext: ...@@ -99,6 +115,19 @@ class DistributedContext:
op_node_id = op_node.id() op_node_id = op_node.id()
self._op_distributed_attr_map_for_graph[op_node_id] = op_dist_attr self._op_distributed_attr_map_for_graph[op_node_id] = op_dist_attr
def set_process_mesh(self, process_mesh):
self._process_mesh = process_mesh
if self._process_mesh is not None:
if self._process_mesh.ndim == 1:
self._data_parallel_axis = 0
self._model_parallel_axis = 0
else:
self._data_parallel_axis = 0
self._model_parallel_axis = 1
else:
self._data_parallel_axis = -1
self._model_parallel_axis = -1
def initialize_distributed_attr_for_program(self, program): def initialize_distributed_attr_for_program(self, program):
if self._is_initialized_for_program: if self._is_initialized_for_program:
return return
...@@ -377,3 +406,11 @@ class DistributedContext: ...@@ -377,3 +406,11 @@ class DistributedContext:
if dims_mapping[i] != -1 and process_mesh_shape[ if dims_mapping[i] != -1 and process_mesh_shape[
dims_mapping[i]] > tensor_shape[i]: dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1 dims_mapping[i] = -1
def _get_data_parallel_info(self):
# This function is a hard code, and will be obsoleted in the future
return self._data_parallel_axis, self._process_mesh
def _get_model_parallel_info(self):
# This function is a hard code, and will be obsoleted in the future
return self._model_parallel_axis, self._process_mesh
...@@ -184,6 +184,13 @@ class ProcessMesh(object): ...@@ -184,6 +184,13 @@ class ProcessMesh(object):
"parent with id %d does not exist." % self._parent_id) "parent with id %d does not exist." % self._parent_id)
return _g_process_mesh_map[self._parent_id] return _g_process_mesh_map[self._parent_id]
@property
def ndim(self):
r"""
Get the number of dimension of ProcessMesh.
"""
return len(self._topology)
def set_placement(self, order): def set_placement(self, order):
""" """
Set the map from logical processes to physical ones using the Set the map from logical processes to physical ones using the
...@@ -229,6 +236,13 @@ class ProcessMesh(object): ...@@ -229,6 +236,13 @@ class ProcessMesh(object):
for idx, l_id in enumerate(logical_order): for idx, l_id in enumerate(logical_order):
_user_defined_physical_map[l_id] = order[idx] _user_defined_physical_map[l_id] = order[idx]
def _reset_global_process_mesh_map(self):
"""
Remove all process mesh in _g_process_mesh_map, make it empty.
"""
_g_process_mesh_map = dict()
def __eq__(self, other): def __eq__(self, other):
assert other and isinstance(other, ProcessMesh) assert other and isinstance(other, ProcessMesh)
if self.topology != other.topology or self.process_group != other.process_group: if self.topology != other.topology or self.process_group != other.process_group:
......
...@@ -33,6 +33,8 @@ class DistributedOperator: ...@@ -33,6 +33,8 @@ class DistributedOperator:
class DistributedOperatorImpl: class DistributedOperatorImpl:
def __init__(self): def __init__(self):
self._name = None self._name = None
self._forward_implemented = False
self._backward_implemented = False
def forward(self, dist_ctx, *args, **kwargs): def forward(self, dist_ctx, *args, **kwargs):
raise NotImplementedError("Please Implement this method in Subclass.") raise NotImplementedError("Please Implement this method in Subclass.")
......
...@@ -22,6 +22,12 @@ from ..utils import is_valid_list_index ...@@ -22,6 +22,12 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from ..process import new_process_group
from ..utils import _get_comm_group
class DistributedEmbedding(DistributedOperator): class DistributedEmbedding(DistributedOperator):
...@@ -39,6 +45,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -39,6 +45,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedEmbeddingImpl, self).__init__() super(DistributedEmbeddingImpl, self).__init__()
self._name = name self._name = name
self._forward_implemented = True
self._backward_implemented = False
def is_process_mesh_compatible(self, op_dist_attr): def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """ """ No restriction for now. """
...@@ -92,6 +100,110 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -92,6 +100,110 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
return changed return changed
def forward(self, serial_op):
def static_handle(dst_block,
src_op,
op_dist_attr,
input_name_mapping,
output_name_mapping,
rank_id=0):
assert len(
input_name_mapping
) == 2, "row_parallel_embedding take 2 inputs variable but got {}".format(
input_name_mapping)
assert len(
output_name_mapping
) == 1, "row_parallel_embedding take 2 inputs variable but got {}".format(
output_name_mapping)
assert len(
input_name_mapping['Ids']
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format(
input_name_mapping['Ids'])
assert len(
input_name_mapping['W']
) == 1, "row_parallel_embedding input W take 1 variable but got {}".format(
input_name_mapping['W'])
assert len(
output_name_mapping['Out']
) == 1, "row_parallel_embedding input Out take 1 variable but got {}".format(
input_name_mapping['Out'])
Ids_var = dst_block.var(input_name_mapping['Ids'][0])
Weight_var = dst_block.var(input_name_mapping['W'][0])
Out_var = dst_block.var(output_name_mapping['Out'][0])
# got dist attribute info
embedding_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[0]
process_mesh_shape = op_dist_attr.get_process_mesh().topology
process_mesh_group = op_dist_attr.get_process_mesh().process_group
# caculate embedding offset
# TODO generalize here, using cartisian product to allow any dimensional mesh shape
mesh_shape = len(process_mesh_shape)
assert mesh_shape <= 2, "row_parallel_embedding only support 1 or 2 dimensional process mesh, but got {}".format(
process_mesh_shape)
num_partition = process_mesh_shape[embedding_row_dim_mapping]
# TODO generalize here, support any mesh group
if mesh_shape == 1:
relative_idx = process_mesh_group.index(rank_id)
else:
relative_idx = rank_id % num_partition
per_part_size = Weight_var.shape[0]
relative_idx = relative_idx * per_part_size
# TODO caculate ring id
model_parallel_axis, process_mesh = op_dist_attr.get_owner_context(
)._get_model_parallel_info()
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology,
model_parallel_axis, rank_id)
group = new_process_group(group_ranks)
# append op
check_variable_and_dtype(Ids_var, 'input', ['int32', 'int64'],
'c_embedding')
intermediate_var_0 = dst_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_embedding", 'tmp'])),
dtype=Weight_var.dtype,
shape=Out_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=Out_var.stop_gradient)
check_variable_and_dtype(
Out_var, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'c_allreduce_sum')
dst_block.append_op(
type='c_embedding',
inputs={'Ids': [Ids_var],
'W': [Weight_var]},
outputs={'Out': [intermediate_var_0]},
attrs={"start_index": relative_idx})
# use_model_parallel
dst_block.append_op(
type='c_allreduce_sum',
inputs={'X': [intermediate_var_0]},
outputs={'Out': [Out_var]},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
})
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
"matmul", 0))
else:
return static_handle
register_distributed_operator_impl("lookup_table_v2", register_distributed_operator_impl("lookup_table_v2",
DistributedEmbeddingImpl("row_parallel")) DistributedEmbeddingImpl("row_parallel"))
...@@ -22,6 +22,12 @@ from ..utils import is_valid_list_index ...@@ -22,6 +22,12 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from ..process import new_process_group
from ..utils import _get_comm_group
def _update_dims_mapping_for_matmul(op_dist_attr): def _update_dims_mapping_for_matmul(op_dist_attr):
...@@ -37,7 +43,6 @@ def _update_dims_mapping_for_matmul(op_dist_attr): ...@@ -37,7 +43,6 @@ def _update_dims_mapping_for_matmul(op_dist_attr):
y_dims_mapping_len = len(y_dims_mapping) y_dims_mapping_len = len(y_dims_mapping)
out_dims_mapping_len = len(out_dims_mapping) out_dims_mapping_len = len(out_dims_mapping)
# print("before", x_dims_mapping, y_dims_mapping, out_dims_mapping)
# Add dim mapping to Make sure the length dims_mapping be at least 2 # Add dim mapping to Make sure the length dims_mapping be at least 2
if x_dims_mapping_len == 1: if x_dims_mapping_len == 1:
x_dims_mapping.insert(0, -1) x_dims_mapping.insert(0, -1)
...@@ -109,7 +114,6 @@ def _update_dims_mapping_for_matmul(op_dist_attr): ...@@ -109,7 +114,6 @@ def _update_dims_mapping_for_matmul(op_dist_attr):
if y_dims_mapping_len == 1: if y_dims_mapping_len == 1:
y_dims_mapping.pop(1) y_dims_mapping.pop(1)
# print("after", x_dims_mapping, y_dims_mapping, out_dims_mapping)
assert len(x_dims_mapping) == x_dims_mapping_len assert len(x_dims_mapping) == x_dims_mapping_len
assert len(y_dims_mapping) == y_dims_mapping_len assert len(y_dims_mapping) == y_dims_mapping_len
assert len(out_dims_mapping) == out_dims_mapping_len assert len(out_dims_mapping) == out_dims_mapping_len
...@@ -131,6 +135,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -131,6 +135,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedMatmulImpl0, self).__init__() super(DistributedMatmulImpl0, self).__init__()
self._name = name self._name = name
self._forward_implemented = True
self._backward_implemented = False
def is_process_mesh_compatible(self, op_dist_attr): def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """ """ No restriction for now. """
...@@ -170,12 +176,101 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -170,12 +176,101 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
changed = True changed = True
return changed return changed
def forward(self, serial_op):
def static_handle(dst_block,
src_op,
op_dist_attr,
input_name_mapping,
output_name_mapping,
rank_id=0):
assert len(
input_name_mapping
) == 2, "col_parallel_linear take 2 inputs variable but got {}".format(
input_name_mapping)
assert len(
output_name_mapping
) == 1, "col_parallel_linear take 2 inputs variable but got {}".format(
output_name_mapping)
assert len(
input_name_mapping['X']
) == 1, "col_parallel_linear input X take 1 variable but got {}".format(
input_name_mapping['X'])
assert len(
input_name_mapping['Y']
) == 1, "col_parallel_linear input Y take 1 variable but got {}".format(
input_name_mapping['Y'])
assert len(
output_name_mapping['Out']
) == 1, "col_parallel_linear input Out take 1 variable but got {}".format(
input_name_mapping['Out'])
X_var = dst_block.var(input_name_mapping['X'][0])
Weight_var = dst_block.var(input_name_mapping['Y'][0])
Out_var = dst_block.var(output_name_mapping['Out'][0])
# TODO infer logic comm presentation
model_parallel_axis, process_mesh = op_dist_attr.get_owner_context(
)._get_model_parallel_info()
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology,
model_parallel_axis, rank_id)
group = new_process_group(group_ranks)
intermediate_var_0 = dst_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_identity", 'tmp'])),
dtype=X_var.dtype,
shape=X_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=X_var.stop_gradient)
check_variable_and_dtype(
X_var, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'_c_identity')
dst_block.append_op(
type='c_identity',
inputs={'X': [X_var]},
outputs={'Out': intermediate_var_0},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
})
check_variable_and_dtype(intermediate_var_0, 'x',
['float16', 'float32', 'float64'],
'linear')
check_dtype(intermediate_var_0.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear')
attrs = {
'transpose_X': False,
'transpose_Y': False,
'alpha': 1,
}
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
dst_block.append_op(
type='matmul',
inputs=inputs,
outputs={'Out': Out_var},
attrs=attrs)
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
"matmul", 0))
else:
return static_handle
# RowParallel # RowParallel
class DistributedMatmulImpl1(DistributedOperatorImpl): class DistributedMatmulImpl1(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedMatmulImpl1, self).__init__() super(DistributedMatmulImpl1, self).__init__()
self._name = name self._name = name
self._forward_implemented = True
self._backward_implemented = False
def is_process_mesh_compatible(self, op_dist_attr): def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """ """ No restriction for now. """
...@@ -217,6 +312,86 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -217,6 +312,86 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
changed = True changed = True
return changed return changed
def forward(self, serial_op):
def static_handle(dst_block,
src_op,
op_dist_attr,
input_name_mapping,
output_name_mapping,
rank_id=0):
assert len(
input_name_mapping
) == 2, "col_parallel_linear take 2 inputs variable but got {}".format(
input_name_mapping)
assert len(
output_name_mapping
) == 1, "col_parallel_linear take 2 inputs variable but got {}".format(
output_name_mapping)
assert len(
input_name_mapping['X']
) == 1, "col_parallel_linear input X take 1 variable but got {}".format(
input_name_mapping['X'])
assert len(
input_name_mapping['Y']
) == 1, "col_parallel_linear input Y take 1 variable but got {}".format(
input_name_mapping['Y'])
assert len(
output_name_mapping['Out']
) == 1, "col_parallel_linear input Out take 1 variable but got {}".format(
input_name_mapping['Out'])
X_var = dst_block.var(input_name_mapping['X'][0])
Weight_var = dst_block.var(input_name_mapping['Y'][0])
Out_var = dst_block.var(output_name_mapping['Out'][0])
# TODO infer logic comm presentation
model_parallel_axis, process_mesh = op_dist_attr.get_owner_context(
)._get_model_parallel_info()
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology,
model_parallel_axis, rank_id)
group = new_process_group(group_ranks)
check_variable_and_dtype(
X_var, 'x', ['float16', 'float32', 'float64'], 'linear')
check_dtype(X_var.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear')
attrs = {
'transpose_X': False,
'transpose_Y': False,
'alpha': 1,
}
inputs = {'X': X_var, 'Y': Weight_var}
intermediate_var_0 = dst_block.create_var(
shape=Out_var.shape,
dtype=Out_var.dtype,
type=Out_var.type,
lod_level=Out_var.lod_level,
persistable=False,
is_data=False,
need_check_feed=Out_var.desc.need_check_feed())
dst_block.append_op(
type='matmul',
inputs=inputs,
outputs={'Out': intermediate_var_0},
attrs=attrs)
dst_block.append_op(
type='c_allreduce_sum',
inputs={'X': intermediate_var_0},
outputs={'Out': Out_var},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True
})
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
"matmul", 0))
else:
return static_handle
# ReplicateParallel # ReplicateParallel
class DistributedMatmulImpl2(DistributedOperatorImpl): class DistributedMatmulImpl2(DistributedOperatorImpl):
......
...@@ -22,6 +22,10 @@ from ..utils import is_valid_list_index ...@@ -22,6 +22,10 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
class DistributedReshape2(DistributedOperator): class DistributedReshape2(DistributedOperator):
...@@ -37,6 +41,8 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -37,6 +41,8 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedReshapeImpl0, self).__init__() super(DistributedReshapeImpl0, self).__init__()
self._name = name self._name = name
self._forward_implemented = True
self._backward_implemented = False
def is_process_mesh_compatible(self, op_dist_attr): def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """ """ No restriction for now. """
...@@ -91,11 +97,90 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -91,11 +97,90 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
return changed return changed
def forward(self, serial_op):
def static_handle(dst_block,
src_op,
op_dist_attr,
input_name_mapping,
output_name_mapping,
rank_id=0):
assert len(
input_name_mapping
) == 3, "Dist op of Reshape take 3 inputs variable but got {}".format(
input_name_mapping)
assert len(
output_name_mapping
) == 2, "Dist op of Reshape take 2 inputs variable but got {}".format(
output_name_mapping)
assert len(
input_name_mapping['X']
) == 1, "Dist op of Reshape input X take 1 variable but got {}".format(
input_name_mapping['X'])
assert len(
input_name_mapping['ShapeTensor']
) <= 1, "Dist op of Reshape input ShapeTensor take 0 or 1 variable but got {}".format(
input_name_mapping['ShapeTensor'])
assert len(
input_name_mapping['Shape']
) <= 1, "Dist op of Reshape input Shape take 0 or 1 variable but got {}".format(
input_name_mapping['Shape'])
assert len(
output_name_mapping['Out']
) == 1, "Dist op of Reshape input Out take 1 variable but got {}".format(
input_name_mapping['Out'])
assert len(
output_name_mapping['XShape']
) == 1, "Dist op of Reshape input XShape take 1 variable but got {}".format(
input_name_mapping['XShape'])
X_var = dst_block.var(input_name_mapping['X'][0])
Out_var = dst_block.var(output_name_mapping['Out'][0])
XShape_var = dst_block.var(output_name_mapping['XShape'][0])
shape_list = src_op.desc.attr("shape")
ShapeTensor_var_list = []
for name in input_name_mapping['ShapeTensor']:
ShapeTensor_var_list.append(name)
Shape_var_list = []
for name in input_name_mapping['Shape']:
Shape_var_list.append(name)
# got dist attribute info
dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
process_mesh_shape = op_dist_attr.get_process_mesh().topology
# modify target shape
for idx, axis in enumerate(dim_mapping):
if axis >= 0:
if len(shape_list) > idx:
shape_list[idx] = shape_list[idx] // process_mesh_shape[
axis]
# create op
new_op_desc = dst_block.desc.append_op()
new_op_desc.copy_from(src_op.desc)
new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list)
new_op_desc.set_input('Shape', Shape_var_list)
new_op_desc.set_input('X', [X_var.name])
new_op_desc.set_output('XShape', [XShape_var.name])
new_op_desc.set_output('Out', [Out_var.name])
new_op_desc._set_attr('shape', shape_list)
dst_block._sync_with_cpp()
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
"matmul", 0))
else:
return static_handle
class DistributedReshapeImpl1(DistributedOperatorImpl): class DistributedReshapeImpl1(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedReshapeImpl1, self).__init__() super(DistributedReshapeImpl1, self).__init__()
self._name = name self._name = name
self._forward_implemented = True
self._backward_implemented = False
def is_process_mesh_compatible(self, op_dist_attr): def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """ """ No restriction for now. """
...@@ -150,6 +235,83 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -150,6 +235,83 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
return changed return changed
def forward(self, serial_op):
def static_handle(dst_block,
src_op,
op_dist_attr,
input_name_mapping,
output_name_mapping,
rank_id=0):
assert len(
input_name_mapping
) == 3, "Dist op of Reshape take 3 inputs variable but got {}".format(
input_name_mapping)
assert len(
output_name_mapping
) == 2, "Dist op of Reshape take 2 inputs variable but got {}".format(
output_name_mapping)
assert len(
input_name_mapping['X']
) == 1, "Dist op of Reshape input X take 1 variable but got {}".format(
input_name_mapping['X'])
assert len(
input_name_mapping['ShapeTensor']
) <= 1, "Dist op of Reshape input ShapeTensor take 0 or 1 variable but got {}".format(
input_name_mapping['ShapeTensor'])
assert len(
input_name_mapping['Shape']
) <= 1, "Dist op of Reshape input Shape take 0 or 1 variable but got {}".format(
input_name_mapping['Shape'])
assert len(
output_name_mapping['Out']
) == 1, "Dist op of Reshape input Out take 1 variable but got {}".format(
input_name_mapping['Out'])
assert len(
output_name_mapping['XShape']
) == 1, "Dist op of Reshape input XShape take 1 variable but got {}".format(
input_name_mapping['XShape'])
X_var = dst_block.var(input_name_mapping['X'][0])
Out_var = dst_block.var(output_name_mapping['Out'][0])
XShape_var = dst_block.var(output_name_mapping['XShape'][0])
shape_list = src_op.desc.attr("shape")
ShapeTensor_var_list = []
for name in input_name_mapping['ShapeTensor']:
ShapeTensor_var_list.append(name)
Shape_var_list = []
for name in input_name_mapping['Shape']:
Shape_var_list.append(name)
# got dist attribute info
dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
process_mesh_shape = op_dist_attr.get_process_mesh().topology
# modify target shape
for idx, axis in enumerate(dim_mapping):
if axis >= 0:
if len(shape_list) > idx:
shape_list[idx] = shape_list[idx] // process_mesh_shape[
axis]
# create op
new_op_desc = dst_block.desc.append_op()
new_op_desc.copy_from(src_op.desc)
new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list)
new_op_desc.set_input('Shape', Shape_var_list)
new_op_desc.set_input('X', [X_var.name])
new_op_desc.set_output('XShape', [XShape_var.name])
new_op_desc.set_output('Out', [Out_var.name])
new_op_desc._set_attr('shape', shape_list)
dst_block._sync_with_cpp()
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
"matmul", 0))
else:
return static_handle
register_distributed_operator_impl("reshape2", register_distributed_operator_impl("reshape2",
DistributedReshapeImpl0("add_one_dim_back")) DistributedReshapeImpl0("add_one_dim_back"))
......
...@@ -47,7 +47,6 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -47,7 +47,6 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
axis = op_desc.attr('axis') axis = op_desc.attr('axis')
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
# print("softmax axis", axis)
if axis != -1 and axis != len(x_dims_mapping) - 1: if axis != -1 and axis != len(x_dims_mapping) - 1:
return False return False
......
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import paddle
import paddle.fluid.core as core
from ..collective import _get_global_env
from ..collective import _new_ring_id
from ...fluid.framework import in_dygraph_mode
from ...fluid.layers.tensor import fill_constant
LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP = None
PROCESSOR_TO_PHYSICAL_PROCESS_MAP = None
def get_all_logical_process_set():
from .interface import _g_process_mesh_map
all_logical_process_set = set(_g_process_mesh_map[0].process_group)
return all_logical_process_set
def get_logical_process_to_physical_process_map():
global LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP
return LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP
def set_logical_process_to_physical_process_map(mapping):
global LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP
LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP = mapping
def get_processor_to_physical_process_map():
global PROCESSOR_TO_PHYSICAL_PROCESS_MAP
return PROCESSOR_TO_PHYSICAL_PROCESS_MAP
def set_processor_to_physical_process_map(mapping):
global PROCESSOR_TO_PHYSICAL_PROCESS_MAP
PROCESSOR_TO_PHYSICAL_PROCESS_MAP = mapping
PROCESS_GROUP_MAP = {}
def get_all_process_groups():
global PROCESS_GROUP_MAP
return PROCESS_GROUP_MAP.values()
def new_process_group(ranks):
global PROCESS_GROUP_MAP
if not PROCESS_GROUP_MAP:
genv = _get_global_env()
PROCESS_GROUP_MAP["global_group"] = ProcessGroup(
0, list(range(genv.world_size)))
# A key constructed from ranks is used in the global process group map
key = ''.join(map(str, sorted(ranks)))
if key not in PROCESS_GROUP_MAP:
num_groups = len(PROCESS_GROUP_MAP)
# Note: our process group may interfere with the original implementation
# so the created group id should start from the original _new_ring_id()
group_id = _new_ring_id() + num_groups + 1
pg = ProcessGroup(group_id, ranks)
PROCESS_GROUP_MAP[key] = pg
return pg
else:
pg = PROCESS_GROUP_MAP[key]
return pg
# This implementation refers to lots of Paddle/python/paddle/distributed/collective.py,
# Fleet also has a collective helper which uses ops to initialize communication in
# Paddle/python/paddle/distributed/fleet/meta_optimizers/common.py. We use the first one
# because it seems simple. This should be enhanced to manage the process membership and
# the instantiation process in a more general way. In the future, the process group may
# handle the communication implementation choice.
class ProcessGroup:
def __init__(self, group_id, ranks):
self._group_id = group_id
self._ranks = sorted(ranks)
self._nranks = len(self._ranks)
self._is_instantiate = False
@property
def id(self):
return self._group_id
# @property
# def key(self):
# return ''.join(map(str, sorted(self._ranks)))
def local_rank(self, global_rank):
if global_rank in self._ranks:
return self._ranks.index(global_rank)
else:
assert False, \
"Rank {} doesn't belong to this group".format(global_rank)
def is_instantiate(self):
return self._is_instantiate
def instantiate(self):
if self._is_instantiate:
return
ring_id = self.id
genv = _get_global_env()
global_rank = genv.rank
if self._nranks >= 2:
strategy = core.ParallelStrategy()
strategy.nranks = self._nranks
strategy.local_rank = self.local_rank(global_rank)
strategy.trainer_endpoints = [
genv.trainer_endpoints[i] for i in self._ranks
]
strategy.current_endpoint = genv.current_endpoint
strategy.nrings = 1
if core.is_compiled_with_cuda():
place = core.CUDAPlace(genv.device_id)
core.NCCLParallelContext(strategy,
place).init_with_ring_id(ring_id)
else:
assert False, ("No CUDA device found")
# TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by cross-creation of new_group
tmp = paddle.to_tensor(
[1], dtype="int32") if in_dygraph_mode() else fill_constant(
[0], dtype="int32", value="1")
paddle.distributed.all_reduce(tmp, use_calc_stream=True)
paddle.distributed.wait(tmp)
self._is_instantiate = True
def __str__(self):
string = "id: {}, nranks: {}, ranks: {}.".format(
self.id, self._nranks, ", ".join(map(str, self._ranks)))
return string
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import threading import threading
import paddle.fluid.core as core import paddle.fluid.core as core
import numpy as np
def is_valid_list_index(list, index): def is_valid_list_index(list, index):
...@@ -155,3 +156,125 @@ def print_program_with_distributed_attr(program, dist_context=None): ...@@ -155,3 +156,125 @@ def print_program_with_distributed_attr(program, dist_context=None):
print(program) print(program)
set_default_distributed_context(original_default_context) set_default_distributed_context(original_default_context)
lock.release() lock.release()
def _get_comm_group(processes, shape, axis, rank):
"""
Given a rank and the processes mesh the rank belongs to,
compute the communication peers of the rank based on the give axis in the mesh.
Example: 16 processes managed in a 4-Dimensinal mesh with shape of [2, 2, 2, 2].
the rank communication peers of rank 0 (included) are following:
in axis 0: [0, 1]
in axis 1: [0, 2]
in axis 2: [0, 4]
in axis 3: [0, 8]
"""
# NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous
# tricks to support processes mesh when it is not start with 0 or continuous
rank_relatvie = processes.index(rank)
coordinate = _linear_idx2coordinate(shape, rank_relatvie)
coordinates_in_group = [coordinate[:] for i in range(shape[axis])]
# select comm group
for i in range(shape[axis]):
coordinates_in_group[i][axis] = i
ranks_in_group_relative = [
_coordinate2linear_idx(shape, coordinate)
for coordinate in coordinates_in_group
]
ranks_in_group = [processes[idx] for idx in ranks_in_group_relative]
return sorted(ranks_in_group)
def _coordinate2linear_idx(mesh_shape, coordinate):
"""
convert a coordinate in multidimensional mesh space into a scala idx in linear space.
it use Row-major order for dimension conversion.
so it has: [most_significant_dim, ..., least_significant_dim]
assume:
the size of i-th dimension to be: S[i]
the index of j-th dimension is: I[j]
linear_idx of a n dimensional coordinate is:
I[n-1] * (S[n-2] * S[n-3] * S[n-4] * .... S[0]) +
I[n-2] * ( S[n-3] * S[n-4] * .... S[0]) +
I[n-3] * ( S[n-4] * .... S[0]) +
...
I[1] * ( S[0]) +
I[0]
"""
# NOTE the following function work based on a strong an assumption
# that the processes in mesh are
# 1. starts from 0
# 2. continuous
# it will be wrong if ths above condition doesnot meet,
# e.g. process_mesh = { process_groups = [7, 8, 9,10, 12, 13, 14, 15], mesh = [2, 4]}
# if you want a more general mapping, you should use cartesian product
assert len(mesh_shape) == len(
coordinate
), "coordinate should have the same size as mesh shape, but got shape: {}, coordinate: {}".format(
mesh_shape, coordinate)
for i in range(len(mesh_shape)):
assert coordinate[
i] >= 0, "index in dimension [{}] is least than zero. coordinate: {}".format(
i, coordinate)
assert coordinate[i] < mesh_shape[
i], "index beyond extent in dimension [{}]. shape: {}, coordinate: {}".format(
i, mesh_shape, coordinate)
base = mesh_shape[-1]
linear_idx = coordinate[-1]
# row major order
for i in range(len(mesh_shape) - 2, -1, -1):
linear_idx += base * coordinate[i]
base *= mesh_shape[i]
return linear_idx
def _linear_idx2coordinate(mesh_shape, linear_idx):
"""
mapping a linear scala into multidimensional mesh space, return it coordinate in that space.
it is the inverse function of _coordinate2linear_idx.
assume:
the size of i-th dimension to be: S[i]
the index of j-th dimension is: I[j]
the coordinate given linear_idx is:
I[0] = linear_idx % S[0]
I[0] = (linear_idx / S[0]) % S[1]
I[0] = (linear_idx / (S[0] * S[1])) % S[2]
....
"""
assert linear_idx >= 0, "linear index [{}] is least than zero".format(
linear_idx)
assert linear_idx < np.prod(
mesh_shape
), "linear index beyond the extent of mesh shape. shape: {}, linear index: {}".format(
mesh_shape, linear_idx)
base = 1
coordinate = [-1] * len(mesh_shape)
for i in reversed(range(len(mesh_shape))):
offset = linear_idx / base
coordinate[i] = int(offset % mesh_shape[i])
base *= mesh_shape[i]
# row major order
return coordinate
...@@ -79,6 +79,8 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_meta_optimizer_base) ...@@ -79,6 +79,8 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_meta_optimizer_base)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_distributed_strategy) list(APPEND MIXED_DIST_TEST_OPS test_fleet_distributed_strategy)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_auto) list(APPEND MIXED_DIST_TEST_OPS test_fleet_auto)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_static_mp_layers) list(APPEND MIXED_DIST_TEST_OPS test_fleet_static_mp_layers)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner_gpt)
foreach(TEST_OP ${MIXED_DIST_TEST_OPS}) foreach(TEST_OP ${MIXED_DIST_TEST_OPS})
list(REMOVE_ITEM TEST_OPS ${TEST_OP}) list(REMOVE_ITEM TEST_OPS ${TEST_OP})
endforeach() endforeach()
...@@ -206,6 +208,8 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) ...@@ -206,6 +208,8 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
LIST(REMOVE_ITEM TEST_OPS test_dygraph_recompute) LIST(REMOVE_ITEM TEST_OPS test_dygraph_recompute)
list(REMOVE_ITEM TEST_OPS test_parallel_class_center_sample) list(REMOVE_ITEM TEST_OPS test_parallel_class_center_sample)
LIST(REMOVE_ITEM TEST_OPS test_parallel_margin_cross_entropy) LIST(REMOVE_ITEM TEST_OPS test_parallel_margin_cross_entropy)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner_gpt)
elseif(WITH_GPU) elseif(WITH_GPU)
if (${CUDNN_VERSION} VERSION_LESS 7100) if (${CUDNN_VERSION} VERSION_LESS 7100)
LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op) LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册