From a622b7017a4f5e361c8d73e5f8a0c065fdc84553 Mon Sep 17 00:00:00 2001 From: JZ-LIANG <38102074+JZ-LIANG@users.noreply.github.com> Date: Thu, 2 Sep 2021 11:01:15 +0800 Subject: [PATCH] [Auto Parallel] Logical Partition & Dist Op (#35117) * support shard reader * support shard reader * add parallel mode * update process mesh * add method to compute comm_group * implement dist_embedding forward func * implement dist matmul forward func * implement dist reshape forward func * add transpiler framework * add transpiler forward * implement transpiler forward * implement transpiler backward & update * add process * add unitest * chmod * chmod * chmod * update unitest * add unitest for gpt * remove unused print * rename transpiler --> partitioner * rename transpiler --> partitioner * chmod * chmod * bug fixed * remove amp function * update case for dp mode * update case for dp mode --- .../distributed/auto_parallel/attribute.py | 5 + .../distributed/auto_parallel/context.py | 37 + .../distributed/auto_parallel/interface.py | 14 + .../auto_parallel/operators/common.py | 2 + .../auto_parallel/operators/dist_embedding.py | 112 +++ .../auto_parallel/operators/dist_matmul.py | 179 +++- .../auto_parallel/operators/dist_reshape.py | 162 +++ .../auto_parallel/operators/dist_softmax.py | 1 - .../distributed/auto_parallel/partitioner.py | 925 +++++++++++++++++ .../distributed/auto_parallel/process.py | 149 +++ .../paddle/distributed/auto_parallel/utils.py | 123 +++ .../fluid/tests/unittests/CMakeLists.txt | 4 + .../test_auto_parallel_partitioner.py | 948 ++++++++++++++++++ .../test_auto_parallel_partitioner_gpt.py | 857 ++++++++++++++++ 14 files changed, 3515 insertions(+), 3 deletions(-) create mode 100755 python/paddle/distributed/auto_parallel/partitioner.py create mode 100644 python/paddle/distributed/auto_parallel/process.py mode change 100644 => 100755 python/paddle/distributed/auto_parallel/utils.py create mode 100755 python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py create mode 100755 python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py diff --git a/python/paddle/distributed/auto_parallel/attribute.py b/python/paddle/distributed/auto_parallel/attribute.py index 0ca1b7e9444..879e94b8373 100644 --- a/python/paddle/distributed/auto_parallel/attribute.py +++ b/python/paddle/distributed/auto_parallel/attribute.py @@ -14,6 +14,7 @@ import copy from collections import defaultdict +from paddle.fluid import core class TensorDistributedAttribute: @@ -77,6 +78,8 @@ class TensorDistributedAttribute: self._is_parameter = True def is_valid(self): + if self.get_owner_tensor().type == core.VarDesc.VarType.READER: + return True tensor_shape = self.get_owner_tensor().desc.shape() if len(tensor_shape) != len(self.get_dims_mapping()): return False @@ -222,6 +225,8 @@ class OperatorDistributedAttribute: self._is_parameters[name] = True def is_valid(self): + if "read" in self.get_owner_op().type: + return True for name in self.get_owner_op().desc.input_arg_names(): dims_mapping = self.get_input_dims_mapping(name) shape = self.get_input_shape(name) diff --git a/python/paddle/distributed/auto_parallel/context.py b/python/paddle/distributed/auto_parallel/context.py index ff2adc7eacf..bddf9368255 100644 --- a/python/paddle/distributed/auto_parallel/context.py +++ b/python/paddle/distributed/auto_parallel/context.py @@ -15,9 +15,11 @@ import copy from collections import defaultdict from paddle.fluid import framework +from paddle.fluid import core from .attribute import TensorDistributedAttribute from .attribute import OperatorDistributedAttribute from .utils import append_distributed_attr_suffix +from .interface import _g_process_mesh_map # There always exists a default context for user. And user can set it to another one. DEFAULT_DISTRIBUTED_CONTEXT = None @@ -49,6 +51,20 @@ class DistributedContext: self._op_distributed_attr_map_for_program = {} self._tensor_distributed_attr_map_for_graph = {} self._op_distributed_attr_map_for_graph = {} + # The following is a hard code and will be removed in the future + self._data_parallel_axis = None + self._model_parallel_axis = None + self._process_mesh = _g_process_mesh_map.get(0, None) + if self._process_mesh is not None: + if self._process_mesh.ndim == 1: + self._data_parallel_axis = 0 + self._model_parallel_axis = 0 + else: + self._data_parallel_axis = 0 + self._model_parallel_axis = 1 + else: + self._data_parallel_axis = -1 + self._model_parallel_axis = -1 def is_initialized_for_program(self): return self._is_initialized_for_program @@ -99,6 +115,19 @@ class DistributedContext: op_node_id = op_node.id() self._op_distributed_attr_map_for_graph[op_node_id] = op_dist_attr + def set_process_mesh(self, process_mesh): + self._process_mesh = process_mesh + if self._process_mesh is not None: + if self._process_mesh.ndim == 1: + self._data_parallel_axis = 0 + self._model_parallel_axis = 0 + else: + self._data_parallel_axis = 0 + self._model_parallel_axis = 1 + else: + self._data_parallel_axis = -1 + self._model_parallel_axis = -1 + def initialize_distributed_attr_for_program(self, program): if self._is_initialized_for_program: return @@ -377,3 +406,11 @@ class DistributedContext: if dims_mapping[i] != -1 and process_mesh_shape[ dims_mapping[i]] > tensor_shape[i]: dims_mapping[i] = -1 + + def _get_data_parallel_info(self): + # This function is a hard code, and will be obsoleted in the future + return self._data_parallel_axis, self._process_mesh + + def _get_model_parallel_info(self): + # This function is a hard code, and will be obsoleted in the future + return self._model_parallel_axis, self._process_mesh diff --git a/python/paddle/distributed/auto_parallel/interface.py b/python/paddle/distributed/auto_parallel/interface.py index 4c408345f17..348edaef681 100644 --- a/python/paddle/distributed/auto_parallel/interface.py +++ b/python/paddle/distributed/auto_parallel/interface.py @@ -184,6 +184,13 @@ class ProcessMesh(object): "parent with id %d does not exist." % self._parent_id) return _g_process_mesh_map[self._parent_id] + @property + def ndim(self): + r""" + Get the number of dimension of ProcessMesh. + """ + return len(self._topology) + def set_placement(self, order): """ Set the map from logical processes to physical ones using the @@ -229,6 +236,13 @@ class ProcessMesh(object): for idx, l_id in enumerate(logical_order): _user_defined_physical_map[l_id] = order[idx] + def _reset_global_process_mesh_map(self): + """ + Remove all process mesh in _g_process_mesh_map, make it empty. + """ + + _g_process_mesh_map = dict() + def __eq__(self, other): assert other and isinstance(other, ProcessMesh) if self.topology != other.topology or self.process_group != other.process_group: diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index c5e253c0e0b..ef2f5083449 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -33,6 +33,8 @@ class DistributedOperator: class DistributedOperatorImpl: def __init__(self): self._name = None + self._forward_implemented = False + self._backward_implemented = False def forward(self, dist_ctx, *args, **kwargs): raise NotImplementedError("Please Implement this method in Subclass.") diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index 81d3925bb5d..5d1cfcbf69e 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -22,6 +22,12 @@ from ..utils import is_valid_list_index from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping +from paddle.fluid import core, unique_name +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.framework import Program, Parameter, Variable, program_guard +from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype +from ..process import new_process_group +from ..utils import _get_comm_group class DistributedEmbedding(DistributedOperator): @@ -39,6 +45,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): def __init__(self, name): super(DistributedEmbeddingImpl, self).__init__() self._name = name + self._forward_implemented = True + self._backward_implemented = False def is_process_mesh_compatible(self, op_dist_attr): """ No restriction for now. """ @@ -92,6 +100,110 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): return changed + def forward(self, serial_op): + def static_handle(dst_block, + src_op, + op_dist_attr, + input_name_mapping, + output_name_mapping, + rank_id=0): + assert len( + input_name_mapping + ) == 2, "row_parallel_embedding take 2 inputs variable but got {}".format( + input_name_mapping) + assert len( + output_name_mapping + ) == 1, "row_parallel_embedding take 2 inputs variable but got {}".format( + output_name_mapping) + assert len( + input_name_mapping['Ids'] + ) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( + input_name_mapping['Ids']) + assert len( + input_name_mapping['W'] + ) == 1, "row_parallel_embedding input W take 1 variable but got {}".format( + input_name_mapping['W']) + assert len( + output_name_mapping['Out'] + ) == 1, "row_parallel_embedding input Out take 1 variable but got {}".format( + input_name_mapping['Out']) + + Ids_var = dst_block.var(input_name_mapping['Ids'][0]) + Weight_var = dst_block.var(input_name_mapping['W'][0]) + Out_var = dst_block.var(output_name_mapping['Out'][0]) + + # got dist attribute info + embedding_row_dim_mapping = op_dist_attr.get_input_dims_mapping( + Weight_var.name)[0] + process_mesh_shape = op_dist_attr.get_process_mesh().topology + process_mesh_group = op_dist_attr.get_process_mesh().process_group + + # caculate embedding offset + # TODO generalize here, using cartisian product to allow any dimensional mesh shape + mesh_shape = len(process_mesh_shape) + assert mesh_shape <= 2, "row_parallel_embedding only support 1 or 2 dimensional process mesh, but got {}".format( + process_mesh_shape) + num_partition = process_mesh_shape[embedding_row_dim_mapping] + # TODO generalize here, support any mesh group + if mesh_shape == 1: + relative_idx = process_mesh_group.index(rank_id) + else: + relative_idx = rank_id % num_partition + + per_part_size = Weight_var.shape[0] + relative_idx = relative_idx * per_part_size + + # TODO caculate ring id + model_parallel_axis, process_mesh = op_dist_attr.get_owner_context( + )._get_model_parallel_info() + group_ranks = _get_comm_group(process_mesh.process_group, + process_mesh.topology, + model_parallel_axis, rank_id) + group = new_process_group(group_ranks) + + # append op + check_variable_and_dtype(Ids_var, 'input', ['int32', 'int64'], + 'c_embedding') + + intermediate_var_0 = dst_block.create_var( + name=unique_name.generate_with_ignorable_key(".".join( + ["c_embedding", 'tmp'])), + dtype=Weight_var.dtype, + shape=Out_var.shape, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=Out_var.stop_gradient) + + check_variable_and_dtype( + Out_var, 'tensor', + ['float16', 'float32', 'float64', 'int32', 'int64'], + 'c_allreduce_sum') + + dst_block.append_op( + type='c_embedding', + inputs={'Ids': [Ids_var], + 'W': [Weight_var]}, + outputs={'Out': [intermediate_var_0]}, + attrs={"start_index": relative_idx}) + + # use_model_parallel + dst_block.append_op( + type='c_allreduce_sum', + inputs={'X': [intermediate_var_0]}, + outputs={'Out': [Out_var]}, + attrs={ + 'ring_id': group.id, + 'use_calc_stream': True, + 'use_model_parallel': True, + }) + + if in_dygraph_mode(): + raise NotImplementedError( + "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( + "matmul", 0)) + else: + return static_handle + register_distributed_operator_impl("lookup_table_v2", DistributedEmbeddingImpl("row_parallel")) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index fbeb0edd418..9059feeaf85 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -22,6 +22,12 @@ from ..utils import is_valid_list_index from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping +from paddle.fluid import core, unique_name +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.framework import Program, Parameter, Variable, program_guard +from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype +from ..process import new_process_group +from ..utils import _get_comm_group def _update_dims_mapping_for_matmul(op_dist_attr): @@ -37,7 +43,6 @@ def _update_dims_mapping_for_matmul(op_dist_attr): y_dims_mapping_len = len(y_dims_mapping) out_dims_mapping_len = len(out_dims_mapping) - # print("before", x_dims_mapping, y_dims_mapping, out_dims_mapping) # Add dim mapping to Make sure the length dims_mapping be at least 2 if x_dims_mapping_len == 1: x_dims_mapping.insert(0, -1) @@ -109,7 +114,6 @@ def _update_dims_mapping_for_matmul(op_dist_attr): if y_dims_mapping_len == 1: y_dims_mapping.pop(1) - # print("after", x_dims_mapping, y_dims_mapping, out_dims_mapping) assert len(x_dims_mapping) == x_dims_mapping_len assert len(y_dims_mapping) == y_dims_mapping_len assert len(out_dims_mapping) == out_dims_mapping_len @@ -131,6 +135,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): def __init__(self, name): super(DistributedMatmulImpl0, self).__init__() self._name = name + self._forward_implemented = True + self._backward_implemented = False def is_process_mesh_compatible(self, op_dist_attr): """ No restriction for now. """ @@ -170,12 +176,101 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): changed = True return changed + def forward(self, serial_op): + def static_handle(dst_block, + src_op, + op_dist_attr, + input_name_mapping, + output_name_mapping, + rank_id=0): + assert len( + input_name_mapping + ) == 2, "col_parallel_linear take 2 inputs variable but got {}".format( + input_name_mapping) + assert len( + output_name_mapping + ) == 1, "col_parallel_linear take 2 inputs variable but got {}".format( + output_name_mapping) + assert len( + input_name_mapping['X'] + ) == 1, "col_parallel_linear input X take 1 variable but got {}".format( + input_name_mapping['X']) + assert len( + input_name_mapping['Y'] + ) == 1, "col_parallel_linear input Y take 1 variable but got {}".format( + input_name_mapping['Y']) + assert len( + output_name_mapping['Out'] + ) == 1, "col_parallel_linear input Out take 1 variable but got {}".format( + input_name_mapping['Out']) + X_var = dst_block.var(input_name_mapping['X'][0]) + Weight_var = dst_block.var(input_name_mapping['Y'][0]) + Out_var = dst_block.var(output_name_mapping['Out'][0]) + + # TODO infer logic comm presentation + model_parallel_axis, process_mesh = op_dist_attr.get_owner_context( + )._get_model_parallel_info() + group_ranks = _get_comm_group(process_mesh.process_group, + process_mesh.topology, + model_parallel_axis, rank_id) + group = new_process_group(group_ranks) + + intermediate_var_0 = dst_block.create_var( + name=unique_name.generate_with_ignorable_key(".".join( + ["c_identity", 'tmp'])), + dtype=X_var.dtype, + shape=X_var.shape, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=X_var.stop_gradient) + + check_variable_and_dtype( + X_var, 'tensor', + ['float16', 'float32', 'float64', 'int32', 'int64'], + '_c_identity') + + dst_block.append_op( + type='c_identity', + inputs={'X': [X_var]}, + outputs={'Out': intermediate_var_0}, + attrs={ + 'ring_id': group.id, + 'use_calc_stream': True, + 'use_model_parallel': True, + }) + + check_variable_and_dtype(intermediate_var_0, 'x', + ['float16', 'float32', 'float64'], + 'linear') + check_dtype(intermediate_var_0.dtype, 'dtype', + ['float16', 'float32', 'float64'], 'linear') + attrs = { + 'transpose_X': False, + 'transpose_Y': False, + 'alpha': 1, + } + inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} + dst_block.append_op( + type='matmul', + inputs=inputs, + outputs={'Out': Out_var}, + attrs=attrs) + + if in_dygraph_mode(): + raise NotImplementedError( + "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( + "matmul", 0)) + else: + return static_handle + # RowParallel class DistributedMatmulImpl1(DistributedOperatorImpl): def __init__(self, name): super(DistributedMatmulImpl1, self).__init__() self._name = name + self._forward_implemented = True + self._backward_implemented = False def is_process_mesh_compatible(self, op_dist_attr): """ No restriction for now. """ @@ -217,6 +312,86 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): changed = True return changed + def forward(self, serial_op): + def static_handle(dst_block, + src_op, + op_dist_attr, + input_name_mapping, + output_name_mapping, + rank_id=0): + assert len( + input_name_mapping + ) == 2, "col_parallel_linear take 2 inputs variable but got {}".format( + input_name_mapping) + assert len( + output_name_mapping + ) == 1, "col_parallel_linear take 2 inputs variable but got {}".format( + output_name_mapping) + assert len( + input_name_mapping['X'] + ) == 1, "col_parallel_linear input X take 1 variable but got {}".format( + input_name_mapping['X']) + assert len( + input_name_mapping['Y'] + ) == 1, "col_parallel_linear input Y take 1 variable but got {}".format( + input_name_mapping['Y']) + assert len( + output_name_mapping['Out'] + ) == 1, "col_parallel_linear input Out take 1 variable but got {}".format( + input_name_mapping['Out']) + X_var = dst_block.var(input_name_mapping['X'][0]) + Weight_var = dst_block.var(input_name_mapping['Y'][0]) + Out_var = dst_block.var(output_name_mapping['Out'][0]) + + # TODO infer logic comm presentation + model_parallel_axis, process_mesh = op_dist_attr.get_owner_context( + )._get_model_parallel_info() + group_ranks = _get_comm_group(process_mesh.process_group, + process_mesh.topology, + model_parallel_axis, rank_id) + group = new_process_group(group_ranks) + + check_variable_and_dtype( + X_var, 'x', ['float16', 'float32', 'float64'], 'linear') + check_dtype(X_var.dtype, 'dtype', + ['float16', 'float32', 'float64'], 'linear') + attrs = { + 'transpose_X': False, + 'transpose_Y': False, + 'alpha': 1, + } + inputs = {'X': X_var, 'Y': Weight_var} + intermediate_var_0 = dst_block.create_var( + shape=Out_var.shape, + dtype=Out_var.dtype, + type=Out_var.type, + lod_level=Out_var.lod_level, + persistable=False, + is_data=False, + need_check_feed=Out_var.desc.need_check_feed()) + dst_block.append_op( + type='matmul', + inputs=inputs, + outputs={'Out': intermediate_var_0}, + attrs=attrs) + + dst_block.append_op( + type='c_allreduce_sum', + inputs={'X': intermediate_var_0}, + outputs={'Out': Out_var}, + attrs={ + 'ring_id': group.id, + 'use_calc_stream': True, + 'use_model_parallel': True + }) + + if in_dygraph_mode(): + raise NotImplementedError( + "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( + "matmul", 0)) + else: + return static_handle + # ReplicateParallel class DistributedMatmulImpl2(DistributedOperatorImpl): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py index 40da0e2f609..e7fbe9cfeba 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py @@ -22,6 +22,10 @@ from ..utils import is_valid_list_index from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping +from paddle.fluid import core, unique_name +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.framework import Program, Parameter, Variable, program_guard +from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype class DistributedReshape2(DistributedOperator): @@ -37,6 +41,8 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): def __init__(self, name): super(DistributedReshapeImpl0, self).__init__() self._name = name + self._forward_implemented = True + self._backward_implemented = False def is_process_mesh_compatible(self, op_dist_attr): """ No restriction for now. """ @@ -91,11 +97,90 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): return changed + def forward(self, serial_op): + def static_handle(dst_block, + src_op, + op_dist_attr, + input_name_mapping, + output_name_mapping, + rank_id=0): + assert len( + input_name_mapping + ) == 3, "Dist op of Reshape take 3 inputs variable but got {}".format( + input_name_mapping) + assert len( + output_name_mapping + ) == 2, "Dist op of Reshape take 2 inputs variable but got {}".format( + output_name_mapping) + assert len( + input_name_mapping['X'] + ) == 1, "Dist op of Reshape input X take 1 variable but got {}".format( + input_name_mapping['X']) + assert len( + input_name_mapping['ShapeTensor'] + ) <= 1, "Dist op of Reshape input ShapeTensor take 0 or 1 variable but got {}".format( + input_name_mapping['ShapeTensor']) + assert len( + input_name_mapping['Shape'] + ) <= 1, "Dist op of Reshape input Shape take 0 or 1 variable but got {}".format( + input_name_mapping['Shape']) + assert len( + output_name_mapping['Out'] + ) == 1, "Dist op of Reshape input Out take 1 variable but got {}".format( + input_name_mapping['Out']) + assert len( + output_name_mapping['XShape'] + ) == 1, "Dist op of Reshape input XShape take 1 variable but got {}".format( + input_name_mapping['XShape']) + + X_var = dst_block.var(input_name_mapping['X'][0]) + Out_var = dst_block.var(output_name_mapping['Out'][0]) + XShape_var = dst_block.var(output_name_mapping['XShape'][0]) + shape_list = src_op.desc.attr("shape") + ShapeTensor_var_list = [] + for name in input_name_mapping['ShapeTensor']: + ShapeTensor_var_list.append(name) + Shape_var_list = [] + for name in input_name_mapping['Shape']: + Shape_var_list.append(name) + + # got dist attribute info + dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) + process_mesh_shape = op_dist_attr.get_process_mesh().topology + + # modify target shape + for idx, axis in enumerate(dim_mapping): + if axis >= 0: + if len(shape_list) > idx: + shape_list[idx] = shape_list[idx] // process_mesh_shape[ + axis] + + # create op + new_op_desc = dst_block.desc.append_op() + new_op_desc.copy_from(src_op.desc) + new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list) + new_op_desc.set_input('Shape', Shape_var_list) + new_op_desc.set_input('X', [X_var.name]) + new_op_desc.set_output('XShape', [XShape_var.name]) + new_op_desc.set_output('Out', [Out_var.name]) + new_op_desc._set_attr('shape', shape_list) + + dst_block._sync_with_cpp() + + if in_dygraph_mode(): + raise NotImplementedError( + "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( + "matmul", 0)) + else: + return static_handle + class DistributedReshapeImpl1(DistributedOperatorImpl): def __init__(self, name): super(DistributedReshapeImpl1, self).__init__() self._name = name + self._forward_implemented = True + self._backward_implemented = False def is_process_mesh_compatible(self, op_dist_attr): """ No restriction for now. """ @@ -150,6 +235,83 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): return changed + def forward(self, serial_op): + def static_handle(dst_block, + src_op, + op_dist_attr, + input_name_mapping, + output_name_mapping, + rank_id=0): + assert len( + input_name_mapping + ) == 3, "Dist op of Reshape take 3 inputs variable but got {}".format( + input_name_mapping) + assert len( + output_name_mapping + ) == 2, "Dist op of Reshape take 2 inputs variable but got {}".format( + output_name_mapping) + assert len( + input_name_mapping['X'] + ) == 1, "Dist op of Reshape input X take 1 variable but got {}".format( + input_name_mapping['X']) + assert len( + input_name_mapping['ShapeTensor'] + ) <= 1, "Dist op of Reshape input ShapeTensor take 0 or 1 variable but got {}".format( + input_name_mapping['ShapeTensor']) + assert len( + input_name_mapping['Shape'] + ) <= 1, "Dist op of Reshape input Shape take 0 or 1 variable but got {}".format( + input_name_mapping['Shape']) + assert len( + output_name_mapping['Out'] + ) == 1, "Dist op of Reshape input Out take 1 variable but got {}".format( + input_name_mapping['Out']) + assert len( + output_name_mapping['XShape'] + ) == 1, "Dist op of Reshape input XShape take 1 variable but got {}".format( + input_name_mapping['XShape']) + + X_var = dst_block.var(input_name_mapping['X'][0]) + Out_var = dst_block.var(output_name_mapping['Out'][0]) + XShape_var = dst_block.var(output_name_mapping['XShape'][0]) + shape_list = src_op.desc.attr("shape") + ShapeTensor_var_list = [] + for name in input_name_mapping['ShapeTensor']: + ShapeTensor_var_list.append(name) + Shape_var_list = [] + for name in input_name_mapping['Shape']: + Shape_var_list.append(name) + + # got dist attribute info + dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) + process_mesh_shape = op_dist_attr.get_process_mesh().topology + + # modify target shape + for idx, axis in enumerate(dim_mapping): + if axis >= 0: + if len(shape_list) > idx: + shape_list[idx] = shape_list[idx] // process_mesh_shape[ + axis] + + # create op + new_op_desc = dst_block.desc.append_op() + new_op_desc.copy_from(src_op.desc) + new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list) + new_op_desc.set_input('Shape', Shape_var_list) + new_op_desc.set_input('X', [X_var.name]) + new_op_desc.set_output('XShape', [XShape_var.name]) + new_op_desc.set_output('Out', [Out_var.name]) + new_op_desc._set_attr('shape', shape_list) + + dst_block._sync_with_cpp() + + if in_dygraph_mode(): + raise NotImplementedError( + "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( + "matmul", 0)) + else: + return static_handle + register_distributed_operator_impl("reshape2", DistributedReshapeImpl0("add_one_dim_back")) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py index fad11aadf80..dc78bdee1fb 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py @@ -47,7 +47,6 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): x_name = op_desc.input('X')[0] axis = op_desc.attr('axis') x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) - # print("softmax axis", axis) if axis != -1 and axis != len(x_dims_mapping) - 1: return False diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py new file mode 100755 index 00000000000..03497f2967c --- /dev/null +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -0,0 +1,925 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import copy +import numpy as np +import paddle +import paddle.fluid as fluid +from paddle.fluid import core +from paddle.fluid import framework as framework +from paddle.fluid import core, unique_name +from paddle.fluid.framework import Program, Parameter, Variable, program_guard +from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype +from paddle.fluid.backward import append_backward, _some_in_set_, _append_grad_suffix_ +from paddle.distributed.auto_parallel.operators.common import get_distributed_operator +from paddle.distributed.auto_parallel.operators.common import find_best_compatible_distributed_operator_impl +from paddle.fluid.clip import GradientClipBase, GradientClipByNorm, error_clip_callback, append_gradient_clip_ops, ClipGradByGlobalNorm +from paddle.distributed.fleet.base.distributed_strategy import DistributedStrategy +from paddle.distributed.auto_parallel.context import DistributedContext +from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op, is_backward_op, is_optimizer_op +from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY +from .process import new_process_group +from .interface import _g_process_mesh_map +from .utils import _get_comm_group + +__varname_not_in_block__ = ["lod_tensor_blocking_queue_0"] + + +class Partitioner(object): + """ + warning:: Partitioner is experimental and subject to change. + + Partitioner convert a program into another program. + Given a serial program which has been auto completed with shard annotation, the Partitioner + convert the serial program into a "distributed" program. The Partitioner will modify the serial + program in following two ways, which is also the major difference between serial and distributed program: + 1. partition op: replace a serial op into its corresponding dist op infered from the shard annotation + 2. partition var: if a var is sharded, modify the shape of var according to its shard annotation + + Partitioner is supposed to be call by the auto parallel framework, and not supposed to be directly called by user. + + Example: + .... + import paddle.distributed.auto_parallel as auto + from paddle.fluid.distributed_attribute import get_default_distributed_context + from paddle.distributed import fleet + from paddle.distributed.auto_parallel.partitioner import Partitioner + + # create serial program with forward only + with static.program_guard(serial_main_program, serial_start_program): + model = create_model(config) + tokens = static.data(name="tokens", shape=[batch_size, sequence_len], dtype='int64') + labels = static.data(name="labels", shape=[batch_size, sequence_len], dtype='int64') + loss_mask = static.data(name="loss_mask", shape=[batch_size, sequence_len], dtype='int64') + preds = model(tokens) + loss = criterion(preds, labels, loss_mask) + + # auto completion + auto.ProcessMesh(shape=[2, 4], process_group=[0, 1, 2, 3, 4, 5, 6, 7]) + annotated_main_program = auto.complete_annotation(serial_main_program) + auto_paralle_context = get_default_distributed_context() + + # distributed strategy & rank info + rank_id = paddle.distributed.get_rank() + dist_strategy = fleet.DistributedStrategy() + + # create partitioner + Partitioner = Partitioner(dist_strategy, auto_paralle_context, rank_id) + + # create dist program with forward only + # for distributed inference, using partitioned_main_prog from here + partitioned_main_prog, partitioned_startup_prog = Partitioner.transpile_forward(complete_train_program, start_program) + + # create dist program with forward/backward/update + # for distributed training, using partitioned_main_prog from here + dist_params_grads = Partitioner.apply_backward(loss, complete_train_program, start_program, partitioned_main_prog, partitioned_startup_prog) + optimizer = paddle.fluid.optimizer.AdamOptimizer( + learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) + opt_ops = Partitioner.apply_optimize(optimizer, dist_params_grads, partitioned_main_prog, partitioned_startup_prog) + """ + + def __init__(self, dist_strategy, auto_parallel_context, rank_id=0): + """ + Args: + dist_strategy (paddle.fleet.distributed_strategy): used to determine the user defined distributed strategy. + auto_parallel_context (paddle.fluid.DistributedContext): used to access the distributed_attr of var & op, every Partitioner object could maintain its own DistributedContext member, and partition program base on that shard scenario. + rank_id (int): global rank id to which the partitioned distributed program belong. + """ + + if not isinstance(dist_strategy, DistributedStrategy): + raise TypeError( + "dist_strategy be paddle.fleet.base.DistributedStrategy, got %s here" + % type(dist_strategy)) + + if not isinstance(auto_parallel_context, DistributedContext): + raise TypeError( + "auto_parallel_context be paddle.fluid.DistributedContext, got %s here" + % type(auto_parallel_context)) + + self._dist_strategy = dist_strategy + self._auto_parallel_context = auto_parallel_context + self._rank_id = rank_id + self._serial2dist_varname_mapping = {} + self._dist_varname_suffix = "" + + # TODO if there is some dist op that is not compatible + # with auto_backward in forward, the following flag + # should be set to False + self._compatible_with_auto_backward = True + + # data parallelism + self._enable_data_parallel = False + self._dp_degree = 0 + self._dp_group = None + + # tensor parallelism + self._enable_tensor_parallel = False + self._tp_degree = 0 + self._tp_group = None + + def transpile_forward(self, serial_main_program, serial_startup_program): + """ + take serial forward programs with shard annotation, create a new distributed forward programs based on the serial ones. + instead of modify the input programs inplace, this function will preserve the inputs and create new program for output. + + beside replace the serial op with its dist op, if user has defined other strategy in fleet.distributed_strategy, and if + those strategy need to transpile (modify) the forward network program, those forward program modification should also be done within this + function in auto parallel scenario, in order to facilitate distributed inference/evaluation which need to DECOUPLE strategy specific forward transpilation with fleet.distributed_optimizer.minimize(). + + by now the fleet.distributed_strategy that need transpile forward program are following: + 1. (optimizer) sharding + + Args: + main_program (paddle.fluid.framework.program): serial main program with forward network only + startup_program (paddle.fluid.framework.program): serial startup program with forward network only + + return: + main_program (paddle.fluid.framework.program): distributed main program with forward network only + startup_program (paddle.fluid.framework.program): distributed startup program with forward network only + """ + + dist_main_program, dist_startup_program = self.transpile_forward_impl( + serial_main_program, serial_startup_program) + return dist_main_program, dist_startup_program + + def apply_backward(self, + serial_loss, + serial_main_program, + serial_startup_program, + dist_main_program, + dist_startup_program, + parameter_list=None, + no_grad_set=None, + callbacks=None): + """ + A complete training neural network is made up of forward and backward propagation. + This function is to generate the dist backward program for the distributed forward program. + + By now, the current automatical backward mechanism in paddle framework might NOT handle the backward generation for + some dist ops correctly, some so we now have two ways to genenate the backward program: + 1. dist_forward_program --> auto_backward --> dist_backward_program (if auto_backward could handle all dist op) + 2. serial_forward_program --> auto_backward --> serial_backward_program --> dist_op_backward_transpile --> dist_backward_program (if auto_backward could not handle all dist op) + + the backprogram is append the input dist program inplaced. + + Args: + serial_loss (Variable) the loss in serial program that to be minimized + serial_main_program (paddle.fluid.framework.program): serial main program with forward network only + serial_startup_program (paddle.fluid.framework.program): serial startup program with forward network only + dist_main_program (paddle.fluid.framework.program): dist main program with forward network only + dist_startup_program (paddle.fluid.framework.program): dist startup program with forward network only + parameter_list (Iterable, optional): Iterable of ``Variable`` or ``Variable.name`` to update + to minimize ``loss``. The default value is None, at this time all parameters + will be updated. + no_grad_set (set, optional): Set of ``Variable`` or ``Variable.name`` that don't need + to be updated. The default value is None. + callbacks (list, optional): list of callable objects to run when appending backward + operator for one parameter. The default value is None. + + return: + params_grads (list) list of tuple that contain param and its grad variable + """ + params_grads = self.apply_backward_impl( + serial_loss, serial_main_program, serial_startup_program, + dist_main_program, dist_startup_program) + return params_grads + + def apply_optimize(self, user_define_optimizer, params_grads, + dist_main_program, dist_startup_program): + """ + append update related ops to the program: clip, weight decay, ops + filter optimize op if sharding is enable + naive gradient synchronization before update + + Args: + user_define_optimizer (paddle.fluid.optimizer): + params_grads (list) list of tuple that contain param and its grad variable + dist_main_program (paddle.fluid.framework.program): dist main program with forward & backward network + dist_startup_program (paddle.fluid.framework.program): dist startup program with forward & backward network + """ + + optimize_ops = self.apply_optimize_impl(user_define_optimizer, + params_grads, dist_main_program, + dist_startup_program) + + return optimize_ops + + def transpile_forward_impl(self, main_program, startup_program): + + if not isinstance(main_program, (Program)): + raise TypeError( + "dist_strategy be paddle.fluid.framework.program, got %s here" % + type(main_program)) + + if not isinstance(startup_program, (Program)): + raise TypeError( + "auto_parallel_context be paddle.fluid.framework.program, got %s here" + % type(startup_program)) + + # check if shard annotated serial program valid + if not self._is_valid_annotated_program(main_program): + raise RuntimeError( + "Not all vars or ops are annotated in main program !") + + # determine parallelism mode + self._determine_parallel_mode(main_program) + + # dist op & partition vars + new_main_prog, new_startup_program = self._dist_var_op_forward_transpile( + main_program, startup_program) + + # Sharding + if self._dist_strategy.sharding: + new_main_prog, new_startup_program = self._sharding_forward_transpile( + new_main_prog, new_startup_program) + + return new_main_prog, new_startup_program + + def apply_backward_impl(self, + serial_loss, + serial_main_program, + serial_startup_program, + dist_main_program, + dist_startup_program, + parameter_list=None, + no_grad_set=None, + callbacks=None): + """ + """ + + params_grads = self._dist_var_op_backward_transpile( + serial_loss, serial_main_program, serial_startup_program, + dist_main_program, dist_startup_program) + # Sharding + if self._dist_strategy.sharding: + self._sharding_backward_transpile(new_main_prog, + new_startup_program) + + # Data Parallel pass + if self._enable_data_parallel: + self._gradient_sync_transpile(dist_main_program, + dist_startup_program) + + return params_grads + + def apply_optimize_impl(self, user_define_optimizer, params_grads, + dist_main_program, dist_startup_program): + """ + append update related ops to the program: clip, weight decay, ops + filter optimize op if sharding is enable + naive gradient synchronization before update + + Args: + user_define_optimizer (paddle.fluid.optimizer): + params_grads (list) list of tuple that contain param and its grad variable + dist_main_program (paddle.fluid.framework.program): dist main program with forward & backward network + dist_startup_program (paddle.fluid.framework.program): dist startup program with forward & backward network + """ + + if self._dist_strategy.sharding: + params_grads = sharding_optimize_transpile( + params_grads, dist_main_program, dist_startup_program) + + optimize_ops = self._optimize_transpile(user_define_optimizer, + params_grads, dist_main_program, + dist_startup_program) + + return optimize_ops + + def _dist_var_op_forward_transpile(self, + serial_main_program, + serial_startup_program=None): + """ + 1. partition variables + 2. replace local op with corresponding dist op + """ + + partitioned_main_prog = fluid.Program() + partitioned_global_block = partitioned_main_prog.global_block() + serial_global_block = serial_main_program.global_block() + serial_ops = serial_main_program.global_block().ops + + # transpile main program + for op in serial_ops: + + # partititon input variables + for serial_input_varname in op.desc.input_arg_names(): + if serial_input_varname not in self._serial2dist_varname_mapping: + new_varname = serial_input_varname + self._dist_varname_suffix + if serial_global_block.has_var(serial_input_varname): + _partition_var(self._auto_parallel_context, + serial_global_block, + partitioned_global_block, + serial_input_varname, new_varname) + else: + assert serial_input_varname in __varname_not_in_block__ + + self._serial2dist_varname_mapping[ + serial_input_varname] = new_varname + + # partition output vars + for serial_output_varname in op.desc.output_arg_names(): + if serial_output_varname not in self._serial2dist_varname_mapping: + new_varname = serial_output_varname + self._dist_varname_suffix + _partition_var(self._auto_parallel_context, + serial_global_block, + partitioned_global_block, + serial_output_varname, new_varname) + self._serial2dist_varname_mapping[ + serial_output_varname] = new_varname + + # partition op + if _found_match_dist_op(self._auto_parallel_context, op): + # replace with corresponding dist op + _insert_dist_op(op, partitioned_global_block, + self._serial2dist_varname_mapping, + self._auto_parallel_context, self._rank_id) + else: + # replicate op + _insert_src_op(op, partitioned_global_block, + self._serial2dist_varname_mapping) + + # transpile startup program + if serial_startup_program == None: + partitioned_startup_prog = None + else: + partitioned_startup_prog = fluid.Program() + # create parameter + partitioned_startup_global_block = partitioned_startup_prog.global_block( + ) + param2shape = {} + for var in partitioned_main_prog.list_vars(): + if isinstance(var, Parameter): + _partition_parameter(self._auto_parallel_context, var, + partitioned_startup_global_block, + var.name, var.shape) + param2shape[var.name] = var.shape + + # copy initializer + for op in serial_startup_program.global_block().ops: + output_vars = op.desc.output_arg_names() + assert len( + output_vars + ) == 1, "initializer should output only ONE variable, but got [{}]".format( + str(op.desc)) + assert self._serial2dist_varname_mapping[output_vars[ + 0]] in param2shape, "try to initialize [{}] which is not a Parameter".format( + output_vars[0]) + new_op_desc = partitioned_startup_global_block.desc.append_op() + new_op_desc.copy_from(op.desc) + new_op_desc._rename_output( + output_vars[0], + self._serial2dist_varname_mapping[output_vars[0]]) + new_op_desc._set_attr("shape", param2shape[ + self._serial2dist_varname_mapping[output_vars[0]]]) + partitioned_startup_global_block._sync_with_cpp() + + # MP broadcast not split parameter + # NOTE Theoretically, the MP param init broadcast should be handled by + # each dist op itself. but if we insert the broadcast op at that moment, the broadcast + # will before the initializer, which lead to a undertermined case. + if self._enable_tensor_parallel: + param_to_sync = [] + for param in partitioned_startup_prog.all_parameters(): + if not self._is_var_distributed(param): + param_to_sync.append(param) + # FIXME the ring id should be set by autoparallel.mapping module + # it should be determined by dp groups butfixed it here for hacking + partitioned_startup_global_block.append_op( + type='c_broadcast', + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': self._tp_group.id, + 'root': 0, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Forward + }) + partitioned_startup_global_block.append_op( + type='c_sync_comm_stream', + inputs={'X': param_to_sync}, + outputs={'Out': param_to_sync}, + attrs={ + 'ring_id': self._tp_group.id, + OP_ROLE_KEY: OpRole.Forward + }) + partitioned_startup_global_block._sync_with_cpp() + + # DP init param broadcast + if self._enable_data_parallel: + # parameters initialization synchronization + param_to_sync = [] + + for param in partitioned_startup_global_block.all_parameters(): + param_to_sync.append(param) + + # FIXME the ring id should be set by autoparallel.mapping module + # it should be determined by dp groups butfixed it here for hacking + partitioned_startup_global_block.append_op( + type='c_broadcast', + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': self._dp_group.id, + 'root': 0, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Forward + }) + partitioned_startup_global_block.append_op( + type='c_sync_comm_stream', + inputs={'X': param_to_sync}, + outputs={'Out': param_to_sync}, + attrs={ + 'ring_id': self._dp_group.id, + OP_ROLE_KEY: OpRole.Forward + }) + partitioned_startup_global_block._sync_with_cpp() + + return partitioned_main_prog, partitioned_startup_prog + + def _dist_var_op_backward_transpile(self, + serial_loss, + serial_main_program, + serial_startup_program, + dist_main_program, + dist_startup_program, + parameter_list=None, + no_grad_set=None, + callbacks=None): + """ + so far, the auto_backward case only guarantee the correcotness of backward ops for curtain Dist ops: + 1. NV-Megatron-like parallel embedding + 2. NV-Megatron-like row parallel linear + 3. NV-Megatron-like col parallel linear + """ + + if self._compatible_with_auto_backward: + assert isinstance( + serial_loss, Variable), "The target loss should be an Variable." + dist_loss = self._serial_varname2dist_var(serial_loss.name, + dist_main_program) + + assert len(dist_loss.shape) == 1 and dist_loss.shape[0] == 1, \ + "The dist loss.shape should be (1L,), but the current dist loss.shape is {}. " \ + "Maybe that you should call fluid.layers.mean to process the current loss.".format( + dist_loss.shape) + + # update parameter list + if parameter_list: + parameter_list = [ + self._serial_varname2dist_var(param.name, dist_main_program) + for param in parameter_list + ] + + # update parameter no_grad_set + if no_grad_set: + no_grad_set = [ + self._serial_varname2dist_var(param.name, dist_main_program) + for param in no_grad_set + ] + + return _auto_backward( + dist_loss, + dist_startup_program, + parameter_list=parameter_list, + no_grad_set=no_grad_set, + callbacks=callbacks) + # replace dist grad ops + else: + raise RuntimeError("transpile NOT implemented !") + + def _optimize_transpile(self, user_define_optimizer, params_grads, + main_program, startup_program): + + with program_guard(main_program, startup_program): + optimize_ops = user_define_optimizer.apply_gradients(params_grads) + + return optimize_ops + + def _is_valid_annotated_program(self, program): + + # TODO (ZJ-LIANG) should check all block + ops = program.global_block().ops + vars_ = program.list_vars() + op_dist_attrs = [ + self._auto_parallel_context.get_op_distributed_attr_for_program(op) + for op in ops + ] + var_dist_attrs = [ + self._auto_parallel_context.get_tensor_distributed_attr_for_program( + var) for var in vars_ + ] + + all_ops_annotated = all(dist_attr is not None + for dist_attr in op_dist_attrs) + all_vars_annotated = all(dist_attr is not None + for dist_attr in var_dist_attrs) + + return all_ops_annotated and all_vars_annotated + + def _serial_varname2dist_var(self, serial_varname, dist_program): + assert serial_varname in self._serial2dist_varname_mapping, "The serial var [{}] is not found in var name mapping".format( + serial_varname) + dist_varname = self._serial2dist_varname_mapping[serial_varname] + + assert dist_program.global_block().has_var( + dist_varname + ), "The dist var [{}] is not found in dist program".format(dist_varname) + dist_var = dist_program.global_block().var(dist_varname) + + return dist_var + + def _determine_parallel_mode(self, program): + """ + determine the parallelism that is enabled + NOTE a hard rule and should be updated in future + """ + + for param in program.all_parameters(): + if self._is_var_distributed(param): + self._enable_tensor_parallel = True + break + + for var in program.list_vars(): + var_dist_attr = self._auto_parallel_context.get_tensor_distributed_attr_for_program( + var) + if not var_dist_attr.is_parameter(): + mapping = var_dist_attr.get_dims_mapping() + mesh = var_dist_attr.get_process_mesh().topology + if mapping[0] >= 0 and mesh[mapping[0]] > 1: + self._enable_data_parallel = True + break + + # tensor parallelism + if self._enable_tensor_parallel: + model_parallel_axis, process_mesh = self._auto_parallel_context._get_model_parallel_info( + ) + group_ranks = _get_comm_group(process_mesh.process_group, + process_mesh.topology, + model_parallel_axis, self._rank_id) + self._tp_degree = len(group_ranks) + self._tp_group = new_process_group(group_ranks) + + # data parallelism + data_parallel_axis, process_mesh = self._auto_parallel_context._get_data_parallel_info( + ) + if self._enable_data_parallel: + group_ranks = _get_comm_group(process_mesh.process_group, + process_mesh.topology, + data_parallel_axis, self._rank_id) + self._dp_degree = len(group_ranks) + self._dp_group = new_process_group(group_ranks) + + def _is_var_distributed(self, var): + + dist_attr = self._auto_parallel_context.get_tensor_distributed_attr_for_program( + var) + assert dist_attr is not None, "dist_attr of var [{}] is None".format( + var.name) + return _is_distributed(dist_attr) + + def _sharding_forward_transpile(self, main_prog, startup_program): + """ + this transpile conduct the modification in forward program need by sharding strategy + which majorly include: + 1. partition the parameter + 2. insert broadcast op + 3. insert sync op + + NOTE the transpile modification is inplace on the input program + """ + + raise NotImplementedError( + "Sharding is NOT support in AutoParallel yet!") + + def _sharding_backward_transpile(self, main_prog, startup_program): + """ + this transpile conduct the modification in backward program need by sharding strategy + which majorly include: + 1. partition the gradient + 2. insert broadcast op + 3. insert sync op + + NOTE the transpile modification is inplace on the input program + """ + + raise NotImplementedError( + "Sharding is NOT support in AutoParallel yet!") + + def _sharding_optimize_transpile(self, params_grads, dist_main_program, + dist_startup_program): + """ + shard params_grads + append the broadcast to sync parameters + """ + raise RuntimeError("sharding transpile is NOT implemented !") + + def _gradient_sync_transpile(self, main_program, startup_program): + """ + append the gradient allreduce ops for all parameters' grad in case of Data Parallel + """ + + # scale loss by dp degree + main_global_block = main_program.global_block() + for idx, op in reversed(list(enumerate(main_global_block.ops))): + if is_loss_grad_op(op): + loss_grad_var = main_global_block.vars[op.output_arg_names[0]] + main_global_block._insert_op_without_sync( + idx + 1, + type='scale', + inputs={'X': loss_grad_var}, + outputs={'Out': loss_grad_var}, + attrs={ + 'scale': 1.0 / self._dp_degree, + OP_ROLE_KEY: OpRole.Backward + }) + break + main_global_block._sync_with_cpp() + + # gradient synchronization + # NOTE naive gradient sync without overlapping + # so there is not need to sync between calc and comm + # collecting grad var + grad_to_sync = [] + for idx, op in reversed(list(enumerate(main_global_block.ops))): + if is_backward_op(op) and \ + OP_ROLE_VAR_KEY in op.attr_names: + op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY] + if len(op_role_var) != 0: + assert len(op_role_var) % 2 == 0 + for i in range(0, len(op_role_var), 2): + param, reduced_grad = op_role_var[i], op_role_var[i + 1] + assert (reduced_grad not in grad_to_sync) + grad_to_sync.append(reduced_grad) + if is_optimizer_op(op): + first_optimize_op_idx = idx + + # insert allreduce + for grad in grad_to_sync: + # FIXME the ring id should be set by autoparallel.mapping module + # it should be determined by dp groups butfixed it here for hacking + main_global_block.append_op( + type='c_allreduce_sum', + inputs={'X': grad}, + outputs={'Out': grad}, + attrs={ + 'ring_id': self._dp_group.id, + 'root': 0, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Backward + }) + main_global_block.append_op( + type='c_sync_comm_stream', + inputs={'X': grad_to_sync}, + outputs={'Out': grad_to_sync}, + attrs={'ring_id': self._dp_group.id, + OP_ROLE_KEY: OpRole.Backward}) + main_global_block._sync_with_cpp() + + +def _get_no_grad_set_name(no_grad_set): + no_grad_set_name = set() + if no_grad_set is not None: + if isinstance(no_grad_set, (set, list, tuple)): + for i, no_grad_var in enumerate(no_grad_set): + if isinstance(no_grad_var, framework.Variable): + no_grad_set_name.add(no_grad_var.name) + elif isinstance(no_grad_var, six.string_types): + no_grad_set_name.add(no_grad_var) + else: + raise TypeError( + "The type of no_grad_set's member must be paddle.fluid.Variable or str, but received %s." + % (type(no_grad_var))) + else: + raise TypeError( + "The type of no_grad_set should be set or list or tuple, but received {}". + format(type(no_grad_set))) + return no_grad_set_name + + +def _get_no_grad_set(loss, no_grad_set=None): + no_grad_set = _get_no_grad_set_name(no_grad_set) + parameters = loss.block.program.global_block().all_parameters() + param_no_trainable = set( + [param.name for param in parameters if param.trainable is False]) + # If the parameter is no trainable, it should not have a gradient. + no_grad_set.update(param_no_trainable) + + return no_grad_set + + +def _found_match_dist_op(auto_paralle_context, op): + dist_attr = auto_paralle_context.get_op_distributed_attr_for_program(op) + dist_ops = get_distributed_operator(op.type) + + return dist_ops and dist_attr.get_impl_idx() >= 0 and dist_ops.get_impl( \ + dist_attr.get_impl_idx())._forward_implemented + + +def _auto_backward(loss, + startup_program=None, + parameter_list=None, + no_grad_set=None, + callbacks=None): + """ + modification is inplaced + """ + act_no_grad_set = _get_no_grad_set(loss, no_grad_set) + assert isinstance(loss, Variable), "The target loss should be an Variable." + + if callbacks is None: + callbacks = [error_clip_callback] + else: + assert (isinstance(callbacks, list)) + + assert len(loss.shape) == 1 and loss.shape[0] == 1, \ + "The loss.shape should be (1L,), but the current loss.shape is {}. " \ + "Maybe that you should call fluid.layers.mean to process the current loss.".format( + loss.shape) + + program = loss.block.program + with program_guard(program, startup_program): + params_grads = append_backward(loss, parameter_list, act_no_grad_set, + callbacks) + + return params_grads + + +def _is_distributed(dist_attr): + + mapping = dist_attr.get_dims_mapping() + mesh = dist_attr.get_process_mesh().topology + for idx in range(len(mapping)): + if mapping[idx] >= 0 and mesh[mapping[idx]] > 1: + return True + + return False + + +def _get_dist_shape(var, dist_attr): + + var_shape = var.shape + mapping = dist_attr.get_dims_mapping() + mesh = dist_attr.get_process_mesh().topology + assert len(var_shape) == len( + mapping + ), "variable shape [{}] and dim_mapping [{}] is NOT match !".format( + var_shape, mapping) + new_shape = [] + for idx in range(len(var_shape)): + if var_shape[idx] == -1 or mapping[idx] == -1: + new_shape.append(var_shape[idx]) + else: + assert var_shape[idx] % mesh[mapping[ + idx]] == 0, "un-event partition: var_shape[idx]=[{}], mesh[{}]".format( + var_shape[idx], mesh[mapping[idx]]) + new_shape.append(var_shape[idx] // mesh[mapping[idx]]) + + return new_shape + + +def _partition_parameter(auto_paralle_context, src_var, dst_block, dst_varname, + dst_shape): + # NOTE hack to copied Parameter + # not initialized parameter, need to initialize it + copied_kwargs = {} + copied_kwargs['trainable'] = src_var.trainable + copied_kwargs['optimize_attr'] = src_var.optimize_attr + copied_kwargs['regularizer'] = src_var.regularizer + copied_kwargs['do_model_average'] = src_var.do_model_average + copied_kwargs['need_clip'] = src_var.need_clip + + param = Parameter( + block=dst_block, + type=src_var.type, + name=dst_varname, + shape=dst_shape, + dtype=src_var.dtype, + lod_level=src_var.lod_level, + error_clip=src_var.error_clip, + stop_gradient=src_var.stop_gradient, + is_data=src_var.is_data, + belong_to_optimizer=src_var.belong_to_optimizer, + **copied_kwargs) + + # set dist attr uid + # distributed_attr_uid = src_var.desc.get_distributed_attr_uid() + # param.desc.set_distributed_attr_uid(distributed_attr_uid) + dist_attr = copy.deepcopy( + auto_paralle_context.get_tensor_distributed_attr_for_program(src_var)) + dist_attr._owner_tensor = param + dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program( + src_var)._owner_context + auto_paralle_context.set_tensor_distributed_attr_for_program(param, + dist_attr) + + +def _partition_intermediate_var(auto_paralle_context, src_var, dst_block, + dst_varname, dst_shape): + var = dst_block.create_var( + type=src_var.type, + name=dst_varname, + shape=dst_shape, + dtype=src_var.dtype, + lod_level=src_var.lod_level, + persistable=src_var.persistable, + error_clip=src_var.error_clip, + stop_gradient=src_var.stop_gradient, + is_data=src_var.is_data, + belong_to_optimizer=src_var.belong_to_optimizer) + + # set dist attr uid + # distributed_attr_uid = src_var.desc.get_distributed_attr_uid() + # var.desc.set_distributed_attr_uid(distributed_attr_uid) + dist_attr = copy.deepcopy( + auto_paralle_context.get_tensor_distributed_attr_for_program(src_var)) + dist_attr._owner_tensor = var + dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program( + src_var)._owner_context + auto_paralle_context.set_tensor_distributed_attr_for_program(var, dist_attr) + + +def _partition_var(auto_paralle_context, src_block, dst_block, src_varname, + dst_varname): + """ + partition include: split + replicate + """ + src_var = src_block.var(src_varname) + + if src_var.type == core.VarDesc.VarType.READER: + dst_block.create_var( + type=src_var.type, + name=dst_varname, + persistable=True, + stop_gradient=True) + else: + dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program( + src_var) + target_shape = _get_dist_shape(src_var, dist_attr) + + if isinstance(src_var, Parameter): + _partition_parameter(auto_paralle_context, src_var, dst_block, + dst_varname, target_shape) + else: + _partition_intermediate_var(auto_paralle_context, src_var, + dst_block, dst_varname, target_shape) + + +def _insert_src_op(src_op, dst_block, varname_mapping): + + new_op_desc = dst_block.desc.append_op() + new_op_desc.copy_from(src_op.desc) + for local_varname in src_op.desc.input_arg_names(): + new_op_desc._rename_input(local_varname, varname_mapping[local_varname]) + for local_varname in src_op.desc.output_arg_names(): + new_op_desc._rename_output(local_varname, + varname_mapping[local_varname]) + dst_block._sync_with_cpp() + + +def _insert_dist_op(src_op, dst_block, varname_mapping, auto_paralle_context, + rank_id): + + # build input varname mapping + input_mapping = {} + for input_name in src_op.desc.input_names(): + varnames = [] + for varname in src_op.desc.input(input_name): + varnames.append(varname_mapping[varname]) + input_mapping[input_name] = varnames + + # build output varname mapping + output_mapping = {} + for output_name in src_op.desc.output_names(): + varnames = [] + for varname in src_op.desc.output(output_name): + varnames.append(varname_mapping[varname]) + output_mapping[output_name] = varnames + + # append dist op + dist_attr = auto_paralle_context.get_op_distributed_attr_for_program(src_op) + dist_ops = get_distributed_operator(src_op.type) + append_op_handle = dist_ops.get_impl(dist_attr.get_impl_idx()).forward( + src_op) + append_op_handle( + dst_block, + src_op, + dist_attr, + input_mapping, + output_mapping, + rank_id=rank_id) diff --git a/python/paddle/distributed/auto_parallel/process.py b/python/paddle/distributed/auto_parallel/process.py new file mode 100644 index 00000000000..b919645b96c --- /dev/null +++ b/python/paddle/distributed/auto_parallel/process.py @@ -0,0 +1,149 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import paddle +import paddle.fluid.core as core +from ..collective import _get_global_env +from ..collective import _new_ring_id +from ...fluid.framework import in_dygraph_mode +from ...fluid.layers.tensor import fill_constant + +LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP = None +PROCESSOR_TO_PHYSICAL_PROCESS_MAP = None + + +def get_all_logical_process_set(): + from .interface import _g_process_mesh_map + all_logical_process_set = set(_g_process_mesh_map[0].process_group) + return all_logical_process_set + + +def get_logical_process_to_physical_process_map(): + global LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP + return LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP + + +def set_logical_process_to_physical_process_map(mapping): + global LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP + LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP = mapping + + +def get_processor_to_physical_process_map(): + global PROCESSOR_TO_PHYSICAL_PROCESS_MAP + return PROCESSOR_TO_PHYSICAL_PROCESS_MAP + + +def set_processor_to_physical_process_map(mapping): + global PROCESSOR_TO_PHYSICAL_PROCESS_MAP + PROCESSOR_TO_PHYSICAL_PROCESS_MAP = mapping + + +PROCESS_GROUP_MAP = {} + + +def get_all_process_groups(): + global PROCESS_GROUP_MAP + return PROCESS_GROUP_MAP.values() + + +def new_process_group(ranks): + global PROCESS_GROUP_MAP + if not PROCESS_GROUP_MAP: + genv = _get_global_env() + PROCESS_GROUP_MAP["global_group"] = ProcessGroup( + 0, list(range(genv.world_size))) + # A key constructed from ranks is used in the global process group map + key = ''.join(map(str, sorted(ranks))) + if key not in PROCESS_GROUP_MAP: + num_groups = len(PROCESS_GROUP_MAP) + # Note: our process group may interfere with the original implementation + # so the created group id should start from the original _new_ring_id() + group_id = _new_ring_id() + num_groups + 1 + pg = ProcessGroup(group_id, ranks) + PROCESS_GROUP_MAP[key] = pg + return pg + else: + pg = PROCESS_GROUP_MAP[key] + return pg + + +# This implementation refers to lots of Paddle/python/paddle/distributed/collective.py, +# Fleet also has a collective helper which uses ops to initialize communication in +# Paddle/python/paddle/distributed/fleet/meta_optimizers/common.py. We use the first one +# because it seems simple. This should be enhanced to manage the process membership and +# the instantiation process in a more general way. In the future, the process group may +# handle the communication implementation choice. +class ProcessGroup: + def __init__(self, group_id, ranks): + self._group_id = group_id + self._ranks = sorted(ranks) + self._nranks = len(self._ranks) + self._is_instantiate = False + + @property + def id(self): + return self._group_id + + # @property + # def key(self): + # return ''.join(map(str, sorted(self._ranks))) + + def local_rank(self, global_rank): + if global_rank in self._ranks: + return self._ranks.index(global_rank) + else: + assert False, \ + "Rank {} doesn't belong to this group".format(global_rank) + + def is_instantiate(self): + return self._is_instantiate + + def instantiate(self): + if self._is_instantiate: + return + ring_id = self.id + genv = _get_global_env() + global_rank = genv.rank + + if self._nranks >= 2: + strategy = core.ParallelStrategy() + strategy.nranks = self._nranks + strategy.local_rank = self.local_rank(global_rank) + strategy.trainer_endpoints = [ + genv.trainer_endpoints[i] for i in self._ranks + ] + strategy.current_endpoint = genv.current_endpoint + strategy.nrings = 1 + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(genv.device_id) + core.NCCLParallelContext(strategy, + place).init_with_ring_id(ring_id) + else: + assert False, ("No CUDA device found") + + # TODO(shenliang03): This is a temporary solution to solve the problem of + # hang caused by cross-creation of new_group + tmp = paddle.to_tensor( + [1], dtype="int32") if in_dygraph_mode() else fill_constant( + [0], dtype="int32", value="1") + paddle.distributed.all_reduce(tmp, use_calc_stream=True) + paddle.distributed.wait(tmp) + + self._is_instantiate = True + + def __str__(self): + string = "id: {}, nranks: {}, ranks: {}.".format( + self.id, self._nranks, ", ".join(map(str, self._ranks))) + return string diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py old mode 100644 new mode 100755 index a4a73ae5c0a..c864375271b --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -14,6 +14,7 @@ import threading import paddle.fluid.core as core +import numpy as np def is_valid_list_index(list, index): @@ -155,3 +156,125 @@ def print_program_with_distributed_attr(program, dist_context=None): print(program) set_default_distributed_context(original_default_context) lock.release() + + +def _get_comm_group(processes, shape, axis, rank): + """ + Given a rank and the processes mesh the rank belongs to, + compute the communication peers of the rank based on the give axis in the mesh. + + Example: 16 processes managed in a 4-Dimensinal mesh with shape of [2, 2, 2, 2]. + the rank communication peers of rank 0 (included) are following: + in axis 0: [0, 1] + in axis 1: [0, 2] + in axis 2: [0, 4] + in axis 3: [0, 8] + """ + + # NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous + # tricks to support processes mesh when it is not start with 0 or continuous + rank_relatvie = processes.index(rank) + coordinate = _linear_idx2coordinate(shape, rank_relatvie) + coordinates_in_group = [coordinate[:] for i in range(shape[axis])] + + # select comm group + for i in range(shape[axis]): + coordinates_in_group[i][axis] = i + + ranks_in_group_relative = [ + _coordinate2linear_idx(shape, coordinate) + for coordinate in coordinates_in_group + ] + ranks_in_group = [processes[idx] for idx in ranks_in_group_relative] + + return sorted(ranks_in_group) + + +def _coordinate2linear_idx(mesh_shape, coordinate): + """ + convert a coordinate in multidimensional mesh space into a scala idx in linear space. + + it use Row-major order for dimension conversion. + so it has: [most_significant_dim, ..., least_significant_dim] + assume: + + the size of i-th dimension to be: S[i] + the index of j-th dimension is: I[j] + + linear_idx of a n dimensional coordinate is: + + I[n-1] * (S[n-2] * S[n-3] * S[n-4] * .... S[0]) + + I[n-2] * ( S[n-3] * S[n-4] * .... S[0]) + + I[n-3] * ( S[n-4] * .... S[0]) + + ... + I[1] * ( S[0]) + + I[0] + + """ + # NOTE the following function work based on a strong an assumption + # that the processes in mesh are + # 1. starts from 0 + # 2. continuous + # it will be wrong if ths above condition doesnot meet, + # e.g. process_mesh = { process_groups = [7, 8, 9,10, 12, 13, 14, 15], mesh = [2, 4]} + # if you want a more general mapping, you should use cartesian product + + assert len(mesh_shape) == len( + coordinate + ), "coordinate should have the same size as mesh shape, but got shape: {}, coordinate: {}".format( + mesh_shape, coordinate) + for i in range(len(mesh_shape)): + assert coordinate[ + i] >= 0, "index in dimension [{}] is least than zero. coordinate: {}".format( + i, coordinate) + assert coordinate[i] < mesh_shape[ + i], "index beyond extent in dimension [{}]. shape: {}, coordinate: {}".format( + i, mesh_shape, coordinate) + + base = mesh_shape[-1] + linear_idx = coordinate[-1] + + # row major order + for i in range(len(mesh_shape) - 2, -1, -1): + linear_idx += base * coordinate[i] + base *= mesh_shape[i] + + return linear_idx + + +def _linear_idx2coordinate(mesh_shape, linear_idx): + """ + mapping a linear scala into multidimensional mesh space, return it coordinate in that space. + + it is the inverse function of _coordinate2linear_idx. + assume: + + the size of i-th dimension to be: S[i] + the index of j-th dimension is: I[j] + + the coordinate given linear_idx is: + + I[0] = linear_idx % S[0] + I[0] = (linear_idx / S[0]) % S[1] + I[0] = (linear_idx / (S[0] * S[1])) % S[2] + .... + + """ + + assert linear_idx >= 0, "linear index [{}] is least than zero".format( + linear_idx) + assert linear_idx < np.prod( + mesh_shape + ), "linear index beyond the extent of mesh shape. shape: {}, linear index: {}".format( + mesh_shape, linear_idx) + + base = 1 + coordinate = [-1] * len(mesh_shape) + + for i in reversed(range(len(mesh_shape))): + offset = linear_idx / base + coordinate[i] = int(offset % mesh_shape[i]) + base *= mesh_shape[i] + + # row major order + return coordinate diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index fb7f18fcc4e..5ca9624b980 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -79,6 +79,8 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_meta_optimizer_base) list(APPEND MIXED_DIST_TEST_OPS test_fleet_distributed_strategy) list(APPEND MIXED_DIST_TEST_OPS test_fleet_auto) list(APPEND MIXED_DIST_TEST_OPS test_fleet_static_mp_layers) +list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner) +list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner_gpt) foreach(TEST_OP ${MIXED_DIST_TEST_OPS}) list(REMOVE_ITEM TEST_OPS ${TEST_OP}) endforeach() @@ -206,6 +208,8 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) LIST(REMOVE_ITEM TEST_OPS test_dygraph_recompute) list(REMOVE_ITEM TEST_OPS test_parallel_class_center_sample) LIST(REMOVE_ITEM TEST_OPS test_parallel_margin_cross_entropy) + LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner) + LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner_gpt) elseif(WITH_GPU) if (${CUDNN_VERSION} VERSION_LESS 7100) LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py new file mode 100755 index 00000000000..f1049084cfb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py @@ -0,0 +1,948 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import unittest.mock +from io import StringIO +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.static as static +import paddle.nn.functional as F +import paddle.utils as utils +import paddle.tensor as tensor +from paddle.fluid import layers +from paddle.nn.layer.transformer import _convert_param_attr_to_list +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program +from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr +from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix +from paddle.distributed.auto_parallel.context import DistributedContext +from paddle.distributed.auto_parallel.context import set_default_distributed_context +from paddle.distributed import fleet +from paddle.distributed.auto_parallel.partitioner import Partitioner +from paddle.distributed.auto_parallel.utils import _get_comm_group +from paddle.distributed.auto_parallel.process import new_process_group + +paddle.enable_static() +_global_parallel_stratergy = None +_global_process_mesh = None +ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]]) + + +def get_programs(annotated_func): + train_program = static.Program() + start_program = static.Program() + dist_context = DistributedContext() + global _global_process_mesh + dist_context.set_process_mesh(_global_process_mesh) + train_program, start_program = annotated_func(train_program, start_program) + complete_train_program = auto.complete_annotation(train_program, + dist_context) + + rank_id = 3 + dist_strategy = fleet.DistributedStrategy() + partitioner = Partitioner(dist_strategy, dist_context, rank_id) + test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog = partitioner.transpile_forward( + complete_train_program, start_program) + + return complete_train_program, start_program, test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog, dist_context + + +def is_all_parameters_shape_equal(prog1, prog2): + + params1 = prog1.all_parameters() + params2 = prog2.all_parameters() + params1.sort(key=lambda x: x.name) + params2.sort(key=lambda x: x.name) + shape1 = [tensor.shape for tensor in params1] + shape2 = [tensor.shape for tensor in params2] + + if len(shape1) != len(shape2): + return False + for i in range(len(shape1)): + if shape1[i] != shape2[i]: + return False + return True + + +def check_tensor_split(prog1, varnames1, prog2, varnames2, axis, nsplit): + + for i in range(len(varnames1)): + var1 = prog1.global_block().var(varnames1[i]) + var2 = prog2.global_block().var(varnames2[i]) + if var1.shape[axis] != (var2.shape[axis] // nsplit): + return False + + return True + + +def initialization_check(mode, dist_context, dist_startup_prog, + serial_startup_prog, var_need_broadcast): + if 'mp' in mode: + mp_parallel_axis, process_mesh = dist_context._get_model_parallel_info() + group_ranks = _get_comm_group(process_mesh.process_group, + process_mesh.topology, mp_parallel_axis, + 3) + mp_ring_id = new_process_group(group_ranks).id + broadcast_ops = [ + op for op in dist_startup_prog.global_block().ops + if (op.type == "c_broadcast" and op.desc.attr("ring_id") == + mp_ring_id) + ] + broadcast_varnames = sorted( + [op.desc.output_arg_names()[0] for op in broadcast_ops]) + if broadcast_varnames != var_need_broadcast: + return False + + if 'dp' in mode: + dp_parallel_axis, process_mesh = dist_context._get_data_parallel_info() + group_ranks = _get_comm_group(process_mesh.process_group, + process_mesh.topology, dp_parallel_axis, + 3) + dp_ring_id = new_process_group(group_ranks).id + nparam = len(serial_startup_prog.all_parameters()) + nbroadcast_dp = len([ + op for op in dist_startup_prog.global_block().ops + if (op.type == "c_broadcast" and op.desc.attr("ring_id") == + dp_ring_id) + ]) + if nparam != nbroadcast_dp: + return False + + if "dp" in mode and 'mp' in mode: + nbroadcast = len([ + op for op in dist_startup_prog.global_block().ops + if op.type == "c_broadcast" + ]) + if len(var_need_broadcast) + nbroadcast_dp != nbroadcast: + return False + + return True + + +class MLPLayer(nn.Layer): + def __init__(self, + hidden_size=1024, + intermediate_size=4 * 1024, + dropout_ratio=0.1, + initializer_range=0.02): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range)) + bias_attr = None + + self.linear0 = nn.Linear( + d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) + self.linear1 = nn.Linear( + dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") + + def forward(self, input): + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1]) + else: + auto.shard_tensor( + self.linear0.weight, _global_process_mesh, + dim_mapping=[-1, -1]) + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, + dim_mapping=[-1, -1]) + + out = self.norm(input) + out = self.linear0(out) + out = F.gelu(out, approximate=True) + out = self.linear1(out) + out = self.dropout(out) + + return out + + +def mlp_pretrain_forward(train_program, start_program): + with static.program_guard(train_program, + start_program), utils.unique_name.guard(): + batch_size = 4 + hidden_size = 1024 + sequence_len = 512 + input = static.data( + name="input", + shape=[batch_size, sequence_len, hidden_size], + dtype='float32') + + if _global_parallel_stratergy == "dp": + auto.shard_tensor( + input, _global_process_mesh, dim_mapping=[0, -1, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + input, _global_process_mesh, dim_mapping=[0, -1, -1]) + + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02) + out = mlp(input) + return train_program, start_program + + +class TestMLPAutoPartitioner(unittest.TestCase): + def test_mlp_dp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "dp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[0, 1, 2, 3], parent=ROOT_MESH) + + serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( + mlp_pretrain_forward) + + # parameter should not be partitioned + self.assertTrue( + is_all_parameters_shape_equal(serial_main_prog, dist_main_prog)) + self.assertTrue( + is_all_parameters_shape_equal(serial_startup_prog, + dist_startup_prog)) + + # op in main prog should be the same + serial_ops = serial_main_prog.global_block().ops + dist_ops = dist_main_prog.global_block().ops + serial_ops = [op.type for op in serial_ops] + dist_ops = [op.type for op in dist_ops] + self.assertTrue(serial_ops == dist_ops) + + # parameter initialization + var_need_broadcast = [] + self.assertTrue( + initialization_check(_global_parallel_stratergy, dist_context, + dist_startup_prog, serial_startup_prog, + var_need_broadcast)) + + def test_mlp_mp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "mp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[0, 1, 2, 3], parent=ROOT_MESH) + serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( + mlp_pretrain_forward) + + # param should be partition + nrank = 4 + # col parallel + weights = ['linear_0.w_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 1, nrank)) + weights = ['linear_0.b_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + # row parallel + weights = ['linear_1.w_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + weights = ['linear_1.b_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, 1)) + + # row and col allreduce + dist_ops = dist_main_prog.global_block().ops + dist_ops = [op.type for op in dist_ops] + ref_ops = [ + 'layer_norm', 'c_identity', 'matmul', 'elementwise_add', 'gelu', + 'matmul', 'c_allreduce_sum', 'elementwise_add', 'dropout' + ] + self.assertTrue(dist_ops == ref_ops) + + # parameter initialization + var_need_broadcast = sorted( + ['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0']) + self.assertTrue( + initialization_check(_global_parallel_stratergy, dist_context, + dist_startup_prog, serial_startup_prog, + var_need_broadcast)) + + def test_mlp_dp_mp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "dp_mp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( + mlp_pretrain_forward) + + # param should be partition + nrank = 4 + # col parallel + weights = ['linear_0.w_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 1, nrank)) + weights = ['linear_0.b_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + # row parallel + weights = ['linear_1.w_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + weights = ['linear_1.b_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, 1)) + + # row and col allreduce + dist_ops = dist_main_prog.global_block().ops + dist_ops = [op.type for op in dist_ops] + ref_ops = [ + 'layer_norm', 'c_identity', 'matmul', 'elementwise_add', 'gelu', + 'matmul', 'c_allreduce_sum', 'elementwise_add', 'dropout' + ] + self.assertTrue(dist_ops == ref_ops) + + # parameter initialization + var_need_broadcast = sorted( + ['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0']) + self.assertTrue( + initialization_check(_global_parallel_stratergy, dist_context, + dist_startup_prog, serial_startup_prog, + var_need_broadcast)) + + +class AttentionLayer(nn.Layer): + def __init__(self, + hidden_size=1024, + sequence_len=512, + intermediate_size=4 * 1024, + num_heads=16, + dropout_ratio=0.1, + initializer_range=0.02): + super(AttentionLayer, self).__init__() + self.hidden_size = hidden_size + self.sequence_len = sequence_len + self.embed_dim = self.hidden_size + self.kdim = self.embed_dim + self.vdim = self.embed_dim + self.num_heads = num_heads + self.head_dim = self.embed_dim // self.num_heads + assert self.head_dim * self.num_heads == self.embed_dim, \ + "embed_dim must be divisible by num_heads" + self.dropout_ratio = dropout_ratio + self.initializer_range = initializer_range + self.training = True + self.attn_mask = None + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range)) + bias_attr = None + + self.q_proj = nn.Linear( + self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr) + self.k_proj = nn.Linear( + self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr) + self.v_proj = nn.Linear( + self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr) + self.out_proj = nn.Linear( + self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr) + + def forward(self, input): + if _global_parallel_stratergy == "dp": + auto.shard_tensor( + input, _global_process_mesh, dim_mapping=[0, -1, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + input, _global_process_mesh, dim_mapping=[0, -1, -1]) + + q = self.q_proj(input) + q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) + q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) + + k = self.k_proj(input) + v = self.v_proj(input) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + auto.shard_tensor( + self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + auto.shard_tensor( + self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + auto.shard_tensor( + self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + auto.shard_tensor( + self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + + k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) + k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) + v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) + v = tensor.transpose(x=v, perm=[0, 2, 1, 3]) + + # scale dot product attention + product = layers.matmul( + x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) + + if self.attn_mask is not None: + product = product + self.attn_mask + + weights = F.softmax(product) + + if self.dropout_ratio: + weights = F.dropout( + weights, + self.dropout_ratio, + training=self.training, + mode="upscale_in_train") + + out = tensor.matmul(weights, v) + + # combine heads + out = tensor.transpose(out, perm=[0, 2, 1, 3]) + out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + + # project to output + out = self.out_proj(out) + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.out_proj.weight, _global_process_mesh, + dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.out_proj.weight, _global_process_mesh, + dim_mapping=[1, -1]) + + return out + + +def attn_pretrain_forward(train_program, start_program): + with static.program_guard(train_program, + start_program), utils.unique_name.guard(): + batch_size = 4 + hidden_size = 1024 + sequence_len = 512 + input = static.data( + name="query", + shape=[batch_size, sequence_len, hidden_size], + dtype='float32') + attn = AttentionLayer( + hidden_size=hidden_size, + sequence_len=sequence_len, + intermediate_size=4 * hidden_size, + num_heads=16, + dropout_ratio=0.1, + initializer_range=0.02) + out = attn(input) + + return train_program, start_program + + +class TestAttentionAutoPartitioner(unittest.TestCase): + def test_attn_dp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "dp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[0, 1, 2, 3], parent=ROOT_MESH) + + serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( + attn_pretrain_forward) + # parameter should not be partitioned + self.assertTrue( + is_all_parameters_shape_equal(serial_main_prog, dist_main_prog)) + self.assertTrue( + is_all_parameters_shape_equal(serial_startup_prog, + dist_startup_prog)) + + # op in main prog should be the same + serial_ops = serial_main_prog.global_block().ops + dist_ops = dist_main_prog.global_block().ops + serial_ops = [op.type for op in serial_ops] + dist_ops = [op.type for op in dist_ops] + self.assertTrue(serial_ops == dist_ops) + + # parameter initialization + var_need_broadcast = [] + self.assertTrue( + initialization_check(_global_parallel_stratergy, dist_context, + dist_startup_prog, serial_startup_prog, + var_need_broadcast)) + + def test_attn_mp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "mp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[0, 1, 2, 3], parent=ROOT_MESH) + + serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( + attn_pretrain_forward) + + # param should be partition + nrank = 4 + # col parallel + weights = ['linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 1, nrank)) + weights = ['linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + # row parallel + weights = ['linear_3.w_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + weights = ['linear_3.b_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, 1)) + + # row and col allreduce + dist_ops = dist_main_prog.global_block().ops + dist_ops = [op.type for op in dist_ops] + ref_ops = [ + 'c_identity', 'matmul', 'elementwise_add', 'reshape2', 'transpose2', + 'c_identity', 'matmul', 'elementwise_add', 'c_identity', 'matmul', + 'elementwise_add', 'reshape2', 'transpose2', 'reshape2', + 'transpose2', 'matmul', 'softmax', 'dropout', 'matmul_v2', + 'transpose2', 'reshape2', 'matmul', 'c_allreduce_sum', + 'elementwise_add' + ] + self.assertTrue(dist_ops == ref_ops) + + # parameter initialization + var_need_broadcast = ['linear_3.b_0'] + self.assertTrue( + initialization_check(_global_parallel_stratergy, dist_context, + dist_startup_prog, serial_startup_prog, + var_need_broadcast)) + + def test_attn_dp_mp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "dp_mp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + + serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( + attn_pretrain_forward) + + # param should be partition + nrank = 4 + # col parallel + weights = ['linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 1, nrank)) + weights = ['linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + # row parallel + weights = ['linear_3.w_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + weights = ['linear_3.b_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, 1)) + + # row and col allreduce + dist_ops = dist_main_prog.global_block().ops + dist_ops = [op.type for op in dist_ops] + ref_ops = [ + 'c_identity', 'matmul', 'elementwise_add', 'reshape2', 'transpose2', + 'c_identity', 'matmul', 'elementwise_add', 'c_identity', 'matmul', + 'elementwise_add', 'reshape2', 'transpose2', 'reshape2', + 'transpose2', 'matmul', 'softmax', 'dropout', 'matmul_v2', + 'transpose2', 'reshape2', 'matmul', 'c_allreduce_sum', + 'elementwise_add' + ] + self.assertTrue(dist_ops == ref_ops) + + # parameter initialization + var_need_broadcast = ['linear_3.b_0'] + self.assertTrue( + initialization_check(_global_parallel_stratergy, dist_context, + dist_startup_prog, serial_startup_prog, + var_need_broadcast)) + + +class DecoderLayer(nn.Layer): + def __init__(self, + vocab_size=32768, + hidden_size=1024, + sequence_len=512, + max_position_embeddings=512, + intermediate_size=4 * 1024, + num_heads=16, + dropout_ratio=0.1, + initializer_range=0.02): + super(DecoderLayer, self).__init__() + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.max_position_embeddings = max_position_embeddings + self.sequence_len = sequence_len + self.embed_dim = self.hidden_size + self.kdim = self.embed_dim + self.vdim = self.embed_dim + self.num_heads = num_heads + self.dropout_ratio = dropout_ratio + self.initializer_range = initializer_range + self.training = True + self.attn_mask = None + + self.head_dim = self.embed_dim // self.num_heads + assert self.head_dim * self.num_heads == self.embed_dim, \ + "embed_dim must be divisible by num_heads" + self.word_embeddings = nn.Embedding( + self.vocab_size, + self.hidden_size, + weight_attr=paddle.ParamAttr( + name="word_embeddings", + initializer=nn.initializer.Normal( + mean=0.0, std=self.initializer_range))) + self.position_embeddings = nn.Embedding( + self.max_position_embeddings, + self.hidden_size, + weight_attr=paddle.ParamAttr( + name="pos_embeddings", + initializer=nn.initializer.Normal( + mean=0.0, std=self.initializer_range))) + + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=self.initializer_range)) + bias_attr = None + self.q_proj = nn.Linear( + self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr) + self.k_proj = nn.Linear( + self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr) + self.v_proj = nn.Linear( + self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr) + self.out_proj = nn.Linear( + self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr) + + intermediate_size = 4 * self.hidden_size + d_model = self.hidden_size + dim_feedforward = intermediate_size + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=self.initializer_range)) + bias_attr = None + self.linear0 = nn.Linear( + d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) + self.linear1 = nn.Linear( + dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + self.dropout1 = nn.Dropout(self.dropout_ratio) + self.dropout2 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train") + self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train") + + def forward(self, input_ids, position_ids): + if _global_parallel_stratergy == "dp": + auto.shard_tensor( + input_ids, _global_process_mesh, dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + input_ids, _global_process_mesh, dim_mapping=[0, -1]) + + input_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.word_embeddings.weight, + _global_process_mesh, + dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.word_embeddings.weight, + _global_process_mesh, + dim_mapping=[1, -1]) + + embeddings = input_embeddings + position_embeddings + embeddings = self.dropout1(embeddings) + + # Pre-norm + target = self.norm(embeddings) + + # The following is the attention part + q = self.q_proj(target) + q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) + q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) + + k = self.k_proj(target) + v = self.v_proj(target) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + auto.shard_tensor( + self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + auto.shard_tensor( + self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + auto.shard_tensor( + self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + auto.shard_tensor( + self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + + k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) + k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) + v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) + v = tensor.transpose(x=v, perm=[0, 2, 1, 3]) + + # scale dot product attention + product = layers.matmul( + x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) + + if self.attn_mask is not None: + product = product + self.attn_mask + + weights = F.softmax(product) + + if self.dropout_ratio: + weights = F.dropout( + weights, + self.dropout_ratio, + training=self.training, + mode="upscale_in_train") + + out = tensor.matmul(weights, v) + + # combine heads + out = tensor.transpose(out, perm=[0, 2, 1, 3]) + out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + + # project to output + out = self.out_proj(out) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.out_proj.weight, _global_process_mesh, + dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.out_proj.weight, _global_process_mesh, + dim_mapping=[1, -1]) + else: + auto.shard_tensor( + self.out_proj.weight, + _global_process_mesh, + dim_mapping=[-1, -1]) + + # Add residual + residual = embeddings + self.dropout2(out) + + # Pre-norm + out0 = self.norm(residual) + + # The following is the MLP part + out1 = self.linear0(out0) + out2 = F.gelu(out1, approximate=True) + out3 = self.linear1(out2) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1]) + + # Add residual + final = residual + self.dropout3(out3) + return final + + +def decoder_pretrain_forward(train_program, start_program): + with static.program_guard(train_program, + start_program), utils.unique_name.guard(): + batch_size = 4 + hidden_size = 1024 + sequence_len = 512 + input_ids = static.data( + name="input_ids", shape=[batch_size, sequence_len], dtype='int64') + position_ids = static.data( + name="position_ids", + shape=[batch_size, sequence_len], + dtype='int64') + decoder = DecoderLayer( + vocab_size=32768, + hidden_size=hidden_size, + sequence_len=sequence_len, + max_position_embeddings=512, + intermediate_size=4 * hidden_size, + num_heads=16, + dropout_ratio=0.1, + initializer_range=0.02) + out = decoder(input_ids, position_ids) + + return train_program, start_program + + +class TestDecoderLayerPartitioner(unittest.TestCase): + def test_decoder_dp_mp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "dp_mp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( + decoder_pretrain_forward) + + # param should be partition + nrank = 4 + # col parallel + weights = [ + 'linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0', 'linear_4.w_0' + ] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 1, nrank)) + weights = [ + 'linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0', 'linear_4.b_0' + ] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + # row parallel + weights = ['word_embeddings', 'linear_3.w_0', 'linear_5.w_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + weights = [ + 'linear_3.b_0', 'pos_embeddings', 'layer_norm_0.b_0', + 'layer_norm_0.w_0', 'linear_5.b_0' + ] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, 1)) + + # row and col allreduce + dist_ops = dist_main_prog.global_block().ops + dist_ops = [op.type for op in dist_ops] + ref_ops = [ + 'c_embedding', 'c_allreduce_sum', 'lookup_table_v2', + 'elementwise_add', 'dropout', 'layer_norm', 'c_identity', 'matmul', + 'elementwise_add', 'reshape2', 'transpose2', 'c_identity', 'matmul', + 'elementwise_add', 'c_identity', 'matmul', 'elementwise_add', + 'reshape2', 'transpose2', 'reshape2', 'transpose2', 'matmul', + 'softmax', 'dropout', 'matmul_v2', 'transpose2', 'reshape2', + 'matmul', 'c_allreduce_sum', 'elementwise_add', 'dropout', + 'elementwise_add', 'layer_norm', 'c_identity', 'matmul', + 'elementwise_add', 'gelu', 'matmul', 'c_allreduce_sum', + 'elementwise_add', 'dropout', 'elementwise_add' + ] + self.assertTrue(dist_ops == ref_ops) + + # parameter initialization + var_need_broadcast = sorted([ + 'linear_3.b_0', 'pos_embeddings', 'layer_norm_0.b_0', + 'layer_norm_0.w_0', 'linear_5.b_0' + ]) + self.assertTrue( + initialization_check(_global_parallel_stratergy, dist_context, + dist_startup_prog, serial_startup_prog, + var_need_broadcast)) + + def test_decoder_noparallel(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "None" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( + decoder_pretrain_forward) + + # param should be partition + nrank = 1 + # col parallel + weights = [ + 'linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0', 'linear_4.w_0' + ] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 1, nrank)) + weights = [ + 'linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0', 'linear_4.b_0' + ] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + # row parallel + weights = ['word_embeddings', 'linear_3.w_0', 'linear_5.w_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + weights = [ + 'linear_3.b_0', 'pos_embeddings', 'layer_norm_0.b_0', + 'layer_norm_0.w_0', 'linear_5.b_0' + ] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, 1)) + + # row and col allreduce + dist_ops = dist_main_prog.global_block().ops + dist_ops = [op.type for op in dist_ops] + ref_ops = [ + 'lookup_table_v2', 'lookup_table_v2', 'elementwise_add', 'dropout', + 'layer_norm', 'matmul', 'elementwise_add', 'reshape2', 'transpose2', + 'matmul', 'elementwise_add', 'matmul', 'elementwise_add', + 'reshape2', 'transpose2', 'reshape2', 'transpose2', 'matmul', + 'softmax', 'dropout', 'matmul_v2', 'transpose2', 'reshape2', + 'matmul', 'elementwise_add', 'dropout', 'elementwise_add', + 'layer_norm', 'matmul', 'elementwise_add', 'gelu', 'matmul', + 'elementwise_add', 'dropout', 'elementwise_add' + ] + self.assertTrue(dist_ops == ref_ops) + dist_ops = dist_startup_prog.global_block().ops + dist_ops = [op.type for op in dist_ops] + ref_ops = [ + 'gaussian_random', 'gaussian_random', 'gaussian_random', + 'fill_constant', 'gaussian_random', 'fill_constant', + 'gaussian_random', 'fill_constant', 'gaussian_random', + 'fill_constant', 'gaussian_random', 'fill_constant', + 'gaussian_random', 'fill_constant', 'fill_constant', 'fill_constant' + ] + self.assertTrue(dist_ops == ref_ops) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py new file mode 100755 index 00000000000..b02c5f8a84f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py @@ -0,0 +1,857 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import collections +import math +import unittest + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import paddle.tensor as tensor +import paddle.utils as utils +from paddle.fluid import layers +from paddle.fluid.framework import in_dygraph_mode +from paddle.nn.layer.transformer import _convert_param_attr_to_list +from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer +from paddle.distributed import fleet +import paddle.static as static +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program +from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr +from paddle.distributed.auto_parallel.context import DistributedContext +from paddle.distributed.auto_parallel.partitioner import Partitioner +from paddle.distributed.auto_parallel.utils import _get_comm_group +from paddle.distributed.auto_parallel.process import new_process_group + +paddle.enable_static() +ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]]) +_global_parallel_stratergy = None +_global_process_mesh = None + + +def check_tensor_split(prog1, varnames1, prog2, varnames2, axis, nsplit): + + for i in range(len(varnames1)): + var1 = prog1.global_block().var(varnames1[i] + '@GRAD') + var2 = prog2.global_block().var(varnames2[i]) + if var1.shape[axis] != (var2.shape[axis] // nsplit): + return False + + return True + + +class MultiHeadAttention(nn.Layer): + """ + Attention mapps queries and a set of key-value pairs to outputs, and + Multi-Head Attention performs multiple parallel attention to jointly attending + to information from different representation subspaces. + """ + + Cache = collections.namedtuple("Cache", ["k", "v"]) + StaticCache = collections.namedtuple("StaticCache", ["k", "v"]) + + def __init__(self, + embed_dim, + num_heads, + dropout=0., + kdim=None, + vdim=None, + need_weights=False, + weight_attr=None, + bias_attr=None, + topo=None, + fuse=False): + super(MultiHeadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.need_weights = need_weights + self.fuse = fuse + + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + if topo is None or topo.mp_info.size == 1: + if self.fuse: + assert self.kdim == embed_dim + assert self.vdim == embed_dim + self.qkv_proj = nn.Linear( + embed_dim, 3 * embed_dim, weight_attr, bias_attr=bias_attr) + else: + self.q_proj = nn.Linear( + embed_dim, embed_dim, weight_attr, bias_attr=bias_attr) + self.k_proj = nn.Linear( + self.kdim, embed_dim, weight_attr, bias_attr=bias_attr) + self.v_proj = nn.Linear( + self.vdim, embed_dim, weight_attr, bias_attr=bias_attr) + self.out_proj = nn.Linear( + embed_dim, embed_dim, weight_attr, bias_attr=bias_attr) + + def _fuse_prepare_qkv(self, query): + mix_layer = self.qkv_proj(query) + mix_layer = paddle.reshape_(mix_layer, + [0, 0, self.num_heads, 3 * self.head_dim]) + mix_layer = paddle.transpose(mix_layer, [0, 2, 1, 3]) + q, k, v = paddle.split(mix_layer, num_or_sections=3, axis=-1) + return q, k, v + + def _prepare_qkv(self, query, key, value, use_cache=False, cache=None): + r""" + Prapares linear projected queries, keys and values for usage of subsequnt + multiple parallel attention. If `cache` is not None, using cached results + to reduce redundant calculations. + """ + q = self.q_proj(query) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + + q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) + q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) + + if isinstance(cache, self.StaticCache): + # for encoder-decoder attention in inference and has cached + k, v = cache.k, cache.v + else: + k, v = self.compute_kv(key, value) + + if isinstance(cache, self.Cache): + # for decoder self-attention in inference + k = tensor.concat([cache.k, k], axis=2) + v = tensor.concat([cache.v, v], axis=2) + if use_cache is True: + cache = self.Cache(k, v) + + return (q, k, v) if use_cache is False else (q, k, v, cache) + + def compute_kv(self, key, value): + r""" + Applies linear projection on input keys and values, then splits heads + (reshape and transpose) to get keys and values from different representation + subspaces. The results are used as key-values pairs for subsequent multiple + parallel attention. + It is part of calculations in multi-head attention, and is provided as + a method to pre-compute and prefetch these results, thus we can use them + to construct cache for inference. + """ + k = self.k_proj(key) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + + v = self.v_proj(value) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + + k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) + k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) + v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) + v = tensor.transpose(x=v, perm=[0, 2, 1, 3]) + return k, v + + def gen_cache(self, key, value=None, type=Cache): + """ + Generates cache for `forward` usage in inference accroding to arguments. + The generated cache is an instance of `MultiHeadAttention.Cache` or an + instance of `MultiHeadAttention.StaticCache`. + """ + if type == MultiHeadAttention.StaticCache: # static_kv + k, v = self.compute_kv(key, value) + return self.StaticCache(k, v) + elif value is None: # incremental_state + k = layers.fill_constant_batch_size_like( + input=key, + shape=[-1, self.num_heads, 0, self.head_dim], + dtype=key.dtype, + value=0) + v = layers.fill_constant_batch_size_like( + input=key, + shape=[-1, self.num_heads, 0, self.head_dim], + dtype=key.dtype, + value=0) + return self.Cache(k, v) + else: + # incremental_state with initial value, mainly for usage like UniLM + return self.Cache(key, value) + + def forward(self, + query, + key, + value, + attn_mask=None, + use_cache=False, + cache=None): + r""" + Applies multi-head attention to map queries and a set of key-value pairs + to outputs. + """ + key = query if key is None else key + value = query if value is None else value + # compute q ,k ,v + if use_cache is False: + if self.fuse: + q, k, v = self._fuse_prepare_qkv(query) + else: + q, k, v = self._prepare_qkv(query, key, value, use_cache, cache) + else: + q, k, v, cache = self._prepare_qkv(query, key, value, use_cache, + cache) + # scale dot product attention + product = layers.matmul( + x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) + + if attn_mask is not None: + product = product + attn_mask + + weights = F.softmax(product) + if self.dropout: + weights = F.dropout( + weights, + self.dropout, + training=self.training, + mode="upscale_in_train") + + out = tensor.matmul(weights, v) + + # combine heads + out = tensor.transpose(out, perm=[0, 2, 1, 3]) + out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + + # project to output + out = self.out_proj(out) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.out_proj.weight, _global_process_mesh, + dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.out_proj.weight, _global_process_mesh, + dim_mapping=[1, -1]) + + outs = [out] + if self.need_weights: + outs.append(weights) + if use_cache: + outs.append(cache) + return out if len(outs) == 1 else tuple(outs) + + +class TransformerDecoder(nn.Layer): + """ + TransformerDecoder is a stack of N decoder layers. + """ + + def __init__(self, + decoder_layers, + num_layers, + norm=None, + hidden_size=None, + topo=None): + super(TransformerDecoder, self).__init__() + + self.topo = topo + self.num_layers = num_layers + self.layers = decoder_layers + self.norm = norm + if norm is "LayerNorm": + self.norm = nn.LayerNorm(hidden_size) + elif norm is not None: + raise ValueError("Only support LayerNorm") + self.checkpoints = [] + + def forward(self, + tgt, + memory, + tgt_mask=None, + memory_mask=None, + use_cache=False, + cache=None): + r""" + Applies a stack of N Transformer decoder layers on inputs. If `norm` is + provided, also applies layer normalization on the output of last decoder + layer. + """ + output = tgt + new_caches = [] + self.checkpoints = [] + + for i, mod in enumerate(self.layers): + if cache is None: + if use_cache: + output, new_cache = mod(output, + memory, + tgt_mask=tgt_mask, + use_cache=use_cache, + cache=cache) + new_caches.append(new_cache) + else: + output = mod(output, + memory, + tgt_mask=tgt_mask, + use_cache=use_cache, + cache=cache) + + else: + output, new_cache = mod(output, + memory, + tgt_mask=tgt_mask, + use_cache=use_cache, + cache=cache[i]) + new_caches.append(new_cache) + self.checkpoints.append(output.name) + + if self.norm is not None: + output = self.norm(output) + return output if use_cache is False else (output, new_caches) + + def gen_cache(self, memory, do_zip=False): + r""" + Generates cache for `forward` usage. The generated cache is a list, and + each element in it is a tuple( :code:`(incremental_cache, static_cache)` ) + produced by `TransformerDecoderLayer.gen_cache`. See `TransformerDecoderLayer.gen_cache` + for more details. If `do_zip` is True, apply `zip` on these tuples to get + a list with two elements. + """ + cache = [layer.gen_cache(memory) for layer in self.layers] + if do_zip: + cache = list(zip(*cache)) + return cache + + +class TransformerDecoderLayer(nn.Layer): + """ + The transformer decoder layer. + It contains multiheadattention and some linear layers. + """ + + def __init__(self, + d_model, + nhead, + dim_feedforward, + dropout=0.1, + activation="gelu", + attn_dropout=None, + act_dropout=None, + normalize_before=True, + weight_attr=None, + bias_attr=None, + topo=None): + self._config = locals() + self._config.pop("self") + self._config.pop("__class__", None) # py3 + + super(TransformerDecoderLayer, self).__init__() + attn_dropout = dropout if attn_dropout is None else attn_dropout + act_dropout = dropout if act_dropout is None else act_dropout + self.normalize_before = normalize_before + + weight_attrs = _convert_param_attr_to_list(weight_attr, 3) + bias_attrs = _convert_param_attr_to_list(bias_attr, 3) + + self.self_attn = MultiHeadAttention( + d_model, + nhead, + dropout=attn_dropout, + weight_attr=weight_attrs[0], + bias_attr=bias_attrs[0], + topo=topo) + if topo is None or topo.mp_info.size == 1: + self.linear1 = nn.Linear( + d_model, + dim_feedforward, + weight_attrs[2], + bias_attr=bias_attrs[2]) + self.linear2 = nn.Linear( + dim_feedforward, + d_model, + weight_attrs[2], + bias_attr=bias_attrs[2]) + + self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5) + self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5) + self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train") + self.dropout2 = nn.Dropout(act_dropout, mode="upscale_in_train") + self.activation = getattr(F, activation) + + def forward(self, tgt, memory, tgt_mask=None, use_cache=False, cache=None): + residual = tgt + + if self.normalize_before: + tgt = self.norm1(tgt) + + if use_cache is False: + tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache) + else: + tgt, incremental_cache = self.self_attn(tgt, tgt, tgt, tgt_mask, + use_cache, cache) + tgt = residual + self.dropout1(tgt) + if not self.normalize_before: + tgt = self.norm1(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm2(tgt) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 0]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 1]) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.linear2.weight, _global_process_mesh, dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.linear2.weight, _global_process_mesh, dim_mapping=[1, -1]) + + # tgt = self.dropout2( + # self.linear2(F.gelu( + # self.linear1(tgt), approximate=True))) + tgt = self.linear1(tgt) + tgt = F.gelu(tgt, approximate=True) + tgt = self.dropout2(self.linear2(tgt)) + tgt = residual + tgt + + if not self.normalize_before: + tgt = self.norm2(tgt) + + return tgt if use_cache is False else (tgt, incremental_cache) + + def gen_cache(self, memory): + incremental_cache = self.self_attn.gen_cache( + memory, type=self.self_attn.Cache) + return incremental_cache + + +class GPTEmbeddings(nn.Layer): + """ + Include embeddings from word, position and token_type embeddings + """ + + def __init__(self, + vocab_size, + hidden_size=768, + hidden_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + initializer_range=0.02, + topo=None): + super(GPTEmbeddings, self).__init__() + if topo is None or topo.mp_info.size == 1: + self.word_embeddings = nn.Embedding( + vocab_size, + hidden_size, + weight_attr=paddle.ParamAttr( + name="word_embeddings", + initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range))) + self.position_embeddings = nn.Embedding( + max_position_embeddings, + hidden_size, + weight_attr=paddle.ParamAttr( + name="pos_embeddings", + initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range))) + + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, input_ids, position_ids=None): + if position_ids is None: + ones = paddle.ones_like(input_ids, dtype="int64") + seq_length = paddle.cumsum(ones, axis=-1) + position_ids = seq_length - ones + + input_embedings = self.word_embeddings(input_ids) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.word_embeddings.weight, + _global_process_mesh, + dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.word_embeddings.weight, + _global_process_mesh, + dim_mapping=[1, -1]) + + position_embeddings = self.position_embeddings(position_ids) + embeddings = input_embedings + position_embeddings + embeddings = self.dropout(embeddings) + return embeddings + + +class GPTModel(nn.Layer): + """ + The base model of gpt. + """ + + def __init__(self, + vocab_size, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + initializer_range=0.02, + pad_token_id=0, + topo=None): + super(GPTModel, self).__init__() + + self.pad_token_id = pad_token_id + self.initializer_range = initializer_range + self.topo = topo + self.hidden_size = hidden_size + self.vocab_size = vocab_size + + self.pipline_mode = topo is not None and topo.pp_info.size > 1 + if self.pipline_mode: + self.layer_per_stage = num_hidden_layers // self.topo.pp_info.size + + self.embeddings = GPTEmbeddings( + vocab_size, hidden_size, hidden_dropout_prob, + max_position_embeddings, type_vocab_size, self.initializer_range, + topo) + + decoder_layers = nn.LayerList() + for i in range(num_hidden_layers): + DecoderLayer = TransformerDecoderLayer + decoder_layers.append( + DecoderLayer( + d_model=hidden_size, + nhead=num_attention_heads, + dim_feedforward=intermediate_size, + dropout=hidden_dropout_prob, + activation=hidden_act, + attn_dropout=attention_probs_dropout_prob, + act_dropout=hidden_dropout_prob, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Normal( + mean=0.0, std=self.initializer_range)), + bias_attr=None, + topo=topo)) + + Decoder = TransformerDecoder + + self.decoder = Decoder( + decoder_layers, + num_hidden_layers, + norm="LayerNorm", + hidden_size=hidden_size, + topo=topo) + + self.checkpoints = [] + + def forward(self, + input_ids, + position_ids=None, + attention_mask=None, + use_cache=False, + cache=None): + self.checkpoints = [] + if attention_mask is None: + length = paddle.shape(input_ids)[1] + # Use bool mask + attention_mask = paddle.tensor.tril( + paddle.ones( + (length, length), + dtype=self.embeddings.word_embeddings.weight.dtype)) + if position_ids is None: + past_length = 0 + if cache is not None: + past_length = paddle.shape(cache[0].k)[-2] + position_ids = paddle.arange( + past_length, + paddle.shape(input_ids)[-1] + past_length, + dtype='int64') + position_ids = position_ids.unsqueeze(0) + # .expand_as(input_ids) + position_ids = paddle.fluid.layers.expand_as(position_ids, + input_ids) + embedding_output = self.embeddings( + input_ids=input_ids, position_ids=position_ids) + + # TODO, use registered buffer + causal_mask = paddle.tensor.triu( + paddle.ones((paddle.shape(input_ids)[-1], + paddle.shape(input_ids)[-1])) * -1e9, + diagonal=1) + + if attention_mask is not None: + attention_mask = attention_mask + causal_mask + else: + attention_mask = causal_mask + + # The tensor returned by triu not in static graph. + attention_mask.stop_gradient = True + + encoder_outputs = self.decoder( + embedding_output, + memory=None, + tgt_mask=attention_mask, + use_cache=use_cache, + cache=cache) + self.checkpoints.extend(self.decoder.checkpoints) + return encoder_outputs + + +class GPTForPretraining(nn.Layer): + """ + The pretraining model of GPT. + It returns some logits and cached_kvs. + """ + + def __init__(self, gpt): + super(GPTForPretraining, self).__init__() + self.gpt = gpt + self.share_param = False + self.weight = self.gpt.embeddings.word_embeddings.weight + if not self.share_param: + self.weight = self.create_parameter(shape=self.weight.shape) + + def parallel_matmul(self, lm_output, logit_weights, parallel_output, topo): + if topo is not None and topo.mp_info.size > 1: + input_parallel = paddle.distributed.collective._c_identity( + lm_output, group=None) + + logits = paddle.matmul( + input_parallel, logit_weights, transpose_y=True) + + if parallel_output: + return logits + + return paddle.distributed.collective._c_concat(logits, group=None) + else: + logits = paddle.matmul(lm_output, logit_weights, transpose_y=True) + return logits + + def forward(self, + input_ids, + position_ids=None, + attention_mask=None, + masked_positions=None, + use_cache=False, + cache=None): + outputs = self.gpt(input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + use_cache=use_cache, + cache=cache) + if use_cache: + encoder_outputs, cached_kvs = outputs[:2] + else: + encoder_outputs = outputs + logits = self.parallel_matmul(encoder_outputs, self.weight, True, + self.gpt.topo) + + if use_cache: + return logits, cached_kvs + else: + return logits + + +class GPTPretrainingCriterion(nn.Layer): + """ + Criterion for GPT. + It calculates the final loss. + """ + + def __init__(self, topo=None): + super(GPTPretrainingCriterion, self).__init__() + if topo is None or topo.mp_info.size == 1: + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none") + else: + self.loss_func = paddle.distributed.collective._c_softmax_with_cross_entropy + + def forward(self, prediction_scores, masked_lm_labels, loss_mask): + masked_lm_loss = self.loss_func(prediction_scores, + masked_lm_labels.unsqueeze(2)) + + loss_mask = loss_mask.reshape([-1]) + masked_lm_loss = paddle.sum(masked_lm_loss.reshape([-1]) * loss_mask) + loss = masked_lm_loss / loss_mask.sum() + return loss + + +def gpt_pretrain_forward(train_program, start_program): + with static.program_guard(train_program, + start_program), utils.unique_name.guard(): + batch_size = 16 + sequence_len = 512 + input_ids = static.data( + name="input_ids", shape=[batch_size, sequence_len], dtype='int64') + position_ids = static.data( + name="position_ids", + shape=[batch_size, sequence_len], + dtype='int64') + attention_mask = static.data( + name="attention_mask", + shape=[batch_size, 1, sequence_len, sequence_len], + dtype='float64') + labels = static.data( + name="labels", shape=[batch_size, sequence_len], dtype='int64') + loss_mask = static.data( + name="loss_mask", shape=[batch_size, sequence_len], dtype='float64') + + if _global_parallel_stratergy == "dp": + auto.shard_tensor( + input_ids, _global_process_mesh, dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + input_ids, _global_process_mesh, dim_mapping=[0, -1]) + + gpt = GPTModel( + vocab_size=32768, + hidden_size=768, + num_hidden_layers=2, + num_attention_heads=12, + intermediate_size=4096, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=1024, + type_vocab_size=16, + initializer_range=0.02, + pad_token_id=0, + topo=None) + + model = GPTForPretraining(gpt) + + preds = model(input_ids, position_ids, attention_mask) + + criterion = GPTPretrainingCriterion() + + loss = criterion(preds, labels, loss_mask) + + return train_program, start_program, loss + + +class TestGPTPartitioner(unittest.TestCase): + def test_gpt_dp_mp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "dp_mp" + global _global_process_mesh + + _global_process_mesh = auto.ProcessMesh( + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + + train_program = static.Program() + start_program = static.Program() + dist_context = DistributedContext() + dist_context.set_process_mesh(_global_process_mesh) + train_program, start_program, loss = gpt_pretrain_forward(train_program, + start_program) + complete_train_program = auto.complete_annotation(train_program, + dist_context) + rank_id = 3 + dist_strategy = fleet.DistributedStrategy() + partitioner = Partitioner(dist_strategy, dist_context, rank_id) + auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward( + complete_train_program, start_program) + dist_params_grads = partitioner.apply_backward( + loss, complete_train_program, start_program, + auto_parallel_main_prog, auto_parallel_startup_prog) + optimizer = paddle.fluid.optimizer.AdamOptimizer( + learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) + opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads, + auto_parallel_main_prog, + auto_parallel_startup_prog) + + nrank = 4 + # col parallel + weights = [ + 'linear_0.w_0', + 'linear_6.w_0', + 'linear_10.w_0', + ] + self.assertTrue( + check_tensor_split(auto_parallel_main_prog, weights, + complete_train_program, weights, 1, nrank)) + + # row parallel + weights = ['word_embeddings', 'linear_9.w_0', 'linear_11.w_0'] + self.assertTrue( + check_tensor_split(auto_parallel_main_prog, weights, + complete_train_program, weights, 0, nrank)) + + weights = ['pos_embeddings', 'layer_norm_0.b_0', 'layer_norm_4.w_0'] + self.assertTrue( + check_tensor_split(auto_parallel_main_prog, weights, + complete_train_program, weights, 0, 1)) + + all_params = sorted( + [param.name for param in start_program.all_parameters()]) + allreduce_grads = [ + 'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2', + 'layer_norm_6.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_7.tmp_2', + 'layer_norm_7.tmp_2', 'layer_norm_8.tmp_2' + ] + mp_parallel_axis, process_mesh = dist_context._get_model_parallel_info() + group_ranks = _get_comm_group(process_mesh.process_group, + process_mesh.topology, mp_parallel_axis, + 3) + mp_ring_id = new_process_group(group_ranks).id + dp_parallel_axis, process_mesh = dist_context._get_data_parallel_info() + group_ranks = _get_comm_group(process_mesh.process_group, + process_mesh.topology, dp_parallel_axis, + 3) + dp_ring_id = new_process_group(group_ranks).id + tensor_parallel_allreduce_vars = sorted([ + op.desc.output_arg_names()[0].split("@")[0] + for op in auto_parallel_main_prog.global_block().ops + if (op.type == "c_allreduce_sum" and op.attr('op_role') == 1 and + op.desc.attr("ring_id") == mp_ring_id) + ]) + data_parallel_allreduce_vars = sorted([ + op.desc.output_arg_names()[0].split("@")[0] + for op in auto_parallel_main_prog.global_block().ops + if (op.type == "c_allreduce_sum" and op.desc.attr("ring_id") == + dp_ring_id) + ]) + + self.assertTrue(all_params == data_parallel_allreduce_vars) + self.assertTrue(allreduce_grads == tensor_parallel_allreduce_vars) + + +if __name__ == "__main__": + unittest.main() -- GitLab