.add_fields('uint32',Doc('input_order','The sequence data layout, allows the user to select 3! = 6 different data layouts or permutations of BEAM, BATCH and TIME dimensions.'),'0')
.add_fields('uint32',Doc('input_order','The sequence data layout, allows the user to select 3! = 6 different data layouts or permutations of BEAM, BATCH and TIME dimensions.'),'0')
.add_enum('ATTN_MASK_TYPE',
Doc('NO_MASK = 0','Indicates that there is no mask.'),
Doc('DEFAULT_MASK = 1','Use the default mask which the upper right triangle of the mask is -inf, and the diagonal and lower left triangle are all 0.'),
Doc('CUDNN_STYLE_MASK = 2','Indicates the use of a cudnn style mask.'),
Doc('USER_DEFINED_MASK = 3','Use the user-defined mask.'),name_field="attn_mask_type")
.add_enum(Doc('TENSOR_COMBINATION_TYPE','Used to determine whether mask tensor and bias_kv tensor exist in the input. Note that bias_kv here is not kbias and vbias in the linear layer, and bias_kv here will be added to the K and V at sequence dimensions, where K and V are the matrices of key and value after projection, and K and V will be used to calculate the attention matrix.'),
Doc('NONE = 0','Indicates that there are no mask tensor and bias_kv tensor in the input.'),
Doc('ONLY_MASK = 1',
'Indicates that there is only mask tensor in input.'),
Doc('ONLY_BIASKV = 2','Indicates that there is only bias_kv tensor in input.'),
Doc('ALL = 3','Indicates that there are mask tensor and bias_kv tensor in the input.'),name_field="tensor_combination_type")
.add_fields('bool',Doc('add_zero_attn','Whether to add a new batch of zeros to the key and value sequences.'),'false')
.add_fields('bool',Doc('need_weights','Whether to return the attention matrix, which is the output result of softmax.'),'false')
.add_fields('bool',Doc('reslink','Whether to add input query to final output.'),'false')
.add_fields('bool',Doc('reslink','Whether to add input query to final output.'),'false')
.add_fields('bool',Doc('training','Whether it is in training mode.'),'true')
.add_fields('bool',Doc('training','Whether it is in training mode.'),'true')
.add_fields('bool',Doc('bias','Whether to add linear bias.'),'false')
.add_fields('bool',Doc('attn_mask','Whether to add attn_mask.'),'false')
),f"Expected `attn_mask` shape to be {expected_shape0} or {expected_shape1} but got {am_shape}"
ifam_dim==3:
expected_shape=(num_heads,q_shape0,k_shape0)
assert(
am_shape==expected_shape
),f"Expected `attn_mask` shape to be {expected_shape} but got {am_shape}"
else:
raiseAssertionError(
f"query should be unbatched 2D or batched 3D tensor but received {q_dim}-D query tensor"
)
returnis_batched
def_canonical_mask(
mask:Optional[Tensor],
mask_name:str,
other_type,
other_name:str,
target_type,
check_other:bool=True,
)->Optional[Tensor]:
ifmaskisnotNone:
_mask_dtype=mask.dtype
_mask_is_float=(
_mask_dtype==np.float16
or_mask_dtype==np.float32
or_mask_dtype==np.float64
)
assert(
_mask_dtype==boolor_mask_is_float
),f"only bool and floating types of {mask_name} are supported"
ifcheck_otherandother_typeisnotNone:
if_mask_dtype!=other_type:
get_logger().warning(
f"Support for mismatched {mask_name} and {other_name} "
"is deprecated. Use same type for both instead."
)
ifnot_mask_is_float:
mask_=zeros_like(mask).astype(target_type)
mask_[mask]=float("-inf")
returnmask_
returnmask
def_merge_masks(
attn_mask:Tensor,
key_padding_mask:Tensor,
query:Tensor,
key:Tensor,
add_bias_kv:bool=False,
add_zero_attn:bool=False,
is_causal:bool=False,
maybe_cudnn_style_mask:bool=False,
num_heads:int=0,
):
r"""
Determine mask type and combine masks if necessary.
Note: This function will continue to improve with the iteration of MHA.
Args:
attn_mask: MHA's attention mask tensor, the shape is :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`
key_padding_mask: MHA's padding mask tensor, the shape is :math:`(N, S)`
query: MHA's query, the shape is :math:`(N, L, E_q)`
key: MHA's key, the shape is :math:`(N, S, E_k)`
add_bias_kv: used to determine whether pad is needed on the sequence dimension of attn_mask and key_padding_mask, from MHA's ``add_bias_kv``.
add_zero_attn: used to determine whether pad is needed on the sequence dimension of attn_mask and key_padding_mask, from MHA's ``add_zero_attn``.
is_causal: MHA's is_causal, is_causal provides a hint that attn_mask is the causal mask.
maybe_cudnn_style_mask: MHA's maybe_cudnn_style_mask, like is_causal, maybe_cudnn_style_mask provides a hint that attn_mask and key_padding_mask is the cudnn style mask.
num_heads: MHA's head number.
Returns:
merged_mask: merged mask, may be None, the shape is :math:`(L, S)`, :math:`(2\cdotL + 2\cdotN)` or :math:`(N\cdot\text{num\_heads}, L, S)`
mask_type: merged mask type ``("no_mask", "default_mask", "cudnn_style_mask" or "user_defined_mask")``
See :class:`~.module.MultiHeadAttn` for more details.
See :class:`~.module.MultiHeadAttn` for more details.
Note: This API is experimental, and there is a possibility of subsequent changes. Currently, only the cuda platform is supported, and if the cudnn version >=8.6.0, the calculation results are completely correct; If the cudnn version >=8.0.4 but <8.6.0, if there is a bias, only the dbias result calculated from the backward is incorrect. If there is no bias, the forward and backward calculations are correct; If the cudnn version is less than 8.0.4, this operator is not supported.
Note: This API is experimental, and there is a possibility of subsequent changes. Currently, only the cuda platform is supported, and if the cudnn version >=8.6.0, the calculation results are completely correct; If the cudnn version >=8.0.4 but <8.6.0, if there is a bias, only the dbias result calculated from the backward is incorrect. If there is no bias, the forward and backward calculations are correct; If the cudnn version is less than 8.0.4, this operator is not supported.
Args:
Args:
query, key, value: map a query and a set of key-value pairs to an output.
query, key, value: map a query and a set of key-value pairs to an output.
attn_drop: probability of an element to be zeroed, used in attention matrix.
attn_drop: probability of an element to be zeroed, used in attention matrix.
out_drop: probability of an element to be zeroed, used in final output.
out_drop: probability of an element to be zeroed, used in final output.
io_weight_bias: input/output projection weight/bias all in one, used for cudnn api.
io_weight_bias: input/output projection weight/bias all in one.
bias: used to indicate a bias in io_weight_bias, used for cudnn api.
The order of arrangement is: query weight, key weight, value weight, out weight, query bias, key bias, value bias, out bias, the following parameters will be used to indicate whether these items exist: qproj_size, kproj_size, vproj_size, oproj_size, qbias, kbias, vbias, obias.
reslink: add input query to final output.
Note: :math:`Y=X@W+B` is used here instead of :math:`Y=X@W^T+B` in pytorch.
qproj_size: indicates the projection size of query weight in io_weight_bias, 0 indicates disabled query projection and no query projection weight.
kproj_size: indicates the projection size of key weight in io_weight_bias, 0 indicates disabled key projection and no key projection weight.
vproj_size: indicates the projection size of value weight in io_weight_bias, 0 indicates disabled value projection and no value projection weight.
oproj_size: indicates the projection size of out weight in io_weight_bias, 0 indicates disabled output projection and no output projection weight.
qbias: indicates whether there is a query bias in io_weight_bias, this parameter is only valid when qproj_size > 0.
kbias: indicates whether there is a key bias in io_weight_bias, this parameter is only valid when kproj_size > 0.
vbias: indicates whether there is a value bias in io_weight_bias, this parameter is only valid when vproj_size > 0.
obias: indicates whether there is a out bias in io_weight_bias, this parameter is only valid when oproj_size > 0.
bias_k, bias_v: the bias of the key and value sequences to be added at sequence dim. distinguished from kbias and vbias, bias_kv here is not kbias and vbias in the linear layer, and bias_kv here will be added to the K and V at sequence dimensions, where K and V are the matrices of key and value after projection, and K and V will be used to calculate the attention matrix.
Note: Should be set to None, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
add_zero_attn: if specified, adds a new batch of zeros to the key and value sequences at sequence dim. Default: ``False``.
Note: should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
key_padding_mask: if specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` to ignore for the purpose of
attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
Note: Should be set to None, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Note: User-defined mask not supported now, only support no mask or default mask, where the upper right triangle is all -inf, and the diagonal and lower left triangle are all 0. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
need_weights: indicates whether to return the attention weight, which is the output result of softmax. Default: `True`
Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
average_attn_weights: if true, indicates that the returned ``attn_weights`` should be averaged across
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
is_causal: if specified, applies a causal mask as attention mask. Default: ``False``
Warning: ``is_causal`` provides a hint that ``attn_mask`` is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility.
maybe_cudnn_style_mask: if specified, applies a cudnn style mask as attention mask. Default: ``False``
Note: In the cudnn style, the shape of the attn_mask is :math:`(2, L)`, and the shape of the key_padding_mask is :math:`(2, N)`.
Warning: like is_causal, maybe_cudnn_style_mask provides a hint that attn_mask and key_padding_mask is a cudnn style mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility. In addition, if the ``_merge_masks`` function returns ``merge_type=cudnn_style_mask``, please ensure that other conditions are correct so that it can run the implementation of cudnn, otherwise an error will be reported.
Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that the underlying implementation only accepts two types of mask type, namely "no_mask" and "default_mask", and we may try to loosen this option after submitting the commit that users can pass in custom attention mask tensors.
reslink: add input query to final output.
Note: It is only valid if the input query is the same as the shape of the output.
training: will apply dropout if is ``True``.
training: will apply dropout if is ``True``.
attn_mask: used to indicate whether to add a mask to the attention matrix.
By default, the upper right triangle of the mask is -inf, and the diagonal and lower left triangle are all 0.
),"when query projection bias is true, query projection weight must be given."
ifkbias:
assert(
kproj_sizeisnotNoneandkproj_size>0
),"when key projection bias is true, key projection weight must be given"
ifvbias:
assert(
vproj_sizeisnotNoneandvproj_size>0
),"when value projection bias is true, value projection weight must be given"
ifobias:
assert(
oproj_sizeisnotNoneandoproj_size>0
),"when output projection bias is true, output projection weight must be given"
unsupport_reason=" The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation."
assertadd_zero_attnisFalse,(
"add_zero_attn should be False, and configuration of this parameter is not supported now."
+unsupport_reason
)
assertkey_padding_maskisNone,(
"key_padding_mask should be None, and configuration of this parameter is not supported now."
+unsupport_reason
)
assertneed_weights==False,(
"need_weights should be set to False, and configuration of this parameter is not supported now."
+unsupport_reason
)
assertaverage_attn_weights==False,(
"average_attn_weights should be set to False, and configuration of this parameter is not supported now."
+unsupport_reason
)
assertmaybe_cudnn_style_mask==False,(
"maybe_cudnn_style_mask should be set to False, and configuration of this parameter is not supported now."
+unsupport_reason
)
assertbias_kisNone,(
"bias_k should be None, and configuration of this parameter is not supported now."
+unsupport_reason
)
assertbias_visNone,(
"bias_v should be None, and configuration of this parameter is not supported now."
Note: This API is experimental, and there is a possibility of subsequent changes. Currently, only the cuda platform is supported, and if the cudnn version >=8.6.0, the calculation results are completely correct; If the cudnn version >=8.0.4 but <8.6.0, if there is a bias, only the dbias result calculated from the backward is incorrect. If there is no bias, the forward and backward calculations are correct; If the cudnn version is less than 8.0.4, this operator is not supported.
Note: This API is experimental, and there is a possibility of subsequent changes. Currently, only the cuda platform is supported, and if the cudnn version >=8.6.0, the calculation results are completely correct; If the cudnn version >=8.0.4 but <8.6.0, if there is a bias, only the dbias result calculated from the backward is incorrect. If there is no bias, the forward and backward calculations are correct; If the cudnn version is less than 8.0.4, this operator is not supported.
When the following conditions are met, you can go to the cudnn backend:
- ``cudnn version`` greater than or equal to 8.0.4 and ``bias`` is ``False`` and ``training`` is ``False``
- ``cudnn version`` greater than or equal to 8.6.0
- ``add_bias_kv`` is ``False``
- ``add_zero_attn`` is ``False``
- ``need_weights`` is ``False``
- ``average_attn_weights`` is ``False``
- ``maybe_cudnn_style_mask`` is ``True`` if support else ``False``
- ``attn_mask`` and ``key_padding_mask`` is cudnn style mask, i.e. the shape of the attn_mask is :math:`(2, L)`, and the shape of the key_padding_mask is :math:`(2, N)`.
- The shape of attn_mask is :math:`(2, L)`, where :math:`(0, :)` elements specify the start index, :math:`(1, :)` elements specify the end index, the start index is inclusive, and the end index is not exclusive. The start index (i.e. elements in `attn_mask[0, x]`) must be less than the corresponding end index (i.e. elements in `attn_mask[1, x]`). The end index must be less than or equal to :math:`S`, where :math:`S` is the source sequence length, :math:`L` is the target sequence length.
- The shape of key_padding_mask is :math:`(2, N)`, where :math:`(0, :)` elements specify the target sequence padding in cudnn style mask and the element must equal to or less than :math:`L`, :math:`(1, :)` elements specify the source sequence padding in cudnn style mask and the element must equal to or less than :math:`S`, where :math:`S` is the source sequence length, :math:`L` is the target sequence length.
- ``qbias``, ``kbias``, ``vbias`` and ``obias`` are equal
Args:
Args:
embed_dim: Total dimension of the model.
embed_dim: Total dimension of the model.
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
attn_dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
out_dropout: Dropout probability on ``output``. Default: ``0.0`` (no dropout).
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
add_bias_kv: If specified, adds bias to the key and value sequences at sequence dim. Default: ``False``.
Different from kbias and vbias, bias_kv here is not kbias and vbias in the linear layer, and bias_kv here will be added to the K and V at sequence dimensions, where K and V are the matrices of key and value after projection, and K and V will be used to calculate the attention matrix.
Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences.
Default: ``False``.
Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
self.unsupport_reason=" The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation."
assert(
assert(
self.head_dim*num_heads==self.embed_dim
self.head_dim*num_heads==self.embed_dim
),"embed_dim must be divisible by num_heads"
),"embed_dim must be divisible by num_heads"
assert(
assertadd_bias_kv==False,(
self._qkv_same_embed_dim
"add_bias_kv should be set to False, and configuration of this parameter is not supported now."
),"it does not support the case where q, k, and v are different."
+self.unsupport_reason
)
assertadd_zero_attn==False,(
"add_zero_attn should be set to False, and configuration of this parameter is not supported now."
query: Query embeddings of shape :math:`(N, L, E_q)`, where :math:`N` is the batch size, :math:`L` is the target sequence length,
query: Query embeddings of shape :math:`(N, L, E_q)`,
and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against
where :math:`N` is the batch size, :math:`L` is the target sequence length, and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against key-value pairs to produce the output. See "Attention Is All You Need" for more details.
key-value pairs to produce the output. See "Attention Is All You Need" for more details.
key: Key embeddings of shape :math:`(N, S, E_k)`,
key: Key embeddings of shape :math:`(N, S, E_k)`, where :math:`N` is the batch size, :math:`S` is the source sequence length, and
where :math:`N` is the batch size, :math:`S` is the source sequence length, and :math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details.
:math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details.
value: Value embeddings of shape :math:`(N, S, E_v)`,
value: Value embeddings of shape :math:`(N, S, E_v)`, where :math:`N` is the batch size, :math:`S` is the source sequence length, and
where :math:`N` is the batch size, :math:`S` is the source sequence length, and :math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details.
:math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details.
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` to ignore for the purpose of
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
Note: Should be set to None, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Note: User-defined mask not supported now, only support no mask or default mask, where the upper right triangle is all -inf, and the diagonal and lower left triangle are all 0. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
need_weights: indicates whether to return the attention weight, which is the output result of softmax. Default: `True`
Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
is_causal: If specified, applies a causal mask as attention mask. Default: ``False``
Warning: ``is_causal`` provides a hint that ``attn_mask`` is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility.
maybe_cudnn_style_mask: if specified, applies a cudnn style mask as attention mask. Default: ``False``
Note: In the cudnn style, the shape of the attn_mask is :math:`(2, L)`, and the shape of the key_padding_mask is :math:`(2, N)`.
Warning: like is_causal, maybe_cudnn_style_mask provides a hint that attn_mask and key_padding_mask is a cudnn style mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility. In addition, if the ``_merge_masks`` function returns ``merge_type=cudnn_style_mask``, please ensure that other conditions are correct so that it can run the implementation of cudnn, otherwise an error will be reported.
Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that the underlying implementation only accepts two types of mask type, namely "no_mask" and "default_mask", and we may try to loosen this option after submitting the commit that users can pass in custom attention mask tensors.
Outputs:
Outputs:
- **attn_output** - Attention outputs of shape :math:`(N, L, E)`,
- **attn_output** - Attention outputs of shape :math:`(N, L, E)`,
where :math:`L` is the target sequence length, :math:`N` is
where :math:`L` is the target sequence length, :math:`N` is
the batch size, and :math:`E` is the embedding dimension ``embed_dim``.
the batch size, and :math:`E` is the embedding dimension ``embed_dim``.
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N * \text{num\_heads}, L, S)`.
Note: Now only None will be returned. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
"""
"""
assertkey_padding_maskisNone,(
"key_padding_mask should be None, and configuration of this parameter is not supported now."
+self.unsupport_reason
)
assertneed_weights==False,(
"need_weights should be set to False, and configuration of this parameter is not supported now."
+self.unsupport_reason
)
assertaverage_attn_weights==False,(
"average_attn_weights should be set to False, and configuration of this parameter is not supported now."
+self.unsupport_reason
)
assertmaybe_cudnn_style_mask==False,(
"maybe_cudnn_style_mask should be set to False, and configuration of this parameter is not supported now."
+self.unsupport_reason
)
returnmulti_head_attention(
returnmulti_head_attention(
query,
query,
...
@@ -145,13 +222,24 @@ class MultiHeadAttention(Module):
...
@@ -145,13 +222,24 @@ class MultiHeadAttention(Module):