提交 a45e0d18 编写于 作者: M Megvii Engine Team

fix(imperative): improve the imperative interface of mha

GitOrigin-RevId: 0e6da8ece5ed474ecd131d8b8e05db03ec5eae5c
上级 a92aea1f
...@@ -1333,16 +1333,34 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), ...@@ -1333,16 +1333,34 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'),
(pdef('MultiHeadAttn') (pdef('MultiHeadAttn')
.add_fields('uint32', Doc('num_heads', 'Number of parallel attention heads.'), '1') .add_fields('uint32', Doc('num_heads', 'Number of parallel attention heads.'), '1')
.add_fields('uint32', Doc('embeding_size', 'Total dimension of the model.'), '0')
.add_fields('uint32', Doc('k_size', 'Total number of features for keys.'), '0')
.add_fields('uint32', Doc('v_size', 'Total number of features for values.'), '0')
.add_fields('uint32', Doc('qproj_size', 'query weight projection.'), '0')
.add_fields('uint32', Doc('kproj_size', 'key weight projection.'), '0')
.add_fields('uint32', Doc('vproj_size', 'value weight projection.'), '0')
.add_fields('uint32', Doc('oproj_size', 'output weight projection.'), '0')
.add_fields('bool', Doc('qbias', 'Whether to add query bias.'), 'false')
.add_fields('bool', Doc('kbias', 'Whether to add key bias.'), 'false')
.add_fields('bool', Doc('vbias', 'Whether to add value bias.'), 'false')
.add_fields('bool', Doc('obias', 'Whether to add out bias.'), 'false')
.add_fields('float32', Doc('sm_scaler', 'Softmax smoothing (1.0 >= smScaler >= 0.0) or sharpening (smScaler > 1.0) coefficient.'), '1.f') .add_fields('float32', Doc('sm_scaler', 'Softmax smoothing (1.0 >= smScaler >= 0.0) or sharpening (smScaler > 1.0) coefficient.'), '1.f')
.add_fields('uint32', Doc('input_order', 'The sequence data layout, allows the user to select 3! = 6 different data layouts or permutations of BEAM, BATCH and TIME dimensions.'), '0') .add_fields('uint32', Doc('input_order', 'The sequence data layout, allows the user to select 3! = 6 different data layouts or permutations of BEAM, BATCH and TIME dimensions.'), '0')
.add_enum('ATTN_MASK_TYPE',
Doc('NO_MASK = 0', 'Indicates that there is no mask.'),
Doc('DEFAULT_MASK = 1', 'Use the default mask which the upper right triangle of the mask is -inf, and the diagonal and lower left triangle are all 0.'),
Doc('CUDNN_STYLE_MASK = 2', 'Indicates the use of a cudnn style mask.'),
Doc('USER_DEFINED_MASK = 3', 'Use the user-defined mask.'), name_field="attn_mask_type")
.add_enum(Doc('TENSOR_COMBINATION_TYPE', 'Used to determine whether mask tensor and bias_kv tensor exist in the input. Note that bias_kv here is not kbias and vbias in the linear layer, and bias_kv here will be added to the K and V at sequence dimensions, where K and V are the matrices of key and value after projection, and K and V will be used to calculate the attention matrix.'),
Doc('NONE = 0', 'Indicates that there are no mask tensor and bias_kv tensor in the input.'),
Doc('ONLY_MASK = 1',
'Indicates that there is only mask tensor in input.'),
Doc('ONLY_BIASKV = 2', 'Indicates that there is only bias_kv tensor in input.'),
Doc('ALL = 3', 'Indicates that there are mask tensor and bias_kv tensor in the input.'), name_field="tensor_combination_type")
.add_fields('bool', Doc('add_zero_attn', 'Whether to add a new batch of zeros to the key and value sequences.'), 'false')
.add_fields('bool', Doc('need_weights', 'Whether to return the attention matrix, which is the output result of softmax.'), 'false')
.add_fields('bool', Doc('reslink', 'Whether to add input query to final output.'), 'false') .add_fields('bool', Doc('reslink', 'Whether to add input query to final output.'), 'false')
.add_fields('bool', Doc('training', 'Whether it is in training mode.'), 'true') .add_fields('bool', Doc('training', 'Whether it is in training mode.'), 'true')
.add_fields('bool', Doc('bias', 'Whether to add linear bias.'), 'false')
.add_fields('bool', Doc('attn_mask', 'Whether to add attn_mask.'), 'false')
.add_fields('bool', Doc('enable_qproj', 'enable query weight projection.'), 'true')
.add_fields('bool', Doc('enable_kproj', 'enable key weight projection.'), 'true')
.add_fields('bool', Doc('enable_vproj', 'enable value weight projection.'), 'true')
.add_fields('bool', Doc('enable_oproj', 'enable output weight projection.'), 'true')
.add_fields('uint64', Doc('seed', 'Random number seed for drop'), '0') .add_fields('uint64', Doc('seed', 'Random number seed for drop'), '0')
.add_fields('float32', Doc('attn_prob', 'Dropout probability on attention, is applied directly to the softmax output'), '0.f') .add_fields('float32', Doc('attn_prob', 'Dropout probability on attention, is applied directly to the softmax output'), '0.f')
.add_fields('float32', Doc('out_prob', 'Dropout probability on output, alters the multi-head attention output'), '0.f') .add_fields('float32', Doc('out_prob', 'Dropout probability on output, alters the multi-head attention output'), '0.f')
......
...@@ -109,14 +109,14 @@ void MultiHeadAttnStatus::set( ...@@ -109,14 +109,14 @@ void MultiHeadAttnStatus::set(
kSize = k.shape[2]; kSize = k.shape[2];
vSize = v.shape[2]; vSize = v.shape[2];
numHeads = p.num_heads; numHeads = p.num_heads;
qProjSize = p.enable_qproj ? qSize / numHeads : 0; qProjSize = p.qproj_size ? qSize / numHeads : 0;
kProjSize = p.enable_kproj ? kSize / numHeads : 0; kProjSize = p.kproj_size ? kSize / numHeads : 0;
vProjSize = p.enable_vproj ? vSize / numHeads : 0; vProjSize = p.vproj_size ? vSize / numHeads : 0;
oProjSize = p.enable_oproj ? qSize : 0; oProjSize = p.oproj_size ? qSize : 0;
attnMask = p.attn_mask; attnMask = p.attn_mask_type >= param::MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK;
cudnnDataType_t cudnn_dtype = to_cudnn_dtype(q.dtype); cudnnDataType_t cudnn_dtype = to_cudnn_dtype(q.dtype);
auto flag = CUDNN_ATTN_QUERYMAP_ONE_TO_ONE; auto flag = CUDNN_ATTN_QUERYMAP_ONE_TO_ONE;
if (p.bias) if (p.qbias or p.kbias or p.vbias or p.obias)
flag = flag | CUDNN_ATTN_ENABLE_PROJ_BIASES; flag = flag | CUDNN_ATTN_ENABLE_PROJ_BIASES;
#if CUDNN_VERSION < 8600 #if CUDNN_VERSION < 8600
// TODO: CUDNN_VERSION < 8600 and out dropout > 0.0, we need to go to the proxy cuda // TODO: CUDNN_VERSION < 8600 and out dropout > 0.0, we need to go to the proxy cuda
...@@ -134,7 +134,9 @@ void MultiHeadAttnStatus::set( ...@@ -134,7 +134,9 @@ void MultiHeadAttnStatus::set(
vProjSize, oProjSize, seqLenQ, seqLenK, batchSize, 1)); vProjSize, oProjSize, seqLenQ, seqLenK, batchSize, 1));
#endif #endif
auxArray.set(batchSize, seqLenQ, seqLenK, p.attn_mask); auxArray.set(
batchSize, seqLenQ, seqLenK,
p.attn_mask_type >= param::MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK);
if (p.training) if (p.training)
cudnnGetMultiHeadAttnBuffers( cudnnGetMultiHeadAttnBuffers(
...@@ -157,16 +159,18 @@ bool MultiHeadAttnStatus::is_initialized( ...@@ -157,16 +159,18 @@ bool MultiHeadAttnStatus::is_initialized(
return false; return false;
if (q.shape[0] != batchSize or q.shape[1] != seqLenQ or k.shape[1] != seqLenK or if (q.shape[0] != batchSize or q.shape[1] != seqLenQ or k.shape[1] != seqLenK or
q.shape[2] != qSize or k.shape[2] != kSize or v.shape[2] != vSize or q.shape[2] != qSize or k.shape[2] != kSize or v.shape[2] != vSize or
attnMask != p.attn_mask or numHeads != p.num_heads) { attnMask != (p.attn_mask_type >=
param::MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK) or
numHeads != p.num_heads) {
return false; return false;
} }
if ((p.enable_qproj && (qProjSize == 0 or qProjSize != qSize / p.num_heads)) or if ((p.qproj_size && (qProjSize == 0 or qProjSize != qSize / p.num_heads)) or
(p.enable_kproj && (kProjSize == 0 or kProjSize != kSize / p.num_heads)) or (p.kproj_size && (kProjSize == 0 or kProjSize != kSize / p.num_heads)) or
(p.enable_vproj && (vProjSize == 0 or vProjSize != vSize / p.num_heads)) or (p.vproj_size && (vProjSize == 0 or vProjSize != vSize / p.num_heads)) or
(p.enable_oproj && (oProjSize == 0 or oProjSize != q.shape[2]))) (p.oproj_size && (oProjSize == 0 or oProjSize != q.shape[2])))
return false; return false;
if ((!p.enable_qproj && qProjSize != 0) or (!p.enable_kproj && kProjSize != 0) or if ((!p.qproj_size && qProjSize != 0) or (!p.kproj_size && kProjSize != 0) or
(!p.enable_vproj && vProjSize != 0) or (!p.enable_oproj && oProjSize != 0)) (!p.vproj_size && vProjSize != 0) or (!p.oproj_size && oProjSize != 0))
return false; return false;
if (!auxArray.is_initialized(batchSize, seqLenQ, seqLenK, attnMask)) if (!auxArray.is_initialized(batchSize, seqLenQ, seqLenK, attnMask))
return false; return false;
......
...@@ -163,7 +163,7 @@ void MultiHeadAttnBackwardImpl::exec( ...@@ -163,7 +163,7 @@ void MultiHeadAttnBackwardImpl::exec(
#else #else
#if CUDNN_VERSION < 8600 #if CUDNN_VERSION < 8600
megdnn_assert( megdnn_assert(
!param().bias, !(param().qbias or param().kbias or param().vbias or param().obias),
"If the cudnn version is lower than 8.6.0, param().bias must be false, " "If the cudnn version is lower than 8.6.0, param().bias must be false, "
"but got true, because there is an error in the " "but got true, because there is an error in the "
"dbias result during the backward calculation."); "dbias result during the backward calculation.");
......
...@@ -9,7 +9,7 @@ from megengine import Parameter ...@@ -9,7 +9,7 @@ from megengine import Parameter
from ..device import get_cudnn_version, is_cuda_available from ..device import get_cudnn_version, is_cuda_available
from ..functional.nn import multi_head_attention from ..functional.nn import multi_head_attention
from ..tensor import Tensor from ..tensor import Tensor
from .init import ones_, zeros_ from .init import ones_, xavier_uniform_, zeros_
from .module import Module from .module import Module
...@@ -24,19 +24,37 @@ class MultiHeadAttention(Module): ...@@ -24,19 +24,37 @@ class MultiHeadAttention(Module):
where :math:`h_i=W_{V,i}V \text{Softmax}\Big( \text{smScaler} \cdot K^TW^T_{K,i}W_{Q,i}q \Big),\text{for }i\text{ = 0 ... nHeads-1}`. where :math:`h_i=W_{V,i}V \text{Softmax}\Big( \text{smScaler} \cdot K^TW^T_{K,i}W_{Q,i}q \Big),\text{for }i\text{ = 0 ... nHeads-1}`.
Note: This API is experimental, and there is a possibility of subsequent changes. Currently, only the cuda platform is supported, and if the cudnn version >=8.6.0, the calculation results are completely correct; If the cudnn version >=8.0.4 but <8.6.0, if there is a bias, only the dbias result calculated from the backward is incorrect. If there is no bias, the forward and backward calculations are correct; If the cudnn version is less than 8.0.4, this operator is not supported. Note: This API is experimental, and there is a possibility of subsequent changes. Currently, only the cuda platform is supported, and if the cudnn version >=8.6.0, the calculation results are completely correct; If the cudnn version >=8.0.4 but <8.6.0, if there is a bias, only the dbias result calculated from the backward is incorrect. If there is no bias, the forward and backward calculations are correct; If the cudnn version is less than 8.0.4, this operator is not supported.
When the following conditions are met, you can go to the cudnn backend:
- ``cudnn version`` greater than or equal to 8.0.4 and ``bias`` is ``False`` and ``training`` is ``False``
- ``cudnn version`` greater than or equal to 8.6.0
- ``add_bias_kv`` is ``False``
- ``add_zero_attn`` is ``False``
- ``need_weights`` is ``False``
- ``average_attn_weights`` is ``False``
- ``maybe_cudnn_style_mask`` is ``True`` if support else ``False``
- ``attn_mask`` and ``key_padding_mask`` is cudnn style mask, i.e. the shape of the attn_mask is :math:`(2, L)`, and the shape of the key_padding_mask is :math:`(2, N)`.
- The shape of attn_mask is :math:`(2, L)`, where :math:`(0, :)` elements specify the start index, :math:`(1, :)` elements specify the end index, the start index is inclusive, and the end index is not exclusive. The start index (i.e. elements in `attn_mask[0, x]`) must be less than the corresponding end index (i.e. elements in `attn_mask[1, x]`). The end index must be less than or equal to :math:`S`, where :math:`S` is the source sequence length, :math:`L` is the target sequence length.
- The shape of key_padding_mask is :math:`(2, N)`, where :math:`(0, :)` elements specify the target sequence padding in cudnn style mask and the element must equal to or less than :math:`L`, :math:`(1, :)` elements specify the source sequence padding in cudnn style mask and the element must equal to or less than :math:`S`, where :math:`S` is the source sequence length, :math:`L` is the target sequence length.
- ``qbias``, ``kbias``, ``vbias`` and ``obias`` are equal
Args: Args:
embed_dim: Total dimension of the model. embed_dim: Total dimension of the model.
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). attn_dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
out_dropout: Dropout probability on ``output``. Default: ``0.0`` (no dropout).
bias: If specified, adds bias to input / output projection layers. Default: ``True``. bias: If specified, adds bias to input / output projection layers. Default: ``True``.
add_bias_kv: If specified, adds bias to the key and value sequences at sequence dim. Default: ``False``.
Different from kbias and vbias, bias_kv here is not kbias and vbias in the linear layer, and bias_kv here will be added to the K and V at sequence dimensions, where K and V are the matrices of key and value after projection, and K and V will be used to calculate the attention matrix.
Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences.
Default: ``False``.
Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
enable_qproj: enable query weight projection. Default: ``True``.
enable_kproj: enable key weight projection. Default: ``True``.
enable_vproj: enable value weight projection. Default: ``True``.
enable_oproj: enable output weight projection. Default: ``True``.
Examples:: Examples::
>>> import numpy as np >>> import numpy as np
...@@ -44,7 +62,7 @@ class MultiHeadAttention(Module): ...@@ -44,7 +62,7 @@ class MultiHeadAttention(Module):
>>> x = Tensor(np.arange(batch_size * seq_len * embed_dim).astype(np.float32).reshape(batch_size, seq_len, embed_dim)) >>> x = Tensor(np.arange(batch_size * seq_len * embed_dim).astype(np.float32).reshape(batch_size, seq_len, embed_dim))
>>> multihead_attn = M.MultiHeadAttention(embed_dim, num_heads) >>> multihead_attn = M.MultiHeadAttention(embed_dim, num_heads)
>>> if is_cuda_available() and get_cudnn_version() >= 8004: >>> if is_cuda_available() and get_cudnn_version() >= 8004:
... out = multihead_attn(x, x, x) ... out = multihead_attn(x, x, x)[0]
... out.numpy().shape ... out.numpy().shape
... else: ... else:
... print(np.zeros((2,4,4)).shape) ... print(np.zeros((2,4,4)).shape)
...@@ -57,84 +75,143 @@ class MultiHeadAttention(Module): ...@@ -57,84 +75,143 @@ class MultiHeadAttention(Module):
num_heads, num_heads,
attn_dropout=0.0, attn_dropout=0.0,
out_dropout=0.0, out_dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None, kdim=None,
vdim=None, vdim=None,
bias=True,
enable_qproj=True,
enable_kproj=True,
enable_vproj=True,
enable_oproj=True,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.add_bias_kv = add_bias_kv
self.add_zero_attn = add_zero_attn
self.num_heads = num_heads self.num_heads = num_heads
self.attn_dropout = attn_dropout self.attn_dropout = attn_dropout
self.out_dropout = out_dropout self.out_dropout = out_dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.unsupport_reason = " The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation."
assert ( assert (
self.head_dim * num_heads == self.embed_dim self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads" ), "embed_dim must be divisible by num_heads"
assert ( assert add_bias_kv == False, (
self._qkv_same_embed_dim "add_bias_kv should be set to False, and configuration of this parameter is not supported now."
), "it does not support the case where q, k, and v are different." + self.unsupport_reason
)
assert add_zero_attn == False, (
"add_zero_attn should be set to False, and configuration of this parameter is not supported now."
+ self.unsupport_reason
)
self.bias = bias self.bias = bias
self.weight_bias_len = (
self.embed_dim + self.kdim + self.vdim + self.embed_dim
) * self.embed_dim + (4 * self.embed_dim if self.bias else 0)
self.enable_qproj = enable_qproj self.io_weight_bias = Parameter(
self.enable_kproj = enable_kproj np.empty((1, self.weight_bias_len), dtype="float32")
self.enable_vproj = enable_vproj )
self.enable_oproj = enable_oproj self.bias_k = (
self.nproj = enable_qproj + enable_kproj + enable_vproj + enable_oproj Parameter(np.empty((1, 1, embed_dim), dtype="float32"))
if self.add_bias_kv
else None
)
self.bias_v = (
Parameter(np.empty((1, 1, embed_dim), dtype="float32"))
if self.add_bias_kv
else None
)
if self.bias:
io_weight = np.ones((embed_dim, self.nproj * embed_dim))
io_bias = np.zeros((1, self.nproj * embed_dim))
self.io_weight_bias = Parameter(
np.concatenate((io_weight, io_bias), axis=0), dtype="float32"
)
else:
self.io_weight_bias = Parameter(
np.ones((self.nproj * embed_dim, embed_dim), dtype="float32")
)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
self.attn_dropout = 0.0 self.attn_dropout = 0.0
self.out_dropout = 0.0 self.out_dropout = 0.0
xavier_uniform_(self.io_weight_bias)
if self.bias: if self.bias:
io_weight = np.ones((self.embed_dim, self.nproj * self.embed_dim)) weight_len = (
io_bias = np.zeros((1, self.nproj * self.embed_dim)) self.embed_dim + self.kdim + self.vdim + self.embed_dim
self.io_weight_bias._reset(np.concatenate((io_weight, io_bias), axis=0)) ) * self.embed_dim
self.io_weight_bias[0, weight_len:,] = 0
if self.add_bias_kv:
xavier_uniform_(self.bias_k)
else: else:
ones_(self.io_weight_bias) self.bias_k = None
if self.add_bias_kv:
xavier_uniform_(self.bias_v)
else:
self.bias_v = None
def forward( def forward(
self, query, key, value, attn_mask: bool = True, self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
need_weights: bool = False,
average_attn_weights: bool = False,
is_causal: bool = False,
maybe_cudnn_style_mask: bool = False,
): ):
r""" r"""
Args: Args:
query: Query embeddings of shape :math:`(N, L, E_q)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, query: Query embeddings of shape :math:`(N, L, E_q)`,
and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against where :math:`N` is the batch size, :math:`L` is the target sequence length, and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against key-value pairs to produce the output. See "Attention Is All You Need" for more details.
key-value pairs to produce the output. See "Attention Is All You Need" for more details. key: Key embeddings of shape :math:`(N, S, E_k)`,
key: Key embeddings of shape :math:`(N, S, E_k)`, where :math:`N` is the batch size, :math:`S` is the source sequence length, and where :math:`N` is the batch size, :math:`S` is the source sequence length, and :math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details.
:math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details. value: Value embeddings of shape :math:`(N, S, E_v)`,
value: Value embeddings of shape :math:`(N, S, E_v)`, where :math:`N` is the batch size, :math:`S` is the source sequence length, and where :math:`N` is the batch size, :math:`S` is the source sequence length, and :math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details.
:math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details. key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` to ignore for the purpose of
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, Note: Should be set to None, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Note: User-defined mask not supported now, only support no mask or default mask, where the upper right triangle is all -inf, and the diagonal and lower left triangle are all 0. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
need_weights: indicates whether to return the attention weight, which is the output result of softmax. Default: `True`
Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
is_causal: If specified, applies a causal mask as attention mask. Default: ``False``
Warning: ``is_causal`` provides a hint that ``attn_mask`` is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility.
maybe_cudnn_style_mask: if specified, applies a cudnn style mask as attention mask. Default: ``False``
Note: In the cudnn style, the shape of the attn_mask is :math:`(2, L)`, and the shape of the key_padding_mask is :math:`(2, N)`.
Warning: like is_causal, maybe_cudnn_style_mask provides a hint that attn_mask and key_padding_mask is a cudnn style mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility. In addition, if the ``_merge_masks`` function returns ``merge_type=cudnn_style_mask``, please ensure that other conditions are correct so that it can run the implementation of cudnn, otherwise an error will be reported.
Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that the underlying implementation only accepts two types of mask type, namely "no_mask" and "default_mask", and we may try to loosen this option after submitting the commit that users can pass in custom attention mask tensors.
Outputs: Outputs:
- **attn_output** - Attention outputs of shape :math:`(N, L, E)`, - **attn_output** - Attention outputs of shape :math:`(N, L, E)`,
where :math:`L` is the target sequence length, :math:`N` is where :math:`L` is the target sequence length, :math:`N` is
the batch size, and :math:`E` is the embedding dimension ``embed_dim``. the batch size, and :math:`E` is the embedding dimension ``embed_dim``.
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N * \text{num\_heads}, L, S)`.
Note: Now only None will be returned. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
""" """
assert key_padding_mask is None, (
"key_padding_mask should be None, and configuration of this parameter is not supported now."
+ self.unsupport_reason
)
assert need_weights == False, (
"need_weights should be set to False, and configuration of this parameter is not supported now."
+ self.unsupport_reason
)
assert average_attn_weights == False, (
"average_attn_weights should be set to False, and configuration of this parameter is not supported now."
+ self.unsupport_reason
)
assert maybe_cudnn_style_mask == False, (
"maybe_cudnn_style_mask should be set to False, and configuration of this parameter is not supported now."
+ self.unsupport_reason
)
return multi_head_attention( return multi_head_attention(
query, query,
...@@ -145,13 +222,24 @@ class MultiHeadAttention(Module): ...@@ -145,13 +222,24 @@ class MultiHeadAttention(Module):
self.attn_dropout, self.attn_dropout,
self.out_dropout, self.out_dropout,
self.io_weight_bias, self.io_weight_bias,
self.bias, qproj_size=self.embed_dim,
kproj_size=self.embed_dim,
vproj_size=self.embed_dim,
oproj_size=self.embed_dim,
qbias=self.bias,
kbias=self.bias,
vbias=self.bias,
obias=self.bias,
bias_k=self.bias_k,
bias_v=self.bias_v,
add_zero_attn=self.add_zero_attn,
training=self.training, training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask, attn_mask=attn_mask,
enable_qproj=self.enable_qproj, average_attn_weights=average_attn_weights,
enable_kproj=self.enable_kproj, is_causal=is_causal,
enable_vproj=self.enable_vproj, maybe_cudnn_style_mask=maybe_cudnn_style_mask,
enable_oproj=self.enable_oproj,
) )
def _module_info_string(self) -> str: def _module_info_string(self) -> str:
......
...@@ -296,11 +296,19 @@ struct OpMeth<MultiHeadAttn> { ...@@ -296,11 +296,19 @@ struct OpMeth<MultiHeadAttn> {
handle_seed == opdef.seed, handle_seed == opdef.seed,
"inconsistent multiheadattn seed: dropout op: %lu handle: %lu", "inconsistent multiheadattn seed: dropout op: %lu handle: %lu",
handle_seed, opdef.seed); handle_seed, opdef.seed);
return {opdef.num_heads, opdef.sm_scaler, opdef.input_order,
opdef.reslink, opdef.training, opdef.bias, return {opdef.num_heads, opdef.embeding_size,
opdef.attn_mask, opdef.enable_qproj, opdef.enable_kproj, opdef.k_size, opdef.v_size,
opdef.enable_vproj, opdef.enable_oproj, handle_seed, opdef.qproj_size, opdef.kproj_size,
opdef.attn_prob, opdef.out_prob}; opdef.vproj_size, opdef.oproj_size,
opdef.qbias, opdef.kbias,
opdef.vbias, opdef.obias,
opdef.sm_scaler, opdef.input_order,
opdef.attn_mask_type, opdef.tensor_combination_type,
opdef.add_zero_attn, opdef.need_weights,
opdef.reslink, opdef.training,
handle_seed, opdef.attn_prob,
opdef.out_prob};
} }
}; };
......
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
cb(::megdnn::param::CvtColor::Mode); \ cb(::megdnn::param::CvtColor::Mode); \
cb(::megdnn::param::Elemwise::Mode); \ cb(::megdnn::param::Elemwise::Mode); \
cb(::megdnn::param::ElemwiseMultiType::Mode); \ cb(::megdnn::param::ElemwiseMultiType::Mode); \
cb(::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE); \
cb(::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE); \
cb(::megdnn::param::Padding::PaddingMode); \ cb(::megdnn::param::Padding::PaddingMode); \
cb(::megdnn::param::RNNCell::NonlineMode); \ cb(::megdnn::param::RNNCell::NonlineMode); \
cb(::megdnn::param::ROIAlignV0::Mode); \ cb(::megdnn::param::ROIAlignV0::Mode); \
......
c5a5d1bd44473912f14cecee3df6409e ../../dnn/scripts/opr_param_defs.py 0a8cd3cd50cadfaae0478ee70621618e ../../dnn/scripts/opr_param_defs.py
4ed3e8cbef0fa5f4d6824d8d55dec722 ../../src/core/include/megbrain/ir/ops.td 9e9636d66694dd7d5a7853247a5406f9 ../../src/core/include/megbrain/ir/ops.td
dc2d4ec8f4f5e203ce0a76bc20f62529 generated/opdef.h.inl 283dffd0e9cd28db5155c44cf4eda148 generated/opdef.h.inl
906957f12994d43c69248a6acfefa396 generated/opdef.cpp.inl 5e8d57337c3aec6f4b3b30ef9ba141f8 generated/opdef.cpp.inl
8817af8997ba0cc00048e71093755238 generated/opdef.py.inl 7f470236e4b5b00bdeaec321bc7187b5 generated/opdef.py.inl
c43ae8b706e3f3658fe3cc0f60061981 generated/opdef.cpy.inl 003addd357423b880cd06410f5bf624b generated/opdef.cpy.inl
71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h d468302f2d4b113913b76b5a181aae56 generated/enum_macro.h
...@@ -5200,28 +5200,54 @@ size_t MultiHeadAttn_hash_impl(const OpDef& def_) { ...@@ -5200,28 +5200,54 @@ size_t MultiHeadAttn_hash_impl(const OpDef& def_) {
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash(op_.num_heads), mgb::hash(op_.num_heads),
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash(op_.sm_scaler), mgb::hash(op_.embeding_size),
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash(op_.input_order), mgb::hash(op_.k_size),
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash(op_.reslink), mgb::hash(op_.v_size),
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash(op_.training), mgb::hash(op_.qproj_size),
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash(op_.bias), mgb::hash(op_.kproj_size),
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash(op_.attn_mask), mgb::hash(op_.vproj_size),
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash(op_.enable_qproj), mgb::hash(op_.oproj_size),
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash(op_.enable_kproj), mgb::hash(op_.qbias),
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash(op_.enable_vproj), mgb::hash(op_.kbias),
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash(op_.enable_oproj), mgb::hash(op_.vbias),
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash(op_.attn_prob), mgb::hash(op_.obias),
mgb::hash(op_.out_prob) mgb::hash_pair_combine(
mgb::hash(op_.sm_scaler),
mgb::hash_pair_combine(
mgb::hash(op_.input_order),
mgb::hash_pair_combine(
mgb::hash(op_.attn_mask_type),
mgb::hash_pair_combine(
mgb::hash(op_.tensor_combination_type),
mgb::hash_pair_combine(
mgb::hash(op_.add_zero_attn),
mgb::hash_pair_combine(
mgb::hash(op_.need_weights),
mgb::hash_pair_combine(
mgb::hash(op_.reslink),
mgb::hash_pair_combine(
mgb::hash(op_.training),
mgb::hash_pair_combine(
mgb::hash(op_.attn_prob),
mgb::hash(op_.out_prob))
)
)
)
)
)
)
)
)
) )
) )
) )
...@@ -5242,22 +5268,63 @@ bool MultiHeadAttn_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { ...@@ -5242,22 +5268,63 @@ bool MultiHeadAttn_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
&&b_ = rhs_.cast_final_safe<MultiHeadAttn>(); &&b_ = rhs_.cast_final_safe<MultiHeadAttn>();
static_cast<void>(a_); static_cast<void>(a_);
static_cast<void>(b_); static_cast<void>(b_);
return a_.handle == b_.handle && a_.num_heads == b_.num_heads && a_.sm_scaler == b_.sm_scaler && a_.input_order == b_.input_order && a_.reslink == b_.reslink && a_.training == b_.training && a_.bias == b_.bias && a_.attn_mask == b_.attn_mask && a_.enable_qproj == b_.enable_qproj && a_.enable_kproj == b_.enable_kproj && a_.enable_vproj == b_.enable_vproj && a_.enable_oproj == b_.enable_oproj && a_.attn_prob == b_.attn_prob && a_.out_prob == b_.out_prob;} return a_.handle == b_.handle && a_.num_heads == b_.num_heads && a_.embeding_size == b_.embeding_size && a_.k_size == b_.k_size && a_.v_size == b_.v_size && a_.qproj_size == b_.qproj_size && a_.kproj_size == b_.kproj_size && a_.vproj_size == b_.vproj_size && a_.oproj_size == b_.oproj_size && a_.qbias == b_.qbias && a_.kbias == b_.kbias && a_.vbias == b_.vbias && a_.obias == b_.obias && a_.sm_scaler == b_.sm_scaler && a_.input_order == b_.input_order && a_.reslink == b_.reslink && a_.training == b_.training && a_.need_weights == b_.need_weights && a_.attn_mask_type == b_.attn_mask_type && a_.add_zero_attn == b_.add_zero_attn && a_.tensor_combination_type == b_.tensor_combination_type && a_.attn_prob == b_.attn_prob && a_.out_prob == b_.out_prob;}
std::vector<std::pair<const char*, std::string>> MultiHeadAttn_props_impl(const OpDef& def_) { std::vector<std::pair<const char*, std::string>> MultiHeadAttn_props_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<MultiHeadAttn>(); auto&& op_ = def_.cast_final_safe<MultiHeadAttn>();
static_cast<void>(op_); static_cast<void>(op_);
std::vector<std::pair<const char*, std::string>> props_; std::vector<std::pair<const char*, std::string>> props_;
props_.emplace_back("num_heads", std::to_string(op_.num_heads)); props_.emplace_back("num_heads", std::to_string(op_.num_heads));
props_.emplace_back("embeding_size", std::to_string(op_.embeding_size));
props_.emplace_back("k_size", std::to_string(op_.k_size));
props_.emplace_back("v_size", std::to_string(op_.v_size));
props_.emplace_back("qproj_size", std::to_string(op_.qproj_size));
props_.emplace_back("kproj_size", std::to_string(op_.kproj_size));
props_.emplace_back("vproj_size", std::to_string(op_.vproj_size));
props_.emplace_back("oproj_size", std::to_string(op_.oproj_size));
props_.emplace_back("qbias", std::to_string(op_.qbias));
props_.emplace_back("kbias", std::to_string(op_.kbias));
props_.emplace_back("vbias", std::to_string(op_.vbias));
props_.emplace_back("obias", std::to_string(op_.obias));
props_.emplace_back("sm_scaler", std::to_string(op_.sm_scaler)); props_.emplace_back("sm_scaler", std::to_string(op_.sm_scaler));
props_.emplace_back("input_order", std::to_string(op_.input_order)); props_.emplace_back("input_order", std::to_string(op_.input_order));
switch (op_.attn_mask_type){
case MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK:
props_.emplace_back("attn_mask_type", "NO_MASK");
break;
case MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK:
props_.emplace_back("attn_mask_type", "DEFAULT_MASK");
break;
case MultiHeadAttn::ATTN_MASK_TYPE::CUDNN_STYLE_MASK:
props_.emplace_back("attn_mask_type", "CUDNN_STYLE_MASK");
break;
case MultiHeadAttn::ATTN_MASK_TYPE::USER_DEFINED_MASK:
props_.emplace_back("attn_mask_type", "USER_DEFINED_MASK");
break;
default:
props_.emplace_back("attn_mask_type", "INVALID");
break;
}
switch (op_.tensor_combination_type){
case MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE:
props_.emplace_back("tensor_combination_type", "NONE");
break;
case MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_MASK:
props_.emplace_back("tensor_combination_type", "ONLY_MASK");
break;
case MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_BIASKV:
props_.emplace_back("tensor_combination_type", "ONLY_BIASKV");
break;
case MultiHeadAttn::TENSOR_COMBINATION_TYPE::ALL:
props_.emplace_back("tensor_combination_type", "ALL");
break;
default:
props_.emplace_back("tensor_combination_type", "INVALID");
break;
}
props_.emplace_back("need_weights", std::to_string(op_.need_weights));
props_.emplace_back("add_zero_attn", std::to_string(op_.add_zero_attn));
props_.emplace_back("reslink", std::to_string(op_.reslink)); props_.emplace_back("reslink", std::to_string(op_.reslink));
props_.emplace_back("training", std::to_string(op_.training)); props_.emplace_back("training", std::to_string(op_.training));
props_.emplace_back("bias", std::to_string(op_.bias));
props_.emplace_back("attn_mask", std::to_string(op_.attn_mask));
props_.emplace_back("enable_qproj", std::to_string(op_.enable_qproj));
props_.emplace_back("enable_kproj", std::to_string(op_.enable_kproj));
props_.emplace_back("enable_vproj", std::to_string(op_.enable_vproj));
props_.emplace_back("enable_oproj", std::to_string(op_.enable_oproj));
props_.emplace_back("seed", std::to_string(op_.seed)); props_.emplace_back("seed", std::to_string(op_.seed));
props_.emplace_back("attn_prob", std::to_string(op_.attn_prob)); props_.emplace_back("attn_prob", std::to_string(op_.attn_prob));
props_.emplace_back("out_prob", std::to_string(op_.out_prob)); props_.emplace_back("out_prob", std::to_string(op_.out_prob));
......
...@@ -1398,26 +1398,37 @@ class MultiHeadAttn : public OpDefImplBase<MultiHeadAttn> { ...@@ -1398,26 +1398,37 @@ class MultiHeadAttn : public OpDefImplBase<MultiHeadAttn> {
MGB_DYN_TYPE_OBJ_FINAL_DECL; MGB_DYN_TYPE_OBJ_FINAL_DECL;
public: public:
using ATTN_MASK_TYPE = ::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE;
using TENSOR_COMBINATION_TYPE = ::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE;
uint32_t num_heads = 1; uint32_t num_heads = 1;
uint32_t embeding_size = 0;
uint32_t k_size = 0;
uint32_t v_size = 0;
uint32_t qproj_size = 0;
uint32_t kproj_size = 0;
uint32_t vproj_size = 0;
uint32_t oproj_size = 0;
bool qbias = false;
bool kbias = false;
bool vbias = false;
bool obias = false;
float sm_scaler = 1.f; float sm_scaler = 1.f;
uint32_t input_order = 0; uint32_t input_order = 0;
ATTN_MASK_TYPE attn_mask_type = ::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK;
TENSOR_COMBINATION_TYPE tensor_combination_type = ::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE;
bool need_weights = false;
bool add_zero_attn = false;
bool reslink = false; bool reslink = false;
bool training = true; bool training = true;
bool bias = false;
bool attn_mask = false;
bool enable_qproj = true;
bool enable_kproj = true;
bool enable_vproj = true;
bool enable_oproj = true;
uint64_t seed = 0; uint64_t seed = 0;
float attn_prob = 0.f; float attn_prob = 0.f;
float out_prob = 0.f; float out_prob = 0.f;
size_t handle; size_t handle;
MultiHeadAttn() = default; MultiHeadAttn() = default;
MultiHeadAttn(uint32_t num_heads_, float sm_scaler_, uint32_t input_order_, bool reslink_, bool training_, bool bias_, bool attn_mask_, bool enable_qproj_, bool enable_kproj_, bool enable_vproj_, bool enable_oproj_, uint64_t seed_, float attn_prob_, float out_prob_, size_t handle_, std::string scope_ = {}): num_heads(num_heads_), sm_scaler(sm_scaler_), input_order(input_order_), reslink(reslink_), training(training_), bias(bias_), attn_mask(attn_mask_), enable_qproj(enable_qproj_), enable_kproj(enable_kproj_), enable_vproj(enable_vproj_), enable_oproj(enable_oproj_), seed(seed_), attn_prob(attn_prob_), out_prob(out_prob_), handle(handle_) { set_scope(scope_); } MultiHeadAttn(uint32_t num_heads_, uint32_t embeding_size_, uint32_t k_size_, uint32_t v_size_, uint32_t qproj_size_, uint32_t kproj_size_, uint32_t vproj_size_, uint32_t oproj_size_, bool qbias_, bool kbias_, bool vbias_, bool obias_, float sm_scaler_, uint32_t input_order_, ATTN_MASK_TYPE attn_mask_type_, TENSOR_COMBINATION_TYPE tensor_combination_type_, bool need_weights_, bool add_zero_attn_, bool reslink_, bool training_, uint64_t seed_, float attn_prob_, float out_prob_, size_t handle_, std::string scope_ = {}): num_heads(num_heads_), embeding_size(embeding_size_), k_size(k_size_), v_size(v_size_), qproj_size(qproj_size_), kproj_size(kproj_size_), vproj_size(vproj_size_), oproj_size(oproj_size_), qbias(qbias_), kbias(kbias_), vbias(vbias_), obias(obias_), sm_scaler(sm_scaler_), input_order(input_order_), attn_mask_type(attn_mask_type_), tensor_combination_type(tensor_combination_type_), need_weights(need_weights_), add_zero_attn(add_zero_attn_), reslink(reslink_), training(training_), seed(seed_), attn_prob(attn_prob_), out_prob(out_prob_), handle(handle_) { set_scope(scope_); }
MultiHeadAttn(::megdnn::param::MultiHeadAttn packed_param_0, size_t handle_): num_heads(packed_param_0.num_heads), sm_scaler(packed_param_0.sm_scaler), input_order(packed_param_0.input_order), reslink(packed_param_0.reslink), training(packed_param_0.training), bias(packed_param_0.bias), attn_mask(packed_param_0.attn_mask), enable_qproj(packed_param_0.enable_qproj), enable_kproj(packed_param_0.enable_kproj), enable_vproj(packed_param_0.enable_vproj), enable_oproj(packed_param_0.enable_oproj), seed(packed_param_0.seed), attn_prob(packed_param_0.attn_prob), out_prob(packed_param_0.out_prob), handle(handle_) {} MultiHeadAttn(::megdnn::param::MultiHeadAttn packed_param_0, size_t handle_): num_heads(packed_param_0.num_heads), embeding_size(packed_param_0.embeding_size), k_size(packed_param_0.k_size), v_size(packed_param_0.v_size), qproj_size(packed_param_0.qproj_size), kproj_size(packed_param_0.kproj_size), vproj_size(packed_param_0.vproj_size), oproj_size(packed_param_0.oproj_size), qbias(packed_param_0.qbias), kbias(packed_param_0.kbias), vbias(packed_param_0.vbias), obias(packed_param_0.obias), sm_scaler(packed_param_0.sm_scaler), input_order(packed_param_0.input_order), attn_mask_type(packed_param_0.attn_mask_type), tensor_combination_type(packed_param_0.tensor_combination_type), need_weights(packed_param_0.need_weights), add_zero_attn(packed_param_0.add_zero_attn), reslink(packed_param_0.reslink), training(packed_param_0.training), seed(packed_param_0.seed), attn_prob(packed_param_0.attn_prob), out_prob(packed_param_0.out_prob), handle(handle_) {}
::megdnn::param::MultiHeadAttn param() const { ::megdnn::param::MultiHeadAttn param() const {
return {num_heads, sm_scaler, input_order, reslink, training, bias, attn_mask, enable_qproj, enable_kproj, enable_vproj, enable_oproj, seed, attn_prob, out_prob}; return {num_heads, embeding_size, k_size, v_size, qproj_size, kproj_size, vproj_size, oproj_size, qbias, kbias, vbias, obias, sm_scaler, input_order, attn_mask_type, tensor_combination_type, need_weights, add_zero_attn, reslink, training, seed, attn_prob, out_prob};
} }
}; };
......
...@@ -1479,20 +1479,59 @@ MeshIndexingInst ...@@ -1479,20 +1479,59 @@ MeshIndexingInst
py::class_<MultiHeadAttn, std::shared_ptr<MultiHeadAttn>, OpDef> MultiHeadAttnInst(m, "MultiHeadAttn"); py::class_<MultiHeadAttn, std::shared_ptr<MultiHeadAttn>, OpDef> MultiHeadAttnInst(m, "MultiHeadAttn");
py::enum_<MultiHeadAttn::ATTN_MASK_TYPE>(MultiHeadAttnInst, "ATTN_MASK_TYPE")
.value("NO_MASK", MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK)
.value("DEFAULT_MASK", MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK)
.value("CUDNN_STYLE_MASK", MultiHeadAttn::ATTN_MASK_TYPE::CUDNN_STYLE_MASK)
.value("USER_DEFINED_MASK", MultiHeadAttn::ATTN_MASK_TYPE::USER_DEFINED_MASK)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "NO_MASK") return MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK;
if (str == "DEFAULT_MASK") return MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK;
if (str == "CUDNN_STYLE_MASK") return MultiHeadAttn::ATTN_MASK_TYPE::CUDNN_STYLE_MASK;
if (str == "USER_DEFINED_MASK") return MultiHeadAttn::ATTN_MASK_TYPE::USER_DEFINED_MASK;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, MultiHeadAttn::ATTN_MASK_TYPE>();
py::enum_<MultiHeadAttn::TENSOR_COMBINATION_TYPE>(MultiHeadAttnInst, "TENSOR_COMBINATION_TYPE")
.value("NONE", MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE)
.value("ONLY_MASK", MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_MASK)
.value("ONLY_BIASKV", MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_BIASKV)
.value("ALL", MultiHeadAttn::TENSOR_COMBINATION_TYPE::ALL)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "NONE") return MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE;
if (str == "ONLY_MASK") return MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_MASK;
if (str == "ONLY_BIASKV") return MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_BIASKV;
if (str == "ALL") return MultiHeadAttn::TENSOR_COMBINATION_TYPE::ALL;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, MultiHeadAttn::TENSOR_COMBINATION_TYPE>();
MultiHeadAttnInst MultiHeadAttnInst
.def(py::init<uint32_t, float, uint32_t, bool, bool, bool, bool, bool, bool, bool, bool, uint64_t, float, float, size_t, std::string>(), py::arg("num_heads") = 1, py::arg("sm_scaler") = 1.f, py::arg("input_order") = 0, py::arg("reslink") = false, py::arg("training") = true, py::arg("bias") = false, py::arg("attn_mask") = false, py::arg("enable_qproj") = true, py::arg("enable_kproj") = true, py::arg("enable_vproj") = true, py::arg("enable_oproj") = true, py::arg("seed") = 0, py::arg("attn_prob") = 0.f, py::arg("out_prob") = 0.f, py::arg("handle"), py::arg("scope") = {}) .def(py::init<uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, bool, bool, bool, bool, float, uint32_t, ::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE, ::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE, bool, bool, bool, bool, uint64_t, float, float, size_t, std::string>(), py::arg("num_heads") = 1, py::arg("embeding_size") = 0, py::arg("k_size") = 0, py::arg("v_size") = 0, py::arg("qproj_size") = 0, py::arg("kproj_size") = 0, py::arg("vproj_size") = 0, py::arg("oproj_size") = 0, py::arg("qbias") = false, py::arg("kbias") = false, py::arg("vbias") = false, py::arg("obias") = false, py::arg("sm_scaler") = 1.f, py::arg("input_order") = 0, py::arg("attn_mask_type") = ::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK, py::arg("tensor_combination_type") = ::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE, py::arg("need_weights") = false, py::arg("add_zero_attn") = false, py::arg("reslink") = false, py::arg("training") = true, py::arg("seed") = 0, py::arg("attn_prob") = 0.f, py::arg("out_prob") = 0.f, py::arg("handle"), py::arg("scope") = {})
.def(py::init<>()) .def(py::init<>())
.def_readwrite("num_heads", &MultiHeadAttn::num_heads) .def_readwrite("num_heads", &MultiHeadAttn::num_heads)
.def_readwrite("embeding_size", &MultiHeadAttn::embeding_size)
.def_readwrite("k_size", &MultiHeadAttn::k_size)
.def_readwrite("v_size", &MultiHeadAttn::v_size)
.def_readwrite("qproj_size", &MultiHeadAttn::qproj_size)
.def_readwrite("kproj_size", &MultiHeadAttn::kproj_size)
.def_readwrite("vproj_size", &MultiHeadAttn::vproj_size)
.def_readwrite("oproj_size", &MultiHeadAttn::oproj_size)
.def_readwrite("qbias", &MultiHeadAttn::qbias)
.def_readwrite("kbias", &MultiHeadAttn::kbias)
.def_readwrite("vbias", &MultiHeadAttn::vbias)
.def_readwrite("obias", &MultiHeadAttn::obias)
.def_readwrite("sm_scaler", &MultiHeadAttn::sm_scaler) .def_readwrite("sm_scaler", &MultiHeadAttn::sm_scaler)
.def_readwrite("input_order", &MultiHeadAttn::input_order) .def_readwrite("input_order", &MultiHeadAttn::input_order)
.def_readwrite("attn_mask_type", &MultiHeadAttn::attn_mask_type)
.def_readwrite("tensor_combination_type", &MultiHeadAttn::tensor_combination_type)
.def_readwrite("need_weights", &MultiHeadAttn::need_weights)
.def_readwrite("add_zero_attn", &MultiHeadAttn::add_zero_attn)
.def_readwrite("reslink", &MultiHeadAttn::reslink) .def_readwrite("reslink", &MultiHeadAttn::reslink)
.def_readwrite("training", &MultiHeadAttn::training) .def_readwrite("training", &MultiHeadAttn::training)
.def_readwrite("bias", &MultiHeadAttn::bias)
.def_readwrite("attn_mask", &MultiHeadAttn::attn_mask)
.def_readwrite("enable_qproj", &MultiHeadAttn::enable_qproj)
.def_readwrite("enable_kproj", &MultiHeadAttn::enable_kproj)
.def_readwrite("enable_vproj", &MultiHeadAttn::enable_vproj)
.def_readwrite("enable_oproj", &MultiHeadAttn::enable_oproj)
.def_readwrite("seed", &MultiHeadAttn::seed) .def_readwrite("seed", &MultiHeadAttn::seed)
.def_readwrite("attn_prob", &MultiHeadAttn::attn_prob) .def_readwrite("attn_prob", &MultiHeadAttn::attn_prob)
.def_readwrite("out_prob", &MultiHeadAttn::out_prob) .def_readwrite("out_prob", &MultiHeadAttn::out_prob)
......
...@@ -558,7 +558,6 @@ def RegionRestrictedConvolution: MgbHashableOp<"RegionRestrictedConvolution", [C ...@@ -558,7 +558,6 @@ def RegionRestrictedConvolution: MgbHashableOp<"RegionRestrictedConvolution", [C
def RegionRestrictedConvolutionBackwardData: MgbHashableOp<"RegionRestrictedConvolutionBackwardData", [ConvolutionParam]>; def RegionRestrictedConvolutionBackwardData: MgbHashableOp<"RegionRestrictedConvolutionBackwardData", [ConvolutionParam]>;
def MaskedFill: MgbHashableOp<"MaskedFill", [FillParam]>; def MaskedFill: MgbHashableOp<"MaskedFill", [FillParam]>;
def MultiHeadAttn: MgbHashableOp<"MultiHeadAttn", [MultiHeadAttnParam]> { def MultiHeadAttn: MgbHashableOp<"MultiHeadAttn", [MultiHeadAttnParam]> {
let extraArguments = (ins let extraArguments = (ins
MgbSizeTAddr:$handle MgbSizeTAddr:$handle
...@@ -571,28 +570,54 @@ def MultiHeadAttn: MgbHashableOp<"MultiHeadAttn", [MultiHeadAttnParam]> { ...@@ -571,28 +570,54 @@ def MultiHeadAttn: MgbHashableOp<"MultiHeadAttn", [MultiHeadAttnParam]> {
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash($_self.num_heads), mgb::hash($_self.num_heads),
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash($_self.sm_scaler), mgb::hash($_self.embeding_size),
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash($_self.input_order), mgb::hash($_self.k_size),
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash($_self.reslink), mgb::hash($_self.v_size),
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash($_self.training), mgb::hash($_self.qproj_size),
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash($_self.bias), mgb::hash($_self.kproj_size),
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash($_self.attn_mask), mgb::hash($_self.vproj_size),
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash($_self.enable_qproj), mgb::hash($_self.oproj_size),
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash($_self.enable_kproj), mgb::hash($_self.qbias),
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash($_self.enable_vproj), mgb::hash($_self.kbias),
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash($_self.enable_oproj), mgb::hash($_self.vbias),
mgb::hash_pair_combine( mgb::hash_pair_combine(
mgb::hash($_self.attn_prob), mgb::hash($_self.obias),
mgb::hash($_self.out_prob) mgb::hash_pair_combine(
mgb::hash($_self.sm_scaler),
mgb::hash_pair_combine(
mgb::hash($_self.input_order),
mgb::hash_pair_combine(
mgb::hash($_self.attn_mask_type),
mgb::hash_pair_combine(
mgb::hash($_self.tensor_combination_type),
mgb::hash_pair_combine(
mgb::hash($_self.add_zero_attn),
mgb::hash_pair_combine(
mgb::hash($_self.need_weights),
mgb::hash_pair_combine(
mgb::hash($_self.reslink),
mgb::hash_pair_combine(
mgb::hash($_self.training),
mgb::hash_pair_combine(
mgb::hash($_self.attn_prob),
mgb::hash($_self.out_prob))
)
)
)
)
)
)
)
)
) )
) )
) )
...@@ -608,7 +633,7 @@ def MultiHeadAttn: MgbHashableOp<"MultiHeadAttn", [MultiHeadAttnParam]> { ...@@ -608,7 +633,7 @@ def MultiHeadAttn: MgbHashableOp<"MultiHeadAttn", [MultiHeadAttnParam]> {
) )
); );
}]; }];
let cmpFunction = [{return $0.handle == $1.handle && $0.num_heads == $1.num_heads && $0.sm_scaler == $1.sm_scaler && $0.input_order == $1.input_order && $0.reslink == $1.reslink && $0.training == $1.training && $0.bias == $1.bias && $0.attn_mask == $1.attn_mask && $0.enable_qproj == $1.enable_qproj && $0.enable_kproj == $1.enable_kproj && $0.enable_vproj == $1.enable_vproj && $0.enable_oproj == $1.enable_oproj && $0.attn_prob == $1.attn_prob && $0.out_prob == $1.out_prob;}]; let cmpFunction = [{return $0.handle == $1.handle && $0.num_heads == $1.num_heads && $0.embeding_size == $1.embeding_size && $0.k_size == $1.k_size && $0.v_size == $1.v_size && $0.qproj_size == $1.qproj_size && $0.kproj_size == $1.kproj_size && $0.vproj_size == $1.vproj_size && $0.oproj_size == $1.oproj_size && $0.qbias == $1.qbias && $0.kbias == $1.kbias && $0.vbias == $1.vbias && $0.obias == $1.obias && $0.sm_scaler == $1.sm_scaler && $0.input_order == $1.input_order && $0.reslink == $1.reslink && $0.training == $1.training && $0.need_weights == $1.need_weights && $0.attn_mask_type == $1.attn_mask_type && $0.add_zero_attn == $1.add_zero_attn && $0.tensor_combination_type == $1.tensor_combination_type && $0.attn_prob == $1.attn_prob && $0.out_prob == $1.out_prob;}];
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册