未验证 提交 a622b701 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] Logical Partition & Dist Op (#35117)

* support shard reader

* support shard reader

* add parallel mode

* update process mesh

* add method to compute comm_group

* implement dist_embedding forward func

* implement dist matmul forward func

* implement dist reshape forward func

* add transpiler framework

* add transpiler forward

* implement transpiler forward

* implement transpiler backward & update

* add process

* add unitest

* chmod

* chmod

* chmod

* update unitest

* add unitest for gpt

* remove unused print

* rename transpiler --> partitioner

* rename transpiler --> partitioner

* chmod

* chmod

* bug fixed

* remove amp function

* update case for dp mode

* update case for dp mode
上级 280d7421
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import copy import copy
from collections import defaultdict from collections import defaultdict
from paddle.fluid import core
class TensorDistributedAttribute: class TensorDistributedAttribute:
...@@ -77,6 +78,8 @@ class TensorDistributedAttribute: ...@@ -77,6 +78,8 @@ class TensorDistributedAttribute:
self._is_parameter = True self._is_parameter = True
def is_valid(self): def is_valid(self):
if self.get_owner_tensor().type == core.VarDesc.VarType.READER:
return True
tensor_shape = self.get_owner_tensor().desc.shape() tensor_shape = self.get_owner_tensor().desc.shape()
if len(tensor_shape) != len(self.get_dims_mapping()): if len(tensor_shape) != len(self.get_dims_mapping()):
return False return False
...@@ -222,6 +225,8 @@ class OperatorDistributedAttribute: ...@@ -222,6 +225,8 @@ class OperatorDistributedAttribute:
self._is_parameters[name] = True self._is_parameters[name] = True
def is_valid(self): def is_valid(self):
if "read" in self.get_owner_op().type:
return True
for name in self.get_owner_op().desc.input_arg_names(): for name in self.get_owner_op().desc.input_arg_names():
dims_mapping = self.get_input_dims_mapping(name) dims_mapping = self.get_input_dims_mapping(name)
shape = self.get_input_shape(name) shape = self.get_input_shape(name)
......
...@@ -15,9 +15,11 @@ ...@@ -15,9 +15,11 @@
import copy import copy
from collections import defaultdict from collections import defaultdict
from paddle.fluid import framework from paddle.fluid import framework
from paddle.fluid import core
from .attribute import TensorDistributedAttribute from .attribute import TensorDistributedAttribute
from .attribute import OperatorDistributedAttribute from .attribute import OperatorDistributedAttribute
from .utils import append_distributed_attr_suffix from .utils import append_distributed_attr_suffix
from .interface import _g_process_mesh_map
# There always exists a default context for user. And user can set it to another one. # There always exists a default context for user. And user can set it to another one.
DEFAULT_DISTRIBUTED_CONTEXT = None DEFAULT_DISTRIBUTED_CONTEXT = None
...@@ -49,6 +51,20 @@ class DistributedContext: ...@@ -49,6 +51,20 @@ class DistributedContext:
self._op_distributed_attr_map_for_program = {} self._op_distributed_attr_map_for_program = {}
self._tensor_distributed_attr_map_for_graph = {} self._tensor_distributed_attr_map_for_graph = {}
self._op_distributed_attr_map_for_graph = {} self._op_distributed_attr_map_for_graph = {}
# The following is a hard code and will be removed in the future
self._data_parallel_axis = None
self._model_parallel_axis = None
self._process_mesh = _g_process_mesh_map.get(0, None)
if self._process_mesh is not None:
if self._process_mesh.ndim == 1:
self._data_parallel_axis = 0
self._model_parallel_axis = 0
else:
self._data_parallel_axis = 0
self._model_parallel_axis = 1
else:
self._data_parallel_axis = -1
self._model_parallel_axis = -1
def is_initialized_for_program(self): def is_initialized_for_program(self):
return self._is_initialized_for_program return self._is_initialized_for_program
...@@ -99,6 +115,19 @@ class DistributedContext: ...@@ -99,6 +115,19 @@ class DistributedContext:
op_node_id = op_node.id() op_node_id = op_node.id()
self._op_distributed_attr_map_for_graph[op_node_id] = op_dist_attr self._op_distributed_attr_map_for_graph[op_node_id] = op_dist_attr
def set_process_mesh(self, process_mesh):
self._process_mesh = process_mesh
if self._process_mesh is not None:
if self._process_mesh.ndim == 1:
self._data_parallel_axis = 0
self._model_parallel_axis = 0
else:
self._data_parallel_axis = 0
self._model_parallel_axis = 1
else:
self._data_parallel_axis = -1
self._model_parallel_axis = -1
def initialize_distributed_attr_for_program(self, program): def initialize_distributed_attr_for_program(self, program):
if self._is_initialized_for_program: if self._is_initialized_for_program:
return return
...@@ -377,3 +406,11 @@ class DistributedContext: ...@@ -377,3 +406,11 @@ class DistributedContext:
if dims_mapping[i] != -1 and process_mesh_shape[ if dims_mapping[i] != -1 and process_mesh_shape[
dims_mapping[i]] > tensor_shape[i]: dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1 dims_mapping[i] = -1
def _get_data_parallel_info(self):
# This function is a hard code, and will be obsoleted in the future
return self._data_parallel_axis, self._process_mesh
def _get_model_parallel_info(self):
# This function is a hard code, and will be obsoleted in the future
return self._model_parallel_axis, self._process_mesh
...@@ -184,6 +184,13 @@ class ProcessMesh(object): ...@@ -184,6 +184,13 @@ class ProcessMesh(object):
"parent with id %d does not exist." % self._parent_id) "parent with id %d does not exist." % self._parent_id)
return _g_process_mesh_map[self._parent_id] return _g_process_mesh_map[self._parent_id]
@property
def ndim(self):
r"""
Get the number of dimension of ProcessMesh.
"""
return len(self._topology)
def set_placement(self, order): def set_placement(self, order):
""" """
Set the map from logical processes to physical ones using the Set the map from logical processes to physical ones using the
...@@ -229,6 +236,13 @@ class ProcessMesh(object): ...@@ -229,6 +236,13 @@ class ProcessMesh(object):
for idx, l_id in enumerate(logical_order): for idx, l_id in enumerate(logical_order):
_user_defined_physical_map[l_id] = order[idx] _user_defined_physical_map[l_id] = order[idx]
def _reset_global_process_mesh_map(self):
"""
Remove all process mesh in _g_process_mesh_map, make it empty.
"""
_g_process_mesh_map = dict()
def __eq__(self, other): def __eq__(self, other):
assert other and isinstance(other, ProcessMesh) assert other and isinstance(other, ProcessMesh)
if self.topology != other.topology or self.process_group != other.process_group: if self.topology != other.topology or self.process_group != other.process_group:
......
...@@ -33,6 +33,8 @@ class DistributedOperator: ...@@ -33,6 +33,8 @@ class DistributedOperator:
class DistributedOperatorImpl: class DistributedOperatorImpl:
def __init__(self): def __init__(self):
self._name = None self._name = None
self._forward_implemented = False
self._backward_implemented = False
def forward(self, dist_ctx, *args, **kwargs): def forward(self, dist_ctx, *args, **kwargs):
raise NotImplementedError("Please Implement this method in Subclass.") raise NotImplementedError("Please Implement this method in Subclass.")
......
...@@ -22,6 +22,12 @@ from ..utils import is_valid_list_index ...@@ -22,6 +22,12 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from ..process import new_process_group
from ..utils import _get_comm_group
class DistributedEmbedding(DistributedOperator): class DistributedEmbedding(DistributedOperator):
...@@ -39,6 +45,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -39,6 +45,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedEmbeddingImpl, self).__init__() super(DistributedEmbeddingImpl, self).__init__()
self._name = name self._name = name
self._forward_implemented = True
self._backward_implemented = False
def is_process_mesh_compatible(self, op_dist_attr): def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """ """ No restriction for now. """
...@@ -92,6 +100,110 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -92,6 +100,110 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
return changed return changed
def forward(self, serial_op):
def static_handle(dst_block,
src_op,
op_dist_attr,
input_name_mapping,
output_name_mapping,
rank_id=0):
assert len(
input_name_mapping
) == 2, "row_parallel_embedding take 2 inputs variable but got {}".format(
input_name_mapping)
assert len(
output_name_mapping
) == 1, "row_parallel_embedding take 2 inputs variable but got {}".format(
output_name_mapping)
assert len(
input_name_mapping['Ids']
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format(
input_name_mapping['Ids'])
assert len(
input_name_mapping['W']
) == 1, "row_parallel_embedding input W take 1 variable but got {}".format(
input_name_mapping['W'])
assert len(
output_name_mapping['Out']
) == 1, "row_parallel_embedding input Out take 1 variable but got {}".format(
input_name_mapping['Out'])
Ids_var = dst_block.var(input_name_mapping['Ids'][0])
Weight_var = dst_block.var(input_name_mapping['W'][0])
Out_var = dst_block.var(output_name_mapping['Out'][0])
# got dist attribute info
embedding_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[0]
process_mesh_shape = op_dist_attr.get_process_mesh().topology
process_mesh_group = op_dist_attr.get_process_mesh().process_group
# caculate embedding offset
# TODO generalize here, using cartisian product to allow any dimensional mesh shape
mesh_shape = len(process_mesh_shape)
assert mesh_shape <= 2, "row_parallel_embedding only support 1 or 2 dimensional process mesh, but got {}".format(
process_mesh_shape)
num_partition = process_mesh_shape[embedding_row_dim_mapping]
# TODO generalize here, support any mesh group
if mesh_shape == 1:
relative_idx = process_mesh_group.index(rank_id)
else:
relative_idx = rank_id % num_partition
per_part_size = Weight_var.shape[0]
relative_idx = relative_idx * per_part_size
# TODO caculate ring id
model_parallel_axis, process_mesh = op_dist_attr.get_owner_context(
)._get_model_parallel_info()
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology,
model_parallel_axis, rank_id)
group = new_process_group(group_ranks)
# append op
check_variable_and_dtype(Ids_var, 'input', ['int32', 'int64'],
'c_embedding')
intermediate_var_0 = dst_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_embedding", 'tmp'])),
dtype=Weight_var.dtype,
shape=Out_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=Out_var.stop_gradient)
check_variable_and_dtype(
Out_var, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'c_allreduce_sum')
dst_block.append_op(
type='c_embedding',
inputs={'Ids': [Ids_var],
'W': [Weight_var]},
outputs={'Out': [intermediate_var_0]},
attrs={"start_index": relative_idx})
# use_model_parallel
dst_block.append_op(
type='c_allreduce_sum',
inputs={'X': [intermediate_var_0]},
outputs={'Out': [Out_var]},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
})
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
"matmul", 0))
else:
return static_handle
register_distributed_operator_impl("lookup_table_v2", register_distributed_operator_impl("lookup_table_v2",
DistributedEmbeddingImpl("row_parallel")) DistributedEmbeddingImpl("row_parallel"))
...@@ -22,6 +22,12 @@ from ..utils import is_valid_list_index ...@@ -22,6 +22,12 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from ..process import new_process_group
from ..utils import _get_comm_group
def _update_dims_mapping_for_matmul(op_dist_attr): def _update_dims_mapping_for_matmul(op_dist_attr):
...@@ -37,7 +43,6 @@ def _update_dims_mapping_for_matmul(op_dist_attr): ...@@ -37,7 +43,6 @@ def _update_dims_mapping_for_matmul(op_dist_attr):
y_dims_mapping_len = len(y_dims_mapping) y_dims_mapping_len = len(y_dims_mapping)
out_dims_mapping_len = len(out_dims_mapping) out_dims_mapping_len = len(out_dims_mapping)
# print("before", x_dims_mapping, y_dims_mapping, out_dims_mapping)
# Add dim mapping to Make sure the length dims_mapping be at least 2 # Add dim mapping to Make sure the length dims_mapping be at least 2
if x_dims_mapping_len == 1: if x_dims_mapping_len == 1:
x_dims_mapping.insert(0, -1) x_dims_mapping.insert(0, -1)
...@@ -109,7 +114,6 @@ def _update_dims_mapping_for_matmul(op_dist_attr): ...@@ -109,7 +114,6 @@ def _update_dims_mapping_for_matmul(op_dist_attr):
if y_dims_mapping_len == 1: if y_dims_mapping_len == 1:
y_dims_mapping.pop(1) y_dims_mapping.pop(1)
# print("after", x_dims_mapping, y_dims_mapping, out_dims_mapping)
assert len(x_dims_mapping) == x_dims_mapping_len assert len(x_dims_mapping) == x_dims_mapping_len
assert len(y_dims_mapping) == y_dims_mapping_len assert len(y_dims_mapping) == y_dims_mapping_len
assert len(out_dims_mapping) == out_dims_mapping_len assert len(out_dims_mapping) == out_dims_mapping_len
...@@ -131,6 +135,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -131,6 +135,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedMatmulImpl0, self).__init__() super(DistributedMatmulImpl0, self).__init__()
self._name = name self._name = name
self._forward_implemented = True
self._backward_implemented = False
def is_process_mesh_compatible(self, op_dist_attr): def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """ """ No restriction for now. """
...@@ -170,12 +176,101 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -170,12 +176,101 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
changed = True changed = True
return changed return changed
def forward(self, serial_op):
def static_handle(dst_block,
src_op,
op_dist_attr,
input_name_mapping,
output_name_mapping,
rank_id=0):
assert len(
input_name_mapping
) == 2, "col_parallel_linear take 2 inputs variable but got {}".format(
input_name_mapping)
assert len(
output_name_mapping
) == 1, "col_parallel_linear take 2 inputs variable but got {}".format(
output_name_mapping)
assert len(
input_name_mapping['X']
) == 1, "col_parallel_linear input X take 1 variable but got {}".format(
input_name_mapping['X'])
assert len(
input_name_mapping['Y']
) == 1, "col_parallel_linear input Y take 1 variable but got {}".format(
input_name_mapping['Y'])
assert len(
output_name_mapping['Out']
) == 1, "col_parallel_linear input Out take 1 variable but got {}".format(
input_name_mapping['Out'])
X_var = dst_block.var(input_name_mapping['X'][0])
Weight_var = dst_block.var(input_name_mapping['Y'][0])
Out_var = dst_block.var(output_name_mapping['Out'][0])
# TODO infer logic comm presentation
model_parallel_axis, process_mesh = op_dist_attr.get_owner_context(
)._get_model_parallel_info()
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology,
model_parallel_axis, rank_id)
group = new_process_group(group_ranks)
intermediate_var_0 = dst_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_identity", 'tmp'])),
dtype=X_var.dtype,
shape=X_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=X_var.stop_gradient)
check_variable_and_dtype(
X_var, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'_c_identity')
dst_block.append_op(
type='c_identity',
inputs={'X': [X_var]},
outputs={'Out': intermediate_var_0},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
})
check_variable_and_dtype(intermediate_var_0, 'x',
['float16', 'float32', 'float64'],
'linear')
check_dtype(intermediate_var_0.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear')
attrs = {
'transpose_X': False,
'transpose_Y': False,
'alpha': 1,
}
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
dst_block.append_op(
type='matmul',
inputs=inputs,
outputs={'Out': Out_var},
attrs=attrs)
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
"matmul", 0))
else:
return static_handle
# RowParallel # RowParallel
class DistributedMatmulImpl1(DistributedOperatorImpl): class DistributedMatmulImpl1(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedMatmulImpl1, self).__init__() super(DistributedMatmulImpl1, self).__init__()
self._name = name self._name = name
self._forward_implemented = True
self._backward_implemented = False
def is_process_mesh_compatible(self, op_dist_attr): def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """ """ No restriction for now. """
...@@ -217,6 +312,86 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -217,6 +312,86 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
changed = True changed = True
return changed return changed
def forward(self, serial_op):
def static_handle(dst_block,
src_op,
op_dist_attr,
input_name_mapping,
output_name_mapping,
rank_id=0):
assert len(
input_name_mapping
) == 2, "col_parallel_linear take 2 inputs variable but got {}".format(
input_name_mapping)
assert len(
output_name_mapping
) == 1, "col_parallel_linear take 2 inputs variable but got {}".format(
output_name_mapping)
assert len(
input_name_mapping['X']
) == 1, "col_parallel_linear input X take 1 variable but got {}".format(
input_name_mapping['X'])
assert len(
input_name_mapping['Y']
) == 1, "col_parallel_linear input Y take 1 variable but got {}".format(
input_name_mapping['Y'])
assert len(
output_name_mapping['Out']
) == 1, "col_parallel_linear input Out take 1 variable but got {}".format(
input_name_mapping['Out'])
X_var = dst_block.var(input_name_mapping['X'][0])
Weight_var = dst_block.var(input_name_mapping['Y'][0])
Out_var = dst_block.var(output_name_mapping['Out'][0])
# TODO infer logic comm presentation
model_parallel_axis, process_mesh = op_dist_attr.get_owner_context(
)._get_model_parallel_info()
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology,
model_parallel_axis, rank_id)
group = new_process_group(group_ranks)
check_variable_and_dtype(
X_var, 'x', ['float16', 'float32', 'float64'], 'linear')
check_dtype(X_var.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear')
attrs = {
'transpose_X': False,
'transpose_Y': False,
'alpha': 1,
}
inputs = {'X': X_var, 'Y': Weight_var}
intermediate_var_0 = dst_block.create_var(
shape=Out_var.shape,
dtype=Out_var.dtype,
type=Out_var.type,
lod_level=Out_var.lod_level,
persistable=False,
is_data=False,
need_check_feed=Out_var.desc.need_check_feed())
dst_block.append_op(
type='matmul',
inputs=inputs,
outputs={'Out': intermediate_var_0},
attrs=attrs)
dst_block.append_op(
type='c_allreduce_sum',
inputs={'X': intermediate_var_0},
outputs={'Out': Out_var},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True
})
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
"matmul", 0))
else:
return static_handle
# ReplicateParallel # ReplicateParallel
class DistributedMatmulImpl2(DistributedOperatorImpl): class DistributedMatmulImpl2(DistributedOperatorImpl):
......
...@@ -22,6 +22,10 @@ from ..utils import is_valid_list_index ...@@ -22,6 +22,10 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
class DistributedReshape2(DistributedOperator): class DistributedReshape2(DistributedOperator):
...@@ -37,6 +41,8 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -37,6 +41,8 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedReshapeImpl0, self).__init__() super(DistributedReshapeImpl0, self).__init__()
self._name = name self._name = name
self._forward_implemented = True
self._backward_implemented = False
def is_process_mesh_compatible(self, op_dist_attr): def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """ """ No restriction for now. """
...@@ -91,11 +97,90 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -91,11 +97,90 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
return changed return changed
def forward(self, serial_op):
def static_handle(dst_block,
src_op,
op_dist_attr,
input_name_mapping,
output_name_mapping,
rank_id=0):
assert len(
input_name_mapping
) == 3, "Dist op of Reshape take 3 inputs variable but got {}".format(
input_name_mapping)
assert len(
output_name_mapping
) == 2, "Dist op of Reshape take 2 inputs variable but got {}".format(
output_name_mapping)
assert len(
input_name_mapping['X']
) == 1, "Dist op of Reshape input X take 1 variable but got {}".format(
input_name_mapping['X'])
assert len(
input_name_mapping['ShapeTensor']
) <= 1, "Dist op of Reshape input ShapeTensor take 0 or 1 variable but got {}".format(
input_name_mapping['ShapeTensor'])
assert len(
input_name_mapping['Shape']
) <= 1, "Dist op of Reshape input Shape take 0 or 1 variable but got {}".format(
input_name_mapping['Shape'])
assert len(
output_name_mapping['Out']
) == 1, "Dist op of Reshape input Out take 1 variable but got {}".format(
input_name_mapping['Out'])
assert len(
output_name_mapping['XShape']
) == 1, "Dist op of Reshape input XShape take 1 variable but got {}".format(
input_name_mapping['XShape'])
X_var = dst_block.var(input_name_mapping['X'][0])
Out_var = dst_block.var(output_name_mapping['Out'][0])
XShape_var = dst_block.var(output_name_mapping['XShape'][0])
shape_list = src_op.desc.attr("shape")
ShapeTensor_var_list = []
for name in input_name_mapping['ShapeTensor']:
ShapeTensor_var_list.append(name)
Shape_var_list = []
for name in input_name_mapping['Shape']:
Shape_var_list.append(name)
# got dist attribute info
dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
process_mesh_shape = op_dist_attr.get_process_mesh().topology
# modify target shape
for idx, axis in enumerate(dim_mapping):
if axis >= 0:
if len(shape_list) > idx:
shape_list[idx] = shape_list[idx] // process_mesh_shape[
axis]
# create op
new_op_desc = dst_block.desc.append_op()
new_op_desc.copy_from(src_op.desc)
new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list)
new_op_desc.set_input('Shape', Shape_var_list)
new_op_desc.set_input('X', [X_var.name])
new_op_desc.set_output('XShape', [XShape_var.name])
new_op_desc.set_output('Out', [Out_var.name])
new_op_desc._set_attr('shape', shape_list)
dst_block._sync_with_cpp()
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
"matmul", 0))
else:
return static_handle
class DistributedReshapeImpl1(DistributedOperatorImpl): class DistributedReshapeImpl1(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedReshapeImpl1, self).__init__() super(DistributedReshapeImpl1, self).__init__()
self._name = name self._name = name
self._forward_implemented = True
self._backward_implemented = False
def is_process_mesh_compatible(self, op_dist_attr): def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """ """ No restriction for now. """
...@@ -150,6 +235,83 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -150,6 +235,83 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
return changed return changed
def forward(self, serial_op):
def static_handle(dst_block,
src_op,
op_dist_attr,
input_name_mapping,
output_name_mapping,
rank_id=0):
assert len(
input_name_mapping
) == 3, "Dist op of Reshape take 3 inputs variable but got {}".format(
input_name_mapping)
assert len(
output_name_mapping
) == 2, "Dist op of Reshape take 2 inputs variable but got {}".format(
output_name_mapping)
assert len(
input_name_mapping['X']
) == 1, "Dist op of Reshape input X take 1 variable but got {}".format(
input_name_mapping['X'])
assert len(
input_name_mapping['ShapeTensor']
) <= 1, "Dist op of Reshape input ShapeTensor take 0 or 1 variable but got {}".format(
input_name_mapping['ShapeTensor'])
assert len(
input_name_mapping['Shape']
) <= 1, "Dist op of Reshape input Shape take 0 or 1 variable but got {}".format(
input_name_mapping['Shape'])
assert len(
output_name_mapping['Out']
) == 1, "Dist op of Reshape input Out take 1 variable but got {}".format(
input_name_mapping['Out'])
assert len(
output_name_mapping['XShape']
) == 1, "Dist op of Reshape input XShape take 1 variable but got {}".format(
input_name_mapping['XShape'])
X_var = dst_block.var(input_name_mapping['X'][0])
Out_var = dst_block.var(output_name_mapping['Out'][0])
XShape_var = dst_block.var(output_name_mapping['XShape'][0])
shape_list = src_op.desc.attr("shape")
ShapeTensor_var_list = []
for name in input_name_mapping['ShapeTensor']:
ShapeTensor_var_list.append(name)
Shape_var_list = []
for name in input_name_mapping['Shape']:
Shape_var_list.append(name)
# got dist attribute info
dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
process_mesh_shape = op_dist_attr.get_process_mesh().topology
# modify target shape
for idx, axis in enumerate(dim_mapping):
if axis >= 0:
if len(shape_list) > idx:
shape_list[idx] = shape_list[idx] // process_mesh_shape[
axis]
# create op
new_op_desc = dst_block.desc.append_op()
new_op_desc.copy_from(src_op.desc)
new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list)
new_op_desc.set_input('Shape', Shape_var_list)
new_op_desc.set_input('X', [X_var.name])
new_op_desc.set_output('XShape', [XShape_var.name])
new_op_desc.set_output('Out', [Out_var.name])
new_op_desc._set_attr('shape', shape_list)
dst_block._sync_with_cpp()
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
"matmul", 0))
else:
return static_handle
register_distributed_operator_impl("reshape2", register_distributed_operator_impl("reshape2",
DistributedReshapeImpl0("add_one_dim_back")) DistributedReshapeImpl0("add_one_dim_back"))
......
...@@ -47,7 +47,6 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -47,7 +47,6 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
axis = op_desc.attr('axis') axis = op_desc.attr('axis')
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
# print("softmax axis", axis)
if axis != -1 and axis != len(x_dims_mapping) - 1: if axis != -1 and axis != len(x_dims_mapping) - 1:
return False return False
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid import framework as framework
from paddle.fluid import core, unique_name
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.fluid.backward import append_backward, _some_in_set_, _append_grad_suffix_
from paddle.distributed.auto_parallel.operators.common import get_distributed_operator
from paddle.distributed.auto_parallel.operators.common import find_best_compatible_distributed_operator_impl
from paddle.fluid.clip import GradientClipBase, GradientClipByNorm, error_clip_callback, append_gradient_clip_ops, ClipGradByGlobalNorm
from paddle.distributed.fleet.base.distributed_strategy import DistributedStrategy
from paddle.distributed.auto_parallel.context import DistributedContext
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op, is_backward_op, is_optimizer_op
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from .process import new_process_group
from .interface import _g_process_mesh_map
from .utils import _get_comm_group
__varname_not_in_block__ = ["lod_tensor_blocking_queue_0"]
class Partitioner(object):
"""
warning:: Partitioner is experimental and subject to change.
Partitioner convert a program into another program.
Given a serial program which has been auto completed with shard annotation, the Partitioner
convert the serial program into a "distributed" program. The Partitioner will modify the serial
program in following two ways, which is also the major difference between serial and distributed program:
1. partition op: replace a serial op into its corresponding dist op infered from the shard annotation
2. partition var: if a var is sharded, modify the shape of var according to its shard annotation
Partitioner is supposed to be call by the auto parallel framework, and not supposed to be directly called by user.
Example:
....
import paddle.distributed.auto_parallel as auto
from paddle.fluid.distributed_attribute import get_default_distributed_context
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
# create serial program with forward only
with static.program_guard(serial_main_program, serial_start_program):
model = create_model(config)
tokens = static.data(name="tokens", shape=[batch_size, sequence_len], dtype='int64')
labels = static.data(name="labels", shape=[batch_size, sequence_len], dtype='int64')
loss_mask = static.data(name="loss_mask", shape=[batch_size, sequence_len], dtype='int64')
preds = model(tokens)
loss = criterion(preds, labels, loss_mask)
# auto completion
auto.ProcessMesh(shape=[2, 4], process_group=[0, 1, 2, 3, 4, 5, 6, 7])
annotated_main_program = auto.complete_annotation(serial_main_program)
auto_paralle_context = get_default_distributed_context()
# distributed strategy & rank info
rank_id = paddle.distributed.get_rank()
dist_strategy = fleet.DistributedStrategy()
# create partitioner
Partitioner = Partitioner(dist_strategy, auto_paralle_context, rank_id)
# create dist program with forward only
# for distributed inference, using partitioned_main_prog from here
partitioned_main_prog, partitioned_startup_prog = Partitioner.transpile_forward(complete_train_program, start_program)
# create dist program with forward/backward/update
# for distributed training, using partitioned_main_prog from here
dist_params_grads = Partitioner.apply_backward(loss, complete_train_program, start_program, partitioned_main_prog, partitioned_startup_prog)
optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
opt_ops = Partitioner.apply_optimize(optimizer, dist_params_grads, partitioned_main_prog, partitioned_startup_prog)
"""
def __init__(self, dist_strategy, auto_parallel_context, rank_id=0):
"""
Args:
dist_strategy (paddle.fleet.distributed_strategy): used to determine the user defined distributed strategy.
auto_parallel_context (paddle.fluid.DistributedContext): used to access the distributed_attr of var & op, every Partitioner object could maintain its own DistributedContext member, and partition program base on that shard scenario.
rank_id (int): global rank id to which the partitioned distributed program belong.
"""
if not isinstance(dist_strategy, DistributedStrategy):
raise TypeError(
"dist_strategy be paddle.fleet.base.DistributedStrategy, got %s here"
% type(dist_strategy))
if not isinstance(auto_parallel_context, DistributedContext):
raise TypeError(
"auto_parallel_context be paddle.fluid.DistributedContext, got %s here"
% type(auto_parallel_context))
self._dist_strategy = dist_strategy
self._auto_parallel_context = auto_parallel_context
self._rank_id = rank_id
self._serial2dist_varname_mapping = {}
self._dist_varname_suffix = ""
# TODO if there is some dist op that is not compatible
# with auto_backward in forward, the following flag
# should be set to False
self._compatible_with_auto_backward = True
# data parallelism
self._enable_data_parallel = False
self._dp_degree = 0
self._dp_group = None
# tensor parallelism
self._enable_tensor_parallel = False
self._tp_degree = 0
self._tp_group = None
def transpile_forward(self, serial_main_program, serial_startup_program):
"""
take serial forward programs with shard annotation, create a new distributed forward programs based on the serial ones.
instead of modify the input programs inplace, this function will preserve the inputs and create new program for output.
beside replace the serial op with its dist op, if user has defined other strategy in fleet.distributed_strategy, and if
those strategy need to transpile (modify) the forward network program, those forward program modification should also be done within this
function in auto parallel scenario, in order to facilitate distributed inference/evaluation which need to DECOUPLE strategy specific forward transpilation with fleet.distributed_optimizer.minimize().
by now the fleet.distributed_strategy that need transpile forward program are following:
1. (optimizer) sharding
Args:
main_program (paddle.fluid.framework.program): serial main program with forward network only
startup_program (paddle.fluid.framework.program): serial startup program with forward network only
return:
main_program (paddle.fluid.framework.program): distributed main program with forward network only
startup_program (paddle.fluid.framework.program): distributed startup program with forward network only
"""
dist_main_program, dist_startup_program = self.transpile_forward_impl(
serial_main_program, serial_startup_program)
return dist_main_program, dist_startup_program
def apply_backward(self,
serial_loss,
serial_main_program,
serial_startup_program,
dist_main_program,
dist_startup_program,
parameter_list=None,
no_grad_set=None,
callbacks=None):
"""
A complete training neural network is made up of forward and backward propagation.
This function is to generate the dist backward program for the distributed forward program.
By now, the current automatical backward mechanism in paddle framework might NOT handle the backward generation for
some dist ops correctly, some so we now have two ways to genenate the backward program:
1. dist_forward_program --> auto_backward --> dist_backward_program (if auto_backward could handle all dist op)
2. serial_forward_program --> auto_backward --> serial_backward_program --> dist_op_backward_transpile --> dist_backward_program (if auto_backward could not handle all dist op)
the backprogram is append the input dist program inplaced.
Args:
serial_loss (Variable) the loss in serial program that to be minimized
serial_main_program (paddle.fluid.framework.program): serial main program with forward network only
serial_startup_program (paddle.fluid.framework.program): serial startup program with forward network only
dist_main_program (paddle.fluid.framework.program): dist main program with forward network only
dist_startup_program (paddle.fluid.framework.program): dist startup program with forward network only
parameter_list (Iterable, optional): Iterable of ``Variable`` or ``Variable.name`` to update
to minimize ``loss``. The default value is None, at this time all parameters
will be updated.
no_grad_set (set, optional): Set of ``Variable`` or ``Variable.name`` that don't need
to be updated. The default value is None.
callbacks (list, optional): list of callable objects to run when appending backward
operator for one parameter. The default value is None.
return:
params_grads (list) list of tuple that contain param and its grad variable
"""
params_grads = self.apply_backward_impl(
serial_loss, serial_main_program, serial_startup_program,
dist_main_program, dist_startup_program)
return params_grads
def apply_optimize(self, user_define_optimizer, params_grads,
dist_main_program, dist_startup_program):
"""
append update related ops to the program: clip, weight decay, ops
filter optimize op if sharding is enable
naive gradient synchronization before update
Args:
user_define_optimizer (paddle.fluid.optimizer):
params_grads (list) list of tuple that contain param and its grad variable
dist_main_program (paddle.fluid.framework.program): dist main program with forward & backward network
dist_startup_program (paddle.fluid.framework.program): dist startup program with forward & backward network
"""
optimize_ops = self.apply_optimize_impl(user_define_optimizer,
params_grads, dist_main_program,
dist_startup_program)
return optimize_ops
def transpile_forward_impl(self, main_program, startup_program):
if not isinstance(main_program, (Program)):
raise TypeError(
"dist_strategy be paddle.fluid.framework.program, got %s here" %
type(main_program))
if not isinstance(startup_program, (Program)):
raise TypeError(
"auto_parallel_context be paddle.fluid.framework.program, got %s here"
% type(startup_program))
# check if shard annotated serial program valid
if not self._is_valid_annotated_program(main_program):
raise RuntimeError(
"Not all vars or ops are annotated in main program !")
# determine parallelism mode
self._determine_parallel_mode(main_program)
# dist op & partition vars
new_main_prog, new_startup_program = self._dist_var_op_forward_transpile(
main_program, startup_program)
# Sharding
if self._dist_strategy.sharding:
new_main_prog, new_startup_program = self._sharding_forward_transpile(
new_main_prog, new_startup_program)
return new_main_prog, new_startup_program
def apply_backward_impl(self,
serial_loss,
serial_main_program,
serial_startup_program,
dist_main_program,
dist_startup_program,
parameter_list=None,
no_grad_set=None,
callbacks=None):
"""
"""
params_grads = self._dist_var_op_backward_transpile(
serial_loss, serial_main_program, serial_startup_program,
dist_main_program, dist_startup_program)
# Sharding
if self._dist_strategy.sharding:
self._sharding_backward_transpile(new_main_prog,
new_startup_program)
# Data Parallel pass
if self._enable_data_parallel:
self._gradient_sync_transpile(dist_main_program,
dist_startup_program)
return params_grads
def apply_optimize_impl(self, user_define_optimizer, params_grads,
dist_main_program, dist_startup_program):
"""
append update related ops to the program: clip, weight decay, ops
filter optimize op if sharding is enable
naive gradient synchronization before update
Args:
user_define_optimizer (paddle.fluid.optimizer):
params_grads (list) list of tuple that contain param and its grad variable
dist_main_program (paddle.fluid.framework.program): dist main program with forward & backward network
dist_startup_program (paddle.fluid.framework.program): dist startup program with forward & backward network
"""
if self._dist_strategy.sharding:
params_grads = sharding_optimize_transpile(
params_grads, dist_main_program, dist_startup_program)
optimize_ops = self._optimize_transpile(user_define_optimizer,
params_grads, dist_main_program,
dist_startup_program)
return optimize_ops
def _dist_var_op_forward_transpile(self,
serial_main_program,
serial_startup_program=None):
"""
1. partition variables
2. replace local op with corresponding dist op
"""
partitioned_main_prog = fluid.Program()
partitioned_global_block = partitioned_main_prog.global_block()
serial_global_block = serial_main_program.global_block()
serial_ops = serial_main_program.global_block().ops
# transpile main program
for op in serial_ops:
# partititon input variables
for serial_input_varname in op.desc.input_arg_names():
if serial_input_varname not in self._serial2dist_varname_mapping:
new_varname = serial_input_varname + self._dist_varname_suffix
if serial_global_block.has_var(serial_input_varname):
_partition_var(self._auto_parallel_context,
serial_global_block,
partitioned_global_block,
serial_input_varname, new_varname)
else:
assert serial_input_varname in __varname_not_in_block__
self._serial2dist_varname_mapping[
serial_input_varname] = new_varname
# partition output vars
for serial_output_varname in op.desc.output_arg_names():
if serial_output_varname not in self._serial2dist_varname_mapping:
new_varname = serial_output_varname + self._dist_varname_suffix
_partition_var(self._auto_parallel_context,
serial_global_block,
partitioned_global_block,
serial_output_varname, new_varname)
self._serial2dist_varname_mapping[
serial_output_varname] = new_varname
# partition op
if _found_match_dist_op(self._auto_parallel_context, op):
# replace with corresponding dist op
_insert_dist_op(op, partitioned_global_block,
self._serial2dist_varname_mapping,
self._auto_parallel_context, self._rank_id)
else:
# replicate op
_insert_src_op(op, partitioned_global_block,
self._serial2dist_varname_mapping)
# transpile startup program
if serial_startup_program == None:
partitioned_startup_prog = None
else:
partitioned_startup_prog = fluid.Program()
# create parameter
partitioned_startup_global_block = partitioned_startup_prog.global_block(
)
param2shape = {}
for var in partitioned_main_prog.list_vars():
if isinstance(var, Parameter):
_partition_parameter(self._auto_parallel_context, var,
partitioned_startup_global_block,
var.name, var.shape)
param2shape[var.name] = var.shape
# copy initializer
for op in serial_startup_program.global_block().ops:
output_vars = op.desc.output_arg_names()
assert len(
output_vars
) == 1, "initializer should output only ONE variable, but got [{}]".format(
str(op.desc))
assert self._serial2dist_varname_mapping[output_vars[
0]] in param2shape, "try to initialize [{}] which is not a Parameter".format(
output_vars[0])
new_op_desc = partitioned_startup_global_block.desc.append_op()
new_op_desc.copy_from(op.desc)
new_op_desc._rename_output(
output_vars[0],
self._serial2dist_varname_mapping[output_vars[0]])
new_op_desc._set_attr("shape", param2shape[
self._serial2dist_varname_mapping[output_vars[0]]])
partitioned_startup_global_block._sync_with_cpp()
# MP broadcast not split parameter
# NOTE Theoretically, the MP param init broadcast should be handled by
# each dist op itself. but if we insert the broadcast op at that moment, the broadcast
# will before the initializer, which lead to a undertermined case.
if self._enable_tensor_parallel:
param_to_sync = []
for param in partitioned_startup_prog.all_parameters():
if not self._is_var_distributed(param):
param_to_sync.append(param)
# FIXME the ring id should be set by autoparallel.mapping module
# it should be determined by dp groups butfixed it here for hacking
partitioned_startup_global_block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': self._tp_group.id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})
partitioned_startup_global_block.append_op(
type='c_sync_comm_stream',
inputs={'X': param_to_sync},
outputs={'Out': param_to_sync},
attrs={
'ring_id': self._tp_group.id,
OP_ROLE_KEY: OpRole.Forward
})
partitioned_startup_global_block._sync_with_cpp()
# DP init param broadcast
if self._enable_data_parallel:
# parameters initialization synchronization
param_to_sync = []
for param in partitioned_startup_global_block.all_parameters():
param_to_sync.append(param)
# FIXME the ring id should be set by autoparallel.mapping module
# it should be determined by dp groups butfixed it here for hacking
partitioned_startup_global_block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': self._dp_group.id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})
partitioned_startup_global_block.append_op(
type='c_sync_comm_stream',
inputs={'X': param_to_sync},
outputs={'Out': param_to_sync},
attrs={
'ring_id': self._dp_group.id,
OP_ROLE_KEY: OpRole.Forward
})
partitioned_startup_global_block._sync_with_cpp()
return partitioned_main_prog, partitioned_startup_prog
def _dist_var_op_backward_transpile(self,
serial_loss,
serial_main_program,
serial_startup_program,
dist_main_program,
dist_startup_program,
parameter_list=None,
no_grad_set=None,
callbacks=None):
"""
so far, the auto_backward case only guarantee the correcotness of backward ops for curtain Dist ops:
1. NV-Megatron-like parallel embedding
2. NV-Megatron-like row parallel linear
3. NV-Megatron-like col parallel linear
"""
if self._compatible_with_auto_backward:
assert isinstance(
serial_loss, Variable), "The target loss should be an Variable."
dist_loss = self._serial_varname2dist_var(serial_loss.name,
dist_main_program)
assert len(dist_loss.shape) == 1 and dist_loss.shape[0] == 1, \
"The dist loss.shape should be (1L,), but the current dist loss.shape is {}. " \
"Maybe that you should call fluid.layers.mean to process the current loss.".format(
dist_loss.shape)
# update parameter list
if parameter_list:
parameter_list = [
self._serial_varname2dist_var(param.name, dist_main_program)
for param in parameter_list
]
# update parameter no_grad_set
if no_grad_set:
no_grad_set = [
self._serial_varname2dist_var(param.name, dist_main_program)
for param in no_grad_set
]
return _auto_backward(
dist_loss,
dist_startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set,
callbacks=callbacks)
# replace dist grad ops
else:
raise RuntimeError("transpile NOT implemented !")
def _optimize_transpile(self, user_define_optimizer, params_grads,
main_program, startup_program):
with program_guard(main_program, startup_program):
optimize_ops = user_define_optimizer.apply_gradients(params_grads)
return optimize_ops
def _is_valid_annotated_program(self, program):
# TODO (ZJ-LIANG) should check all block
ops = program.global_block().ops
vars_ = program.list_vars()
op_dist_attrs = [
self._auto_parallel_context.get_op_distributed_attr_for_program(op)
for op in ops
]
var_dist_attrs = [
self._auto_parallel_context.get_tensor_distributed_attr_for_program(
var) for var in vars_
]
all_ops_annotated = all(dist_attr is not None
for dist_attr in op_dist_attrs)
all_vars_annotated = all(dist_attr is not None
for dist_attr in var_dist_attrs)
return all_ops_annotated and all_vars_annotated
def _serial_varname2dist_var(self, serial_varname, dist_program):
assert serial_varname in self._serial2dist_varname_mapping, "The serial var [{}] is not found in var name mapping".format(
serial_varname)
dist_varname = self._serial2dist_varname_mapping[serial_varname]
assert dist_program.global_block().has_var(
dist_varname
), "The dist var [{}] is not found in dist program".format(dist_varname)
dist_var = dist_program.global_block().var(dist_varname)
return dist_var
def _determine_parallel_mode(self, program):
"""
determine the parallelism that is enabled
NOTE a hard rule and should be updated in future
"""
for param in program.all_parameters():
if self._is_var_distributed(param):
self._enable_tensor_parallel = True
break
for var in program.list_vars():
var_dist_attr = self._auto_parallel_context.get_tensor_distributed_attr_for_program(
var)
if not var_dist_attr.is_parameter():
mapping = var_dist_attr.get_dims_mapping()
mesh = var_dist_attr.get_process_mesh().topology
if mapping[0] >= 0 and mesh[mapping[0]] > 1:
self._enable_data_parallel = True
break
# tensor parallelism
if self._enable_tensor_parallel:
model_parallel_axis, process_mesh = self._auto_parallel_context._get_model_parallel_info(
)
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology,
model_parallel_axis, self._rank_id)
self._tp_degree = len(group_ranks)
self._tp_group = new_process_group(group_ranks)
# data parallelism
data_parallel_axis, process_mesh = self._auto_parallel_context._get_data_parallel_info(
)
if self._enable_data_parallel:
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology,
data_parallel_axis, self._rank_id)
self._dp_degree = len(group_ranks)
self._dp_group = new_process_group(group_ranks)
def _is_var_distributed(self, var):
dist_attr = self._auto_parallel_context.get_tensor_distributed_attr_for_program(
var)
assert dist_attr is not None, "dist_attr of var [{}] is None".format(
var.name)
return _is_distributed(dist_attr)
def _sharding_forward_transpile(self, main_prog, startup_program):
"""
this transpile conduct the modification in forward program need by sharding strategy
which majorly include:
1. partition the parameter
2. insert broadcast op
3. insert sync op
NOTE the transpile modification is inplace on the input program
"""
raise NotImplementedError(
"Sharding is NOT support in AutoParallel yet!")
def _sharding_backward_transpile(self, main_prog, startup_program):
"""
this transpile conduct the modification in backward program need by sharding strategy
which majorly include:
1. partition the gradient
2. insert broadcast op
3. insert sync op
NOTE the transpile modification is inplace on the input program
"""
raise NotImplementedError(
"Sharding is NOT support in AutoParallel yet!")
def _sharding_optimize_transpile(self, params_grads, dist_main_program,
dist_startup_program):
"""
shard params_grads
append the broadcast to sync parameters
"""
raise RuntimeError("sharding transpile is NOT implemented !")
def _gradient_sync_transpile(self, main_program, startup_program):
"""
append the gradient allreduce ops for all parameters' grad in case of Data Parallel
"""
# scale loss by dp degree
main_global_block = main_program.global_block()
for idx, op in reversed(list(enumerate(main_global_block.ops))):
if is_loss_grad_op(op):
loss_grad_var = main_global_block.vars[op.output_arg_names[0]]
main_global_block._insert_op_without_sync(
idx + 1,
type='scale',
inputs={'X': loss_grad_var},
outputs={'Out': loss_grad_var},
attrs={
'scale': 1.0 / self._dp_degree,
OP_ROLE_KEY: OpRole.Backward
})
break
main_global_block._sync_with_cpp()
# gradient synchronization
# NOTE naive gradient sync without overlapping
# so there is not need to sync between calc and comm
# collecting grad var
grad_to_sync = []
for idx, op in reversed(list(enumerate(main_global_block.ops))):
if is_backward_op(op) and \
OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
if len(op_role_var) != 0:
assert len(op_role_var) % 2 == 0
for i in range(0, len(op_role_var), 2):
param, reduced_grad = op_role_var[i], op_role_var[i + 1]
assert (reduced_grad not in grad_to_sync)
grad_to_sync.append(reduced_grad)
if is_optimizer_op(op):
first_optimize_op_idx = idx
# insert allreduce
for grad in grad_to_sync:
# FIXME the ring id should be set by autoparallel.mapping module
# it should be determined by dp groups butfixed it here for hacking
main_global_block.append_op(
type='c_allreduce_sum',
inputs={'X': grad},
outputs={'Out': grad},
attrs={
'ring_id': self._dp_group.id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Backward
})
main_global_block.append_op(
type='c_sync_comm_stream',
inputs={'X': grad_to_sync},
outputs={'Out': grad_to_sync},
attrs={'ring_id': self._dp_group.id,
OP_ROLE_KEY: OpRole.Backward})
main_global_block._sync_with_cpp()
def _get_no_grad_set_name(no_grad_set):
no_grad_set_name = set()
if no_grad_set is not None:
if isinstance(no_grad_set, (set, list, tuple)):
for i, no_grad_var in enumerate(no_grad_set):
if isinstance(no_grad_var, framework.Variable):
no_grad_set_name.add(no_grad_var.name)
elif isinstance(no_grad_var, six.string_types):
no_grad_set_name.add(no_grad_var)
else:
raise TypeError(
"The type of no_grad_set's member must be paddle.fluid.Variable or str, but received %s."
% (type(no_grad_var)))
else:
raise TypeError(
"The type of no_grad_set should be set or list or tuple, but received {}".
format(type(no_grad_set)))
return no_grad_set_name
def _get_no_grad_set(loss, no_grad_set=None):
no_grad_set = _get_no_grad_set_name(no_grad_set)
parameters = loss.block.program.global_block().all_parameters()
param_no_trainable = set(
[param.name for param in parameters if param.trainable is False])
# If the parameter is no trainable, it should not have a gradient.
no_grad_set.update(param_no_trainable)
return no_grad_set
def _found_match_dist_op(auto_paralle_context, op):
dist_attr = auto_paralle_context.get_op_distributed_attr_for_program(op)
dist_ops = get_distributed_operator(op.type)
return dist_ops and dist_attr.get_impl_idx() >= 0 and dist_ops.get_impl( \
dist_attr.get_impl_idx())._forward_implemented
def _auto_backward(loss,
startup_program=None,
parameter_list=None,
no_grad_set=None,
callbacks=None):
"""
modification is inplaced
"""
act_no_grad_set = _get_no_grad_set(loss, no_grad_set)
assert isinstance(loss, Variable), "The target loss should be an Variable."
if callbacks is None:
callbacks = [error_clip_callback]
else:
assert (isinstance(callbacks, list))
assert len(loss.shape) == 1 and loss.shape[0] == 1, \
"The loss.shape should be (1L,), but the current loss.shape is {}. " \
"Maybe that you should call fluid.layers.mean to process the current loss.".format(
loss.shape)
program = loss.block.program
with program_guard(program, startup_program):
params_grads = append_backward(loss, parameter_list, act_no_grad_set,
callbacks)
return params_grads
def _is_distributed(dist_attr):
mapping = dist_attr.get_dims_mapping()
mesh = dist_attr.get_process_mesh().topology
for idx in range(len(mapping)):
if mapping[idx] >= 0 and mesh[mapping[idx]] > 1:
return True
return False
def _get_dist_shape(var, dist_attr):
var_shape = var.shape
mapping = dist_attr.get_dims_mapping()
mesh = dist_attr.get_process_mesh().topology
assert len(var_shape) == len(
mapping
), "variable shape [{}] and dim_mapping [{}] is NOT match !".format(
var_shape, mapping)
new_shape = []
for idx in range(len(var_shape)):
if var_shape[idx] == -1 or mapping[idx] == -1:
new_shape.append(var_shape[idx])
else:
assert var_shape[idx] % mesh[mapping[
idx]] == 0, "un-event partition: var_shape[idx]=[{}], mesh[{}]".format(
var_shape[idx], mesh[mapping[idx]])
new_shape.append(var_shape[idx] // mesh[mapping[idx]])
return new_shape
def _partition_parameter(auto_paralle_context, src_var, dst_block, dst_varname,
dst_shape):
# NOTE hack to copied Parameter
# not initialized parameter, need to initialize it
copied_kwargs = {}
copied_kwargs['trainable'] = src_var.trainable
copied_kwargs['optimize_attr'] = src_var.optimize_attr
copied_kwargs['regularizer'] = src_var.regularizer
copied_kwargs['do_model_average'] = src_var.do_model_average
copied_kwargs['need_clip'] = src_var.need_clip
param = Parameter(
block=dst_block,
type=src_var.type,
name=dst_varname,
shape=dst_shape,
dtype=src_var.dtype,
lod_level=src_var.lod_level,
error_clip=src_var.error_clip,
stop_gradient=src_var.stop_gradient,
is_data=src_var.is_data,
belong_to_optimizer=src_var.belong_to_optimizer,
**copied_kwargs)
# set dist attr uid
# distributed_attr_uid = src_var.desc.get_distributed_attr_uid()
# param.desc.set_distributed_attr_uid(distributed_attr_uid)
dist_attr = copy.deepcopy(
auto_paralle_context.get_tensor_distributed_attr_for_program(src_var))
dist_attr._owner_tensor = param
dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program(
src_var)._owner_context
auto_paralle_context.set_tensor_distributed_attr_for_program(param,
dist_attr)
def _partition_intermediate_var(auto_paralle_context, src_var, dst_block,
dst_varname, dst_shape):
var = dst_block.create_var(
type=src_var.type,
name=dst_varname,
shape=dst_shape,
dtype=src_var.dtype,
lod_level=src_var.lod_level,
persistable=src_var.persistable,
error_clip=src_var.error_clip,
stop_gradient=src_var.stop_gradient,
is_data=src_var.is_data,
belong_to_optimizer=src_var.belong_to_optimizer)
# set dist attr uid
# distributed_attr_uid = src_var.desc.get_distributed_attr_uid()
# var.desc.set_distributed_attr_uid(distributed_attr_uid)
dist_attr = copy.deepcopy(
auto_paralle_context.get_tensor_distributed_attr_for_program(src_var))
dist_attr._owner_tensor = var
dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program(
src_var)._owner_context
auto_paralle_context.set_tensor_distributed_attr_for_program(var, dist_attr)
def _partition_var(auto_paralle_context, src_block, dst_block, src_varname,
dst_varname):
"""
partition include: split + replicate
"""
src_var = src_block.var(src_varname)
if src_var.type == core.VarDesc.VarType.READER:
dst_block.create_var(
type=src_var.type,
name=dst_varname,
persistable=True,
stop_gradient=True)
else:
dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program(
src_var)
target_shape = _get_dist_shape(src_var, dist_attr)
if isinstance(src_var, Parameter):
_partition_parameter(auto_paralle_context, src_var, dst_block,
dst_varname, target_shape)
else:
_partition_intermediate_var(auto_paralle_context, src_var,
dst_block, dst_varname, target_shape)
def _insert_src_op(src_op, dst_block, varname_mapping):
new_op_desc = dst_block.desc.append_op()
new_op_desc.copy_from(src_op.desc)
for local_varname in src_op.desc.input_arg_names():
new_op_desc._rename_input(local_varname, varname_mapping[local_varname])
for local_varname in src_op.desc.output_arg_names():
new_op_desc._rename_output(local_varname,
varname_mapping[local_varname])
dst_block._sync_with_cpp()
def _insert_dist_op(src_op, dst_block, varname_mapping, auto_paralle_context,
rank_id):
# build input varname mapping
input_mapping = {}
for input_name in src_op.desc.input_names():
varnames = []
for varname in src_op.desc.input(input_name):
varnames.append(varname_mapping[varname])
input_mapping[input_name] = varnames
# build output varname mapping
output_mapping = {}
for output_name in src_op.desc.output_names():
varnames = []
for varname in src_op.desc.output(output_name):
varnames.append(varname_mapping[varname])
output_mapping[output_name] = varnames
# append dist op
dist_attr = auto_paralle_context.get_op_distributed_attr_for_program(src_op)
dist_ops = get_distributed_operator(src_op.type)
append_op_handle = dist_ops.get_impl(dist_attr.get_impl_idx()).forward(
src_op)
append_op_handle(
dst_block,
src_op,
dist_attr,
input_mapping,
output_mapping,
rank_id=rank_id)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import paddle
import paddle.fluid.core as core
from ..collective import _get_global_env
from ..collective import _new_ring_id
from ...fluid.framework import in_dygraph_mode
from ...fluid.layers.tensor import fill_constant
LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP = None
PROCESSOR_TO_PHYSICAL_PROCESS_MAP = None
def get_all_logical_process_set():
from .interface import _g_process_mesh_map
all_logical_process_set = set(_g_process_mesh_map[0].process_group)
return all_logical_process_set
def get_logical_process_to_physical_process_map():
global LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP
return LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP
def set_logical_process_to_physical_process_map(mapping):
global LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP
LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP = mapping
def get_processor_to_physical_process_map():
global PROCESSOR_TO_PHYSICAL_PROCESS_MAP
return PROCESSOR_TO_PHYSICAL_PROCESS_MAP
def set_processor_to_physical_process_map(mapping):
global PROCESSOR_TO_PHYSICAL_PROCESS_MAP
PROCESSOR_TO_PHYSICAL_PROCESS_MAP = mapping
PROCESS_GROUP_MAP = {}
def get_all_process_groups():
global PROCESS_GROUP_MAP
return PROCESS_GROUP_MAP.values()
def new_process_group(ranks):
global PROCESS_GROUP_MAP
if not PROCESS_GROUP_MAP:
genv = _get_global_env()
PROCESS_GROUP_MAP["global_group"] = ProcessGroup(
0, list(range(genv.world_size)))
# A key constructed from ranks is used in the global process group map
key = ''.join(map(str, sorted(ranks)))
if key not in PROCESS_GROUP_MAP:
num_groups = len(PROCESS_GROUP_MAP)
# Note: our process group may interfere with the original implementation
# so the created group id should start from the original _new_ring_id()
group_id = _new_ring_id() + num_groups + 1
pg = ProcessGroup(group_id, ranks)
PROCESS_GROUP_MAP[key] = pg
return pg
else:
pg = PROCESS_GROUP_MAP[key]
return pg
# This implementation refers to lots of Paddle/python/paddle/distributed/collective.py,
# Fleet also has a collective helper which uses ops to initialize communication in
# Paddle/python/paddle/distributed/fleet/meta_optimizers/common.py. We use the first one
# because it seems simple. This should be enhanced to manage the process membership and
# the instantiation process in a more general way. In the future, the process group may
# handle the communication implementation choice.
class ProcessGroup:
def __init__(self, group_id, ranks):
self._group_id = group_id
self._ranks = sorted(ranks)
self._nranks = len(self._ranks)
self._is_instantiate = False
@property
def id(self):
return self._group_id
# @property
# def key(self):
# return ''.join(map(str, sorted(self._ranks)))
def local_rank(self, global_rank):
if global_rank in self._ranks:
return self._ranks.index(global_rank)
else:
assert False, \
"Rank {} doesn't belong to this group".format(global_rank)
def is_instantiate(self):
return self._is_instantiate
def instantiate(self):
if self._is_instantiate:
return
ring_id = self.id
genv = _get_global_env()
global_rank = genv.rank
if self._nranks >= 2:
strategy = core.ParallelStrategy()
strategy.nranks = self._nranks
strategy.local_rank = self.local_rank(global_rank)
strategy.trainer_endpoints = [
genv.trainer_endpoints[i] for i in self._ranks
]
strategy.current_endpoint = genv.current_endpoint
strategy.nrings = 1
if core.is_compiled_with_cuda():
place = core.CUDAPlace(genv.device_id)
core.NCCLParallelContext(strategy,
place).init_with_ring_id(ring_id)
else:
assert False, ("No CUDA device found")
# TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by cross-creation of new_group
tmp = paddle.to_tensor(
[1], dtype="int32") if in_dygraph_mode() else fill_constant(
[0], dtype="int32", value="1")
paddle.distributed.all_reduce(tmp, use_calc_stream=True)
paddle.distributed.wait(tmp)
self._is_instantiate = True
def __str__(self):
string = "id: {}, nranks: {}, ranks: {}.".format(
self.id, self._nranks, ", ".join(map(str, self._ranks)))
return string
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import threading import threading
import paddle.fluid.core as core import paddle.fluid.core as core
import numpy as np
def is_valid_list_index(list, index): def is_valid_list_index(list, index):
...@@ -155,3 +156,125 @@ def print_program_with_distributed_attr(program, dist_context=None): ...@@ -155,3 +156,125 @@ def print_program_with_distributed_attr(program, dist_context=None):
print(program) print(program)
set_default_distributed_context(original_default_context) set_default_distributed_context(original_default_context)
lock.release() lock.release()
def _get_comm_group(processes, shape, axis, rank):
"""
Given a rank and the processes mesh the rank belongs to,
compute the communication peers of the rank based on the give axis in the mesh.
Example: 16 processes managed in a 4-Dimensinal mesh with shape of [2, 2, 2, 2].
the rank communication peers of rank 0 (included) are following:
in axis 0: [0, 1]
in axis 1: [0, 2]
in axis 2: [0, 4]
in axis 3: [0, 8]
"""
# NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous
# tricks to support processes mesh when it is not start with 0 or continuous
rank_relatvie = processes.index(rank)
coordinate = _linear_idx2coordinate(shape, rank_relatvie)
coordinates_in_group = [coordinate[:] for i in range(shape[axis])]
# select comm group
for i in range(shape[axis]):
coordinates_in_group[i][axis] = i
ranks_in_group_relative = [
_coordinate2linear_idx(shape, coordinate)
for coordinate in coordinates_in_group
]
ranks_in_group = [processes[idx] for idx in ranks_in_group_relative]
return sorted(ranks_in_group)
def _coordinate2linear_idx(mesh_shape, coordinate):
"""
convert a coordinate in multidimensional mesh space into a scala idx in linear space.
it use Row-major order for dimension conversion.
so it has: [most_significant_dim, ..., least_significant_dim]
assume:
the size of i-th dimension to be: S[i]
the index of j-th dimension is: I[j]
linear_idx of a n dimensional coordinate is:
I[n-1] * (S[n-2] * S[n-3] * S[n-4] * .... S[0]) +
I[n-2] * ( S[n-3] * S[n-4] * .... S[0]) +
I[n-3] * ( S[n-4] * .... S[0]) +
...
I[1] * ( S[0]) +
I[0]
"""
# NOTE the following function work based on a strong an assumption
# that the processes in mesh are
# 1. starts from 0
# 2. continuous
# it will be wrong if ths above condition doesnot meet,
# e.g. process_mesh = { process_groups = [7, 8, 9,10, 12, 13, 14, 15], mesh = [2, 4]}
# if you want a more general mapping, you should use cartesian product
assert len(mesh_shape) == len(
coordinate
), "coordinate should have the same size as mesh shape, but got shape: {}, coordinate: {}".format(
mesh_shape, coordinate)
for i in range(len(mesh_shape)):
assert coordinate[
i] >= 0, "index in dimension [{}] is least than zero. coordinate: {}".format(
i, coordinate)
assert coordinate[i] < mesh_shape[
i], "index beyond extent in dimension [{}]. shape: {}, coordinate: {}".format(
i, mesh_shape, coordinate)
base = mesh_shape[-1]
linear_idx = coordinate[-1]
# row major order
for i in range(len(mesh_shape) - 2, -1, -1):
linear_idx += base * coordinate[i]
base *= mesh_shape[i]
return linear_idx
def _linear_idx2coordinate(mesh_shape, linear_idx):
"""
mapping a linear scala into multidimensional mesh space, return it coordinate in that space.
it is the inverse function of _coordinate2linear_idx.
assume:
the size of i-th dimension to be: S[i]
the index of j-th dimension is: I[j]
the coordinate given linear_idx is:
I[0] = linear_idx % S[0]
I[0] = (linear_idx / S[0]) % S[1]
I[0] = (linear_idx / (S[0] * S[1])) % S[2]
....
"""
assert linear_idx >= 0, "linear index [{}] is least than zero".format(
linear_idx)
assert linear_idx < np.prod(
mesh_shape
), "linear index beyond the extent of mesh shape. shape: {}, linear index: {}".format(
mesh_shape, linear_idx)
base = 1
coordinate = [-1] * len(mesh_shape)
for i in reversed(range(len(mesh_shape))):
offset = linear_idx / base
coordinate[i] = int(offset % mesh_shape[i])
base *= mesh_shape[i]
# row major order
return coordinate
...@@ -79,6 +79,8 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_meta_optimizer_base) ...@@ -79,6 +79,8 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_meta_optimizer_base)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_distributed_strategy) list(APPEND MIXED_DIST_TEST_OPS test_fleet_distributed_strategy)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_auto) list(APPEND MIXED_DIST_TEST_OPS test_fleet_auto)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_static_mp_layers) list(APPEND MIXED_DIST_TEST_OPS test_fleet_static_mp_layers)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner_gpt)
foreach(TEST_OP ${MIXED_DIST_TEST_OPS}) foreach(TEST_OP ${MIXED_DIST_TEST_OPS})
list(REMOVE_ITEM TEST_OPS ${TEST_OP}) list(REMOVE_ITEM TEST_OPS ${TEST_OP})
endforeach() endforeach()
...@@ -206,6 +208,8 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) ...@@ -206,6 +208,8 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
LIST(REMOVE_ITEM TEST_OPS test_dygraph_recompute) LIST(REMOVE_ITEM TEST_OPS test_dygraph_recompute)
list(REMOVE_ITEM TEST_OPS test_parallel_class_center_sample) list(REMOVE_ITEM TEST_OPS test_parallel_class_center_sample)
LIST(REMOVE_ITEM TEST_OPS test_parallel_margin_cross_entropy) LIST(REMOVE_ITEM TEST_OPS test_parallel_margin_cross_entropy)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner_gpt)
elseif(WITH_GPU) elseif(WITH_GPU)
if (${CUDNN_VERSION} VERSION_LESS 7100) if (${CUDNN_VERSION} VERSION_LESS 7100)
LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op) LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import unittest.mock
from io import StringIO
import numpy as np
import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.tensor as tensor
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
from paddle.distributed.auto_parallel.context import DistributedContext
from paddle.distributed.auto_parallel.context import set_default_distributed_context
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import _get_comm_group
from paddle.distributed.auto_parallel.process import new_process_group
paddle.enable_static()
_global_parallel_stratergy = None
_global_process_mesh = None
ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]])
def get_programs(annotated_func):
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
global _global_process_mesh
dist_context.set_process_mesh(_global_process_mesh)
train_program, start_program = annotated_func(train_program, start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
rank_id = 3
dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog = partitioner.transpile_forward(
complete_train_program, start_program)
return complete_train_program, start_program, test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog, dist_context
def is_all_parameters_shape_equal(prog1, prog2):
params1 = prog1.all_parameters()
params2 = prog2.all_parameters()
params1.sort(key=lambda x: x.name)
params2.sort(key=lambda x: x.name)
shape1 = [tensor.shape for tensor in params1]
shape2 = [tensor.shape for tensor in params2]
if len(shape1) != len(shape2):
return False
for i in range(len(shape1)):
if shape1[i] != shape2[i]:
return False
return True
def check_tensor_split(prog1, varnames1, prog2, varnames2, axis, nsplit):
for i in range(len(varnames1)):
var1 = prog1.global_block().var(varnames1[i])
var2 = prog2.global_block().var(varnames2[i])
if var1.shape[axis] != (var2.shape[axis] // nsplit):
return False
return True
def initialization_check(mode, dist_context, dist_startup_prog,
serial_startup_prog, var_need_broadcast):
if 'mp' in mode:
mp_parallel_axis, process_mesh = dist_context._get_model_parallel_info()
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology, mp_parallel_axis,
3)
mp_ring_id = new_process_group(group_ranks).id
broadcast_ops = [
op for op in dist_startup_prog.global_block().ops
if (op.type == "c_broadcast" and op.desc.attr("ring_id") ==
mp_ring_id)
]
broadcast_varnames = sorted(
[op.desc.output_arg_names()[0] for op in broadcast_ops])
if broadcast_varnames != var_need_broadcast:
return False
if 'dp' in mode:
dp_parallel_axis, process_mesh = dist_context._get_data_parallel_info()
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology, dp_parallel_axis,
3)
dp_ring_id = new_process_group(group_ranks).id
nparam = len(serial_startup_prog.all_parameters())
nbroadcast_dp = len([
op for op in dist_startup_prog.global_block().ops
if (op.type == "c_broadcast" and op.desc.attr("ring_id") ==
dp_ring_id)
])
if nparam != nbroadcast_dp:
return False
if "dp" in mode and 'mp' in mode:
nbroadcast = len([
op for op in dist_startup_prog.global_block().ops
if op.type == "c_broadcast"
])
if len(var_need_broadcast) + nbroadcast_dp != nbroadcast:
return False
return True
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1])
else:
auto.shard_tensor(
self.linear0.weight, _global_process_mesh,
dim_mapping=[-1, -1])
auto.shard_tensor(
self.linear1.weight, _global_process_mesh,
dim_mapping=[-1, -1])
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
out = self.dropout(out)
return out
def mlp_pretrain_forward(train_program, start_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 4
hidden_size = 1024
sequence_len = 512
input = static.data(
name="input",
shape=[batch_size, sequence_len, hidden_size],
dtype='float32')
if _global_parallel_stratergy == "dp":
auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1])
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
out = mlp(input)
return train_program, start_program
class TestMLPAutoPartitioner(unittest.TestCase):
def test_mlp_dp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
mlp_pretrain_forward)
# parameter should not be partitioned
self.assertTrue(
is_all_parameters_shape_equal(serial_main_prog, dist_main_prog))
self.assertTrue(
is_all_parameters_shape_equal(serial_startup_prog,
dist_startup_prog))
# op in main prog should be the same
serial_ops = serial_main_prog.global_block().ops
dist_ops = dist_main_prog.global_block().ops
serial_ops = [op.type for op in serial_ops]
dist_ops = [op.type for op in dist_ops]
self.assertTrue(serial_ops == dist_ops)
# parameter initialization
var_need_broadcast = []
self.assertTrue(
initialization_check(_global_parallel_stratergy, dist_context,
dist_startup_prog, serial_startup_prog,
var_need_broadcast))
def test_mlp_mp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
mlp_pretrain_forward)
# param should be partition
nrank = 4
# col parallel
weights = ['linear_0.w_0']
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 1, nrank))
weights = ['linear_0.b_0']
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 0, nrank))
# row parallel
weights = ['linear_1.w_0']
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 0, nrank))
weights = ['linear_1.b_0']
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 0, 1))
# row and col allreduce
dist_ops = dist_main_prog.global_block().ops
dist_ops = [op.type for op in dist_ops]
ref_ops = [
'layer_norm', 'c_identity', 'matmul', 'elementwise_add', 'gelu',
'matmul', 'c_allreduce_sum', 'elementwise_add', 'dropout'
]
self.assertTrue(dist_ops == ref_ops)
# parameter initialization
var_need_broadcast = sorted(
['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0'])
self.assertTrue(
initialization_check(_global_parallel_stratergy, dist_context,
dist_startup_prog, serial_startup_prog,
var_need_broadcast))
def test_mlp_dp_mp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
mlp_pretrain_forward)
# param should be partition
nrank = 4
# col parallel
weights = ['linear_0.w_0']
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 1, nrank))
weights = ['linear_0.b_0']
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 0, nrank))
# row parallel
weights = ['linear_1.w_0']
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 0, nrank))
weights = ['linear_1.b_0']
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 0, 1))
# row and col allreduce
dist_ops = dist_main_prog.global_block().ops
dist_ops = [op.type for op in dist_ops]
ref_ops = [
'layer_norm', 'c_identity', 'matmul', 'elementwise_add', 'gelu',
'matmul', 'c_allreduce_sum', 'elementwise_add', 'dropout'
]
self.assertTrue(dist_ops == ref_ops)
# parameter initialization
var_need_broadcast = sorted(
['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0'])
self.assertTrue(
initialization_check(_global_parallel_stratergy, dist_context,
dist_startup_prog, serial_startup_prog,
var_need_broadcast))
class AttentionLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
sequence_len=512,
intermediate_size=4 * 1024,
num_heads=16,
dropout_ratio=0.1,
initializer_range=0.02):
super(AttentionLayer, self).__init__()
self.hidden_size = hidden_size
self.sequence_len = sequence_len
self.embed_dim = self.hidden_size
self.kdim = self.embed_dim
self.vdim = self.embed_dim
self.num_heads = num_heads
self.head_dim = self.embed_dim // self.num_heads
assert self.head_dim * self.num_heads == self.embed_dim, \
"embed_dim must be divisible by num_heads"
self.dropout_ratio = dropout_ratio
self.initializer_range = initializer_range
self.training = True
self.attn_mask = None
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))
bias_attr = None
self.q_proj = nn.Linear(
self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)
self.k_proj = nn.Linear(
self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
self.v_proj = nn.Linear(
self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
self.out_proj = nn.Linear(
self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)
def forward(self, input):
if _global_parallel_stratergy == "dp":
auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1])
q = self.q_proj(input)
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
k = self.k_proj(input)
v = self.v_proj(input)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
v = tensor.transpose(x=v, perm=[0, 2, 1, 3])
# scale dot product attention
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)
if self.attn_mask is not None:
product = product + self.attn_mask
weights = F.softmax(product)
if self.dropout_ratio:
weights = F.dropout(
weights,
self.dropout_ratio,
training=self.training,
mode="upscale_in_train")
out = tensor.matmul(weights, v)
# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
# project to output
out = self.out_proj(out)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[1, -1])
return out
def attn_pretrain_forward(train_program, start_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 4
hidden_size = 1024
sequence_len = 512
input = static.data(
name="query",
shape=[batch_size, sequence_len, hidden_size],
dtype='float32')
attn = AttentionLayer(
hidden_size=hidden_size,
sequence_len=sequence_len,
intermediate_size=4 * hidden_size,
num_heads=16,
dropout_ratio=0.1,
initializer_range=0.02)
out = attn(input)
return train_program, start_program
class TestAttentionAutoPartitioner(unittest.TestCase):
def test_attn_dp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
attn_pretrain_forward)
# parameter should not be partitioned
self.assertTrue(
is_all_parameters_shape_equal(serial_main_prog, dist_main_prog))
self.assertTrue(
is_all_parameters_shape_equal(serial_startup_prog,
dist_startup_prog))
# op in main prog should be the same
serial_ops = serial_main_prog.global_block().ops
dist_ops = dist_main_prog.global_block().ops
serial_ops = [op.type for op in serial_ops]
dist_ops = [op.type for op in dist_ops]
self.assertTrue(serial_ops == dist_ops)
# parameter initialization
var_need_broadcast = []
self.assertTrue(
initialization_check(_global_parallel_stratergy, dist_context,
dist_startup_prog, serial_startup_prog,
var_need_broadcast))
def test_attn_mp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
attn_pretrain_forward)
# param should be partition
nrank = 4
# col parallel
weights = ['linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0']
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 1, nrank))
weights = ['linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0']
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 0, nrank))
# row parallel
weights = ['linear_3.w_0']
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 0, nrank))
weights = ['linear_3.b_0']
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 0, 1))
# row and col allreduce
dist_ops = dist_main_prog.global_block().ops
dist_ops = [op.type for op in dist_ops]
ref_ops = [
'c_identity', 'matmul', 'elementwise_add', 'reshape2', 'transpose2',
'c_identity', 'matmul', 'elementwise_add', 'c_identity', 'matmul',
'elementwise_add', 'reshape2', 'transpose2', 'reshape2',
'transpose2', 'matmul', 'softmax', 'dropout', 'matmul_v2',
'transpose2', 'reshape2', 'matmul', 'c_allreduce_sum',
'elementwise_add'
]
self.assertTrue(dist_ops == ref_ops)
# parameter initialization
var_need_broadcast = ['linear_3.b_0']
self.assertTrue(
initialization_check(_global_parallel_stratergy, dist_context,
dist_startup_prog, serial_startup_prog,
var_need_broadcast))
def test_attn_dp_mp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
attn_pretrain_forward)
# param should be partition
nrank = 4
# col parallel
weights = ['linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0']
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 1, nrank))
weights = ['linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0']
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 0, nrank))
# row parallel
weights = ['linear_3.w_0']
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 0, nrank))
weights = ['linear_3.b_0']
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 0, 1))
# row and col allreduce
dist_ops = dist_main_prog.global_block().ops
dist_ops = [op.type for op in dist_ops]
ref_ops = [
'c_identity', 'matmul', 'elementwise_add', 'reshape2', 'transpose2',
'c_identity', 'matmul', 'elementwise_add', 'c_identity', 'matmul',
'elementwise_add', 'reshape2', 'transpose2', 'reshape2',
'transpose2', 'matmul', 'softmax', 'dropout', 'matmul_v2',
'transpose2', 'reshape2', 'matmul', 'c_allreduce_sum',
'elementwise_add'
]
self.assertTrue(dist_ops == ref_ops)
# parameter initialization
var_need_broadcast = ['linear_3.b_0']
self.assertTrue(
initialization_check(_global_parallel_stratergy, dist_context,
dist_startup_prog, serial_startup_prog,
var_need_broadcast))
class DecoderLayer(nn.Layer):
def __init__(self,
vocab_size=32768,
hidden_size=1024,
sequence_len=512,
max_position_embeddings=512,
intermediate_size=4 * 1024,
num_heads=16,
dropout_ratio=0.1,
initializer_range=0.02):
super(DecoderLayer, self).__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.max_position_embeddings = max_position_embeddings
self.sequence_len = sequence_len
self.embed_dim = self.hidden_size
self.kdim = self.embed_dim
self.vdim = self.embed_dim
self.num_heads = num_heads
self.dropout_ratio = dropout_ratio
self.initializer_range = initializer_range
self.training = True
self.attn_mask = None
self.head_dim = self.embed_dim // self.num_heads
assert self.head_dim * self.num_heads == self.embed_dim, \
"embed_dim must be divisible by num_heads"
self.word_embeddings = nn.Embedding(
self.vocab_size,
self.hidden_size,
weight_attr=paddle.ParamAttr(
name="word_embeddings",
initializer=nn.initializer.Normal(
mean=0.0, std=self.initializer_range)))
self.position_embeddings = nn.Embedding(
self.max_position_embeddings,
self.hidden_size,
weight_attr=paddle.ParamAttr(
name="pos_embeddings",
initializer=nn.initializer.Normal(
mean=0.0, std=self.initializer_range)))
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=self.initializer_range))
bias_attr = None
self.q_proj = nn.Linear(
self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)
self.k_proj = nn.Linear(
self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
self.v_proj = nn.Linear(
self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
self.out_proj = nn.Linear(
self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)
intermediate_size = 4 * self.hidden_size
d_model = self.hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=self.initializer_range))
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout1 = nn.Dropout(self.dropout_ratio)
self.dropout2 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
def forward(self, input_ids, position_ids):
if _global_parallel_stratergy == "dp":
auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1])
input_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.word_embeddings.weight,
_global_process_mesh,
dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.word_embeddings.weight,
_global_process_mesh,
dim_mapping=[1, -1])
embeddings = input_embeddings + position_embeddings
embeddings = self.dropout1(embeddings)
# Pre-norm
target = self.norm(embeddings)
# The following is the attention part
q = self.q_proj(target)
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
k = self.k_proj(target)
v = self.v_proj(target)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
v = tensor.transpose(x=v, perm=[0, 2, 1, 3])
# scale dot product attention
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)
if self.attn_mask is not None:
product = product + self.attn_mask
weights = F.softmax(product)
if self.dropout_ratio:
weights = F.dropout(
weights,
self.dropout_ratio,
training=self.training,
mode="upscale_in_train")
out = tensor.matmul(weights, v)
# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
# project to output
out = self.out_proj(out)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[1, -1])
else:
auto.shard_tensor(
self.out_proj.weight,
_global_process_mesh,
dim_mapping=[-1, -1])
# Add residual
residual = embeddings + self.dropout2(out)
# Pre-norm
out0 = self.norm(residual)
# The following is the MLP part
out1 = self.linear0(out0)
out2 = F.gelu(out1, approximate=True)
out3 = self.linear1(out2)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1])
# Add residual
final = residual + self.dropout3(out3)
return final
def decoder_pretrain_forward(train_program, start_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 4
hidden_size = 1024
sequence_len = 512
input_ids = static.data(
name="input_ids", shape=[batch_size, sequence_len], dtype='int64')
position_ids = static.data(
name="position_ids",
shape=[batch_size, sequence_len],
dtype='int64')
decoder = DecoderLayer(
vocab_size=32768,
hidden_size=hidden_size,
sequence_len=sequence_len,
max_position_embeddings=512,
intermediate_size=4 * hidden_size,
num_heads=16,
dropout_ratio=0.1,
initializer_range=0.02)
out = decoder(input_ids, position_ids)
return train_program, start_program
class TestDecoderLayerPartitioner(unittest.TestCase):
def test_decoder_dp_mp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
decoder_pretrain_forward)
# param should be partition
nrank = 4
# col parallel
weights = [
'linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0', 'linear_4.w_0'
]
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 1, nrank))
weights = [
'linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0', 'linear_4.b_0'
]
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 0, nrank))
# row parallel
weights = ['word_embeddings', 'linear_3.w_0', 'linear_5.w_0']
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 0, nrank))
weights = [
'linear_3.b_0', 'pos_embeddings', 'layer_norm_0.b_0',
'layer_norm_0.w_0', 'linear_5.b_0'
]
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 0, 1))
# row and col allreduce
dist_ops = dist_main_prog.global_block().ops
dist_ops = [op.type for op in dist_ops]
ref_ops = [
'c_embedding', 'c_allreduce_sum', 'lookup_table_v2',
'elementwise_add', 'dropout', 'layer_norm', 'c_identity', 'matmul',
'elementwise_add', 'reshape2', 'transpose2', 'c_identity', 'matmul',
'elementwise_add', 'c_identity', 'matmul', 'elementwise_add',
'reshape2', 'transpose2', 'reshape2', 'transpose2', 'matmul',
'softmax', 'dropout', 'matmul_v2', 'transpose2', 'reshape2',
'matmul', 'c_allreduce_sum', 'elementwise_add', 'dropout',
'elementwise_add', 'layer_norm', 'c_identity', 'matmul',
'elementwise_add', 'gelu', 'matmul', 'c_allreduce_sum',
'elementwise_add', 'dropout', 'elementwise_add'
]
self.assertTrue(dist_ops == ref_ops)
# parameter initialization
var_need_broadcast = sorted([
'linear_3.b_0', 'pos_embeddings', 'layer_norm_0.b_0',
'layer_norm_0.w_0', 'linear_5.b_0'
])
self.assertTrue(
initialization_check(_global_parallel_stratergy, dist_context,
dist_startup_prog, serial_startup_prog,
var_need_broadcast))
def test_decoder_noparallel(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "None"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
decoder_pretrain_forward)
# param should be partition
nrank = 1
# col parallel
weights = [
'linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0', 'linear_4.w_0'
]
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 1, nrank))
weights = [
'linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0', 'linear_4.b_0'
]
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 0, nrank))
# row parallel
weights = ['word_embeddings', 'linear_3.w_0', 'linear_5.w_0']
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 0, nrank))
weights = [
'linear_3.b_0', 'pos_embeddings', 'layer_norm_0.b_0',
'layer_norm_0.w_0', 'linear_5.b_0'
]
self.assertTrue(
check_tensor_split(dist_main_prog, weights, serial_main_prog,
weights, 0, 1))
# row and col allreduce
dist_ops = dist_main_prog.global_block().ops
dist_ops = [op.type for op in dist_ops]
ref_ops = [
'lookup_table_v2', 'lookup_table_v2', 'elementwise_add', 'dropout',
'layer_norm', 'matmul', 'elementwise_add', 'reshape2', 'transpose2',
'matmul', 'elementwise_add', 'matmul', 'elementwise_add',
'reshape2', 'transpose2', 'reshape2', 'transpose2', 'matmul',
'softmax', 'dropout', 'matmul_v2', 'transpose2', 'reshape2',
'matmul', 'elementwise_add', 'dropout', 'elementwise_add',
'layer_norm', 'matmul', 'elementwise_add', 'gelu', 'matmul',
'elementwise_add', 'dropout', 'elementwise_add'
]
self.assertTrue(dist_ops == ref_ops)
dist_ops = dist_startup_prog.global_block().ops
dist_ops = [op.type for op in dist_ops]
ref_ops = [
'gaussian_random', 'gaussian_random', 'gaussian_random',
'fill_constant', 'gaussian_random', 'fill_constant',
'gaussian_random', 'fill_constant', 'gaussian_random',
'fill_constant', 'gaussian_random', 'fill_constant',
'gaussian_random', 'fill_constant', 'fill_constant', 'fill_constant'
]
self.assertTrue(dist_ops == ref_ops)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import collections
import math
import unittest
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.tensor as tensor
import paddle.utils as utils
from paddle.fluid import layers
from paddle.fluid.framework import in_dygraph_mode
from paddle.nn.layer.transformer import _convert_param_attr_to_list
from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.distributed import fleet
import paddle.static as static
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr
from paddle.distributed.auto_parallel.context import DistributedContext
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import _get_comm_group
from paddle.distributed.auto_parallel.process import new_process_group
paddle.enable_static()
ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]])
_global_parallel_stratergy = None
_global_process_mesh = None
def check_tensor_split(prog1, varnames1, prog2, varnames2, axis, nsplit):
for i in range(len(varnames1)):
var1 = prog1.global_block().var(varnames1[i] + '@GRAD')
var2 = prog2.global_block().var(varnames2[i])
if var1.shape[axis] != (var2.shape[axis] // nsplit):
return False
return True
class MultiHeadAttention(nn.Layer):
"""
Attention mapps queries and a set of key-value pairs to outputs, and
Multi-Head Attention performs multiple parallel attention to jointly attending
to information from different representation subspaces.
"""
Cache = collections.namedtuple("Cache", ["k", "v"])
StaticCache = collections.namedtuple("StaticCache", ["k", "v"])
def __init__(self,
embed_dim,
num_heads,
dropout=0.,
kdim=None,
vdim=None,
need_weights=False,
weight_attr=None,
bias_attr=None,
topo=None,
fuse=False):
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.need_weights = need_weights
self.fuse = fuse
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
if topo is None or topo.mp_info.size == 1:
if self.fuse:
assert self.kdim == embed_dim
assert self.vdim == embed_dim
self.qkv_proj = nn.Linear(
embed_dim, 3 * embed_dim, weight_attr, bias_attr=bias_attr)
else:
self.q_proj = nn.Linear(
embed_dim, embed_dim, weight_attr, bias_attr=bias_attr)
self.k_proj = nn.Linear(
self.kdim, embed_dim, weight_attr, bias_attr=bias_attr)
self.v_proj = nn.Linear(
self.vdim, embed_dim, weight_attr, bias_attr=bias_attr)
self.out_proj = nn.Linear(
embed_dim, embed_dim, weight_attr, bias_attr=bias_attr)
def _fuse_prepare_qkv(self, query):
mix_layer = self.qkv_proj(query)
mix_layer = paddle.reshape_(mix_layer,
[0, 0, self.num_heads, 3 * self.head_dim])
mix_layer = paddle.transpose(mix_layer, [0, 2, 1, 3])
q, k, v = paddle.split(mix_layer, num_or_sections=3, axis=-1)
return q, k, v
def _prepare_qkv(self, query, key, value, use_cache=False, cache=None):
r"""
Prapares linear projected queries, keys and values for usage of subsequnt
multiple parallel attention. If `cache` is not None, using cached results
to reduce redundant calculations.
"""
q = self.q_proj(query)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
if isinstance(cache, self.StaticCache):
# for encoder-decoder attention in inference and has cached
k, v = cache.k, cache.v
else:
k, v = self.compute_kv(key, value)
if isinstance(cache, self.Cache):
# for decoder self-attention in inference
k = tensor.concat([cache.k, k], axis=2)
v = tensor.concat([cache.v, v], axis=2)
if use_cache is True:
cache = self.Cache(k, v)
return (q, k, v) if use_cache is False else (q, k, v, cache)
def compute_kv(self, key, value):
r"""
Applies linear projection on input keys and values, then splits heads
(reshape and transpose) to get keys and values from different representation
subspaces. The results are used as key-values pairs for subsequent multiple
parallel attention.
It is part of calculations in multi-head attention, and is provided as
a method to pre-compute and prefetch these results, thus we can use them
to construct cache for inference.
"""
k = self.k_proj(key)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
v = self.v_proj(value)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
v = tensor.transpose(x=v, perm=[0, 2, 1, 3])
return k, v
def gen_cache(self, key, value=None, type=Cache):
"""
Generates cache for `forward` usage in inference accroding to arguments.
The generated cache is an instance of `MultiHeadAttention.Cache` or an
instance of `MultiHeadAttention.StaticCache`.
"""
if type == MultiHeadAttention.StaticCache: # static_kv
k, v = self.compute_kv(key, value)
return self.StaticCache(k, v)
elif value is None: # incremental_state
k = layers.fill_constant_batch_size_like(
input=key,
shape=[-1, self.num_heads, 0, self.head_dim],
dtype=key.dtype,
value=0)
v = layers.fill_constant_batch_size_like(
input=key,
shape=[-1, self.num_heads, 0, self.head_dim],
dtype=key.dtype,
value=0)
return self.Cache(k, v)
else:
# incremental_state with initial value, mainly for usage like UniLM
return self.Cache(key, value)
def forward(self,
query,
key,
value,
attn_mask=None,
use_cache=False,
cache=None):
r"""
Applies multi-head attention to map queries and a set of key-value pairs
to outputs.
"""
key = query if key is None else key
value = query if value is None else value
# compute q ,k ,v
if use_cache is False:
if self.fuse:
q, k, v = self._fuse_prepare_qkv(query)
else:
q, k, v = self._prepare_qkv(query, key, value, use_cache, cache)
else:
q, k, v, cache = self._prepare_qkv(query, key, value, use_cache,
cache)
# scale dot product attention
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)
if attn_mask is not None:
product = product + attn_mask
weights = F.softmax(product)
if self.dropout:
weights = F.dropout(
weights,
self.dropout,
training=self.training,
mode="upscale_in_train")
out = tensor.matmul(weights, v)
# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
# project to output
out = self.out_proj(out)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[1, -1])
outs = [out]
if self.need_weights:
outs.append(weights)
if use_cache:
outs.append(cache)
return out if len(outs) == 1 else tuple(outs)
class TransformerDecoder(nn.Layer):
"""
TransformerDecoder is a stack of N decoder layers.
"""
def __init__(self,
decoder_layers,
num_layers,
norm=None,
hidden_size=None,
topo=None):
super(TransformerDecoder, self).__init__()
self.topo = topo
self.num_layers = num_layers
self.layers = decoder_layers
self.norm = norm
if norm is "LayerNorm":
self.norm = nn.LayerNorm(hidden_size)
elif norm is not None:
raise ValueError("Only support LayerNorm")
self.checkpoints = []
def forward(self,
tgt,
memory,
tgt_mask=None,
memory_mask=None,
use_cache=False,
cache=None):
r"""
Applies a stack of N Transformer decoder layers on inputs. If `norm` is
provided, also applies layer normalization on the output of last decoder
layer.
"""
output = tgt
new_caches = []
self.checkpoints = []
for i, mod in enumerate(self.layers):
if cache is None:
if use_cache:
output, new_cache = mod(output,
memory,
tgt_mask=tgt_mask,
use_cache=use_cache,
cache=cache)
new_caches.append(new_cache)
else:
output = mod(output,
memory,
tgt_mask=tgt_mask,
use_cache=use_cache,
cache=cache)
else:
output, new_cache = mod(output,
memory,
tgt_mask=tgt_mask,
use_cache=use_cache,
cache=cache[i])
new_caches.append(new_cache)
self.checkpoints.append(output.name)
if self.norm is not None:
output = self.norm(output)
return output if use_cache is False else (output, new_caches)
def gen_cache(self, memory, do_zip=False):
r"""
Generates cache for `forward` usage. The generated cache is a list, and
each element in it is a tuple( :code:`(incremental_cache, static_cache)` )
produced by `TransformerDecoderLayer.gen_cache`. See `TransformerDecoderLayer.gen_cache`
for more details. If `do_zip` is True, apply `zip` on these tuples to get
a list with two elements.
"""
cache = [layer.gen_cache(memory) for layer in self.layers]
if do_zip:
cache = list(zip(*cache))
return cache
class TransformerDecoderLayer(nn.Layer):
"""
The transformer decoder layer.
It contains multiheadattention and some linear layers.
"""
def __init__(self,
d_model,
nhead,
dim_feedforward,
dropout=0.1,
activation="gelu",
attn_dropout=None,
act_dropout=None,
normalize_before=True,
weight_attr=None,
bias_attr=None,
topo=None):
self._config = locals()
self._config.pop("self")
self._config.pop("__class__", None) # py3
super(TransformerDecoderLayer, self).__init__()
attn_dropout = dropout if attn_dropout is None else attn_dropout
act_dropout = dropout if act_dropout is None else act_dropout
self.normalize_before = normalize_before
weight_attrs = _convert_param_attr_to_list(weight_attr, 3)
bias_attrs = _convert_param_attr_to_list(bias_attr, 3)
self.self_attn = MultiHeadAttention(
d_model,
nhead,
dropout=attn_dropout,
weight_attr=weight_attrs[0],
bias_attr=bias_attrs[0],
topo=topo)
if topo is None or topo.mp_info.size == 1:
self.linear1 = nn.Linear(
d_model,
dim_feedforward,
weight_attrs[2],
bias_attr=bias_attrs[2])
self.linear2 = nn.Linear(
dim_feedforward,
d_model,
weight_attrs[2],
bias_attr=bias_attrs[2])
self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
self.dropout2 = nn.Dropout(act_dropout, mode="upscale_in_train")
self.activation = getattr(F, activation)
def forward(self, tgt, memory, tgt_mask=None, use_cache=False, cache=None):
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
if use_cache is False:
tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache)
else:
tgt, incremental_cache = self.self_attn(tgt, tgt, tgt, tgt_mask,
use_cache, cache)
tgt = residual + self.dropout1(tgt)
if not self.normalize_before:
tgt = self.norm1(tgt)
residual = tgt
if self.normalize_before:
tgt = self.norm2(tgt)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 1])
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.linear2.weight, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.linear2.weight, _global_process_mesh, dim_mapping=[1, -1])
# tgt = self.dropout2(
# self.linear2(F.gelu(
# self.linear1(tgt), approximate=True)))
tgt = self.linear1(tgt)
tgt = F.gelu(tgt, approximate=True)
tgt = self.dropout2(self.linear2(tgt))
tgt = residual + tgt
if not self.normalize_before:
tgt = self.norm2(tgt)
return tgt if use_cache is False else (tgt, incremental_cache)
def gen_cache(self, memory):
incremental_cache = self.self_attn.gen_cache(
memory, type=self.self_attn.Cache)
return incremental_cache
class GPTEmbeddings(nn.Layer):
"""
Include embeddings from word, position and token_type embeddings
"""
def __init__(self,
vocab_size,
hidden_size=768,
hidden_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02,
topo=None):
super(GPTEmbeddings, self).__init__()
if topo is None or topo.mp_info.size == 1:
self.word_embeddings = nn.Embedding(
vocab_size,
hidden_size,
weight_attr=paddle.ParamAttr(
name="word_embeddings",
initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range)))
self.position_embeddings = nn.Embedding(
max_position_embeddings,
hidden_size,
weight_attr=paddle.ParamAttr(
name="pos_embeddings",
initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range)))
self.dropout = nn.Dropout(hidden_dropout_prob)
def forward(self, input_ids, position_ids=None):
if position_ids is None:
ones = paddle.ones_like(input_ids, dtype="int64")
seq_length = paddle.cumsum(ones, axis=-1)
position_ids = seq_length - ones
input_embedings = self.word_embeddings(input_ids)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.word_embeddings.weight,
_global_process_mesh,
dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.word_embeddings.weight,
_global_process_mesh,
dim_mapping=[1, -1])
position_embeddings = self.position_embeddings(position_ids)
embeddings = input_embedings + position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
class GPTModel(nn.Layer):
"""
The base model of gpt.
"""
def __init__(self,
vocab_size,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02,
pad_token_id=0,
topo=None):
super(GPTModel, self).__init__()
self.pad_token_id = pad_token_id
self.initializer_range = initializer_range
self.topo = topo
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.pipline_mode = topo is not None and topo.pp_info.size > 1
if self.pipline_mode:
self.layer_per_stage = num_hidden_layers // self.topo.pp_info.size
self.embeddings = GPTEmbeddings(
vocab_size, hidden_size, hidden_dropout_prob,
max_position_embeddings, type_vocab_size, self.initializer_range,
topo)
decoder_layers = nn.LayerList()
for i in range(num_hidden_layers):
DecoderLayer = TransformerDecoderLayer
decoder_layers.append(
DecoderLayer(
d_model=hidden_size,
nhead=num_attention_heads,
dim_feedforward=intermediate_size,
dropout=hidden_dropout_prob,
activation=hidden_act,
attn_dropout=attention_probs_dropout_prob,
act_dropout=hidden_dropout_prob,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal(
mean=0.0, std=self.initializer_range)),
bias_attr=None,
topo=topo))
Decoder = TransformerDecoder
self.decoder = Decoder(
decoder_layers,
num_hidden_layers,
norm="LayerNorm",
hidden_size=hidden_size,
topo=topo)
self.checkpoints = []
def forward(self,
input_ids,
position_ids=None,
attention_mask=None,
use_cache=False,
cache=None):
self.checkpoints = []
if attention_mask is None:
length = paddle.shape(input_ids)[1]
# Use bool mask
attention_mask = paddle.tensor.tril(
paddle.ones(
(length, length),
dtype=self.embeddings.word_embeddings.weight.dtype))
if position_ids is None:
past_length = 0
if cache is not None:
past_length = paddle.shape(cache[0].k)[-2]
position_ids = paddle.arange(
past_length,
paddle.shape(input_ids)[-1] + past_length,
dtype='int64')
position_ids = position_ids.unsqueeze(0)
# .expand_as(input_ids)
position_ids = paddle.fluid.layers.expand_as(position_ids,
input_ids)
embedding_output = self.embeddings(
input_ids=input_ids, position_ids=position_ids)
# TODO, use registered buffer
causal_mask = paddle.tensor.triu(
paddle.ones((paddle.shape(input_ids)[-1],
paddle.shape(input_ids)[-1])) * -1e9,
diagonal=1)
if attention_mask is not None:
attention_mask = attention_mask + causal_mask
else:
attention_mask = causal_mask
# The tensor returned by triu not in static graph.
attention_mask.stop_gradient = True
encoder_outputs = self.decoder(
embedding_output,
memory=None,
tgt_mask=attention_mask,
use_cache=use_cache,
cache=cache)
self.checkpoints.extend(self.decoder.checkpoints)
return encoder_outputs
class GPTForPretraining(nn.Layer):
"""
The pretraining model of GPT.
It returns some logits and cached_kvs.
"""
def __init__(self, gpt):
super(GPTForPretraining, self).__init__()
self.gpt = gpt
self.share_param = False
self.weight = self.gpt.embeddings.word_embeddings.weight
if not self.share_param:
self.weight = self.create_parameter(shape=self.weight.shape)
def parallel_matmul(self, lm_output, logit_weights, parallel_output, topo):
if topo is not None and topo.mp_info.size > 1:
input_parallel = paddle.distributed.collective._c_identity(
lm_output, group=None)
logits = paddle.matmul(
input_parallel, logit_weights, transpose_y=True)
if parallel_output:
return logits
return paddle.distributed.collective._c_concat(logits, group=None)
else:
logits = paddle.matmul(lm_output, logit_weights, transpose_y=True)
return logits
def forward(self,
input_ids,
position_ids=None,
attention_mask=None,
masked_positions=None,
use_cache=False,
cache=None):
outputs = self.gpt(input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
use_cache=use_cache,
cache=cache)
if use_cache:
encoder_outputs, cached_kvs = outputs[:2]
else:
encoder_outputs = outputs
logits = self.parallel_matmul(encoder_outputs, self.weight, True,
self.gpt.topo)
if use_cache:
return logits, cached_kvs
else:
return logits
class GPTPretrainingCriterion(nn.Layer):
"""
Criterion for GPT.
It calculates the final loss.
"""
def __init__(self, topo=None):
super(GPTPretrainingCriterion, self).__init__()
if topo is None or topo.mp_info.size == 1:
self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none")
else:
self.loss_func = paddle.distributed.collective._c_softmax_with_cross_entropy
def forward(self, prediction_scores, masked_lm_labels, loss_mask):
masked_lm_loss = self.loss_func(prediction_scores,
masked_lm_labels.unsqueeze(2))
loss_mask = loss_mask.reshape([-1])
masked_lm_loss = paddle.sum(masked_lm_loss.reshape([-1]) * loss_mask)
loss = masked_lm_loss / loss_mask.sum()
return loss
def gpt_pretrain_forward(train_program, start_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 16
sequence_len = 512
input_ids = static.data(
name="input_ids", shape=[batch_size, sequence_len], dtype='int64')
position_ids = static.data(
name="position_ids",
shape=[batch_size, sequence_len],
dtype='int64')
attention_mask = static.data(
name="attention_mask",
shape=[batch_size, 1, sequence_len, sequence_len],
dtype='float64')
labels = static.data(
name="labels", shape=[batch_size, sequence_len], dtype='int64')
loss_mask = static.data(
name="loss_mask", shape=[batch_size, sequence_len], dtype='float64')
if _global_parallel_stratergy == "dp":
auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1])
gpt = GPTModel(
vocab_size=32768,
hidden_size=768,
num_hidden_layers=2,
num_attention_heads=12,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=1024,
type_vocab_size=16,
initializer_range=0.02,
pad_token_id=0,
topo=None)
model = GPTForPretraining(gpt)
preds = model(input_ids, position_ids, attention_mask)
criterion = GPTPretrainingCriterion()
loss = criterion(preds, labels, loss_mask)
return train_program, start_program, loss
class TestGPTPartitioner(unittest.TestCase):
def test_gpt_dp_mp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
dist_context.set_process_mesh(_global_process_mesh)
train_program, start_program, loss = gpt_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
rank_id = 3
dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward(
complete_train_program, start_program)
dist_params_grads = partitioner.apply_backward(
loss, complete_train_program, start_program,
auto_parallel_main_prog, auto_parallel_startup_prog)
optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
auto_parallel_main_prog,
auto_parallel_startup_prog)
nrank = 4
# col parallel
weights = [
'linear_0.w_0',
'linear_6.w_0',
'linear_10.w_0',
]
self.assertTrue(
check_tensor_split(auto_parallel_main_prog, weights,
complete_train_program, weights, 1, nrank))
# row parallel
weights = ['word_embeddings', 'linear_9.w_0', 'linear_11.w_0']
self.assertTrue(
check_tensor_split(auto_parallel_main_prog, weights,
complete_train_program, weights, 0, nrank))
weights = ['pos_embeddings', 'layer_norm_0.b_0', 'layer_norm_4.w_0']
self.assertTrue(
check_tensor_split(auto_parallel_main_prog, weights,
complete_train_program, weights, 0, 1))
all_params = sorted(
[param.name for param in start_program.all_parameters()])
allreduce_grads = [
'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2',
'layer_norm_6.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_7.tmp_2',
'layer_norm_7.tmp_2', 'layer_norm_8.tmp_2'
]
mp_parallel_axis, process_mesh = dist_context._get_model_parallel_info()
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology, mp_parallel_axis,
3)
mp_ring_id = new_process_group(group_ranks).id
dp_parallel_axis, process_mesh = dist_context._get_data_parallel_info()
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology, dp_parallel_axis,
3)
dp_ring_id = new_process_group(group_ranks).id
tensor_parallel_allreduce_vars = sorted([
op.desc.output_arg_names()[0].split("@")[0]
for op in auto_parallel_main_prog.global_block().ops
if (op.type == "c_allreduce_sum" and op.attr('op_role') == 1 and
op.desc.attr("ring_id") == mp_ring_id)
])
data_parallel_allreduce_vars = sorted([
op.desc.output_arg_names()[0].split("@")[0]
for op in auto_parallel_main_prog.global_block().ops
if (op.type == "c_allreduce_sum" and op.desc.attr("ring_id") ==
dp_ring_id)
])
self.assertTrue(all_params == data_parallel_allreduce_vars)
self.assertTrue(allreduce_grads == tensor_parallel_allreduce_vars)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册