提交 5bcb8435 编写于 作者: M Megvii Engine Team

fix(dnn): improve the c++ interface of mha

GitOrigin-RevId: 1dba138d4f5579ba67d7dae17cc7b2dac3b93c93
上级 04dc29e4
...@@ -2580,66 +2580,139 @@ class MultiHeadAttnBase : public OperatorBase { ...@@ -2580,66 +2580,139 @@ class MultiHeadAttnBase : public OperatorBase {
}; };
class MultiHeadAttnForward : public MultiHeadAttnBase { class MultiHeadAttnForward : public MultiHeadAttnBase {
DEF_OPR_IMPL(MultiHeadAttnForward, MultiHeadAttnBase, 4, 2); DEF_OPR_IMPL(MultiHeadAttnForward, MultiHeadAttnBase, 7, 4);
public: public:
/**
* \param[in] queries (N, L, E_q), where N is the batch size, L is the target
* sequence length, and E_q is the query embedding dimension embed_dim.
* \param[in] keys (N, S, E_k), where N is the batch size, S is the source
* sequence length, and E_k is the key embedding dimension k_dim.
* \param[in] values (N, S, E_v), where N is the batch size, S is the source
* sequence length, and E_v is the value embedding dimension v_dim.
* \param[in] qkvo_weight_bias, input/output projection weight/bias all in one.
* The order of arrangement is: query weight, key weight, value weight,
* out weight, query bias, key bias, value bias, out bias, the following parameters
* in param will be used to indicate whether these items exist: qproj_size,
* kproj_size, vproj_size, oproj_size, qbias, kbias, vbias, obias.
* Note: Y=X@W+B is used here instead of Y=X@W^T+B in pytorch.
* \param[in] attn_mask, (N*num_heads, L, S) or (L, S), where N is the batch size,
* num_heads is the number of parallel attention heads, L is the target sequence
* length, and S is the source sequence length. attention mask is obtained by
* combining attn_mask, key_padding_mask, is_causal and maybe_cudnn_style_mask by
* mge.functional._merge_masks.
* \param[in] bias_k, (1, 1, kproj_size), where kproj_size is the projected
* dimension of key weight, if kproj_size == 0, will be the key embedding dimension
* k_dim.
* Note: bias_k and bias_v are the bias of the K and V sequences to be added at
* sequence dim, distinguished from kbias and vbias, bias_kv here is not kbias and
* vbias in the linear layer, and bias_kv here will be added to the K and V at
* sequence dimensions, where K and V are the matrices of key and value after
* projection, and K and V will be used to calculate the attention matrix.
* \param[in] bias_v, (1, 1, vproj_size), where vproj_size is the projected
* dimension of value weight, if vproj_size == 0, will be the value embedding
* dimension v_dim.
* Note: see bias_k.
* \param[out] out, (N, S, oproj_size), where N is
* the batch size, S is the source sequence length, and oproj_size is the projected
* dimension of output weight, if oproj_size == 0, will be the projected
* dimension of value weight vproj_size, but if vproj_size == 0, will be the value
* embedding dimension v_dim.
* \param[out] attn_weight, (N * num_heads, L, S), where N is the batch size,
* num_heads is the number of parallel attention heads, L is the target sequence
* length, and S is the source sequence length.
* Note: attn_weight is the output of softmax.
* \param[out] mask_reservespace, when param.training=true, we need this output to
* save the mask of attention dropout and output dropout.
* \param[out] othr_reservespace, when param.training=true, we need this output to
* save the intermediate calculation results.
*/
virtual void exec( virtual void exec(
_megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values, _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values,
_megdnn_tensor_in wqkv, _megdnn_tensor_out out, _megdnn_tensor_in qkvo_weight_bias, _megdnn_tensor_in attn_mask,
_megdnn_tensor_out reserveSpace, _megdnn_workspace workspace) = 0; _megdnn_tensor_in bias_k, _megdnn_tensor_in bias_v, _megdnn_tensor_out out,
MGE_WIN_DECLSPEC_FUC void deduce_layout( _megdnn_tensor_out attn_weight, _megdnn_tensor_out mask_reservespace,
_megdnn_tensor_out othr_reservespace, _megdnn_workspace workspace) = 0;
virtual void deduce_layout(
const TensorLayout& queries, const TensorLayout& keys, const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv, TensorLayout& out, const TensorLayout& values, const TensorLayout& qkvo_weight_bias,
TensorLayout& reserveSpace); const TensorLayout& attn_mask, const TensorLayout& bias_k,
const TensorLayout& bias_v, TensorLayout& out, TensorLayout& attn_weight,
TensorLayout& mask_reservespace, TensorLayout& othr_reservespace) = 0;
virtual size_t get_workspace_in_bytes( virtual size_t get_workspace_in_bytes(
const TensorLayout& queries, const TensorLayout& keys, const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& values, const TensorLayout& qkvo_weight_bias,
const TensorLayout& out, const TensorLayout& reserveSpace) = 0; const TensorLayout& attn_mask, const TensorLayout& bias_k,
virtual size_t get_reservespace_in_bytes( const TensorLayout& bias_v, const TensorLayout& out,
const TensorLayout& attn_weight, const TensorLayout& mask_reservespace,
const TensorLayout& othr_reservespace) = 0;
virtual size_t get_mask_reservespace_in_bytes(
const TensorLayout& queries, const TensorLayout& keys, const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& values, const TensorLayout& qkvo_weight_bias,
const TensorLayout& out, const TensorLayout& reserveSpace) = 0; const TensorLayout& attn_mask, const TensorLayout& bias_k,
const TensorLayout& bias_v, const TensorLayout& out,
const TensorLayout& attn_weight, const TensorLayout& mask_reservespace,
const TensorLayout& othr_reservespace) = 0;
virtual size_t get_othr_reservespace_in_bytes(
const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& qkvo_weight_bias,
const TensorLayout& attn_mask, const TensorLayout& bias_k,
const TensorLayout& bias_v, const TensorLayout& out,
const TensorLayout& attn_weight, const TensorLayout& mask_reservespace,
const TensorLayout& othr_reservespace) = 0;
protected: protected:
void check_exec( void check_exec(
const TensorLayout& queries, const TensorLayout& keys, const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& values, const TensorLayout& qkvo_weight_bias,
const TensorLayout& out, const TensorLayout& reserveSpace, const TensorLayout& attn_mask, const TensorLayout& bias_k,
size_t workspace_in_bytes); const TensorLayout& bias_v, const TensorLayout& out,
const TensorLayout& attn_weight, const TensorLayout& mask_reservespace,
const TensorLayout& othr_reservespace, size_t workspace_in_bytes);
}; };
using MultiHeadAttn = MultiHeadAttnForward; using MultiHeadAttn = MultiHeadAttnForward;
class MultiHeadAttnBackward : public MultiHeadAttnBase { class MultiHeadAttnBackward : public MultiHeadAttnBase {
DEF_OPR_IMPL(MultiHeadAttnBackward, MultiHeadAttnBase, 6, 4); DEF_OPR_IMPL(MultiHeadAttnBackward, MultiHeadAttnBase, 9, 6);
public: public:
virtual void exec( virtual void exec(
_megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys,
_megdnn_tensor_in values, _megdnn_tensor_in wqkv, _megdnn_tensor_in values, _megdnn_tensor_in qkvo_weight_bias,
_megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries, _megdnn_tensor_in attn_mask, _megdnn_tensor_in attn_weight,
_megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues, _megdnn_tensor_in mask_reservespace, _megdnn_tensor_in othr_reservespace,
_megdnn_tensor_out dweights, _megdnn_workspace workspace) = 0; _megdnn_tensor_out dqueries, _megdnn_tensor_out dkeys,
_megdnn_tensor_out dvalues, _megdnn_tensor_out dqkvo_weight_bias,
_megdnn_tensor_out dbias_k, _megdnn_tensor_out dbias_v,
_megdnn_workspace workspace) = 0;
MGE_WIN_DECLSPEC_FUC void deduce_layout( MGE_WIN_DECLSPEC_FUC void deduce_layout(
const TensorLayout& diff, const TensorLayout& queries, const TensorLayout& diff, const TensorLayout& queries,
const TensorLayout& keys, const TensorLayout& values, const TensorLayout& keys, const TensorLayout& values,
const TensorLayout& wqkv, const TensorLayout& reserveSpace, const TensorLayout& qkvo_weight_bias, const TensorLayout& attn_mask,
TensorLayout& dqueries, TensorLayout& dkeys, TensorLayout& dvalues, const TensorLayout& attn_weight, const TensorLayout& mask_reservespace,
TensorLayout& dweights); const TensorLayout& othr_reservespace, TensorLayout& dqueries,
TensorLayout& dkeys, TensorLayout& dvalues, TensorLayout& dqkvo_weight_bias,
TensorLayout& dbias_k, TensorLayout& dbias_v);
virtual size_t get_workspace_in_bytes( virtual size_t get_workspace_in_bytes(
const TensorLayout& diff, const TensorLayout& queries, const TensorLayout& diff, const TensorLayout& queries,
const TensorLayout& keys, const TensorLayout& values, const TensorLayout& keys, const TensorLayout& values,
const TensorLayout& wqkv, const TensorLayout& reserveSpace, const TensorLayout& qkvo_weight_bias, const TensorLayout& attn_mask,
const TensorLayout& dqueries, const TensorLayout& dkeys, const TensorLayout& attn_weight, const TensorLayout& mask_reservespace,
const TensorLayout& dvalues, const TensorLayout& dweights) = 0; const TensorLayout& othr_reservespace, const TensorLayout& dqueries,
const TensorLayout& dkeys, const TensorLayout& dvalues,
const TensorLayout& dqkvo_weight_bias, const TensorLayout& dbias_k,
const TensorLayout& dbias_v) = 0;
protected: protected:
void check_exec( void check_exec(
const TensorLayout& diff, const TensorLayout& queries, const TensorLayout& diff, const TensorLayout& queries,
const TensorLayout& keys, const TensorLayout& values, const TensorLayout& keys, const TensorLayout& values,
const TensorLayout& wqkv, const TensorLayout& reserveSpace, const TensorLayout& qkvo_weight_bias, const TensorLayout& attn_mask,
const TensorLayout& dqueries, const TensorLayout& dkeys, const TensorLayout& attn_weight, const TensorLayout& mask_reservespace,
const TensorLayout& dvalues, const TensorLayout& dweights, const TensorLayout& othr_reservespace, const TensorLayout& dqueries,
size_t workspace_in_bytes); const TensorLayout& dkeys, const TensorLayout& dvalues,
const TensorLayout& dqkvo_weight_bias, const TensorLayout& dbias_k,
const TensorLayout& dbias_v, size_t workspace_in_bytes);
}; };
} // namespace megdnn } // namespace megdnn
#include "megdnn/internal/opr_header_epilogue.h" #include "megdnn/internal/opr_header_epilogue.h"
......
此差异已折叠。
...@@ -148,8 +148,8 @@ DEF(RegionRestrictedConvolutionBackwardFilter, 5, true, false); ...@@ -148,8 +148,8 @@ DEF(RegionRestrictedConvolutionBackwardFilter, 5, true, false);
DEF(GroupNormForward, 6, true, true); DEF(GroupNormForward, 6, true, true);
DEF(GroupNormBackward, 8, true, true); DEF(GroupNormBackward, 8, true, true);
DEF(MaskedFill, 3, false, true); DEF(MaskedFill, 3, false, true);
DEF(MultiHeadAttnForward, 6, true, true); DEF(MultiHeadAttnForward, 11, true, true);
DEF(MultiHeadAttnBackward, 10, true, true); DEF(MultiHeadAttnBackward, 15, true, true);
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -8,54 +8,84 @@ namespace cuda { ...@@ -8,54 +8,84 @@ namespace cuda {
void MultiHeadAttnForwardImpl::deduce_layout( void MultiHeadAttnForwardImpl::deduce_layout(
const TensorLayout& queries, const TensorLayout& keys, const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv, TensorLayout& out, const TensorLayout& values, const TensorLayout& qkvo_weight_bias,
TensorLayout& reserveSpace) { const TensorLayout& attn_mask, const TensorLayout& bias_k,
const TensorLayout& bias_v, TensorLayout& out, TensorLayout& attn_weight,
TensorLayout& mask_reservespace, TensorLayout& othr_reservespace) {
#if CUDNN_VERSION < 8004 #if CUDNN_VERSION < 8004
// TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation. // TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation.
MEGDNN_MARK_USED_VAR(queries); MEGDNN_MARK_USED_VAR(queries);
MEGDNN_MARK_USED_VAR(keys); MEGDNN_MARK_USED_VAR(keys);
MEGDNN_MARK_USED_VAR(values); MEGDNN_MARK_USED_VAR(values);
MEGDNN_MARK_USED_VAR(wqkv); MEGDNN_MARK_USED_VAR(qkvo_weight_bias);
MEGDNN_MARK_USED_VAR(attn_mask);
MEGDNN_MARK_USED_VAR(bias_k);
MEGDNN_MARK_USED_VAR(bias_v);
MEGDNN_MARK_USED_VAR(out); MEGDNN_MARK_USED_VAR(out);
MEGDNN_MARK_USED_VAR(reserveSpace); MEGDNN_MARK_USED_VAR(attn_weight);
MEGDNN_MARK_USED_VAR(mask_reservespace);
MEGDNN_MARK_USED_VAR(othr_reservespace);
return; return;
#else #else
MEGDNN_MARK_USED_VAR(keys); MEGDNN_MARK_USED_VAR(qkvo_weight_bias);
MEGDNN_MARK_USED_VAR(wqkv); MEGDNN_MARK_USED_VAR(attn_mask);
auto p = param();
megdnn_assert( megdnn_assert(
queries.ndim == 3, queries.ndim == 3, "queries.ndim should be 3, but got %zu", queries.ndim);
"queries.ndim should be 3[batch, sequence, embeding], but got %zu",
queries.ndim);
if (!desc_status.is_initialized(param(), queries, keys, values)) { if (!desc_status.is_initialized(p, queries, keys, values))
desc_status.set(cudnn_handle(this->handle()), param(), queries, keys, values); desc_status.set(cudnn_handle(this->handle()), p, queries, keys, values);
auto input_type = p.tensor_combination_type;
using INPUT_TYPE = Param::TENSOR_COMBINATION_TYPE;
bool have_biaskv =
input_type == INPUT_TYPE::ONLY_BIASKV or input_type == INPUT_TYPE::ALL;
size_t attn_seqk_dim_add = (have_biaskv ? 1 : 0) + (p.add_zero_attn ? 1 : 0);
attn_weight = TensorLayout(
TensorShape{
queries.shape[0] * p.num_heads, queries.shape[1],
keys.shape[1] + attn_seqk_dim_add},
queries.dtype);
size_t osize = p.oproj_size != 0 ? p.oproj_size
: (p.vproj_size != 0 ? p.vproj_size : p.v_size);
out = TensorLayout(
TensorShape{queries.shape[0], queries.shape[1], osize}, queries.dtype);
mask_reservespace = TensorLayout(TensorShape{0}, dtype::Uint8());
othr_reservespace =
TensorLayout(TensorShape{desc_status.sizeReserve}, queries.dtype);
out = TensorLayout(
TensorShape{queries.shape[0], queries.shape[1], queries.shape[2]},
queries.dtype);
reserveSpace =
TensorLayout(TensorShape{desc_status.sizeReserve}, queries.dtype);
}
#endif #endif
} }
size_t MultiHeadAttnForwardImpl::get_workspace_in_bytes( size_t MultiHeadAttnForwardImpl::get_workspace_in_bytes(
const TensorLayout& queries, const TensorLayout& keys, const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& out, const TensorLayout& values, const TensorLayout& qkvo_weight_bias,
const TensorLayout& reserveSpace) { const TensorLayout& attn_mask, const TensorLayout& bias_k,
const TensorLayout& bias_v, const TensorLayout& out,
const TensorLayout& attn_weight, const TensorLayout& mask_reservespace,
const TensorLayout& othr_reservespace) {
#if CUDNN_VERSION < 8004 #if CUDNN_VERSION < 8004
// TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation. // TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation.
MEGDNN_MARK_USED_VAR(queries); MEGDNN_MARK_USED_VAR(queries);
MEGDNN_MARK_USED_VAR(keys); MEGDNN_MARK_USED_VAR(keys);
MEGDNN_MARK_USED_VAR(values); MEGDNN_MARK_USED_VAR(values);
MEGDNN_MARK_USED_VAR(wqkv); MEGDNN_MARK_USED_VAR(qkvo_weight_bias);
MEGDNN_MARK_USED_VAR(attn_mask);
MEGDNN_MARK_USED_VAR(bias_k);
MEGDNN_MARK_USED_VAR(bias_v);
MEGDNN_MARK_USED_VAR(out); MEGDNN_MARK_USED_VAR(out);
MEGDNN_MARK_USED_VAR(reserveSpace); MEGDNN_MARK_USED_VAR(attn_weight);
MEGDNN_MARK_USED_VAR(mask_reservespace);
MEGDNN_MARK_USED_VAR(othr_reservespace);
return 0; return 0;
#else #else
MEGDNN_MARK_USED_VAR(wqkv); MEGDNN_MARK_USED_VAR(qkvo_weight_bias);
MEGDNN_MARK_USED_VAR(attn_mask);
MEGDNN_MARK_USED_VAR(out); MEGDNN_MARK_USED_VAR(out);
MEGDNN_MARK_USED_VAR(reserveSpace); MEGDNN_MARK_USED_VAR(attn_weight);
MEGDNN_MARK_USED_VAR(mask_reservespace);
MEGDNN_MARK_USED_VAR(othr_reservespace);
if (!desc_status.is_initialized(param(), queries, keys, values)) if (!desc_status.is_initialized(param(), queries, keys, values))
desc_status.set(cudnn_handle(this->handle()), param(), queries, keys, values); desc_status.set(cudnn_handle(this->handle()), param(), queries, keys, values);
...@@ -64,47 +94,102 @@ size_t MultiHeadAttnForwardImpl::get_workspace_in_bytes( ...@@ -64,47 +94,102 @@ size_t MultiHeadAttnForwardImpl::get_workspace_in_bytes(
#endif #endif
} }
size_t MultiHeadAttnForwardImpl::get_reservespace_in_bytes( size_t MultiHeadAttnForwardImpl::get_mask_reservespace_in_bytes(
const TensorLayout& queries, const TensorLayout& keys, const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& out, const TensorLayout& values, const TensorLayout& qkvo_weight_bias,
const TensorLayout& reserveSpace) { const TensorLayout& attn_mask, const TensorLayout& bias_k,
const TensorLayout& bias_v, const TensorLayout& out,
const TensorLayout& attn_weight, const TensorLayout& mask_reservespace,
const TensorLayout& othr_reservespace) {
#if CUDNN_VERSION < 8004 #if CUDNN_VERSION < 8004
// TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation. // TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation.
MEGDNN_MARK_USED_VAR(queries); MEGDNN_MARK_USED_VAR(queries);
MEGDNN_MARK_USED_VAR(keys); MEGDNN_MARK_USED_VAR(keys);
MEGDNN_MARK_USED_VAR(values); MEGDNN_MARK_USED_VAR(values);
MEGDNN_MARK_USED_VAR(wqkv); MEGDNN_MARK_USED_VAR(qkvo_weight_bias);
MEGDNN_MARK_USED_VAR(attn_mask);
MEGDNN_MARK_USED_VAR(bias_k);
MEGDNN_MARK_USED_VAR(bias_v);
MEGDNN_MARK_USED_VAR(out); MEGDNN_MARK_USED_VAR(out);
MEGDNN_MARK_USED_VAR(reserveSpace); MEGDNN_MARK_USED_VAR(attn_weight);
MEGDNN_MARK_USED_VAR(mask_reservespace);
MEGDNN_MARK_USED_VAR(othr_reservespace);
return 0; return 0;
#else #else
MEGDNN_MARK_USED_VAR(wqkv); MEGDNN_MARK_USED_VAR(qkvo_weight_bias);
MEGDNN_MARK_USED_VAR(attn_mask);
MEGDNN_MARK_USED_VAR(out); MEGDNN_MARK_USED_VAR(out);
MEGDNN_MARK_USED_VAR(reserveSpace); MEGDNN_MARK_USED_VAR(attn_weight);
MEGDNN_MARK_USED_VAR(mask_reservespace);
MEGDNN_MARK_USED_VAR(othr_reservespace);
if (!desc_status.is_initialized(param(), queries, keys, values))
desc_status.set(cudnn_handle(this->handle()), param(), queries, keys, values);
return 0;
#endif
}
size_t MultiHeadAttnForwardImpl::get_othr_reservespace_in_bytes(
const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& qkvo_weight_bias,
const TensorLayout& attn_mask, const TensorLayout& bias_k,
const TensorLayout& bias_v, const TensorLayout& out,
const TensorLayout& attn_weight, const TensorLayout& mask_reservespace,
const TensorLayout& othr_reservespace) {
#if CUDNN_VERSION < 8004
// TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation.
MEGDNN_MARK_USED_VAR(queries);
MEGDNN_MARK_USED_VAR(keys);
MEGDNN_MARK_USED_VAR(values);
MEGDNN_MARK_USED_VAR(qkvo_weight_bias);
MEGDNN_MARK_USED_VAR(attn_mask);
MEGDNN_MARK_USED_VAR(bias_k);
MEGDNN_MARK_USED_VAR(bias_v);
MEGDNN_MARK_USED_VAR(out);
MEGDNN_MARK_USED_VAR(attn_weight);
MEGDNN_MARK_USED_VAR(mask_reservespace);
MEGDNN_MARK_USED_VAR(othr_reservespace);
return 0;
#else
MEGDNN_MARK_USED_VAR(qkvo_weight_bias);
MEGDNN_MARK_USED_VAR(attn_mask);
MEGDNN_MARK_USED_VAR(bias_k);
MEGDNN_MARK_USED_VAR(bias_v);
MEGDNN_MARK_USED_VAR(out);
MEGDNN_MARK_USED_VAR(attn_weight);
MEGDNN_MARK_USED_VAR(mask_reservespace);
MEGDNN_MARK_USED_VAR(othr_reservespace);
if (!desc_status.is_initialized(param(), queries, keys, values)) if (!desc_status.is_initialized(param(), queries, keys, values))
desc_status.set(cudnn_handle(this->handle()), param(), queries, keys, values); desc_status.set(cudnn_handle(this->handle()), param(), queries, keys, values);
return desc_status.sizeReserve; return desc_status.sizeReserve;
#endif #endif
} }
void MultiHeadAttnForwardImpl::exec( void MultiHeadAttnForwardImpl::exec(
_megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values, _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values,
_megdnn_tensor_in wqkv, _megdnn_tensor_out out, _megdnn_tensor_out reserveSpace, _megdnn_tensor_in qkvo_weight_bias, _megdnn_tensor_in attn_mask,
_megdnn_workspace workspace) { _megdnn_tensor_in bias_k, _megdnn_tensor_in bias_v, _megdnn_tensor_out out,
_megdnn_tensor_out attn_weight, _megdnn_tensor_out mask_reservespace,
_megdnn_tensor_out othr_reservespace, _megdnn_workspace workspace) {
#if CUDNN_VERSION < 8004 #if CUDNN_VERSION < 8004
// TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation. // TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation.
MEGDNN_MARK_USED_VAR(queries); MEGDNN_MARK_USED_VAR(queries);
MEGDNN_MARK_USED_VAR(keys); MEGDNN_MARK_USED_VAR(keys);
MEGDNN_MARK_USED_VAR(values); MEGDNN_MARK_USED_VAR(values);
MEGDNN_MARK_USED_VAR(wqkv); MEGDNN_MARK_USED_VAR(qkvo_weight_bias);
MEGDNN_MARK_USED_VAR(attn_mask);
MEGDNN_MARK_USED_VAR(bias_k);
MEGDNN_MARK_USED_VAR(bias_v);
MEGDNN_MARK_USED_VAR(out); MEGDNN_MARK_USED_VAR(out);
MEGDNN_MARK_USED_VAR(reserveSpace); MEGDNN_MARK_USED_VAR(attn_weight);
MEGDNN_MARK_USED_VAR(workspace); MEGDNN_MARK_USED_VAR(mask_reservespace);
MEGDNN_MARK_USED_VAR(othr_reservespace);
megdnn_throw( megdnn_throw(
"The cudnn version is lower than 8.0.4. Please upgrade the cudnn version."); "The cudnn version is lower than 8.0.4. Please upgrade the cudnn version.");
#else #else
check_exec( check_exec(
queries.layout, keys.layout, values.layout, wqkv.layout, out.layout, queries.layout, keys.layout, values.layout, qkvo_weight_bias.layout,
reserveSpace.layout, workspace.size); attn_mask.layout, bias_k.layout, bias_v.layout, out.layout,
attn_weight.layout, mask_reservespace.layout, othr_reservespace.layout,
workspace.size);
auto p = param(); auto p = param();
if (!desc_status.is_initialized(p, queries.layout, keys.layout, values.layout)) if (!desc_status.is_initialized(p, queries.layout, keys.layout, values.layout))
...@@ -112,12 +197,16 @@ void MultiHeadAttnForwardImpl::exec( ...@@ -112,12 +197,16 @@ void MultiHeadAttnForwardImpl::exec(
cudnn_handle(this->handle()), p, queries.layout, keys.layout, cudnn_handle(this->handle()), p, queries.layout, keys.layout,
values.layout); values.layout);
size_t osize =
desc_status.oProjSize != 0
? desc_status.oProjSize
: (desc_status.vProjSize != 0 ? desc_status.vProjSize * p.num_heads
: desc_status.vSize);
SeqTensorDesc q{queries.layout, desc_status.batchSize, SeqTensorDesc q{queries.layout, desc_status.batchSize,
desc_status.seqLenQ, desc_status.qSize, desc_status.seqLenQ, desc_status.qSize,
p.input_order, desc_status.auxArray.seqQArray}; p.input_order, desc_status.auxArray.seqQArray};
SeqTensorDesc o{out.layout, desc_status.batchSize, SeqTensorDesc o{out.layout, desc_status.batchSize, desc_status.seqLenQ,
desc_status.seqLenQ, desc_status.oProjSize, osize, p.input_order, desc_status.auxArray.seqQArray};
p.input_order, desc_status.auxArray.seqQArray};
SeqTensorDesc k{keys.layout, desc_status.batchSize, SeqTensorDesc k{keys.layout, desc_status.batchSize,
desc_status.seqLenK, desc_status.kSize, desc_status.seqLenK, desc_status.kSize,
p.input_order, desc_status.auxArray.seqKArray}; p.input_order, desc_status.auxArray.seqKArray};
...@@ -132,19 +221,22 @@ void MultiHeadAttnForwardImpl::exec( ...@@ -132,19 +221,22 @@ void MultiHeadAttnForwardImpl::exec(
q.desc, queries.raw_ptr(), p.reslink ? queries.raw_ptr() : NULL, k.desc, q.desc, queries.raw_ptr(), p.reslink ? queries.raw_ptr() : NULL, k.desc,
keys.raw_ptr(), v.desc, values.raw_ptr(), o.desc, out.raw_ptr(), keys.raw_ptr(), v.desc, values.raw_ptr(), o.desc, out.raw_ptr(),
desc_status.sizeWeights, desc_status.sizeWeights,
desc_status.sizeWeights > 0 ? wqkv.raw_ptr() : NULL, desc_status.sizeWeights > 0 ? qkvo_weight_bias.raw_ptr() : NULL,
desc_status.sizeWkspace, workspace.raw_ptr, desc_status.sizeWkspace, workspace.raw_ptr,
p.training ? desc_status.sizeReserve : 0, p.training ? desc_status.sizeReserve : 0,
p.training ? reserveSpace.raw_ptr() : NULL)); p.training ? othr_reservespace.raw_ptr() : NULL));
#endif #endif
} }
void MultiHeadAttnBackwardImpl::exec( void MultiHeadAttnBackwardImpl::exec(
_megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys,
_megdnn_tensor_in values, _megdnn_tensor_in wqkv, _megdnn_tensor_in values, _megdnn_tensor_in qkvo_weight_bias,
_megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries, _megdnn_tensor_in attn_mask, _megdnn_tensor_in attn_weight,
_megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues, _megdnn_tensor_in mask_reservespace, _megdnn_tensor_in othr_reservespace,
_megdnn_tensor_out dweights, _megdnn_workspace workspace) { _megdnn_tensor_out dqueries, _megdnn_tensor_out dkeys,
_megdnn_tensor_out dvalues, _megdnn_tensor_out dqkvo_weight_bias,
_megdnn_tensor_out dbias_k, _megdnn_tensor_out dbias_v,
_megdnn_workspace workspace) {
#if CUDNN_VERSION < 8004 #if CUDNN_VERSION < 8004
// TODO: CUDNN_VERSION < 8004 and param().bias = true, we need to go to the proxy // TODO: CUDNN_VERSION < 8004 and param().bias = true, we need to go to the proxy
// cuda implementation. // cuda implementation.
...@@ -152,12 +244,17 @@ void MultiHeadAttnBackwardImpl::exec( ...@@ -152,12 +244,17 @@ void MultiHeadAttnBackwardImpl::exec(
MEGDNN_MARK_USED_VAR(queries); MEGDNN_MARK_USED_VAR(queries);
MEGDNN_MARK_USED_VAR(keys); MEGDNN_MARK_USED_VAR(keys);
MEGDNN_MARK_USED_VAR(values); MEGDNN_MARK_USED_VAR(values);
MEGDNN_MARK_USED_VAR(wqkv); MEGDNN_MARK_USED_VAR(qkvo_weight_bias);
MEGDNN_MARK_USED_VAR(reserveSpace); MEGDNN_MARK_USED_VAR(attn_mask);
MEGDNN_MARK_USED_VAR(attn_weight);
MEGDNN_MARK_USED_VAR(mask_reservespace);
MEGDNN_MARK_USED_VAR(othr_reservespace);
MEGDNN_MARK_USED_VAR(dqueries); MEGDNN_MARK_USED_VAR(dqueries);
MEGDNN_MARK_USED_VAR(dkeys); MEGDNN_MARK_USED_VAR(dkeys);
MEGDNN_MARK_USED_VAR(dvalues); MEGDNN_MARK_USED_VAR(dvalues);
MEGDNN_MARK_USED_VAR(dweights); MEGDNN_MARK_USED_VAR(dqkvo_weight_bias);
MEGDNN_MARK_USED_VAR(dbias_k);
MEGDNN_MARK_USED_VAR(dbias_v);
megdnn_throw( megdnn_throw(
"The cudnn version is lower than 8.0.4. Please upgrade the cudnn version."); "The cudnn version is lower than 8.0.4. Please upgrade the cudnn version.");
#else #else
...@@ -168,11 +265,16 @@ void MultiHeadAttnBackwardImpl::exec( ...@@ -168,11 +265,16 @@ void MultiHeadAttnBackwardImpl::exec(
"but got true, because there is an error in the " "but got true, because there is an error in the "
"dbias result during the backward calculation."); "dbias result during the backward calculation.");
#endif #endif
MEGDNN_MARK_USED_VAR(attn_mask);
MEGDNN_MARK_USED_VAR(dbias_k);
MEGDNN_MARK_USED_VAR(dbias_v);
check_exec( check_exec(
diff.layout, queries.layout, keys.layout, values.layout, wqkv.layout, diff.layout, queries.layout, keys.layout, values.layout,
reserveSpace.layout, dqueries.layout, dkeys.layout, dvalues.layout, qkvo_weight_bias.layout, attn_mask.layout, attn_weight.layout,
dweights.layout, workspace.size); mask_reservespace.layout, othr_reservespace.layout, dqueries.layout,
dkeys.layout, dvalues.layout, dqkvo_weight_bias.layout, dbias_k.layout,
dbias_v.layout, workspace.size);
auto p = param(); auto p = param();
if (!desc_status.is_initialized(p, queries.layout, keys.layout, values.layout)) if (!desc_status.is_initialized(p, queries.layout, keys.layout, values.layout))
...@@ -180,12 +282,16 @@ void MultiHeadAttnBackwardImpl::exec( ...@@ -180,12 +282,16 @@ void MultiHeadAttnBackwardImpl::exec(
cudnn_handle(this->handle()), p, queries.layout, keys.layout, cudnn_handle(this->handle()), p, queries.layout, keys.layout,
values.layout); values.layout);
size_t osize =
desc_status.oProjSize != 0
? desc_status.oProjSize
: (desc_status.vProjSize != 0 ? desc_status.vProjSize * p.num_heads
: desc_status.vSize);
SeqTensorDesc q{queries.layout, desc_status.batchSize, SeqTensorDesc q{queries.layout, desc_status.batchSize,
desc_status.seqLenQ, desc_status.qSize, desc_status.seqLenQ, desc_status.qSize,
p.input_order, desc_status.auxArray.seqQArray}; p.input_order, desc_status.auxArray.seqQArray};
SeqTensorDesc d{diff.layout, desc_status.batchSize, SeqTensorDesc d{diff.layout, desc_status.batchSize, desc_status.seqLenQ,
desc_status.seqLenQ, desc_status.oProjSize, osize, p.input_order, desc_status.auxArray.seqQArray};
p.input_order, desc_status.auxArray.seqQArray};
SeqTensorDesc k{keys.layout, desc_status.batchSize, SeqTensorDesc k{keys.layout, desc_status.batchSize,
desc_status.seqLenK, desc_status.kSize, desc_status.seqLenK, desc_status.kSize,
p.input_order, desc_status.auxArray.seqKArray}; p.input_order, desc_status.auxArray.seqKArray};
...@@ -200,11 +306,11 @@ void MultiHeadAttnBackwardImpl::exec( ...@@ -200,11 +306,11 @@ void MultiHeadAttnBackwardImpl::exec(
d.desc, diff.raw_ptr(), q.desc, dqueries.raw_ptr(), queries.raw_ptr(), d.desc, diff.raw_ptr(), q.desc, dqueries.raw_ptr(), queries.raw_ptr(),
k.desc, dkeys.raw_ptr(), keys.raw_ptr(), v.desc, dvalues.raw_ptr(), k.desc, dkeys.raw_ptr(), keys.raw_ptr(), v.desc, dvalues.raw_ptr(),
values.raw_ptr(), desc_status.sizeWeights, values.raw_ptr(), desc_status.sizeWeights,
desc_status.sizeWeights > 0 ? wqkv.raw_ptr() : NULL, desc_status.sizeWeights > 0 ? qkvo_weight_bias.raw_ptr() : NULL,
desc_status.sizeWkspace, workspace.raw_ptr, desc_status.sizeReserve, desc_status.sizeWkspace, workspace.raw_ptr, desc_status.sizeReserve,
reserveSpace.raw_ptr())); othr_reservespace.raw_ptr()));
cuda_check(cudaMemset(dweights.raw_ptr(), 0, desc_status.sizeWeights)); cuda_check(cudaMemset(dqkvo_weight_bias.raw_ptr(), 0, desc_status.sizeWeights));
#if CUDNN_VERSION < 8600 #if CUDNN_VERSION < 8600
cuda_check(cudaDeviceSynchronize()); cuda_check(cudaDeviceSynchronize());
#endif #endif
...@@ -212,28 +318,35 @@ void MultiHeadAttnBackwardImpl::exec( ...@@ -212,28 +318,35 @@ void MultiHeadAttnBackwardImpl::exec(
cudnn_handle(this->handle()), desc_status.attn_desc, CUDNN_WGRAD_MODE_ADD, cudnn_handle(this->handle()), desc_status.attn_desc, CUDNN_WGRAD_MODE_ADD,
q.desc, queries.raw_ptr(), k.desc, keys.raw_ptr(), v.desc, values.raw_ptr(), q.desc, queries.raw_ptr(), k.desc, keys.raw_ptr(), v.desc, values.raw_ptr(),
d.desc, diff.raw_ptr(), desc_status.sizeWeights, d.desc, diff.raw_ptr(), desc_status.sizeWeights,
desc_status.sizeWeights > 0 ? wqkv.raw_ptr() : NULL, desc_status.sizeWeights > 0 ? qkvo_weight_bias.raw_ptr() : NULL,
desc_status.sizeWeights > 0 ? dweights.raw_ptr() : NULL, desc_status.sizeWeights > 0 ? dqkvo_weight_bias.raw_ptr() : NULL,
desc_status.sizeWkspace, workspace.raw_ptr, desc_status.sizeReserve, desc_status.sizeWkspace, workspace.raw_ptr, desc_status.sizeReserve,
reserveSpace.raw_ptr())); othr_reservespace.raw_ptr()));
#endif #endif
} }
size_t MultiHeadAttnBackwardImpl::get_workspace_in_bytes( size_t MultiHeadAttnBackwardImpl::get_workspace_in_bytes(
const TensorLayout& diff, const TensorLayout& queries, const TensorLayout& keys, const TensorLayout& diff, const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& values, const TensorLayout& qkvo_weight_bias,
const TensorLayout& reserveSpace, const TensorLayout& dqueries, const TensorLayout& attn_mask, const TensorLayout& attn_weight,
const TensorLayout& dkeys, const TensorLayout& dvalues, const TensorLayout& mask_reservespace, const TensorLayout& othr_reservespace,
const TensorLayout& dweights) { const TensorLayout& dqueries, const TensorLayout& dkeys,
const TensorLayout& dvalues, const TensorLayout& dqkvo_weight_bias,
const TensorLayout& dbias_k, const TensorLayout& dbias_v) {
MEGDNN_MARK_USED_VAR(diff); MEGDNN_MARK_USED_VAR(diff);
MEGDNN_MARK_USED_VAR(queries); MEGDNN_MARK_USED_VAR(queries);
MEGDNN_MARK_USED_VAR(keys); MEGDNN_MARK_USED_VAR(keys);
MEGDNN_MARK_USED_VAR(values); MEGDNN_MARK_USED_VAR(values);
MEGDNN_MARK_USED_VAR(wqkv); MEGDNN_MARK_USED_VAR(qkvo_weight_bias);
MEGDNN_MARK_USED_VAR(reserveSpace); MEGDNN_MARK_USED_VAR(attn_mask);
MEGDNN_MARK_USED_VAR(attn_weight);
MEGDNN_MARK_USED_VAR(mask_reservespace);
MEGDNN_MARK_USED_VAR(othr_reservespace);
MEGDNN_MARK_USED_VAR(dqueries); MEGDNN_MARK_USED_VAR(dqueries);
MEGDNN_MARK_USED_VAR(dkeys); MEGDNN_MARK_USED_VAR(dkeys);
MEGDNN_MARK_USED_VAR(dvalues); MEGDNN_MARK_USED_VAR(dvalues);
MEGDNN_MARK_USED_VAR(dweights); MEGDNN_MARK_USED_VAR(dqkvo_weight_bias);
MEGDNN_MARK_USED_VAR(dbias_k);
MEGDNN_MARK_USED_VAR(dbias_v);
return 0; return 0;
} }
} // namespace cuda } // namespace cuda
......
...@@ -19,20 +19,37 @@ public: ...@@ -19,20 +19,37 @@ public:
void exec( void exec(
_megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values, _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values,
_megdnn_tensor_in wqkv, _megdnn_tensor_out out, _megdnn_tensor_in qkvo_weight_bias, _megdnn_tensor_in attn_mask,
_megdnn_tensor_out reserveSpace, _megdnn_workspace workspace) override; _megdnn_tensor_in bias_k, _megdnn_tensor_in bias_v, _megdnn_tensor_out out,
_megdnn_tensor_out attn_weight, _megdnn_tensor_out mask_reservespace,
_megdnn_tensor_out othr_reservespace, _megdnn_workspace workspace) override;
void deduce_layout( void deduce_layout(
const TensorLayout& queries, const TensorLayout& keys, const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv, TensorLayout& out, const TensorLayout& values, const TensorLayout& qkvo_weight_bias,
TensorLayout& reserveSpace); const TensorLayout& attn_mask, const TensorLayout& bias_k,
size_t get_reservespace_in_bytes( const TensorLayout& bias_v, TensorLayout& out, TensorLayout& attn_weight,
TensorLayout& mask_reservespace, TensorLayout& othr_reservespace) override;
size_t get_mask_reservespace_in_bytes(
const TensorLayout& queries, const TensorLayout& keys, const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& values, const TensorLayout& qkvo_weight_bias,
const TensorLayout& out, const TensorLayout& reserveSpace) override; const TensorLayout& attn_mask, const TensorLayout& bias_k,
const TensorLayout& bias_v, const TensorLayout& out,
const TensorLayout& attn_weight, const TensorLayout& mask_reservespace,
const TensorLayout& othr_reservespace) override;
size_t get_othr_reservespace_in_bytes(
const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& qkvo_weight_bias,
const TensorLayout& attn_mask, const TensorLayout& bias_k,
const TensorLayout& bias_v, const TensorLayout& out,
const TensorLayout& attn_weight, const TensorLayout& mask_reservespace,
const TensorLayout& othr_reservespace) override;
size_t get_workspace_in_bytes( size_t get_workspace_in_bytes(
const TensorLayout& queries, const TensorLayout& keys, const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& values, const TensorLayout& qkvo_weight_bias,
const TensorLayout& out, const TensorLayout& reserveSpace) override; const TensorLayout& attn_mask, const TensorLayout& bias_k,
const TensorLayout& bias_v, const TensorLayout& out,
const TensorLayout& attn_weight, const TensorLayout& mask_reservespace,
const TensorLayout& othr_reservespace) override;
}; };
class MultiHeadAttnBackwardImpl final : public MultiHeadAttnBackward { class MultiHeadAttnBackwardImpl final : public MultiHeadAttnBackward {
...@@ -43,16 +60,22 @@ public: ...@@ -43,16 +60,22 @@ public:
#endif #endif
void exec( void exec(
_megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys,
_megdnn_tensor_in values, _megdnn_tensor_in wqkv, _megdnn_tensor_in values, _megdnn_tensor_in qkvo_weight_bias,
_megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries, _megdnn_tensor_in attn_mask, _megdnn_tensor_in attn_weight,
_megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues, _megdnn_tensor_in mask_reservespace, _megdnn_tensor_in othr_reservespace,
_megdnn_tensor_out dweights, _megdnn_workspace workspace) override; _megdnn_tensor_out dqueries, _megdnn_tensor_out dkeys,
_megdnn_tensor_out dvalues, _megdnn_tensor_out dqkvo_weight_bias,
_megdnn_tensor_out dbias_k, _megdnn_tensor_out dbias_v,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes( size_t get_workspace_in_bytes(
const TensorLayout& diff, const TensorLayout& queries, const TensorLayout& diff, const TensorLayout& queries,
const TensorLayout& keys, const TensorLayout& values, const TensorLayout& keys, const TensorLayout& values,
const TensorLayout& wqkv, const TensorLayout& reserveSpace, const TensorLayout& qkvo_weight_bias, const TensorLayout& attn_mask,
const TensorLayout& dqueries, const TensorLayout& dkeys, const TensorLayout& attn_weight, const TensorLayout& mask_reservespace,
const TensorLayout& dvalues, const TensorLayout& dweights) override; const TensorLayout& othr_reservespace, const TensorLayout& dqueries,
const TensorLayout& dkeys, const TensorLayout& dvalues,
const TensorLayout& dqkvo_weight_bias, const TensorLayout& dbias_k,
const TensorLayout& dbias_v) override;
}; };
} // namespace cuda } // namespace cuda
} // namespace megdnn } // namespace megdnn
......
...@@ -8,45 +8,45 @@ namespace naive { ...@@ -8,45 +8,45 @@ namespace naive {
using Param = MultiHeadAttnBase::Param; using Param = MultiHeadAttnBase::Param;
size_t MultiHeadAttnForwardImpl::get_workspace_in_bytes( size_t MultiHeadAttnForwardImpl::get_workspace_in_bytes(
const TensorLayout& queries, const TensorLayout& keys, const TensorLayout& /*queries*/, const TensorLayout& /*keys*/,
const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& out, const TensorLayout& /*values*/, const TensorLayout& /*qkvo_weight_bias*/,
const TensorLayout& reserveSpace) { const TensorLayout& /*attn_mask*/, const TensorLayout& /*bias_k*/,
MEGDNN_MARK_USED_VAR(queries); const TensorLayout& /*bias_v*/, const TensorLayout& /*out*/,
MEGDNN_MARK_USED_VAR(keys); const TensorLayout& /*attn_weight*/, const TensorLayout& /*mask_reservespace*/,
MEGDNN_MARK_USED_VAR(values); const TensorLayout& /*othr_reservespace*/) {
MEGDNN_MARK_USED_VAR(wqkv);
MEGDNN_MARK_USED_VAR(out);
MEGDNN_MARK_USED_VAR(reserveSpace);
megdnn_throw("unsupported naive multiheadattn forward\n"); megdnn_throw("unsupported naive multiheadattn forward\n");
} }
void MultiHeadAttnForwardImpl::exec( void MultiHeadAttnForwardImpl::exec(
_megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values, _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values,
_megdnn_tensor_in wqkv, _megdnn_tensor_out out, _megdnn_tensor_out reserveSpace, _megdnn_tensor_in qkvo_weight_bias, _megdnn_tensor_in attn_mask,
_megdnn_workspace workspace) { _megdnn_tensor_in bias_k, _megdnn_tensor_in bias_v, _megdnn_tensor_out out,
MEGDNN_MARK_USED_VAR(queries); _megdnn_tensor_out attn_weight, _megdnn_tensor_out mask_reservespace,
MEGDNN_MARK_USED_VAR(keys); _megdnn_tensor_out othr_reservespace, _megdnn_workspace workspace) {
MEGDNN_MARK_USED_VAR(values);
MEGDNN_MARK_USED_VAR(wqkv);
MEGDNN_MARK_USED_VAR(out);
MEGDNN_MARK_USED_VAR(reserveSpace);
check_exec( check_exec(
queries.layout, keys.layout, values.layout, wqkv.layout, out.layout, queries.layout, keys.layout, values.layout, qkvo_weight_bias.layout,
reserveSpace.layout, workspace.size); attn_mask.layout, bias_k.layout, bias_v.layout, out.layout,
attn_weight.layout, mask_reservespace.layout, othr_reservespace.layout,
workspace.size);
megdnn_throw("unsupported naive multiheadattn forward\n"); megdnn_throw("unsupported naive multiheadattn forward\n");
} }
void MultiHeadAttnBackwardImpl::exec( void MultiHeadAttnBackwardImpl::exec(
_megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys,
_megdnn_tensor_in values, _megdnn_tensor_in wqkv, _megdnn_tensor_in values, _megdnn_tensor_in qkvo_weight_bias,
_megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries, _megdnn_tensor_in attn_mask, _megdnn_tensor_in attn_weight,
_megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues, _megdnn_tensor_in mask_reservespace, _megdnn_tensor_in othr_reservespace,
_megdnn_tensor_out dweights, _megdnn_workspace workspace) { _megdnn_tensor_out dqueries, _megdnn_tensor_out dkeys,
_megdnn_tensor_out dvalues, _megdnn_tensor_out dqkvo_weight_bias,
_megdnn_tensor_out dbias_k, _megdnn_tensor_out dbias_v,
_megdnn_workspace workspace) {
check_exec( check_exec(
diff.layout, queries.layout, keys.layout, values.layout, wqkv.layout, diff.layout, queries.layout, keys.layout, values.layout,
reserveSpace.layout, dqueries.layout, dkeys.layout, dvalues.layout, qkvo_weight_bias.layout, attn_mask.layout, attn_weight.layout,
dweights.layout, workspace.size); mask_reservespace.layout, othr_reservespace.layout, dqueries.layout,
dkeys.layout, dvalues.layout, dqkvo_weight_bias.layout, dbias_k.layout,
dbias_v.layout, workspace.size);
megdnn_throw("unsupported naive multiheadattn backward\n"); megdnn_throw("unsupported naive multiheadattn backward\n");
} }
......
...@@ -14,17 +14,43 @@ public: ...@@ -14,17 +14,43 @@ public:
using MultiHeadAttnForward::MultiHeadAttnForward; using MultiHeadAttnForward::MultiHeadAttnForward;
void exec( void exec(
_megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values, _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values,
_megdnn_tensor_in wqkv, _megdnn_tensor_out out, _megdnn_tensor_in qkvo_weight_bias, _megdnn_tensor_in attn_mask,
_megdnn_tensor_out reserveSpace, _megdnn_workspace workspace) override; _megdnn_tensor_in bias_k, _megdnn_tensor_in bias_v, _megdnn_tensor_out out,
_megdnn_tensor_out attn_weight, _megdnn_tensor_out mask_reservespace,
_megdnn_tensor_out othr_reservespace, _megdnn_workspace workspace) override;
void deduce_layout(
const TensorLayout& /*queries*/, const TensorLayout& /*keys*/,
const TensorLayout& /*values*/, const TensorLayout& /*qkvo_weight_bias*/,
const TensorLayout& /*attn_mask*/, const TensorLayout& /*bias_k*/,
const TensorLayout& /*bias_v*/, TensorLayout& /*out*/,
TensorLayout& /*attn_weight*/, TensorLayout& /*mask_reservespace*/,
TensorLayout& /*othr_reservespace*/) override {}
size_t get_workspace_in_bytes( size_t get_workspace_in_bytes(
const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv,
const TensorLayout& out, const TensorLayout& reserveSpace) override;
size_t get_reservespace_in_bytes(
const TensorLayout& /*queries*/, const TensorLayout& /*keys*/, const TensorLayout& /*queries*/, const TensorLayout& /*keys*/,
const TensorLayout& /*values*/, const TensorLayout& /*wqkv*/, const TensorLayout& /*values*/, const TensorLayout& /*qkvo_weight_bias*/,
const TensorLayout& /*out*/, const TensorLayout& /*attn_mask*/, const TensorLayout& /*bias_k*/,
const TensorLayout& /*reserveSpace*/) override { const TensorLayout& /*bias_v*/, const TensorLayout& /*out*/,
const TensorLayout& /*attn_weight*/,
const TensorLayout& /*mask_reservespace*/,
const TensorLayout& /*othr_reservespace*/) override;
size_t get_mask_reservespace_in_bytes(
const TensorLayout& /*queries*/, const TensorLayout& /*keys*/,
const TensorLayout& /*values*/, const TensorLayout& /*qkvo_weight_bias*/,
const TensorLayout& /*attn_mask*/, const TensorLayout& /*bias_k*/,
const TensorLayout& /*bias_v*/, const TensorLayout& /*out*/,
const TensorLayout& /*attn_weight*/,
const TensorLayout& /*mask_reservespace*/,
const TensorLayout& /*othr_reservespace*/) override {
return 0;
}
size_t get_othr_reservespace_in_bytes(
const TensorLayout& /*queries*/, const TensorLayout& /*keys*/,
const TensorLayout& /*values*/, const TensorLayout& /*qkvo_weight_bias*/,
const TensorLayout& /*attn_mask*/, const TensorLayout& /*bias_k*/,
const TensorLayout& /*bias_v*/, const TensorLayout& /*out*/,
const TensorLayout& /*attn_weight*/,
const TensorLayout& /*mask_reservespace*/,
const TensorLayout& /*othr_reservespace*/) override {
return 0; return 0;
} }
}; };
...@@ -34,17 +60,23 @@ public: ...@@ -34,17 +60,23 @@ public:
using MultiHeadAttnBackward::MultiHeadAttnBackward; using MultiHeadAttnBackward::MultiHeadAttnBackward;
void exec( void exec(
_megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys,
_megdnn_tensor_in values, _megdnn_tensor_in wqkv, _megdnn_tensor_in values, _megdnn_tensor_in qkvo_weight_bias,
_megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries, _megdnn_tensor_in attn_mask, _megdnn_tensor_in attn_weight,
_megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues, _megdnn_tensor_in mask_reservespace, _megdnn_tensor_in othr_reservespace,
_megdnn_tensor_out dweights, _megdnn_workspace workspace) override; _megdnn_tensor_out dqueries, _megdnn_tensor_out dkeys,
_megdnn_tensor_out dvalues, _megdnn_tensor_out dqkvo_weight_bias,
_megdnn_tensor_out dbias_k, _megdnn_tensor_out dbias_v,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes( size_t get_workspace_in_bytes(
const TensorLayout& /*diff*/, const TensorLayout& /* queries*/, const TensorLayout& /*diff*/, const TensorLayout& /*queries*/,
const TensorLayout& /*keyes*/, const TensorLayout& /* values*/, const TensorLayout& /*keys*/, const TensorLayout& /*values*/,
const TensorLayout& /*wqkv*/, const TensorLayout& /* reserveSpace*/, const TensorLayout& /*qkvo_weight_bias*/, const TensorLayout& /*attn_mask*/,
const TensorLayout& /*dqueries*/, const TensorLayout& /* dkeyes*/, const TensorLayout& /*attn_weight*/,
const TensorLayout& /*dvalues*/, const TensorLayout& /*mask_reservespace*/,
const TensorLayout& /* dweights*/) override { const TensorLayout& /*othr_reservespace*/, const TensorLayout& /*dqueries*/,
const TensorLayout& /*dkeys*/, const TensorLayout& /*dvalues*/,
const TensorLayout& /*dqkvo_weight_bias*/, const TensorLayout& /*dbias_k*/,
const TensorLayout& /*dbias_v*/) override {
return 0; return 0;
} }
}; };
......
...@@ -899,24 +899,24 @@ def gelu(x): ...@@ -899,24 +899,24 @@ def gelu(x):
def softplus(inp: Tensor) -> Tensor: def softplus(inp: Tensor) -> Tensor:
r"""Applies the element-wise function: r"""Applies the element-wise function:
.. math:: .. math::
\text{softplus}(x) = \log(1 + \exp(x)) \text{softplus}(x) = \log(1 + \exp(x))
softplus is a smooth approximation to the ReLU function and can be used softplus is a smooth approximation to the ReLU function and can be used
to constrain the output to be always positive. to constrain the output to be always positive.
For numerical stability the implementation follows this transformation: For numerical stability the implementation follows this transformation:
.. math:: .. math::
\text{softplus}(x) = \log(1 + \exp(x)) \text{softplus}(x) = \log(1 + \exp(x))
= \log(1 + \exp(-\text{abs}(x))) + \max(x, 0) = \log(1 + \exp(-\text{abs}(x))) + \max(x, 0)
= \log1p(\exp(-\text{abs}(x))) + \text{relu}(x) = \log1p(\exp(-\text{abs}(x))) + \text{relu}(x)
Examples: Examples:
>>> import numpy as np >>> import numpy as np
>>> x = Tensor(np.arange(-3, 3, dtype=np.float32)) >>> x = Tensor(np.arange(-3, 3, dtype=np.float32))
>>> y = F.softplus(x) >>> y = F.softplus(x)
>>> y.numpy().round(decimals=4) >>> y.numpy().round(decimals=4)
array([0.0486, 0.1269, 0.3133, 0.6931, 1.3133, 2.1269], dtype=float32) array([0.0486, 0.1269, 0.3133, 0.6931, 1.3133, 2.1269], dtype=float32)
""" """
return _elwise(inp, mode=Elemwise.Mode.SOFTPLUS) return _elwise(inp, mode=Elemwise.Mode.SOFTPLUS)
...@@ -2213,7 +2213,7 @@ def _merge_masks( ...@@ -2213,7 +2213,7 @@ def _merge_masks(
): ):
r""" r"""
Determine mask type and combine masks if necessary. Determine mask type and combine masks if necessary.
Note: This function will continue to improve with the iteration of MHA. Note: This function will continue to improve with the iteration of MHA.
Args: Args:
...@@ -2224,7 +2224,7 @@ def _merge_masks( ...@@ -2224,7 +2224,7 @@ def _merge_masks(
add_bias_kv: used to determine whether pad is needed on the sequence dimension of attn_mask and key_padding_mask, from MHA's ``add_bias_kv``. add_bias_kv: used to determine whether pad is needed on the sequence dimension of attn_mask and key_padding_mask, from MHA's ``add_bias_kv``.
add_zero_attn: used to determine whether pad is needed on the sequence dimension of attn_mask and key_padding_mask, from MHA's ``add_zero_attn``. add_zero_attn: used to determine whether pad is needed on the sequence dimension of attn_mask and key_padding_mask, from MHA's ``add_zero_attn``.
is_causal: MHA's is_causal, is_causal provides a hint that attn_mask is the causal mask. is_causal: MHA's is_causal, is_causal provides a hint that attn_mask is the causal mask.
maybe_cudnn_style_mask: MHA's maybe_cudnn_style_mask, like is_causal, maybe_cudnn_style_mask provides a hint that attn_mask and key_padding_mask is the cudnn style mask. maybe_cudnn_style_mask: MHA's maybe_cudnn_style_mask, like is_causal, maybe_cudnn_style_mask provides a hint that attn_mask and key_padding_mask is the cudnn style mask.
num_heads: MHA's head number. num_heads: MHA's head number.
Returns: Returns:
merged_mask: merged mask, may be None, the shape is :math:`(L, S)`, :math:`(2\cdotL + 2\cdotN)` or :math:`(N\cdot\text{num\_heads}, L, S)` merged_mask: merged mask, may be None, the shape is :math:`(L, S)`, :math:`(2\cdotL + 2\cdotN)` or :math:`(N\cdot\text{num\_heads}, L, S)`
...@@ -2320,8 +2320,8 @@ def multi_head_attention( ...@@ -2320,8 +2320,8 @@ def multi_head_attention(
num_heads: parallel attention heads. num_heads: parallel attention heads.
attn_drop: probability of an element to be zeroed, used in attention matrix. attn_drop: probability of an element to be zeroed, used in attention matrix.
out_drop: probability of an element to be zeroed, used in final output. out_drop: probability of an element to be zeroed, used in final output.
io_weight_bias: input/output projection weight/bias all in one. io_weight_bias: input/output projection weight/bias all in one.
The order of arrangement is: query weight, key weight, value weight, out weight, query bias, key bias, value bias, out bias, the following parameters will be used to indicate whether these items exist: qproj_size, kproj_size, vproj_size, oproj_size, qbias, kbias, vbias, obias. The order of arrangement is: query weight, key weight, value weight, out weight, query bias, key bias, value bias, out bias, the following parameters will be used to indicate whether these items exist: qproj_size, kproj_size, vproj_size, oproj_size, qbias, kbias, vbias, obias.
Note: :math:`Y=X@W+B` is used here instead of :math:`Y=X@W^T+B` in pytorch. Note: :math:`Y=X@W+B` is used here instead of :math:`Y=X@W^T+B` in pytorch.
qproj_size: indicates the projection size of query weight in io_weight_bias, 0 indicates disabled query projection and no query projection weight. qproj_size: indicates the projection size of query weight in io_weight_bias, 0 indicates disabled query projection and no query projection weight.
kproj_size: indicates the projection size of key weight in io_weight_bias, 0 indicates disabled key projection and no key projection weight. kproj_size: indicates the projection size of key weight in io_weight_bias, 0 indicates disabled key projection and no key projection weight.
...@@ -2335,7 +2335,7 @@ def multi_head_attention( ...@@ -2335,7 +2335,7 @@ def multi_head_attention(
Note: Should be set to None, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. Note: Should be set to None, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
add_zero_attn: if specified, adds a new batch of zeros to the key and value sequences at sequence dim. Default: ``False``. add_zero_attn: if specified, adds a new batch of zeros to the key and value sequences at sequence dim. Default: ``False``.
Note: should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. Note: should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
key_padding_mask: if specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` to ignore for the purpose of key_padding_mask: if specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` to ignore for the purpose of
attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
Note: Should be set to None, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. Note: Should be set to None, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
...@@ -2353,9 +2353,22 @@ def multi_head_attention( ...@@ -2353,9 +2353,22 @@ def multi_head_attention(
Note: In the cudnn style, the shape of the attn_mask is :math:`(2, L)`, and the shape of the key_padding_mask is :math:`(2, N)`. Note: In the cudnn style, the shape of the attn_mask is :math:`(2, L)`, and the shape of the key_padding_mask is :math:`(2, N)`.
Warning: like is_causal, maybe_cudnn_style_mask provides a hint that attn_mask and key_padding_mask is a cudnn style mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility. In addition, if the ``_merge_masks`` function returns ``merge_type=cudnn_style_mask``, please ensure that other conditions are correct so that it can run the implementation of cudnn, otherwise an error will be reported. Warning: like is_causal, maybe_cudnn_style_mask provides a hint that attn_mask and key_padding_mask is a cudnn style mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility. In addition, if the ``_merge_masks`` function returns ``merge_type=cudnn_style_mask``, please ensure that other conditions are correct so that it can run the implementation of cudnn, otherwise an error will be reported.
Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that the underlying implementation only accepts two types of mask type, namely "no_mask" and "default_mask", and we may try to loosen this option after submitting the commit that users can pass in custom attention mask tensors. Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that the underlying implementation only accepts two types of mask type, namely "no_mask" and "default_mask", and we may try to loosen this option after submitting the commit that users can pass in custom attention mask tensors.
reslink: add input query to final output. reslink: add input query to final output.
Note: It is only valid if the input query is the same as the shape of the output. Note: It is only valid if the input query is the same as the shape of the output.
training: will apply dropout if is ``True``. training: will apply dropout if is ``True``.
Outputs:
- **out[0]=attn_output** - Attention outputs of shape :math:`(N, L, E)`,
where :math:`L` is the target sequence length, :math:`N` is
the batch size, and :math:`E` is the embedding dimension ``embed_dim``.
- **out[1]=attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N * \text{num\_heads}, L, S)`.
Note: Now only None will be returned. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
- **out[2]=mask_reversespace** - Used to save the dropout mask needed for backward propagation.`,
- **out[3]=othr_reversespace** - Used to save the intermediate results that need to be used in backward propagation.`,
""" """
qproj_size = embed_dim if qproj_size is None else qproj_size qproj_size = embed_dim if qproj_size is None else qproj_size
kproj_size = embed_dim if kproj_size is None else kproj_size kproj_size = embed_dim if kproj_size is None else kproj_size
...@@ -2448,6 +2461,21 @@ def multi_head_attention( ...@@ -2448,6 +2461,21 @@ def multi_head_attention(
num_heads=num_heads, num_heads=num_heads,
) )
def get_tensor_combination_type(attn_mask_tensor, bias_k, bias_v):
bias_kv = bias_k is not None and bias_v is not None
if not bias_kv and attn_mask_tensor is None:
return "none"
elif not bias_kv and attn_mask_tensor is not None:
return "only_mask"
elif bias_kv and attn_mask_tensor is None:
return "only_biaskv"
else:
return "all"
tensor_combination_type = get_tensor_combination_type(
attn_mask_tensor, bias_k, bias_v
)
op = builtin.MultiHeadAttn( op = builtin.MultiHeadAttn(
num_heads=num_heads, num_heads=num_heads,
sm_scaler=smScaler, sm_scaler=smScaler,
...@@ -2471,11 +2499,20 @@ def multi_head_attention( ...@@ -2471,11 +2499,20 @@ def multi_head_attention(
vbias=vbias, vbias=vbias,
obias=obias, obias=obias,
need_weights=need_weights, need_weights=need_weights,
tensor_combination_type="none", tensor_combination_type=tensor_combination_type,
) )
if tensor_combination_type == "none":
out = apply(op, query, key, value, io_weight_bias)
elif tensor_combination_type == "only_mask":
out = apply(op, query, key, value, io_weight_bias, attn_mask_tensor)
elif tensor_combination_type == "only_biaskv":
out = apply(op, query, key, value, io_weight_bias, bias_k, bias_v)
else:
out = apply(
op, query, key, value, io_weight_bias, attn_mask_tensor, bias_k, bias_v
)
out, reserveSpace = apply(op, query, key, value, io_weight_bias) return out[0], out[1]
return out, None
from .loss import * # isort:skip from .loss import * # isort:skip
......
...@@ -436,6 +436,28 @@ _INST_RNG_MAKER(4) ...@@ -436,6 +436,28 @@ _INST_RNG_MAKER(4)
#undef _FOR_EACH_OUT #undef _FOR_EACH_OUT
#undef _FOR_EACH_IN #undef _FOR_EACH_IN
#define _FOR_EACH_IN(subfix) \
inputs[0] subfix, inputs[1] subfix, inputs[2] subfix, inputs[3] subfix, \
inputs[4] subfix,
_INST_RNG_MAKER(5)
#undef _FOR_EACH_IN
#define _FOR_EACH_IN(subfix) \
inputs[0] subfix, inputs[1] subfix, inputs[2] subfix, inputs[3] subfix, \
inputs[4] subfix, inputs[5] subfix,
_INST_RNG_MAKER(6)
#undef _FOR_EACH_IN
#define _FOR_EACH_IN(subfix) \
inputs[0] subfix, inputs[1] subfix, inputs[2] subfix, inputs[3] subfix, \
inputs[4] subfix, inputs[5] subfix, inputs[6] subfix,
#define _FOR_EACH_OUT(subfix) \
outputs[0] subfix, outputs[1] subfix, outputs[2] subfix, outputs[3] subfix
_INST_RNG_INVOLKER(7, 4)
_INST_RNG_MAKER(7)
#undef _FOR_EACH_OUT
#undef _FOR_EACH_IN
#undef _INST_RNG_INVOLKER #undef _INST_RNG_INVOLKER
#undef _INST_RNG_MAKER #undef _INST_RNG_MAKER
...@@ -541,37 +563,90 @@ SmallVector<LogicalTensorDesc> infer_output_attrs<Dropout>( ...@@ -541,37 +563,90 @@ SmallVector<LogicalTensorDesc> infer_output_attrs<Dropout>(
return dests; return dests;
} }
template <typename Op>
std::tuple<SmallVector<LogicalTensorDesc>, bool> _infer_output_attrs(
const OpDef& op, const SmallVector<TensorLayout>& inputs, const CompNode cn){};
template <> template <>
SmallVector<LogicalTensorDesc> infer_output_attrs<MultiHeadAttn>( std::tuple<SmallVector<LogicalTensorDesc>, bool> _infer_output_attrs<MultiHeadAttn>(
const OpDef& op, const SmallVector<TensorPtr>& inputs) { const OpDef& op, const SmallVector<TensorLayout>& inputs, const CompNode cn) {
SmallVector<LogicalTensorDesc> dests(2); bool success = inputs[0].ndim != 0;
auto&& cn = inputs[0]->comp_node();
dests[0].comp_node = cn; SmallVector<LogicalTensorDesc> dests(4);
dests[0].layout = TensorLayout(inputs[0]->layout());
dests[0].layout.dtype = inputs[0]->layout().dtype;
auto get_reservespace_in_bytes = [&]() -> size_t { // retrieve dnn_op from glob cache
// retrieve dnn_op from glob cache auto&& rng = op.cast_final_safe<MultiHeadAttn>();
auto&& rng = op.cast_final_safe<MultiHeadAttn>(); auto handle = rng.handle;
auto handle = rng.handle; if (!handle) {
if (!handle) { handle = RNGDnnOpManager::get_default_handle(cn);
handle = RNGDnnOpManager::get_default_handle(cn); }
} auto dnn_op_thread_safe = RNGDnnOpManager::inst().get_dnn_op<megdnn::MultiHeadAttn>(
auto dnn_op_thread_safe = handle, reinterpret_cast<size_t>(op.dyn_typeinfo()), cn);
RNGDnnOpManager::inst().get_dnn_op<megdnn::MultiHeadAttn>( auto dnn_op = std::get<1>(dnn_op_thread_safe);
handle, reinterpret_cast<size_t>(op.dyn_typeinfo()), cn); dnn_op->param() = OpMeth<MultiHeadAttn>::make_param(rng);
auto dnn_op = std::get<1>(dnn_op_thread_safe);
dnn_op->param() = OpMeth<MultiHeadAttn>::make_param(rng);
return dnn_op->get_reservespace_in_bytes( TensorLayout out, attn_weight, mask_layout, othr_layout;
inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(), dnn_op->deduce_layout(
inputs[3]->layout(), {}, {}); inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], inputs[5], inputs[6],
}; out, attn_weight, mask_layout, othr_layout);
dests[0].comp_node = cn;
dests[0].layout = out;
dests[0].layout.dtype = inputs[0].dtype;
dests[1].comp_node = cn; dests[1].comp_node = cn;
dests[1].layout = dests[1].layout = attn_weight;
TensorLayout(TensorShape({get_reservespace_in_bytes()}), dtype::Byte()); if (success) {
return dests; dests[2].comp_node = cn;
dests[2].layout = mask_layout;
dests[3].comp_node = cn;
dests[3].layout = othr_layout;
} else {
dests[2].comp_node = cn;
dests[2].layout = TensorLayout(dtype::Byte());
dests[3].comp_node = cn;
dests[3].layout = TensorLayout(inputs[0].dtype);
}
return {dests, success};
}
template <>
SmallVector<LogicalTensorDesc> infer_output_attrs<MultiHeadAttn>(
const OpDef& op, const SmallVector<TensorPtr>& inputs) {
using INPUT_TYPE = opr::MultiHeadAttn::Param::TENSOR_COMBINATION_TYPE;
auto&& cn = inputs[0]->comp_node();
auto input_type = op.cast_final_safe<MultiHeadAttn>().tensor_combination_type;
std::tuple<SmallVector<LogicalTensorDesc>, bool> ret;
TensorLayout empty_layout;
if (input_type == INPUT_TYPE::NONE)
ret = _infer_output_attrs<MultiHeadAttn>(
op,
{inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
inputs[3]->layout(), empty_layout, empty_layout, empty_layout},
cn);
else if (input_type == INPUT_TYPE::ONLY_MASK)
ret = _infer_output_attrs<MultiHeadAttn>(
op,
{inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
inputs[3]->layout(), inputs[4]->layout(), empty_layout, empty_layout},
cn);
else if (input_type == INPUT_TYPE::ONLY_BIASKV)
ret = _infer_output_attrs<MultiHeadAttn>(
op,
{inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
inputs[3]->layout(), empty_layout, inputs[4]->layout(),
inputs[5]->layout()},
cn);
else
ret = _infer_output_attrs<MultiHeadAttn>(
op,
{inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
inputs[3]->layout(), inputs[4]->layout(), inputs[5]->layout(),
inputs[6]->layout()},
cn);
return std::get<0>(ret);
} }
template <typename Op> template <typename Op>
...@@ -587,6 +662,127 @@ SmallVector<TensorPtr> apply_on_physical_tensor( ...@@ -587,6 +662,127 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return outputs; return outputs;
} }
template <>
SmallVector<TensorPtr> apply_on_physical_tensor<MultiHeadAttn>(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
using INPUT_TYPE = opr::MultiHeadAttn::Param::TENSOR_COMBINATION_TYPE;
SmallVector<TensorPtr> outputs;
SmallVector<LogicalTensorDesc> desc =
infer_output_attrs<MultiHeadAttn>(def, inputs);
for (auto&& i : desc) {
outputs.push_back(Tensor::make(i.layout, i.comp_node));
}
auto&& rng = def.cast_final_safe<MultiHeadAttn>();
auto dest = outputs[0];
if (dest->layout().is_empty())
return outputs;
auto cn = dest->comp_node();
auto handle = rng.handle;
if (!handle) {
handle = RNGDnnOpManager::get_default_handle(cn);
}
// retrieve dnn_op from glob cache
auto dnn_op_thread_safe =
RNGDnnOpManager::inst().get_dnn_op<typename OpMeth<MultiHeadAttn>::DnnOp>(
handle, reinterpret_cast<size_t>(def.dyn_typeinfo()), cn);
auto initialized = std::get<0>(dnn_op_thread_safe);
auto dnn_op = std::get<1>(dnn_op_thread_safe);
if (initialized) {
auto handle_seed = RNGDnnOpManager::get_seed(handle);
mgb_assert(
dnn_op->param().seed == handle_seed,
"inconsistent rng seed: handle: %lu, dnn_op: %lu", handle_seed,
dnn_op->param().seed);
}
dnn_op->param() = OpMeth<MultiHeadAttn>::make_param(rng);
auto input_type = rng.tensor_combination_type;
std::shared_ptr<Tensor> empty_dnn(nullptr);
size_t wk_size = 0;
TensorLayout empty_layout;
megdnn::TensorND empty_tensor;
if (input_type == INPUT_TYPE::ALL) {
wk_size = dnn_op->get_workspace_in_bytes(
inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
inputs[3]->layout(), inputs[4]->layout(), inputs[5]->layout(),
inputs[6]->layout(), outputs[0]->layout(), outputs[1]->layout(),
outputs[2]->layout(), outputs[3]->layout());
auto workspace = Blob::make(outputs[0]->comp_node(), wk_size);
megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size);
dnn_op->exec(
inputs[0]->dev_tensor().as_megdnn(),
inputs[1]->dev_tensor().as_megdnn(),
inputs[2]->dev_tensor().as_megdnn(),
inputs[3]->dev_tensor().as_megdnn(),
inputs[4]->dev_tensor().as_megdnn(),
inputs[5]->dev_tensor().as_megdnn(),
inputs[6]->dev_tensor().as_megdnn(),
outputs[0]->dev_tensor().as_megdnn(),
outputs[1]->dev_tensor().as_megdnn(),
outputs[2]->dev_tensor().as_megdnn(),
outputs[3]->dev_tensor().as_megdnn(), dnn_wk);
} else if (input_type == INPUT_TYPE::ONLY_MASK) {
wk_size = dnn_op->get_workspace_in_bytes(
inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
inputs[3]->layout(), inputs[4]->layout(), empty_layout, empty_layout,
outputs[0]->layout(), outputs[1]->layout(), outputs[2]->layout(),
outputs[3]->layout());
auto workspace = Blob::make(outputs[0]->comp_node(), wk_size);
megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size);
dnn_op->exec(
inputs[0]->dev_tensor().as_megdnn(),
inputs[1]->dev_tensor().as_megdnn(),
inputs[2]->dev_tensor().as_megdnn(),
inputs[3]->dev_tensor().as_megdnn(),
inputs[4]->dev_tensor().as_megdnn(), empty_tensor, empty_tensor,
outputs[0]->dev_tensor().as_megdnn(),
outputs[1]->dev_tensor().as_megdnn(),
outputs[2]->dev_tensor().as_megdnn(),
outputs[3]->dev_tensor().as_megdnn(), dnn_wk);
} else if (input_type == INPUT_TYPE::ONLY_BIASKV) {
wk_size = dnn_op->get_workspace_in_bytes(
inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
inputs[3]->layout(), empty_layout, inputs[4]->layout(),
inputs[5]->layout(), outputs[0]->layout(), outputs[1]->layout(),
outputs[2]->layout(), outputs[3]->layout());
auto workspace = Blob::make(outputs[0]->comp_node(), wk_size);
megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size);
dnn_op->exec(
inputs[0]->dev_tensor().as_megdnn(),
inputs[1]->dev_tensor().as_megdnn(),
inputs[2]->dev_tensor().as_megdnn(),
inputs[3]->dev_tensor().as_megdnn(), empty_tensor,
inputs[5]->dev_tensor().as_megdnn(),
inputs[6]->dev_tensor().as_megdnn(),
outputs[0]->dev_tensor().as_megdnn(),
outputs[1]->dev_tensor().as_megdnn(),
outputs[2]->dev_tensor().as_megdnn(),
outputs[3]->dev_tensor().as_megdnn(), dnn_wk);
} else {
wk_size = dnn_op->get_workspace_in_bytes(
inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
inputs[3]->layout(), empty_layout, empty_layout, empty_layout,
outputs[0]->layout(), outputs[1]->layout(), outputs[2]->layout(),
outputs[3]->layout());
auto workspace = Blob::make(outputs[0]->comp_node(), wk_size);
megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size);
dnn_op->exec(
inputs[0]->dev_tensor().as_megdnn(),
inputs[1]->dev_tensor().as_megdnn(),
inputs[2]->dev_tensor().as_megdnn(),
inputs[3]->dev_tensor().as_megdnn(), empty_tensor, empty_tensor,
empty_tensor, outputs[0]->dev_tensor().as_megdnn(),
outputs[1]->dev_tensor().as_megdnn(),
outputs[2]->dev_tensor().as_megdnn(),
outputs[3]->dev_tensor().as_megdnn(), dnn_wk);
}
return outputs;
}
template <typename Op, typename Output> template <typename Op, typename Output>
Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
size_t nr_inp = inputs.size(); size_t nr_inp = inputs.size();
...@@ -601,6 +797,23 @@ Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { ...@@ -601,6 +797,23 @@ Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
return _RNGOprMaker<mgb_nr_inp>::make(inputs, rng); return _RNGOprMaker<mgb_nr_inp>::make(inputs, rng);
} }
template <>
SymbolVarArray apply_on_var_node<MultiHeadAttn, SymbolVarArray>(
const OpDef& def, const VarNodeArray& inputs) {
auto&& rng = def.cast_final_safe<MultiHeadAttn>();
using INPUT_TYPE = opr::MultiHeadAttn::Param::TENSOR_COMBINATION_TYPE;
auto input_type = rng.tensor_combination_type;
if (input_type == INPUT_TYPE::ALL) {
return _RNGOprMaker<7>::make(inputs, rng);
} else if (input_type == INPUT_TYPE::ONLY_BIASKV) {
return _RNGOprMaker<6>::make(inputs, rng);
} else if (input_type == INPUT_TYPE::ONLY_MASK) {
return _RNGOprMaker<5>::make(inputs, rng);
} else {
return _RNGOprMaker<4>::make(inputs, rng);
}
}
template <typename Op> template <typename Op>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
...@@ -671,39 +884,38 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dro ...@@ -671,39 +884,38 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dro
template <> template <>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible< std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<
MultiHeadAttn>(const OpDef& op, const SmallVector<LogicalTensorDesc>& inputs) { MultiHeadAttn>(const OpDef& op, const SmallVector<LogicalTensorDesc>& inputs) {
bool success = inputs[0].layout.ndim != 0; using INPUT_TYPE = opr::MultiHeadAttn::Param::TENSOR_COMBINATION_TYPE;
auto&& cn = inputs[0].comp_node;
SmallVector<LogicalTensorDesc> dests(2); auto input_type = op.cast_final_safe<MultiHeadAttn>().tensor_combination_type;
auto cn = inputs[0].comp_node;
dests[0].comp_node = cn; std::tuple<SmallVector<LogicalTensorDesc>, bool> ret;
dests[0].layout = TensorLayout(inputs[0].layout); TensorLayout empty_layout;
dests[0].layout.dtype = inputs[0].layout.dtype; if (input_type == INPUT_TYPE::NONE)
ret = _infer_output_attrs<MultiHeadAttn>(
auto get_reservespace_in_bytes = [&]() -> size_t { op,
auto&& rng = op.cast_final_safe<MultiHeadAttn>(); {inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout,
auto handle = rng.handle; empty_layout, empty_layout, empty_layout},
if (!handle) { cn);
handle = RNGDnnOpManager::get_default_handle(cn); else if (input_type == INPUT_TYPE::ONLY_MASK)
} ret = _infer_output_attrs<MultiHeadAttn>(
auto dnn_op_thread_safe = op,
RNGDnnOpManager::inst().get_dnn_op<megdnn::MultiHeadAttn>( {inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout,
handle, reinterpret_cast<size_t>(op.dyn_typeinfo()), cn); inputs[4].layout, empty_layout, empty_layout},
auto dnn_op = std::get<1>(dnn_op_thread_safe); cn);
dnn_op->param() = OpMeth<MultiHeadAttn>::make_param(rng); else if (input_type == INPUT_TYPE::ONLY_BIASKV)
ret = _infer_output_attrs<MultiHeadAttn>(
return dnn_op->get_reservespace_in_bytes( op,
inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout, {inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout,
{}, {}); empty_layout, inputs[4].layout, inputs[5].layout},
}; cn);
dests[1].comp_node = cn; else
if (success) { ret = _infer_output_attrs<MultiHeadAttn>(
dests[1].layout = op,
TensorLayout(TensorShape({get_reservespace_in_bytes()}), dtype::Byte()); {inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout,
} else { inputs[4].layout, inputs[5].layout, inputs[6].layout},
dests[1].layout = TensorLayout(dtype::Byte()); cn);
}
return ret;
return {dests, success};
} }
template <typename Op> template <typename Op>
......
0a8cd3cd50cadfaae0478ee70621618e ../../dnn/scripts/opr_param_defs.py 0a8cd3cd50cadfaae0478ee70621618e ../../dnn/scripts/opr_param_defs.py
9e9636d66694dd7d5a7853247a5406f9 ../../src/core/include/megbrain/ir/ops.td 9e9636d66694dd7d5a7853247a5406f9 ../../src/core/include/megbrain/ir/ops.td
283dffd0e9cd28db5155c44cf4eda148 generated/opdef.h.inl 2c15c869c1731d1bc5f25f9b132f4f08 generated/opdef.h.inl
5e8d57337c3aec6f4b3b30ef9ba141f8 generated/opdef.cpp.inl 0dabeee4b8f81be4c1809906b99795a5 generated/opdef.cpp.inl
7f470236e4b5b00bdeaec321bc7187b5 generated/opdef.py.inl be20faf18eccbc56f535b012170ed90a generated/opdef.py.inl
003addd357423b880cd06410f5bf624b generated/opdef.cpy.inl af9ab62fe962d409bb65e66af5f44a79 generated/opdef.cpy.inl
d468302f2d4b113913b76b5a181aae56 generated/enum_macro.h d468302f2d4b113913b76b5a181aae56 generated/enum_macro.h
...@@ -5321,8 +5321,8 @@ std::vector<std::pair<const char*, std::string>> MultiHeadAttn_props_impl(const ...@@ -5321,8 +5321,8 @@ std::vector<std::pair<const char*, std::string>> MultiHeadAttn_props_impl(const
props_.emplace_back("tensor_combination_type", "INVALID"); props_.emplace_back("tensor_combination_type", "INVALID");
break; break;
} }
props_.emplace_back("need_weights", std::to_string(op_.need_weights));
props_.emplace_back("add_zero_attn", std::to_string(op_.add_zero_attn)); props_.emplace_back("add_zero_attn", std::to_string(op_.add_zero_attn));
props_.emplace_back("need_weights", std::to_string(op_.need_weights));
props_.emplace_back("reslink", std::to_string(op_.reslink)); props_.emplace_back("reslink", std::to_string(op_.reslink));
props_.emplace_back("training", std::to_string(op_.training)); props_.emplace_back("training", std::to_string(op_.training));
props_.emplace_back("seed", std::to_string(op_.seed)); props_.emplace_back("seed", std::to_string(op_.seed));
......
...@@ -15238,8 +15238,8 @@ PyOpDefBegin(MultiHeadAttn) // { ...@@ -15238,8 +15238,8 @@ PyOpDefBegin(MultiHeadAttn) // {
{"input_order", serialization<decltype(opdef.input_order)>::dump(opdef.input_order)}, {"input_order", serialization<decltype(opdef.input_order)>::dump(opdef.input_order)},
{"attn_mask_type", serialization<decltype(opdef.attn_mask_type)>::dump(opdef.attn_mask_type)}, {"attn_mask_type", serialization<decltype(opdef.attn_mask_type)>::dump(opdef.attn_mask_type)},
{"tensor_combination_type", serialization<decltype(opdef.tensor_combination_type)>::dump(opdef.tensor_combination_type)}, {"tensor_combination_type", serialization<decltype(opdef.tensor_combination_type)>::dump(opdef.tensor_combination_type)},
{"need_weights", serialization<decltype(opdef.need_weights)>::dump(opdef.need_weights)},
{"add_zero_attn", serialization<decltype(opdef.add_zero_attn)>::dump(opdef.add_zero_attn)}, {"add_zero_attn", serialization<decltype(opdef.add_zero_attn)>::dump(opdef.add_zero_attn)},
{"need_weights", serialization<decltype(opdef.need_weights)>::dump(opdef.need_weights)},
{"reslink", serialization<decltype(opdef.reslink)>::dump(opdef.reslink)}, {"reslink", serialization<decltype(opdef.reslink)>::dump(opdef.reslink)},
{"training", serialization<decltype(opdef.training)>::dump(opdef.training)}, {"training", serialization<decltype(opdef.training)>::dump(opdef.training)},
{"seed", serialization<decltype(opdef.seed)>::dump(opdef.seed)}, {"seed", serialization<decltype(opdef.seed)>::dump(opdef.seed)},
...@@ -15369,16 +15369,16 @@ PyOpDefBegin(MultiHeadAttn) // { ...@@ -15369,16 +15369,16 @@ PyOpDefBegin(MultiHeadAttn) // {
} }
{ {
auto&& iter = state.find("need_weights"); auto&& iter = state.find("add_zero_attn");
if (iter != state.end()) { if (iter != state.end()) {
opdef.need_weights = serialization<decltype(opdef.need_weights)>::load(iter->second); opdef.add_zero_attn = serialization<decltype(opdef.add_zero_attn)>::load(iter->second);
} }
} }
{ {
auto&& iter = state.find("add_zero_attn"); auto&& iter = state.find("need_weights");
if (iter != state.end()) { if (iter != state.end()) {
opdef.add_zero_attn = serialization<decltype(opdef.add_zero_attn)>::load(iter->second); opdef.need_weights = serialization<decltype(opdef.need_weights)>::load(iter->second);
} }
} }
...@@ -15432,9 +15432,9 @@ PyOpDefBegin(MultiHeadAttn) // { ...@@ -15432,9 +15432,9 @@ PyOpDefBegin(MultiHeadAttn) // {
PyOpDefEnd(MultiHeadAttn) PyOpDefEnd(MultiHeadAttn)
int PyOp(MultiHeadAttn)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { int PyOp(MultiHeadAttn)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
static const char* kwlist[] = {"num_heads", "embeding_size", "k_size", "v_size", "qproj_size", "kproj_size", "vproj_size", "oproj_size", "qbias", "kbias", "vbias", "obias", "sm_scaler", "input_order", "attn_mask_type", "tensor_combination_type", "need_weights", "add_zero_attn", "reslink", "training", "seed", "attn_prob", "out_prob", "handle", "scope", NULL}; static const char* kwlist[] = {"num_heads", "embeding_size", "k_size", "v_size", "qproj_size", "kproj_size", "vproj_size", "oproj_size", "qbias", "kbias", "vbias", "obias", "sm_scaler", "input_order", "attn_mask_type", "tensor_combination_type", "add_zero_attn", "need_weights", "reslink", "training", "seed", "attn_prob", "out_prob", "handle", "scope", NULL};
PyObject *num_heads = NULL, *embeding_size = NULL, *k_size = NULL, *v_size = NULL, *qproj_size = NULL, *kproj_size = NULL, *vproj_size = NULL, *oproj_size = NULL, *qbias = NULL, *kbias = NULL, *vbias = NULL, *obias = NULL, *sm_scaler = NULL, *input_order = NULL, *attn_mask_type = NULL, *tensor_combination_type = NULL, *need_weights = NULL, *add_zero_attn = NULL, *reslink = NULL, *training = NULL, *seed = NULL, *attn_prob = NULL, *out_prob = NULL, *handle = NULL, *scope = NULL; PyObject *num_heads = NULL, *embeding_size = NULL, *k_size = NULL, *v_size = NULL, *qproj_size = NULL, *kproj_size = NULL, *vproj_size = NULL, *oproj_size = NULL, *qbias = NULL, *kbias = NULL, *vbias = NULL, *obias = NULL, *sm_scaler = NULL, *input_order = NULL, *attn_mask_type = NULL, *tensor_combination_type = NULL, *add_zero_attn = NULL, *need_weights = NULL, *reslink = NULL, *training = NULL, *seed = NULL, *attn_prob = NULL, *out_prob = NULL, *handle = NULL, *scope = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOOOOOOOOOOOOOOOOOOOOOO", const_cast<char**>(kwlist), &num_heads, &embeding_size, &k_size, &v_size, &qproj_size, &kproj_size, &vproj_size, &oproj_size, &qbias, &kbias, &vbias, &obias, &sm_scaler, &input_order, &attn_mask_type, &tensor_combination_type, &need_weights, &add_zero_attn, &reslink, &training, &seed, &attn_prob, &out_prob, &handle, &scope)) if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOOOOOOOOOOOOOOOOOOOOOO", const_cast<char**>(kwlist), &num_heads, &embeding_size, &k_size, &v_size, &qproj_size, &kproj_size, &vproj_size, &oproj_size, &qbias, &kbias, &vbias, &obias, &sm_scaler, &input_order, &attn_mask_type, &tensor_combination_type, &add_zero_attn, &need_weights, &reslink, &training, &seed, &attn_prob, &out_prob, &handle, &scope))
return -1; return -1;
if (num_heads) { if (num_heads) {
...@@ -15581,21 +15581,21 @@ int PyOp(MultiHeadAttn)::py_init(PyObject *self, PyObject *args, PyObject *kwds) ...@@ -15581,21 +15581,21 @@ int PyOp(MultiHeadAttn)::py_init(PyObject *self, PyObject *args, PyObject *kwds)
} CATCH_ALL(-1) } CATCH_ALL(-1)
} }
if (need_weights) { if (add_zero_attn) {
try { try {
// TODO: remove this guard which is used for pybind11 implicit conversion // TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{}; py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().need_weights = reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().add_zero_attn =
py::cast<decltype(MultiHeadAttn::need_weights)>(py::handle(need_weights)); py::cast<decltype(MultiHeadAttn::add_zero_attn)>(py::handle(add_zero_attn));
} CATCH_ALL(-1) } CATCH_ALL(-1)
} }
if (add_zero_attn) { if (need_weights) {
try { try {
// TODO: remove this guard which is used for pybind11 implicit conversion // TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{}; py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().add_zero_attn = reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().need_weights =
py::cast<decltype(MultiHeadAttn::add_zero_attn)>(py::handle(add_zero_attn)); py::cast<decltype(MultiHeadAttn::need_weights)>(py::handle(need_weights));
} CATCH_ALL(-1) } CATCH_ALL(-1)
} }
...@@ -15680,8 +15680,8 @@ PyGetSetDef PyOp(MultiHeadAttn)::py_getsetters[] = { ...@@ -15680,8 +15680,8 @@ PyGetSetDef PyOp(MultiHeadAttn)::py_getsetters[] = {
{const_cast<char*>("input_order"), py_get_generic(MultiHeadAttn, input_order), py_set_generic(MultiHeadAttn, input_order), const_cast<char*>("input_order"), NULL}, {const_cast<char*>("input_order"), py_get_generic(MultiHeadAttn, input_order), py_set_generic(MultiHeadAttn, input_order), const_cast<char*>("input_order"), NULL},
{const_cast<char*>("attn_mask_type"), py_get_generic(MultiHeadAttn, attn_mask_type), py_set_generic(MultiHeadAttn, attn_mask_type), const_cast<char*>("attn_mask_type"), NULL}, {const_cast<char*>("attn_mask_type"), py_get_generic(MultiHeadAttn, attn_mask_type), py_set_generic(MultiHeadAttn, attn_mask_type), const_cast<char*>("attn_mask_type"), NULL},
{const_cast<char*>("tensor_combination_type"), py_get_generic(MultiHeadAttn, tensor_combination_type), py_set_generic(MultiHeadAttn, tensor_combination_type), const_cast<char*>("tensor_combination_type"), NULL}, {const_cast<char*>("tensor_combination_type"), py_get_generic(MultiHeadAttn, tensor_combination_type), py_set_generic(MultiHeadAttn, tensor_combination_type), const_cast<char*>("tensor_combination_type"), NULL},
{const_cast<char*>("need_weights"), py_get_generic(MultiHeadAttn, need_weights), py_set_generic(MultiHeadAttn, need_weights), const_cast<char*>("need_weights"), NULL},
{const_cast<char*>("add_zero_attn"), py_get_generic(MultiHeadAttn, add_zero_attn), py_set_generic(MultiHeadAttn, add_zero_attn), const_cast<char*>("add_zero_attn"), NULL}, {const_cast<char*>("add_zero_attn"), py_get_generic(MultiHeadAttn, add_zero_attn), py_set_generic(MultiHeadAttn, add_zero_attn), const_cast<char*>("add_zero_attn"), NULL},
{const_cast<char*>("need_weights"), py_get_generic(MultiHeadAttn, need_weights), py_set_generic(MultiHeadAttn, need_weights), const_cast<char*>("need_weights"), NULL},
{const_cast<char*>("reslink"), py_get_generic(MultiHeadAttn, reslink), py_set_generic(MultiHeadAttn, reslink), const_cast<char*>("reslink"), NULL}, {const_cast<char*>("reslink"), py_get_generic(MultiHeadAttn, reslink), py_set_generic(MultiHeadAttn, reslink), const_cast<char*>("reslink"), NULL},
{const_cast<char*>("training"), py_get_generic(MultiHeadAttn, training), py_set_generic(MultiHeadAttn, training), const_cast<char*>("training"), NULL}, {const_cast<char*>("training"), py_get_generic(MultiHeadAttn, training), py_set_generic(MultiHeadAttn, training), const_cast<char*>("training"), NULL},
{const_cast<char*>("seed"), py_get_generic(MultiHeadAttn, seed), py_set_generic(MultiHeadAttn, seed), const_cast<char*>("seed"), NULL}, {const_cast<char*>("seed"), py_get_generic(MultiHeadAttn, seed), py_set_generic(MultiHeadAttn, seed), const_cast<char*>("seed"), NULL},
...@@ -15708,7 +15708,7 @@ PyMethodDef PyOp(MultiHeadAttn)::py_init_methoddef = { ...@@ -15708,7 +15708,7 @@ PyMethodDef PyOp(MultiHeadAttn)::py_init_methoddef = {
"__init__", "__init__",
(PyCFunction)PyOp(MultiHeadAttn)::py_init_proxy, (PyCFunction)PyOp(MultiHeadAttn)::py_init_proxy,
METH_VARARGS | METH_KEYWORDS, METH_VARARGS | METH_KEYWORDS,
"__init__(self, num_heads: int = ..., embeding_size: int = ..., k_size: int = ..., v_size: int = ..., qproj_size: int = ..., kproj_size: int = ..., vproj_size: int = ..., oproj_size: int = ..., qbias: bool = ..., kbias: bool = ..., vbias: bool = ..., obias: bool = ..., sm_scaler: float = ..., input_order: int = ..., attn_mask_type: Union[str, ATTN_MASK_TYPE] = ..., tensor_combination_type: Union[str, TENSOR_COMBINATION_TYPE] = ..., need_weights: bool = ..., add_zero_attn: bool = ..., reslink: bool = ..., training: bool = ..., seed: int = ..., attn_prob: float = ..., out_prob: float = ..., handle: int = ...) -> None\n" "__init__(self, num_heads: int = ..., embeding_size: int = ..., k_size: int = ..., v_size: int = ..., qproj_size: int = ..., kproj_size: int = ..., vproj_size: int = ..., oproj_size: int = ..., qbias: bool = ..., kbias: bool = ..., vbias: bool = ..., obias: bool = ..., sm_scaler: float = ..., input_order: int = ..., attn_mask_type: Union[str, ATTN_MASK_TYPE] = ..., tensor_combination_type: Union[str, TENSOR_COMBINATION_TYPE] = ..., add_zero_attn: bool = ..., need_weights: bool = ..., reslink: bool = ..., training: bool = ..., seed: int = ..., attn_prob: float = ..., out_prob: float = ..., handle: int = ...) -> None\n"
}; };
void _init_py_MultiHeadAttn(py::module m) { void _init_py_MultiHeadAttn(py::module m) {
......
...@@ -1416,8 +1416,8 @@ public: ...@@ -1416,8 +1416,8 @@ public:
uint32_t input_order = 0; uint32_t input_order = 0;
ATTN_MASK_TYPE attn_mask_type = ::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK; ATTN_MASK_TYPE attn_mask_type = ::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK;
TENSOR_COMBINATION_TYPE tensor_combination_type = ::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE; TENSOR_COMBINATION_TYPE tensor_combination_type = ::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE;
bool need_weights = false;
bool add_zero_attn = false; bool add_zero_attn = false;
bool need_weights = false;
bool reslink = false; bool reslink = false;
bool training = true; bool training = true;
uint64_t seed = 0; uint64_t seed = 0;
...@@ -1425,10 +1425,10 @@ public: ...@@ -1425,10 +1425,10 @@ public:
float out_prob = 0.f; float out_prob = 0.f;
size_t handle; size_t handle;
MultiHeadAttn() = default; MultiHeadAttn() = default;
MultiHeadAttn(uint32_t num_heads_, uint32_t embeding_size_, uint32_t k_size_, uint32_t v_size_, uint32_t qproj_size_, uint32_t kproj_size_, uint32_t vproj_size_, uint32_t oproj_size_, bool qbias_, bool kbias_, bool vbias_, bool obias_, float sm_scaler_, uint32_t input_order_, ATTN_MASK_TYPE attn_mask_type_, TENSOR_COMBINATION_TYPE tensor_combination_type_, bool need_weights_, bool add_zero_attn_, bool reslink_, bool training_, uint64_t seed_, float attn_prob_, float out_prob_, size_t handle_, std::string scope_ = {}): num_heads(num_heads_), embeding_size(embeding_size_), k_size(k_size_), v_size(v_size_), qproj_size(qproj_size_), kproj_size(kproj_size_), vproj_size(vproj_size_), oproj_size(oproj_size_), qbias(qbias_), kbias(kbias_), vbias(vbias_), obias(obias_), sm_scaler(sm_scaler_), input_order(input_order_), attn_mask_type(attn_mask_type_), tensor_combination_type(tensor_combination_type_), need_weights(need_weights_), add_zero_attn(add_zero_attn_), reslink(reslink_), training(training_), seed(seed_), attn_prob(attn_prob_), out_prob(out_prob_), handle(handle_) { set_scope(scope_); } MultiHeadAttn(uint32_t num_heads_, uint32_t embeding_size_, uint32_t k_size_, uint32_t v_size_, uint32_t qproj_size_, uint32_t kproj_size_, uint32_t vproj_size_, uint32_t oproj_size_, bool qbias_, bool kbias_, bool vbias_, bool obias_, float sm_scaler_, uint32_t input_order_, ATTN_MASK_TYPE attn_mask_type_, TENSOR_COMBINATION_TYPE tensor_combination_type_, bool add_zero_attn_, bool need_weights_, bool reslink_, bool training_, uint64_t seed_, float attn_prob_, float out_prob_, size_t handle_, std::string scope_ = {}): num_heads(num_heads_), embeding_size(embeding_size_), k_size(k_size_), v_size(v_size_), qproj_size(qproj_size_), kproj_size(kproj_size_), vproj_size(vproj_size_), oproj_size(oproj_size_), qbias(qbias_), kbias(kbias_), vbias(vbias_), obias(obias_), sm_scaler(sm_scaler_), input_order(input_order_), attn_mask_type(attn_mask_type_), tensor_combination_type(tensor_combination_type_), add_zero_attn(add_zero_attn_), need_weights(need_weights_), reslink(reslink_), training(training_), seed(seed_), attn_prob(attn_prob_), out_prob(out_prob_), handle(handle_) { set_scope(scope_); }
MultiHeadAttn(::megdnn::param::MultiHeadAttn packed_param_0, size_t handle_): num_heads(packed_param_0.num_heads), embeding_size(packed_param_0.embeding_size), k_size(packed_param_0.k_size), v_size(packed_param_0.v_size), qproj_size(packed_param_0.qproj_size), kproj_size(packed_param_0.kproj_size), vproj_size(packed_param_0.vproj_size), oproj_size(packed_param_0.oproj_size), qbias(packed_param_0.qbias), kbias(packed_param_0.kbias), vbias(packed_param_0.vbias), obias(packed_param_0.obias), sm_scaler(packed_param_0.sm_scaler), input_order(packed_param_0.input_order), attn_mask_type(packed_param_0.attn_mask_type), tensor_combination_type(packed_param_0.tensor_combination_type), need_weights(packed_param_0.need_weights), add_zero_attn(packed_param_0.add_zero_attn), reslink(packed_param_0.reslink), training(packed_param_0.training), seed(packed_param_0.seed), attn_prob(packed_param_0.attn_prob), out_prob(packed_param_0.out_prob), handle(handle_) {} MultiHeadAttn(::megdnn::param::MultiHeadAttn packed_param_0, size_t handle_): num_heads(packed_param_0.num_heads), embeding_size(packed_param_0.embeding_size), k_size(packed_param_0.k_size), v_size(packed_param_0.v_size), qproj_size(packed_param_0.qproj_size), kproj_size(packed_param_0.kproj_size), vproj_size(packed_param_0.vproj_size), oproj_size(packed_param_0.oproj_size), qbias(packed_param_0.qbias), kbias(packed_param_0.kbias), vbias(packed_param_0.vbias), obias(packed_param_0.obias), sm_scaler(packed_param_0.sm_scaler), input_order(packed_param_0.input_order), attn_mask_type(packed_param_0.attn_mask_type), tensor_combination_type(packed_param_0.tensor_combination_type), add_zero_attn(packed_param_0.add_zero_attn), need_weights(packed_param_0.need_weights), reslink(packed_param_0.reslink), training(packed_param_0.training), seed(packed_param_0.seed), attn_prob(packed_param_0.attn_prob), out_prob(packed_param_0.out_prob), handle(handle_) {}
::megdnn::param::MultiHeadAttn param() const { ::megdnn::param::MultiHeadAttn param() const {
return {num_heads, embeding_size, k_size, v_size, qproj_size, kproj_size, vproj_size, oproj_size, qbias, kbias, vbias, obias, sm_scaler, input_order, attn_mask_type, tensor_combination_type, need_weights, add_zero_attn, reslink, training, seed, attn_prob, out_prob}; return {num_heads, embeding_size, k_size, v_size, qproj_size, kproj_size, vproj_size, oproj_size, qbias, kbias, vbias, obias, sm_scaler, input_order, attn_mask_type, tensor_combination_type, add_zero_attn, need_weights, reslink, training, seed, attn_prob, out_prob};
} }
}; };
......
...@@ -1510,7 +1510,7 @@ py::enum_<MultiHeadAttn::TENSOR_COMBINATION_TYPE>(MultiHeadAttnInst, "TENSOR_COM ...@@ -1510,7 +1510,7 @@ py::enum_<MultiHeadAttn::TENSOR_COMBINATION_TYPE>(MultiHeadAttnInst, "TENSOR_COM
py::implicitly_convertible<std::string, MultiHeadAttn::TENSOR_COMBINATION_TYPE>(); py::implicitly_convertible<std::string, MultiHeadAttn::TENSOR_COMBINATION_TYPE>();
MultiHeadAttnInst MultiHeadAttnInst
.def(py::init<uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, bool, bool, bool, bool, float, uint32_t, ::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE, ::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE, bool, bool, bool, bool, uint64_t, float, float, size_t, std::string>(), py::arg("num_heads") = 1, py::arg("embeding_size") = 0, py::arg("k_size") = 0, py::arg("v_size") = 0, py::arg("qproj_size") = 0, py::arg("kproj_size") = 0, py::arg("vproj_size") = 0, py::arg("oproj_size") = 0, py::arg("qbias") = false, py::arg("kbias") = false, py::arg("vbias") = false, py::arg("obias") = false, py::arg("sm_scaler") = 1.f, py::arg("input_order") = 0, py::arg("attn_mask_type") = ::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK, py::arg("tensor_combination_type") = ::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE, py::arg("need_weights") = false, py::arg("add_zero_attn") = false, py::arg("reslink") = false, py::arg("training") = true, py::arg("seed") = 0, py::arg("attn_prob") = 0.f, py::arg("out_prob") = 0.f, py::arg("handle"), py::arg("scope") = {}) .def(py::init<uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, bool, bool, bool, bool, float, uint32_t, ::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE, ::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE, bool, bool, bool, bool, uint64_t, float, float, size_t, std::string>(), py::arg("num_heads") = 1, py::arg("embeding_size") = 0, py::arg("k_size") = 0, py::arg("v_size") = 0, py::arg("qproj_size") = 0, py::arg("kproj_size") = 0, py::arg("vproj_size") = 0, py::arg("oproj_size") = 0, py::arg("qbias") = false, py::arg("kbias") = false, py::arg("vbias") = false, py::arg("obias") = false, py::arg("sm_scaler") = 1.f, py::arg("input_order") = 0, py::arg("attn_mask_type") = ::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK, py::arg("tensor_combination_type") = ::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE, py::arg("add_zero_attn") = false, py::arg("need_weights") = false, py::arg("reslink") = false, py::arg("training") = true, py::arg("seed") = 0, py::arg("attn_prob") = 0.f, py::arg("out_prob") = 0.f, py::arg("handle"), py::arg("scope") = {})
.def(py::init<>()) .def(py::init<>())
.def_readwrite("num_heads", &MultiHeadAttn::num_heads) .def_readwrite("num_heads", &MultiHeadAttn::num_heads)
.def_readwrite("embeding_size", &MultiHeadAttn::embeding_size) .def_readwrite("embeding_size", &MultiHeadAttn::embeding_size)
...@@ -1528,8 +1528,8 @@ MultiHeadAttnInst ...@@ -1528,8 +1528,8 @@ MultiHeadAttnInst
.def_readwrite("input_order", &MultiHeadAttn::input_order) .def_readwrite("input_order", &MultiHeadAttn::input_order)
.def_readwrite("attn_mask_type", &MultiHeadAttn::attn_mask_type) .def_readwrite("attn_mask_type", &MultiHeadAttn::attn_mask_type)
.def_readwrite("tensor_combination_type", &MultiHeadAttn::tensor_combination_type) .def_readwrite("tensor_combination_type", &MultiHeadAttn::tensor_combination_type)
.def_readwrite("need_weights", &MultiHeadAttn::need_weights)
.def_readwrite("add_zero_attn", &MultiHeadAttn::add_zero_attn) .def_readwrite("add_zero_attn", &MultiHeadAttn::add_zero_attn)
.def_readwrite("need_weights", &MultiHeadAttn::need_weights)
.def_readwrite("reslink", &MultiHeadAttn::reslink) .def_readwrite("reslink", &MultiHeadAttn::reslink)
.def_readwrite("training", &MultiHeadAttn::training) .def_readwrite("training", &MultiHeadAttn::training)
.def_readwrite("seed", &MultiHeadAttn::seed) .def_readwrite("seed", &MultiHeadAttn::seed)
......
...@@ -184,6 +184,12 @@ using MegDNNOprMethInvoker = _MegDNNOprMethInvoker<Opr::NR_INPUTS, Opr::NR_OUTPU ...@@ -184,6 +184,12 @@ using MegDNNOprMethInvoker = _MegDNNOprMethInvoker<Opr::NR_INPUTS, Opr::NR_OUTPU
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1), _o(2) #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1), _o(2)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 5
#define _NR_OUTPUTS 4
#define _FOREACH_IO(_i, _o) \
_i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1), _o(2), _o(3)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 5 #define _NR_INPUTS 5
#define _NR_OUTPUTS 5 #define _NR_OUTPUTS 5
#define _FOREACH_IO(_i, _o) \ #define _FOREACH_IO(_i, _o) \
...@@ -218,6 +224,25 @@ using MegDNNOprMethInvoker = _MegDNNOprMethInvoker<Opr::NR_INPUTS, Opr::NR_OUTPU ...@@ -218,6 +224,25 @@ using MegDNNOprMethInvoker = _MegDNNOprMethInvoker<Opr::NR_INPUTS, Opr::NR_OUTPU
_i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _i(6), _o(0), _o(1), _o(2) _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _i(6), _o(0), _o(1), _o(2)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 7
#define _NR_OUTPUTS 4
#define _FOREACH_IO(_i, _o) \
_i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _i(6), _o(0), _o(1), _o(2), _o(3)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 8
#define _NR_OUTPUTS 4
#define _FOREACH_IO(_i, _o) \
_i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _i(6), _i(7), _o(0), _o(1), _o(2), _o(3)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 9
#define _NR_OUTPUTS 6
#define _FOREACH_IO(_i, _o) \
_i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _i(6), _i(7), _i(8), _o(0), _o(1), \
_o(2), _o(3), _o(4), _o(5)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 9 #define _NR_INPUTS 9
#define _NR_OUTPUTS 4 #define _NR_OUTPUTS 4
#define _FOREACH_IO(_i, _o) \ #define _FOREACH_IO(_i, _o) \
......
此差异已折叠。
...@@ -33,13 +33,35 @@ struct OprMaker<opr::DropoutForward, 1> { ...@@ -33,13 +33,35 @@ struct OprMaker<opr::DropoutForward, 1> {
template <> template <>
struct OprMaker<opr::MultiHeadAttn, 0> { struct OprMaker<opr::MultiHeadAttn, 0> {
using Param = opr::MultiHeadAttn::Param; using Param = opr::MultiHeadAttn::Param;
using INPUT_TYPE = Param::TENSOR_COMBINATION_TYPE;
static cg::OperatorNodeBase* make( static cg::OperatorNodeBase* make(
const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
const OperatorNodeConfig& config) { const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph); MGB_MARK_USED_VAR(graph);
return opr::MultiHeadAttn::make(i[0], i[1], i[2], i[3], param, config)[0] if (i.size() == 7) {
.node() mgb_assert(INPUT_TYPE::ALL == param.tensor_combination_type);
->owner_opr(); return opr::MultiHeadAttn::make(
i[0], i[1], i[2], i[3], i[4], i[5], i[6], param, config)[0]
.node()
->owner_opr();
} else if (i.size() == 6) {
mgb_assert(INPUT_TYPE::ONLY_BIASKV == param.tensor_combination_type);
return opr::MultiHeadAttn::make(
i[0], i[1], i[2], i[3], i[4], i[5], param, config)[0]
.node()
->owner_opr();
} else if (i.size() == 5) {
mgb_assert(INPUT_TYPE::ONLY_MASK == param.tensor_combination_type);
return opr::MultiHeadAttn::make(
i[0], i[1], i[2], i[3], i[4], param, config)[0]
.node()
->owner_opr();
} else {
mgb_assert(INPUT_TYPE::NONE == param.tensor_combination_type);
return opr::MultiHeadAttn::make(i[0], i[1], i[2], i[3], param, config)[0]
.node()
->owner_opr();
}
} }
}; };
...@@ -52,10 +74,18 @@ struct OprMaker<opr::MultiHeadAttnBackward, 0> { ...@@ -52,10 +74,18 @@ struct OprMaker<opr::MultiHeadAttnBackward, 0> {
const OperatorNodeConfig& config) { const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph); MGB_MARK_USED_VAR(graph);
return opr::MultiHeadAttnBackward::make( if (i.size() == 8)
i[0], i[1], i[2], i[3], i[4], i[5], param, config)[0] return opr::MultiHeadAttnBackward::make(
.node() i[0], i[1], i[2], i[3], i[4], i[5], i[6], i[7], param,
->owner_opr(); config)[0]
.node()
->owner_opr();
else
return opr::MultiHeadAttnBackward::make(
i[0], i[1], i[2], i[3], i[4], i[5], i[6], i[7], i[8], param,
config)[0]
.node()
->owner_opr();
} }
}; };
......
...@@ -87,14 +87,50 @@ _DEFINE_RNG_OPR_WITH_INPUT_CLASS(GammaRNG) ...@@ -87,14 +87,50 @@ _DEFINE_RNG_OPR_WITH_INPUT_CLASS(GammaRNG)
#undef _OUTPUTS #undef _OUTPUTS
#undef _INPUTS #undef _INPUTS
/* ================= 4 input ================= */
#define _INPUTS(preifx) preifx i0, preifx i1, preifx i2, preifx i3
#define _OUTPUTS SymbolVarArray
_DEFINE_RNG_OPR_WITH_INPUT_CLASS(MultiHeadAttnForward)
#undef _OUTPUTS
#undef _INPUTS
#undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS #undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
MultiHeadAttnForward, RNGOprBase<megdnn::MultiHeadAttnForward>) // {
void add_input_layout_constraint() override;
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override;
public:
MultiHeadAttnForward(
VarNode* queries, VarNode* keys, VarNode* values, VarNode* qkvo_weight_bias,
VarNode* attn_mask, VarNode* bias_k, VarNode* bias_v, const Param& param,
const OperatorNodeConfig& config);
MultiHeadAttnForward(
VarNode* queries, VarNode* keys, VarNode* values, VarNode* qkvo_weight_bias,
VarNode* attn_mask, const Param& param, const OperatorNodeConfig& config);
MultiHeadAttnForward(
VarNode* queries, VarNode* keys, VarNode* values, VarNode* qkvo_weight_bias,
VarNode* bias_k, VarNode* bias_v, const Param& param,
const OperatorNodeConfig& config);
MultiHeadAttnForward(
VarNode* queries, VarNode* keys, VarNode* values, VarNode* qkvo_weight_bias,
const Param& param, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar queries, SymbolVar keys, SymbolVar values,
SymbolVar qkvo_weight_bias, SymbolVar attn_mask, SymbolVar bias_k,
SymbolVar bias_v, const Param& param = {},
const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar queries, SymbolVar keys, SymbolVar values,
SymbolVar qkvo_weight_bias, SymbolVar attn_mask, const Param& param = {},
const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar queries, SymbolVar keys, SymbolVar values,
SymbolVar qkvo_weight_bias, SymbolVar bias_k, SymbolVar bias_v,
const Param& param = {}, const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar queries, SymbolVar keys, SymbolVar values,
SymbolVar qkvo_weight_bias, const Param& param = {},
const OperatorNodeConfig& config = {});
void init_output_static_infer_desc() override;
void scn_do_execute() override;
};
} // namespace intl } // namespace intl
using UniformRNG = intl::UniformRNG; using UniformRNG = intl::UniformRNG;
...@@ -146,13 +182,27 @@ MGB_DEFINE_OPR_CLASS_WITH_EXPORT( ...@@ -146,13 +182,27 @@ MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
public: public:
MGE_WIN_DECLSPEC_FUC MultiHeadAttnBackward( MGE_WIN_DECLSPEC_FUC MultiHeadAttnBackward(
VarNode* diff, VarNode* queries, VarNode* keys, VarNode* values, VarNode* diff, VarNode* queries, VarNode* keys, VarNode* values,
VarNode* wqkv, VarNode* reserveSpace, const Param& param, VarNode* qkvo_weight_bias, VarNode* attn_mask, VarNode* attn_weight,
VarNode* mask_reservespace, VarNode* othr_reservespace, const Param& param,
const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC MultiHeadAttnBackward(
VarNode* diff, VarNode* queries, VarNode* keys, VarNode* values,
VarNode* qkvo_weight_bias, VarNode* attn_weight, VarNode* mask_reservespace,
VarNode* othr_reservespace, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar diff, SymbolVar queries, SymbolVar keys, SymbolVar values, SymbolVar diff, SymbolVar queries, SymbolVar keys, SymbolVar values,
SymbolVar wqkv, SymbolVar reserveSpace, const Param& param = {}, SymbolVar qkvo_weight_bias, SymbolVar attn_mask, SymbolVar attn_weight,
const OperatorNodeConfig& config = {}); SymbolVar mask_reservespace, SymbolVar othr_reservespace,
const Param& param = {}, const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar diff, SymbolVar queries, SymbolVar keys, SymbolVar values,
SymbolVar qkvo_weight_bias, SymbolVar attn_weight,
SymbolVar mask_reservespace, SymbolVar othr_reservespace,
const Param& param = {}, const OperatorNodeConfig& config = {});
private: private:
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册