未验证 提交 fa97e5ba 编写于 作者: W wuhuachaocoding 提交者: GitHub

refactor mp. (#45803)

* refactor mp.

* update setup.py.

* update mp_layers.py for compatibility.

* add documents for mp_layers.py

* update init.py

* update collective.py.

* update.

* update mp_ops.py

* update.

* update code style.

* update code style.
上级 ae00f428
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class ReduceOp:
"""
Specify the type of operation used for element-wise reductions.
It should be one of the following values:
ReduceOp.SUM
ReduceOp.MAX
ReduceOp.MIN
ReduceOp.PROD
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
if dist.get_rank() == 0:
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
else:
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
dist.all_reduce(data, op=dist.ReduceOp.SUM)
print(data)
# [[5, 7, 9], [5, 7, 9]] (2 GPUs)
"""
SUM = 0
MAX = 1
MIN = 2
PROD = 3
AVG = 4
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .mp_layers import VocabParallelEmbedding
from .mp_layers import ColumnParallelLinear
from .mp_layers import RowParallelLinear
from .mp_layers import ParallelCrossEntropy
from .random import RNGStatesTracker
from .random import get_rng_state_tracker
from .random import model_parallel_random_seed
from .random import determinate_seed
from .random import dropout
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from . import mp_ops
from paddle.fluid import core
from paddle.fluid.dygraph.layers import Layer
from .random import get_rng_state_tracker
from paddle.nn import functional as F
from paddle import framework
from paddle.autograd import PyLayer
from ...base import topology as tp
__all__ = []
# Follow this paper to achieve the file:
# Shoeybi M, Patwary M, Puri R, et al. Megatron-lm: Training multi-billion parameter
# language models using model parallelism[J]. arXiv preprint arXiv:1909.08053, 2019. (https://arxiv.org/abs/1909.08053)
def is_fused_matmul_bias_supported():
if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm():
return hasattr(core.ops, 'fused_gemm_epilogue')
else:
return False
class VocabParallelEmbedding(Layer):
"""Embedding mp parallelized in the vocabulary dimension.
this class is used for splitting embedding in mp group.
Args:
num_embeddings(int): One element which indicate the size of the dictionary of embeddings.
embedding_dim(int): One element which indicate the size of each embedding vector respectively.
weight_attr(ParamAttr|None): To specify the weight parameter property. Default: None, which means the
default weight parameter property is used. See usage for details in :ref:`api_ParamAttr` . In addition,
user-defined or pre-trained word vectors can be loaded with the :attr:`param_attr` parameter.
The local word vector needs to be transformed into numpy format, and the shape of local word
vector should be consistent with :attr:`num_embeddings` . Then :ref:`api_initializer_NumpyArrayInitializer`
is used to load custom or pre-trained word vectors. See code example for details.
mp_group(Group): The tensor parallel group.
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Examples:
.. code-block:: python
import paddle
from paddle.distributed import fleet
class SimpleMPNet(paddle.nn.Layer):
def __init__(self, vocab_size, hidden_size, inner_size, output_size):
super(SimpleMPNet, self).__init__()
self.linear1 = fleet.meta_parallel.ColumnParallelLinear(
hidden_size,
inner_size,
gather_output=False,
has_bias=True)
self.linear2 = fleet.meta_parallel.RowParallelLinear(
inner_size,
hidden_size,
input_is_parallel=True,
has_bias=True)
self.linear3 = paddle.nn.Linear(hidden_size, output_size)
self.embedding = fleet.meta_parallel.VocabParallelEmbedding(
vocab_size,
hidden_size)
def forward(self, x):
x = self.embedding(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
return x
"""
def __init__(self,
num_embeddings,
embedding_dim,
weight_attr=None,
mp_group=None,
name=None):
super(VocabParallelEmbedding, self).__init__()
self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
) if mp_group is None else mp_group
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
) if mp_group is None else mp_group.nranks
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank(
) if mp_group is None else mp_group.rank
self.origin_num_embeddings = num_embeddings
self.is_mp = (self.world_size > 1)
assert num_embeddings % self.world_size == 0, (
"The length of the vocabulary must be divisible by the parallelism degree of MP"
)
per_part_size = num_embeddings // self.world_size
self.vocab_start_index = self.rank * per_part_size
self._dtype = self._helper.get_default_dtype()
self._size = [per_part_size, embedding_dim]
self._weight_attr = weight_attr
self._name = name
if self.is_mp and paddle.in_dynamic_mode():
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter(attr=self._weight_attr,
shape=self._size,
dtype=self._dtype,
is_bias=False)
else:
self.weight = self.create_parameter(attr=self._weight_attr,
shape=self._size,
dtype=self._dtype,
is_bias=False)
self.weight.is_distributed = True if self.is_mp else False
def forward(self, x):
if self.is_mp:
output_parallel = mp_ops._c_lookup_table(
self.weight,
x,
start_index=self.vocab_start_index,
name=self._name)
output = mp_ops._mp_allreduce(output_parallel,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True)
else:
output = F.embedding(x,
weight=self.weight,
padding_idx=None,
sparse=False,
name=self._name)
return output
class ColumnParallelLinear(Layer):
"""Linear layer with mp parallelized(column).
this class is used for splitting Linear Layer in mp group, column split the weight of the Linear layer.
Args:
in_features(int): The number of input units.
out_features(int): The number of output units.
weight_attr(ParamAttr|None): The attribute for the learnable weight of this layer. The default value is None
and the weight will be initialized to zero. For detailed information, please refer to paddle.ParamAttr.
has_bias(bool): whether to add bias.
gather_output(bool): whether to do allgahter for the output of each rank.
fuse_matmul_bias(bool): whether to fuse matmul and bias.
mp_group(Group): The tensor parallel group.
name(str, optional): Normally there is no need for user to set this parameter.
For detailed information, please refer to :ref:`api_guide_Name` .
Examples:
.. code-block:: python
import paddle
from paddle.distributed import fleet
class SimpleMPNet(paddle.nn.Layer):
def __init__(self, vocab_size, hidden_size, inner_size, output_size):
super(SimpleMPNet, self).__init__()
self.linear1 = fleet.meta_parallel.ColumnParallelLinear(
hidden_size,
inner_size,
gather_output=False,
has_bias=True)
self.linear2 = fleet.meta_parallel.RowParallelLinear(
inner_size,
hidden_size,
input_is_parallel=True,
has_bias=True)
self.linear3 = paddle.nn.Linear(hidden_size, output_size)
self.embedding = fleet.meta_parallel.VocabParallelEmbedding(
vocab_size,
hidden_size)
def forward(self, x):
x = self.embedding(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
return x
"""
def __init__(self,
in_features,
out_features,
weight_attr=None,
has_bias=None,
gather_output=True,
fuse_matmul_bias=False,
mp_group=None,
name=None):
super(ColumnParallelLinear, self).__init__()
self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
) if mp_group is None else mp_group
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
) if mp_group is None else mp_group.nranks
self._name = name
self.is_mp = (self.world_size > 1)
self.gather_output = gather_output
assert out_features % self.world_size == 0, (
"Number of column of the weight for linear ({}) must be"
" divisible by model parallel size ({})".format(
out_features, self.world_size))
self.output_size_per_partition = out_features // self.world_size
self._weight_attr = weight_attr
self._dtype = self._helper.get_default_dtype()
if self.is_mp and paddle.in_dynamic_mode():
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter(
shape=[in_features, self.output_size_per_partition],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
else:
self.weight = self.create_parameter(
shape=[in_features, self.output_size_per_partition],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
self.weight.is_distributed = True if self.is_mp else False
if has_bias:
# initialize bias to zero like Megatron
self.bias = self.create_parameter(
shape=[self.output_size_per_partition],
attr=paddle.nn.initializer.Constant(value=0.0),
dtype=self._dtype,
is_bias=True)
self.bias.is_distributed = True if self.is_mp else False
else:
self.bias = None
self.linear = F.linear
if fuse_matmul_bias:
if not is_fused_matmul_bias_supported():
raise NotImplementedError(
"You set fuse_matmul_bias=True in ColumnParallelLinear, "
"however, the paddle you are using not support this operation. "
"Please set fuse_matmul_bias=False or use paddle compiled "
"with cuda 11.6 or higher.")
from paddle.incubate.nn.functional import fused_linear
self.linear = fused_linear
def forward(self, x):
# use inner api to process identity
if self.is_mp:
input_parallel = mp_ops._c_identity(x,
group=self.model_parallel_group)
else:
input_parallel = x
output_parallel = self.linear(input_parallel,
self.weight,
self.bias,
name=self._name)
if self.gather_output and self.is_mp:
output = mp_ops._c_concat(output_parallel,
group=self.model_parallel_group)
else:
output = output_parallel
return output
class RowParallelLinear(Layer):
"""Linear layer with mp parallelized(row).
this class is used for splitting Linear Layer in mp group, row split the weight of the Linear layer.
Args:
in_features(int): The number of input units.
out_features(int): The number of output units.
weight_attr(ParamAttr|None): The attribute for the learnable weight of this layer. The default value is None
and the weight will be initialized to zero. For detailed information, please refer to paddle.ParamAttr.
has_bias(bool): whether to add bias.
input_is_parallel(bool): whether the input has alreadly been splitted across the mp group.
fuse_matmul_bias(bool): whether to fuse matmul and bias.
mp_group(Group): The tensor parallel group.
name(str, optional): Normally there is no need for user to set this parameter.
For detailed information, please refer to :ref:`api_guide_Name` .
Examples:
.. code-block:: python
import paddle
from paddle.distributed import fleet
class SimpleMPNet(paddle.nn.Layer):
def __init__(self, vocab_size, hidden_size, inner_size, output_size):
super(SimpleMPNet, self).__init__()
self.linear1 = fleet.meta_parallel.ColumnParallelLinear(
hidden_size,
inner_size,
gather_output=False,
has_bias=True)
self.linear2 = fleet.meta_parallel.RowParallelLinear(
inner_size,
hidden_size,
input_is_parallel=True,
has_bias=True)
self.linear3 = paddle.nn.Linear(hidden_size, output_size)
self.embedding = fleet.meta_parallel.VocabParallelEmbedding(
vocab_size,
hidden_size)
def forward(self, x):
x = self.embedding(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
return x
"""
def __init__(self,
in_features,
out_features,
weight_attr=None,
has_bias=True,
input_is_parallel=False,
fuse_matmul_bias=False,
mp_group=None,
name=None):
super(RowParallelLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.input_is_parallel = input_is_parallel
self._weight_attr = weight_attr
self._dtype = self._helper.get_default_dtype()
self._name = name
self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
) if mp_group is None else mp_group
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
) if mp_group is None else mp_group.nranks
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank(
) if mp_group is None else mp_group.rank
self.is_mp = (self.world_size > 1)
assert in_features % self.world_size == 0, (
"Number of row of the weight for linear ({}) must be"
" divisible by model parallel size ({})".format(
in_features, self.world_size))
self.input_size_per_partition = in_features // self.world_size
if self.is_mp and paddle.in_dynamic_mode():
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter(
shape=[self.input_size_per_partition, self.out_features],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
else:
self.weight = self.create_parameter(
shape=[self.input_size_per_partition, self.out_features],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
self.weight.is_distributed = True if self.is_mp else False
if has_bias:
self.bias = self.create_parameter(
shape=[self.out_features],
attr=paddle.nn.initializer.Constant(value=0.0),
dtype=self._dtype,
is_bias=True)
else:
self.bias = None
self.linear = F.linear
if fuse_matmul_bias:
if not is_fused_matmul_bias_supported():
raise NotImplementedError(
"You set fuse_matmul_bias=True in RowParallelLinear, "
"however, the paddle you are using not support this operation. "
"Please set fuse_matmul_bias=False or use paddle compiled "
"with cuda 11.6 or higher.")
from paddle.incubate.nn.functional import fused_linear
self.linear = fused_linear
def forward(self, x):
if self.input_is_parallel or (not self.is_mp):
input_parallel = x
else:
# split last dim
input_parallel = mp_ops._c_split(x, group=self.model_parallel_group)
if self.is_mp:
output_parallel = self.linear(input_parallel,
self.weight,
name=self._name)
output_ = mp_ops._mp_allreduce(output_parallel,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True)
output = output_ + self.bias if self.bias is not None else output_
else:
output = self.linear(input_parallel,
self.weight,
self.bias,
name=self._name)
return output
class ParallelCrossEntropy(Layer):
"""CrossEntropy with mp parallelized.
this class is used for splitting softmax cross entropy in mp group.
Args:
mp_group(Group): The tensor parallel group.
name(str, optional): Normally there is no need for user to set this parameter.
For detailed information, please refer to :ref:`api_guide_Name` .
Examples:
.. code-block:: python
loss_func = ParallelCrossEntropy()
loss = loss_func(img, lable)
"""
def __init__(self, mp_group=None, name=None):
super(ParallelCrossEntropy, self).__init__()
self.name = name
self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
) if mp_group is None else mp_group
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
) if mp_group is None else mp_group.nranks
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank(
) if mp_group is None else mp_group.rank
def forward(self, input, label):
loss = mp_ops._c_softmax_with_cross_entropy(
input, label, group=self.model_parallel_group)
return loss
此差异已折叠。
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import contextlib
from paddle import _C_ops, _legacy_C_ops
from paddle.fluid import core
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.framework import _non_static_mode, default_main_program, Variable
from paddle.fluid.layer_helper import LayerHelper
__all__ = []
MODEL_PARALLEL_RNG = 'model_parallel_rng'
# This file is inspired by Megatron to control random states for MP:
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/random.py
class RNGStatesTracker:
"""
Tracker the RNG states.
"""
def __init__(self):
# Map from name to the rng state.
self.states_ = {}
self.seeds_ = set()
def reset(self):
self.states_ = {}
self.seeds_ = set()
def add(self, name, seed):
if seed in self.seeds_:
raise ValueError('seed {} already exists'.format(seed))
self.seeds_.add(seed)
if name in self.states_:
raise ValueError('state {} already exists'.format(name))
orig_rng_state = paddle.get_cuda_rng_state()
paddle.seed(seed)
self.states_[name] = paddle.get_cuda_rng_state()
paddle.set_cuda_rng_state(orig_rng_state)
def get_states_tracker(self):
states = {}
for name in self.states_:
states[name] = self.states_[name]
return states
def set_states_tracker(self, states):
self.states_ = states
@contextlib.contextmanager
def rng_state(self, name=MODEL_PARALLEL_RNG):
if name not in self.states_:
raise ValueError('state {} does not exist'.format(name))
orig_cuda_rng_state = paddle.get_cuda_rng_state()
paddle.set_cuda_rng_state(self.states_[name])
try:
yield
finally:
self.states_[name] = paddle.get_cuda_rng_state()
paddle.set_cuda_rng_state(orig_cuda_rng_state)
RNG_STATE_TRACKER = RNGStatesTracker()
def get_rng_state_tracker():
return RNG_STATE_TRACKER
def model_parallel_random_seed(seed=None):
import paddle.distributed.fleet as fleet
hcg = fleet.get_hybrid_communicate_group()
rank = hcg.get_model_parallel_rank()
if seed:
global_seed = seed
local_seed = seed * 1024 + rank * 100
else:
global_seed = np.random.randint(0, 655350)
local_seed = np.random.randint(rank * 10000, (rank + 1) * 10000 - 1)
RNG_STATE_TRACKER.reset()
RNG_STATE_TRACKER.add(MODEL_PARALLEL_RNG, local_seed)
paddle.seed(global_seed)
def determinate_seed(rng_name):
assert rng_name is not None and rng_name != ""
helper = LayerHelper('seed', **locals())
out = helper.create_variable_for_type_inference(dtype=paddle.int32)
# set force_cpu to reduce sync copy from CPU->GPU->CPU, and reduce pipeline hang
helper.append_op(type='seed',
outputs={'Out': out},
attrs={
'deterministic': True,
'rng_name': rng_name,
'force_cpu': True
})
return out
def dropout(x,
p=0.5,
axis=None,
rng_name=None,
training=True,
mode="upscale_in_train",
name=None):
"""
Dropout is a regularization technique for reducing overfitting by preventing
neuron co-adaption during training. The dropout operator randomly sets the
outputs of some units to zero, while upscale others according to the given
dropout probability.
Args:
x (Tensor): The input tensor. The data type is float32 or float64.
p (float|int): Probability of setting units to zero. Default 0.5.
axis (int|list|tuple): The axis along which the dropout is performed. Default None.
rng_name (str): The random seed generator name, which used to obtain deterministic results.
training (bool): A flag indicating whether it is in train phrase or not. Default True.
mode(str): ['upscale_in_train'(default) | 'downscale_in_infer'].
1. upscale_in_train(default), upscale the output at training time
- train: out = input * mask / ( 1.0 - dropout_prob )
- inference: out = input
2. downscale_in_infer, downscale the output at inference
- train: out = input * mask
- inference: out = input * (1.0 - dropout_prob)
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor representing the dropout, has same shape and data type as `x` .
Examples:
We use ``p=0.5`` in the following description for simplicity.
1. When ``axis=None`` , this is commonly used dropout, which dropout each element of x randomly.
.. code-block:: text
Let's see a simple case when x is a 2d tensor with shape 2*3:
[[1 2 3]
[4 5 6]]
we generate mask with the same shape as x, which is 2*3. The value of mask is
sampled from a Bernoulli distribution randomly. For example, we may get such mask:
[[0 1 0]
[1 0 1]]
So the output is obtained from elementwise multiply of x and mask:
[[0 2 0]
[4 0 6]]
Using default setting, i.e. ``mode='upscale_in_train'`` ,
if in training phase, the final upscale output is:
[[0 4 0 ]
[8 0 12]]
if in test phase, the output is the same as input:
[[1 2 3]
[4 5 6]]
we can also set ``mode='downscale_in_infer'`` , then
if in training phase, the final output is:
[[0 2 0]
[4 0 6]]
if in test phase, the scale output is:
[[0.5 1. 1.5]
[2. 2.5 3. ]]
"""
if rng_name is None:
return paddle.nn.functional.dropout(x, p, axis, training, mode, name)
if not isinstance(p, (float, int, Variable)):
raise TypeError("p argument should be a number(int|float) or Variable")
# fast return for p == 0
if isinstance(p, (int, float)) and p == 0: return x
assert 0 <= p <= 1, ValueError("p argument should between 0 and 1")
assert mode in ('downscale_in_infer', 'upscale_in_train'), \
ValueError(
"mode argument should be 'downscale_in_infer' or 'upscale_in_train'")
assert axis is None, \
TypeError("unsupport axis when using random seed generator")
mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer
# dygraph using tracker, doesn't need determinate seed
if _non_static_mode():
out, mask = _legacy_C_ops.dropout(x, 'dropout_prob', p, 'is_test',
not training, 'fix_seed', False,
'seed', 0, 'dropout_implementation',
mode)
return out
seed = determinate_seed(rng_name)
if isinstance(p, Variable) and not p.shape != [1]:
raise TypeError(
"Required p.shape == [1] if type(p) is Variable, but received p.shape = {}"
.format(p.shape))
helper = LayerHelper('dropout', **locals())
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'dropout')
out = helper.create_variable_for_type_inference(dtype=x.dtype)
mask = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
helper.append_op(type='dropout',
inputs={
'X': [x],
'Seed': seed
},
outputs={
'Out': [out],
'Mask': [mask]
},
attrs={
'dropout_prob': p,
'is_test': not training,
'dropout_implementation': mode,
})
return out
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,298 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.fluid import core
from paddle.fluid.dygraph.layers import Layer
from .random import get_rng_state_tracker
from paddle.nn import functional as F
from paddle import framework
from ...base import topology as tp
from paddle.autograd import PyLayer
from ...layers.mpu.mp_layers import VocabParallelEmbedding # noqa: F401
from ...layers.mpu.mp_layers import ColumnParallelLinear # noqa: F401
from ...layers.mpu.mp_layers import RowParallelLinear # noqa: F401
from ...layers.mpu.mp_layers import ParallelCrossEntropy # noqa: F401
__all__ = []
# Follow this paper to achieve the file:
# Shoeybi M, Patwary M, Puri R, et al. Megatron-lm: Training multi-billion parameter
# language models using model parallelism[J]. arXiv preprint arXiv:1909.08053, 2019. (https://arxiv.org/abs/1909.08053)
def is_fused_matmul_bias_supported():
if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm():
return hasattr(core.ops, 'fused_gemm_epilogue')
else:
return False
class VocabParallelEmbedding(Layer):
def __init__(self,
num_embeddings,
embedding_dim,
weight_attr=None,
mp_group=None,
name=None):
super(VocabParallelEmbedding, self).__init__()
self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
) if mp_group is None else mp_group
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
) if mp_group is None else mp_group.nranks
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank(
) if mp_group is None else mp_group.rank
self.origin_num_embeddings = num_embeddings
self.is_mp = (self.world_size > 1)
assert num_embeddings % self.world_size == 0, (
"The length of the vocabulary must be divisible by the parallelism degree of MP"
)
per_part_size = num_embeddings // self.world_size
self.vocab_start_index = self.rank * per_part_size
self._dtype = self._helper.get_default_dtype()
self._size = [per_part_size, embedding_dim]
self._weight_attr = weight_attr
self._name = name
if self.is_mp and paddle.in_dynamic_mode():
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter(attr=self._weight_attr,
shape=self._size,
dtype=self._dtype,
is_bias=False)
else:
self.weight = self.create_parameter(attr=self._weight_attr,
shape=self._size,
dtype=self._dtype,
is_bias=False)
self.weight.is_distributed = True if self.is_mp else False
def forward(self, x):
if self.is_mp:
output_parallel = paddle.distributed.collective._c_lookup_table(
self.weight,
x,
start_index=self.vocab_start_index,
name=self._name)
output = paddle.distributed.collective._mp_allreduce(
output_parallel,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True)
else:
output = F.embedding(x,
weight=self.weight,
padding_idx=None,
sparse=False,
name=self._name)
return output
class ColumnParallelLinear(Layer):
def __init__(self,
in_features,
out_features,
weight_attr=None,
has_bias=None,
gather_output=True,
fuse_matmul_bias=False,
mp_group=None,
name=None):
super(ColumnParallelLinear, self).__init__()
self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
) if mp_group is None else mp_group
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
) if mp_group is None else mp_group.nranks
self._name = name
self.is_mp = (self.world_size > 1)
self.gather_output = gather_output
assert out_features % self.world_size == 0, (
"Number of column of the weight for linear ({}) must be"
" divisible by model parallel size ({})".format(
out_features, self.world_size))
self.output_size_per_partition = out_features // self.world_size
self._weight_attr = weight_attr
self._dtype = self._helper.get_default_dtype()
if self.is_mp and paddle.in_dynamic_mode():
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter(
shape=[in_features, self.output_size_per_partition],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
else:
self.weight = self.create_parameter(
shape=[in_features, self.output_size_per_partition],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
self.weight.is_distributed = True if self.is_mp else False
if has_bias:
# initialize bias to zero like Megatron
self.bias = self.create_parameter(
shape=[self.output_size_per_partition],
attr=paddle.nn.initializer.Constant(value=0.0),
dtype=self._dtype,
is_bias=True)
self.bias.is_distributed = True if self.is_mp else False
else:
self.bias = None
self.linear = F.linear
if fuse_matmul_bias:
if not is_fused_matmul_bias_supported():
raise NotImplementedError(
"You set fuse_matmul_bias=True in ColumnParallelLinear, "
"however, the paddle you are using not support this operation. "
"Please set fuse_matmul_bias=False or use paddle compiled "
"with cuda 11.6 or higher.")
from paddle.incubate.nn.functional import fused_linear
self.linear = fused_linear
def forward(self, x):
# use inner api to process identity
if self.is_mp:
input_parallel = paddle.distributed.collective._c_identity(
x, group=self.model_parallel_group)
else:
input_parallel = x
output_parallel = self.linear(input_parallel,
self.weight,
self.bias,
name=self._name)
if self.gather_output and self.is_mp:
output = paddle.distributed.collective._c_concat(
output_parallel, group=self.model_parallel_group)
else:
output = output_parallel
return output
class RowParallelLinear(Layer):
def __init__(self,
in_features,
out_features,
weight_attr=None,
has_bias=True,
input_is_parallel=False,
fuse_matmul_bias=False,
mp_group=None,
name=None):
super(RowParallelLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.input_is_parallel = input_is_parallel
self._weight_attr = weight_attr
self._dtype = self._helper.get_default_dtype()
self._name = name
self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
) if mp_group is None else mp_group
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
) if mp_group is None else mp_group.nranks
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank(
) if mp_group is None else mp_group.rank
self.is_mp = (self.world_size > 1)
assert in_features % self.world_size == 0, (
"Number of row of the weight for linear ({}) must be"
" divisible by model parallel size ({})".format(
in_features, self.world_size))
self.input_size_per_partition = in_features // self.world_size
if self.is_mp and paddle.in_dynamic_mode():
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter(
shape=[self.input_size_per_partition, self.out_features],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
else:
self.weight = self.create_parameter(
shape=[self.input_size_per_partition, self.out_features],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
self.weight.is_distributed = True if self.is_mp else False
if has_bias:
self.bias = self.create_parameter(
shape=[self.out_features],
attr=paddle.nn.initializer.Constant(value=0.0),
dtype=self._dtype,
is_bias=True)
else:
self.bias = None
self.linear = F.linear
if fuse_matmul_bias:
if not is_fused_matmul_bias_supported():
raise NotImplementedError(
"You set fuse_matmul_bias=True in RowParallelLinear, "
"however, the paddle you are using not support this operation. "
"Please set fuse_matmul_bias=False or use paddle compiled "
"with cuda 11.6 or higher.")
from paddle.incubate.nn.functional import fused_linear
self.linear = fused_linear
def forward(self, x):
if self.input_is_parallel or (not self.is_mp):
input_parallel = x
else:
# split last dim
input_parallel = paddle.distributed.collective._c_split(
x, group=self.model_parallel_group)
if self.is_mp:
output_parallel = self.linear(input_parallel,
self.weight,
name=self._name)
output_ = paddle.distributed.collective._mp_allreduce(
output_parallel,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True)
output = output_ + self.bias if self.bias is not None else output_
else:
output = self.linear(input_parallel,
self.weight,
self.bias,
name=self._name)
return output
class ParallelCrossEntropy(Layer):
def __init__(self, mp_group=None, name=None):
super(ParallelCrossEntropy, self).__init__()
self.name = name
self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
) if mp_group is None else mp_group
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
) if mp_group is None else mp_group.nranks
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank(
) if mp_group is None else mp_group.rank
def forward(self, input, label):
loss = paddle.distributed.collective._c_softmax_with_cross_entropy(
input, label, group=self.model_parallel_group)
return loss
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,232 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import contextlib
import numpy as np
from paddle import _C_ops, _legacy_C_ops
from paddle.fluid import core
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.framework import _non_static_mode, default_main_program, Variable
from paddle.fluid.layer_helper import LayerHelper
from ...layers.mpu.random import RNGStatesTracker # noqa: F401
from ...layers.mpu.random import get_rng_state_tracker # noqa: F401
from ...layers.mpu.random import model_parallel_random_seed # noqa: F401
from ...layers.mpu.random import determinate_seed # noqa: F401
from ...layers.mpu.random import dropout # noqa: F401
__all__ = []
MODEL_PARALLEL_RNG = 'model_parallel_rng'
# This file is inspired by Megatron to control random states for MP:
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/random.py
class RNGStatesTracker:
"""
Tracker the RNG states.
"""
def __init__(self):
# Map from name to the rng state.
self.states_ = {}
self.seeds_ = set()
def reset(self):
self.states_ = {}
self.seeds_ = set()
def add(self, name, seed):
if seed in self.seeds_:
raise ValueError('seed {} already exists'.format(seed))
self.seeds_.add(seed)
if name in self.states_:
raise ValueError('state {} already exists'.format(name))
orig_rng_state = paddle.get_cuda_rng_state()
paddle.seed(seed)
self.states_[name] = paddle.get_cuda_rng_state()
paddle.set_cuda_rng_state(orig_rng_state)
def get_states_tracker(self):
states = {}
for name in self.states_:
states[name] = self.states_[name]
return states
def set_states_tracker(self, states):
self.states_ = states
@contextlib.contextmanager
def rng_state(self, name=MODEL_PARALLEL_RNG):
if name not in self.states_:
raise ValueError('state {} does not exist'.format(name))
orig_cuda_rng_state = paddle.get_cuda_rng_state()
paddle.set_cuda_rng_state(self.states_[name])
try:
yield
finally:
self.states_[name] = paddle.get_cuda_rng_state()
paddle.set_cuda_rng_state(orig_cuda_rng_state)
RNG_STATE_TRACKER = RNGStatesTracker()
def get_rng_state_tracker():
return RNG_STATE_TRACKER
def model_parallel_random_seed(seed=None):
import paddle.distributed.fleet as fleet
hcg = fleet.get_hybrid_communicate_group()
rank = hcg.get_model_parallel_rank()
if seed:
global_seed = seed
local_seed = seed * 1024 + rank * 100
else:
global_seed = np.random.randint(0, 655350)
local_seed = np.random.randint(rank * 10000, (rank + 1) * 10000 - 1)
RNG_STATE_TRACKER.reset()
RNG_STATE_TRACKER.add(MODEL_PARALLEL_RNG, local_seed)
paddle.seed(global_seed)
def determinate_seed(rng_name):
assert rng_name is not None and rng_name != ""
helper = LayerHelper('seed', **locals())
out = helper.create_variable_for_type_inference(dtype=paddle.int32)
# set force_cpu to reduce sync copy from CPU->GPU->CPU, and reduce pipeline hang
helper.append_op(type='seed',
outputs={'Out': out},
attrs={
'deterministic': True,
'rng_name': rng_name,
'force_cpu': True
})
return out
def dropout(x,
p=0.5,
axis=None,
rng_name=None,
training=True,
mode="upscale_in_train",
name=None):
"""
Dropout is a regularization technique for reducing overfitting by preventing
neuron co-adaption during training. The dropout operator randomly sets the
outputs of some units to zero, while upscale others according to the given
dropout probability.
Args:
x (Tensor): The input tensor. The data type is float32 or float64.
p (float|int): Probability of setting units to zero. Default 0.5.
axis (int|list|tuple): The axis along which the dropout is performed. Default None.
rng_name (str): The random seed generator name, which used to obtain deterministic results.
training (bool): A flag indicating whether it is in train phrase or not. Default True.
mode(str): ['upscale_in_train'(default) | 'downscale_in_infer'].
1. upscale_in_train(default), upscale the output at training time
- train: out = input * mask / ( 1.0 - dropout_prob )
- inference: out = input
2. downscale_in_infer, downscale the output at inference
- train: out = input * mask
- inference: out = input * (1.0 - dropout_prob)
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor representing the dropout, has same shape and data type as `x` .
Examples:
We use ``p=0.5`` in the following description for simplicity.
1. When ``axis=None`` , this is commonly used dropout, which dropout each element of x randomly.
.. code-block:: text
Let's see a simple case when x is a 2d tensor with shape 2*3:
[[1 2 3]
[4 5 6]]
we generate mask with the same shape as x, which is 2*3. The value of mask is
sampled from a Bernoulli distribution randomly. For example, we may get such mask:
[[0 1 0]
[1 0 1]]
So the output is obtained from elementwise multiply of x and mask:
[[0 2 0]
[4 0 6]]
Using default setting, i.e. ``mode='upscale_in_train'`` ,
if in training phase, the final upscale output is:
[[0 4 0 ]
[8 0 12]]
if in test phase, the output is the same as input:
[[1 2 3]
[4 5 6]]
we can also set ``mode='downscale_in_infer'`` , then
if in training phase, the final output is:
[[0 2 0]
[4 0 6]]
if in test phase, the scale output is:
[[0.5 1. 1.5]
[2. 2.5 3. ]]
"""
if rng_name is None:
return paddle.nn.functional.dropout(x, p, axis, training, mode, name)
if not isinstance(p, (float, int, Variable)):
raise TypeError("p argument should be a number(int|float) or Variable")
# fast return for p == 0
if isinstance(p, (int, float)) and p == 0: return x
assert 0 <= p <= 1, ValueError("p argument should between 0 and 1")
assert mode in ('downscale_in_infer', 'upscale_in_train'), \
ValueError(
"mode argument should be 'downscale_in_infer' or 'upscale_in_train'")
assert axis is None, \
TypeError("unsupport axis when using random seed generator")
mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer
# dygraph using tracker, doesn't need determinate seed
if _non_static_mode():
out, mask = _legacy_C_ops.dropout(x, 'dropout_prob', p, 'is_test',
not training, 'fix_seed', False,
'seed', 0, 'dropout_implementation',
mode)
return out
seed = determinate_seed(rng_name)
if isinstance(p, Variable) and not p.shape != [1]:
raise TypeError(
"Required p.shape == [1] if type(p) is Variable, but received p.shape = {}"
.format(p.shape))
helper = LayerHelper('dropout', **locals())
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'dropout')
out = helper.create_variable_for_type_inference(dtype=x.dtype)
mask = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
helper.append_op(type='dropout',
inputs={
'X': [x],
'Seed': seed
},
outputs={
'Out': [out],
'Mask': [mask]
},
attrs={
'dropout_prob': p,
'is_test': not training,
'dropout_implementation': mode,
})
return out
......@@ -307,6 +307,8 @@ packages=['paddle',
'paddle.distributed.fleet.metrics',
'paddle.distributed.fleet.proto',
'paddle.distributed.fleet.utils',
'paddle.distributed.fleet.layers',
'paddle.distributed.fleet.layers.mpu',
'paddle.distributed.fleet.meta_parallel',
'paddle.distributed.fleet.meta_parallel.pp_utils',
'paddle.distributed.fleet.meta_parallel.sharding',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册