diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 33fe4cbace400b96707866681e1b2fc184e42899..17bac96bf1a396e2cbd7d68e55d78bcc19ac57a5 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -1333,16 +1333,34 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), (pdef('MultiHeadAttn') .add_fields('uint32', Doc('num_heads', 'Number of parallel attention heads.'), '1') + .add_fields('uint32', Doc('embeding_size', 'Total dimension of the model.'), '0') + .add_fields('uint32', Doc('k_size', 'Total number of features for keys.'), '0') + .add_fields('uint32', Doc('v_size', 'Total number of features for values.'), '0') + .add_fields('uint32', Doc('qproj_size', 'query weight projection.'), '0') + .add_fields('uint32', Doc('kproj_size', 'key weight projection.'), '0') + .add_fields('uint32', Doc('vproj_size', 'value weight projection.'), '0') + .add_fields('uint32', Doc('oproj_size', 'output weight projection.'), '0') + .add_fields('bool', Doc('qbias', 'Whether to add query bias.'), 'false') + .add_fields('bool', Doc('kbias', 'Whether to add key bias.'), 'false') + .add_fields('bool', Doc('vbias', 'Whether to add value bias.'), 'false') + .add_fields('bool', Doc('obias', 'Whether to add out bias.'), 'false') .add_fields('float32', Doc('sm_scaler', 'Softmax smoothing (1.0 >= smScaler >= 0.0) or sharpening (smScaler > 1.0) coefficient.'), '1.f') .add_fields('uint32', Doc('input_order', 'The sequence data layout, allows the user to select 3! = 6 different data layouts or permutations of BEAM, BATCH and TIME dimensions.'), '0') + .add_enum('ATTN_MASK_TYPE', + Doc('NO_MASK = 0', 'Indicates that there is no mask.'), + Doc('DEFAULT_MASK = 1', 'Use the default mask which the upper right triangle of the mask is -inf, and the diagonal and lower left triangle are all 0.'), + Doc('CUDNN_STYLE_MASK = 2', 'Indicates the use of a cudnn style mask.'), + Doc('USER_DEFINED_MASK = 3', 'Use the user-defined mask.'), name_field="attn_mask_type") + .add_enum(Doc('TENSOR_COMBINATION_TYPE', 'Used to determine whether mask tensor and bias_kv tensor exist in the input. Note that bias_kv here is not kbias and vbias in the linear layer, and bias_kv here will be added to the K and V at sequence dimensions, where K and V are the matrices of key and value after projection, and K and V will be used to calculate the attention matrix.'), + Doc('NONE = 0', 'Indicates that there are no mask tensor and bias_kv tensor in the input.'), + Doc('ONLY_MASK = 1', + 'Indicates that there is only mask tensor in input.'), + Doc('ONLY_BIASKV = 2', 'Indicates that there is only bias_kv tensor in input.'), + Doc('ALL = 3', 'Indicates that there are mask tensor and bias_kv tensor in the input.'), name_field="tensor_combination_type") + .add_fields('bool', Doc('add_zero_attn', 'Whether to add a new batch of zeros to the key and value sequences.'), 'false') + .add_fields('bool', Doc('need_weights', 'Whether to return the attention matrix, which is the output result of softmax.'), 'false') .add_fields('bool', Doc('reslink', 'Whether to add input query to final output.'), 'false') .add_fields('bool', Doc('training', 'Whether it is in training mode.'), 'true') - .add_fields('bool', Doc('bias', 'Whether to add linear bias.'), 'false') - .add_fields('bool', Doc('attn_mask', 'Whether to add attn_mask.'), 'false') - .add_fields('bool', Doc('enable_qproj', 'enable query weight projection.'), 'true') - .add_fields('bool', Doc('enable_kproj', 'enable key weight projection.'), 'true') - .add_fields('bool', Doc('enable_vproj', 'enable value weight projection.'), 'true') - .add_fields('bool', Doc('enable_oproj', 'enable output weight projection.'), 'true') .add_fields('uint64', Doc('seed', 'Random number seed for drop'), '0') .add_fields('float32', Doc('attn_prob', 'Dropout probability on attention, is applied directly to the softmax output'), '0.f') .add_fields('float32', Doc('out_prob', 'Dropout probability on output, alters the multi-head attention output'), '0.f') diff --git a/dnn/src/cuda/multi_head_attn/helper.cpp b/dnn/src/cuda/multi_head_attn/helper.cpp index ab685629242deb8f3682214d5afe778fb135d07a..03e7c8e90f0fe1fcd4aa19a9dd0e8b9bb42c467c 100644 --- a/dnn/src/cuda/multi_head_attn/helper.cpp +++ b/dnn/src/cuda/multi_head_attn/helper.cpp @@ -109,14 +109,14 @@ void MultiHeadAttnStatus::set( kSize = k.shape[2]; vSize = v.shape[2]; numHeads = p.num_heads; - qProjSize = p.enable_qproj ? qSize / numHeads : 0; - kProjSize = p.enable_kproj ? kSize / numHeads : 0; - vProjSize = p.enable_vproj ? vSize / numHeads : 0; - oProjSize = p.enable_oproj ? qSize : 0; - attnMask = p.attn_mask; + qProjSize = p.qproj_size ? qSize / numHeads : 0; + kProjSize = p.kproj_size ? kSize / numHeads : 0; + vProjSize = p.vproj_size ? vSize / numHeads : 0; + oProjSize = p.oproj_size ? qSize : 0; + attnMask = p.attn_mask_type >= param::MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK; cudnnDataType_t cudnn_dtype = to_cudnn_dtype(q.dtype); auto flag = CUDNN_ATTN_QUERYMAP_ONE_TO_ONE; - if (p.bias) + if (p.qbias or p.kbias or p.vbias or p.obias) flag = flag | CUDNN_ATTN_ENABLE_PROJ_BIASES; #if CUDNN_VERSION < 8600 // TODO: CUDNN_VERSION < 8600 and out dropout > 0.0, we need to go to the proxy cuda @@ -134,7 +134,9 @@ void MultiHeadAttnStatus::set( vProjSize, oProjSize, seqLenQ, seqLenK, batchSize, 1)); #endif - auxArray.set(batchSize, seqLenQ, seqLenK, p.attn_mask); + auxArray.set( + batchSize, seqLenQ, seqLenK, + p.attn_mask_type >= param::MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK); if (p.training) cudnnGetMultiHeadAttnBuffers( @@ -157,16 +159,18 @@ bool MultiHeadAttnStatus::is_initialized( return false; if (q.shape[0] != batchSize or q.shape[1] != seqLenQ or k.shape[1] != seqLenK or q.shape[2] != qSize or k.shape[2] != kSize or v.shape[2] != vSize or - attnMask != p.attn_mask or numHeads != p.num_heads) { + attnMask != (p.attn_mask_type >= + param::MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK) or + numHeads != p.num_heads) { return false; } - if ((p.enable_qproj && (qProjSize == 0 or qProjSize != qSize / p.num_heads)) or - (p.enable_kproj && (kProjSize == 0 or kProjSize != kSize / p.num_heads)) or - (p.enable_vproj && (vProjSize == 0 or vProjSize != vSize / p.num_heads)) or - (p.enable_oproj && (oProjSize == 0 or oProjSize != q.shape[2]))) + if ((p.qproj_size && (qProjSize == 0 or qProjSize != qSize / p.num_heads)) or + (p.kproj_size && (kProjSize == 0 or kProjSize != kSize / p.num_heads)) or + (p.vproj_size && (vProjSize == 0 or vProjSize != vSize / p.num_heads)) or + (p.oproj_size && (oProjSize == 0 or oProjSize != q.shape[2]))) return false; - if ((!p.enable_qproj && qProjSize != 0) or (!p.enable_kproj && kProjSize != 0) or - (!p.enable_vproj && vProjSize != 0) or (!p.enable_oproj && oProjSize != 0)) + if ((!p.qproj_size && qProjSize != 0) or (!p.kproj_size && kProjSize != 0) or + (!p.vproj_size && vProjSize != 0) or (!p.oproj_size && oProjSize != 0)) return false; if (!auxArray.is_initialized(batchSize, seqLenQ, seqLenK, attnMask)) return false; diff --git a/dnn/src/cuda/multi_head_attn/opr_impl.cpp b/dnn/src/cuda/multi_head_attn/opr_impl.cpp index f4125cdd18db020a58fb5ca90006befc10a26e31..10ea4d9ac96e1f9d69c198c4d5655dad5480e0b7 100644 --- a/dnn/src/cuda/multi_head_attn/opr_impl.cpp +++ b/dnn/src/cuda/multi_head_attn/opr_impl.cpp @@ -163,7 +163,7 @@ void MultiHeadAttnBackwardImpl::exec( #else #if CUDNN_VERSION < 8600 megdnn_assert( - !param().bias, + !(param().qbias or param().kbias or param().vbias or param().obias), "If the cudnn version is lower than 8.6.0, param().bias must be false, " "but got true, because there is an error in the " "dbias result during the backward calculation."); diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 2e1a6221ffa6719ecc435e03d82800e4fc95c737..69fc6ead2ab0951c22790cb9d011457742cc5b30 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -4,6 +4,8 @@ from builtins import min as _builtins_min from functools import lru_cache from typing import NamedTuple, Optional, Sequence, Tuple, Union +import numpy as np + from ..core import _config from ..core._imperative_rt.core2 import ( Const, @@ -54,6 +56,7 @@ from .tensor import ( squeeze, transpose, zeros, + zeros_like, ) __all__ = [ @@ -2053,6 +2056,221 @@ def region_restricted_conv( return output +def _mha_shape_check( + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + num_heads: int, +): + # Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask` + # and returns if the input is batched or not. + # Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor. + q_dim = query.ndim + k_dim = key.ndim + v_dim = value.ndim + kpm_dim = key_padding_mask.ndim if key_padding_mask is not None else 0 + kpm_shape = key_padding_mask.shape if key_padding_mask is not None else None + am_dim = attn_mask.ndim if attn_mask is not None else 0 + am_shape = attn_mask.shape if attn_mask is not None else None + # Shape check. + if q_dim == 3: + # Batched Inputs + is_batched = True + assert k_dim == 3 and v_dim == 3, ( + "For batched (3-D) `query`, expected `key` and `value` to be 3-D" + f" but found {k_dim}-D and {v_dim}-D tensors respectively" + ) + q_shape0, q_shape1, _ = query.shape + k_shape0, k_shape1, _ = key.shape + v_shape0, v_shape1, _ = value.shape + assert q_shape0 == k_shape0 and k_shape0 == v_shape0, ( + "For batched (3-D) `query`, expected the batch sizes of `query`, `key` and `value` to be equal" + f" but found query batch size is {q_shape0}, key batch size is {k_shape0} and value batch size is {v_shape0} respectively" + ) + assert k_shape1 == v_shape1, ( + "For batched (3-D) `query`, expected the sequence length of `key` and `value` to be equal" + f" but found key seqlen is {k_shape1} and value seqlen is {v_shape1} respectively" + ) + if key_padding_mask is not None: + assert kpm_dim == 2, ( + "For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D" + f" but found {kpm_dim}-D tensor instead" + ) + expected_shape0 = (k_shape0, k_shape1) # norm style + expected_shape1 = (2, k_shape0) # cudnn style + assert expected_shape0 == kpm_shape and expected_shape1 == kpm_shape, ( + f"For batched (3-D) `query`, expected `key_padding_mask.shape` equal {expected_shape0} or {expected_shape1}" + f" but found {kpm_shape} instead" + ) + if attn_mask is not None: + assert am_dim in (2, 3), ( + "For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" + f" but found {am_dim}-D tensor instead" + ) + if am_dim == 2: + expected_shape0 = (q_shape1, k_shape1) # norm style + expected_shape1 = (2, q_shape1) # cudnn style + assert ( + am_shape == expected_shape0 or am_shape == expected_shape1 + ), f"Expected `attn_mask` shape to be {expected_shape0} or {expected_shape1} but got {am_shape}" + if am_dim == 3: + expected_shape = (q_shape0 * num_heads, q_shape1, k_shape1) + assert ( + am_shape == expected_shape + ), f"Expected `attn_mask` shape to be {expected_shape0} but got {am_shape}" + elif q_dim == 2: + # Unbatched Inputs + is_batched = False + assert k_dim == 2 and v_dim == 2, ( + "For unbatched (2-D) `query`, expected `key` and `value` to be 2-D" + f" but found {k_dim}-D and {v_dim}-D tensors respectively" + ) + q_shape0, q_shape1 = query.shape + k_shape0, k_shape1 = key.shape + v_shape0, v_shape1 = value.shape + assert k_shape0 == v_shape0, ( + "For unbatched (3-D) `query`, expected the sequence length of `key` and `value` to be equal" + f" but found key seqlen is {k_shape0} and query seqlen is {v_shape0} respectively" + ) + if key_padding_mask is not None: + assert kpm_dim in (1, 2), ( + "For unbatched (2-D) `query`, expected `key_padding_mask` to be `None`, 1-D or 2-D" + f" but found {kpm_dim}-D tensor instead" + ) + expected_shape0 = k_shape0 # norm style + expected_shape1 = (2, 1) # cudnn style + assert expected_shape0 == kpm_shape or expected_shape1 == kpm_shape, ( + f"For batched (3-D) `query`, expected `key_padding_mask.shape` equal {expected_shape0} or {expected_shape1}" + f" but found {kpm_shape} tensor instead" + ) + if attn_mask is not None: + assert am_dim in (2, 3), ( + "For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" + f" but found {am_dim}-D tensor instead" + ) + if am_dim == 2: + expected_shape0 = (q_shape0, k_shape0) # normal style mask + expected_shape1 = (2, q_shape0) # cudnn style mask + assert ( + am_shape == expected_shape0 or am_shape == expected_shape1 + ), f"Expected `attn_mask` shape to be {expected_shape0} or {expected_shape1} but got {am_shape}" + if am_dim == 3: + expected_shape = (num_heads, q_shape0, k_shape0) + assert ( + am_shape == expected_shape + ), f"Expected `attn_mask` shape to be {expected_shape} but got {am_shape}" + else: + raise AssertionError( + f"query should be unbatched 2D or batched 3D tensor but received {q_dim}-D query tensor" + ) + + return is_batched + + +def _canonical_mask( + mask: Optional[Tensor], + mask_name: str, + other_type, + other_name: str, + target_type, + check_other: bool = True, +) -> Optional[Tensor]: + if mask is not None: + _mask_dtype = mask.dtype + _mask_is_float = ( + _mask_dtype == np.float16 + or _mask_dtype == np.float32 + or _mask_dtype == np.float64 + ) + assert ( + _mask_dtype == bool or _mask_is_float + ), f"only bool and floating types of {mask_name} are supported" + if check_other and other_type is not None: + if _mask_dtype != other_type: + get_logger().warning( + f"Support for mismatched {mask_name} and {other_name} " + "is deprecated. Use same type for both instead." + ) + if not _mask_is_float: + mask_ = zeros_like(mask).astype(target_type) + mask_[mask] = float("-inf") + return mask_ + return mask + + +def _merge_masks( + attn_mask: Tensor, + key_padding_mask: Tensor, + query: Tensor, + key: Tensor, + add_bias_kv: bool = False, + add_zero_attn: bool = False, + is_causal: bool = False, + maybe_cudnn_style_mask: bool = False, + num_heads: int = 0, +): + r""" + Determine mask type and combine masks if necessary. + + Note: This function will continue to improve with the iteration of MHA. + + Args: + attn_mask: MHA's attention mask tensor, the shape is :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)` + key_padding_mask: MHA's padding mask tensor, the shape is :math:`(N, S)` + query: MHA's query, the shape is :math:`(N, L, E_q)` + key: MHA's key, the shape is :math:`(N, S, E_k)` + add_bias_kv: used to determine whether pad is needed on the sequence dimension of attn_mask and key_padding_mask, from MHA's ``add_bias_kv``. + add_zero_attn: used to determine whether pad is needed on the sequence dimension of attn_mask and key_padding_mask, from MHA's ``add_zero_attn``. + is_causal: MHA's is_causal, is_causal provides a hint that attn_mask is the causal mask. + maybe_cudnn_style_mask: MHA's maybe_cudnn_style_mask, like is_causal, maybe_cudnn_style_mask provides a hint that attn_mask and key_padding_mask is the cudnn style mask. + num_heads: MHA's head number. + Returns: + merged_mask: merged mask, may be None, the shape is :math:`(L, S)`, :math:`(2\cdotL + 2\cdotN)` or :math:`(N\cdot\text{num\_heads}, L, S)` + mask_type: merged mask type ``("no_mask", "default_mask", "cudnn_style_mask" or "user_defined_mask")`` + """ + mask_type = "no_mask" + merged_mask = None + seq_qlen = query.shape[1] + seq_klen = key.shape[1] + attn_mask_np = attn_mask.numpy() if attn_mask is not None else None + + # is_causal is used to hint whether to use a causal mask, where the upper right triangle is all -inf, + # and the diagonal and lower left triangle are all 0. But if attn_mask is given, attn_mask is used first. + if is_causal and attn_mask is None and key_padding_mask is None: + # At this point, merged_mask = None + mask_type = "default_mask" + elif is_causal and attn_mask is not None and key_padding_mask is None: + # At this point, merged_mask = attn_mask + default_mask_np = np.triu( + -float("inf") * np.ones((seq_qlen, seq_klen)), k=1 + ).astype("float32") + if (attn_mask_np == default_mask_np).all(): + mask_type = "default_mask" + else: + mask_type = "user_defined_mask" + merged_mask = attn_mask + else: + if attn_mask is not None: + # At this point, merged_mask = attn_mask + default_mask_np = np.triu( + -float("inf") * np.ones((seq_qlen, seq_klen)), k=1 + ).astype("float32") + if ( + attn_mask_np == default_mask_np + and (attn_mask_np == default_mask_np).all() + ): + mask_type = "default_mask" + merged_mask = attn_mask + elif np.all(attn_mask_np == 0): + mask_type = "no_mask" + else: + mask_type = "user_defined_mask" + merged_mask = attn_mask + return merged_mask, mask_type + + def multi_head_attention( query: Tensor, key: Tensor, @@ -2062,28 +2280,39 @@ def multi_head_attention( attn_drop: float, out_drop: float, io_weight_bias: Optional[Tensor], - bias: bool = False, + qproj_size: Optional[int] = None, + kproj_size: Optional[int] = None, + vproj_size: Optional[int] = None, + oproj_size: Optional[int] = None, + qbias: bool = False, + kbias: bool = False, + vbias: bool = False, + obias: bool = False, + bias_k: Optional[Tensor] = None, + bias_v: Optional[Tensor] = None, + add_zero_attn: bool = False, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + need_weights: bool = False, + average_attn_weights: bool = False, + is_causal: bool = False, + maybe_cudnn_style_mask: bool = False, reslink: bool = False, training: bool = True, - attn_mask: bool = False, - enable_qproj: bool = True, - enable_kproj: bool = True, - enable_vproj: bool = True, - enable_oproj: bool = True, ): r"""Allows the model to jointly attend to information from different representation subspaces. See `Attention Is All You Need `_. .. math:: - \text{MultiHeadAttn}\big(q,K,V, W_Q, W_V, W_O\big) = \sum^{nHeads-1}_{i=0}W_{O,i}h_i + \text{MultiHeadAttn}\big(q, k, v, W_Q, W_K, W_V, W_O\big) = \sum^{nHeads-1}_{i=0}W_{O,i}h_i - where :math:`h_i=W_{V,i}V \text{Softmax}\Big( \text{smScaler} \cdot K^TW^T_{K,i}W_{Q,i}q \Big),\text{for }i\text{ = 0 ... nHeads-1}`. + where :math:`h_i=W_{V,i}v \text{Softmax}\Big( \text{smScaler} \cdot k^TW^T_{K,i}W_{Q,i}q \Big),\text{for }i\text{ = 0 ... nHeads-1}`. See :class:`~.module.MultiHeadAttn` for more details. - + Note: This API is experimental, and there is a possibility of subsequent changes. Currently, only the cuda platform is supported, and if the cudnn version >=8.6.0, the calculation results are completely correct; If the cudnn version >=8.0.4 but <8.6.0, if there is a bias, only the dbias result calculated from the backward is incorrect. If there is no bias, the forward and backward calculations are correct; If the cudnn version is less than 8.0.4, this operator is not supported. - + Args: query, key, value: map a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. @@ -2091,21 +2320,133 @@ def multi_head_attention( num_heads: parallel attention heads. attn_drop: probability of an element to be zeroed, used in attention matrix. out_drop: probability of an element to be zeroed, used in final output. - io_weight_bias: input/output projection weight/bias all in one, used for cudnn api. - bias: used to indicate a bias in io_weight_bias, used for cudnn api. - reslink: add input query to final output. + io_weight_bias: input/output projection weight/bias all in one. + The order of arrangement is: query weight, key weight, value weight, out weight, query bias, key bias, value bias, out bias, the following parameters will be used to indicate whether these items exist: qproj_size, kproj_size, vproj_size, oproj_size, qbias, kbias, vbias, obias. + Note: :math:`Y=X@W+B` is used here instead of :math:`Y=X@W^T+B` in pytorch. + qproj_size: indicates the projection size of query weight in io_weight_bias, 0 indicates disabled query projection and no query projection weight. + kproj_size: indicates the projection size of key weight in io_weight_bias, 0 indicates disabled key projection and no key projection weight. + vproj_size: indicates the projection size of value weight in io_weight_bias, 0 indicates disabled value projection and no value projection weight. + oproj_size: indicates the projection size of out weight in io_weight_bias, 0 indicates disabled output projection and no output projection weight. + qbias: indicates whether there is a query bias in io_weight_bias, this parameter is only valid when qproj_size > 0. + kbias: indicates whether there is a key bias in io_weight_bias, this parameter is only valid when kproj_size > 0. + vbias: indicates whether there is a value bias in io_weight_bias, this parameter is only valid when vproj_size > 0. + obias: indicates whether there is a out bias in io_weight_bias, this parameter is only valid when oproj_size > 0. + bias_k, bias_v: the bias of the key and value sequences to be added at sequence dim. distinguished from kbias and vbias, bias_kv here is not kbias and vbias in the linear layer, and bias_kv here will be added to the K and V at sequence dimensions, where K and V are the matrices of key and value after projection, and K and V will be used to calculate the attention matrix. + Note: Should be set to None, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. + add_zero_attn: if specified, adds a new batch of zeros to the key and value sequences at sequence dim. Default: ``False``. + Note: should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. + key_padding_mask: if specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` to ignore for the purpose of + attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. + Note: Should be set to None, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + Note: User-defined mask not supported now, only support no mask or default mask, where the upper right triangle is all -inf, and the diagonal and lower left triangle are all 0. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. + need_weights: indicates whether to return the attention weight, which is the output result of softmax. Default: `True` + Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. + average_attn_weights: if true, indicates that the returned ``attn_weights`` should be averaged across + heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an + effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) + Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. + is_causal: if specified, applies a causal mask as attention mask. Default: ``False`` + Warning: ``is_causal`` provides a hint that ``attn_mask`` is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility. + maybe_cudnn_style_mask: if specified, applies a cudnn style mask as attention mask. Default: ``False`` + Note: In the cudnn style, the shape of the attn_mask is :math:`(2, L)`, and the shape of the key_padding_mask is :math:`(2, N)`. + Warning: like is_causal, maybe_cudnn_style_mask provides a hint that attn_mask and key_padding_mask is a cudnn style mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility. In addition, if the ``_merge_masks`` function returns ``merge_type=cudnn_style_mask``, please ensure that other conditions are correct so that it can run the implementation of cudnn, otherwise an error will be reported. + Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that the underlying implementation only accepts two types of mask type, namely "no_mask" and "default_mask", and we may try to loosen this option after submitting the commit that users can pass in custom attention mask tensors. + reslink: add input query to final output. + Note: It is only valid if the input query is the same as the shape of the output. training: will apply dropout if is ``True``. - attn_mask: used to indicate whether to add a mask to the attention matrix. - By default, the upper right triangle of the mask is -inf, and the diagonal and lower left triangle are all 0. - Default: `True` - enable_qproj: enable query weight projection. Default: ``True``. - enable_kproj: enable key weight projection. Default: ``True``. - enable_vproj: enable value weight projection. Default: ``True``. - enable_oproj: enable output weight projection. Default: ``True``. """ - - head_dim = embed_dim // num_heads + qproj_size = embed_dim if qproj_size is None else qproj_size + kproj_size = embed_dim if kproj_size is None else kproj_size + vproj_size = embed_dim if vproj_size is None else vproj_size + oproj_size = embed_dim if oproj_size is None else oproj_size + if qbias: + assert ( + qproj_size is not None and qproj_size > 0 + ), "when query projection bias is true, query projection weight must be given." + if kbias: + assert ( + kproj_size is not None and kproj_size > 0 + ), "when key projection bias is true, key projection weight must be given" + if vbias: + assert ( + vproj_size is not None and vproj_size > 0 + ), "when value projection bias is true, value projection weight must be given" + if obias: + assert ( + oproj_size is not None and oproj_size > 0 + ), "when output projection bias is true, output projection weight must be given" + unsupport_reason = " The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation." + assert add_zero_attn is False, ( + "add_zero_attn should be False, and configuration of this parameter is not supported now." + + unsupport_reason + ) + assert key_padding_mask is None, ( + "key_padding_mask should be None, and configuration of this parameter is not supported now." + + unsupport_reason + ) + assert need_weights == False, ( + "need_weights should be set to False, and configuration of this parameter is not supported now." + + unsupport_reason + ) + assert average_attn_weights == False, ( + "average_attn_weights should be set to False, and configuration of this parameter is not supported now." + + unsupport_reason + ) + assert maybe_cudnn_style_mask == False, ( + "maybe_cudnn_style_mask should be set to False, and configuration of this parameter is not supported now." + + unsupport_reason + ) + assert bias_k is None, ( + "bias_k should be None, and configuration of this parameter is not supported now." + + unsupport_reason + ) + assert bias_v is None, ( + "bias_v should be None, and configuration of this parameter is not supported now." + + unsupport_reason + ) + head_dim = (qproj_size if qproj_size != 0 else embed_dim) // num_heads smScaler = head_dim ** -0.5 + k_size = key.shape[2] + v_size = value.shape[2] + + is_batched = _mha_shape_check( + query, key, value, key_padding_mask, attn_mask, num_heads + ) + if not is_batched: + query = expand_dims(query, 0) + key = expand_dims(key, 0) + value = expand_dims(value, 0) + if key_padding_mask is not None: + key_padding_mask = expand_dims(key_padding_mask, 0) + + key_padding_mask = _canonical_mask( + mask=key_padding_mask, + mask_name="key_padding_mask", + other_type=attn_mask, + other_name="attn_mask", + target_type=query.dtype, + ) + attn_mask = _canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=query.dtype, + check_other=False, + ) + attn_mask_tensor, attn_mask_type = _merge_masks( + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + query=query, + key=key, + add_bias_kv=bias_k is not None and bias_v is not None, + add_zero_attn=add_zero_attn, + is_causal=is_causal, + maybe_cudnn_style_mask=maybe_cudnn_style_mask, + num_heads=num_heads, + ) op = builtin.MultiHeadAttn( num_heads=num_heads, @@ -2116,16 +2457,25 @@ def multi_head_attention( training=training, input_order=0, seed=_get_global_rng_seed(), - bias=bias, - attn_mask=attn_mask, - enable_qproj=enable_qproj, - enable_kproj=enable_kproj, - enable_vproj=enable_vproj, - enable_oproj=enable_oproj, + attn_mask_type=attn_mask_type, + add_zero_attn=add_zero_attn, + embeding_size=embed_dim, + k_size=k_size, + v_size=v_size, + qproj_size=qproj_size, + kproj_size=kproj_size, + vproj_size=vproj_size, + oproj_size=oproj_size, + qbias=qbias, + kbias=kbias, + vbias=vbias, + obias=obias, + need_weights=need_weights, + tensor_combination_type="none", ) out, reserveSpace = apply(op, query, key, value, io_weight_bias) - return out + return out, None from .loss import * # isort:skip diff --git a/imperative/python/megengine/module/multiheadattn.py b/imperative/python/megengine/module/multiheadattn.py index ad80b7d62e1310578d7fd32a58aa032c1729a5cb..cc26620a874682c7bbbec9dd1d32cf1250771718 100644 --- a/imperative/python/megengine/module/multiheadattn.py +++ b/imperative/python/megengine/module/multiheadattn.py @@ -9,7 +9,7 @@ from megengine import Parameter from ..device import get_cudnn_version, is_cuda_available from ..functional.nn import multi_head_attention from ..tensor import Tensor -from .init import ones_, zeros_ +from .init import ones_, xavier_uniform_, zeros_ from .module import Module @@ -24,19 +24,37 @@ class MultiHeadAttention(Module): where :math:`h_i=W_{V,i}V \text{Softmax}\Big( \text{smScaler} \cdot K^TW^T_{K,i}W_{Q,i}q \Big),\text{for }i\text{ = 0 ... nHeads-1}`. Note: This API is experimental, and there is a possibility of subsequent changes. Currently, only the cuda platform is supported, and if the cudnn version >=8.6.0, the calculation results are completely correct; If the cudnn version >=8.0.4 but <8.6.0, if there is a bias, only the dbias result calculated from the backward is incorrect. If there is no bias, the forward and backward calculations are correct; If the cudnn version is less than 8.0.4, this operator is not supported. + + When the following conditions are met, you can go to the cudnn backend: + + - ``cudnn version`` greater than or equal to 8.0.4 and ``bias`` is ``False`` and ``training`` is ``False`` + - ``cudnn version`` greater than or equal to 8.6.0 + - ``add_bias_kv`` is ``False`` + - ``add_zero_attn`` is ``False`` + - ``need_weights`` is ``False`` + - ``average_attn_weights`` is ``False`` + - ``maybe_cudnn_style_mask`` is ``True`` if support else ``False`` + - ``attn_mask`` and ``key_padding_mask`` is cudnn style mask, i.e. the shape of the attn_mask is :math:`(2, L)`, and the shape of the key_padding_mask is :math:`(2, N)`. + - The shape of attn_mask is :math:`(2, L)`, where :math:`(0, :)` elements specify the start index, :math:`(1, :)` elements specify the end index, the start index is inclusive, and the end index is not exclusive. The start index (i.e. elements in `attn_mask[0, x]`) must be less than the corresponding end index (i.e. elements in `attn_mask[1, x]`). The end index must be less than or equal to :math:`S`, where :math:`S` is the source sequence length, :math:`L` is the target sequence length. + - The shape of key_padding_mask is :math:`(2, N)`, where :math:`(0, :)` elements specify the target sequence padding in cudnn style mask and the element must equal to or less than :math:`L`, :math:`(1, :)` elements specify the source sequence padding in cudnn style mask and the element must equal to or less than :math:`S`, where :math:`S` is the source sequence length, :math:`L` is the target sequence length. + - ``qbias``, ``kbias``, ``vbias`` and ``obias`` are equal + Args: embed_dim: Total dimension of the model. num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). - dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). + attn_dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). + out_dropout: Dropout probability on ``output``. Default: ``0.0`` (no dropout). bias: If specified, adds bias to input / output projection layers. Default: ``True``. + add_bias_kv: If specified, adds bias to the key and value sequences at sequence dim. Default: ``False``. + Different from kbias and vbias, bias_kv here is not kbias and vbias in the linear layer, and bias_kv here will be added to the K and V at sequence dimensions, where K and V are the matrices of key and value after projection, and K and V will be used to calculate the attention matrix. + Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. + add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences. + Default: ``False``. + Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). - enable_qproj: enable query weight projection. Default: ``True``. - enable_kproj: enable key weight projection. Default: ``True``. - enable_vproj: enable value weight projection. Default: ``True``. - enable_oproj: enable output weight projection. Default: ``True``. Examples:: >>> import numpy as np @@ -44,7 +62,7 @@ class MultiHeadAttention(Module): >>> x = Tensor(np.arange(batch_size * seq_len * embed_dim).astype(np.float32).reshape(batch_size, seq_len, embed_dim)) >>> multihead_attn = M.MultiHeadAttention(embed_dim, num_heads) >>> if is_cuda_available() and get_cudnn_version() >= 8004: - ... out = multihead_attn(x, x, x) + ... out = multihead_attn(x, x, x)[0] ... out.numpy().shape ... else: ... print(np.zeros((2,4,4)).shape) @@ -57,84 +75,143 @@ class MultiHeadAttention(Module): num_heads, attn_dropout=0.0, out_dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, kdim=None, vdim=None, - bias=True, - enable_qproj=True, - enable_kproj=True, - enable_vproj=True, - enable_oproj=True, **kwargs ): super().__init__(**kwargs) self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + self.add_bias_kv = add_bias_kv + self.add_zero_attn = add_zero_attn self.num_heads = num_heads self.attn_dropout = attn_dropout self.out_dropout = out_dropout self.head_dim = embed_dim // num_heads + self.unsupport_reason = " The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation." assert ( self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" - assert ( - self._qkv_same_embed_dim - ), "it does not support the case where q, k, and v are different." + assert add_bias_kv == False, ( + "add_bias_kv should be set to False, and configuration of this parameter is not supported now." + + self.unsupport_reason + ) + assert add_zero_attn == False, ( + "add_zero_attn should be set to False, and configuration of this parameter is not supported now." + + self.unsupport_reason + ) + self.bias = bias + self.weight_bias_len = ( + self.embed_dim + self.kdim + self.vdim + self.embed_dim + ) * self.embed_dim + (4 * self.embed_dim if self.bias else 0) - self.enable_qproj = enable_qproj - self.enable_kproj = enable_kproj - self.enable_vproj = enable_vproj - self.enable_oproj = enable_oproj - self.nproj = enable_qproj + enable_kproj + enable_vproj + enable_oproj + self.io_weight_bias = Parameter( + np.empty((1, self.weight_bias_len), dtype="float32") + ) + self.bias_k = ( + Parameter(np.empty((1, 1, embed_dim), dtype="float32")) + if self.add_bias_kv + else None + ) + self.bias_v = ( + Parameter(np.empty((1, 1, embed_dim), dtype="float32")) + if self.add_bias_kv + else None + ) - if self.bias: - io_weight = np.ones((embed_dim, self.nproj * embed_dim)) - io_bias = np.zeros((1, self.nproj * embed_dim)) - self.io_weight_bias = Parameter( - np.concatenate((io_weight, io_bias), axis=0), dtype="float32" - ) - else: - self.io_weight_bias = Parameter( - np.ones((self.nproj * embed_dim, embed_dim), dtype="float32") - ) self.reset_parameters() def reset_parameters(self): self.attn_dropout = 0.0 self.out_dropout = 0.0 + xavier_uniform_(self.io_weight_bias) if self.bias: - io_weight = np.ones((self.embed_dim, self.nproj * self.embed_dim)) - io_bias = np.zeros((1, self.nproj * self.embed_dim)) - self.io_weight_bias._reset(np.concatenate((io_weight, io_bias), axis=0)) + weight_len = ( + self.embed_dim + self.kdim + self.vdim + self.embed_dim + ) * self.embed_dim + self.io_weight_bias[0, weight_len:,] = 0 + + if self.add_bias_kv: + xavier_uniform_(self.bias_k) else: - ones_(self.io_weight_bias) + self.bias_k = None + if self.add_bias_kv: + xavier_uniform_(self.bias_v) + else: + self.bias_v = None def forward( - self, query, key, value, attn_mask: bool = True, + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + need_weights: bool = False, + average_attn_weights: bool = False, + is_causal: bool = False, + maybe_cudnn_style_mask: bool = False, ): r""" Args: - query: Query embeddings of shape :math:`(N, L, E_q)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, - and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against - key-value pairs to produce the output. See "Attention Is All You Need" for more details. - key: Key embeddings of shape :math:`(N, S, E_k)`, where :math:`N` is the batch size, :math:`S` is the source sequence length, and - :math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details. - value: Value embeddings of shape :math:`(N, S, E_v)`, where :math:`N` is the batch size, :math:`S` is the source sequence length, and - :math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details. - attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape - :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, - :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be - broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. - + query: Query embeddings of shape :math:`(N, L, E_q)`, + where :math:`N` is the batch size, :math:`L` is the target sequence length, and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against key-value pairs to produce the output. See "Attention Is All You Need" for more details. + key: Key embeddings of shape :math:`(N, S, E_k)`, + where :math:`N` is the batch size, :math:`S` is the source sequence length, and :math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details. + value: Value embeddings of shape :math:`(N, S, E_v)`, + where :math:`N` is the batch size, :math:`S` is the source sequence length, and :math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details. + key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` to ignore for the purpose of + attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. + Note: Should be set to None, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + Note: User-defined mask not supported now, only support no mask or default mask, where the upper right triangle is all -inf, and the diagonal and lower left triangle are all 0. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. + need_weights: indicates whether to return the attention weight, which is the output result of softmax. Default: `True` + Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across + heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an + effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) + Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. + is_causal: If specified, applies a causal mask as attention mask. Default: ``False`` + Warning: ``is_causal`` provides a hint that ``attn_mask`` is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility. + maybe_cudnn_style_mask: if specified, applies a cudnn style mask as attention mask. Default: ``False`` + Note: In the cudnn style, the shape of the attn_mask is :math:`(2, L)`, and the shape of the key_padding_mask is :math:`(2, N)`. + Warning: like is_causal, maybe_cudnn_style_mask provides a hint that attn_mask and key_padding_mask is a cudnn style mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility. In addition, if the ``_merge_masks`` function returns ``merge_type=cudnn_style_mask``, please ensure that other conditions are correct so that it can run the implementation of cudnn, otherwise an error will be reported. + Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that the underlying implementation only accepts two types of mask type, namely "no_mask" and "default_mask", and we may try to loosen this option after submitting the commit that users can pass in custom attention mask tensors. Outputs: - **attn_output** - Attention outputs of shape :math:`(N, L, E)`, where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the embedding dimension ``embed_dim``. + - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, + returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N * \text{num\_heads}, L, S)`. + Note: Now only None will be returned. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. """ + assert key_padding_mask is None, ( + "key_padding_mask should be None, and configuration of this parameter is not supported now." + + self.unsupport_reason + ) + assert need_weights == False, ( + "need_weights should be set to False, and configuration of this parameter is not supported now." + + self.unsupport_reason + ) + assert average_attn_weights == False, ( + "average_attn_weights should be set to False, and configuration of this parameter is not supported now." + + self.unsupport_reason + ) + assert maybe_cudnn_style_mask == False, ( + "maybe_cudnn_style_mask should be set to False, and configuration of this parameter is not supported now." + + self.unsupport_reason + ) return multi_head_attention( query, @@ -145,13 +222,24 @@ class MultiHeadAttention(Module): self.attn_dropout, self.out_dropout, self.io_weight_bias, - self.bias, + qproj_size=self.embed_dim, + kproj_size=self.embed_dim, + vproj_size=self.embed_dim, + oproj_size=self.embed_dim, + qbias=self.bias, + kbias=self.bias, + vbias=self.bias, + obias=self.bias, + bias_k=self.bias_k, + bias_v=self.bias_v, + add_zero_attn=self.add_zero_attn, training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, attn_mask=attn_mask, - enable_qproj=self.enable_qproj, - enable_kproj=self.enable_kproj, - enable_vproj=self.enable_vproj, - enable_oproj=self.enable_oproj, + average_attn_weights=average_attn_weights, + is_causal=is_causal, + maybe_cudnn_style_mask=maybe_cudnn_style_mask, ) def _module_info_string(self) -> str: diff --git a/imperative/src/impl/ops/rng.cpp b/imperative/src/impl/ops/rng.cpp index 19e012956c39392b483002b65d624df536ba61ed..ca8cd3327255536083384910c94672e6ed315d8e 100644 --- a/imperative/src/impl/ops/rng.cpp +++ b/imperative/src/impl/ops/rng.cpp @@ -296,11 +296,19 @@ struct OpMeth { handle_seed == opdef.seed, "inconsistent multiheadattn seed: dropout op: %lu handle: %lu", handle_seed, opdef.seed); - return {opdef.num_heads, opdef.sm_scaler, opdef.input_order, - opdef.reslink, opdef.training, opdef.bias, - opdef.attn_mask, opdef.enable_qproj, opdef.enable_kproj, - opdef.enable_vproj, opdef.enable_oproj, handle_seed, - opdef.attn_prob, opdef.out_prob}; + + return {opdef.num_heads, opdef.embeding_size, + opdef.k_size, opdef.v_size, + opdef.qproj_size, opdef.kproj_size, + opdef.vproj_size, opdef.oproj_size, + opdef.qbias, opdef.kbias, + opdef.vbias, opdef.obias, + opdef.sm_scaler, opdef.input_order, + opdef.attn_mask_type, opdef.tensor_combination_type, + opdef.add_zero_attn, opdef.need_weights, + opdef.reslink, opdef.training, + handle_seed, opdef.attn_prob, + opdef.out_prob}; } }; diff --git a/imperative/tablegen/generated/enum_macro.h b/imperative/tablegen/generated/enum_macro.h index 56b71baa3a6826113d11c8676e81b3f6dbdcc1eb..b3fbfd2dc4938742fc1f10e9f2dfbeff39b0b2c1 100644 --- a/imperative/tablegen/generated/enum_macro.h +++ b/imperative/tablegen/generated/enum_macro.h @@ -20,6 +20,8 @@ cb(::megdnn::param::CvtColor::Mode); \ cb(::megdnn::param::Elemwise::Mode); \ cb(::megdnn::param::ElemwiseMultiType::Mode); \ + cb(::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE); \ + cb(::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE); \ cb(::megdnn::param::Padding::PaddingMode); \ cb(::megdnn::param::RNNCell::NonlineMode); \ cb(::megdnn::param::ROIAlignV0::Mode); \ diff --git a/imperative/tablegen/generated/hash.txt b/imperative/tablegen/generated/hash.txt index ab7a867652a7546e3a5a7787eb1999710dacef95..3e050f39a392df0992cd4063c7400124868837e3 100644 --- a/imperative/tablegen/generated/hash.txt +++ b/imperative/tablegen/generated/hash.txt @@ -1,7 +1,7 @@ -c5a5d1bd44473912f14cecee3df6409e ../../dnn/scripts/opr_param_defs.py -4ed3e8cbef0fa5f4d6824d8d55dec722 ../../src/core/include/megbrain/ir/ops.td -dc2d4ec8f4f5e203ce0a76bc20f62529 generated/opdef.h.inl -906957f12994d43c69248a6acfefa396 generated/opdef.cpp.inl -8817af8997ba0cc00048e71093755238 generated/opdef.py.inl -c43ae8b706e3f3658fe3cc0f60061981 generated/opdef.cpy.inl -71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h +0a8cd3cd50cadfaae0478ee70621618e ../../dnn/scripts/opr_param_defs.py +9e9636d66694dd7d5a7853247a5406f9 ../../src/core/include/megbrain/ir/ops.td +283dffd0e9cd28db5155c44cf4eda148 generated/opdef.h.inl +5e8d57337c3aec6f4b3b30ef9ba141f8 generated/opdef.cpp.inl +7f470236e4b5b00bdeaec321bc7187b5 generated/opdef.py.inl +003addd357423b880cd06410f5bf624b generated/opdef.cpy.inl +d468302f2d4b113913b76b5a181aae56 generated/enum_macro.h diff --git a/imperative/tablegen/generated/opdef.cpp.inl b/imperative/tablegen/generated/opdef.cpp.inl index bbde2e9b665fdb16eaa930a3819570c29d6fd314..b0f2d2ce49dfb85f131704664d35fd90d011f164 100644 --- a/imperative/tablegen/generated/opdef.cpp.inl +++ b/imperative/tablegen/generated/opdef.cpp.inl @@ -5200,28 +5200,54 @@ size_t MultiHeadAttn_hash_impl(const OpDef& def_) { mgb::hash_pair_combine( mgb::hash(op_.num_heads), mgb::hash_pair_combine( - mgb::hash(op_.sm_scaler), + mgb::hash(op_.embeding_size), mgb::hash_pair_combine( - mgb::hash(op_.input_order), + mgb::hash(op_.k_size), mgb::hash_pair_combine( - mgb::hash(op_.reslink), + mgb::hash(op_.v_size), mgb::hash_pair_combine( - mgb::hash(op_.training), + mgb::hash(op_.qproj_size), mgb::hash_pair_combine( - mgb::hash(op_.bias), + mgb::hash(op_.kproj_size), mgb::hash_pair_combine( - mgb::hash(op_.attn_mask), + mgb::hash(op_.vproj_size), mgb::hash_pair_combine( - mgb::hash(op_.enable_qproj), + mgb::hash(op_.oproj_size), mgb::hash_pair_combine( - mgb::hash(op_.enable_kproj), + mgb::hash(op_.qbias), mgb::hash_pair_combine( - mgb::hash(op_.enable_vproj), + mgb::hash(op_.kbias), mgb::hash_pair_combine( - mgb::hash(op_.enable_oproj), + mgb::hash(op_.vbias), mgb::hash_pair_combine( - mgb::hash(op_.attn_prob), - mgb::hash(op_.out_prob) + mgb::hash(op_.obias), + mgb::hash_pair_combine( + mgb::hash(op_.sm_scaler), + mgb::hash_pair_combine( + mgb::hash(op_.input_order), + mgb::hash_pair_combine( + mgb::hash(op_.attn_mask_type), + mgb::hash_pair_combine( + mgb::hash(op_.tensor_combination_type), + mgb::hash_pair_combine( + mgb::hash(op_.add_zero_attn), + mgb::hash_pair_combine( + mgb::hash(op_.need_weights), + mgb::hash_pair_combine( + mgb::hash(op_.reslink), + mgb::hash_pair_combine( + mgb::hash(op_.training), + mgb::hash_pair_combine( + mgb::hash(op_.attn_prob), + mgb::hash(op_.out_prob)) + ) + ) + ) + ) + ) + ) + ) + ) ) ) ) @@ -5242,22 +5268,63 @@ bool MultiHeadAttn_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { &&b_ = rhs_.cast_final_safe(); static_cast(a_); static_cast(b_); -return a_.handle == b_.handle && a_.num_heads == b_.num_heads && a_.sm_scaler == b_.sm_scaler && a_.input_order == b_.input_order && a_.reslink == b_.reslink && a_.training == b_.training && a_.bias == b_.bias && a_.attn_mask == b_.attn_mask && a_.enable_qproj == b_.enable_qproj && a_.enable_kproj == b_.enable_kproj && a_.enable_vproj == b_.enable_vproj && a_.enable_oproj == b_.enable_oproj && a_.attn_prob == b_.attn_prob && a_.out_prob == b_.out_prob;} +return a_.handle == b_.handle && a_.num_heads == b_.num_heads && a_.embeding_size == b_.embeding_size && a_.k_size == b_.k_size && a_.v_size == b_.v_size && a_.qproj_size == b_.qproj_size && a_.kproj_size == b_.kproj_size && a_.vproj_size == b_.vproj_size && a_.oproj_size == b_.oproj_size && a_.qbias == b_.qbias && a_.kbias == b_.kbias && a_.vbias == b_.vbias && a_.obias == b_.obias && a_.sm_scaler == b_.sm_scaler && a_.input_order == b_.input_order && a_.reslink == b_.reslink && a_.training == b_.training && a_.need_weights == b_.need_weights && a_.attn_mask_type == b_.attn_mask_type && a_.add_zero_attn == b_.add_zero_attn && a_.tensor_combination_type == b_.tensor_combination_type && a_.attn_prob == b_.attn_prob && a_.out_prob == b_.out_prob;} std::vector> MultiHeadAttn_props_impl(const OpDef& def_) { auto&& op_ = def_.cast_final_safe(); static_cast(op_); std::vector> props_; props_.emplace_back("num_heads", std::to_string(op_.num_heads)); + props_.emplace_back("embeding_size", std::to_string(op_.embeding_size)); + props_.emplace_back("k_size", std::to_string(op_.k_size)); + props_.emplace_back("v_size", std::to_string(op_.v_size)); + props_.emplace_back("qproj_size", std::to_string(op_.qproj_size)); + props_.emplace_back("kproj_size", std::to_string(op_.kproj_size)); + props_.emplace_back("vproj_size", std::to_string(op_.vproj_size)); + props_.emplace_back("oproj_size", std::to_string(op_.oproj_size)); + props_.emplace_back("qbias", std::to_string(op_.qbias)); + props_.emplace_back("kbias", std::to_string(op_.kbias)); + props_.emplace_back("vbias", std::to_string(op_.vbias)); + props_.emplace_back("obias", std::to_string(op_.obias)); props_.emplace_back("sm_scaler", std::to_string(op_.sm_scaler)); props_.emplace_back("input_order", std::to_string(op_.input_order)); + switch (op_.attn_mask_type){ + case MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK: + props_.emplace_back("attn_mask_type", "NO_MASK"); + break; + case MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK: + props_.emplace_back("attn_mask_type", "DEFAULT_MASK"); + break; + case MultiHeadAttn::ATTN_MASK_TYPE::CUDNN_STYLE_MASK: + props_.emplace_back("attn_mask_type", "CUDNN_STYLE_MASK"); + break; + case MultiHeadAttn::ATTN_MASK_TYPE::USER_DEFINED_MASK: + props_.emplace_back("attn_mask_type", "USER_DEFINED_MASK"); + break; + default: + props_.emplace_back("attn_mask_type", "INVALID"); + break; + } + switch (op_.tensor_combination_type){ + case MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE: + props_.emplace_back("tensor_combination_type", "NONE"); + break; + case MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_MASK: + props_.emplace_back("tensor_combination_type", "ONLY_MASK"); + break; + case MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_BIASKV: + props_.emplace_back("tensor_combination_type", "ONLY_BIASKV"); + break; + case MultiHeadAttn::TENSOR_COMBINATION_TYPE::ALL: + props_.emplace_back("tensor_combination_type", "ALL"); + break; + default: + props_.emplace_back("tensor_combination_type", "INVALID"); + break; + } + props_.emplace_back("need_weights", std::to_string(op_.need_weights)); + props_.emplace_back("add_zero_attn", std::to_string(op_.add_zero_attn)); props_.emplace_back("reslink", std::to_string(op_.reslink)); props_.emplace_back("training", std::to_string(op_.training)); - props_.emplace_back("bias", std::to_string(op_.bias)); - props_.emplace_back("attn_mask", std::to_string(op_.attn_mask)); - props_.emplace_back("enable_qproj", std::to_string(op_.enable_qproj)); - props_.emplace_back("enable_kproj", std::to_string(op_.enable_kproj)); - props_.emplace_back("enable_vproj", std::to_string(op_.enable_vproj)); - props_.emplace_back("enable_oproj", std::to_string(op_.enable_oproj)); props_.emplace_back("seed", std::to_string(op_.seed)); props_.emplace_back("attn_prob", std::to_string(op_.attn_prob)); props_.emplace_back("out_prob", std::to_string(op_.out_prob)); diff --git a/imperative/tablegen/generated/opdef.cpy.inl b/imperative/tablegen/generated/opdef.cpy.inl index 98a5af8f2005284d566ee74cae7b5f3eb9e3df83..3673cd6ad8db5cd1a5abd33c32710878b203f8ab 100644 --- a/imperative/tablegen/generated/opdef.cpy.inl +++ b/imperative/tablegen/generated/opdef.cpy.inl @@ -15043,6 +15043,176 @@ void _init_py_MeshIndexing(py::module m) { mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MeshIndexing::typeinfo(), &py_type).second); } +template<> struct EnumTrait { + static constexpr const char *name = "MultiHeadAttn.ATTN_MASK_TYPE"; + static constexpr std::underlying_type_t max = 4 - 1; +}; +template<> PyTypeObject* EnumWrapper::type = nullptr; + +template<> const char* +EnumWrapper::members[] = {"NO_MASK", "DEFAULT_MASK", "CUDNN_STYLE_MASK", "USER_DEFINED_MASK"}; + +template<> std::unordered_map +EnumWrapper::mem2value = {{normalize_enum("NO_MASK"), MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK}, {normalize_enum("DEFAULT_MASK"), MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK}, {normalize_enum("CUDNN_STYLE_MASK"), MultiHeadAttn::ATTN_MASK_TYPE::CUDNN_STYLE_MASK}, {normalize_enum("USER_DEFINED_MASK"), MultiHeadAttn::ATTN_MASK_TYPE::USER_DEFINED_MASK}}; +template<> PyObject* EnumWrapper::pyobj_insts[4] = {nullptr}; + +void _init_py_MultiHeadAttn_ATTN_MASK_TYPE(PyTypeObject& py_type) { + auto& e_type = EnumWrapper::type; + + static PyMethodDef tp_methods[] = { + {const_cast("dump"), (PyCFunction)EnumWrapper::py_dump, METH_NOARGS, NULL}, + {NULL} /* Sentinel */ + }; + + static PyType_Slot slots[] = { + {Py_tp_repr, (void*)EnumWrapper::py_repr}, + {Py_tp_richcompare, (void*)EnumWrapper::tp_richcompare}, + {Py_tp_methods, tp_methods}, + + {0, NULL} + }; + static PyType_Spec spec = { + // name + "megengine.core._imperative_rt.ops.MultiHeadAttn.ATTN_MASK_TYPE", + // basicsize + sizeof(EnumWrapper), + // itemsize + 0, + // flags + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, + // slots + slots + }; + e_type = reinterpret_cast(PyType_FromSpec(&spec)); + + mgb_assert( + e_type->tp_setattro( + reinterpret_cast(e_type), + py::cast("__name__").release().ptr(), + py::cast("ATTN_MASK_TYPE").release().ptr()) >= 0); + + mgb_assert( + e_type->tp_setattro( + reinterpret_cast(e_type), + py::cast("__module__").release().ptr(), + py::cast("megengine.core._imperative_rt.ops").release().ptr()) >= 0); + + mgb_assert( + e_type->tp_setattro( + reinterpret_cast(e_type), + py::cast("__qualname__").release().ptr(), + py::cast("MultiHeadAttn.ATTN_MASK_TYPE").release().ptr()) >= 0); +{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "NO_MASK", inst) >= 0); + EnumWrapper::pyobj_insts[0] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "DEFAULT_MASK", inst) >= 0); + EnumWrapper::pyobj_insts[1] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = MultiHeadAttn::ATTN_MASK_TYPE::CUDNN_STYLE_MASK; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "CUDNN_STYLE_MASK", inst) >= 0); + EnumWrapper::pyobj_insts[2] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = MultiHeadAttn::ATTN_MASK_TYPE::USER_DEFINED_MASK; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "USER_DEFINED_MASK", inst) >= 0); + EnumWrapper::pyobj_insts[3] = inst; +} + Py_INCREF(e_type); + mgb_assert(PyDict_SetItemString( + py_type.tp_dict, "ATTN_MASK_TYPE", reinterpret_cast(e_type)) >= 0); +} + +template<> struct EnumTrait { + static constexpr const char *name = "MultiHeadAttn.TENSOR_COMBINATION_TYPE"; + static constexpr std::underlying_type_t max = 4 - 1; +}; +template<> PyTypeObject* EnumWrapper::type = nullptr; + +template<> const char* +EnumWrapper::members[] = {"NONE", "ONLY_MASK", "ONLY_BIASKV", "ALL"}; + +template<> std::unordered_map +EnumWrapper::mem2value = {{normalize_enum("NONE"), MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE}, {normalize_enum("ONLY_MASK"), MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_MASK}, {normalize_enum("ONLY_BIASKV"), MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_BIASKV}, {normalize_enum("ALL"), MultiHeadAttn::TENSOR_COMBINATION_TYPE::ALL}}; +template<> PyObject* EnumWrapper::pyobj_insts[4] = {nullptr}; + +void _init_py_MultiHeadAttn_TENSOR_COMBINATION_TYPE(PyTypeObject& py_type) { + auto& e_type = EnumWrapper::type; + + static PyMethodDef tp_methods[] = { + {const_cast("dump"), (PyCFunction)EnumWrapper::py_dump, METH_NOARGS, NULL}, + {NULL} /* Sentinel */ + }; + + static PyType_Slot slots[] = { + {Py_tp_repr, (void*)EnumWrapper::py_repr}, + {Py_tp_richcompare, (void*)EnumWrapper::tp_richcompare}, + {Py_tp_methods, tp_methods}, + + {0, NULL} + }; + static PyType_Spec spec = { + // name + "megengine.core._imperative_rt.ops.MultiHeadAttn.TENSOR_COMBINATION_TYPE", + // basicsize + sizeof(EnumWrapper), + // itemsize + 0, + // flags + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, + // slots + slots + }; + e_type = reinterpret_cast(PyType_FromSpec(&spec)); + + mgb_assert( + e_type->tp_setattro( + reinterpret_cast(e_type), + py::cast("__name__").release().ptr(), + py::cast("TENSOR_COMBINATION_TYPE").release().ptr()) >= 0); + + mgb_assert( + e_type->tp_setattro( + reinterpret_cast(e_type), + py::cast("__module__").release().ptr(), + py::cast("megengine.core._imperative_rt.ops").release().ptr()) >= 0); + + mgb_assert( + e_type->tp_setattro( + reinterpret_cast(e_type), + py::cast("__qualname__").release().ptr(), + py::cast("MultiHeadAttn.TENSOR_COMBINATION_TYPE").release().ptr()) >= 0); +{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "NONE", inst) >= 0); + EnumWrapper::pyobj_insts[0] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_MASK; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ONLY_MASK", inst) >= 0); + EnumWrapper::pyobj_insts[1] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_BIASKV; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ONLY_BIASKV", inst) >= 0); + EnumWrapper::pyobj_insts[2] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = MultiHeadAttn::TENSOR_COMBINATION_TYPE::ALL; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ALL", inst) >= 0); + EnumWrapper::pyobj_insts[3] = inst; +} + Py_INCREF(e_type); + mgb_assert(PyDict_SetItemString( + py_type.tp_dict, "TENSOR_COMBINATION_TYPE", reinterpret_cast(e_type)) >= 0); +} + PyOpDefBegin(MultiHeadAttn) // { static PyGetSetDef py_getsetters[]; static PyMethodDef tp_methods[]; @@ -15053,16 +15223,25 @@ PyOpDefBegin(MultiHeadAttn) // { std::unordered_map state { {"num_heads", serialization::dump(opdef.num_heads)}, + {"embeding_size", serialization::dump(opdef.embeding_size)}, + {"k_size", serialization::dump(opdef.k_size)}, + {"v_size", serialization::dump(opdef.v_size)}, + {"qproj_size", serialization::dump(opdef.qproj_size)}, + {"kproj_size", serialization::dump(opdef.kproj_size)}, + {"vproj_size", serialization::dump(opdef.vproj_size)}, + {"oproj_size", serialization::dump(opdef.oproj_size)}, + {"qbias", serialization::dump(opdef.qbias)}, + {"kbias", serialization::dump(opdef.kbias)}, + {"vbias", serialization::dump(opdef.vbias)}, + {"obias", serialization::dump(opdef.obias)}, {"sm_scaler", serialization::dump(opdef.sm_scaler)}, {"input_order", serialization::dump(opdef.input_order)}, + {"attn_mask_type", serialization::dump(opdef.attn_mask_type)}, + {"tensor_combination_type", serialization::dump(opdef.tensor_combination_type)}, + {"need_weights", serialization::dump(opdef.need_weights)}, + {"add_zero_attn", serialization::dump(opdef.add_zero_attn)}, {"reslink", serialization::dump(opdef.reslink)}, {"training", serialization::dump(opdef.training)}, - {"bias", serialization::dump(opdef.bias)}, - {"attn_mask", serialization::dump(opdef.attn_mask)}, - {"enable_qproj", serialization::dump(opdef.enable_qproj)}, - {"enable_kproj", serialization::dump(opdef.enable_kproj)}, - {"enable_vproj", serialization::dump(opdef.enable_vproj)}, - {"enable_oproj", serialization::dump(opdef.enable_oproj)}, {"seed", serialization::dump(opdef.seed)}, {"attn_prob", serialization::dump(opdef.attn_prob)}, {"out_prob", serialization::dump(opdef.out_prob)}, @@ -15085,72 +15264,135 @@ PyOpDefBegin(MultiHeadAttn) // { } { - auto&& iter = state.find("sm_scaler"); + auto&& iter = state.find("embeding_size"); if (iter != state.end()) { - opdef.sm_scaler = serialization::load(iter->second); + opdef.embeding_size = serialization::load(iter->second); } } { - auto&& iter = state.find("input_order"); + auto&& iter = state.find("k_size"); if (iter != state.end()) { - opdef.input_order = serialization::load(iter->second); + opdef.k_size = serialization::load(iter->second); } } { - auto&& iter = state.find("reslink"); + auto&& iter = state.find("v_size"); if (iter != state.end()) { - opdef.reslink = serialization::load(iter->second); + opdef.v_size = serialization::load(iter->second); } } { - auto&& iter = state.find("training"); + auto&& iter = state.find("qproj_size"); if (iter != state.end()) { - opdef.training = serialization::load(iter->second); + opdef.qproj_size = serialization::load(iter->second); } } { - auto&& iter = state.find("bias"); + auto&& iter = state.find("kproj_size"); if (iter != state.end()) { - opdef.bias = serialization::load(iter->second); + opdef.kproj_size = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("vproj_size"); + if (iter != state.end()) { + opdef.vproj_size = serialization::load(iter->second); } } { - auto&& iter = state.find("attn_mask"); + auto&& iter = state.find("oproj_size"); if (iter != state.end()) { - opdef.attn_mask = serialization::load(iter->second); + opdef.oproj_size = serialization::load(iter->second); } } { - auto&& iter = state.find("enable_qproj"); + auto&& iter = state.find("qbias"); if (iter != state.end()) { - opdef.enable_qproj = serialization::load(iter->second); + opdef.qbias = serialization::load(iter->second); } } { - auto&& iter = state.find("enable_kproj"); + auto&& iter = state.find("kbias"); if (iter != state.end()) { - opdef.enable_kproj = serialization::load(iter->second); + opdef.kbias = serialization::load(iter->second); } } { - auto&& iter = state.find("enable_vproj"); + auto&& iter = state.find("vbias"); if (iter != state.end()) { - opdef.enable_vproj = serialization::load(iter->second); + opdef.vbias = serialization::load(iter->second); } } { - auto&& iter = state.find("enable_oproj"); + auto&& iter = state.find("obias"); if (iter != state.end()) { - opdef.enable_oproj = serialization::load(iter->second); + opdef.obias = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("sm_scaler"); + if (iter != state.end()) { + opdef.sm_scaler = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("input_order"); + if (iter != state.end()) { + opdef.input_order = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("attn_mask_type"); + if (iter != state.end()) { + opdef.attn_mask_type = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("tensor_combination_type"); + if (iter != state.end()) { + opdef.tensor_combination_type = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("need_weights"); + if (iter != state.end()) { + opdef.need_weights = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("add_zero_attn"); + if (iter != state.end()) { + opdef.add_zero_attn = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("reslink"); + if (iter != state.end()) { + opdef.reslink = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("training"); + if (iter != state.end()) { + opdef.training = serialization::load(iter->second); } } @@ -15190,9 +15432,9 @@ PyOpDefBegin(MultiHeadAttn) // { PyOpDefEnd(MultiHeadAttn) int PyOp(MultiHeadAttn)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { - static const char* kwlist[] = {"num_heads", "sm_scaler", "input_order", "reslink", "training", "bias", "attn_mask", "enable_qproj", "enable_kproj", "enable_vproj", "enable_oproj", "seed", "attn_prob", "out_prob", "handle", "scope", NULL}; - PyObject *num_heads = NULL, *sm_scaler = NULL, *input_order = NULL, *reslink = NULL, *training = NULL, *bias = NULL, *attn_mask = NULL, *enable_qproj = NULL, *enable_kproj = NULL, *enable_vproj = NULL, *enable_oproj = NULL, *seed = NULL, *attn_prob = NULL, *out_prob = NULL, *handle = NULL, *scope = NULL; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOOOOOOOOOOOOO", const_cast(kwlist), &num_heads, &sm_scaler, &input_order, &reslink, &training, &bias, &attn_mask, &enable_qproj, &enable_kproj, &enable_vproj, &enable_oproj, &seed, &attn_prob, &out_prob, &handle, &scope)) + static const char* kwlist[] = {"num_heads", "embeding_size", "k_size", "v_size", "qproj_size", "kproj_size", "vproj_size", "oproj_size", "qbias", "kbias", "vbias", "obias", "sm_scaler", "input_order", "attn_mask_type", "tensor_combination_type", "need_weights", "add_zero_attn", "reslink", "training", "seed", "attn_prob", "out_prob", "handle", "scope", NULL}; + PyObject *num_heads = NULL, *embeding_size = NULL, *k_size = NULL, *v_size = NULL, *qproj_size = NULL, *kproj_size = NULL, *vproj_size = NULL, *oproj_size = NULL, *qbias = NULL, *kbias = NULL, *vbias = NULL, *obias = NULL, *sm_scaler = NULL, *input_order = NULL, *attn_mask_type = NULL, *tensor_combination_type = NULL, *need_weights = NULL, *add_zero_attn = NULL, *reslink = NULL, *training = NULL, *seed = NULL, *attn_prob = NULL, *out_prob = NULL, *handle = NULL, *scope = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOOOOOOOOOOOOOOOOOOOOOO", const_cast(kwlist), &num_heads, &embeding_size, &k_size, &v_size, &qproj_size, &kproj_size, &vproj_size, &oproj_size, &qbias, &kbias, &vbias, &obias, &sm_scaler, &input_order, &attn_mask_type, &tensor_combination_type, &need_weights, &add_zero_attn, &reslink, &training, &seed, &attn_prob, &out_prob, &handle, &scope)) return -1; if (num_heads) { @@ -15204,93 +15446,174 @@ int PyOp(MultiHeadAttn)::py_init(PyObject *self, PyObject *args, PyObject *kwds) } CATCH_ALL(-1) } - if (sm_scaler) { + if (embeding_size) { try { // TODO: remove this guard which is used for pybind11 implicit conversion py::detail::loader_life_support guard{}; - reinterpret_cast(self)->inst().sm_scaler = - py::cast(py::handle(sm_scaler)); + reinterpret_cast(self)->inst().embeding_size = + py::cast(py::handle(embeding_size)); } CATCH_ALL(-1) } - if (input_order) { + if (k_size) { try { // TODO: remove this guard which is used for pybind11 implicit conversion py::detail::loader_life_support guard{}; - reinterpret_cast(self)->inst().input_order = - py::cast(py::handle(input_order)); + reinterpret_cast(self)->inst().k_size = + py::cast(py::handle(k_size)); } CATCH_ALL(-1) } - if (reslink) { + if (v_size) { try { // TODO: remove this guard which is used for pybind11 implicit conversion py::detail::loader_life_support guard{}; - reinterpret_cast(self)->inst().reslink = - py::cast(py::handle(reslink)); + reinterpret_cast(self)->inst().v_size = + py::cast(py::handle(v_size)); } CATCH_ALL(-1) } - if (training) { + if (qproj_size) { try { // TODO: remove this guard which is used for pybind11 implicit conversion py::detail::loader_life_support guard{}; - reinterpret_cast(self)->inst().training = - py::cast(py::handle(training)); + reinterpret_cast(self)->inst().qproj_size = + py::cast(py::handle(qproj_size)); } CATCH_ALL(-1) } - if (bias) { + if (kproj_size) { try { // TODO: remove this guard which is used for pybind11 implicit conversion py::detail::loader_life_support guard{}; - reinterpret_cast(self)->inst().bias = - py::cast(py::handle(bias)); + reinterpret_cast(self)->inst().kproj_size = + py::cast(py::handle(kproj_size)); } CATCH_ALL(-1) } - if (attn_mask) { + if (vproj_size) { try { // TODO: remove this guard which is used for pybind11 implicit conversion py::detail::loader_life_support guard{}; - reinterpret_cast(self)->inst().attn_mask = - py::cast(py::handle(attn_mask)); + reinterpret_cast(self)->inst().vproj_size = + py::cast(py::handle(vproj_size)); } CATCH_ALL(-1) } - if (enable_qproj) { + if (oproj_size) { try { // TODO: remove this guard which is used for pybind11 implicit conversion py::detail::loader_life_support guard{}; - reinterpret_cast(self)->inst().enable_qproj = - py::cast(py::handle(enable_qproj)); + reinterpret_cast(self)->inst().oproj_size = + py::cast(py::handle(oproj_size)); } CATCH_ALL(-1) } - if (enable_kproj) { + if (qbias) { try { // TODO: remove this guard which is used for pybind11 implicit conversion py::detail::loader_life_support guard{}; - reinterpret_cast(self)->inst().enable_kproj = - py::cast(py::handle(enable_kproj)); + reinterpret_cast(self)->inst().qbias = + py::cast(py::handle(qbias)); } CATCH_ALL(-1) } - if (enable_vproj) { + if (kbias) { try { // TODO: remove this guard which is used for pybind11 implicit conversion py::detail::loader_life_support guard{}; - reinterpret_cast(self)->inst().enable_vproj = - py::cast(py::handle(enable_vproj)); + reinterpret_cast(self)->inst().kbias = + py::cast(py::handle(kbias)); } CATCH_ALL(-1) } - if (enable_oproj) { + if (vbias) { try { // TODO: remove this guard which is used for pybind11 implicit conversion py::detail::loader_life_support guard{}; - reinterpret_cast(self)->inst().enable_oproj = - py::cast(py::handle(enable_oproj)); + reinterpret_cast(self)->inst().vbias = + py::cast(py::handle(vbias)); + } CATCH_ALL(-1) + } + + if (obias) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().obias = + py::cast(py::handle(obias)); + } CATCH_ALL(-1) + } + + if (sm_scaler) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().sm_scaler = + py::cast(py::handle(sm_scaler)); + } CATCH_ALL(-1) + } + + if (input_order) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().input_order = + py::cast(py::handle(input_order)); + } CATCH_ALL(-1) + } + + if (attn_mask_type) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().attn_mask_type = + py::cast(py::handle(attn_mask_type)); + } CATCH_ALL(-1) + } + + if (tensor_combination_type) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().tensor_combination_type = + py::cast(py::handle(tensor_combination_type)); + } CATCH_ALL(-1) + } + + if (need_weights) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().need_weights = + py::cast(py::handle(need_weights)); + } CATCH_ALL(-1) + } + + if (add_zero_attn) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().add_zero_attn = + py::cast(py::handle(add_zero_attn)); + } CATCH_ALL(-1) + } + + if (reslink) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().reslink = + py::cast(py::handle(reslink)); + } CATCH_ALL(-1) + } + + if (training) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().training = + py::cast(py::handle(training)); } CATCH_ALL(-1) } @@ -15342,16 +15665,25 @@ int PyOp(MultiHeadAttn)::py_init(PyObject *self, PyObject *args, PyObject *kwds) PyGetSetDef PyOp(MultiHeadAttn)::py_getsetters[] = { {const_cast("num_heads"), py_get_generic(MultiHeadAttn, num_heads), py_set_generic(MultiHeadAttn, num_heads), const_cast("num_heads"), NULL}, + {const_cast("embeding_size"), py_get_generic(MultiHeadAttn, embeding_size), py_set_generic(MultiHeadAttn, embeding_size), const_cast("embeding_size"), NULL}, + {const_cast("k_size"), py_get_generic(MultiHeadAttn, k_size), py_set_generic(MultiHeadAttn, k_size), const_cast("k_size"), NULL}, + {const_cast("v_size"), py_get_generic(MultiHeadAttn, v_size), py_set_generic(MultiHeadAttn, v_size), const_cast("v_size"), NULL}, + {const_cast("qproj_size"), py_get_generic(MultiHeadAttn, qproj_size), py_set_generic(MultiHeadAttn, qproj_size), const_cast("qproj_size"), NULL}, + {const_cast("kproj_size"), py_get_generic(MultiHeadAttn, kproj_size), py_set_generic(MultiHeadAttn, kproj_size), const_cast("kproj_size"), NULL}, + {const_cast("vproj_size"), py_get_generic(MultiHeadAttn, vproj_size), py_set_generic(MultiHeadAttn, vproj_size), const_cast("vproj_size"), NULL}, + {const_cast("oproj_size"), py_get_generic(MultiHeadAttn, oproj_size), py_set_generic(MultiHeadAttn, oproj_size), const_cast("oproj_size"), NULL}, + {const_cast("qbias"), py_get_generic(MultiHeadAttn, qbias), py_set_generic(MultiHeadAttn, qbias), const_cast("qbias"), NULL}, + {const_cast("kbias"), py_get_generic(MultiHeadAttn, kbias), py_set_generic(MultiHeadAttn, kbias), const_cast("kbias"), NULL}, + {const_cast("vbias"), py_get_generic(MultiHeadAttn, vbias), py_set_generic(MultiHeadAttn, vbias), const_cast("vbias"), NULL}, + {const_cast("obias"), py_get_generic(MultiHeadAttn, obias), py_set_generic(MultiHeadAttn, obias), const_cast("obias"), NULL}, {const_cast("sm_scaler"), py_get_generic(MultiHeadAttn, sm_scaler), py_set_generic(MultiHeadAttn, sm_scaler), const_cast("sm_scaler"), NULL}, {const_cast("input_order"), py_get_generic(MultiHeadAttn, input_order), py_set_generic(MultiHeadAttn, input_order), const_cast("input_order"), NULL}, + {const_cast("attn_mask_type"), py_get_generic(MultiHeadAttn, attn_mask_type), py_set_generic(MultiHeadAttn, attn_mask_type), const_cast("attn_mask_type"), NULL}, + {const_cast("tensor_combination_type"), py_get_generic(MultiHeadAttn, tensor_combination_type), py_set_generic(MultiHeadAttn, tensor_combination_type), const_cast("tensor_combination_type"), NULL}, + {const_cast("need_weights"), py_get_generic(MultiHeadAttn, need_weights), py_set_generic(MultiHeadAttn, need_weights), const_cast("need_weights"), NULL}, + {const_cast("add_zero_attn"), py_get_generic(MultiHeadAttn, add_zero_attn), py_set_generic(MultiHeadAttn, add_zero_attn), const_cast("add_zero_attn"), NULL}, {const_cast("reslink"), py_get_generic(MultiHeadAttn, reslink), py_set_generic(MultiHeadAttn, reslink), const_cast("reslink"), NULL}, {const_cast("training"), py_get_generic(MultiHeadAttn, training), py_set_generic(MultiHeadAttn, training), const_cast("training"), NULL}, - {const_cast("bias"), py_get_generic(MultiHeadAttn, bias), py_set_generic(MultiHeadAttn, bias), const_cast("bias"), NULL}, - {const_cast("attn_mask"), py_get_generic(MultiHeadAttn, attn_mask), py_set_generic(MultiHeadAttn, attn_mask), const_cast("attn_mask"), NULL}, - {const_cast("enable_qproj"), py_get_generic(MultiHeadAttn, enable_qproj), py_set_generic(MultiHeadAttn, enable_qproj), const_cast("enable_qproj"), NULL}, - {const_cast("enable_kproj"), py_get_generic(MultiHeadAttn, enable_kproj), py_set_generic(MultiHeadAttn, enable_kproj), const_cast("enable_kproj"), NULL}, - {const_cast("enable_vproj"), py_get_generic(MultiHeadAttn, enable_vproj), py_set_generic(MultiHeadAttn, enable_vproj), const_cast("enable_vproj"), NULL}, - {const_cast("enable_oproj"), py_get_generic(MultiHeadAttn, enable_oproj), py_set_generic(MultiHeadAttn, enable_oproj), const_cast("enable_oproj"), NULL}, {const_cast("seed"), py_get_generic(MultiHeadAttn, seed), py_set_generic(MultiHeadAttn, seed), const_cast("seed"), NULL}, {const_cast("attn_prob"), py_get_generic(MultiHeadAttn, attn_prob), py_set_generic(MultiHeadAttn, attn_prob), const_cast("attn_prob"), NULL}, {const_cast("out_prob"), py_get_generic(MultiHeadAttn, out_prob), py_set_generic(MultiHeadAttn, out_prob), const_cast("out_prob"), NULL}, @@ -15376,7 +15708,7 @@ PyMethodDef PyOp(MultiHeadAttn)::py_init_methoddef = { "__init__", (PyCFunction)PyOp(MultiHeadAttn)::py_init_proxy, METH_VARARGS | METH_KEYWORDS, - "__init__(self, num_heads: int = ..., sm_scaler: float = ..., input_order: int = ..., reslink: bool = ..., training: bool = ..., bias: bool = ..., attn_mask: bool = ..., enable_qproj: bool = ..., enable_kproj: bool = ..., enable_vproj: bool = ..., enable_oproj: bool = ..., seed: int = ..., attn_prob: float = ..., out_prob: float = ..., handle: int = ...) -> None\n" + "__init__(self, num_heads: int = ..., embeding_size: int = ..., k_size: int = ..., v_size: int = ..., qproj_size: int = ..., kproj_size: int = ..., vproj_size: int = ..., oproj_size: int = ..., qbias: bool = ..., kbias: bool = ..., vbias: bool = ..., obias: bool = ..., sm_scaler: float = ..., input_order: int = ..., attn_mask_type: Union[str, ATTN_MASK_TYPE] = ..., tensor_combination_type: Union[str, TENSOR_COMBINATION_TYPE] = ..., need_weights: bool = ..., add_zero_attn: bool = ..., reslink: bool = ..., training: bool = ..., seed: int = ..., attn_prob: float = ..., out_prob: float = ..., handle: int = ...) -> None\n" }; void _init_py_MultiHeadAttn(py::module m) { @@ -15398,7 +15730,9 @@ void _init_py_MultiHeadAttn(py::module m) { PyObject* descr = PyDescr_NewMethod(&PyOpType(MultiHeadAttn), &PyOp(MultiHeadAttn)::py_init_methoddef); PyDict_SetItemString(py_type.tp_dict, "__init__", descr); mgb_assert(PyType_Ready(&py_type) >= 0); - + _init_py_MultiHeadAttn_ATTN_MASK_TYPE(py_type); + _init_py_MultiHeadAttn_TENSOR_COMBINATION_TYPE(py_type); + PyType_Modified(&py_type); m.add_object("MultiHeadAttn", reinterpret_cast(&py_type)); mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MultiHeadAttn::typeinfo(), &py_type).second); diff --git a/imperative/tablegen/generated/opdef.h.inl b/imperative/tablegen/generated/opdef.h.inl index 5f774aaae359105370fe43db18fe10512118414b..7bb498e446818c3cd5898176354582b869ab876e 100644 --- a/imperative/tablegen/generated/opdef.h.inl +++ b/imperative/tablegen/generated/opdef.h.inl @@ -1398,26 +1398,37 @@ class MultiHeadAttn : public OpDefImplBase { MGB_DYN_TYPE_OBJ_FINAL_DECL; public: + using ATTN_MASK_TYPE = ::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE; + using TENSOR_COMBINATION_TYPE = ::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE; uint32_t num_heads = 1; + uint32_t embeding_size = 0; + uint32_t k_size = 0; + uint32_t v_size = 0; + uint32_t qproj_size = 0; + uint32_t kproj_size = 0; + uint32_t vproj_size = 0; + uint32_t oproj_size = 0; + bool qbias = false; + bool kbias = false; + bool vbias = false; + bool obias = false; float sm_scaler = 1.f; uint32_t input_order = 0; + ATTN_MASK_TYPE attn_mask_type = ::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK; + TENSOR_COMBINATION_TYPE tensor_combination_type = ::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE; + bool need_weights = false; + bool add_zero_attn = false; bool reslink = false; bool training = true; - bool bias = false; - bool attn_mask = false; - bool enable_qproj = true; - bool enable_kproj = true; - bool enable_vproj = true; - bool enable_oproj = true; uint64_t seed = 0; float attn_prob = 0.f; float out_prob = 0.f; size_t handle; MultiHeadAttn() = default; - MultiHeadAttn(uint32_t num_heads_, float sm_scaler_, uint32_t input_order_, bool reslink_, bool training_, bool bias_, bool attn_mask_, bool enable_qproj_, bool enable_kproj_, bool enable_vproj_, bool enable_oproj_, uint64_t seed_, float attn_prob_, float out_prob_, size_t handle_, std::string scope_ = {}): num_heads(num_heads_), sm_scaler(sm_scaler_), input_order(input_order_), reslink(reslink_), training(training_), bias(bias_), attn_mask(attn_mask_), enable_qproj(enable_qproj_), enable_kproj(enable_kproj_), enable_vproj(enable_vproj_), enable_oproj(enable_oproj_), seed(seed_), attn_prob(attn_prob_), out_prob(out_prob_), handle(handle_) { set_scope(scope_); } - MultiHeadAttn(::megdnn::param::MultiHeadAttn packed_param_0, size_t handle_): num_heads(packed_param_0.num_heads), sm_scaler(packed_param_0.sm_scaler), input_order(packed_param_0.input_order), reslink(packed_param_0.reslink), training(packed_param_0.training), bias(packed_param_0.bias), attn_mask(packed_param_0.attn_mask), enable_qproj(packed_param_0.enable_qproj), enable_kproj(packed_param_0.enable_kproj), enable_vproj(packed_param_0.enable_vproj), enable_oproj(packed_param_0.enable_oproj), seed(packed_param_0.seed), attn_prob(packed_param_0.attn_prob), out_prob(packed_param_0.out_prob), handle(handle_) {} + MultiHeadAttn(uint32_t num_heads_, uint32_t embeding_size_, uint32_t k_size_, uint32_t v_size_, uint32_t qproj_size_, uint32_t kproj_size_, uint32_t vproj_size_, uint32_t oproj_size_, bool qbias_, bool kbias_, bool vbias_, bool obias_, float sm_scaler_, uint32_t input_order_, ATTN_MASK_TYPE attn_mask_type_, TENSOR_COMBINATION_TYPE tensor_combination_type_, bool need_weights_, bool add_zero_attn_, bool reslink_, bool training_, uint64_t seed_, float attn_prob_, float out_prob_, size_t handle_, std::string scope_ = {}): num_heads(num_heads_), embeding_size(embeding_size_), k_size(k_size_), v_size(v_size_), qproj_size(qproj_size_), kproj_size(kproj_size_), vproj_size(vproj_size_), oproj_size(oproj_size_), qbias(qbias_), kbias(kbias_), vbias(vbias_), obias(obias_), sm_scaler(sm_scaler_), input_order(input_order_), attn_mask_type(attn_mask_type_), tensor_combination_type(tensor_combination_type_), need_weights(need_weights_), add_zero_attn(add_zero_attn_), reslink(reslink_), training(training_), seed(seed_), attn_prob(attn_prob_), out_prob(out_prob_), handle(handle_) { set_scope(scope_); } + MultiHeadAttn(::megdnn::param::MultiHeadAttn packed_param_0, size_t handle_): num_heads(packed_param_0.num_heads), embeding_size(packed_param_0.embeding_size), k_size(packed_param_0.k_size), v_size(packed_param_0.v_size), qproj_size(packed_param_0.qproj_size), kproj_size(packed_param_0.kproj_size), vproj_size(packed_param_0.vproj_size), oproj_size(packed_param_0.oproj_size), qbias(packed_param_0.qbias), kbias(packed_param_0.kbias), vbias(packed_param_0.vbias), obias(packed_param_0.obias), sm_scaler(packed_param_0.sm_scaler), input_order(packed_param_0.input_order), attn_mask_type(packed_param_0.attn_mask_type), tensor_combination_type(packed_param_0.tensor_combination_type), need_weights(packed_param_0.need_weights), add_zero_attn(packed_param_0.add_zero_attn), reslink(packed_param_0.reslink), training(packed_param_0.training), seed(packed_param_0.seed), attn_prob(packed_param_0.attn_prob), out_prob(packed_param_0.out_prob), handle(handle_) {} ::megdnn::param::MultiHeadAttn param() const { - return {num_heads, sm_scaler, input_order, reslink, training, bias, attn_mask, enable_qproj, enable_kproj, enable_vproj, enable_oproj, seed, attn_prob, out_prob}; + return {num_heads, embeding_size, k_size, v_size, qproj_size, kproj_size, vproj_size, oproj_size, qbias, kbias, vbias, obias, sm_scaler, input_order, attn_mask_type, tensor_combination_type, need_weights, add_zero_attn, reslink, training, seed, attn_prob, out_prob}; } }; diff --git a/imperative/tablegen/generated/opdef.py.inl b/imperative/tablegen/generated/opdef.py.inl index b6591c362de4a3e22a8f2b79463ff9703abbe2b1..c93f9bb50c2d4c3fd9eeb4fd5294d947eff7e265 100644 --- a/imperative/tablegen/generated/opdef.py.inl +++ b/imperative/tablegen/generated/opdef.py.inl @@ -1479,20 +1479,59 @@ MeshIndexingInst py::class_, OpDef> MultiHeadAttnInst(m, "MultiHeadAttn"); +py::enum_(MultiHeadAttnInst, "ATTN_MASK_TYPE") + .value("NO_MASK", MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK) + .value("DEFAULT_MASK", MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK) + .value("CUDNN_STYLE_MASK", MultiHeadAttn::ATTN_MASK_TYPE::CUDNN_STYLE_MASK) + .value("USER_DEFINED_MASK", MultiHeadAttn::ATTN_MASK_TYPE::USER_DEFINED_MASK) + .def(py::init([](const std::string& in) { + auto&& str = normalize_enum(in); + if (str == "NO_MASK") return MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK; + if (str == "DEFAULT_MASK") return MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK; + if (str == "CUDNN_STYLE_MASK") return MultiHeadAttn::ATTN_MASK_TYPE::CUDNN_STYLE_MASK; + if (str == "USER_DEFINED_MASK") return MultiHeadAttn::ATTN_MASK_TYPE::USER_DEFINED_MASK; + throw py::cast_error("invalid enum value " + in); + })); +py::implicitly_convertible(); + +py::enum_(MultiHeadAttnInst, "TENSOR_COMBINATION_TYPE") + .value("NONE", MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE) + .value("ONLY_MASK", MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_MASK) + .value("ONLY_BIASKV", MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_BIASKV) + .value("ALL", MultiHeadAttn::TENSOR_COMBINATION_TYPE::ALL) + .def(py::init([](const std::string& in) { + auto&& str = normalize_enum(in); + if (str == "NONE") return MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE; + if (str == "ONLY_MASK") return MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_MASK; + if (str == "ONLY_BIASKV") return MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_BIASKV; + if (str == "ALL") return MultiHeadAttn::TENSOR_COMBINATION_TYPE::ALL; + throw py::cast_error("invalid enum value " + in); + })); +py::implicitly_convertible(); + MultiHeadAttnInst - .def(py::init(), py::arg("num_heads") = 1, py::arg("sm_scaler") = 1.f, py::arg("input_order") = 0, py::arg("reslink") = false, py::arg("training") = true, py::arg("bias") = false, py::arg("attn_mask") = false, py::arg("enable_qproj") = true, py::arg("enable_kproj") = true, py::arg("enable_vproj") = true, py::arg("enable_oproj") = true, py::arg("seed") = 0, py::arg("attn_prob") = 0.f, py::arg("out_prob") = 0.f, py::arg("handle"), py::arg("scope") = {}) + .def(py::init(), py::arg("num_heads") = 1, py::arg("embeding_size") = 0, py::arg("k_size") = 0, py::arg("v_size") = 0, py::arg("qproj_size") = 0, py::arg("kproj_size") = 0, py::arg("vproj_size") = 0, py::arg("oproj_size") = 0, py::arg("qbias") = false, py::arg("kbias") = false, py::arg("vbias") = false, py::arg("obias") = false, py::arg("sm_scaler") = 1.f, py::arg("input_order") = 0, py::arg("attn_mask_type") = ::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK, py::arg("tensor_combination_type") = ::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE, py::arg("need_weights") = false, py::arg("add_zero_attn") = false, py::arg("reslink") = false, py::arg("training") = true, py::arg("seed") = 0, py::arg("attn_prob") = 0.f, py::arg("out_prob") = 0.f, py::arg("handle"), py::arg("scope") = {}) .def(py::init<>()) .def_readwrite("num_heads", &MultiHeadAttn::num_heads) + .def_readwrite("embeding_size", &MultiHeadAttn::embeding_size) + .def_readwrite("k_size", &MultiHeadAttn::k_size) + .def_readwrite("v_size", &MultiHeadAttn::v_size) + .def_readwrite("qproj_size", &MultiHeadAttn::qproj_size) + .def_readwrite("kproj_size", &MultiHeadAttn::kproj_size) + .def_readwrite("vproj_size", &MultiHeadAttn::vproj_size) + .def_readwrite("oproj_size", &MultiHeadAttn::oproj_size) + .def_readwrite("qbias", &MultiHeadAttn::qbias) + .def_readwrite("kbias", &MultiHeadAttn::kbias) + .def_readwrite("vbias", &MultiHeadAttn::vbias) + .def_readwrite("obias", &MultiHeadAttn::obias) .def_readwrite("sm_scaler", &MultiHeadAttn::sm_scaler) .def_readwrite("input_order", &MultiHeadAttn::input_order) + .def_readwrite("attn_mask_type", &MultiHeadAttn::attn_mask_type) + .def_readwrite("tensor_combination_type", &MultiHeadAttn::tensor_combination_type) + .def_readwrite("need_weights", &MultiHeadAttn::need_weights) + .def_readwrite("add_zero_attn", &MultiHeadAttn::add_zero_attn) .def_readwrite("reslink", &MultiHeadAttn::reslink) .def_readwrite("training", &MultiHeadAttn::training) - .def_readwrite("bias", &MultiHeadAttn::bias) - .def_readwrite("attn_mask", &MultiHeadAttn::attn_mask) - .def_readwrite("enable_qproj", &MultiHeadAttn::enable_qproj) - .def_readwrite("enable_kproj", &MultiHeadAttn::enable_kproj) - .def_readwrite("enable_vproj", &MultiHeadAttn::enable_vproj) - .def_readwrite("enable_oproj", &MultiHeadAttn::enable_oproj) .def_readwrite("seed", &MultiHeadAttn::seed) .def_readwrite("attn_prob", &MultiHeadAttn::attn_prob) .def_readwrite("out_prob", &MultiHeadAttn::out_prob) diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index c1f8eff0284b3fe810cdb824a5f0691fa3a05892..3384a7923696d6054281b1069cc1269492d40e93 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -558,7 +558,6 @@ def RegionRestrictedConvolution: MgbHashableOp<"RegionRestrictedConvolution", [C def RegionRestrictedConvolutionBackwardData: MgbHashableOp<"RegionRestrictedConvolutionBackwardData", [ConvolutionParam]>; def MaskedFill: MgbHashableOp<"MaskedFill", [FillParam]>; - def MultiHeadAttn: MgbHashableOp<"MultiHeadAttn", [MultiHeadAttnParam]> { let extraArguments = (ins MgbSizeTAddr:$handle @@ -571,28 +570,54 @@ def MultiHeadAttn: MgbHashableOp<"MultiHeadAttn", [MultiHeadAttnParam]> { mgb::hash_pair_combine( mgb::hash($_self.num_heads), mgb::hash_pair_combine( - mgb::hash($_self.sm_scaler), + mgb::hash($_self.embeding_size), mgb::hash_pair_combine( - mgb::hash($_self.input_order), + mgb::hash($_self.k_size), mgb::hash_pair_combine( - mgb::hash($_self.reslink), + mgb::hash($_self.v_size), mgb::hash_pair_combine( - mgb::hash($_self.training), + mgb::hash($_self.qproj_size), mgb::hash_pair_combine( - mgb::hash($_self.bias), + mgb::hash($_self.kproj_size), mgb::hash_pair_combine( - mgb::hash($_self.attn_mask), + mgb::hash($_self.vproj_size), mgb::hash_pair_combine( - mgb::hash($_self.enable_qproj), + mgb::hash($_self.oproj_size), mgb::hash_pair_combine( - mgb::hash($_self.enable_kproj), + mgb::hash($_self.qbias), mgb::hash_pair_combine( - mgb::hash($_self.enable_vproj), + mgb::hash($_self.kbias), mgb::hash_pair_combine( - mgb::hash($_self.enable_oproj), + mgb::hash($_self.vbias), mgb::hash_pair_combine( - mgb::hash($_self.attn_prob), - mgb::hash($_self.out_prob) + mgb::hash($_self.obias), + mgb::hash_pair_combine( + mgb::hash($_self.sm_scaler), + mgb::hash_pair_combine( + mgb::hash($_self.input_order), + mgb::hash_pair_combine( + mgb::hash($_self.attn_mask_type), + mgb::hash_pair_combine( + mgb::hash($_self.tensor_combination_type), + mgb::hash_pair_combine( + mgb::hash($_self.add_zero_attn), + mgb::hash_pair_combine( + mgb::hash($_self.need_weights), + mgb::hash_pair_combine( + mgb::hash($_self.reslink), + mgb::hash_pair_combine( + mgb::hash($_self.training), + mgb::hash_pair_combine( + mgb::hash($_self.attn_prob), + mgb::hash($_self.out_prob)) + ) + ) + ) + ) + ) + ) + ) + ) ) ) ) @@ -608,7 +633,7 @@ def MultiHeadAttn: MgbHashableOp<"MultiHeadAttn", [MultiHeadAttnParam]> { ) ); }]; - let cmpFunction = [{return $0.handle == $1.handle && $0.num_heads == $1.num_heads && $0.sm_scaler == $1.sm_scaler && $0.input_order == $1.input_order && $0.reslink == $1.reslink && $0.training == $1.training && $0.bias == $1.bias && $0.attn_mask == $1.attn_mask && $0.enable_qproj == $1.enable_qproj && $0.enable_kproj == $1.enable_kproj && $0.enable_vproj == $1.enable_vproj && $0.enable_oproj == $1.enable_oproj && $0.attn_prob == $1.attn_prob && $0.out_prob == $1.out_prob;}]; + let cmpFunction = [{return $0.handle == $1.handle && $0.num_heads == $1.num_heads && $0.embeding_size == $1.embeding_size && $0.k_size == $1.k_size && $0.v_size == $1.v_size && $0.qproj_size == $1.qproj_size && $0.kproj_size == $1.kproj_size && $0.vproj_size == $1.vproj_size && $0.oproj_size == $1.oproj_size && $0.qbias == $1.qbias && $0.kbias == $1.kbias && $0.vbias == $1.vbias && $0.obias == $1.obias && $0.sm_scaler == $1.sm_scaler && $0.input_order == $1.input_order && $0.reslink == $1.reslink && $0.training == $1.training && $0.need_weights == $1.need_weights && $0.attn_mask_type == $1.attn_mask_type && $0.add_zero_attn == $1.add_zero_attn && $0.tensor_combination_type == $1.tensor_combination_type && $0.attn_prob == $1.attn_prob && $0.out_prob == $1.out_prob;}]; }