diff --git a/dnn/include/megdnn/oprs/nn.h b/dnn/include/megdnn/oprs/nn.h index fa0be658e7c8ace2f1e6c8fdab8ed2971ac6c6cc..01a46a470a106b21757ff5c3e8cf7996f6cf171a 100644 --- a/dnn/include/megdnn/oprs/nn.h +++ b/dnn/include/megdnn/oprs/nn.h @@ -2580,66 +2580,139 @@ class MultiHeadAttnBase : public OperatorBase { }; class MultiHeadAttnForward : public MultiHeadAttnBase { - DEF_OPR_IMPL(MultiHeadAttnForward, MultiHeadAttnBase, 4, 2); + DEF_OPR_IMPL(MultiHeadAttnForward, MultiHeadAttnBase, 7, 4); public: + /** + * \param[in] queries (N, L, E_q), where N is the batch size, L is the target + * sequence length, and E_q is the query embedding dimension embed_dim. + * \param[in] keys (N, S, E_k), where N is the batch size, S is the source + * sequence length, and E_k is the key embedding dimension k_dim. + * \param[in] values (N, S, E_v), where N is the batch size, S is the source + * sequence length, and E_v is the value embedding dimension v_dim. + * \param[in] qkvo_weight_bias, input/output projection weight/bias all in one. + * The order of arrangement is: query weight, key weight, value weight, + * out weight, query bias, key bias, value bias, out bias, the following parameters + * in param will be used to indicate whether these items exist: qproj_size, + * kproj_size, vproj_size, oproj_size, qbias, kbias, vbias, obias. + * Note: Y=X@W+B is used here instead of Y=X@W^T+B in pytorch. + * \param[in] attn_mask, (N*num_heads, L, S) or (L, S), where N is the batch size, + * num_heads is the number of parallel attention heads, L is the target sequence + * length, and S is the source sequence length. attention mask is obtained by + * combining attn_mask, key_padding_mask, is_causal and maybe_cudnn_style_mask by + * mge.functional._merge_masks. + * \param[in] bias_k, (1, 1, kproj_size), where kproj_size is the projected + * dimension of key weight, if kproj_size == 0, will be the key embedding dimension + * k_dim. + * Note: bias_k and bias_v are the bias of the K and V sequences to be added at + * sequence dim, distinguished from kbias and vbias, bias_kv here is not kbias and + * vbias in the linear layer, and bias_kv here will be added to the K and V at + * sequence dimensions, where K and V are the matrices of key and value after + * projection, and K and V will be used to calculate the attention matrix. + * \param[in] bias_v, (1, 1, vproj_size), where vproj_size is the projected + * dimension of value weight, if vproj_size == 0, will be the value embedding + * dimension v_dim. + * Note: see bias_k. + * \param[out] out, (N, S, oproj_size), where N is + * the batch size, S is the source sequence length, and oproj_size is the projected + * dimension of output weight, if oproj_size == 0, will be the projected + * dimension of value weight vproj_size, but if vproj_size == 0, will be the value + * embedding dimension v_dim. + * \param[out] attn_weight, (N * num_heads, L, S), where N is the batch size, + * num_heads is the number of parallel attention heads, L is the target sequence + * length, and S is the source sequence length. + * Note: attn_weight is the output of softmax. + * \param[out] mask_reservespace, when param.training=true, we need this output to + * save the mask of attention dropout and output dropout. + * \param[out] othr_reservespace, when param.training=true, we need this output to + * save the intermediate calculation results. + */ virtual void exec( _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values, - _megdnn_tensor_in wqkv, _megdnn_tensor_out out, - _megdnn_tensor_out reserveSpace, _megdnn_workspace workspace) = 0; - MGE_WIN_DECLSPEC_FUC void deduce_layout( + _megdnn_tensor_in qkvo_weight_bias, _megdnn_tensor_in attn_mask, + _megdnn_tensor_in bias_k, _megdnn_tensor_in bias_v, _megdnn_tensor_out out, + _megdnn_tensor_out attn_weight, _megdnn_tensor_out mask_reservespace, + _megdnn_tensor_out othr_reservespace, _megdnn_workspace workspace) = 0; + virtual void deduce_layout( const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& wqkv, TensorLayout& out, - TensorLayout& reserveSpace); + const TensorLayout& values, const TensorLayout& qkvo_weight_bias, + const TensorLayout& attn_mask, const TensorLayout& bias_k, + const TensorLayout& bias_v, TensorLayout& out, TensorLayout& attn_weight, + TensorLayout& mask_reservespace, TensorLayout& othr_reservespace) = 0; virtual size_t get_workspace_in_bytes( const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& wqkv, - const TensorLayout& out, const TensorLayout& reserveSpace) = 0; - virtual size_t get_reservespace_in_bytes( + const TensorLayout& values, const TensorLayout& qkvo_weight_bias, + const TensorLayout& attn_mask, const TensorLayout& bias_k, + const TensorLayout& bias_v, const TensorLayout& out, + const TensorLayout& attn_weight, const TensorLayout& mask_reservespace, + const TensorLayout& othr_reservespace) = 0; + virtual size_t get_mask_reservespace_in_bytes( const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& wqkv, - const TensorLayout& out, const TensorLayout& reserveSpace) = 0; + const TensorLayout& values, const TensorLayout& qkvo_weight_bias, + const TensorLayout& attn_mask, const TensorLayout& bias_k, + const TensorLayout& bias_v, const TensorLayout& out, + const TensorLayout& attn_weight, const TensorLayout& mask_reservespace, + const TensorLayout& othr_reservespace) = 0; + virtual size_t get_othr_reservespace_in_bytes( + const TensorLayout& queries, const TensorLayout& keys, + const TensorLayout& values, const TensorLayout& qkvo_weight_bias, + const TensorLayout& attn_mask, const TensorLayout& bias_k, + const TensorLayout& bias_v, const TensorLayout& out, + const TensorLayout& attn_weight, const TensorLayout& mask_reservespace, + const TensorLayout& othr_reservespace) = 0; protected: void check_exec( const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& wqkv, - const TensorLayout& out, const TensorLayout& reserveSpace, - size_t workspace_in_bytes); + const TensorLayout& values, const TensorLayout& qkvo_weight_bias, + const TensorLayout& attn_mask, const TensorLayout& bias_k, + const TensorLayout& bias_v, const TensorLayout& out, + const TensorLayout& attn_weight, const TensorLayout& mask_reservespace, + const TensorLayout& othr_reservespace, size_t workspace_in_bytes); }; using MultiHeadAttn = MultiHeadAttnForward; class MultiHeadAttnBackward : public MultiHeadAttnBase { - DEF_OPR_IMPL(MultiHeadAttnBackward, MultiHeadAttnBase, 6, 4); + DEF_OPR_IMPL(MultiHeadAttnBackward, MultiHeadAttnBase, 9, 6); public: virtual void exec( _megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys, - _megdnn_tensor_in values, _megdnn_tensor_in wqkv, - _megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries, - _megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues, - _megdnn_tensor_out dweights, _megdnn_workspace workspace) = 0; + _megdnn_tensor_in values, _megdnn_tensor_in qkvo_weight_bias, + _megdnn_tensor_in attn_mask, _megdnn_tensor_in attn_weight, + _megdnn_tensor_in mask_reservespace, _megdnn_tensor_in othr_reservespace, + _megdnn_tensor_out dqueries, _megdnn_tensor_out dkeys, + _megdnn_tensor_out dvalues, _megdnn_tensor_out dqkvo_weight_bias, + _megdnn_tensor_out dbias_k, _megdnn_tensor_out dbias_v, + _megdnn_workspace workspace) = 0; MGE_WIN_DECLSPEC_FUC void deduce_layout( const TensorLayout& diff, const TensorLayout& queries, const TensorLayout& keys, const TensorLayout& values, - const TensorLayout& wqkv, const TensorLayout& reserveSpace, - TensorLayout& dqueries, TensorLayout& dkeys, TensorLayout& dvalues, - TensorLayout& dweights); + const TensorLayout& qkvo_weight_bias, const TensorLayout& attn_mask, + const TensorLayout& attn_weight, const TensorLayout& mask_reservespace, + const TensorLayout& othr_reservespace, TensorLayout& dqueries, + TensorLayout& dkeys, TensorLayout& dvalues, TensorLayout& dqkvo_weight_bias, + TensorLayout& dbias_k, TensorLayout& dbias_v); virtual size_t get_workspace_in_bytes( const TensorLayout& diff, const TensorLayout& queries, const TensorLayout& keys, const TensorLayout& values, - const TensorLayout& wqkv, const TensorLayout& reserveSpace, - const TensorLayout& dqueries, const TensorLayout& dkeys, - const TensorLayout& dvalues, const TensorLayout& dweights) = 0; + const TensorLayout& qkvo_weight_bias, const TensorLayout& attn_mask, + const TensorLayout& attn_weight, const TensorLayout& mask_reservespace, + const TensorLayout& othr_reservespace, const TensorLayout& dqueries, + const TensorLayout& dkeys, const TensorLayout& dvalues, + const TensorLayout& dqkvo_weight_bias, const TensorLayout& dbias_k, + const TensorLayout& dbias_v) = 0; protected: void check_exec( const TensorLayout& diff, const TensorLayout& queries, const TensorLayout& keys, const TensorLayout& values, - const TensorLayout& wqkv, const TensorLayout& reserveSpace, - const TensorLayout& dqueries, const TensorLayout& dkeys, - const TensorLayout& dvalues, const TensorLayout& dweights, - size_t workspace_in_bytes); + const TensorLayout& qkvo_weight_bias, const TensorLayout& attn_mask, + const TensorLayout& attn_weight, const TensorLayout& mask_reservespace, + const TensorLayout& othr_reservespace, const TensorLayout& dqueries, + const TensorLayout& dkeys, const TensorLayout& dvalues, + const TensorLayout& dqkvo_weight_bias, const TensorLayout& dbias_k, + const TensorLayout& dbias_v, size_t workspace_in_bytes); }; } // namespace megdnn #include "megdnn/internal/opr_header_epilogue.h" diff --git a/dnn/src/common/multi_head_attn.cpp b/dnn/src/common/multi_head_attn.cpp index a3219a67483c39fce374e12100972ac6c39d0c82..0394706220fc5ac4b00b3b4c87b9bdde6e27d0b9 100644 --- a/dnn/src/common/multi_head_attn.cpp +++ b/dnn/src/common/multi_head_attn.cpp @@ -1,4 +1,5 @@ #include "megdnn/basic_types.h" +#include "megdnn/dtype.h" #include "megdnn/oprs.h" #include "src/common/utils.cuh" #include "unroll_macro.h" @@ -8,128 +9,278 @@ namespace megdnn { using Param = MultiHeadAttnBase::Param; - -void MultiHeadAttnForward::deduce_layout( - const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& wqkv, TensorLayout& out, - TensorLayout& reserveSpace) { - megdnn_assert( - queries.ndim == 3, - "queries.ndim should be 3[batch, sequence, embeding], but got %zu", - queries.ndim); - size_t size = - get_reservespace_in_bytes(queries, keys, values, wqkv, out, reserveSpace); - out = TensorLayout( - {queries.shape[0], queries.shape[1], queries.shape[2]}, queries.dtype); - reserveSpace = TensorLayout({size}, queries.dtype); -} +using INPUT_TYPE = Param::TENSOR_COMBINATION_TYPE; void MultiHeadAttnForward::check_exec( const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& out, - const TensorLayout& reserveSpace, size_t workspace_in_bytes) { + const TensorLayout& values, const TensorLayout& qkvo_weight_bias, + const TensorLayout& attn_mask, const TensorLayout& bias_k, + const TensorLayout& bias_v, const TensorLayout& out, + const TensorLayout& attn_weight, const TensorLayout& mask_reservespace, + const TensorLayout& othr_reservespace, size_t workspace_in_bytes) { Param p = param(); + // contiguous megdnn_assert_contiguous(queries); megdnn_assert_contiguous(keys); megdnn_assert_contiguous(values); - megdnn_assert_contiguous(wqkv); megdnn_assert_contiguous(out); - if (p.training) - megdnn_assert_contiguous(reserveSpace); - auto required_workspace_in_bytes = - get_workspace_in_bytes(queries, keys, values, wqkv, out, reserveSpace); + megdnn_assert_contiguous(attn_weight); + if (p.training) { + megdnn_assert_contiguous(othr_reservespace); + } + if (p.qproj_size or p.kproj_size or p.vproj_size or p.kproj_size) + megdnn_assert_contiguous(qkvo_weight_bias); + bool have_mask = false; + bool have_biaskv = false; + auto input_type = p.tensor_combination_type; + if (input_type == INPUT_TYPE::ONLY_BIASKV or input_type == INPUT_TYPE::ALL) { + have_biaskv = true; + megdnn_assert_contiguous(bias_k); + megdnn_assert_contiguous(bias_v); + } + if (input_type == INPUT_TYPE::ONLY_MASK or input_type == INPUT_TYPE::ALL) { + have_mask = true; + megdnn_assert_contiguous(attn_mask); + } + + // misc + size_t required_workspace_in_bytes = get_workspace_in_bytes( + queries, keys, values, qkvo_weight_bias, attn_mask, bias_k, bias_v, out, + attn_weight, mask_reservespace, othr_reservespace); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); + megdnn_assert( + queries.ndim == 3, "queries.ndim should be 3, but got %zu", queries.ndim); + megdnn_assert(keys.ndim == 3, "keys.ndim should be 3, but got %zu", keys.ndim); + megdnn_assert( + values.ndim == 3, "values.ndim should be 3, but got %zu", values.ndim); + auto errmsg = [&]() { + return megdnn_layout_msg(queries) + ", " + megdnn_layout_msg(keys) + ", " + + megdnn_layout_msg(values) + ", " + megdnn_layout_msg(qkvo_weight_bias) + + ", " + megdnn_layout_msg(attn_mask) + ", " + megdnn_layout_msg(bias_k) + + ", " + megdnn_layout_msg(bias_v) + ", " + megdnn_layout_msg(out) + ", " + + megdnn_layout_msg(attn_weight); + }; + // batch match megdnn_assert( - queries.ndim == 3, - "queries.ndim should be 3[batch, sequence, embeding], but got %zu", - queries.ndim); + (queries.shape[0] == out.shape[0]) and + (keys.shape[0] == values.shape[0]) and + (queries.shape[0] == keys.shape[0]), + "the batch of query(%zu), key(%zu), value(%zu) and output(%zu) do not " + "match. details: %s", + queries.shape[0], keys.shape[0], values.shape[0], out.shape[0], + errmsg().c_str()); + // sequence length match megdnn_assert( - keys.ndim == 3, - "keys.ndim should be 3[batch, sequence, embeding], but got %zu", keys.ndim); + queries.shape[1] == out.shape[1], + "the sequence length of query(%zu) does not match the sequence length of " + "output(%zu). details: %s", + queries.shape[1], out.shape[1], errmsg().c_str()); megdnn_assert( - values.ndim == 3, - "values.ndim should be 3[batch, sequence, embeding], but got %zu", - values.ndim); + keys.shape[1] == values.shape[1], + "the sequence length of key(%zu) does not match the sequence length of " + "value(%zu). details: %s", + keys.shape[1], values.shape[1], errmsg().c_str()); + // bias_k and bias_v layout check + if (have_biaskv) { + megdnn_assert( + bias_k.ndim == 3 and bias_v.ndim == 3, + "bias_k ndim should be 3, but got %zu, details: %s", bias_k.ndim, + errmsg().c_str()); + megdnn_assert( + (bias_k.shape[0] == 1) and (bias_k.shape[1] == 1) and + (bias_k.shape[2] == (p.kproj_size ? p.kproj_size : p.k_size)), + "bias_k.shape should be [1, 1, %u], but got [%zu, " + "%zu, %zu], details: %s", + p.kproj_size ? p.kproj_size : p.k_size, bias_k.shape[0], + bias_k.shape[1], bias_k.shape[2], errmsg().c_str()); + megdnn_assert( + (bias_v.shape[0] == 1) and (bias_v.shape[1] == 1) and + (bias_v.shape[2] == (p.vproj_size ? p.vproj_size : p.v_size)), + "bias_v.shape should be [1, 1, %u], but got [%zu, " + "%zu, %zu], details: %s", + p.vproj_size ? p.vproj_size : p.v_size, bias_v.shape[0], + bias_v.shape[1], bias_v.shape[2], errmsg().c_str()); + } + // attn mask layout check + size_t attn_add = (have_biaskv ? 1 : 0) + (p.add_zero_attn ? 1 : 0); + if (have_mask and attn_mask.ndim == 3) { + megdnn_assert( + (queries.shape[0] * p.num_heads == attn_mask.shape[0]) and + (queries.shape[1] == attn_mask.shape[1]) and + ((keys.shape[1] + attn_add) == attn_mask.shape[2]), + "attn_mask.shape should be [%zu, %zu, %zu](attn_add=%zu), but got " + "[%zu, %zu, %zu]. details: %s", + queries.shape[0] * p.num_heads, queries.shape[1], + keys.shape[1] + attn_add, attn_add, attn_mask.shape[0], + attn_mask.shape[1], attn_mask.shape[2], errmsg().c_str()); + } else if (have_mask and attn_mask.ndim == 2) { + megdnn_assert( + (queries.shape[1] == attn_mask.shape[0]) and + ((keys.shape[1] + attn_add) == attn_mask.shape[1]), + "attn_mask.shape should be [%zu, %zu](attn_add=%zu), but got " + "[%zu, %zu]. details: %s", + queries.shape[1], keys.shape[1] + attn_add, attn_add, + attn_mask.shape[0], attn_mask.shape[1], errmsg().c_str()); + } + // attn_weight layout check + megdnn_assert( + (attn_weight.shape[0] == queries.shape[0] * p.num_heads) and + (attn_weight.shape[1] == queries.shape[1]) and + (attn_weight.shape[2] == keys.shape[1] + attn_add), + "attn_weight.shape should be [%zu, %zu, %zu](attn_add=%zu), but got [%zu, " + "%zu, %zu]. details: %s", + queries.shape[0] * p.num_heads, queries.shape[1], keys.shape[1] + attn_add, + attn_add, attn_weight.shape[0], attn_weight.shape[1], attn_weight.shape[2], + errmsg().c_str()); - auto errmsg = [&]() { - return megdnn_layout_msg(queries) + ", " + megdnn_layout_msg(keys) + ", " + - megdnn_layout_msg(values) + ", " + megdnn_layout_msg(wqkv) + ", " + - megdnn_layout_msg(out) + ", " + megdnn_layout_msg(reserveSpace); + // weigth and bias +#define TOSTRING(data) #data "=" + std::to_string(data) + auto param_errmsg = [&]() { + return TOSTRING(p.embeding_size) + ", " + TOSTRING(p.k_size) + ", " + + TOSTRING(p.v_size) + ", " + TOSTRING(p.qproj_size) + ", " + + TOSTRING(p.kproj_size) + ", " + TOSTRING(p.vproj_size) + ", " + + TOSTRING(p.oproj_size) + ", " + TOSTRING(p.qbias) + ", " + + TOSTRING(p.kbias) + ", " + TOSTRING(p.vbias) + ", " + TOSTRING(p.obias) + + ", " + TOSTRING(p.num_heads) + ", " + TOSTRING(p.need_weights) + ", " + + TOSTRING(p.add_zero_attn) + ", " + TOSTRING(int(p.attn_mask_type)) + + ", " + TOSTRING(int(p.tensor_combination_type)) + ", " + + TOSTRING(p.sm_scaler) + ", " + TOSTRING(p.training); }; - megdnn_assert(queries.shape[0] == out.shape[0], "%s", errmsg().c_str()); - megdnn_assert(keys.shape[0] == values.shape[0], "%s", errmsg().c_str()); - megdnn_assert(queries.shape[0] == keys.shape[0], "%s", errmsg().c_str()); - megdnn_assert(queries.shape[1] == out.shape[1], "%s", errmsg().c_str()); - megdnn_assert(keys.shape[1] == values.shape[1], "%s", errmsg().c_str()); - megdnn_assert( - queries.shape[2] == keys.shape[2] and keys.shape[2] == values.shape[2] and - queries.shape[2] == out.shape[2], - "%s", errmsg().c_str()); +#undef TOSTRING + size_t weight_len = 0; + size_t embeding_size = p.embeding_size; + size_t ksize = p.k_size; + size_t vsize = p.v_size; + size_t qprojsize = p.qproj_size; + size_t kprojsize = p.kproj_size; + size_t vprojsize = p.vproj_size; + size_t oprojsize = p.oproj_size; + megdnn_assert(embeding_size == queries.shape[2], "%s", param_errmsg().c_str()); + megdnn_assert(ksize == keys.shape[2], "%s", param_errmsg().c_str()); + megdnn_assert(vsize == values.shape[2], "%s", param_errmsg().c_str()); + if (qprojsize == 0 and kprojsize == 0) + megdnn_assert(embeding_size == ksize, "%s", param_errmsg().c_str()); + if (qprojsize == 0 and kprojsize != 0) + megdnn_assert(embeding_size == kprojsize, "%s", param_errmsg().c_str()); + if (qprojsize != 0 and kprojsize == 0) + megdnn_assert(qprojsize == ksize, "%s", param_errmsg().c_str()); + if (qprojsize != 0 and kprojsize != 0) + megdnn_assert(qprojsize == kprojsize, "%s", param_errmsg().c_str()); + if (p.qbias) + megdnn_assert(p.qproj_size > 0, "%s", param_errmsg().c_str()); + if (p.kbias) + megdnn_assert(p.kproj_size > 0, "%s", param_errmsg().c_str()); + if (p.vbias) + megdnn_assert(p.vproj_size > 0, "%s", param_errmsg().c_str()); + if (p.obias) + megdnn_assert(p.oproj_size > 0, "%s", param_errmsg().c_str()); + if (p.qproj_size > 0) + weight_len += embeding_size * qprojsize + (p.qbias ? qprojsize : 0); + if (p.kproj_size > 0) + weight_len += ksize * kprojsize + (p.kbias ? kprojsize : 0); + if (p.vproj_size > 0) + weight_len += vsize * vprojsize + (p.vbias ? vprojsize : 0); + if (p.oproj_size > 0 and p.vproj_size > 0) + weight_len += vprojsize * oprojsize + (p.obias ? oprojsize : 0); + else if (p.oproj_size > 0 and p.vproj_size == 0) + weight_len += vsize * oprojsize + (p.obias ? oprojsize : 0); + megdnn_assert( + weight_len == qkvo_weight_bias.total_nr_elems(), + "qkvo_weight_bias length should be %zu, but got %zu. details: %s", + weight_len, qkvo_weight_bias.total_nr_elems(), param_errmsg().c_str()); } void MultiHeadAttnBackward::deduce_layout( const TensorLayout& diff, const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& wqkv, - const TensorLayout& reserveSpace, TensorLayout& dqueries, TensorLayout& dkeys, - TensorLayout& dvalues, TensorLayout& dweights) { + const TensorLayout& values, const TensorLayout& qkvo_weight_bias, + const TensorLayout& attn_mask, const TensorLayout& attn_weight, + const TensorLayout& mask_reservespace, const TensorLayout& othr_reservespace, + TensorLayout& dqueries, TensorLayout& dkeys, TensorLayout& dvalues, + TensorLayout& dqkvo_weight_bias, TensorLayout& dbias_k, TensorLayout& dbias_v) { MEGDNN_MARK_USED_VAR(diff); - MEGDNN_MARK_USED_VAR(reserveSpace); + MEGDNN_MARK_USED_VAR(attn_mask); + MEGDNN_MARK_USED_VAR(attn_weight); + MEGDNN_MARK_USED_VAR(mask_reservespace); + MEGDNN_MARK_USED_VAR(othr_reservespace); dqueries = queries; dkeys = keys; dvalues = values; - dweights = wqkv; + dqkvo_weight_bias = qkvo_weight_bias; + auto input_type = param().tensor_combination_type; + if (input_type == INPUT_TYPE::ONLY_BIASKV or input_type == INPUT_TYPE::ALL) { + dbias_k = TensorLayout( + {1, 1, param().kproj_size ? param().kproj_size : param().k_size}, + keys.dtype); + dbias_v = TensorLayout( + {1, 1, param().vproj_size ? param().vproj_size : param().v_size}, + values.dtype); + } else { + dbias_k = TensorLayout(); + dbias_v = TensorLayout(); + } } void MultiHeadAttnBackward::check_exec( const TensorLayout& diff, const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& wqkv, - const TensorLayout& reserveSpace, const TensorLayout& dqueries, - const TensorLayout& dkeys, const TensorLayout& dvalues, - const TensorLayout& dweights, size_t workspace_in_bytes) { + const TensorLayout& values, const TensorLayout& qkvo_weight_bias, + const TensorLayout& attn_mask, const TensorLayout& attn_weight, + const TensorLayout& mask_reservespace, const TensorLayout& othr_reservespace, + const TensorLayout& dqueries, const TensorLayout& dkeys, + const TensorLayout& dvalues, const TensorLayout& dqkvo_weight_bias, + const TensorLayout& dbias_k, const TensorLayout& dbias_v, + size_t workspace_in_bytes) { Param p = param(); megdnn_assert( p.training, "When calling MultiHeadAttn backward, param().training must be true, " "but got false"); + // contiguous megdnn_assert_contiguous(diff); megdnn_assert_contiguous(queries); megdnn_assert_contiguous(keys); megdnn_assert_contiguous(values); - megdnn_assert_contiguous(wqkv); + megdnn_assert_contiguous(attn_weight); megdnn_assert_contiguous(dqueries); megdnn_assert_contiguous(dkeys); megdnn_assert_contiguous(dvalues); - megdnn_assert_contiguous(dweights); - if (p.training) - megdnn_assert_contiguous(reserveSpace); + if (p.training) { + megdnn_assert_contiguous(othr_reservespace); + } + if (p.qproj_size or p.kproj_size or p.vproj_size or p.oproj_size) { + megdnn_assert_contiguous(qkvo_weight_bias); + megdnn_assert_contiguous(dqkvo_weight_bias); + } + + auto input_type = p.tensor_combination_type; + bool have_mask = false; + bool have_biaskv = + input_type == INPUT_TYPE::ONLY_BIASKV or input_type == INPUT_TYPE::ALL; + if (input_type == INPUT_TYPE::ONLY_MASK or input_type == INPUT_TYPE::ALL) { + have_mask = true; + megdnn_assert_contiguous(attn_mask); + } + + // misc auto required_workspace_in_bytes = get_workspace_in_bytes( - diff, queries, keys, values, wqkv, reserveSpace, dqueries, dkeys, dvalues, - dweights); + diff, queries, keys, values, qkvo_weight_bias, attn_mask, attn_weight, + mask_reservespace, othr_reservespace, dqueries, dkeys, dvalues, + dqkvo_weight_bias, dbias_k, dbias_v); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); - megdnn_assert(reserveSpace.total_nr_elems() > 0); - - megdnn_assert( - queries.ndim == 3, - "queries.ndim should be 3[batch, sequence, embeding], but got %zu", - queries.ndim); - megdnn_assert( - keys.ndim == 3, - "keys.ndim should be 3[batch, sequence, embeding], but got %zu", keys.ndim); + megdnn_assert(othr_reservespace.total_nr_elems() > 0); megdnn_assert( - values.ndim == 3, - "values.ndim should be 3[batch, sequence, embeding], but got %zu", - values.ndim); + queries.ndim == 3, "queries.ndim should be 3, but got %zu", queries.ndim); + megdnn_assert(keys.ndim == 3, "keys.ndim should be 3, but got %zu", keys.ndim); megdnn_assert( - diff.ndim == 3, - "diff.ndim should be 3[batch, sequence, embeding], but got %zu", diff.ndim); - + values.ndim == 3, "values.ndim should be 3, but got %zu", values.ndim); + megdnn_assert(diff.ndim == 3, "diff.ndim should be 3, but got %zu", diff.ndim); auto errmsg = [&]() { return megdnn_layout_msg(diff) + ", " + megdnn_layout_msg(queries) + ", " + megdnn_layout_msg(keys) + ", " + megdnn_layout_msg(values) + ", " + - megdnn_layout_msg(wqkv) + ", " + megdnn_layout_msg(reserveSpace) + ", " + - megdnn_layout_msg(dqueries) + ", " + megdnn_layout_msg(dkeys) + ", " + - megdnn_layout_msg(dvalues) + ", " + megdnn_layout_msg(dweights); + megdnn_layout_msg(qkvo_weight_bias) + ", " + + megdnn_layout_msg(attn_weight) + ", " + megdnn_layout_msg(dqueries) + + ", " + megdnn_layout_msg(dkeys) + ", " + megdnn_layout_msg(dvalues) + + ", " + megdnn_layout_msg(dqkvo_weight_bias); }; auto equal_layout = [](const TensorLayout& lhs, const TensorLayout& rhs) -> bool { @@ -144,21 +295,151 @@ void MultiHeadAttnBackward::check_exec( return true; }; - megdnn_assert(equal_layout(queries, diff), "%s", errmsg().c_str()); + // layout check + size_t osize = p.oproj_size != 0 ? p.oproj_size + : (p.vproj_size != 0 ? p.vproj_size : p.v_size); + TensorLayout diff_expect = TensorLayout( + TensorShape{queries.shape[0], queries.shape[1], osize}, queries.dtype); + megdnn_assert(equal_layout(diff, diff_expect), "%s", errmsg().c_str()); megdnn_assert(equal_layout(queries, dqueries), "%s", errmsg().c_str()); megdnn_assert(equal_layout(keys, dkeys), "%s", errmsg().c_str()); megdnn_assert(equal_layout(values, dvalues), "%s", errmsg().c_str()); - megdnn_assert(equal_layout(wqkv, dweights), "%s", errmsg().c_str()); + megdnn_assert( + equal_layout(qkvo_weight_bias, dqkvo_weight_bias), "%s", errmsg().c_str()); + + // batch match + megdnn_assert( + (queries.shape[0] == diff.shape[0]) and + (keys.shape[0] == values.shape[0]) and + (queries.shape[0] == keys.shape[0]), + "the batch of query(%zu), key(%zu), value(%zu) and diff(%zu) do not " + "match. details: %s", + queries.shape[0], keys.shape[0], values.shape[0], diff.shape[0], + errmsg().c_str()); + // sequence length match + megdnn_assert( + queries.shape[1] == diff.shape[1], + "the sequence length of query(%zu) does not match the sequence length of " + "output(%zu). details: %s", + queries.shape[1], diff.shape[1], errmsg().c_str()); + megdnn_assert( + keys.shape[1] == values.shape[1], + "the sequence length of key(%zu) does not match the sequence length of " + "value(%zu). details: %s", + keys.shape[1], values.shape[1], errmsg().c_str()); - megdnn_assert(queries.shape[0] == diff.shape[0], "%s", errmsg().c_str()); - megdnn_assert(keys.shape[0] == values.shape[0], "%s", errmsg().c_str()); - megdnn_assert(queries.shape[0] == keys.shape[0], "%s", errmsg().c_str()); - megdnn_assert(queries.shape[1] == diff.shape[1], "%s", errmsg().c_str()); - megdnn_assert(keys.shape[1] == values.shape[1], "%s", errmsg().c_str()); + size_t attn_add = (have_biaskv ? 1 : 0) + (p.add_zero_attn ? 1 : 0); + // attn_weight layout check + megdnn_assert( + (attn_weight.shape[0] == queries.shape[0] * p.num_heads) and + (attn_weight.shape[1] == queries.shape[1]) and + (attn_weight.shape[2] == keys.shape[1] + attn_add), + "attn_weight.shape should be [%zu, %zu, %zu](attn_add=%zu), but got [%zu, " + "%zu, %zu]. details: %s", + queries.shape[0] * p.num_heads, queries.shape[1], keys.shape[1] + attn_add, + attn_add, attn_weight.shape[0], attn_weight.shape[1], attn_weight.shape[2], + errmsg().c_str()); + // dbias_k, dbias_v layout check + if (have_biaskv) { + megdnn_assert( + dbias_k.ndim == 3 and dbias_v.ndim == 3, + "dbias_k ndim should be 3, but got %zu, details: %s", dbias_k.ndim, + errmsg().c_str()); + megdnn_assert( + (dbias_k.shape[0] == 1) and (dbias_k.shape[1] == 1) and + (dbias_k.shape[2] == (p.kproj_size ? p.kproj_size : p.k_size)), + "dbias_k.shape should be [1, 1, %u], but got [%zu, " + "%zu, %zu], details: %s", + p.kproj_size ? p.kproj_size : p.k_size, dbias_k.shape[0], + dbias_k.shape[1], dbias_k.shape[2], errmsg().c_str()); + megdnn_assert( + (dbias_v.shape[0] == 1) and (dbias_v.shape[1] == 1) and + (dbias_v.shape[2] == (p.vproj_size ? p.vproj_size : p.v_size)), + "dbias_v.shape should be [1, 1, %u], but got [%zu, " + "%zu, %zu], details: %s", + p.vproj_size ? p.vproj_size : p.v_size, dbias_v.shape[0], + dbias_v.shape[1], dbias_v.shape[2], errmsg().c_str()); + } + // attn mask layout check + if (have_mask and attn_mask.ndim == 3) { + megdnn_assert( + (queries.shape[0] * p.num_heads == attn_mask.shape[0]) and + (queries.shape[1] == attn_mask.shape[1]) and + ((keys.shape[1] + attn_add) == attn_mask.shape[2]), + "attn_mask.shape should be [%zu, %zu, %zu](attn_add=%zu), but got " + "[%zu, %zu, %zu]. details: %s", + queries.shape[0] * p.num_heads, queries.shape[1], + keys.shape[1] + attn_add, attn_add, attn_mask.shape[0], + attn_mask.shape[1], attn_mask.shape[2], errmsg().c_str()); + } else if (have_mask and attn_mask.ndim == 2) { + megdnn_assert( + (queries.shape[1] == attn_mask.shape[0]) and + ((keys.shape[1] + attn_add) == attn_mask.shape[1]), + "attn_mask.shape should be [%zu, %zu](attn_add=%zu), but got " + "[%zu, %zu]. details: %s", + queries.shape[1], keys.shape[1] + attn_add, attn_add, + attn_mask.shape[0], attn_mask.shape[1], errmsg().c_str()); + } + + // weigth and bias +#define TOSTRING(data) #data "=" + std::to_string(data) + auto param_errmsg = [&]() { + return TOSTRING(p.embeding_size) + ", " + TOSTRING(p.k_size) + ", " + + TOSTRING(p.v_size) + ", " + TOSTRING(p.qproj_size) + ", " + + TOSTRING(p.kproj_size) + ", " + TOSTRING(p.vproj_size) + ", " + + TOSTRING(p.oproj_size) + ", " + TOSTRING(p.qbias) + ", " + + TOSTRING(p.kbias) + ", " + TOSTRING(p.vbias) + ", " + TOSTRING(p.obias) + + ", " + TOSTRING(p.num_heads) + ", " + TOSTRING(p.need_weights) + ", " + + TOSTRING(p.add_zero_attn) + ", " + TOSTRING(int(p.attn_mask_type)) + + ", " + TOSTRING(int(p.tensor_combination_type)) + ", " + + TOSTRING(p.sm_scaler) + ", " + TOSTRING(p.training); + }; +#undef TOSTRING + size_t weight_len = 0; + size_t embeding_size = p.embeding_size; + size_t ksize = p.k_size; + size_t vsize = p.v_size; + size_t qprojsize = p.qproj_size; + size_t kprojsize = p.kproj_size; + size_t vprojsize = p.vproj_size; + size_t oprojsize = p.oproj_size; + megdnn_assert(embeding_size == queries.shape[2], "%s", param_errmsg().c_str()); + megdnn_assert(ksize == keys.shape[2], "%s", param_errmsg().c_str()); + megdnn_assert(vsize == values.shape[2], "%s", param_errmsg().c_str()); + if (qprojsize == 0 and kprojsize == 0) + megdnn_assert(embeding_size == ksize, "%s", param_errmsg().c_str()); + if (qprojsize == 0 and kprojsize != 0) + megdnn_assert(embeding_size == kprojsize, "%s", param_errmsg().c_str()); + if (qprojsize != 0 and kprojsize == 0) + megdnn_assert(qprojsize == ksize, "%s", param_errmsg().c_str()); + if (qprojsize != 0 and kprojsize != 0) + megdnn_assert(qprojsize == kprojsize, "%s", param_errmsg().c_str()); + if (p.qbias) + megdnn_assert(p.qproj_size > 0, "%s", param_errmsg().c_str()); + if (p.kbias) + megdnn_assert(p.kproj_size > 0, "%s", param_errmsg().c_str()); + if (p.vbias) + megdnn_assert(p.vproj_size > 0, "%s", param_errmsg().c_str()); + if (p.obias) + megdnn_assert(p.oproj_size > 0, "%s", param_errmsg().c_str()); + if (p.qproj_size > 0) + weight_len += embeding_size * qprojsize + (p.qbias ? qprojsize : 0); + if (p.kproj_size > 0) + weight_len += ksize * kprojsize + (p.kbias ? kprojsize : 0); + if (p.vproj_size > 0) + weight_len += vsize * vprojsize + (p.vbias ? vprojsize : 0); + if (p.oproj_size > 0 and p.vproj_size > 0) + weight_len += vprojsize * oprojsize + (p.obias ? oprojsize : 0); + else if (p.oproj_size > 0 and p.vproj_size == 0) + weight_len += vsize * oprojsize + (p.obias ? oprojsize : 0); + megdnn_assert( + weight_len == qkvo_weight_bias.total_nr_elems(), + "qkvo_weight_bias length should be %zu, but got %zu. details: %s", + weight_len, qkvo_weight_bias.total_nr_elems(), param_errmsg().c_str()); megdnn_assert( - queries.shape[2] == keys.shape[2] and keys.shape[2] == values.shape[2] and - queries.shape[2] == diff.shape[2], - "%s", errmsg().c_str()); + weight_len == dqkvo_weight_bias.total_nr_elems(), + "dqkvo_weight_bias length should be %zu, but got %zu. details: %s", + weight_len, dqkvo_weight_bias.total_nr_elems(), param_errmsg().c_str()); } } // namespace megdnn diff --git a/dnn/src/common/opr_trait.h b/dnn/src/common/opr_trait.h index 65591fb9194b9653c0efe0a7e720ea69364d3cf3..b7242eaedc18a4cab1e7846deb2c0a728e6067fe 100644 --- a/dnn/src/common/opr_trait.h +++ b/dnn/src/common/opr_trait.h @@ -148,8 +148,8 @@ DEF(RegionRestrictedConvolutionBackwardFilter, 5, true, false); DEF(GroupNormForward, 6, true, true); DEF(GroupNormBackward, 8, true, true); DEF(MaskedFill, 3, false, true); -DEF(MultiHeadAttnForward, 6, true, true); -DEF(MultiHeadAttnBackward, 10, true, true); +DEF(MultiHeadAttnForward, 11, true, true); +DEF(MultiHeadAttnBackward, 15, true, true); } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/multi_head_attn/opr_impl.cpp b/dnn/src/cuda/multi_head_attn/opr_impl.cpp index 10ea4d9ac96e1f9d69c198c4d5655dad5480e0b7..3a628c1dc898c649980037287aecc822abfb71ba 100644 --- a/dnn/src/cuda/multi_head_attn/opr_impl.cpp +++ b/dnn/src/cuda/multi_head_attn/opr_impl.cpp @@ -8,54 +8,84 @@ namespace cuda { void MultiHeadAttnForwardImpl::deduce_layout( const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& wqkv, TensorLayout& out, - TensorLayout& reserveSpace) { + const TensorLayout& values, const TensorLayout& qkvo_weight_bias, + const TensorLayout& attn_mask, const TensorLayout& bias_k, + const TensorLayout& bias_v, TensorLayout& out, TensorLayout& attn_weight, + TensorLayout& mask_reservespace, TensorLayout& othr_reservespace) { #if CUDNN_VERSION < 8004 // TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation. MEGDNN_MARK_USED_VAR(queries); MEGDNN_MARK_USED_VAR(keys); MEGDNN_MARK_USED_VAR(values); - MEGDNN_MARK_USED_VAR(wqkv); + MEGDNN_MARK_USED_VAR(qkvo_weight_bias); + MEGDNN_MARK_USED_VAR(attn_mask); + MEGDNN_MARK_USED_VAR(bias_k); + MEGDNN_MARK_USED_VAR(bias_v); MEGDNN_MARK_USED_VAR(out); - MEGDNN_MARK_USED_VAR(reserveSpace); + MEGDNN_MARK_USED_VAR(attn_weight); + MEGDNN_MARK_USED_VAR(mask_reservespace); + MEGDNN_MARK_USED_VAR(othr_reservespace); return; #else - MEGDNN_MARK_USED_VAR(keys); - MEGDNN_MARK_USED_VAR(wqkv); + MEGDNN_MARK_USED_VAR(qkvo_weight_bias); + MEGDNN_MARK_USED_VAR(attn_mask); + auto p = param(); megdnn_assert( - queries.ndim == 3, - "queries.ndim should be 3[batch, sequence, embeding], but got %zu", - queries.ndim); + queries.ndim == 3, "queries.ndim should be 3, but got %zu", queries.ndim); - if (!desc_status.is_initialized(param(), queries, keys, values)) { - desc_status.set(cudnn_handle(this->handle()), param(), queries, keys, values); + if (!desc_status.is_initialized(p, queries, keys, values)) + desc_status.set(cudnn_handle(this->handle()), p, queries, keys, values); + + auto input_type = p.tensor_combination_type; + using INPUT_TYPE = Param::TENSOR_COMBINATION_TYPE; + bool have_biaskv = + input_type == INPUT_TYPE::ONLY_BIASKV or input_type == INPUT_TYPE::ALL; + size_t attn_seqk_dim_add = (have_biaskv ? 1 : 0) + (p.add_zero_attn ? 1 : 0); + attn_weight = TensorLayout( + TensorShape{ + queries.shape[0] * p.num_heads, queries.shape[1], + keys.shape[1] + attn_seqk_dim_add}, + queries.dtype); + + size_t osize = p.oproj_size != 0 ? p.oproj_size + : (p.vproj_size != 0 ? p.vproj_size : p.v_size); + out = TensorLayout( + TensorShape{queries.shape[0], queries.shape[1], osize}, queries.dtype); + mask_reservespace = TensorLayout(TensorShape{0}, dtype::Uint8()); + othr_reservespace = + TensorLayout(TensorShape{desc_status.sizeReserve}, queries.dtype); - out = TensorLayout( - TensorShape{queries.shape[0], queries.shape[1], queries.shape[2]}, - queries.dtype); - reserveSpace = - TensorLayout(TensorShape{desc_status.sizeReserve}, queries.dtype); - } #endif } size_t MultiHeadAttnForwardImpl::get_workspace_in_bytes( const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& out, - const TensorLayout& reserveSpace) { + const TensorLayout& values, const TensorLayout& qkvo_weight_bias, + const TensorLayout& attn_mask, const TensorLayout& bias_k, + const TensorLayout& bias_v, const TensorLayout& out, + const TensorLayout& attn_weight, const TensorLayout& mask_reservespace, + const TensorLayout& othr_reservespace) { #if CUDNN_VERSION < 8004 // TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation. MEGDNN_MARK_USED_VAR(queries); MEGDNN_MARK_USED_VAR(keys); MEGDNN_MARK_USED_VAR(values); - MEGDNN_MARK_USED_VAR(wqkv); + MEGDNN_MARK_USED_VAR(qkvo_weight_bias); + MEGDNN_MARK_USED_VAR(attn_mask); + MEGDNN_MARK_USED_VAR(bias_k); + MEGDNN_MARK_USED_VAR(bias_v); MEGDNN_MARK_USED_VAR(out); - MEGDNN_MARK_USED_VAR(reserveSpace); + MEGDNN_MARK_USED_VAR(attn_weight); + MEGDNN_MARK_USED_VAR(mask_reservespace); + MEGDNN_MARK_USED_VAR(othr_reservespace); return 0; #else - MEGDNN_MARK_USED_VAR(wqkv); + MEGDNN_MARK_USED_VAR(qkvo_weight_bias); + MEGDNN_MARK_USED_VAR(attn_mask); MEGDNN_MARK_USED_VAR(out); - MEGDNN_MARK_USED_VAR(reserveSpace); + MEGDNN_MARK_USED_VAR(attn_weight); + MEGDNN_MARK_USED_VAR(mask_reservespace); + MEGDNN_MARK_USED_VAR(othr_reservespace); if (!desc_status.is_initialized(param(), queries, keys, values)) desc_status.set(cudnn_handle(this->handle()), param(), queries, keys, values); @@ -64,47 +94,102 @@ size_t MultiHeadAttnForwardImpl::get_workspace_in_bytes( #endif } -size_t MultiHeadAttnForwardImpl::get_reservespace_in_bytes( +size_t MultiHeadAttnForwardImpl::get_mask_reservespace_in_bytes( const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& out, - const TensorLayout& reserveSpace) { + const TensorLayout& values, const TensorLayout& qkvo_weight_bias, + const TensorLayout& attn_mask, const TensorLayout& bias_k, + const TensorLayout& bias_v, const TensorLayout& out, + const TensorLayout& attn_weight, const TensorLayout& mask_reservespace, + const TensorLayout& othr_reservespace) { #if CUDNN_VERSION < 8004 // TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation. MEGDNN_MARK_USED_VAR(queries); MEGDNN_MARK_USED_VAR(keys); MEGDNN_MARK_USED_VAR(values); - MEGDNN_MARK_USED_VAR(wqkv); + MEGDNN_MARK_USED_VAR(qkvo_weight_bias); + MEGDNN_MARK_USED_VAR(attn_mask); + MEGDNN_MARK_USED_VAR(bias_k); + MEGDNN_MARK_USED_VAR(bias_v); MEGDNN_MARK_USED_VAR(out); - MEGDNN_MARK_USED_VAR(reserveSpace); + MEGDNN_MARK_USED_VAR(attn_weight); + MEGDNN_MARK_USED_VAR(mask_reservespace); + MEGDNN_MARK_USED_VAR(othr_reservespace); return 0; #else - MEGDNN_MARK_USED_VAR(wqkv); + MEGDNN_MARK_USED_VAR(qkvo_weight_bias); + MEGDNN_MARK_USED_VAR(attn_mask); MEGDNN_MARK_USED_VAR(out); - MEGDNN_MARK_USED_VAR(reserveSpace); + MEGDNN_MARK_USED_VAR(attn_weight); + MEGDNN_MARK_USED_VAR(mask_reservespace); + MEGDNN_MARK_USED_VAR(othr_reservespace); + if (!desc_status.is_initialized(param(), queries, keys, values)) + desc_status.set(cudnn_handle(this->handle()), param(), queries, keys, values); + return 0; +#endif +} +size_t MultiHeadAttnForwardImpl::get_othr_reservespace_in_bytes( + const TensorLayout& queries, const TensorLayout& keys, + const TensorLayout& values, const TensorLayout& qkvo_weight_bias, + const TensorLayout& attn_mask, const TensorLayout& bias_k, + const TensorLayout& bias_v, const TensorLayout& out, + const TensorLayout& attn_weight, const TensorLayout& mask_reservespace, + const TensorLayout& othr_reservespace) { +#if CUDNN_VERSION < 8004 + // TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation. + MEGDNN_MARK_USED_VAR(queries); + MEGDNN_MARK_USED_VAR(keys); + MEGDNN_MARK_USED_VAR(values); + MEGDNN_MARK_USED_VAR(qkvo_weight_bias); + MEGDNN_MARK_USED_VAR(attn_mask); + MEGDNN_MARK_USED_VAR(bias_k); + MEGDNN_MARK_USED_VAR(bias_v); + MEGDNN_MARK_USED_VAR(out); + MEGDNN_MARK_USED_VAR(attn_weight); + MEGDNN_MARK_USED_VAR(mask_reservespace); + MEGDNN_MARK_USED_VAR(othr_reservespace); + return 0; +#else + MEGDNN_MARK_USED_VAR(qkvo_weight_bias); + MEGDNN_MARK_USED_VAR(attn_mask); + MEGDNN_MARK_USED_VAR(bias_k); + MEGDNN_MARK_USED_VAR(bias_v); + MEGDNN_MARK_USED_VAR(out); + MEGDNN_MARK_USED_VAR(attn_weight); + MEGDNN_MARK_USED_VAR(mask_reservespace); + MEGDNN_MARK_USED_VAR(othr_reservespace); if (!desc_status.is_initialized(param(), queries, keys, values)) desc_status.set(cudnn_handle(this->handle()), param(), queries, keys, values); return desc_status.sizeReserve; #endif } + void MultiHeadAttnForwardImpl::exec( _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values, - _megdnn_tensor_in wqkv, _megdnn_tensor_out out, _megdnn_tensor_out reserveSpace, - _megdnn_workspace workspace) { + _megdnn_tensor_in qkvo_weight_bias, _megdnn_tensor_in attn_mask, + _megdnn_tensor_in bias_k, _megdnn_tensor_in bias_v, _megdnn_tensor_out out, + _megdnn_tensor_out attn_weight, _megdnn_tensor_out mask_reservespace, + _megdnn_tensor_out othr_reservespace, _megdnn_workspace workspace) { #if CUDNN_VERSION < 8004 // TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation. MEGDNN_MARK_USED_VAR(queries); MEGDNN_MARK_USED_VAR(keys); MEGDNN_MARK_USED_VAR(values); - MEGDNN_MARK_USED_VAR(wqkv); + MEGDNN_MARK_USED_VAR(qkvo_weight_bias); + MEGDNN_MARK_USED_VAR(attn_mask); + MEGDNN_MARK_USED_VAR(bias_k); + MEGDNN_MARK_USED_VAR(bias_v); MEGDNN_MARK_USED_VAR(out); - MEGDNN_MARK_USED_VAR(reserveSpace); - MEGDNN_MARK_USED_VAR(workspace); + MEGDNN_MARK_USED_VAR(attn_weight); + MEGDNN_MARK_USED_VAR(mask_reservespace); + MEGDNN_MARK_USED_VAR(othr_reservespace); megdnn_throw( "The cudnn version is lower than 8.0.4. Please upgrade the cudnn version."); #else check_exec( - queries.layout, keys.layout, values.layout, wqkv.layout, out.layout, - reserveSpace.layout, workspace.size); + queries.layout, keys.layout, values.layout, qkvo_weight_bias.layout, + attn_mask.layout, bias_k.layout, bias_v.layout, out.layout, + attn_weight.layout, mask_reservespace.layout, othr_reservespace.layout, + workspace.size); auto p = param(); if (!desc_status.is_initialized(p, queries.layout, keys.layout, values.layout)) @@ -112,12 +197,16 @@ void MultiHeadAttnForwardImpl::exec( cudnn_handle(this->handle()), p, queries.layout, keys.layout, values.layout); + size_t osize = + desc_status.oProjSize != 0 + ? desc_status.oProjSize + : (desc_status.vProjSize != 0 ? desc_status.vProjSize * p.num_heads + : desc_status.vSize); SeqTensorDesc q{queries.layout, desc_status.batchSize, desc_status.seqLenQ, desc_status.qSize, p.input_order, desc_status.auxArray.seqQArray}; - SeqTensorDesc o{out.layout, desc_status.batchSize, - desc_status.seqLenQ, desc_status.oProjSize, - p.input_order, desc_status.auxArray.seqQArray}; + SeqTensorDesc o{out.layout, desc_status.batchSize, desc_status.seqLenQ, + osize, p.input_order, desc_status.auxArray.seqQArray}; SeqTensorDesc k{keys.layout, desc_status.batchSize, desc_status.seqLenK, desc_status.kSize, p.input_order, desc_status.auxArray.seqKArray}; @@ -132,19 +221,22 @@ void MultiHeadAttnForwardImpl::exec( q.desc, queries.raw_ptr(), p.reslink ? queries.raw_ptr() : NULL, k.desc, keys.raw_ptr(), v.desc, values.raw_ptr(), o.desc, out.raw_ptr(), desc_status.sizeWeights, - desc_status.sizeWeights > 0 ? wqkv.raw_ptr() : NULL, + desc_status.sizeWeights > 0 ? qkvo_weight_bias.raw_ptr() : NULL, desc_status.sizeWkspace, workspace.raw_ptr, p.training ? desc_status.sizeReserve : 0, - p.training ? reserveSpace.raw_ptr() : NULL)); + p.training ? othr_reservespace.raw_ptr() : NULL)); #endif } void MultiHeadAttnBackwardImpl::exec( _megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys, - _megdnn_tensor_in values, _megdnn_tensor_in wqkv, - _megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries, - _megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues, - _megdnn_tensor_out dweights, _megdnn_workspace workspace) { + _megdnn_tensor_in values, _megdnn_tensor_in qkvo_weight_bias, + _megdnn_tensor_in attn_mask, _megdnn_tensor_in attn_weight, + _megdnn_tensor_in mask_reservespace, _megdnn_tensor_in othr_reservespace, + _megdnn_tensor_out dqueries, _megdnn_tensor_out dkeys, + _megdnn_tensor_out dvalues, _megdnn_tensor_out dqkvo_weight_bias, + _megdnn_tensor_out dbias_k, _megdnn_tensor_out dbias_v, + _megdnn_workspace workspace) { #if CUDNN_VERSION < 8004 // TODO: CUDNN_VERSION < 8004 and param().bias = true, we need to go to the proxy // cuda implementation. @@ -152,12 +244,17 @@ void MultiHeadAttnBackwardImpl::exec( MEGDNN_MARK_USED_VAR(queries); MEGDNN_MARK_USED_VAR(keys); MEGDNN_MARK_USED_VAR(values); - MEGDNN_MARK_USED_VAR(wqkv); - MEGDNN_MARK_USED_VAR(reserveSpace); + MEGDNN_MARK_USED_VAR(qkvo_weight_bias); + MEGDNN_MARK_USED_VAR(attn_mask); + MEGDNN_MARK_USED_VAR(attn_weight); + MEGDNN_MARK_USED_VAR(mask_reservespace); + MEGDNN_MARK_USED_VAR(othr_reservespace); MEGDNN_MARK_USED_VAR(dqueries); MEGDNN_MARK_USED_VAR(dkeys); MEGDNN_MARK_USED_VAR(dvalues); - MEGDNN_MARK_USED_VAR(dweights); + MEGDNN_MARK_USED_VAR(dqkvo_weight_bias); + MEGDNN_MARK_USED_VAR(dbias_k); + MEGDNN_MARK_USED_VAR(dbias_v); megdnn_throw( "The cudnn version is lower than 8.0.4. Please upgrade the cudnn version."); #else @@ -168,11 +265,16 @@ void MultiHeadAttnBackwardImpl::exec( "but got true, because there is an error in the " "dbias result during the backward calculation."); #endif + MEGDNN_MARK_USED_VAR(attn_mask); + MEGDNN_MARK_USED_VAR(dbias_k); + MEGDNN_MARK_USED_VAR(dbias_v); check_exec( - diff.layout, queries.layout, keys.layout, values.layout, wqkv.layout, - reserveSpace.layout, dqueries.layout, dkeys.layout, dvalues.layout, - dweights.layout, workspace.size); + diff.layout, queries.layout, keys.layout, values.layout, + qkvo_weight_bias.layout, attn_mask.layout, attn_weight.layout, + mask_reservespace.layout, othr_reservespace.layout, dqueries.layout, + dkeys.layout, dvalues.layout, dqkvo_weight_bias.layout, dbias_k.layout, + dbias_v.layout, workspace.size); auto p = param(); if (!desc_status.is_initialized(p, queries.layout, keys.layout, values.layout)) @@ -180,12 +282,16 @@ void MultiHeadAttnBackwardImpl::exec( cudnn_handle(this->handle()), p, queries.layout, keys.layout, values.layout); + size_t osize = + desc_status.oProjSize != 0 + ? desc_status.oProjSize + : (desc_status.vProjSize != 0 ? desc_status.vProjSize * p.num_heads + : desc_status.vSize); SeqTensorDesc q{queries.layout, desc_status.batchSize, desc_status.seqLenQ, desc_status.qSize, p.input_order, desc_status.auxArray.seqQArray}; - SeqTensorDesc d{diff.layout, desc_status.batchSize, - desc_status.seqLenQ, desc_status.oProjSize, - p.input_order, desc_status.auxArray.seqQArray}; + SeqTensorDesc d{diff.layout, desc_status.batchSize, desc_status.seqLenQ, + osize, p.input_order, desc_status.auxArray.seqQArray}; SeqTensorDesc k{keys.layout, desc_status.batchSize, desc_status.seqLenK, desc_status.kSize, p.input_order, desc_status.auxArray.seqKArray}; @@ -200,11 +306,11 @@ void MultiHeadAttnBackwardImpl::exec( d.desc, diff.raw_ptr(), q.desc, dqueries.raw_ptr(), queries.raw_ptr(), k.desc, dkeys.raw_ptr(), keys.raw_ptr(), v.desc, dvalues.raw_ptr(), values.raw_ptr(), desc_status.sizeWeights, - desc_status.sizeWeights > 0 ? wqkv.raw_ptr() : NULL, + desc_status.sizeWeights > 0 ? qkvo_weight_bias.raw_ptr() : NULL, desc_status.sizeWkspace, workspace.raw_ptr, desc_status.sizeReserve, - reserveSpace.raw_ptr())); + othr_reservespace.raw_ptr())); - cuda_check(cudaMemset(dweights.raw_ptr(), 0, desc_status.sizeWeights)); + cuda_check(cudaMemset(dqkvo_weight_bias.raw_ptr(), 0, desc_status.sizeWeights)); #if CUDNN_VERSION < 8600 cuda_check(cudaDeviceSynchronize()); #endif @@ -212,28 +318,35 @@ void MultiHeadAttnBackwardImpl::exec( cudnn_handle(this->handle()), desc_status.attn_desc, CUDNN_WGRAD_MODE_ADD, q.desc, queries.raw_ptr(), k.desc, keys.raw_ptr(), v.desc, values.raw_ptr(), d.desc, diff.raw_ptr(), desc_status.sizeWeights, - desc_status.sizeWeights > 0 ? wqkv.raw_ptr() : NULL, - desc_status.sizeWeights > 0 ? dweights.raw_ptr() : NULL, + desc_status.sizeWeights > 0 ? qkvo_weight_bias.raw_ptr() : NULL, + desc_status.sizeWeights > 0 ? dqkvo_weight_bias.raw_ptr() : NULL, desc_status.sizeWkspace, workspace.raw_ptr, desc_status.sizeReserve, - reserveSpace.raw_ptr())); + othr_reservespace.raw_ptr())); #endif } size_t MultiHeadAttnBackwardImpl::get_workspace_in_bytes( const TensorLayout& diff, const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& wqkv, - const TensorLayout& reserveSpace, const TensorLayout& dqueries, - const TensorLayout& dkeys, const TensorLayout& dvalues, - const TensorLayout& dweights) { + const TensorLayout& values, const TensorLayout& qkvo_weight_bias, + const TensorLayout& attn_mask, const TensorLayout& attn_weight, + const TensorLayout& mask_reservespace, const TensorLayout& othr_reservespace, + const TensorLayout& dqueries, const TensorLayout& dkeys, + const TensorLayout& dvalues, const TensorLayout& dqkvo_weight_bias, + const TensorLayout& dbias_k, const TensorLayout& dbias_v) { MEGDNN_MARK_USED_VAR(diff); MEGDNN_MARK_USED_VAR(queries); MEGDNN_MARK_USED_VAR(keys); MEGDNN_MARK_USED_VAR(values); - MEGDNN_MARK_USED_VAR(wqkv); - MEGDNN_MARK_USED_VAR(reserveSpace); + MEGDNN_MARK_USED_VAR(qkvo_weight_bias); + MEGDNN_MARK_USED_VAR(attn_mask); + MEGDNN_MARK_USED_VAR(attn_weight); + MEGDNN_MARK_USED_VAR(mask_reservespace); + MEGDNN_MARK_USED_VAR(othr_reservespace); MEGDNN_MARK_USED_VAR(dqueries); MEGDNN_MARK_USED_VAR(dkeys); MEGDNN_MARK_USED_VAR(dvalues); - MEGDNN_MARK_USED_VAR(dweights); + MEGDNN_MARK_USED_VAR(dqkvo_weight_bias); + MEGDNN_MARK_USED_VAR(dbias_k); + MEGDNN_MARK_USED_VAR(dbias_v); return 0; } } // namespace cuda diff --git a/dnn/src/cuda/multi_head_attn/opr_impl.h b/dnn/src/cuda/multi_head_attn/opr_impl.h index 4596bd37175c50469e031701ef799cf64c871e84..b7aed9774c9d7f01b27500ca1f519d46701e220c 100644 --- a/dnn/src/cuda/multi_head_attn/opr_impl.h +++ b/dnn/src/cuda/multi_head_attn/opr_impl.h @@ -19,20 +19,37 @@ public: void exec( _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values, - _megdnn_tensor_in wqkv, _megdnn_tensor_out out, - _megdnn_tensor_out reserveSpace, _megdnn_workspace workspace) override; + _megdnn_tensor_in qkvo_weight_bias, _megdnn_tensor_in attn_mask, + _megdnn_tensor_in bias_k, _megdnn_tensor_in bias_v, _megdnn_tensor_out out, + _megdnn_tensor_out attn_weight, _megdnn_tensor_out mask_reservespace, + _megdnn_tensor_out othr_reservespace, _megdnn_workspace workspace) override; void deduce_layout( const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& wqkv, TensorLayout& out, - TensorLayout& reserveSpace); - size_t get_reservespace_in_bytes( + const TensorLayout& values, const TensorLayout& qkvo_weight_bias, + const TensorLayout& attn_mask, const TensorLayout& bias_k, + const TensorLayout& bias_v, TensorLayout& out, TensorLayout& attn_weight, + TensorLayout& mask_reservespace, TensorLayout& othr_reservespace) override; + size_t get_mask_reservespace_in_bytes( const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& wqkv, - const TensorLayout& out, const TensorLayout& reserveSpace) override; + const TensorLayout& values, const TensorLayout& qkvo_weight_bias, + const TensorLayout& attn_mask, const TensorLayout& bias_k, + const TensorLayout& bias_v, const TensorLayout& out, + const TensorLayout& attn_weight, const TensorLayout& mask_reservespace, + const TensorLayout& othr_reservespace) override; + size_t get_othr_reservespace_in_bytes( + const TensorLayout& queries, const TensorLayout& keys, + const TensorLayout& values, const TensorLayout& qkvo_weight_bias, + const TensorLayout& attn_mask, const TensorLayout& bias_k, + const TensorLayout& bias_v, const TensorLayout& out, + const TensorLayout& attn_weight, const TensorLayout& mask_reservespace, + const TensorLayout& othr_reservespace) override; size_t get_workspace_in_bytes( const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& wqkv, - const TensorLayout& out, const TensorLayout& reserveSpace) override; + const TensorLayout& values, const TensorLayout& qkvo_weight_bias, + const TensorLayout& attn_mask, const TensorLayout& bias_k, + const TensorLayout& bias_v, const TensorLayout& out, + const TensorLayout& attn_weight, const TensorLayout& mask_reservespace, + const TensorLayout& othr_reservespace) override; }; class MultiHeadAttnBackwardImpl final : public MultiHeadAttnBackward { @@ -43,16 +60,22 @@ public: #endif void exec( _megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys, - _megdnn_tensor_in values, _megdnn_tensor_in wqkv, - _megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries, - _megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues, - _megdnn_tensor_out dweights, _megdnn_workspace workspace) override; + _megdnn_tensor_in values, _megdnn_tensor_in qkvo_weight_bias, + _megdnn_tensor_in attn_mask, _megdnn_tensor_in attn_weight, + _megdnn_tensor_in mask_reservespace, _megdnn_tensor_in othr_reservespace, + _megdnn_tensor_out dqueries, _megdnn_tensor_out dkeys, + _megdnn_tensor_out dvalues, _megdnn_tensor_out dqkvo_weight_bias, + _megdnn_tensor_out dbias_k, _megdnn_tensor_out dbias_v, + _megdnn_workspace workspace) override; size_t get_workspace_in_bytes( const TensorLayout& diff, const TensorLayout& queries, const TensorLayout& keys, const TensorLayout& values, - const TensorLayout& wqkv, const TensorLayout& reserveSpace, - const TensorLayout& dqueries, const TensorLayout& dkeys, - const TensorLayout& dvalues, const TensorLayout& dweights) override; + const TensorLayout& qkvo_weight_bias, const TensorLayout& attn_mask, + const TensorLayout& attn_weight, const TensorLayout& mask_reservespace, + const TensorLayout& othr_reservespace, const TensorLayout& dqueries, + const TensorLayout& dkeys, const TensorLayout& dvalues, + const TensorLayout& dqkvo_weight_bias, const TensorLayout& dbias_k, + const TensorLayout& dbias_v) override; }; } // namespace cuda } // namespace megdnn diff --git a/dnn/src/naive/multi_head_attn/opr_impl.cpp b/dnn/src/naive/multi_head_attn/opr_impl.cpp index 773810601f459c873f86e25f7db6add8c4af6d51..cb5f6d66e3564f102c60241764d2b63bf034fb8b 100644 --- a/dnn/src/naive/multi_head_attn/opr_impl.cpp +++ b/dnn/src/naive/multi_head_attn/opr_impl.cpp @@ -8,45 +8,45 @@ namespace naive { using Param = MultiHeadAttnBase::Param; size_t MultiHeadAttnForwardImpl::get_workspace_in_bytes( - const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& out, - const TensorLayout& reserveSpace) { - MEGDNN_MARK_USED_VAR(queries); - MEGDNN_MARK_USED_VAR(keys); - MEGDNN_MARK_USED_VAR(values); - MEGDNN_MARK_USED_VAR(wqkv); - MEGDNN_MARK_USED_VAR(out); - MEGDNN_MARK_USED_VAR(reserveSpace); + const TensorLayout& /*queries*/, const TensorLayout& /*keys*/, + const TensorLayout& /*values*/, const TensorLayout& /*qkvo_weight_bias*/, + const TensorLayout& /*attn_mask*/, const TensorLayout& /*bias_k*/, + const TensorLayout& /*bias_v*/, const TensorLayout& /*out*/, + const TensorLayout& /*attn_weight*/, const TensorLayout& /*mask_reservespace*/, + const TensorLayout& /*othr_reservespace*/) { megdnn_throw("unsupported naive multiheadattn forward\n"); } void MultiHeadAttnForwardImpl::exec( _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values, - _megdnn_tensor_in wqkv, _megdnn_tensor_out out, _megdnn_tensor_out reserveSpace, - _megdnn_workspace workspace) { - MEGDNN_MARK_USED_VAR(queries); - MEGDNN_MARK_USED_VAR(keys); - MEGDNN_MARK_USED_VAR(values); - MEGDNN_MARK_USED_VAR(wqkv); - MEGDNN_MARK_USED_VAR(out); - MEGDNN_MARK_USED_VAR(reserveSpace); + _megdnn_tensor_in qkvo_weight_bias, _megdnn_tensor_in attn_mask, + _megdnn_tensor_in bias_k, _megdnn_tensor_in bias_v, _megdnn_tensor_out out, + _megdnn_tensor_out attn_weight, _megdnn_tensor_out mask_reservespace, + _megdnn_tensor_out othr_reservespace, _megdnn_workspace workspace) { check_exec( - queries.layout, keys.layout, values.layout, wqkv.layout, out.layout, - reserveSpace.layout, workspace.size); + queries.layout, keys.layout, values.layout, qkvo_weight_bias.layout, + attn_mask.layout, bias_k.layout, bias_v.layout, out.layout, + attn_weight.layout, mask_reservespace.layout, othr_reservespace.layout, + workspace.size); megdnn_throw("unsupported naive multiheadattn forward\n"); } void MultiHeadAttnBackwardImpl::exec( _megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys, - _megdnn_tensor_in values, _megdnn_tensor_in wqkv, - _megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries, - _megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues, - _megdnn_tensor_out dweights, _megdnn_workspace workspace) { + _megdnn_tensor_in values, _megdnn_tensor_in qkvo_weight_bias, + _megdnn_tensor_in attn_mask, _megdnn_tensor_in attn_weight, + _megdnn_tensor_in mask_reservespace, _megdnn_tensor_in othr_reservespace, + _megdnn_tensor_out dqueries, _megdnn_tensor_out dkeys, + _megdnn_tensor_out dvalues, _megdnn_tensor_out dqkvo_weight_bias, + _megdnn_tensor_out dbias_k, _megdnn_tensor_out dbias_v, + _megdnn_workspace workspace) { check_exec( - diff.layout, queries.layout, keys.layout, values.layout, wqkv.layout, - reserveSpace.layout, dqueries.layout, dkeys.layout, dvalues.layout, - dweights.layout, workspace.size); + diff.layout, queries.layout, keys.layout, values.layout, + qkvo_weight_bias.layout, attn_mask.layout, attn_weight.layout, + mask_reservespace.layout, othr_reservespace.layout, dqueries.layout, + dkeys.layout, dvalues.layout, dqkvo_weight_bias.layout, dbias_k.layout, + dbias_v.layout, workspace.size); megdnn_throw("unsupported naive multiheadattn backward\n"); } diff --git a/dnn/src/naive/multi_head_attn/opr_impl.h b/dnn/src/naive/multi_head_attn/opr_impl.h index 5fb1cba9120b545b1167757045cd55e37942e16a..5c14806cdea1bd1d0da7cab021ccef9d231275dd 100644 --- a/dnn/src/naive/multi_head_attn/opr_impl.h +++ b/dnn/src/naive/multi_head_attn/opr_impl.h @@ -14,17 +14,43 @@ public: using MultiHeadAttnForward::MultiHeadAttnForward; void exec( _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values, - _megdnn_tensor_in wqkv, _megdnn_tensor_out out, - _megdnn_tensor_out reserveSpace, _megdnn_workspace workspace) override; + _megdnn_tensor_in qkvo_weight_bias, _megdnn_tensor_in attn_mask, + _megdnn_tensor_in bias_k, _megdnn_tensor_in bias_v, _megdnn_tensor_out out, + _megdnn_tensor_out attn_weight, _megdnn_tensor_out mask_reservespace, + _megdnn_tensor_out othr_reservespace, _megdnn_workspace workspace) override; + void deduce_layout( + const TensorLayout& /*queries*/, const TensorLayout& /*keys*/, + const TensorLayout& /*values*/, const TensorLayout& /*qkvo_weight_bias*/, + const TensorLayout& /*attn_mask*/, const TensorLayout& /*bias_k*/, + const TensorLayout& /*bias_v*/, TensorLayout& /*out*/, + TensorLayout& /*attn_weight*/, TensorLayout& /*mask_reservespace*/, + TensorLayout& /*othr_reservespace*/) override {} size_t get_workspace_in_bytes( - const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& wqkv, - const TensorLayout& out, const TensorLayout& reserveSpace) override; - size_t get_reservespace_in_bytes( const TensorLayout& /*queries*/, const TensorLayout& /*keys*/, - const TensorLayout& /*values*/, const TensorLayout& /*wqkv*/, - const TensorLayout& /*out*/, - const TensorLayout& /*reserveSpace*/) override { + const TensorLayout& /*values*/, const TensorLayout& /*qkvo_weight_bias*/, + const TensorLayout& /*attn_mask*/, const TensorLayout& /*bias_k*/, + const TensorLayout& /*bias_v*/, const TensorLayout& /*out*/, + const TensorLayout& /*attn_weight*/, + const TensorLayout& /*mask_reservespace*/, + const TensorLayout& /*othr_reservespace*/) override; + size_t get_mask_reservespace_in_bytes( + const TensorLayout& /*queries*/, const TensorLayout& /*keys*/, + const TensorLayout& /*values*/, const TensorLayout& /*qkvo_weight_bias*/, + const TensorLayout& /*attn_mask*/, const TensorLayout& /*bias_k*/, + const TensorLayout& /*bias_v*/, const TensorLayout& /*out*/, + const TensorLayout& /*attn_weight*/, + const TensorLayout& /*mask_reservespace*/, + const TensorLayout& /*othr_reservespace*/) override { + return 0; + } + size_t get_othr_reservespace_in_bytes( + const TensorLayout& /*queries*/, const TensorLayout& /*keys*/, + const TensorLayout& /*values*/, const TensorLayout& /*qkvo_weight_bias*/, + const TensorLayout& /*attn_mask*/, const TensorLayout& /*bias_k*/, + const TensorLayout& /*bias_v*/, const TensorLayout& /*out*/, + const TensorLayout& /*attn_weight*/, + const TensorLayout& /*mask_reservespace*/, + const TensorLayout& /*othr_reservespace*/) override { return 0; } }; @@ -34,17 +60,23 @@ public: using MultiHeadAttnBackward::MultiHeadAttnBackward; void exec( _megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys, - _megdnn_tensor_in values, _megdnn_tensor_in wqkv, - _megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries, - _megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues, - _megdnn_tensor_out dweights, _megdnn_workspace workspace) override; + _megdnn_tensor_in values, _megdnn_tensor_in qkvo_weight_bias, + _megdnn_tensor_in attn_mask, _megdnn_tensor_in attn_weight, + _megdnn_tensor_in mask_reservespace, _megdnn_tensor_in othr_reservespace, + _megdnn_tensor_out dqueries, _megdnn_tensor_out dkeys, + _megdnn_tensor_out dvalues, _megdnn_tensor_out dqkvo_weight_bias, + _megdnn_tensor_out dbias_k, _megdnn_tensor_out dbias_v, + _megdnn_workspace workspace) override; size_t get_workspace_in_bytes( - const TensorLayout& /*diff*/, const TensorLayout& /* queries*/, - const TensorLayout& /*keyes*/, const TensorLayout& /* values*/, - const TensorLayout& /*wqkv*/, const TensorLayout& /* reserveSpace*/, - const TensorLayout& /*dqueries*/, const TensorLayout& /* dkeyes*/, - const TensorLayout& /*dvalues*/, - const TensorLayout& /* dweights*/) override { + const TensorLayout& /*diff*/, const TensorLayout& /*queries*/, + const TensorLayout& /*keys*/, const TensorLayout& /*values*/, + const TensorLayout& /*qkvo_weight_bias*/, const TensorLayout& /*attn_mask*/, + const TensorLayout& /*attn_weight*/, + const TensorLayout& /*mask_reservespace*/, + const TensorLayout& /*othr_reservespace*/, const TensorLayout& /*dqueries*/, + const TensorLayout& /*dkeys*/, const TensorLayout& /*dvalues*/, + const TensorLayout& /*dqkvo_weight_bias*/, const TensorLayout& /*dbias_k*/, + const TensorLayout& /*dbias_v*/) override { return 0; } }; diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 69fc6ead2ab0951c22790cb9d011457742cc5b30..c3568702576840681f30994c923cd40998d7f446 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -899,24 +899,24 @@ def gelu(x): def softplus(inp: Tensor) -> Tensor: r"""Applies the element-wise function: - .. math:: - \text{softplus}(x) = \log(1 + \exp(x)) + .. math:: + \text{softplus}(x) = \log(1 + \exp(x)) - softplus is a smooth approximation to the ReLU function and can be used - to constrain the output to be always positive. - For numerical stability the implementation follows this transformation: + softplus is a smooth approximation to the ReLU function and can be used + to constrain the output to be always positive. + For numerical stability the implementation follows this transformation: - .. math:: - \text{softplus}(x) = \log(1 + \exp(x)) - = \log(1 + \exp(-\text{abs}(x))) + \max(x, 0) - = \log1p(\exp(-\text{abs}(x))) + \text{relu}(x) + .. math:: + \text{softplus}(x) = \log(1 + \exp(x)) + = \log(1 + \exp(-\text{abs}(x))) + \max(x, 0) + = \log1p(\exp(-\text{abs}(x))) + \text{relu}(x) - Examples: - >>> import numpy as np - >>> x = Tensor(np.arange(-3, 3, dtype=np.float32)) - >>> y = F.softplus(x) - >>> y.numpy().round(decimals=4) - array([0.0486, 0.1269, 0.3133, 0.6931, 1.3133, 2.1269], dtype=float32) + Examples: + >>> import numpy as np + >>> x = Tensor(np.arange(-3, 3, dtype=np.float32)) + >>> y = F.softplus(x) + >>> y.numpy().round(decimals=4) + array([0.0486, 0.1269, 0.3133, 0.6931, 1.3133, 2.1269], dtype=float32) """ return _elwise(inp, mode=Elemwise.Mode.SOFTPLUS) @@ -2213,7 +2213,7 @@ def _merge_masks( ): r""" Determine mask type and combine masks if necessary. - + Note: This function will continue to improve with the iteration of MHA. Args: @@ -2224,7 +2224,7 @@ def _merge_masks( add_bias_kv: used to determine whether pad is needed on the sequence dimension of attn_mask and key_padding_mask, from MHA's ``add_bias_kv``. add_zero_attn: used to determine whether pad is needed on the sequence dimension of attn_mask and key_padding_mask, from MHA's ``add_zero_attn``. is_causal: MHA's is_causal, is_causal provides a hint that attn_mask is the causal mask. - maybe_cudnn_style_mask: MHA's maybe_cudnn_style_mask, like is_causal, maybe_cudnn_style_mask provides a hint that attn_mask and key_padding_mask is the cudnn style mask. + maybe_cudnn_style_mask: MHA's maybe_cudnn_style_mask, like is_causal, maybe_cudnn_style_mask provides a hint that attn_mask and key_padding_mask is the cudnn style mask. num_heads: MHA's head number. Returns: merged_mask: merged mask, may be None, the shape is :math:`(L, S)`, :math:`(2\cdotL + 2\cdotN)` or :math:`(N\cdot\text{num\_heads}, L, S)` @@ -2320,8 +2320,8 @@ def multi_head_attention( num_heads: parallel attention heads. attn_drop: probability of an element to be zeroed, used in attention matrix. out_drop: probability of an element to be zeroed, used in final output. - io_weight_bias: input/output projection weight/bias all in one. - The order of arrangement is: query weight, key weight, value weight, out weight, query bias, key bias, value bias, out bias, the following parameters will be used to indicate whether these items exist: qproj_size, kproj_size, vproj_size, oproj_size, qbias, kbias, vbias, obias. + io_weight_bias: input/output projection weight/bias all in one. + The order of arrangement is: query weight, key weight, value weight, out weight, query bias, key bias, value bias, out bias, the following parameters will be used to indicate whether these items exist: qproj_size, kproj_size, vproj_size, oproj_size, qbias, kbias, vbias, obias. Note: :math:`Y=X@W+B` is used here instead of :math:`Y=X@W^T+B` in pytorch. qproj_size: indicates the projection size of query weight in io_weight_bias, 0 indicates disabled query projection and no query projection weight. kproj_size: indicates the projection size of key weight in io_weight_bias, 0 indicates disabled key projection and no key projection weight. @@ -2335,7 +2335,7 @@ def multi_head_attention( Note: Should be set to None, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. add_zero_attn: if specified, adds a new batch of zeros to the key and value sequences at sequence dim. Default: ``False``. Note: should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. - key_padding_mask: if specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` to ignore for the purpose of + key_padding_mask: if specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. Note: Should be set to None, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all @@ -2353,9 +2353,22 @@ def multi_head_attention( Note: In the cudnn style, the shape of the attn_mask is :math:`(2, L)`, and the shape of the key_padding_mask is :math:`(2, N)`. Warning: like is_causal, maybe_cudnn_style_mask provides a hint that attn_mask and key_padding_mask is a cudnn style mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility. In addition, if the ``_merge_masks`` function returns ``merge_type=cudnn_style_mask``, please ensure that other conditions are correct so that it can run the implementation of cudnn, otherwise an error will be reported. Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that the underlying implementation only accepts two types of mask type, namely "no_mask" and "default_mask", and we may try to loosen this option after submitting the commit that users can pass in custom attention mask tensors. - reslink: add input query to final output. + reslink: add input query to final output. Note: It is only valid if the input query is the same as the shape of the output. training: will apply dropout if is ``True``. + + Outputs: + - **out[0]=attn_output** - Attention outputs of shape :math:`(N, L, E)`, + where :math:`L` is the target sequence length, :math:`N` is + the batch size, and :math:`E` is the embedding dimension ``embed_dim``. + - **out[1]=attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, + returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N * \text{num\_heads}, L, S)`. + Note: Now only None will be returned. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. + - **out[2]=mask_reversespace** - Used to save the dropout mask needed for backward propagation.`, + - **out[3]=othr_reversespace** - Used to save the intermediate results that need to be used in backward propagation.`, """ qproj_size = embed_dim if qproj_size is None else qproj_size kproj_size = embed_dim if kproj_size is None else kproj_size @@ -2448,6 +2461,21 @@ def multi_head_attention( num_heads=num_heads, ) + def get_tensor_combination_type(attn_mask_tensor, bias_k, bias_v): + bias_kv = bias_k is not None and bias_v is not None + if not bias_kv and attn_mask_tensor is None: + return "none" + elif not bias_kv and attn_mask_tensor is not None: + return "only_mask" + elif bias_kv and attn_mask_tensor is None: + return "only_biaskv" + else: + return "all" + + tensor_combination_type = get_tensor_combination_type( + attn_mask_tensor, bias_k, bias_v + ) + op = builtin.MultiHeadAttn( num_heads=num_heads, sm_scaler=smScaler, @@ -2471,11 +2499,20 @@ def multi_head_attention( vbias=vbias, obias=obias, need_weights=need_weights, - tensor_combination_type="none", + tensor_combination_type=tensor_combination_type, ) + if tensor_combination_type == "none": + out = apply(op, query, key, value, io_weight_bias) + elif tensor_combination_type == "only_mask": + out = apply(op, query, key, value, io_weight_bias, attn_mask_tensor) + elif tensor_combination_type == "only_biaskv": + out = apply(op, query, key, value, io_weight_bias, bias_k, bias_v) + else: + out = apply( + op, query, key, value, io_weight_bias, attn_mask_tensor, bias_k, bias_v + ) - out, reserveSpace = apply(op, query, key, value, io_weight_bias) - return out, None + return out[0], out[1] from .loss import * # isort:skip diff --git a/imperative/src/impl/ops/rng.cpp b/imperative/src/impl/ops/rng.cpp index ca8cd3327255536083384910c94672e6ed315d8e..6b62f3e9c985bece66fbe03c593f3a5682e1b0e8 100644 --- a/imperative/src/impl/ops/rng.cpp +++ b/imperative/src/impl/ops/rng.cpp @@ -436,6 +436,28 @@ _INST_RNG_MAKER(4) #undef _FOR_EACH_OUT #undef _FOR_EACH_IN +#define _FOR_EACH_IN(subfix) \ + inputs[0] subfix, inputs[1] subfix, inputs[2] subfix, inputs[3] subfix, \ + inputs[4] subfix, +_INST_RNG_MAKER(5) +#undef _FOR_EACH_IN + +#define _FOR_EACH_IN(subfix) \ + inputs[0] subfix, inputs[1] subfix, inputs[2] subfix, inputs[3] subfix, \ + inputs[4] subfix, inputs[5] subfix, +_INST_RNG_MAKER(6) +#undef _FOR_EACH_IN + +#define _FOR_EACH_IN(subfix) \ + inputs[0] subfix, inputs[1] subfix, inputs[2] subfix, inputs[3] subfix, \ + inputs[4] subfix, inputs[5] subfix, inputs[6] subfix, +#define _FOR_EACH_OUT(subfix) \ + outputs[0] subfix, outputs[1] subfix, outputs[2] subfix, outputs[3] subfix +_INST_RNG_INVOLKER(7, 4) +_INST_RNG_MAKER(7) +#undef _FOR_EACH_OUT +#undef _FOR_EACH_IN + #undef _INST_RNG_INVOLKER #undef _INST_RNG_MAKER @@ -541,37 +563,90 @@ SmallVector infer_output_attrs( return dests; } +template +std::tuple, bool> _infer_output_attrs( + const OpDef& op, const SmallVector& inputs, const CompNode cn){}; + template <> -SmallVector infer_output_attrs( - const OpDef& op, const SmallVector& inputs) { - SmallVector dests(2); - auto&& cn = inputs[0]->comp_node(); +std::tuple, bool> _infer_output_attrs( + const OpDef& op, const SmallVector& inputs, const CompNode cn) { + bool success = inputs[0].ndim != 0; - dests[0].comp_node = cn; - dests[0].layout = TensorLayout(inputs[0]->layout()); - dests[0].layout.dtype = inputs[0]->layout().dtype; + SmallVector dests(4); - auto get_reservespace_in_bytes = [&]() -> size_t { - // retrieve dnn_op from glob cache - auto&& rng = op.cast_final_safe(); - auto handle = rng.handle; - if (!handle) { - handle = RNGDnnOpManager::get_default_handle(cn); - } - auto dnn_op_thread_safe = - RNGDnnOpManager::inst().get_dnn_op( - handle, reinterpret_cast(op.dyn_typeinfo()), cn); - auto dnn_op = std::get<1>(dnn_op_thread_safe); - dnn_op->param() = OpMeth::make_param(rng); + // retrieve dnn_op from glob cache + auto&& rng = op.cast_final_safe(); + auto handle = rng.handle; + if (!handle) { + handle = RNGDnnOpManager::get_default_handle(cn); + } + auto dnn_op_thread_safe = RNGDnnOpManager::inst().get_dnn_op( + handle, reinterpret_cast(op.dyn_typeinfo()), cn); + auto dnn_op = std::get<1>(dnn_op_thread_safe); + dnn_op->param() = OpMeth::make_param(rng); - return dnn_op->get_reservespace_in_bytes( - inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(), - inputs[3]->layout(), {}, {}); - }; + TensorLayout out, attn_weight, mask_layout, othr_layout; + dnn_op->deduce_layout( + inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], inputs[5], inputs[6], + out, attn_weight, mask_layout, othr_layout); + + dests[0].comp_node = cn; + dests[0].layout = out; + dests[0].layout.dtype = inputs[0].dtype; dests[1].comp_node = cn; - dests[1].layout = - TensorLayout(TensorShape({get_reservespace_in_bytes()}), dtype::Byte()); - return dests; + dests[1].layout = attn_weight; + if (success) { + dests[2].comp_node = cn; + dests[2].layout = mask_layout; + dests[3].comp_node = cn; + dests[3].layout = othr_layout; + } else { + dests[2].comp_node = cn; + dests[2].layout = TensorLayout(dtype::Byte()); + dests[3].comp_node = cn; + dests[3].layout = TensorLayout(inputs[0].dtype); + } + + return {dests, success}; +} + +template <> +SmallVector infer_output_attrs( + const OpDef& op, const SmallVector& inputs) { + using INPUT_TYPE = opr::MultiHeadAttn::Param::TENSOR_COMBINATION_TYPE; + auto&& cn = inputs[0]->comp_node(); + auto input_type = op.cast_final_safe().tensor_combination_type; + + std::tuple, bool> ret; + TensorLayout empty_layout; + if (input_type == INPUT_TYPE::NONE) + ret = _infer_output_attrs( + op, + {inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(), + inputs[3]->layout(), empty_layout, empty_layout, empty_layout}, + cn); + else if (input_type == INPUT_TYPE::ONLY_MASK) + ret = _infer_output_attrs( + op, + {inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(), + inputs[3]->layout(), inputs[4]->layout(), empty_layout, empty_layout}, + cn); + else if (input_type == INPUT_TYPE::ONLY_BIASKV) + ret = _infer_output_attrs( + op, + {inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(), + inputs[3]->layout(), empty_layout, inputs[4]->layout(), + inputs[5]->layout()}, + cn); + else + ret = _infer_output_attrs( + op, + {inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(), + inputs[3]->layout(), inputs[4]->layout(), inputs[5]->layout(), + inputs[6]->layout()}, + cn); + + return std::get<0>(ret); } template @@ -587,6 +662,127 @@ SmallVector apply_on_physical_tensor( return outputs; } +template <> +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + using INPUT_TYPE = opr::MultiHeadAttn::Param::TENSOR_COMBINATION_TYPE; + SmallVector outputs; + SmallVector desc = + infer_output_attrs(def, inputs); + for (auto&& i : desc) { + outputs.push_back(Tensor::make(i.layout, i.comp_node)); + } + + auto&& rng = def.cast_final_safe(); + auto dest = outputs[0]; + if (dest->layout().is_empty()) + return outputs; + auto cn = dest->comp_node(); + auto handle = rng.handle; + if (!handle) { + handle = RNGDnnOpManager::get_default_handle(cn); + } + + // retrieve dnn_op from glob cache + auto dnn_op_thread_safe = + RNGDnnOpManager::inst().get_dnn_op::DnnOp>( + handle, reinterpret_cast(def.dyn_typeinfo()), cn); + auto initialized = std::get<0>(dnn_op_thread_safe); + auto dnn_op = std::get<1>(dnn_op_thread_safe); + if (initialized) { + auto handle_seed = RNGDnnOpManager::get_seed(handle); + mgb_assert( + dnn_op->param().seed == handle_seed, + "inconsistent rng seed: handle: %lu, dnn_op: %lu", handle_seed, + dnn_op->param().seed); + } + dnn_op->param() = OpMeth::make_param(rng); + + auto input_type = rng.tensor_combination_type; + std::shared_ptr empty_dnn(nullptr); + size_t wk_size = 0; + TensorLayout empty_layout; + megdnn::TensorND empty_tensor; + + if (input_type == INPUT_TYPE::ALL) { + wk_size = dnn_op->get_workspace_in_bytes( + inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(), + inputs[3]->layout(), inputs[4]->layout(), inputs[5]->layout(), + inputs[6]->layout(), outputs[0]->layout(), outputs[1]->layout(), + outputs[2]->layout(), outputs[3]->layout()); + auto workspace = Blob::make(outputs[0]->comp_node(), wk_size); + megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); + dnn_op->exec( + inputs[0]->dev_tensor().as_megdnn(), + inputs[1]->dev_tensor().as_megdnn(), + inputs[2]->dev_tensor().as_megdnn(), + inputs[3]->dev_tensor().as_megdnn(), + inputs[4]->dev_tensor().as_megdnn(), + inputs[5]->dev_tensor().as_megdnn(), + inputs[6]->dev_tensor().as_megdnn(), + outputs[0]->dev_tensor().as_megdnn(), + outputs[1]->dev_tensor().as_megdnn(), + outputs[2]->dev_tensor().as_megdnn(), + outputs[3]->dev_tensor().as_megdnn(), dnn_wk); + } else if (input_type == INPUT_TYPE::ONLY_MASK) { + wk_size = dnn_op->get_workspace_in_bytes( + inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(), + inputs[3]->layout(), inputs[4]->layout(), empty_layout, empty_layout, + outputs[0]->layout(), outputs[1]->layout(), outputs[2]->layout(), + outputs[3]->layout()); + auto workspace = Blob::make(outputs[0]->comp_node(), wk_size); + megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); + dnn_op->exec( + inputs[0]->dev_tensor().as_megdnn(), + inputs[1]->dev_tensor().as_megdnn(), + inputs[2]->dev_tensor().as_megdnn(), + inputs[3]->dev_tensor().as_megdnn(), + inputs[4]->dev_tensor().as_megdnn(), empty_tensor, empty_tensor, + outputs[0]->dev_tensor().as_megdnn(), + outputs[1]->dev_tensor().as_megdnn(), + outputs[2]->dev_tensor().as_megdnn(), + outputs[3]->dev_tensor().as_megdnn(), dnn_wk); + } else if (input_type == INPUT_TYPE::ONLY_BIASKV) { + wk_size = dnn_op->get_workspace_in_bytes( + inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(), + inputs[3]->layout(), empty_layout, inputs[4]->layout(), + inputs[5]->layout(), outputs[0]->layout(), outputs[1]->layout(), + outputs[2]->layout(), outputs[3]->layout()); + auto workspace = Blob::make(outputs[0]->comp_node(), wk_size); + megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); + dnn_op->exec( + inputs[0]->dev_tensor().as_megdnn(), + inputs[1]->dev_tensor().as_megdnn(), + inputs[2]->dev_tensor().as_megdnn(), + inputs[3]->dev_tensor().as_megdnn(), empty_tensor, + inputs[5]->dev_tensor().as_megdnn(), + inputs[6]->dev_tensor().as_megdnn(), + outputs[0]->dev_tensor().as_megdnn(), + outputs[1]->dev_tensor().as_megdnn(), + outputs[2]->dev_tensor().as_megdnn(), + outputs[3]->dev_tensor().as_megdnn(), dnn_wk); + } else { + wk_size = dnn_op->get_workspace_in_bytes( + inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(), + inputs[3]->layout(), empty_layout, empty_layout, empty_layout, + outputs[0]->layout(), outputs[1]->layout(), outputs[2]->layout(), + outputs[3]->layout()); + auto workspace = Blob::make(outputs[0]->comp_node(), wk_size); + megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); + dnn_op->exec( + inputs[0]->dev_tensor().as_megdnn(), + inputs[1]->dev_tensor().as_megdnn(), + inputs[2]->dev_tensor().as_megdnn(), + inputs[3]->dev_tensor().as_megdnn(), empty_tensor, empty_tensor, + empty_tensor, outputs[0]->dev_tensor().as_megdnn(), + outputs[1]->dev_tensor().as_megdnn(), + outputs[2]->dev_tensor().as_megdnn(), + outputs[3]->dev_tensor().as_megdnn(), dnn_wk); + } + return outputs; +} + template Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { size_t nr_inp = inputs.size(); @@ -601,6 +797,23 @@ Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { return _RNGOprMaker::make(inputs, rng); } +template <> +SymbolVarArray apply_on_var_node( + const OpDef& def, const VarNodeArray& inputs) { + auto&& rng = def.cast_final_safe(); + using INPUT_TYPE = opr::MultiHeadAttn::Param::TENSOR_COMBINATION_TYPE; + auto input_type = rng.tensor_combination_type; + if (input_type == INPUT_TYPE::ALL) { + return _RNGOprMaker<7>::make(inputs, rng); + } else if (input_type == INPUT_TYPE::ONLY_BIASKV) { + return _RNGOprMaker<6>::make(inputs, rng); + } else if (input_type == INPUT_TYPE::ONLY_MASK) { + return _RNGOprMaker<5>::make(inputs, rng); + } else { + return _RNGOprMaker<4>::make(inputs, rng); + } +} + template std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { @@ -671,39 +884,38 @@ std::tuple, bool> infer_output_attrs_fallible std::tuple, bool> infer_output_attrs_fallible< MultiHeadAttn>(const OpDef& op, const SmallVector& inputs) { - bool success = inputs[0].layout.ndim != 0; - - SmallVector dests(2); - auto cn = inputs[0].comp_node; - dests[0].comp_node = cn; - dests[0].layout = TensorLayout(inputs[0].layout); - dests[0].layout.dtype = inputs[0].layout.dtype; - - auto get_reservespace_in_bytes = [&]() -> size_t { - auto&& rng = op.cast_final_safe(); - auto handle = rng.handle; - if (!handle) { - handle = RNGDnnOpManager::get_default_handle(cn); - } - auto dnn_op_thread_safe = - RNGDnnOpManager::inst().get_dnn_op( - handle, reinterpret_cast(op.dyn_typeinfo()), cn); - auto dnn_op = std::get<1>(dnn_op_thread_safe); - dnn_op->param() = OpMeth::make_param(rng); - - return dnn_op->get_reservespace_in_bytes( - inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout, - {}, {}); - }; - dests[1].comp_node = cn; - if (success) { - dests[1].layout = - TensorLayout(TensorShape({get_reservespace_in_bytes()}), dtype::Byte()); - } else { - dests[1].layout = TensorLayout(dtype::Byte()); - } - - return {dests, success}; + using INPUT_TYPE = opr::MultiHeadAttn::Param::TENSOR_COMBINATION_TYPE; + auto&& cn = inputs[0].comp_node; + auto input_type = op.cast_final_safe().tensor_combination_type; + + std::tuple, bool> ret; + TensorLayout empty_layout; + if (input_type == INPUT_TYPE::NONE) + ret = _infer_output_attrs( + op, + {inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout, + empty_layout, empty_layout, empty_layout}, + cn); + else if (input_type == INPUT_TYPE::ONLY_MASK) + ret = _infer_output_attrs( + op, + {inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout, + inputs[4].layout, empty_layout, empty_layout}, + cn); + else if (input_type == INPUT_TYPE::ONLY_BIASKV) + ret = _infer_output_attrs( + op, + {inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout, + empty_layout, inputs[4].layout, inputs[5].layout}, + cn); + else + ret = _infer_output_attrs( + op, + {inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout, + inputs[4].layout, inputs[5].layout, inputs[6].layout}, + cn); + + return ret; } template diff --git a/imperative/tablegen/generated/hash.txt b/imperative/tablegen/generated/hash.txt index 3e050f39a392df0992cd4063c7400124868837e3..71eb57d1ce310f39603d6c48d5b9ef8e4d33b7ff 100644 --- a/imperative/tablegen/generated/hash.txt +++ b/imperative/tablegen/generated/hash.txt @@ -1,7 +1,7 @@ 0a8cd3cd50cadfaae0478ee70621618e ../../dnn/scripts/opr_param_defs.py 9e9636d66694dd7d5a7853247a5406f9 ../../src/core/include/megbrain/ir/ops.td -283dffd0e9cd28db5155c44cf4eda148 generated/opdef.h.inl -5e8d57337c3aec6f4b3b30ef9ba141f8 generated/opdef.cpp.inl -7f470236e4b5b00bdeaec321bc7187b5 generated/opdef.py.inl -003addd357423b880cd06410f5bf624b generated/opdef.cpy.inl +2c15c869c1731d1bc5f25f9b132f4f08 generated/opdef.h.inl +0dabeee4b8f81be4c1809906b99795a5 generated/opdef.cpp.inl +be20faf18eccbc56f535b012170ed90a generated/opdef.py.inl +af9ab62fe962d409bb65e66af5f44a79 generated/opdef.cpy.inl d468302f2d4b113913b76b5a181aae56 generated/enum_macro.h diff --git a/imperative/tablegen/generated/opdef.cpp.inl b/imperative/tablegen/generated/opdef.cpp.inl index b0f2d2ce49dfb85f131704664d35fd90d011f164..e44998ba4cdf4069e682b93d452c4f13047b10c3 100644 --- a/imperative/tablegen/generated/opdef.cpp.inl +++ b/imperative/tablegen/generated/opdef.cpp.inl @@ -5321,8 +5321,8 @@ std::vector> MultiHeadAttn_props_impl(const props_.emplace_back("tensor_combination_type", "INVALID"); break; } - props_.emplace_back("need_weights", std::to_string(op_.need_weights)); props_.emplace_back("add_zero_attn", std::to_string(op_.add_zero_attn)); + props_.emplace_back("need_weights", std::to_string(op_.need_weights)); props_.emplace_back("reslink", std::to_string(op_.reslink)); props_.emplace_back("training", std::to_string(op_.training)); props_.emplace_back("seed", std::to_string(op_.seed)); diff --git a/imperative/tablegen/generated/opdef.cpy.inl b/imperative/tablegen/generated/opdef.cpy.inl index 3673cd6ad8db5cd1a5abd33c32710878b203f8ab..f10c3a482b67b804ac8f3280c5652375150a33bd 100644 --- a/imperative/tablegen/generated/opdef.cpy.inl +++ b/imperative/tablegen/generated/opdef.cpy.inl @@ -15238,8 +15238,8 @@ PyOpDefBegin(MultiHeadAttn) // { {"input_order", serialization::dump(opdef.input_order)}, {"attn_mask_type", serialization::dump(opdef.attn_mask_type)}, {"tensor_combination_type", serialization::dump(opdef.tensor_combination_type)}, - {"need_weights", serialization::dump(opdef.need_weights)}, {"add_zero_attn", serialization::dump(opdef.add_zero_attn)}, + {"need_weights", serialization::dump(opdef.need_weights)}, {"reslink", serialization::dump(opdef.reslink)}, {"training", serialization::dump(opdef.training)}, {"seed", serialization::dump(opdef.seed)}, @@ -15369,16 +15369,16 @@ PyOpDefBegin(MultiHeadAttn) // { } { - auto&& iter = state.find("need_weights"); + auto&& iter = state.find("add_zero_attn"); if (iter != state.end()) { - opdef.need_weights = serialization::load(iter->second); + opdef.add_zero_attn = serialization::load(iter->second); } } { - auto&& iter = state.find("add_zero_attn"); + auto&& iter = state.find("need_weights"); if (iter != state.end()) { - opdef.add_zero_attn = serialization::load(iter->second); + opdef.need_weights = serialization::load(iter->second); } } @@ -15432,9 +15432,9 @@ PyOpDefBegin(MultiHeadAttn) // { PyOpDefEnd(MultiHeadAttn) int PyOp(MultiHeadAttn)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { - static const char* kwlist[] = {"num_heads", "embeding_size", "k_size", "v_size", "qproj_size", "kproj_size", "vproj_size", "oproj_size", "qbias", "kbias", "vbias", "obias", "sm_scaler", "input_order", "attn_mask_type", "tensor_combination_type", "need_weights", "add_zero_attn", "reslink", "training", "seed", "attn_prob", "out_prob", "handle", "scope", NULL}; - PyObject *num_heads = NULL, *embeding_size = NULL, *k_size = NULL, *v_size = NULL, *qproj_size = NULL, *kproj_size = NULL, *vproj_size = NULL, *oproj_size = NULL, *qbias = NULL, *kbias = NULL, *vbias = NULL, *obias = NULL, *sm_scaler = NULL, *input_order = NULL, *attn_mask_type = NULL, *tensor_combination_type = NULL, *need_weights = NULL, *add_zero_attn = NULL, *reslink = NULL, *training = NULL, *seed = NULL, *attn_prob = NULL, *out_prob = NULL, *handle = NULL, *scope = NULL; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOOOOOOOOOOOOOOOOOOOOOO", const_cast(kwlist), &num_heads, &embeding_size, &k_size, &v_size, &qproj_size, &kproj_size, &vproj_size, &oproj_size, &qbias, &kbias, &vbias, &obias, &sm_scaler, &input_order, &attn_mask_type, &tensor_combination_type, &need_weights, &add_zero_attn, &reslink, &training, &seed, &attn_prob, &out_prob, &handle, &scope)) + static const char* kwlist[] = {"num_heads", "embeding_size", "k_size", "v_size", "qproj_size", "kproj_size", "vproj_size", "oproj_size", "qbias", "kbias", "vbias", "obias", "sm_scaler", "input_order", "attn_mask_type", "tensor_combination_type", "add_zero_attn", "need_weights", "reslink", "training", "seed", "attn_prob", "out_prob", "handle", "scope", NULL}; + PyObject *num_heads = NULL, *embeding_size = NULL, *k_size = NULL, *v_size = NULL, *qproj_size = NULL, *kproj_size = NULL, *vproj_size = NULL, *oproj_size = NULL, *qbias = NULL, *kbias = NULL, *vbias = NULL, *obias = NULL, *sm_scaler = NULL, *input_order = NULL, *attn_mask_type = NULL, *tensor_combination_type = NULL, *add_zero_attn = NULL, *need_weights = NULL, *reslink = NULL, *training = NULL, *seed = NULL, *attn_prob = NULL, *out_prob = NULL, *handle = NULL, *scope = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOOOOOOOOOOOOOOOOOOOOOO", const_cast(kwlist), &num_heads, &embeding_size, &k_size, &v_size, &qproj_size, &kproj_size, &vproj_size, &oproj_size, &qbias, &kbias, &vbias, &obias, &sm_scaler, &input_order, &attn_mask_type, &tensor_combination_type, &add_zero_attn, &need_weights, &reslink, &training, &seed, &attn_prob, &out_prob, &handle, &scope)) return -1; if (num_heads) { @@ -15581,21 +15581,21 @@ int PyOp(MultiHeadAttn)::py_init(PyObject *self, PyObject *args, PyObject *kwds) } CATCH_ALL(-1) } - if (need_weights) { + if (add_zero_attn) { try { // TODO: remove this guard which is used for pybind11 implicit conversion py::detail::loader_life_support guard{}; - reinterpret_cast(self)->inst().need_weights = - py::cast(py::handle(need_weights)); + reinterpret_cast(self)->inst().add_zero_attn = + py::cast(py::handle(add_zero_attn)); } CATCH_ALL(-1) } - if (add_zero_attn) { + if (need_weights) { try { // TODO: remove this guard which is used for pybind11 implicit conversion py::detail::loader_life_support guard{}; - reinterpret_cast(self)->inst().add_zero_attn = - py::cast(py::handle(add_zero_attn)); + reinterpret_cast(self)->inst().need_weights = + py::cast(py::handle(need_weights)); } CATCH_ALL(-1) } @@ -15680,8 +15680,8 @@ PyGetSetDef PyOp(MultiHeadAttn)::py_getsetters[] = { {const_cast("input_order"), py_get_generic(MultiHeadAttn, input_order), py_set_generic(MultiHeadAttn, input_order), const_cast("input_order"), NULL}, {const_cast("attn_mask_type"), py_get_generic(MultiHeadAttn, attn_mask_type), py_set_generic(MultiHeadAttn, attn_mask_type), const_cast("attn_mask_type"), NULL}, {const_cast("tensor_combination_type"), py_get_generic(MultiHeadAttn, tensor_combination_type), py_set_generic(MultiHeadAttn, tensor_combination_type), const_cast("tensor_combination_type"), NULL}, - {const_cast("need_weights"), py_get_generic(MultiHeadAttn, need_weights), py_set_generic(MultiHeadAttn, need_weights), const_cast("need_weights"), NULL}, {const_cast("add_zero_attn"), py_get_generic(MultiHeadAttn, add_zero_attn), py_set_generic(MultiHeadAttn, add_zero_attn), const_cast("add_zero_attn"), NULL}, + {const_cast("need_weights"), py_get_generic(MultiHeadAttn, need_weights), py_set_generic(MultiHeadAttn, need_weights), const_cast("need_weights"), NULL}, {const_cast("reslink"), py_get_generic(MultiHeadAttn, reslink), py_set_generic(MultiHeadAttn, reslink), const_cast("reslink"), NULL}, {const_cast("training"), py_get_generic(MultiHeadAttn, training), py_set_generic(MultiHeadAttn, training), const_cast("training"), NULL}, {const_cast("seed"), py_get_generic(MultiHeadAttn, seed), py_set_generic(MultiHeadAttn, seed), const_cast("seed"), NULL}, @@ -15708,7 +15708,7 @@ PyMethodDef PyOp(MultiHeadAttn)::py_init_methoddef = { "__init__", (PyCFunction)PyOp(MultiHeadAttn)::py_init_proxy, METH_VARARGS | METH_KEYWORDS, - "__init__(self, num_heads: int = ..., embeding_size: int = ..., k_size: int = ..., v_size: int = ..., qproj_size: int = ..., kproj_size: int = ..., vproj_size: int = ..., oproj_size: int = ..., qbias: bool = ..., kbias: bool = ..., vbias: bool = ..., obias: bool = ..., sm_scaler: float = ..., input_order: int = ..., attn_mask_type: Union[str, ATTN_MASK_TYPE] = ..., tensor_combination_type: Union[str, TENSOR_COMBINATION_TYPE] = ..., need_weights: bool = ..., add_zero_attn: bool = ..., reslink: bool = ..., training: bool = ..., seed: int = ..., attn_prob: float = ..., out_prob: float = ..., handle: int = ...) -> None\n" + "__init__(self, num_heads: int = ..., embeding_size: int = ..., k_size: int = ..., v_size: int = ..., qproj_size: int = ..., kproj_size: int = ..., vproj_size: int = ..., oproj_size: int = ..., qbias: bool = ..., kbias: bool = ..., vbias: bool = ..., obias: bool = ..., sm_scaler: float = ..., input_order: int = ..., attn_mask_type: Union[str, ATTN_MASK_TYPE] = ..., tensor_combination_type: Union[str, TENSOR_COMBINATION_TYPE] = ..., add_zero_attn: bool = ..., need_weights: bool = ..., reslink: bool = ..., training: bool = ..., seed: int = ..., attn_prob: float = ..., out_prob: float = ..., handle: int = ...) -> None\n" }; void _init_py_MultiHeadAttn(py::module m) { diff --git a/imperative/tablegen/generated/opdef.h.inl b/imperative/tablegen/generated/opdef.h.inl index 7bb498e446818c3cd5898176354582b869ab876e..a099e0b4a50b0dc0bd02d2c5995ebf7958fb4498 100644 --- a/imperative/tablegen/generated/opdef.h.inl +++ b/imperative/tablegen/generated/opdef.h.inl @@ -1416,8 +1416,8 @@ public: uint32_t input_order = 0; ATTN_MASK_TYPE attn_mask_type = ::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK; TENSOR_COMBINATION_TYPE tensor_combination_type = ::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE; - bool need_weights = false; bool add_zero_attn = false; + bool need_weights = false; bool reslink = false; bool training = true; uint64_t seed = 0; @@ -1425,10 +1425,10 @@ public: float out_prob = 0.f; size_t handle; MultiHeadAttn() = default; - MultiHeadAttn(uint32_t num_heads_, uint32_t embeding_size_, uint32_t k_size_, uint32_t v_size_, uint32_t qproj_size_, uint32_t kproj_size_, uint32_t vproj_size_, uint32_t oproj_size_, bool qbias_, bool kbias_, bool vbias_, bool obias_, float sm_scaler_, uint32_t input_order_, ATTN_MASK_TYPE attn_mask_type_, TENSOR_COMBINATION_TYPE tensor_combination_type_, bool need_weights_, bool add_zero_attn_, bool reslink_, bool training_, uint64_t seed_, float attn_prob_, float out_prob_, size_t handle_, std::string scope_ = {}): num_heads(num_heads_), embeding_size(embeding_size_), k_size(k_size_), v_size(v_size_), qproj_size(qproj_size_), kproj_size(kproj_size_), vproj_size(vproj_size_), oproj_size(oproj_size_), qbias(qbias_), kbias(kbias_), vbias(vbias_), obias(obias_), sm_scaler(sm_scaler_), input_order(input_order_), attn_mask_type(attn_mask_type_), tensor_combination_type(tensor_combination_type_), need_weights(need_weights_), add_zero_attn(add_zero_attn_), reslink(reslink_), training(training_), seed(seed_), attn_prob(attn_prob_), out_prob(out_prob_), handle(handle_) { set_scope(scope_); } - MultiHeadAttn(::megdnn::param::MultiHeadAttn packed_param_0, size_t handle_): num_heads(packed_param_0.num_heads), embeding_size(packed_param_0.embeding_size), k_size(packed_param_0.k_size), v_size(packed_param_0.v_size), qproj_size(packed_param_0.qproj_size), kproj_size(packed_param_0.kproj_size), vproj_size(packed_param_0.vproj_size), oproj_size(packed_param_0.oproj_size), qbias(packed_param_0.qbias), kbias(packed_param_0.kbias), vbias(packed_param_0.vbias), obias(packed_param_0.obias), sm_scaler(packed_param_0.sm_scaler), input_order(packed_param_0.input_order), attn_mask_type(packed_param_0.attn_mask_type), tensor_combination_type(packed_param_0.tensor_combination_type), need_weights(packed_param_0.need_weights), add_zero_attn(packed_param_0.add_zero_attn), reslink(packed_param_0.reslink), training(packed_param_0.training), seed(packed_param_0.seed), attn_prob(packed_param_0.attn_prob), out_prob(packed_param_0.out_prob), handle(handle_) {} + MultiHeadAttn(uint32_t num_heads_, uint32_t embeding_size_, uint32_t k_size_, uint32_t v_size_, uint32_t qproj_size_, uint32_t kproj_size_, uint32_t vproj_size_, uint32_t oproj_size_, bool qbias_, bool kbias_, bool vbias_, bool obias_, float sm_scaler_, uint32_t input_order_, ATTN_MASK_TYPE attn_mask_type_, TENSOR_COMBINATION_TYPE tensor_combination_type_, bool add_zero_attn_, bool need_weights_, bool reslink_, bool training_, uint64_t seed_, float attn_prob_, float out_prob_, size_t handle_, std::string scope_ = {}): num_heads(num_heads_), embeding_size(embeding_size_), k_size(k_size_), v_size(v_size_), qproj_size(qproj_size_), kproj_size(kproj_size_), vproj_size(vproj_size_), oproj_size(oproj_size_), qbias(qbias_), kbias(kbias_), vbias(vbias_), obias(obias_), sm_scaler(sm_scaler_), input_order(input_order_), attn_mask_type(attn_mask_type_), tensor_combination_type(tensor_combination_type_), add_zero_attn(add_zero_attn_), need_weights(need_weights_), reslink(reslink_), training(training_), seed(seed_), attn_prob(attn_prob_), out_prob(out_prob_), handle(handle_) { set_scope(scope_); } + MultiHeadAttn(::megdnn::param::MultiHeadAttn packed_param_0, size_t handle_): num_heads(packed_param_0.num_heads), embeding_size(packed_param_0.embeding_size), k_size(packed_param_0.k_size), v_size(packed_param_0.v_size), qproj_size(packed_param_0.qproj_size), kproj_size(packed_param_0.kproj_size), vproj_size(packed_param_0.vproj_size), oproj_size(packed_param_0.oproj_size), qbias(packed_param_0.qbias), kbias(packed_param_0.kbias), vbias(packed_param_0.vbias), obias(packed_param_0.obias), sm_scaler(packed_param_0.sm_scaler), input_order(packed_param_0.input_order), attn_mask_type(packed_param_0.attn_mask_type), tensor_combination_type(packed_param_0.tensor_combination_type), add_zero_attn(packed_param_0.add_zero_attn), need_weights(packed_param_0.need_weights), reslink(packed_param_0.reslink), training(packed_param_0.training), seed(packed_param_0.seed), attn_prob(packed_param_0.attn_prob), out_prob(packed_param_0.out_prob), handle(handle_) {} ::megdnn::param::MultiHeadAttn param() const { - return {num_heads, embeding_size, k_size, v_size, qproj_size, kproj_size, vproj_size, oproj_size, qbias, kbias, vbias, obias, sm_scaler, input_order, attn_mask_type, tensor_combination_type, need_weights, add_zero_attn, reslink, training, seed, attn_prob, out_prob}; + return {num_heads, embeding_size, k_size, v_size, qproj_size, kproj_size, vproj_size, oproj_size, qbias, kbias, vbias, obias, sm_scaler, input_order, attn_mask_type, tensor_combination_type, add_zero_attn, need_weights, reslink, training, seed, attn_prob, out_prob}; } }; diff --git a/imperative/tablegen/generated/opdef.py.inl b/imperative/tablegen/generated/opdef.py.inl index c93f9bb50c2d4c3fd9eeb4fd5294d947eff7e265..7a7de5f8acca10297bac31b2407ff517df048df5 100644 --- a/imperative/tablegen/generated/opdef.py.inl +++ b/imperative/tablegen/generated/opdef.py.inl @@ -1510,7 +1510,7 @@ py::enum_(MultiHeadAttnInst, "TENSOR_COM py::implicitly_convertible(); MultiHeadAttnInst - .def(py::init(), py::arg("num_heads") = 1, py::arg("embeding_size") = 0, py::arg("k_size") = 0, py::arg("v_size") = 0, py::arg("qproj_size") = 0, py::arg("kproj_size") = 0, py::arg("vproj_size") = 0, py::arg("oproj_size") = 0, py::arg("qbias") = false, py::arg("kbias") = false, py::arg("vbias") = false, py::arg("obias") = false, py::arg("sm_scaler") = 1.f, py::arg("input_order") = 0, py::arg("attn_mask_type") = ::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK, py::arg("tensor_combination_type") = ::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE, py::arg("need_weights") = false, py::arg("add_zero_attn") = false, py::arg("reslink") = false, py::arg("training") = true, py::arg("seed") = 0, py::arg("attn_prob") = 0.f, py::arg("out_prob") = 0.f, py::arg("handle"), py::arg("scope") = {}) + .def(py::init(), py::arg("num_heads") = 1, py::arg("embeding_size") = 0, py::arg("k_size") = 0, py::arg("v_size") = 0, py::arg("qproj_size") = 0, py::arg("kproj_size") = 0, py::arg("vproj_size") = 0, py::arg("oproj_size") = 0, py::arg("qbias") = false, py::arg("kbias") = false, py::arg("vbias") = false, py::arg("obias") = false, py::arg("sm_scaler") = 1.f, py::arg("input_order") = 0, py::arg("attn_mask_type") = ::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK, py::arg("tensor_combination_type") = ::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE, py::arg("add_zero_attn") = false, py::arg("need_weights") = false, py::arg("reslink") = false, py::arg("training") = true, py::arg("seed") = 0, py::arg("attn_prob") = 0.f, py::arg("out_prob") = 0.f, py::arg("handle"), py::arg("scope") = {}) .def(py::init<>()) .def_readwrite("num_heads", &MultiHeadAttn::num_heads) .def_readwrite("embeding_size", &MultiHeadAttn::embeding_size) @@ -1528,8 +1528,8 @@ MultiHeadAttnInst .def_readwrite("input_order", &MultiHeadAttn::input_order) .def_readwrite("attn_mask_type", &MultiHeadAttn::attn_mask_type) .def_readwrite("tensor_combination_type", &MultiHeadAttn::tensor_combination_type) - .def_readwrite("need_weights", &MultiHeadAttn::need_weights) .def_readwrite("add_zero_attn", &MultiHeadAttn::add_zero_attn) + .def_readwrite("need_weights", &MultiHeadAttn::need_weights) .def_readwrite("reslink", &MultiHeadAttn::reslink) .def_readwrite("training", &MultiHeadAttn::training) .def_readwrite("seed", &MultiHeadAttn::seed) diff --git a/src/opr/impl/internal/megdnn_opr_wrapper.inl b/src/opr/impl/internal/megdnn_opr_wrapper.inl index f54ccec2f9fe19b90cdf00b04f2084503a34cdaf..c68d5f8d300024197269ab535207a4c0683a9797 100644 --- a/src/opr/impl/internal/megdnn_opr_wrapper.inl +++ b/src/opr/impl/internal/megdnn_opr_wrapper.inl @@ -184,6 +184,12 @@ using MegDNNOprMethInvoker = _MegDNNOprMethInvokerdev_tensor().as_megdnn(), {}); } -/* ==================== MultiHeadAttnForward ==================== */ +/* ==================== MultiHeadAttnForward ==================== */ +using INPUT_TYPE = MultiHeadAttnForward::Param::TENSOR_COMBINATION_TYPE; + MGB_DYN_TYPE_OBJ_FINAL_IMPL(MultiHeadAttnForward); MultiHeadAttnForward::MultiHeadAttnForward( - VarNode* queries, VarNode* keys, VarNode* values, VarNode* wqkv, + VarNode* queries, VarNode* keys, VarNode* values, VarNode* qkvo_weight_bias, + VarNode* attn_mask, VarNode* bias_k, VarNode* bias_v, const Param& param, + const OperatorNodeConfig& config) + : Super{{queries->owner_graph(), + config, + "multi_head_attn", + {queries, keys, values, qkvo_weight_bias, attn_mask, bias_k, bias_v}}, + param} { + mgb_assert( + param.tensor_combination_type == + MultiHeadAttnForward::Param::TENSOR_COMBINATION_TYPE::ALL); + add_input({queries, keys, values, qkvo_weight_bias, attn_mask, bias_k, bias_v}); + add_output(None)->dtype(queries->dtype()); + add_output(None) + ->dtype(queries->dtype()) + .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + add_output(None)->dtype(dtype::Byte()).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + add_output(None) + ->dtype(queries->dtype()) + .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + cg::add_workspace_output(this); + add_equivalence_component>(this); +} + +MultiHeadAttnForward::MultiHeadAttnForward( + VarNode* queries, VarNode* keys, VarNode* values, VarNode* qkvo_weight_bias, + VarNode* attn_mask, const Param& param, const OperatorNodeConfig& config) + : Super{{queries->owner_graph(), + config, + "multi_head_attn", + {queries, keys, values, qkvo_weight_bias, attn_mask}}, + param} { + mgb_assert( + param.tensor_combination_type == + MultiHeadAttnForward::Param::TENSOR_COMBINATION_TYPE::ONLY_MASK); + add_input({queries, keys, values, qkvo_weight_bias, attn_mask}); + add_output(None)->dtype(queries->dtype()); + add_output(None) + ->dtype(queries->dtype()) + .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + add_output(None)->dtype(dtype::Byte()).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + add_output(None) + ->dtype(queries->dtype()) + .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + cg::add_workspace_output(this); + add_equivalence_component>(this); +} + +MultiHeadAttnForward::MultiHeadAttnForward( + VarNode* queries, VarNode* keys, VarNode* values, VarNode* qkvo_weight_bias, + VarNode* bias_k, VarNode* bias_v, const Param& param, + const OperatorNodeConfig& config) + : Super{{queries->owner_graph(), + config, + "multi_head_attn", + {queries, keys, values, qkvo_weight_bias, bias_k, bias_v}}, + param} { + mgb_assert( + param.tensor_combination_type == + MultiHeadAttnForward::Param::TENSOR_COMBINATION_TYPE::ONLY_BIASKV); + add_input({queries, keys, values, qkvo_weight_bias, bias_k, bias_v}); + add_output(None)->dtype(queries->dtype()); + add_output(None) + ->dtype(queries->dtype()) + .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + add_output(None)->dtype(dtype::Byte()).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + add_output(None) + ->dtype(queries->dtype()) + .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + cg::add_workspace_output(this); + add_equivalence_component>(this); +} + +MultiHeadAttnForward::MultiHeadAttnForward( + VarNode* queries, VarNode* keys, VarNode* values, VarNode* qkvo_weight_bias, const Param& param, const OperatorNodeConfig& config) : Super{{queries->owner_graph(), config, "multi_head_attn", - {queries, keys, values, wqkv}}, + {queries, keys, values, qkvo_weight_bias}}, param} { - add_input({queries, keys, values, wqkv}); + mgb_assert( + param.tensor_combination_type == + MultiHeadAttnForward::Param::TENSOR_COMBINATION_TYPE::NONE); + add_input({queries, keys, values, qkvo_weight_bias}); + add_output(None)->dtype(queries->dtype()); add_output(None) ->dtype(queries->dtype()) .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); add_output(None)->dtype(dtype::Byte()).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + add_output(None) + ->dtype(queries->dtype()) + .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); cg::add_workspace_output(this); add_equivalence_component>(this); } SymbolVarArray MultiHeadAttnForward::make( - SymbolVar queries, SymbolVar keys, SymbolVar values, SymbolVar wqkv, - const Param& param, const OperatorNodeConfig& config) { + SymbolVar queries, SymbolVar keys, SymbolVar values, SymbolVar qkvo_weight_bias, + SymbolVar attn_mask, SymbolVar bias_k, SymbolVar bias_v, const Param& param, + const OperatorNodeConfig& config) { auto outs = queries.node() ->owner_graph() ->insert_opr(std::make_unique( - queries.node(), keys.node(), values.node(), wqkv.node(), + queries.node(), keys.node(), values.node(), + qkvo_weight_bias.node(), attn_mask.node(), + bias_k.node(), bias_v.node(), param, config)) + ->output(); + mgb_assert(outs.size() == 5); + return {outs[0], outs[1], outs[2], outs[3]}; +} +SymbolVarArray MultiHeadAttnForward::make( + SymbolVar queries, SymbolVar keys, SymbolVar values, SymbolVar qkvo_weight_bias, + SymbolVar attn_mask, const Param& param, const OperatorNodeConfig& config) { + auto outs = + queries.node() + ->owner_graph() + ->insert_opr(std::make_unique( + queries.node(), keys.node(), values.node(), + qkvo_weight_bias.node(), attn_mask.node(), param, config)) + ->output(); + mgb_assert(outs.size() == 5); + return {outs[0], outs[1], outs[2], outs[3]}; +} +SymbolVarArray MultiHeadAttnForward::make( + SymbolVar queries, SymbolVar keys, SymbolVar values, SymbolVar qkvo_weight_bias, + SymbolVar bias_k, SymbolVar bias_v, const Param& param, + const OperatorNodeConfig& config) { + auto outs = queries.node() + ->owner_graph() + ->insert_opr(std::make_unique( + queries.node(), keys.node(), values.node(), + qkvo_weight_bias.node(), bias_k.node(), bias_v.node(), param, config)) ->output(); - mgb_assert(outs.size() == 3); - return {outs[0], outs[1]}; + mgb_assert(outs.size() == 5); + return {outs[0], outs[1], outs[2], outs[3]}; +} +SymbolVarArray MultiHeadAttnForward::make( + SymbolVar queries, SymbolVar keys, SymbolVar values, SymbolVar qkvo_weight_bias, + const Param& param, const OperatorNodeConfig& config) { + auto outs = queries.node() + ->owner_graph() + ->insert_opr(std::make_unique( + queries.node(), keys.node(), values.node(), + qkvo_weight_bias.node(), param, config)) + ->output(); + mgb_assert(outs.size() == 5); + return {outs[0], outs[1], outs[2], outs[3]}; } void MultiHeadAttnForward::init_output_static_infer_desc() { @@ -461,18 +585,52 @@ void MultiHeadAttnForward::init_output_static_infer_desc() { auto&& mgr = owner_graph()->static_infer_manager(); mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0))); + auto infer_oshp1 = [this](TensorShape& dest, const InpVal& iv) { + TensorLayout in0{iv.val[0].shape(), input(0)->dtype()}; + TensorLayout in1{iv.val[1].shape(), input(1)->dtype()}; + TensorLayout in2{iv.val[2].shape(), input(2)->dtype()}; + TensorLayout in3{iv.val[3].shape(), input(3)->dtype()}; + TensorLayout in4{iv.val[4].shape(), input(4)->dtype()}; + TensorLayout in5{iv.val[5].shape(), input(5)->dtype()}; + TensorLayout in6{iv.val[6].shape(), input(6)->dtype()}; + TensorLayout o0, o1, o2, o3; + m_dnn_opr->deduce_layout(in0, in1, in2, in3, in4, in5, in6, o0, o1, o2, o3); + dest = o1; + return true; + }; + mgr.register_shape_infer( + output(1), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_oshp1}); + auto infer_mask = [this](TensorShape& dest, const InpVal& iv) { ensure_megdnn_opr(); dest.ndim = 1; - dest.shape[0] = m_dnn_opr->get_reservespace_in_bytes( + dest.shape[0] = m_dnn_opr->get_mask_reservespace_in_bytes( + {iv.val[0].shape(), input(0)->dtype()}, + {iv.val[1].shape(), input(1)->dtype()}, + {iv.val[2].shape(), input(2)->dtype()}, + {iv.val[3].shape(), input(3)->dtype()}, + {iv.val[4].shape(), input(4)->dtype()}, + {iv.val[5].shape(), input(5)->dtype()}, + {iv.val[6].shape(), input(6)->dtype()}, {}, {}, {}, {}); + return true; + }; + auto infer_othr = [this](TensorShape& dest, const InpVal& iv) { + ensure_megdnn_opr(); + dest.ndim = 1; + dest.shape[0] = m_dnn_opr->get_othr_reservespace_in_bytes( {iv.val[0].shape(), input(0)->dtype()}, {iv.val[1].shape(), input(1)->dtype()}, {iv.val[2].shape(), input(2)->dtype()}, - {iv.val[3].shape(), input(3)->dtype()}, {}, {}); + {iv.val[3].shape(), input(3)->dtype()}, + {iv.val[4].shape(), input(4)->dtype()}, + {iv.val[5].shape(), input(5)->dtype()}, + {iv.val[6].shape(), input(6)->dtype()}, {}, {}, {}, {}); return true; }; mgr.register_shape_infer( - output(1), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_mask}); + output(2), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_mask}); + mgr.register_shape_infer( + output(3), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_othr}); } void MultiHeadAttnForward::add_input_layout_constraint() { @@ -484,15 +642,52 @@ void MultiHeadAttnForward::add_input_layout_constraint() { void MultiHeadAttnForward::scn_do_execute() { auto&& ret = output(0); + auto input_type = m_dnn_opr->param().tensor_combination_type; if (ret->layout().is_empty()) { mgb_assert(ret->dev_tensor().empty()); return; } - m_dnn_opr->exec( - input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), - input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(), - output(0)->dev_tensor().as_megdnn(), output(1)->dev_tensor().as_megdnn(), - get_megdnn_workspace_from_var(output(2))); + megdnn::TensorND empty_dnn; + if (input_type == INPUT_TYPE::ALL) { + m_dnn_opr->exec( + input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), + input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(), + input(4)->dev_tensor().as_megdnn(), input(5)->dev_tensor().as_megdnn(), + input(6)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), + output(1)->dev_tensor().as_megdnn(), + output(2)->dev_tensor().as_megdnn(), + output(3)->dev_tensor().as_megdnn(), + get_megdnn_workspace_from_var(output(4))); + } else if (input_type == INPUT_TYPE::ONLY_MASK) { + m_dnn_opr->exec( + input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), + input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(), + input(4)->dev_tensor().as_megdnn(), empty_dnn, empty_dnn, + output(0)->dev_tensor().as_megdnn(), + output(1)->dev_tensor().as_megdnn(), + output(2)->dev_tensor().as_megdnn(), + output(3)->dev_tensor().as_megdnn(), + get_megdnn_workspace_from_var(output(4))); + } else if (input_type == INPUT_TYPE::ONLY_BIASKV) { + m_dnn_opr->exec( + input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), + input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(), + empty_dnn, input(4)->dev_tensor().as_megdnn(), + input(5)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), + output(1)->dev_tensor().as_megdnn(), + output(2)->dev_tensor().as_megdnn(), + output(3)->dev_tensor().as_megdnn(), + get_megdnn_workspace_from_var(output(4))); + } else { + m_dnn_opr->exec( + input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), + input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(), + empty_dnn, empty_dnn, empty_dnn, output(0)->dev_tensor().as_megdnn(), + output(1)->dev_tensor().as_megdnn(), + output(2)->dev_tensor().as_megdnn(), + output(3)->dev_tensor().as_megdnn(), + get_megdnn_workspace_from_var(output(4))); + } } cg::OperatorNodeBase::NodeProp* MultiHeadAttnForward::do_make_node_prop() const { @@ -510,12 +705,23 @@ MGB_IMPL_OPR_GRAD(MultiHeadAttnForward) { MGB_MARK_USED_VAR(out_grad); SymbolVarArray grad; VarNodeArray ret; - mgb_assert(wrt_idx < 5, "wrt_idx %zu is out of range", wrt_idx); - grad = MultiHeadAttnBackward::make( - out_grad[0], opr.input(0), opr.input(1), opr.input(2), opr.input(3), - opr.output(1), opr.param()); - - uint32_t nr_ret = 4; + mgb_assert(wrt_idx < 7, "wrt_idx %zu is out of range", wrt_idx); + auto input_type = opr.param().tensor_combination_type; + if (input_type == INPUT_TYPE::ALL or input_type == INPUT_TYPE::ONLY_MASK) + grad = MultiHeadAttnBackward::make( + out_grad[0], opr.input(0), opr.input(1), opr.input(2), opr.input(3), + opr.input(4), opr.output(1), opr.output(2), opr.output(3), opr.param()); + else + grad = MultiHeadAttnBackward::make( + out_grad[0], opr.input(0), opr.input(1), opr.input(2), opr.input(3), + opr.output(1), opr.output(2), opr.output(3), opr.param()); + uint32_t nr_ret = 7; + if (input_type == INPUT_TYPE::NONE) + nr_ret = 4; + if (input_type == INPUT_TYPE::ONLY_MASK) + nr_ret = 5; + if (input_type == INPUT_TYPE::ONLY_BIASKV) + nr_ret = 6; for (uint32_t i = 0; i < nr_ret; ++i) { ret.push_back(grad[i].node()); } @@ -527,29 +733,77 @@ MGB_IMPL_OPR_GRAD(MultiHeadAttnForward) { MGB_DYN_TYPE_OBJ_FINAL_IMPL(MultiHeadAttnBackward); MultiHeadAttnBackward::MultiHeadAttnBackward( - VarNode* diff, VarNode* queries, VarNode* keys, VarNode* values, VarNode* wqkv, - VarNode* reserveSpace, const Param& param, const OperatorNodeConfig& config) + VarNode* diff, VarNode* queries, VarNode* keys, VarNode* values, + VarNode* qkvo_weight_bias, VarNode* attn_mask, VarNode* attn_weight, + VarNode* mask_reservespace, VarNode* othr_reservespace, const Param& param, + const OperatorNodeConfig& config) + : Super({queries->owner_graph(), + config, + "multi_head_attn_backward", + {diff, queries, keys, values, qkvo_weight_bias, attn_mask, attn_weight, + mask_reservespace, othr_reservespace}}, + 0, true) { + init_megdnn_opr(*this, param); + add_input( + {diff, queries, keys, values, qkvo_weight_bias, attn_mask, attn_weight, + mask_reservespace, othr_reservespace}); + this->output()[3]->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + this->output()[4]->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + this->output()[5]->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); +} +MultiHeadAttnBackward::MultiHeadAttnBackward( + VarNode* diff, VarNode* queries, VarNode* keys, VarNode* values, + VarNode* qkvo_weight_bias, VarNode* attn_weight, VarNode* mask_reservespace, + VarNode* othr_reservespace, const Param& param, + const OperatorNodeConfig& config) : Super({queries->owner_graph(), config, "multi_head_attn_backward", - {diff, queries, keys, values, wqkv, reserveSpace}}, + {diff, queries, keys, values, qkvo_weight_bias, attn_weight, + mask_reservespace, othr_reservespace}}, 0, true) { init_megdnn_opr(*this, param); - add_input({diff, queries, keys, values, wqkv, reserveSpace}); + add_input( + {diff, queries, keys, values, qkvo_weight_bias, attn_weight, + mask_reservespace, othr_reservespace}); + this->output()[3]->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + this->output()[4]->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + this->output()[5]->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); } SymbolVarArray MultiHeadAttnBackward::make( SymbolVar diff, SymbolVar queries, SymbolVar keys, SymbolVar values, - SymbolVar wqkv, SymbolVar reserveSpace, const Param& param, + SymbolVar qkvo_weight_bias, SymbolVar attn_mask, SymbolVar attn_weight, + SymbolVar mask_reservespace, SymbolVar othr_reservespace, const Param& param, const OperatorNodeConfig& config) { auto outs = queries.node() ->owner_graph() ->insert_opr(std::make_unique( diff.node(), queries.node(), keys.node(), values.node(), - wqkv.node(), reserveSpace.node(), param, config)) + qkvo_weight_bias.node(), attn_mask.node(), + attn_weight.node(), mask_reservespace.node(), + othr_reservespace.node(), param, config)) ->output(); - mgb_assert(outs.size() == 5); - return {outs[0], outs[1], outs[2], outs[3]}; + mgb_assert(outs.size() == 7); + + return {outs[0], outs[1], outs[2], outs[3], outs[4], outs[5], {}}; +} +SymbolVarArray MultiHeadAttnBackward::make( + SymbolVar diff, SymbolVar queries, SymbolVar keys, SymbolVar values, + SymbolVar qkvo_weight_bias, SymbolVar attn_weight, SymbolVar mask_reservespace, + SymbolVar othr_reservespace, const Param& param, + const OperatorNodeConfig& config) { + auto outs = queries.node() + ->owner_graph() + ->insert_opr(std::make_unique( + diff.node(), queries.node(), keys.node(), values.node(), + qkvo_weight_bias.node(), attn_weight.node(), + mask_reservespace.node(), othr_reservespace.node(), + param, config)) + ->output(); + mgb_assert(outs.size() == 7); + + return {outs[0], outs[1], outs[2], outs[3], outs[4], outs[5], {}}; } void MultiHeadAttnBackward::init_output_static_infer_desc() { @@ -559,7 +813,15 @@ void MultiHeadAttnBackward::init_output_static_infer_desc() { mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2))); mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(3))); mgr.register_shape_infer(output(3), ShapeInferDesc::make_identity(input(4))); - + auto input_type = param().tensor_combination_type; + if (input_type == INPUT_TYPE::ALL or input_type == INPUT_TYPE::ONLY_BIASKV) { + mgr.register_shape_infer(output(4), ShapeInferDesc::make_identity(input(4))); + mgr.register_shape_infer(output(5), ShapeInferDesc::make_identity(input(4))); + } else { + TensorShape empty{0}; + mgr.register_shape_infer(output(4), ShapeInferDesc::make_const(empty)); + mgr.register_shape_infer(output(5), ShapeInferDesc::make_const(empty)); + } this->init_output_static_infer_desc_workspace(false); } @@ -568,25 +830,79 @@ void MultiHeadAttnBackward::init_output_dtype() { output(1)->dtype(input(2)->dtype()); output(2)->dtype(input(3)->dtype()); output(3)->dtype(input(4)->dtype()); + output(4)->dtype(input(2)->dtype()); + output(5)->dtype(input(3)->dtype()); } size_t MultiHeadAttnBackward::get_workspace_size_bytes( const TensorShapeArray& input_shapes, const TensorShapeArray& output_shapes) const { - MGB_MARK_USED_VAR(input_shapes); - MGB_MARK_USED_VAR(output_shapes); - - return 0; + auto input_type = megdnn_opr()->param().tensor_combination_type; + megdnn::TensorLayout empty_dnn; + if (input_type == INPUT_TYPE::ALL or input_type == INPUT_TYPE::ONLY_MASK) + return megdnn_opr()->get_workspace_in_bytes( + {input_shapes[0], input(0)->dtype(), input(0)->format()}, + {input_shapes[1], input(1)->dtype(), input(1)->format()}, + {input_shapes[2], input(2)->dtype(), input(2)->format()}, + {input_shapes[3], input(3)->dtype(), input(3)->format()}, + {input_shapes[4], input(4)->dtype(), input(4)->format()}, + {input_shapes[5], input(5)->dtype(), input(5)->format()}, + {input_shapes[6], input(6)->dtype(), input(6)->format()}, + {input_shapes[7], input(7)->dtype(), input(7)->format()}, + {input_shapes[8], input(8)->dtype(), input(8)->format()}, + {output_shapes[0], output(0)->dtype(), output(0)->format()}, + {output_shapes[1], output(1)->dtype(), output(1)->format()}, + {output_shapes[2], output(2)->dtype(), output(2)->format()}, + {output_shapes[3], output(3)->dtype(), output(3)->format()}, + {output_shapes[4], output(4)->dtype(), output(4)->format()}, + {output_shapes[5], output(5)->dtype(), output(5)->format()}); + else + return megdnn_opr()->get_workspace_in_bytes( + {input_shapes[0], input(0)->dtype(), input(0)->format()}, + {input_shapes[1], input(1)->dtype(), input(1)->format()}, + {input_shapes[2], input(2)->dtype(), input(2)->format()}, + {input_shapes[3], input(3)->dtype(), input(3)->format()}, + {input_shapes[4], input(4)->dtype(), input(4)->format()}, empty_dnn, + {input_shapes[5], input(5)->dtype(), input(5)->format()}, + {input_shapes[6], input(6)->dtype(), input(6)->format()}, + {input_shapes[7], input(7)->dtype(), input(7)->format()}, + {output_shapes[0], output(0)->dtype(), output(0)->format()}, + {output_shapes[1], output(1)->dtype(), output(1)->format()}, + {output_shapes[2], output(2)->dtype(), output(2)->format()}, + {output_shapes[3], output(3)->dtype(), output(3)->format()}, + {output_shapes[4], output(4)->dtype(), output(4)->format()}, + {output_shapes[5], output(5)->dtype(), output(5)->format()}); } void MultiHeadAttnBackward::scn_do_execute() { - megdnn_opr()->exec( - input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), - input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(), - input(4)->dev_tensor().as_megdnn(), input(5)->dev_tensor().as_megdnn(), - output(0)->dev_tensor().as_megdnn(), output(1)->dev_tensor().as_megdnn(), - output(2)->dev_tensor().as_megdnn(), output(3)->dev_tensor().as_megdnn(), - {}); + auto input_type = megdnn_opr()->param().tensor_combination_type; + megdnn::TensorND empty_dnn; + if (input_type == INPUT_TYPE::ALL or input_type == INPUT_TYPE::ONLY_MASK) + megdnn_opr()->exec( + input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), + input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(), + input(4)->dev_tensor().as_megdnn(), input(5)->dev_tensor().as_megdnn(), + input(6)->dev_tensor().as_megdnn(), input(7)->dev_tensor().as_megdnn(), + input(8)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), + output(1)->dev_tensor().as_megdnn(), + output(2)->dev_tensor().as_megdnn(), + output(3)->dev_tensor().as_megdnn(), + output(4)->dev_tensor().as_megdnn(), + output(5)->dev_tensor().as_megdnn(), + get_megdnn_workspace_from_var(output(6))); + else + megdnn_opr()->exec( + input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), + input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(), + input(4)->dev_tensor().as_megdnn(), empty_dnn, + input(5)->dev_tensor().as_megdnn(), input(6)->dev_tensor().as_megdnn(), + input(7)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), + output(1)->dev_tensor().as_megdnn(), + output(2)->dev_tensor().as_megdnn(), + output(3)->dev_tensor().as_megdnn(), + output(4)->dev_tensor().as_megdnn(), + output(5)->dev_tensor().as_megdnn(), + get_megdnn_workspace_from_var(output(6))); } // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/rand.sereg.h b/src/opr/impl/rand.sereg.h index 0333543342cea1c3892711cc14f0b3549aaa6605..0552b2d67191dcdea1c9598c8a07e1b49ff6f9e8 100644 --- a/src/opr/impl/rand.sereg.h +++ b/src/opr/impl/rand.sereg.h @@ -33,13 +33,35 @@ struct OprMaker { template <> struct OprMaker { using Param = opr::MultiHeadAttn::Param; + using INPUT_TYPE = Param::TENSOR_COMBINATION_TYPE; static cg::OperatorNodeBase* make( const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, const OperatorNodeConfig& config) { MGB_MARK_USED_VAR(graph); - return opr::MultiHeadAttn::make(i[0], i[1], i[2], i[3], param, config)[0] - .node() - ->owner_opr(); + if (i.size() == 7) { + mgb_assert(INPUT_TYPE::ALL == param.tensor_combination_type); + return opr::MultiHeadAttn::make( + i[0], i[1], i[2], i[3], i[4], i[5], i[6], param, config)[0] + .node() + ->owner_opr(); + } else if (i.size() == 6) { + mgb_assert(INPUT_TYPE::ONLY_BIASKV == param.tensor_combination_type); + return opr::MultiHeadAttn::make( + i[0], i[1], i[2], i[3], i[4], i[5], param, config)[0] + .node() + ->owner_opr(); + } else if (i.size() == 5) { + mgb_assert(INPUT_TYPE::ONLY_MASK == param.tensor_combination_type); + return opr::MultiHeadAttn::make( + i[0], i[1], i[2], i[3], i[4], param, config)[0] + .node() + ->owner_opr(); + } else { + mgb_assert(INPUT_TYPE::NONE == param.tensor_combination_type); + return opr::MultiHeadAttn::make(i[0], i[1], i[2], i[3], param, config)[0] + .node() + ->owner_opr(); + } } }; @@ -52,10 +74,18 @@ struct OprMaker { const OperatorNodeConfig& config) { MGB_MARK_USED_VAR(graph); - return opr::MultiHeadAttnBackward::make( - i[0], i[1], i[2], i[3], i[4], i[5], param, config)[0] - .node() - ->owner_opr(); + if (i.size() == 8) + return opr::MultiHeadAttnBackward::make( + i[0], i[1], i[2], i[3], i[4], i[5], i[6], i[7], param, + config)[0] + .node() + ->owner_opr(); + else + return opr::MultiHeadAttnBackward::make( + i[0], i[1], i[2], i[3], i[4], i[5], i[6], i[7], i[8], param, + config)[0] + .node() + ->owner_opr(); } }; diff --git a/src/opr/include/megbrain/opr/rand.h b/src/opr/include/megbrain/opr/rand.h index d5d79352cd97fc806433e94321bfaf4e0d68bf98..3b02d8bdef2f20f45a3d39d91d621bcbcd7644c5 100644 --- a/src/opr/include/megbrain/opr/rand.h +++ b/src/opr/include/megbrain/opr/rand.h @@ -87,14 +87,50 @@ _DEFINE_RNG_OPR_WITH_INPUT_CLASS(GammaRNG) #undef _OUTPUTS #undef _INPUTS -/* ================= 4 input ================= */ -#define _INPUTS(preifx) preifx i0, preifx i1, preifx i2, preifx i3 -#define _OUTPUTS SymbolVarArray -_DEFINE_RNG_OPR_WITH_INPUT_CLASS(MultiHeadAttnForward) -#undef _OUTPUTS -#undef _INPUTS #undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS +MGB_DEFINE_OPR_CLASS_WITH_EXPORT( + MultiHeadAttnForward, RNGOprBase) // { + void add_input_layout_constraint() override; + cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; + +public: + MultiHeadAttnForward( + VarNode* queries, VarNode* keys, VarNode* values, VarNode* qkvo_weight_bias, + VarNode* attn_mask, VarNode* bias_k, VarNode* bias_v, const Param& param, + const OperatorNodeConfig& config); + MultiHeadAttnForward( + VarNode* queries, VarNode* keys, VarNode* values, VarNode* qkvo_weight_bias, + VarNode* attn_mask, const Param& param, const OperatorNodeConfig& config); + MultiHeadAttnForward( + VarNode* queries, VarNode* keys, VarNode* values, VarNode* qkvo_weight_bias, + VarNode* bias_k, VarNode* bias_v, const Param& param, + const OperatorNodeConfig& config); + MultiHeadAttnForward( + VarNode* queries, VarNode* keys, VarNode* values, VarNode* qkvo_weight_bias, + const Param& param, const OperatorNodeConfig& config); + + MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( + SymbolVar queries, SymbolVar keys, SymbolVar values, + SymbolVar qkvo_weight_bias, SymbolVar attn_mask, SymbolVar bias_k, + SymbolVar bias_v, const Param& param = {}, + const OperatorNodeConfig& config = {}); + MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( + SymbolVar queries, SymbolVar keys, SymbolVar values, + SymbolVar qkvo_weight_bias, SymbolVar attn_mask, const Param& param = {}, + const OperatorNodeConfig& config = {}); + MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( + SymbolVar queries, SymbolVar keys, SymbolVar values, + SymbolVar qkvo_weight_bias, SymbolVar bias_k, SymbolVar bias_v, + const Param& param = {}, const OperatorNodeConfig& config = {}); + MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( + SymbolVar queries, SymbolVar keys, SymbolVar values, + SymbolVar qkvo_weight_bias, const Param& param = {}, + const OperatorNodeConfig& config = {}); + void init_output_static_infer_desc() override; + void scn_do_execute() override; +}; + } // namespace intl using UniformRNG = intl::UniformRNG; @@ -146,13 +182,27 @@ MGB_DEFINE_OPR_CLASS_WITH_EXPORT( public: MGE_WIN_DECLSPEC_FUC MultiHeadAttnBackward( VarNode* diff, VarNode* queries, VarNode* keys, VarNode* values, - VarNode* wqkv, VarNode* reserveSpace, const Param& param, + VarNode* qkvo_weight_bias, VarNode* attn_mask, VarNode* attn_weight, + VarNode* mask_reservespace, VarNode* othr_reservespace, const Param& param, + const OperatorNodeConfig& config); + + MGE_WIN_DECLSPEC_FUC MultiHeadAttnBackward( + VarNode* diff, VarNode* queries, VarNode* keys, VarNode* values, + VarNode* qkvo_weight_bias, VarNode* attn_weight, VarNode* mask_reservespace, + VarNode* othr_reservespace, const Param& param, const OperatorNodeConfig& config); MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( SymbolVar diff, SymbolVar queries, SymbolVar keys, SymbolVar values, - SymbolVar wqkv, SymbolVar reserveSpace, const Param& param = {}, - const OperatorNodeConfig& config = {}); + SymbolVar qkvo_weight_bias, SymbolVar attn_mask, SymbolVar attn_weight, + SymbolVar mask_reservespace, SymbolVar othr_reservespace, + const Param& param = {}, const OperatorNodeConfig& config = {}); + + MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( + SymbolVar diff, SymbolVar queries, SymbolVar keys, SymbolVar values, + SymbolVar qkvo_weight_bias, SymbolVar attn_weight, + SymbolVar mask_reservespace, SymbolVar othr_reservespace, + const Param& param = {}, const OperatorNodeConfig& config = {}); private: void init_output_static_infer_desc() override;