From fa97e5bad5829a079749f3947bf93af0673c8d5d Mon Sep 17 00:00:00 2001 From: wuhuachaocoding <77733235+wuhuachaocoding@users.noreply.github.com> Date: Fri, 16 Sep 2022 10:53:58 +0800 Subject: [PATCH] refactor mp. (#45803) * refactor mp. * update setup.py. * update mp_layers.py for compatibility. * add documents for mp_layers.py * update init.py * update collective.py. * update. * update mp_ops.py * update. * update code style. * update code style. --- python/paddle/distributed/collective.py | 790 +----------------- .../distributed/communication/comm_utils.py | 50 ++ .../distributed/fleet/layers/mpu/__init__.py | 24 + .../distributed/fleet/layers/mpu/mp_layers.py | 466 +++++++++++ .../distributed/fleet/layers/mpu/mp_ops.py | 772 +++++++++++++++++ .../distributed/fleet/layers/mpu/random.py | 243 ++++++ .../parallel_layers/mp_layers.py | 299 +------ .../meta_parallel/parallel_layers/random.py | 234 +----- python/setup.py.in | 2 + 9 files changed, 1581 insertions(+), 1299 deletions(-) create mode 100644 python/paddle/distributed/communication/comm_utils.py create mode 100644 python/paddle/distributed/fleet/layers/mpu/__init__.py create mode 100644 python/paddle/distributed/fleet/layers/mpu/mp_layers.py create mode 100644 python/paddle/distributed/fleet/layers/mpu/mp_ops.py create mode 100644 python/paddle/distributed/fleet/layers/mpu/random.py diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 57731b8ad0e..5ceb046f550 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -40,46 +40,23 @@ import paddle.fluid.core as core from paddle import _C_ops, _legacy_C_ops import paddle.fluid.dygraph_utils as dygraph_utils import contextlib +from .fleet.layers.mpu.mp_ops import split +from .fleet.layers.mpu.mp_ops import _c_identity +from .fleet.layers.mpu.mp_ops import _c_concat +from .fleet.layers.mpu.mp_ops import _c_split +from .fleet.layers.mpu.mp_ops import _mp_allreduce +from .fleet.layers.mpu.mp_ops import _c_lookup_table +from .fleet.layers.mpu.mp_ops import _Linear +from .fleet.layers.mpu.mp_ops import _set_var_distributed +from .fleet.layers.mpu.mp_ops import _c_softmax_with_cross_entropy +from .fleet.layers.mpu.mp_ops import _linear +from .fleet.layers.mpu.mp_ops import _parallel_linear +from .fleet.layers.mpu.mp_ops import _parallel_embedding +from .communication.comm_utils import ReduceOp __all__ = [] -class ReduceOp: - """ - Specify the type of operation used for element-wise reductions. - It should be one of the following values: - - ReduceOp.SUM - - ReduceOp.MAX - - ReduceOp.MIN - - ReduceOp.PROD - - Examples: - .. code-block:: python - - # required: distributed - import paddle - import paddle.distributed as dist - - dist.init_parallel_env() - if dist.get_rank() == 0: - data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]]) - else: - data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]]) - dist.all_reduce(data, op=dist.ReduceOp.SUM) - print(data) - # [[5, 7, 9], [5, 7, 9]] (2 GPUs) - """ - SUM = 0 - MAX = 1 - MIN = 2 - PROD = 3 - AVG = 4 - - class Group(): """ The abstract representation of group. @@ -1259,747 +1236,6 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True): }) -def _c_identity(tensor, group=None): - """ - Return a copy of the tensor, mainly used with model parallel. - - Args: - tensor (Tensor): The input Tensor. Its data type - should be float16, float32, float64, int32 or int64. - group (int): The id of the process group to work on. - - Returns: - Tensor. - """ - if group is not None and not group.is_member(): - return - ring_id = 0 if group is None else group.id - - if _non_static_mode(): - return _legacy_C_ops.c_identity(tensor, 'use_calc_stream', True, - 'ring_id', ring_id, - 'use_model_parallel', True) - op_type = 'c_identity' - helper = LayerHelper(op_type, **locals()) - out = helper.create_variable_for_type_inference(dtype=tensor.dtype) - - check_variable_and_dtype( - tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], - '_c_identity') - - helper.append_op(type=op_type, - inputs={'X': tensor}, - outputs={'Out': out}, - attrs={ - 'ring_id': ring_id, - 'use_calc_stream': True, - 'use_model_parallel': True, - }) - return out - - -def _c_concat(tensor, group=None): - """ - Return allgather of the tensor, mainly used with model parallel. - - Args: - tensor (Tensor): The input Tensor. Its data type - should be float16, float32, float64, int32 or int64. - group (int): The id of the process group to work on. - - Returns: - Tensor. - """ - if group is not None and not group.is_member(): - return - group = _get_default_group() if group is None else group - ring_id = group.id - - global_rank = _get_global_env().rank - rank = group.rank - nranks = group.nranks - - if _non_static_mode(): - return _legacy_C_ops.c_concat(tensor, 'ring_id', ring_id, - 'use_calc_stream', True, 'rank', rank, - 'nranks', nranks, 'use_model_parallel', - True) - - op_type = 'c_concat' - helper = LayerHelper(op_type, **locals()) - out = helper.create_variable_for_type_inference(dtype=tensor.dtype) - - check_variable_and_dtype( - tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], - '_c_concat') - - helper.append_op(type=op_type, - inputs={'X': tensor}, - outputs={'Out': out}, - attrs={ - 'ring_id': ring_id, - 'use_calc_stream': True, - 'use_model_parallel': True, - 'nranks': nranks, - 'rank': rank - }) - return out - - -def _c_split(tensor, group=None): - """ - Split tensor evenly among all members, mainly used with model parallel. - - Args: - tensor (Tensor): The input Tensor. Its data type - should be float16, float32, float64, int32 or int64. - rank (int): The rank of the current process. - group (int): The id of the process group to work on. - - Returns: - Tensor. - """ - if group is not None and not group.is_member(): - return - ring_id = 0 if group is None else group.id - - global_rank = _get_global_env().rank - rank = global_rank if group is None else group.get_group_rank(global_rank) - nranks = _get_global_env().world_size if group is None else group.nranks - - if _non_static_mode(): - return _legacy_C_ops.c_split(tensor, 'use_calc_stream', True, 'ring_id', - ring_id, 'rank', rank, 'nranks', nranks, - 'use_model_parallel', True) - - op_type = 'c_split' - helper = LayerHelper(op_type, **locals()) - out = helper.create_variable_for_type_inference(dtype=tensor.dtype) - - check_variable_and_dtype( - tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], - '_c_split') - - helper.append_op(type=op_type, - inputs={'X': tensor}, - outputs={'Out': out}, - attrs={ - 'ring_id': ring_id, - 'use_calc_stream': True, - 'rank': rank, - 'nranks': nranks, - 'use_model_parallel': True, - }) - return out - - -def _mp_allreduce(tensor, - op=ReduceOp.SUM, - group=None, - use_calc_stream=True, - use_model_parallel=True): - """[it is same as allreduce above, but it supports model parallel. And it support inplace startegy] - """ - if group is not None and not group.is_member(): - return - - if in_dygraph_mode(): - group = _get_default_group() if group is None else group - assert op == ReduceOp.SUM, "Unknown parameter: {}.".format(op) - - from paddle.autograd import PyLayer - - class mp_allreduce_eager(PyLayer): - - @staticmethod - def forward(ctx, tensor, group, use_calc_stream, - use_model_parallel): - ctx.ring_id = group.id - - if use_calc_stream: - op_type = _get_reduce_op(op, "_mp_allreduce") - group.process_group.allreduce_on_calc_stream( - tensor, op_type) - return tensor - else: - return _legacy_C_ops.c_allreduce_sum_( - tensor, 'use_calc_stream', use_calc_stream, 'ring_id', - ring_id, "use_model_parallel", use_model_parallel) - - @staticmethod - def backward(ctx, dy): - return _legacy_C_ops.c_identity(dy, 'use_calc_stream', True, - 'ring_id', ctx.ring_id, - 'use_model_parallel', True) - - return mp_allreduce_eager.apply(tensor, group, use_calc_stream, - use_model_parallel) - - ring_id = 0 if group is None else group.id - if _in_legacy_dygraph(): - if op == ReduceOp.SUM: - return _legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream', - use_calc_stream, 'ring_id', - ring_id, "use_model_parallel", - use_model_parallel) - else: - raise ValueError("Unknown parameter: {}.".format(op)) - - op_type = 'c_allreduce_sum' - helper = LayerHelper(op_type, **locals()) - out = helper.create_variable_for_type_inference(dtype=tensor.dtype) - - check_variable_and_dtype( - tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], - op_type) - - helper.append_op(type=op_type, - inputs={'X': tensor}, - outputs={'Out': out}, - attrs={ - 'ring_id': ring_id, - 'use_calc_stream': use_calc_stream, - 'use_model_parallel': use_model_parallel, - }) - return out - - -def _c_lookup_table(table, index, start_index=0, name=None): - """ - Lookup table according to index. - - Args: - table (Tensor): The input Tensor. Its data type - should be float16, float32, float64. - index (Tensor): The index to lookup table. - start_index (int): The initial index for table range. - name (string): The name of the api - - Returns: - Tensor. - """ - if _non_static_mode(): - return _legacy_C_ops.c_embedding(table, index, "start_index", - start_index) - - op_type = 'c_embedding' - helper = LayerHelper(op_type, **locals()) - dtype = helper.input_dtype(input_param_name='table') - check_variable_and_dtype(index, 'input', ['int32', 'int64'], op_type) - tmp = helper.create_variable_for_type_inference(dtype) - helper.append_op(type='c_embedding', - inputs={ - 'Ids': index, - 'W': table - }, - outputs={'Out': tmp}, - attrs={"start_index": start_index}) - return tmp - - -class _Linear(layers.Layer): - """ - Linear - """ - - def __init__(self, - in_features, - out_features, - weight_attr=None, - bias_attr=None, - name=None): - super(_Linear, self).__init__() - self._dtype = self._helper.get_default_dtype() - self._weight_attr = weight_attr - self._bias_attr = bias_attr - self.weight = self.create_parameter(shape=[in_features, out_features], - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False) - self.bias = self.create_parameter(shape=[out_features], - attr=self._bias_attr, - dtype=self._dtype, - is_bias=True) - self.name = name - - def forward(self, input): - out = _linear(x=input, - weight=self.weight, - bias=self.bias, - name=self.name) - return out - - def extra_repr(self): - name_str = ', name={}'.format(self.name) if self.name else '' - return 'in_features={}, out_features={}, dtype={}{}'.format( - self.weight.shape[0], self.weight.shape[1], self._dtype, name_str) - - -def _c_softmax_with_cross_entropy(logits, - label, - group=None, - return_softmax=False): - if group is not None and not group.is_member(): - return - ring_id = 0 if group is None else group.id - global_rank = _get_global_env().rank - rank = global_rank if group is None else group.get_group_rank(global_rank) - nranks = _get_global_env().world_size if group is None else group.nranks - - input_dims = len(list(logits.shape)) - label_dims = len(list(label.shape)) - if input_dims - 1 != label_dims and input_dims != label_dims: - raise ValueError( - 'Expected nput_dims - 1 = label_dims or input_dims == label_dims\ - (got nput_dims{}, label_dims{})'.format(input_dims, label_dims)) - if input_dims - 1 == label_dims: - label = paddle.unsqueeze(label, axis=-1) - - if _non_static_mode(): - softmax, loss = _legacy_C_ops.c_softmax_with_cross_entropy( - logits, label, 'ring_id', ring_id, 'rank', rank, 'nranks', nranks) - if not return_softmax: - return loss - else: - return loss, softmax - - attrs = { - 'ring_id': ring_id, - 'rank': rank, - 'nranks': nranks, - } - helper = LayerHelper('c_softmax_with_cross_entropy', **locals()) - softmax = helper.create_variable_for_type_inference(dtype=logits.dtype) - loss = helper.create_variable_for_type_inference(dtype=logits.dtype) - helper.append_op(type='c_softmax_with_cross_entropy', - inputs={ - 'Logits': logits, - 'Label': label - }, - outputs={ - 'Softmax': softmax, - 'Loss': loss - }, - attrs=attrs) - - if return_softmax: - return loss, softmax - - return loss - - -def _linear(x, weight, bias=None, name=None): - """ - Fuction Linear - """ - if _non_static_mode(): - pre_bias = _varbase_creator(dtype=x.dtype) - _legacy_C_ops.matmul(x, weight, pre_bias, 'transpose_X', False, - 'transpose_Y', False, "alpha", 1) - return dygraph_utils._append_bias_in_dygraph(pre_bias, - bias, - axis=len(x.shape) - 1) - else: - helper = LayerHelper('linear', **locals()) - dtype = x.dtype - assert len( - x.shape) < 4, "X latitude is not supported greater than 3 now." - - check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], - 'linear') - check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear') - - inputs = {'X': [x], 'Y': [weight]} - attrs = { - 'transpose_X': False, - 'transpose_Y': False, - 'alpha': 1, - } - tmp = helper.create_variable_for_type_inference(dtype) - helper.append_op(type='matmul_v2', - inputs=inputs, - outputs={'Out': tmp}, - attrs=attrs) - if bias is not None: - res = helper.create_variable_for_type_inference(dtype) - helper.append_op(type='elementwise_add', - inputs={ - 'X': [tmp], - 'Y': [bias] - }, - outputs={'Out': [res]}, - attrs={'axis': len(x.shape) - 1}) - else: - res = tmp - return res - - -def _set_var_distributed(var): - if var is None: - return - - var.is_distributed = True - - # NOTE: use current_block and find_var_recursive to support while_loop - startup_block = paddle.static.default_startup_program().current_block() - main_block = paddle.static.default_main_program().current_block() - startup_block._find_var_recursive(var.name).is_distributed = True - main_block._find_var_recursive(var.name).is_distributed = True - - -def _parallel_linear(x, - num_rows, - num_cols, - axis, - param_attr, - bias_attr, - gather_out, - inner_rank, - nranks, - split_tensor, - name, - group=None): - """ - Parallel Linear - - axis the dimension of the parameter of linear layer. - axis = 0: the row dimension - axis = 1: the col dimension - - """ - if group is not None and not group.is_member(): - return - ring_id = 0 if group is None else group.id - - if axis == 0: - if split_tensor: - x = _c_split(x, group=group) - else: - x = _c_identity(x, group=group) - - linear = paddle.nn.Linear(num_rows, - num_cols, - weight_attr=param_attr, - bias_attr=bias_attr, - name=name) - - # NOTE: npu linear function use matmul_v2 but linear use matmul - linear_function = _linear if core.is_compiled_with_npu()\ - else paddle.nn.functional.linear - linear_out = linear_function( - x, - linear.weight, - # NOTE(wangxi): row split, bias need add after allreduce - None if axis == 0 else linear.bias, - linear.name) - - _set_var_distributed(linear.weight) - # set is_distributed for splited bias - # if a linear layer is splited by row, each rank would hold a complete bias and they should be the same in each rank. - # if a linear layer is splited by col, the bias would also be split into each rank as its weight - if axis == 1 and linear._bias_attr != False: - _set_var_distributed(linear.bias) - - if not gather_out: return linear_out - - out_shape = list(linear_out.shape) - out_shape[0] *= 1 if axis == 0 else nranks - main_block = paddle.static.default_main_program().current_block() - out = main_block.create_var( - shape=out_shape, - dtype=linear_out.dtype, - type=linear_out.type, - lod_level=linear_out.lod_level, - persistable=False, - is_data=False, - need_check_feed=linear_out.desc.need_check_feed()) - if axis == 0: - main_block.append_op(type='c_allreduce_sum', - inputs={'X': linear_out}, - outputs={'Out': out}, - attrs={ - 'ring_id': ring_id, - 'use_calc_stream': True, - 'use_model_parallel': True - }) - if linear.bias is not None: - out = out + linear.bias - else: - main_block.append_op(type='c_concat', - inputs={'X': linear_out}, - outputs={'Out': out}, - attrs={ - 'rank': inner_rank, - 'ring_id': ring_id, - 'nranks': nranks, - 'use_calc_stream': True, - 'use_model_parallel': True - }) - return out - - -def _parallel_embedding(x, - per_part_embeddings, - origin_size, - param_attr, - inner_rank, - num_partitions, - name, - group=None): - """ - Parallel Embedding - """ - if group is not None and not group.is_member(): - return - ring_id = 0 if group is None else group.id - - helper = LayerHelper("_parallel_embedding", **locals()) - - per_part_size = per_part_embeddings - rank = inner_rank - - vocab_start_index = rank * per_part_size - dtype = helper.get_default_dtype() - size = [per_part_size, origin_size[1]] - - weight = helper.create_parameter(attr=param_attr, - shape=size, - dtype=dtype, - is_bias=False) - - if num_partitions == 1: - return paddle.nn.functional.embedding(x, - weight=weight, - padding_idx=None, - sparse=False, - name=name) - - startup_block = paddle.static.default_startup_program().global_block() - main_block = paddle.static.default_main_program().global_block() - startup_block.vars[weight.name].is_distributed = True - main_block.vars[weight.name].is_distributed = True - - output_parallel = paddle.distributed.collective._c_lookup_table( - weight, x, start_index=vocab_start_index, name=name) - out = paddle.distributed.collective._mp_allreduce(output_parallel, - group=group, - use_calc_stream=True, - use_model_parallel=True) - return out - - -def split(x, - size, - operation, - axis=0, - num_partitions=1, - gather_out=True, - weight_attr=None, - bias_attr=None, - name=None): - """ - - Split the weight of the specified operation into multiple devices - and do the computation in parallel. - - Now the following three cases are supported. - - Case 1: Parallel Embedding - The weight of the embedding operation is a NxM matrix with N rows and M columns. - With parallel embedding, the weight is split into num_partitions partitions, each - of which is a matrix with (N/num_partitions + 1) rows and M column where the last - row as the padding idx. - - Suppose we split the NxM weight into two partitons on device_0 and device_1 - respectively. Then, one each device, the final weight has (N/2 + 1) rows with the - index range from 0 to N/2. On device_0, all values in the input within [0, N/2 -1] - keep unchanged and all other values are changed to N/2 which is the padding index and - are mapped to all zeros after embedding. In the same way, on device_1, the value V in the - input within [N/2, N-1] will be changed to (V - N/2), and all other values are changed - to N/2 and are mapped to all zeros after embedding. Finally, the results on the two - devices are sum-reduced. - - The Embedding put on single card is as shown below: - - .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_embedding_single.png - :width: 800 - :height: 350 - :alt: single_embedding - :align: center - - Parallel Embedding is shown as below: - - .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_embedding_split.png - :width: 800 - :alt: split_embedding - :align: center - - Case 2: Row Parallel Linear - The weight of the linear operation is a NxM matrix with N rows and M columns. - With row parallel linear, the weight is split into num_partitions partitions, each - of which is a matrix with N/num_partitions rows and M column. - - The linear layer put on single card is shown as below, the input variable is represented by X, - the weight matrix is represented by W and the output vaiable is O. The linear layer on single card is - simple matrix multiplication operation, O = X * W. - - .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_single.png - :width: 800 - :alt: single_linear - :align: center - - Row Parallel Linear is shown as below. As the name suggests, Row Parallel Linear splits the weight matrix W into - [[W_row1], [W_row2]] along the row. And accordingly the input is splitted along the column into [X_col1, X_col2] and multiply their - respective weight matrices. Finally apply AllReduce on the output from each card to get the final output. - - .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_row.png - :width: 800 - :alt: split_row - :align: center - - Case 3: Column Parallel Linear - The weight of the linear operation is a NxM matrix with N rows and M columns. - With column parallel linear, the weight is split into num_paratitions partitions, each - of which is a matrix with N rows and M/num_partitions column. - - The linear layer put on single card has been illustrated on case 2 and Column Parallel Linear - is shown as below. The Column Parallel Linear splits the weight matrix W into [W_col1, W_col2] along the column and - these splitted matrices respectively multiply the input. Finally apply AllGather on the output from each card to get the final output. - - .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_col.png - :width: 800 - :alt: split_col - :align: center - - As observed, the column parallel linear and row parallel linear can be combined to skip one ALLGATHER communication - operator. Furthermore the Attention and MLP can be combined to imporve the performance as shown below. - - .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_col_row.png - :width: 800 - :alt: split_col_row - :align: center - - Args: - x (Tensor): Input tensor. It's data type should be float16, float32, float64, int32 or int64. - size (list|tuple): A list or tuple with two elements indicating the shape of the weight. - operation (str): The name of the operation. The supported operations are 'linear' and 'embedding'. - axis (int, Optional): Indicate along which axis to split the weight. Default: 0. - num_partitions (int, Optional): How many parts the weight is partitioned. Default: 1. - gather_out (bool, Optional): Whether to gather the output after computation. By default, the output - on each partitions will be gathered after computation. Default: True. - weight_attr (ParamAttr, Optional): The parameter attribute for the learnable - weights(Parameter) of the specified operation. Default: None. - bias_attr (ParamAttr, Optional): The parameter attribute for the bias - of the specified operation. Default: None. - name (str, Optional): The default value is None. Normally there is no need for user to set this - property. Default: None. For more information, please refer to :ref:`api_guide_Name`. - - Returns: - Tensor. - - Examples: - .. code-block:: python - - # required: distributed - import paddle - import paddle.distributed.fleet as fleet - - paddle.enable_static() - paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id) - fleet.init(is_collective=True) - data = paddle.randint(0, 8, shape=[10,4]) - emb_out = paddle.distributed.split( - data, - (8, 8), - operation="embedding", - num_partitions=2) - - """ - assert isinstance( - size, - (list, tuple)), ("The type of size for " - "paddle.distributed.split must be list or tuple.") - assert len(size) == 2, ("Number of elements in size of " - "paddle.distributed.split must be two.") - assert isinstance(operation, str), ("The type of operation for " - "paddle.distributed.split must be str.") - supported_operations = [ - 'linear', - 'embedding', - ] - assert operation in supported_operations, ( - "The operation for " - "paddle.distributed.split must be one of {}.".format( - supported_operations)) - if _non_static_mode(): - raise ValueError( - "paddle.distributed.split cannot be used in dynamic " - "graph mode, plese use ParallelEmbedding, ParallelRowLinear, " - "ParallelColumnLinear instead.") - else: - from .fleet import fleet - assert fleet._role_maker, ("To use paddle.distributed.split, " - "you must call fleet.init() firstly.") - rank = fleet.worker_index() - nranks = fleet.worker_num() - - # rank within a model parallel group - inner_rank = rank % num_partitions - - if operation == "embedding": - assert axis == 0, ("We only support to split the weight of embedding " - "along the first axis now.") - assert size[0] % num_partitions == 0, \ - "The length of the vocabulary must be divisible by num_partitions " \ - "but received vocabulary={} num_partitions={}".format(size[0], num_partitions) - - per_part_size = size[0] // num_partitions - emb_out = _parallel_embedding(x, - per_part_size, - size, - weight_attr, - inner_rank, - num_partitions, - name, - group=None) - return emb_out - else: - should_split = False - if axis == 0: - assert size[0] % num_partitions == 0, ( - "Number of rows of the weight for linear ({}) must be" - " divisible by num_partitions ({})".format( - size[0], num_partitions)) - per_part_size = size[0] // num_partitions - linear_size = (per_part_size, size[1]) - if x.shape[-1] == size[0]: should_split = True - - elif axis == 1: - assert size[1] % num_partitions == 0, ( - "Number of column of the weight for linear ({}) must be" - " divisible by num_partitions ({})".format( - size[1], num_partitions)) - per_part_size = size[1] // num_partitions - linear_size = (size[0], per_part_size) - else: - raise ValueError("The value of axis must be 0 or 1, but the value " - "given is {}.".format(axis)) - - linear_out = _parallel_linear(x, - linear_size[0], - linear_size[1], - axis, - weight_attr, - bias_attr, - gather_out, - inner_rank, - num_partitions, - should_split, - name=name, - group=None) - return linear_out - - def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True): """ Scatter tensors in in_tensor_list to all participators averagely and gather the result tensors in out_tensor_list. diff --git a/python/paddle/distributed/communication/comm_utils.py b/python/paddle/distributed/communication/comm_utils.py new file mode 100644 index 00000000000..62e1bcb4cca --- /dev/null +++ b/python/paddle/distributed/communication/comm_utils.py @@ -0,0 +1,50 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class ReduceOp: + """ + + Specify the type of operation used for element-wise reductions. + It should be one of the following values: + + ReduceOp.SUM + + ReduceOp.MAX + + ReduceOp.MIN + + ReduceOp.PROD + + Examples: + .. code-block:: python + + # required: distributed + import paddle + import paddle.distributed as dist + + dist.init_parallel_env() + if dist.get_rank() == 0: + data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]]) + else: + data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]]) + dist.all_reduce(data, op=dist.ReduceOp.SUM) + print(data) + # [[5, 7, 9], [5, 7, 9]] (2 GPUs) + """ + SUM = 0 + MAX = 1 + MIN = 2 + PROD = 3 + AVG = 4 diff --git a/python/paddle/distributed/fleet/layers/mpu/__init__.py b/python/paddle/distributed/fleet/layers/mpu/__init__.py new file mode 100644 index 00000000000..11b69702650 --- /dev/null +++ b/python/paddle/distributed/fleet/layers/mpu/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .mp_layers import VocabParallelEmbedding +from .mp_layers import ColumnParallelLinear +from .mp_layers import RowParallelLinear +from .mp_layers import ParallelCrossEntropy + +from .random import RNGStatesTracker +from .random import get_rng_state_tracker +from .random import model_parallel_random_seed +from .random import determinate_seed +from .random import dropout diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_layers.py b/python/paddle/distributed/fleet/layers/mpu/mp_layers.py new file mode 100644 index 00000000000..2ba9ce9ed76 --- /dev/null +++ b/python/paddle/distributed/fleet/layers/mpu/mp_layers.py @@ -0,0 +1,466 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from . import mp_ops +from paddle.fluid import core +from paddle.fluid.dygraph.layers import Layer +from .random import get_rng_state_tracker +from paddle.nn import functional as F +from paddle import framework +from paddle.autograd import PyLayer +from ...base import topology as tp + +__all__ = [] + +# Follow this paper to achieve the file: +# Shoeybi M, Patwary M, Puri R, et al. Megatron-lm: Training multi-billion parameter +# language models using model parallelism[J]. arXiv preprint arXiv:1909.08053, 2019. (https://arxiv.org/abs/1909.08053) + + +def is_fused_matmul_bias_supported(): + if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm(): + return hasattr(core.ops, 'fused_gemm_epilogue') + else: + return False + + +class VocabParallelEmbedding(Layer): + """Embedding mp parallelized in the vocabulary dimension. + this class is used for splitting embedding in mp group. + + Args: + num_embeddings(int): One element which indicate the size of the dictionary of embeddings. + embedding_dim(int): One element which indicate the size of each embedding vector respectively. + weight_attr(ParamAttr|None): To specify the weight parameter property. Default: None, which means the + default weight parameter property is used. See usage for details in :ref:`api_ParamAttr` . In addition, + user-defined or pre-trained word vectors can be loaded with the :attr:`param_attr` parameter. + The local word vector needs to be transformed into numpy format, and the shape of local word + vector should be consistent with :attr:`num_embeddings` . Then :ref:`api_initializer_NumpyArrayInitializer` + is used to load custom or pre-trained word vectors. See code example for details. + mp_group(Group): The tensor parallel group. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Examples: + .. code-block:: python + import paddle + from paddle.distributed import fleet + + class SimpleMPNet(paddle.nn.Layer): + def __init__(self, vocab_size, hidden_size, inner_size, output_size): + super(SimpleMPNet, self).__init__() + self.linear1 = fleet.meta_parallel.ColumnParallelLinear( + hidden_size, + inner_size, + gather_output=False, + has_bias=True) + + self.linear2 = fleet.meta_parallel.RowParallelLinear( + inner_size, + hidden_size, + input_is_parallel=True, + has_bias=True) + + self.linear3 = paddle.nn.Linear(hidden_size, output_size) + + self.embedding = fleet.meta_parallel.VocabParallelEmbedding( + vocab_size, + hidden_size) + + def forward(self, x): + x = self.embedding(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + return x + """ + + def __init__(self, + num_embeddings, + embedding_dim, + weight_attr=None, + mp_group=None, + name=None): + super(VocabParallelEmbedding, self).__init__() + + self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( + ) if mp_group is None else mp_group + self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( + ) if mp_group is None else mp_group.nranks + self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank( + ) if mp_group is None else mp_group.rank + + self.origin_num_embeddings = num_embeddings + self.is_mp = (self.world_size > 1) + + assert num_embeddings % self.world_size == 0, ( + "The length of the vocabulary must be divisible by the parallelism degree of MP" + ) + + per_part_size = num_embeddings // self.world_size + + self.vocab_start_index = self.rank * per_part_size + self._dtype = self._helper.get_default_dtype() + self._size = [per_part_size, embedding_dim] + self._weight_attr = weight_attr + self._name = name + + if self.is_mp and paddle.in_dynamic_mode(): + with get_rng_state_tracker().rng_state(): + self.weight = self.create_parameter(attr=self._weight_attr, + shape=self._size, + dtype=self._dtype, + is_bias=False) + else: + self.weight = self.create_parameter(attr=self._weight_attr, + shape=self._size, + dtype=self._dtype, + is_bias=False) + + self.weight.is_distributed = True if self.is_mp else False + + def forward(self, x): + if self.is_mp: + output_parallel = mp_ops._c_lookup_table( + self.weight, + x, + start_index=self.vocab_start_index, + name=self._name) + output = mp_ops._mp_allreduce(output_parallel, + group=self.model_parallel_group, + use_calc_stream=True, + use_model_parallel=True) + else: + output = F.embedding(x, + weight=self.weight, + padding_idx=None, + sparse=False, + name=self._name) + return output + + +class ColumnParallelLinear(Layer): + """Linear layer with mp parallelized(column). + this class is used for splitting Linear Layer in mp group, column split the weight of the Linear layer. + + Args: + in_features(int): The number of input units. + out_features(int): The number of output units. + weight_attr(ParamAttr|None): The attribute for the learnable weight of this layer. The default value is None + and the weight will be initialized to zero. For detailed information, please refer to paddle.ParamAttr. + has_bias(bool): whether to add bias. + gather_output(bool): whether to do allgahter for the output of each rank. + fuse_matmul_bias(bool): whether to fuse matmul and bias. + mp_group(Group): The tensor parallel group. + name(str, optional): Normally there is no need for user to set this parameter. + For detailed information, please refer to :ref:`api_guide_Name` . + + Examples: + .. code-block:: python + import paddle + from paddle.distributed import fleet + + class SimpleMPNet(paddle.nn.Layer): + def __init__(self, vocab_size, hidden_size, inner_size, output_size): + super(SimpleMPNet, self).__init__() + self.linear1 = fleet.meta_parallel.ColumnParallelLinear( + hidden_size, + inner_size, + gather_output=False, + has_bias=True) + + self.linear2 = fleet.meta_parallel.RowParallelLinear( + inner_size, + hidden_size, + input_is_parallel=True, + has_bias=True) + + self.linear3 = paddle.nn.Linear(hidden_size, output_size) + + self.embedding = fleet.meta_parallel.VocabParallelEmbedding( + vocab_size, + hidden_size) + + def forward(self, x): + x = self.embedding(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + return x + """ + + def __init__(self, + in_features, + out_features, + weight_attr=None, + has_bias=None, + gather_output=True, + fuse_matmul_bias=False, + mp_group=None, + name=None): + super(ColumnParallelLinear, self).__init__() + + self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( + ) if mp_group is None else mp_group + self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( + ) if mp_group is None else mp_group.nranks + self._name = name + self.is_mp = (self.world_size > 1) + + self.gather_output = gather_output + assert out_features % self.world_size == 0, ( + "Number of column of the weight for linear ({}) must be" + " divisible by model parallel size ({})".format( + out_features, self.world_size)) + self.output_size_per_partition = out_features // self.world_size + + self._weight_attr = weight_attr + self._dtype = self._helper.get_default_dtype() + + if self.is_mp and paddle.in_dynamic_mode(): + with get_rng_state_tracker().rng_state(): + self.weight = self.create_parameter( + shape=[in_features, self.output_size_per_partition], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False) + else: + self.weight = self.create_parameter( + shape=[in_features, self.output_size_per_partition], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False) + + self.weight.is_distributed = True if self.is_mp else False + + if has_bias: + # initialize bias to zero like Megatron + self.bias = self.create_parameter( + shape=[self.output_size_per_partition], + attr=paddle.nn.initializer.Constant(value=0.0), + dtype=self._dtype, + is_bias=True) + self.bias.is_distributed = True if self.is_mp else False + else: + self.bias = None + + self.linear = F.linear + + if fuse_matmul_bias: + if not is_fused_matmul_bias_supported(): + raise NotImplementedError( + "You set fuse_matmul_bias=True in ColumnParallelLinear, " + "however, the paddle you are using not support this operation. " + "Please set fuse_matmul_bias=False or use paddle compiled " + "with cuda 11.6 or higher.") + from paddle.incubate.nn.functional import fused_linear + self.linear = fused_linear + + def forward(self, x): + # use inner api to process identity + if self.is_mp: + input_parallel = mp_ops._c_identity(x, + group=self.model_parallel_group) + else: + input_parallel = x + + output_parallel = self.linear(input_parallel, + self.weight, + self.bias, + name=self._name) + + if self.gather_output and self.is_mp: + output = mp_ops._c_concat(output_parallel, + group=self.model_parallel_group) + else: + output = output_parallel + return output + + +class RowParallelLinear(Layer): + """Linear layer with mp parallelized(row). + this class is used for splitting Linear Layer in mp group, row split the weight of the Linear layer. + + Args: + in_features(int): The number of input units. + out_features(int): The number of output units. + weight_attr(ParamAttr|None): The attribute for the learnable weight of this layer. The default value is None + and the weight will be initialized to zero. For detailed information, please refer to paddle.ParamAttr. + has_bias(bool): whether to add bias. + input_is_parallel(bool): whether the input has alreadly been splitted across the mp group. + fuse_matmul_bias(bool): whether to fuse matmul and bias. + mp_group(Group): The tensor parallel group. + name(str, optional): Normally there is no need for user to set this parameter. + For detailed information, please refer to :ref:`api_guide_Name` . + + Examples: + .. code-block:: python + import paddle + from paddle.distributed import fleet + + class SimpleMPNet(paddle.nn.Layer): + def __init__(self, vocab_size, hidden_size, inner_size, output_size): + super(SimpleMPNet, self).__init__() + self.linear1 = fleet.meta_parallel.ColumnParallelLinear( + hidden_size, + inner_size, + gather_output=False, + has_bias=True) + + self.linear2 = fleet.meta_parallel.RowParallelLinear( + inner_size, + hidden_size, + input_is_parallel=True, + has_bias=True) + + self.linear3 = paddle.nn.Linear(hidden_size, output_size) + + self.embedding = fleet.meta_parallel.VocabParallelEmbedding( + vocab_size, + hidden_size) + + def forward(self, x): + x = self.embedding(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + return x + """ + + def __init__(self, + in_features, + out_features, + weight_attr=None, + has_bias=True, + input_is_parallel=False, + fuse_matmul_bias=False, + mp_group=None, + name=None): + super(RowParallelLinear, self).__init__() + + self.in_features = in_features + self.out_features = out_features + self.input_is_parallel = input_is_parallel + self._weight_attr = weight_attr + self._dtype = self._helper.get_default_dtype() + self._name = name + + self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( + ) if mp_group is None else mp_group + self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( + ) if mp_group is None else mp_group.nranks + self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank( + ) if mp_group is None else mp_group.rank + + self.is_mp = (self.world_size > 1) + assert in_features % self.world_size == 0, ( + "Number of row of the weight for linear ({}) must be" + " divisible by model parallel size ({})".format( + in_features, self.world_size)) + + self.input_size_per_partition = in_features // self.world_size + + if self.is_mp and paddle.in_dynamic_mode(): + with get_rng_state_tracker().rng_state(): + self.weight = self.create_parameter( + shape=[self.input_size_per_partition, self.out_features], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False) + else: + self.weight = self.create_parameter( + shape=[self.input_size_per_partition, self.out_features], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False) + + self.weight.is_distributed = True if self.is_mp else False + + if has_bias: + self.bias = self.create_parameter( + shape=[self.out_features], + attr=paddle.nn.initializer.Constant(value=0.0), + dtype=self._dtype, + is_bias=True) + else: + self.bias = None + + self.linear = F.linear + + if fuse_matmul_bias: + if not is_fused_matmul_bias_supported(): + raise NotImplementedError( + "You set fuse_matmul_bias=True in RowParallelLinear, " + "however, the paddle you are using not support this operation. " + "Please set fuse_matmul_bias=False or use paddle compiled " + "with cuda 11.6 or higher.") + from paddle.incubate.nn.functional import fused_linear + self.linear = fused_linear + + def forward(self, x): + if self.input_is_parallel or (not self.is_mp): + input_parallel = x + else: + # split last dim + input_parallel = mp_ops._c_split(x, group=self.model_parallel_group) + + if self.is_mp: + output_parallel = self.linear(input_parallel, + self.weight, + name=self._name) + output_ = mp_ops._mp_allreduce(output_parallel, + group=self.model_parallel_group, + use_calc_stream=True, + use_model_parallel=True) + output = output_ + self.bias if self.bias is not None else output_ + else: + output = self.linear(input_parallel, + self.weight, + self.bias, + name=self._name) + + return output + + +class ParallelCrossEntropy(Layer): + """CrossEntropy with mp parallelized. + this class is used for splitting softmax cross entropy in mp group. + + Args: + mp_group(Group): The tensor parallel group. + name(str, optional): Normally there is no need for user to set this parameter. + For detailed information, please refer to :ref:`api_guide_Name` . + + Examples: + .. code-block:: python + loss_func = ParallelCrossEntropy() + loss = loss_func(img, lable) + """ + + def __init__(self, mp_group=None, name=None): + super(ParallelCrossEntropy, self).__init__() + self.name = name + self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( + ) if mp_group is None else mp_group + self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( + ) if mp_group is None else mp_group.nranks + self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank( + ) if mp_group is None else mp_group.rank + + def forward(self, input, label): + loss = mp_ops._c_softmax_with_cross_entropy( + input, label, group=self.model_parallel_group) + return loss diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py new file mode 100644 index 00000000000..dc4dc05c7ba --- /dev/null +++ b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py @@ -0,0 +1,772 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import _C_ops, _legacy_C_ops +from paddle.fluid import core +from paddle.fluid.framework import _non_static_mode +from paddle.fluid.framework import _in_legacy_dygraph +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.data_feeder import check_variable_and_dtype +from paddle.fluid.dygraph import layers +from paddle.distributed import collective +from ....communication.comm_utils import ReduceOp +from paddle.fluid.data_feeder import check_dtype +import paddle.fluid.dygraph_utils as dygraph_utils + + +def _c_identity(tensor, group=None): + """ + Return a copy of the tensor, mainly used with model parallel. + + Args: + tensor (Tensor): The input Tensor. Its data type + should be float16, float32, float64, int32 or int64. + group (int): The id of the process group to work on. + + Returns: + Tensor. + """ + if group is not None and not group.is_member(): + return + ring_id = 0 if group is None else group.id + + if _non_static_mode(): + return _legacy_C_ops.c_identity(tensor, 'use_calc_stream', True, + 'ring_id', ring_id, + 'use_model_parallel', True) + op_type = 'c_identity' + helper = LayerHelper(op_type, **locals()) + out = helper.create_variable_for_type_inference(dtype=tensor.dtype) + + check_variable_and_dtype( + tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], + '_c_identity') + + helper.append_op(type=op_type, + inputs={'X': tensor}, + outputs={'Out': out}, + attrs={ + 'ring_id': ring_id, + 'use_calc_stream': True, + 'use_model_parallel': True, + }) + return out + + +def _c_concat(tensor, group=None): + """ + Return allgather of the tensor, mainly used with model parallel. + + Args: + tensor (Tensor): The input Tensor. Its data type + should be float16, float32, float64, int32 or int64. + group (int): The id of the process group to work on. + + Returns: + Tensor. + """ + if group is not None and not group.is_member(): + return + group = collective._get_default_group() if group is None else group + ring_id = group.id + + global_rank = collective._get_global_env().rank + rank = group.rank + nranks = group.nranks + + if _non_static_mode(): + return _legacy_C_ops.c_concat(tensor, 'ring_id', ring_id, + 'use_calc_stream', True, 'rank', rank, + 'nranks', nranks, 'use_model_parallel', + True) + + op_type = 'c_concat' + helper = LayerHelper(op_type, **locals()) + out = helper.create_variable_for_type_inference(dtype=tensor.dtype) + + check_variable_and_dtype( + tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], + '_c_concat') + + helper.append_op(type=op_type, + inputs={'X': tensor}, + outputs={'Out': out}, + attrs={ + 'ring_id': ring_id, + 'use_calc_stream': True, + 'use_model_parallel': True, + 'nranks': nranks, + 'rank': rank + }) + return out + + +def _c_split(tensor, group=None): + """ + Split tensor evenly among all members, mainly used with model parallel. + + Args: + tensor (Tensor): The input Tensor. Its data type + should be float16, float32, float64, int32 or int64. + rank (int): The rank of the current process. + group (int): The id of the process group to work on. + + Returns: + Tensor. + """ + if group is not None and not group.is_member(): + return + ring_id = 0 if group is None else group.id + + global_rank = collective._get_global_env().rank + rank = global_rank if group is None else group.get_group_rank(global_rank) + nranks = collective._get_global_env( + ).world_size if group is None else group.nranks + + if _non_static_mode(): + return _legacy_C_ops.c_split(tensor, 'use_calc_stream', True, 'ring_id', + ring_id, 'rank', rank, 'nranks', nranks, + 'use_model_parallel', True) + + op_type = 'c_split' + helper = LayerHelper(op_type, **locals()) + out = helper.create_variable_for_type_inference(dtype=tensor.dtype) + + check_variable_and_dtype( + tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], + '_c_split') + + helper.append_op(type=op_type, + inputs={'X': tensor}, + outputs={'Out': out}, + attrs={ + 'ring_id': ring_id, + 'use_calc_stream': True, + 'rank': rank, + 'nranks': nranks, + 'use_model_parallel': True, + }) + return out + + +def _mp_allreduce(tensor, + op=ReduceOp.SUM, + group=None, + use_calc_stream=True, + use_model_parallel=True): + """[it is same as allreduce above, but it supports model parallel. And it support inplace startegy] + """ + if group is not None and not group.is_member(): + return + + if in_dygraph_mode(): + group = collective._get_default_group() if group is None else group + assert op == ReduceOp.SUM, "Unknown parameter: {}.".format(op) + + from paddle.autograd import PyLayer + + class mp_allreduce_eager(PyLayer): + + @staticmethod + def forward(ctx, tensor, group, use_calc_stream, + use_model_parallel): + ctx.ring_id = group.id + + if use_calc_stream: + op_type = collective._get_reduce_op(op, "_mp_allreduce") + group.process_group.allreduce_on_calc_stream( + tensor, op_type) + return tensor + else: + return _legacy_C_ops.c_allreduce_sum_( + tensor, 'use_calc_stream', use_calc_stream, 'ring_id', + ring_id, "use_model_parallel", use_model_parallel) + + @staticmethod + def backward(ctx, dy): + return _legacy_C_ops.c_identity(dy, 'use_calc_stream', True, + 'ring_id', ctx.ring_id, + 'use_model_parallel', True) + + return mp_allreduce_eager.apply(tensor, group, use_calc_stream, + use_model_parallel) + + ring_id = 0 if group is None else group.id + if _in_legacy_dygraph(): + if op == ReduceOp.SUM: + return _legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream', + use_calc_stream, 'ring_id', + ring_id, "use_model_parallel", + use_model_parallel) + else: + raise ValueError("Unknown parameter: {}.".format(op)) + + op_type = 'c_allreduce_sum' + helper = LayerHelper(op_type, **locals()) + out = helper.create_variable_for_type_inference(dtype=tensor.dtype) + + check_variable_and_dtype( + tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], + op_type) + + helper.append_op(type=op_type, + inputs={'X': tensor}, + outputs={'Out': out}, + attrs={ + 'ring_id': ring_id, + 'use_calc_stream': use_calc_stream, + 'use_model_parallel': use_model_parallel, + }) + return out + + +def _c_lookup_table(table, index, start_index=0, name=None): + """ + Lookup table according to index. + + Args: + table (Tensor): The input Tensor. Its data type + should be float16, float32, float64. + index (Tensor): The index to lookup table. + start_index (int): The initial index for table range. + name (string): The name of the api + + Returns: + Tensor. + """ + if _non_static_mode(): + return _legacy_C_ops.c_embedding(table, index, "start_index", + start_index) + + op_type = 'c_embedding' + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name='table') + check_variable_and_dtype(index, 'input', ['int32', 'int64'], op_type) + tmp = helper.create_variable_for_type_inference(dtype) + helper.append_op(type='c_embedding', + inputs={ + 'Ids': index, + 'W': table + }, + outputs={'Out': tmp}, + attrs={"start_index": start_index}) + return tmp + + +class _Linear(layers.Layer): + """ + Linear + """ + + def __init__(self, + in_features, + out_features, + weight_attr=None, + bias_attr=None, + name=None): + super(_Linear, self).__init__() + self._dtype = self._helper.get_default_dtype() + self._weight_attr = weight_attr + self._bias_attr = bias_attr + self.weight = self.create_parameter(shape=[in_features, out_features], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False) + self.bias = self.create_parameter(shape=[out_features], + attr=self._bias_attr, + dtype=self._dtype, + is_bias=True) + self.name = name + + def forward(self, input): + out = _linear(x=input, + weight=self.weight, + bias=self.bias, + name=self.name) + return out + + def extra_repr(self): + name_str = ', name={}'.format(self.name) if self.name else '' + return 'in_features={}, out_features={}, dtype={}{}'.format( + self.weight.shape[0], self.weight.shape[1], self._dtype, name_str) + + +def _c_softmax_with_cross_entropy(logits, + label, + group=None, + return_softmax=False): + if group is not None and not group.is_member(): + return + ring_id = 0 if group is None else group.id + global_rank = collective._get_global_env().rank + rank = global_rank if group is None else group.get_group_rank(global_rank) + nranks = collective._get_global_env( + ).world_size if group is None else group.nranks + + input_dims = len(list(logits.shape)) + label_dims = len(list(label.shape)) + if input_dims - 1 != label_dims and input_dims != label_dims: + raise ValueError( + 'Expected nput_dims - 1 = label_dims or input_dims == label_dims\ + (got nput_dims{}, label_dims{})'.format(input_dims, label_dims)) + if input_dims - 1 == label_dims: + label = paddle.unsqueeze(label, axis=-1) + + if _non_static_mode(): + softmax, loss = _legacy_C_ops.c_softmax_with_cross_entropy( + logits, label, 'ring_id', ring_id, 'rank', rank, 'nranks', nranks) + if not return_softmax: + return loss + else: + return loss, softmax + + attrs = { + 'ring_id': ring_id, + 'rank': rank, + 'nranks': nranks, + } + helper = LayerHelper('c_softmax_with_cross_entropy', **locals()) + softmax = helper.create_variable_for_type_inference(dtype=logits.dtype) + loss = helper.create_variable_for_type_inference(dtype=logits.dtype) + helper.append_op(type='c_softmax_with_cross_entropy', + inputs={ + 'Logits': logits, + 'Label': label + }, + outputs={ + 'Softmax': softmax, + 'Loss': loss + }, + attrs=attrs) + + if return_softmax: + return loss, softmax + + return loss + + +def _linear(x, weight, bias=None, name=None): + """ + Fuction Linear + """ + if _non_static_mode(): + pre_bias = _varbase_creator(dtype=x.dtype) + _legacy_C_ops.matmul(x, weight, pre_bias, 'transpose_X', False, + 'transpose_Y', False, "alpha", 1) + return dygraph_utils._append_bias_in_dygraph(pre_bias, + bias, + axis=len(x.shape) - 1) + else: + helper = LayerHelper('linear', **locals()) + dtype = x.dtype + assert len( + x.shape) < 4, "X latitude is not supported greater than 3 now." + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'linear') + check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear') + + inputs = {'X': [x], 'Y': [weight]} + attrs = { + 'transpose_X': False, + 'transpose_Y': False, + 'alpha': 1, + } + tmp = helper.create_variable_for_type_inference(dtype) + helper.append_op(type='matmul_v2', + inputs=inputs, + outputs={'Out': tmp}, + attrs=attrs) + if bias is not None: + res = helper.create_variable_for_type_inference(dtype) + helper.append_op(type='elementwise_add', + inputs={ + 'X': [tmp], + 'Y': [bias] + }, + outputs={'Out': [res]}, + attrs={'axis': len(x.shape) - 1}) + else: + res = tmp + return res + + +def _set_var_distributed(var): + if var is None: + return + + var.is_distributed = True + + # NOTE: use current_block and find_var_recursive to support while_loop + startup_block = paddle.static.default_startup_program().current_block() + main_block = paddle.static.default_main_program().current_block() + startup_block._find_var_recursive(var.name).is_distributed = True + main_block._find_var_recursive(var.name).is_distributed = True + + +def _parallel_linear(x, + num_rows, + num_cols, + axis, + param_attr, + bias_attr, + gather_out, + inner_rank, + nranks, + split_tensor, + name, + group=None): + """ + Parallel Linear + + axis the dimension of the parameter of linear layer. + axis = 0: the row dimension + axis = 1: the col dimension + + """ + if group is not None and not group.is_member(): + return + ring_id = 0 if group is None else group.id + + if axis == 0: + if split_tensor: + x = _c_split(x, group=group) + else: + x = _c_identity(x, group=group) + + linear = paddle.nn.Linear(num_rows, + num_cols, + weight_attr=param_attr, + bias_attr=bias_attr, + name=name) + + # NOTE: npu linear function use matmul_v2 but linear use matmul + linear_function = _linear if core.is_compiled_with_npu()\ + else paddle.nn.functional.linear + linear_out = linear_function( + x, + linear.weight, + # NOTE(wangxi): row split, bias need add after allreduce + None if axis == 0 else linear.bias, + linear.name) + + _set_var_distributed(linear.weight) + # set is_distributed for splited bias + # if a linear layer is splited by row, each rank would hold a complete bias and they should be the same in each rank. + # if a linear layer is splited by col, the bias would also be split into each rank as its weight + if axis == 1 and linear._bias_attr != False: + _set_var_distributed(linear.bias) + + if not gather_out: return linear_out + + out_shape = list(linear_out.shape) + out_shape[0] *= 1 if axis == 0 else nranks + main_block = paddle.static.default_main_program().current_block() + out = main_block.create_var( + shape=out_shape, + dtype=linear_out.dtype, + type=linear_out.type, + lod_level=linear_out.lod_level, + persistable=False, + is_data=False, + need_check_feed=linear_out.desc.need_check_feed()) + if axis == 0: + main_block.append_op(type='c_allreduce_sum', + inputs={'X': linear_out}, + outputs={'Out': out}, + attrs={ + 'ring_id': ring_id, + 'use_calc_stream': True, + 'use_model_parallel': True + }) + if linear.bias is not None: + out = out + linear.bias + else: + main_block.append_op(type='c_concat', + inputs={'X': linear_out}, + outputs={'Out': out}, + attrs={ + 'rank': inner_rank, + 'ring_id': ring_id, + 'nranks': nranks, + 'use_calc_stream': True, + 'use_model_parallel': True + }) + return out + + +def _parallel_embedding(x, + per_part_embeddings, + origin_size, + param_attr, + inner_rank, + num_partitions, + name, + group=None): + """ + Parallel Embedding + """ + if group is not None and not group.is_member(): + return + ring_id = 0 if group is None else group.id + + helper = LayerHelper("_parallel_embedding", **locals()) + + per_part_size = per_part_embeddings + rank = inner_rank + + vocab_start_index = rank * per_part_size + dtype = helper.get_default_dtype() + size = [per_part_size, origin_size[1]] + + weight = helper.create_parameter(attr=param_attr, + shape=size, + dtype=dtype, + is_bias=False) + + if num_partitions == 1: + return paddle.nn.functional.embedding(x, + weight=weight, + padding_idx=None, + sparse=False, + name=name) + + startup_block = paddle.static.default_startup_program().global_block() + main_block = paddle.static.default_main_program().global_block() + startup_block.vars[weight.name].is_distributed = True + main_block.vars[weight.name].is_distributed = True + + output_parallel = _c_lookup_table(weight, + x, + start_index=vocab_start_index, + name=name) + out = _mp_allreduce(output_parallel, + group=group, + use_calc_stream=True, + use_model_parallel=True) + return out + + +def split(x, + size, + operation, + axis=0, + num_partitions=1, + gather_out=True, + weight_attr=None, + bias_attr=None, + name=None): + """ + + Split the weight of the specified operation into multiple devices + and do the computation in parallel. + + Now the following three cases are supported. + + Case 1: Parallel Embedding + The weight of the embedding operation is a NxM matrix with N rows and M columns. + With parallel embedding, the weight is split into num_partitions partitions, each + of which is a matrix with (N/num_partitions + 1) rows and M column where the last + row as the padding idx. + + Suppose we split the NxM weight into two partitons on device_0 and device_1 + respectively. Then, one each device, the final weight has (N/2 + 1) rows with the + index range from 0 to N/2. On device_0, all values in the input within [0, N/2 -1] + keep unchanged and all other values are changed to N/2 which is the padding index and + are mapped to all zeros after embedding. In the same way, on device_1, the value V in the + input within [N/2, N-1] will be changed to (V - N/2), and all other values are changed + to N/2 and are mapped to all zeros after embedding. Finally, the results on the two + devices are sum-reduced. + + The Embedding put on single card is as shown below: + + .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_embedding_single.png + :width: 800 + :height: 350 + :alt: single_embedding + :align: center + + Parallel Embedding is shown as below: + + .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_embedding_split.png + :width: 800 + :alt: split_embedding + :align: center + + Case 2: Row Parallel Linear + The weight of the linear operation is a NxM matrix with N rows and M columns. + With row parallel linear, the weight is split into num_partitions partitions, each + of which is a matrix with N/num_partitions rows and M column. + + The linear layer put on single card is shown as below, the input variable is represented by X, + the weight matrix is represented by W and the output vaiable is O. The linear layer on single card is + simple matrix multiplication operation, O = X * W. + + .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_single.png + :width: 800 + :alt: single_linear + :align: center + + Row Parallel Linear is shown as below. As the name suggests, Row Parallel Linear splits the weight matrix W into + [[W_row1], [W_row2]] along the row. And accordingly the input is splitted along the column into [X_col1, X_col2] and multiply their + respective weight matrices. Finally apply AllReduce on the output from each card to get the final output. + + .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_row.png + :width: 800 + :alt: split_row + :align: center + + Case 3: Column Parallel Linear + The weight of the linear operation is a NxM matrix with N rows and M columns. + With column parallel linear, the weight is split into num_paratitions partitions, each + of which is a matrix with N rows and M/num_partitions column. + + The linear layer put on single card has been illustrated on case 2 and Column Parallel Linear + is shown as below. The Column Parallel Linear splits the weight matrix W into [W_col1, W_col2] along the column and + these splitted matrices respectively multiply the input. Finally apply AllGather on the output from each card to get the final output. + + .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_col.png + :width: 800 + :alt: split_col + :align: center + + As observed, the column parallel linear and row parallel linear can be combined to skip one ALLGATHER communication + operator. Furthermore the Attention and MLP can be combined to imporve the performance as shown below. + + .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_col_row.png + :width: 800 + :alt: split_col_row + :align: center + + Args: + x (Tensor): Input tensor. It's data type should be float16, float32, float64, int32 or int64. + size (list|tuple): A list or tuple with two elements indicating the shape of the weight. + operation (str): The name of the operation. The supported operations are 'linear' and 'embedding'. + axis (int, Optional): Indicate along which axis to split the weight. Default: 0. + num_partitions (int, Optional): How many parts the weight is partitioned. Default: 1. + gather_out (bool, Optional): Whether to gather the output after computation. By default, the output + on each partitions will be gathered after computation. Default: True. + weight_attr (ParamAttr, Optional): The parameter attribute for the learnable + weights(Parameter) of the specified operation. Default: None. + bias_attr (ParamAttr, Optional): The parameter attribute for the bias + of the specified operation. Default: None. + name (str, Optional): The default value is None. Normally there is no need for user to set this + property. Default: None. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor. + + Examples: + .. code-block:: python + + # required: distributed + import paddle + import paddle.distributed.fleet as fleet + + paddle.enable_static() + paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id) + fleet.init(is_collective=True) + data = paddle.randint(0, 8, shape=[10,4]) + emb_out = paddle.distributed.split( + data, + (8, 8), + operation="embedding", + num_partitions=2) + + """ + assert isinstance( + size, + (list, tuple)), ("The type of size for " + "paddle.distributed.split must be list or tuple.") + assert len(size) == 2, ("Number of elements in size of " + "paddle.distributed.split must be two.") + assert isinstance(operation, str), ("The type of operation for " + "paddle.distributed.split must be str.") + supported_operations = [ + 'linear', + 'embedding', + ] + assert operation in supported_operations, ( + "The operation for " + "paddle.distributed.split must be one of {}.".format( + supported_operations)) + if _non_static_mode(): + raise ValueError( + "paddle.distributed.split cannot be used in dynamic " + "graph mode, plese use ParallelEmbedding, ParallelRowLinear, " + "ParallelColumnLinear instead.") + else: + from paddle.distributed.fleet import fleet + assert fleet._role_maker, ("To use paddle.distributed.split, " + "you must call fleet.init() firstly.") + rank = fleet.worker_index() + nranks = fleet.worker_num() + + # rank within a model parallel group + inner_rank = rank % num_partitions + + if operation == "embedding": + assert axis == 0, ("We only support to split the weight of embedding " + "along the first axis now.") + assert size[0] % num_partitions == 0, \ + "The length of the vocabulary must be divisible by num_partitions " \ + "but received vocabulary={} num_partitions={}".format(size[0], num_partitions) + + per_part_size = size[0] // num_partitions + emb_out = _parallel_embedding(x, + per_part_size, + size, + weight_attr, + inner_rank, + num_partitions, + name, + group=None) + return emb_out + else: + should_split = False + if axis == 0: + assert size[0] % num_partitions == 0, ( + "Number of rows of the weight for linear ({}) must be" + " divisible by num_partitions ({})".format( + size[0], num_partitions)) + per_part_size = size[0] // num_partitions + linear_size = (per_part_size, size[1]) + if x.shape[-1] == size[0]: should_split = True + + elif axis == 1: + assert size[1] % num_partitions == 0, ( + "Number of column of the weight for linear ({}) must be" + " divisible by num_partitions ({})".format( + size[1], num_partitions)) + per_part_size = size[1] // num_partitions + linear_size = (size[0], per_part_size) + else: + raise ValueError("The value of axis must be 0 or 1, but the value " + "given is {}.".format(axis)) + + linear_out = _parallel_linear(x, + linear_size[0], + linear_size[1], + axis, + weight_attr, + bias_attr, + gather_out, + inner_rank, + num_partitions, + should_split, + name=name, + group=None) + return linear_out diff --git a/python/paddle/distributed/fleet/layers/mpu/random.py b/python/paddle/distributed/fleet/layers/mpu/random.py new file mode 100644 index 00000000000..7577be6253c --- /dev/null +++ b/python/paddle/distributed/fleet/layers/mpu/random.py @@ -0,0 +1,243 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +import contextlib +from paddle import _C_ops, _legacy_C_ops +from paddle.fluid import core +from paddle.fluid.data_feeder import check_variable_and_dtype +from paddle.fluid.framework import _non_static_mode, default_main_program, Variable +from paddle.fluid.layer_helper import LayerHelper + +__all__ = [] + +MODEL_PARALLEL_RNG = 'model_parallel_rng' + +# This file is inspired by Megatron to control random states for MP: +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/random.py + + +class RNGStatesTracker: + """ + Tracker the RNG states. + """ + + def __init__(self): + # Map from name to the rng state. + self.states_ = {} + self.seeds_ = set() + + def reset(self): + self.states_ = {} + self.seeds_ = set() + + def add(self, name, seed): + if seed in self.seeds_: + raise ValueError('seed {} already exists'.format(seed)) + self.seeds_.add(seed) + if name in self.states_: + raise ValueError('state {} already exists'.format(name)) + orig_rng_state = paddle.get_cuda_rng_state() + paddle.seed(seed) + self.states_[name] = paddle.get_cuda_rng_state() + paddle.set_cuda_rng_state(orig_rng_state) + + def get_states_tracker(self): + states = {} + for name in self.states_: + states[name] = self.states_[name] + return states + + def set_states_tracker(self, states): + self.states_ = states + + @contextlib.contextmanager + def rng_state(self, name=MODEL_PARALLEL_RNG): + if name not in self.states_: + raise ValueError('state {} does not exist'.format(name)) + orig_cuda_rng_state = paddle.get_cuda_rng_state() + paddle.set_cuda_rng_state(self.states_[name]) + try: + yield + finally: + self.states_[name] = paddle.get_cuda_rng_state() + paddle.set_cuda_rng_state(orig_cuda_rng_state) + + +RNG_STATE_TRACKER = RNGStatesTracker() + + +def get_rng_state_tracker(): + return RNG_STATE_TRACKER + + +def model_parallel_random_seed(seed=None): + import paddle.distributed.fleet as fleet + hcg = fleet.get_hybrid_communicate_group() + rank = hcg.get_model_parallel_rank() + + if seed: + global_seed = seed + local_seed = seed * 1024 + rank * 100 + else: + global_seed = np.random.randint(0, 655350) + local_seed = np.random.randint(rank * 10000, (rank + 1) * 10000 - 1) + + RNG_STATE_TRACKER.reset() + RNG_STATE_TRACKER.add(MODEL_PARALLEL_RNG, local_seed) + paddle.seed(global_seed) + + +def determinate_seed(rng_name): + assert rng_name is not None and rng_name != "" + helper = LayerHelper('seed', **locals()) + out = helper.create_variable_for_type_inference(dtype=paddle.int32) + # set force_cpu to reduce sync copy from CPU->GPU->CPU, and reduce pipeline hang + helper.append_op(type='seed', + outputs={'Out': out}, + attrs={ + 'deterministic': True, + 'rng_name': rng_name, + 'force_cpu': True + }) + return out + + +def dropout(x, + p=0.5, + axis=None, + rng_name=None, + training=True, + mode="upscale_in_train", + name=None): + """ + Dropout is a regularization technique for reducing overfitting by preventing + neuron co-adaption during training. The dropout operator randomly sets the + outputs of some units to zero, while upscale others according to the given + dropout probability. + + Args: + x (Tensor): The input tensor. The data type is float32 or float64. + p (float|int): Probability of setting units to zero. Default 0.5. + axis (int|list|tuple): The axis along which the dropout is performed. Default None. + rng_name (str): The random seed generator name, which used to obtain deterministic results. + training (bool): A flag indicating whether it is in train phrase or not. Default True. + mode(str): ['upscale_in_train'(default) | 'downscale_in_infer']. + + 1. upscale_in_train(default), upscale the output at training time + + - train: out = input * mask / ( 1.0 - dropout_prob ) + - inference: out = input + + 2. downscale_in_infer, downscale the output at inference + + - train: out = input * mask + - inference: out = input * (1.0 - dropout_prob) + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor representing the dropout, has same shape and data type as `x` . + + + Examples: + We use ``p=0.5`` in the following description for simplicity. + + 1. When ``axis=None`` , this is commonly used dropout, which dropout each element of x randomly. + + .. code-block:: text + + Let's see a simple case when x is a 2d tensor with shape 2*3: + [[1 2 3] + [4 5 6]] + we generate mask with the same shape as x, which is 2*3. The value of mask is + sampled from a Bernoulli distribution randomly. For example, we may get such mask: + [[0 1 0] + [1 0 1]] + So the output is obtained from elementwise multiply of x and mask: + [[0 2 0] + [4 0 6]] + Using default setting, i.e. ``mode='upscale_in_train'`` , + if in training phase, the final upscale output is: + [[0 4 0 ] + [8 0 12]] + if in test phase, the output is the same as input: + [[1 2 3] + [4 5 6]] + we can also set ``mode='downscale_in_infer'`` , then + if in training phase, the final output is: + [[0 2 0] + [4 0 6]] + if in test phase, the scale output is: + [[0.5 1. 1.5] + [2. 2.5 3. ]] + + """ + if rng_name is None: + return paddle.nn.functional.dropout(x, p, axis, training, mode, name) + + if not isinstance(p, (float, int, Variable)): + raise TypeError("p argument should be a number(int|float) or Variable") + + # fast return for p == 0 + if isinstance(p, (int, float)) and p == 0: return x + + assert 0 <= p <= 1, ValueError("p argument should between 0 and 1") + assert mode in ('downscale_in_infer', 'upscale_in_train'), \ + ValueError( + "mode argument should be 'downscale_in_infer' or 'upscale_in_train'") + + assert axis is None, \ + TypeError("unsupport axis when using random seed generator") + + mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer + + # dygraph using tracker, doesn't need determinate seed + if _non_static_mode(): + out, mask = _legacy_C_ops.dropout(x, 'dropout_prob', p, 'is_test', + not training, 'fix_seed', False, + 'seed', 0, 'dropout_implementation', + mode) + return out + + seed = determinate_seed(rng_name) + + if isinstance(p, Variable) and not p.shape != [1]: + raise TypeError( + "Required p.shape == [1] if type(p) is Variable, but received p.shape = {}" + .format(p.shape)) + + helper = LayerHelper('dropout', **locals()) + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'dropout') + + out = helper.create_variable_for_type_inference(dtype=x.dtype) + mask = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) + + helper.append_op(type='dropout', + inputs={ + 'X': [x], + 'Seed': seed + }, + outputs={ + 'Out': [out], + 'Mask': [mask] + }, + attrs={ + 'dropout_prob': p, + 'is_test': not training, + 'dropout_implementation': mode, + }) + return out diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py index 6cb69bc73ce..66a1c877562 100644 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,298 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import paddle -from paddle.fluid import core -from paddle.fluid.dygraph.layers import Layer -from .random import get_rng_state_tracker -from paddle.nn import functional as F -from paddle import framework -from ...base import topology as tp -from paddle.autograd import PyLayer +from ...layers.mpu.mp_layers import VocabParallelEmbedding # noqa: F401 +from ...layers.mpu.mp_layers import ColumnParallelLinear # noqa: F401 +from ...layers.mpu.mp_layers import RowParallelLinear # noqa: F401 +from ...layers.mpu.mp_layers import ParallelCrossEntropy # noqa: F401 __all__ = [] - -# Follow this paper to achieve the file: -# Shoeybi M, Patwary M, Puri R, et al. Megatron-lm: Training multi-billion parameter -# language models using model parallelism[J]. arXiv preprint arXiv:1909.08053, 2019. (https://arxiv.org/abs/1909.08053) - - -def is_fused_matmul_bias_supported(): - if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm(): - return hasattr(core.ops, 'fused_gemm_epilogue') - else: - return False - - -class VocabParallelEmbedding(Layer): - - def __init__(self, - num_embeddings, - embedding_dim, - weight_attr=None, - mp_group=None, - name=None): - super(VocabParallelEmbedding, self).__init__() - - self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( - ) if mp_group is None else mp_group - self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( - ) if mp_group is None else mp_group.nranks - self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank( - ) if mp_group is None else mp_group.rank - - self.origin_num_embeddings = num_embeddings - self.is_mp = (self.world_size > 1) - - assert num_embeddings % self.world_size == 0, ( - "The length of the vocabulary must be divisible by the parallelism degree of MP" - ) - - per_part_size = num_embeddings // self.world_size - - self.vocab_start_index = self.rank * per_part_size - self._dtype = self._helper.get_default_dtype() - self._size = [per_part_size, embedding_dim] - self._weight_attr = weight_attr - self._name = name - - if self.is_mp and paddle.in_dynamic_mode(): - with get_rng_state_tracker().rng_state(): - self.weight = self.create_parameter(attr=self._weight_attr, - shape=self._size, - dtype=self._dtype, - is_bias=False) - else: - self.weight = self.create_parameter(attr=self._weight_attr, - shape=self._size, - dtype=self._dtype, - is_bias=False) - - self.weight.is_distributed = True if self.is_mp else False - - def forward(self, x): - if self.is_mp: - output_parallel = paddle.distributed.collective._c_lookup_table( - self.weight, - x, - start_index=self.vocab_start_index, - name=self._name) - output = paddle.distributed.collective._mp_allreduce( - output_parallel, - group=self.model_parallel_group, - use_calc_stream=True, - use_model_parallel=True) - else: - output = F.embedding(x, - weight=self.weight, - padding_idx=None, - sparse=False, - name=self._name) - return output - - -class ColumnParallelLinear(Layer): - - def __init__(self, - in_features, - out_features, - weight_attr=None, - has_bias=None, - gather_output=True, - fuse_matmul_bias=False, - mp_group=None, - name=None): - super(ColumnParallelLinear, self).__init__() - - self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( - ) if mp_group is None else mp_group - self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( - ) if mp_group is None else mp_group.nranks - self._name = name - self.is_mp = (self.world_size > 1) - - self.gather_output = gather_output - assert out_features % self.world_size == 0, ( - "Number of column of the weight for linear ({}) must be" - " divisible by model parallel size ({})".format( - out_features, self.world_size)) - self.output_size_per_partition = out_features // self.world_size - - self._weight_attr = weight_attr - self._dtype = self._helper.get_default_dtype() - - if self.is_mp and paddle.in_dynamic_mode(): - with get_rng_state_tracker().rng_state(): - self.weight = self.create_parameter( - shape=[in_features, self.output_size_per_partition], - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False) - else: - self.weight = self.create_parameter( - shape=[in_features, self.output_size_per_partition], - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False) - - self.weight.is_distributed = True if self.is_mp else False - - if has_bias: - # initialize bias to zero like Megatron - self.bias = self.create_parameter( - shape=[self.output_size_per_partition], - attr=paddle.nn.initializer.Constant(value=0.0), - dtype=self._dtype, - is_bias=True) - self.bias.is_distributed = True if self.is_mp else False - else: - self.bias = None - - self.linear = F.linear - - if fuse_matmul_bias: - if not is_fused_matmul_bias_supported(): - raise NotImplementedError( - "You set fuse_matmul_bias=True in ColumnParallelLinear, " - "however, the paddle you are using not support this operation. " - "Please set fuse_matmul_bias=False or use paddle compiled " - "with cuda 11.6 or higher.") - from paddle.incubate.nn.functional import fused_linear - self.linear = fused_linear - - def forward(self, x): - # use inner api to process identity - if self.is_mp: - input_parallel = paddle.distributed.collective._c_identity( - x, group=self.model_parallel_group) - else: - input_parallel = x - - output_parallel = self.linear(input_parallel, - self.weight, - self.bias, - name=self._name) - - if self.gather_output and self.is_mp: - output = paddle.distributed.collective._c_concat( - output_parallel, group=self.model_parallel_group) - else: - output = output_parallel - return output - - -class RowParallelLinear(Layer): - - def __init__(self, - in_features, - out_features, - weight_attr=None, - has_bias=True, - input_is_parallel=False, - fuse_matmul_bias=False, - mp_group=None, - name=None): - super(RowParallelLinear, self).__init__() - - self.in_features = in_features - self.out_features = out_features - self.input_is_parallel = input_is_parallel - self._weight_attr = weight_attr - self._dtype = self._helper.get_default_dtype() - self._name = name - - self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( - ) if mp_group is None else mp_group - self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( - ) if mp_group is None else mp_group.nranks - self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank( - ) if mp_group is None else mp_group.rank - - self.is_mp = (self.world_size > 1) - assert in_features % self.world_size == 0, ( - "Number of row of the weight for linear ({}) must be" - " divisible by model parallel size ({})".format( - in_features, self.world_size)) - - self.input_size_per_partition = in_features // self.world_size - - if self.is_mp and paddle.in_dynamic_mode(): - with get_rng_state_tracker().rng_state(): - self.weight = self.create_parameter( - shape=[self.input_size_per_partition, self.out_features], - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False) - else: - self.weight = self.create_parameter( - shape=[self.input_size_per_partition, self.out_features], - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False) - - self.weight.is_distributed = True if self.is_mp else False - - if has_bias: - self.bias = self.create_parameter( - shape=[self.out_features], - attr=paddle.nn.initializer.Constant(value=0.0), - dtype=self._dtype, - is_bias=True) - else: - self.bias = None - - self.linear = F.linear - - if fuse_matmul_bias: - if not is_fused_matmul_bias_supported(): - raise NotImplementedError( - "You set fuse_matmul_bias=True in RowParallelLinear, " - "however, the paddle you are using not support this operation. " - "Please set fuse_matmul_bias=False or use paddle compiled " - "with cuda 11.6 or higher.") - from paddle.incubate.nn.functional import fused_linear - self.linear = fused_linear - - def forward(self, x): - if self.input_is_parallel or (not self.is_mp): - input_parallel = x - else: - # split last dim - input_parallel = paddle.distributed.collective._c_split( - x, group=self.model_parallel_group) - - if self.is_mp: - output_parallel = self.linear(input_parallel, - self.weight, - name=self._name) - output_ = paddle.distributed.collective._mp_allreduce( - output_parallel, - group=self.model_parallel_group, - use_calc_stream=True, - use_model_parallel=True) - output = output_ + self.bias if self.bias is not None else output_ - else: - output = self.linear(input_parallel, - self.weight, - self.bias, - name=self._name) - - return output - - -class ParallelCrossEntropy(Layer): - - def __init__(self, mp_group=None, name=None): - super(ParallelCrossEntropy, self).__init__() - self.name = name - self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( - ) if mp_group is None else mp_group - self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( - ) if mp_group is None else mp_group.nranks - self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank( - ) if mp_group is None else mp_group.rank - - def forward(self, input, label): - loss = paddle.distributed.collective._c_softmax_with_cross_entropy( - input, label, group=self.model_parallel_group) - return loss diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/random.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/random.py index 900c0f79798..9deed30db66 100644 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/random.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/random.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,232 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import paddle -import contextlib -import numpy as np -from paddle import _C_ops, _legacy_C_ops -from paddle.fluid import core -from paddle.fluid.data_feeder import check_variable_and_dtype -from paddle.fluid.framework import _non_static_mode, default_main_program, Variable -from paddle.fluid.layer_helper import LayerHelper +from ...layers.mpu.random import RNGStatesTracker # noqa: F401 +from ...layers.mpu.random import get_rng_state_tracker # noqa: F401 +from ...layers.mpu.random import model_parallel_random_seed # noqa: F401 +from ...layers.mpu.random import determinate_seed # noqa: F401 +from ...layers.mpu.random import dropout # noqa: F401 __all__ = [] - -MODEL_PARALLEL_RNG = 'model_parallel_rng' - -# This file is inspired by Megatron to control random states for MP: -# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/random.py - - -class RNGStatesTracker: - """ - Tracker the RNG states. - """ - - def __init__(self): - # Map from name to the rng state. - self.states_ = {} - self.seeds_ = set() - - def reset(self): - self.states_ = {} - self.seeds_ = set() - - def add(self, name, seed): - if seed in self.seeds_: - raise ValueError('seed {} already exists'.format(seed)) - self.seeds_.add(seed) - if name in self.states_: - raise ValueError('state {} already exists'.format(name)) - orig_rng_state = paddle.get_cuda_rng_state() - paddle.seed(seed) - self.states_[name] = paddle.get_cuda_rng_state() - paddle.set_cuda_rng_state(orig_rng_state) - - def get_states_tracker(self): - states = {} - for name in self.states_: - states[name] = self.states_[name] - return states - - def set_states_tracker(self, states): - self.states_ = states - - @contextlib.contextmanager - def rng_state(self, name=MODEL_PARALLEL_RNG): - if name not in self.states_: - raise ValueError('state {} does not exist'.format(name)) - orig_cuda_rng_state = paddle.get_cuda_rng_state() - paddle.set_cuda_rng_state(self.states_[name]) - try: - yield - finally: - self.states_[name] = paddle.get_cuda_rng_state() - paddle.set_cuda_rng_state(orig_cuda_rng_state) - - -RNG_STATE_TRACKER = RNGStatesTracker() - - -def get_rng_state_tracker(): - return RNG_STATE_TRACKER - - -def model_parallel_random_seed(seed=None): - import paddle.distributed.fleet as fleet - hcg = fleet.get_hybrid_communicate_group() - rank = hcg.get_model_parallel_rank() - - if seed: - global_seed = seed - local_seed = seed * 1024 + rank * 100 - else: - global_seed = np.random.randint(0, 655350) - local_seed = np.random.randint(rank * 10000, (rank + 1) * 10000 - 1) - - RNG_STATE_TRACKER.reset() - RNG_STATE_TRACKER.add(MODEL_PARALLEL_RNG, local_seed) - paddle.seed(global_seed) - - -def determinate_seed(rng_name): - assert rng_name is not None and rng_name != "" - helper = LayerHelper('seed', **locals()) - out = helper.create_variable_for_type_inference(dtype=paddle.int32) - # set force_cpu to reduce sync copy from CPU->GPU->CPU, and reduce pipeline hang - helper.append_op(type='seed', - outputs={'Out': out}, - attrs={ - 'deterministic': True, - 'rng_name': rng_name, - 'force_cpu': True - }) - return out - - -def dropout(x, - p=0.5, - axis=None, - rng_name=None, - training=True, - mode="upscale_in_train", - name=None): - """ - Dropout is a regularization technique for reducing overfitting by preventing - neuron co-adaption during training. The dropout operator randomly sets the - outputs of some units to zero, while upscale others according to the given - dropout probability. - - Args: - x (Tensor): The input tensor. The data type is float32 or float64. - p (float|int): Probability of setting units to zero. Default 0.5. - axis (int|list|tuple): The axis along which the dropout is performed. Default None. - rng_name (str): The random seed generator name, which used to obtain deterministic results. - training (bool): A flag indicating whether it is in train phrase or not. Default True. - mode(str): ['upscale_in_train'(default) | 'downscale_in_infer']. - - 1. upscale_in_train(default), upscale the output at training time - - - train: out = input * mask / ( 1.0 - dropout_prob ) - - inference: out = input - - 2. downscale_in_infer, downscale the output at inference - - - train: out = input * mask - - inference: out = input * (1.0 - dropout_prob) - name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. - - Returns: - A Tensor representing the dropout, has same shape and data type as `x` . - - - Examples: - We use ``p=0.5`` in the following description for simplicity. - - 1. When ``axis=None`` , this is commonly used dropout, which dropout each element of x randomly. - - .. code-block:: text - - Let's see a simple case when x is a 2d tensor with shape 2*3: - [[1 2 3] - [4 5 6]] - we generate mask with the same shape as x, which is 2*3. The value of mask is - sampled from a Bernoulli distribution randomly. For example, we may get such mask: - [[0 1 0] - [1 0 1]] - So the output is obtained from elementwise multiply of x and mask: - [[0 2 0] - [4 0 6]] - Using default setting, i.e. ``mode='upscale_in_train'`` , - if in training phase, the final upscale output is: - [[0 4 0 ] - [8 0 12]] - if in test phase, the output is the same as input: - [[1 2 3] - [4 5 6]] - we can also set ``mode='downscale_in_infer'`` , then - if in training phase, the final output is: - [[0 2 0] - [4 0 6]] - if in test phase, the scale output is: - [[0.5 1. 1.5] - [2. 2.5 3. ]] - - """ - if rng_name is None: - return paddle.nn.functional.dropout(x, p, axis, training, mode, name) - - if not isinstance(p, (float, int, Variable)): - raise TypeError("p argument should be a number(int|float) or Variable") - - # fast return for p == 0 - if isinstance(p, (int, float)) and p == 0: return x - - assert 0 <= p <= 1, ValueError("p argument should between 0 and 1") - assert mode in ('downscale_in_infer', 'upscale_in_train'), \ - ValueError( - "mode argument should be 'downscale_in_infer' or 'upscale_in_train'") - - assert axis is None, \ - TypeError("unsupport axis when using random seed generator") - - mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer - - # dygraph using tracker, doesn't need determinate seed - if _non_static_mode(): - out, mask = _legacy_C_ops.dropout(x, 'dropout_prob', p, 'is_test', - not training, 'fix_seed', False, - 'seed', 0, 'dropout_implementation', - mode) - return out - - seed = determinate_seed(rng_name) - - if isinstance(p, Variable) and not p.shape != [1]: - raise TypeError( - "Required p.shape == [1] if type(p) is Variable, but received p.shape = {}" - .format(p.shape)) - - helper = LayerHelper('dropout', **locals()) - check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], - 'dropout') - - out = helper.create_variable_for_type_inference(dtype=x.dtype) - mask = helper.create_variable_for_type_inference( - dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) - - helper.append_op(type='dropout', - inputs={ - 'X': [x], - 'Seed': seed - }, - outputs={ - 'Out': [out], - 'Mask': [mask] - }, - attrs={ - 'dropout_prob': p, - 'is_test': not training, - 'dropout_implementation': mode, - }) - return out diff --git a/python/setup.py.in b/python/setup.py.in index 3d400881de3..04ad7bf0388 100755 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -307,6 +307,8 @@ packages=['paddle', 'paddle.distributed.fleet.metrics', 'paddle.distributed.fleet.proto', 'paddle.distributed.fleet.utils', + 'paddle.distributed.fleet.layers', + 'paddle.distributed.fleet.layers.mpu', 'paddle.distributed.fleet.meta_parallel', 'paddle.distributed.fleet.meta_parallel.pp_utils', 'paddle.distributed.fleet.meta_parallel.sharding', -- GitLab