diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 17bac96bf1a396e2cbd7d68e55d78bcc19ac57a5..2428ea7949a90d9dea252cf2fd74f99521f2df93 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -1346,12 +1346,12 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), .add_fields('bool', Doc('obias', 'Whether to add out bias.'), 'false') .add_fields('float32', Doc('sm_scaler', 'Softmax smoothing (1.0 >= smScaler >= 0.0) or sharpening (smScaler > 1.0) coefficient.'), '1.f') .add_fields('uint32', Doc('input_order', 'The sequence data layout, allows the user to select 3! = 6 different data layouts or permutations of BEAM, BATCH and TIME dimensions.'), '0') - .add_enum('ATTN_MASK_TYPE', + .add_enum('AttnMaskType', Doc('NO_MASK = 0', 'Indicates that there is no mask.'), Doc('DEFAULT_MASK = 1', 'Use the default mask which the upper right triangle of the mask is -inf, and the diagonal and lower left triangle are all 0.'), Doc('CUDNN_STYLE_MASK = 2', 'Indicates the use of a cudnn style mask.'), Doc('USER_DEFINED_MASK = 3', 'Use the user-defined mask.'), name_field="attn_mask_type") - .add_enum(Doc('TENSOR_COMBINATION_TYPE', 'Used to determine whether mask tensor and bias_kv tensor exist in the input. Note that bias_kv here is not kbias and vbias in the linear layer, and bias_kv here will be added to the K and V at sequence dimensions, where K and V are the matrices of key and value after projection, and K and V will be used to calculate the attention matrix.'), + .add_enum(Doc('TensorCombinationType', 'Used to determine whether mask tensor and bias_kv tensor exist in the input. Note that bias_kv here is not m_kbias and m_vbias in the linear layer, and bias_kv here will be added to the K and V at sequence dimensions, where K and V are the matrices of key and value after projection, and K and V will be used to calculate the attention matrix.'), Doc('NONE = 0', 'Indicates that there are no mask tensor and bias_kv tensor in the input.'), Doc('ONLY_MASK = 1', 'Indicates that there is only mask tensor in input.'), @@ -1363,5 +1363,5 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), .add_fields('bool', Doc('training', 'Whether it is in training mode.'), 'true') .add_fields('uint64', Doc('seed', 'Random number seed for drop'), '0') .add_fields('float32', Doc('attn_prob', 'Dropout probability on attention, is applied directly to the softmax output'), '0.f') - .add_fields('float32', Doc('out_prob', 'Dropout probability on output, alters the multi-head attention output'), '0.f') + .add_fields('float32', Doc('out_prob', 'Dropout probability on output, alters the multi-m_head attention output'), '0.f') ) diff --git a/dnn/src/common/multi_head_attn.cpp b/dnn/src/common/multi_head_attn.cpp index 0394706220fc5ac4b00b3b4c87b9bdde6e27d0b9..90f4f3e3313f2a4b19791003134fff0391f73d4a 100644 --- a/dnn/src/common/multi_head_attn.cpp +++ b/dnn/src/common/multi_head_attn.cpp @@ -9,7 +9,8 @@ namespace megdnn { using Param = MultiHeadAttnBase::Param; -using INPUT_TYPE = Param::TENSOR_COMBINATION_TYPE; +using InputType = Param::TensorCombinationType; +using MaskType = Param::AttnMaskType; void MultiHeadAttnForward::check_exec( const TensorLayout& queries, const TensorLayout& keys, @@ -33,12 +34,12 @@ void MultiHeadAttnForward::check_exec( bool have_mask = false; bool have_biaskv = false; auto input_type = p.tensor_combination_type; - if (input_type == INPUT_TYPE::ONLY_BIASKV or input_type == INPUT_TYPE::ALL) { + if (input_type == InputType::ONLY_BIASKV or input_type == InputType::ALL) { have_biaskv = true; megdnn_assert_contiguous(bias_k); megdnn_assert_contiguous(bias_v); } - if (input_type == INPUT_TYPE::ONLY_MASK or input_type == INPUT_TYPE::ALL) { + if (input_type == InputType::ONLY_MASK or input_type == InputType::ALL) { have_mask = true; megdnn_assert_contiguous(attn_mask); } @@ -162,9 +163,10 @@ void MultiHeadAttnForward::check_exec( if (qprojsize == 0 and kprojsize == 0) megdnn_assert(embeding_size == ksize, "%s", param_errmsg().c_str()); if (qprojsize == 0 and kprojsize != 0) - megdnn_assert(embeding_size == kprojsize, "%s", param_errmsg().c_str()); + megdnn_assert( + embeding_size * p.num_heads == kprojsize, "%s", param_errmsg().c_str()); if (qprojsize != 0 and kprojsize == 0) - megdnn_assert(qprojsize == ksize, "%s", param_errmsg().c_str()); + megdnn_assert(qprojsize == ksize * p.num_heads, "%s", param_errmsg().c_str()); if (qprojsize != 0 and kprojsize != 0) megdnn_assert(qprojsize == kprojsize, "%s", param_errmsg().c_str()); if (p.qbias) @@ -184,7 +186,7 @@ void MultiHeadAttnForward::check_exec( if (p.oproj_size > 0 and p.vproj_size > 0) weight_len += vprojsize * oprojsize + (p.obias ? oprojsize : 0); else if (p.oproj_size > 0 and p.vproj_size == 0) - weight_len += vsize * oprojsize + (p.obias ? oprojsize : 0); + weight_len += p.num_heads * vsize * oprojsize + (p.obias ? oprojsize : 0); megdnn_assert( weight_len == qkvo_weight_bias.total_nr_elems(), "qkvo_weight_bias length should be %zu, but got %zu. details: %s", @@ -208,7 +210,7 @@ void MultiHeadAttnBackward::deduce_layout( dvalues = values; dqkvo_weight_bias = qkvo_weight_bias; auto input_type = param().tensor_combination_type; - if (input_type == INPUT_TYPE::ONLY_BIASKV or input_type == INPUT_TYPE::ALL) { + if (input_type == InputType::ONLY_BIASKV or input_type == InputType::ALL) { dbias_k = TensorLayout( {1, 1, param().kproj_size ? param().kproj_size : param().k_size}, keys.dtype); @@ -255,8 +257,8 @@ void MultiHeadAttnBackward::check_exec( auto input_type = p.tensor_combination_type; bool have_mask = false; bool have_biaskv = - input_type == INPUT_TYPE::ONLY_BIASKV or input_type == INPUT_TYPE::ALL; - if (input_type == INPUT_TYPE::ONLY_MASK or input_type == INPUT_TYPE::ALL) { + input_type == InputType::ONLY_BIASKV or input_type == InputType::ALL; + if (input_type == InputType::ONLY_MASK or input_type == InputType::ALL) { have_mask = true; megdnn_assert_contiguous(attn_mask); } @@ -296,8 +298,9 @@ void MultiHeadAttnBackward::check_exec( }; // layout check - size_t osize = p.oproj_size != 0 ? p.oproj_size - : (p.vproj_size != 0 ? p.vproj_size : p.v_size); + size_t osize = p.oproj_size != 0 + ? p.oproj_size + : (p.vproj_size != 0 ? p.vproj_size : p.v_size * p.num_heads); TensorLayout diff_expect = TensorLayout( TensorShape{queries.shape[0], queries.shape[1], osize}, queries.dtype); megdnn_assert(equal_layout(diff, diff_expect), "%s", errmsg().c_str()); @@ -409,9 +412,10 @@ void MultiHeadAttnBackward::check_exec( if (qprojsize == 0 and kprojsize == 0) megdnn_assert(embeding_size == ksize, "%s", param_errmsg().c_str()); if (qprojsize == 0 and kprojsize != 0) - megdnn_assert(embeding_size == kprojsize, "%s", param_errmsg().c_str()); + megdnn_assert( + embeding_size * p.num_heads == kprojsize, "%s", param_errmsg().c_str()); if (qprojsize != 0 and kprojsize == 0) - megdnn_assert(qprojsize == ksize, "%s", param_errmsg().c_str()); + megdnn_assert(qprojsize == ksize * p.num_heads, "%s", param_errmsg().c_str()); if (qprojsize != 0 and kprojsize != 0) megdnn_assert(qprojsize == kprojsize, "%s", param_errmsg().c_str()); if (p.qbias) @@ -431,7 +435,7 @@ void MultiHeadAttnBackward::check_exec( if (p.oproj_size > 0 and p.vproj_size > 0) weight_len += vprojsize * oprojsize + (p.obias ? oprojsize : 0); else if (p.oproj_size > 0 and p.vproj_size == 0) - weight_len += vsize * oprojsize + (p.obias ? oprojsize : 0); + weight_len += p.num_heads * vsize * oprojsize + (p.obias ? oprojsize : 0); megdnn_assert( weight_len == qkvo_weight_bias.total_nr_elems(), "qkvo_weight_bias length should be %zu, but got %zu. details: %s", diff --git a/dnn/src/common/multi_head_attn/helper.h b/dnn/src/common/multi_head_attn/helper.h new file mode 100644 index 0000000000000000000000000000000000000000..ea3d4d61b2aac5390f7961b583f3f795098a0858 --- /dev/null +++ b/dnn/src/common/multi_head_attn/helper.h @@ -0,0 +1,126 @@ +#pragma once +#include "megdnn/dtype.h" + +#include "megdnn/basic_types.h" +#include "megdnn/handle.h" +#include "megdnn/oprs/linalg.h" +#include "megdnn/oprs/nn.h" +#include "src/common/utils.h" + +namespace megdnn { + +namespace multi_head_attn { + +inline void matmul_deduce_layout( + std::unique_ptr& opr, const TensorLayout& A, + const TensorLayout& B, TensorLayout& C) { + megdnn_assert(A.ndim == 3 && B.ndim == 2); + auto m_param = opr->param(); + + size_t A1, A2, B0, B1; + A1 = A.shape[1]; + A2 = A.shape[2]; + B0 = B.shape[0]; + B1 = B.shape[1]; + if (m_param.transposeA) { + std::swap(A1, A2); + } + if (m_param.transposeB) { + std::swap(B0, B1); + } + C = TensorLayout(TensorShape({A.shape[0], A1, B1}), A.dtype); +} + +inline void matmul_exec( + std::unique_ptr& opr, _megdnn_tensor_in A, + _megdnn_tensor_in B, _megdnn_tensor_out C, _megdnn_workspace workspace) { + auto Batch = A.layout.shape[0]; + + auto Astrd = A.layout.dtype.size() * A.layout.stride[0], + Cstrd = C.layout.dtype.size() * C.layout.stride[0]; + + auto Aref = A.get_ref_ptr(); + auto Bref = B.get_ref_ptr(); + auto Cref = C.get_ref_ptr(); + + rep(b, Batch) { + //! all tensors should share the same RefPtr + auto A_ref = Aref; + A_ref += b * Astrd; + auto B_ref = Bref; + auto C_ref = Cref; + C_ref += b * Cstrd; + TensorND A_{A.layout.remove_axis(0), A_ref}; + TensorND B_{B.layout, B_ref}; + TensorND C_{C.layout.remove_axis(0), C_ref}; + opr->exec(A_, B_, C_, workspace); + } +} + +using Param = MultiHeadAttnBase::Param; +using MaskType = Param::AttnMaskType; +using InputType = Param::TensorCombinationType; + +/***************************** MHA base *****************************/ +#define _MHA_FORWARD(INPUT_TYPE, OUTPUT_TYPE) \ + INPUT_TYPE queries, INPUT_TYPE keys, INPUT_TYPE values, \ + INPUT_TYPE qkvo_weight_bias, INPUT_TYPE attn_mask, INPUT_TYPE bias_k, \ + INPUT_TYPE bias_v, OUTPUT_TYPE out, OUTPUT_TYPE attn_weight, \ + OUTPUT_TYPE mask_reservespace, OUTPUT_TYPE othr_reservespace +#define _MHA_BACKWARD(INPUT_TYPE, OUTPUT_TYPE) \ + INPUT_TYPE diff, INPUT_TYPE queries, INPUT_TYPE keys, INPUT_TYPE values, \ + INPUT_TYPE qkvo_weight_bias, INPUT_TYPE attn_mask, INPUT_TYPE attn_weight, \ + INPUT_TYPE mask_reservespace, INPUT_TYPE othr_reservespace, \ + OUTPUT_TYPE dqueries, OUTPUT_TYPE dkeys, OUTPUT_TYPE dvalues, \ + OUTPUT_TYPE dqkvo_weight_bias, OUTPUT_TYPE dbias_k, OUTPUT_TYPE dbias_v +#define _MHA_PROXY_PRE(HANDLE_TYPE, PARAM_TYPE) HANDLE_TYPE handle, PARAM_TYPE param + +#define MHA_EXEC_PARAM(cb) \ + cb(_megdnn_tensor_in, _megdnn_tensor_out), _megdnn_workspace workspace +#define MHA_LAYOUT_CONST_PARAM(cb) cb(const TensorLayout&, const TensorLayout&) +#define MHA_LAYOUT_PARAM(cb) cb(const TensorLayout&, TensorLayout&) +#define MHA_CALL(cb) cb(, ) +#define MHA_PROXY_PRE_PARAM _MHA_PROXY_PRE(Handle*, Param&) +#define MHA_PROXY_PRE_CALL _MHA_PROXY_PRE(, ) + +/***************************** MHA forward *****************************/ +#define MHA_FORWARD_EXEC_PARAM MHA_EXEC_PARAM(_MHA_FORWARD) +#define MHA_FORWARD_LAYOUT_CONST_PARAM MHA_LAYOUT_CONST_PARAM(_MHA_FORWARD) +#define MHA_FORWARD_LAYOUT_PARAM MHA_LAYOUT_PARAM(_MHA_FORWARD) +#define MHA_FORWARD_CALL MHA_CALL(_MHA_FORWARD) + +#define MHA_PROXY_FORWARD_EXEC_PARAM MHA_PROXY_PRE_PARAM, MHA_FORWARD_EXEC_PARAM +#define MHA_PROXY_FORWARD_LAYOUT_PARAM MHA_PROXY_PRE_PARAM, MHA_FORWARD_LAYOUT_PARAM +#define MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM \ + MHA_PROXY_PRE_PARAM, MHA_FORWARD_LAYOUT_CONST_PARAM +#define MHA_PROXY_FORWARD_CALL MHA_PROXY_PRE_CALL, MHA_FORWARD_CALL + +/***************************** MHA backward *****************************/ +#define MHA_BACKWARD_EXEC_PARAM MHA_EXEC_PARAM(_MHA_BACKWARD) +#define MHA_BACKWARD_LAYOUT_CONST_PARAM MHA_LAYOUT_CONST_PARAM(_MHA_BACKWARD) +#define MHA_BACKWARD_LAYOUT_PARAM MHA_LAYOUT_PARAM(_MHA_BACKWARD) +#define MHA_BACKWARD_CALL MHA_CALL(_MHA_BACKWARD) + +#define MHA_PROXY_BACKWARD_EXEC_PARAM MHA_PROXY_PRE_PARAM, MHA_BACKWARD_EXEC_PARAM +#define MHA_PROXY_BACKWARD_LAYOUT_PARAM MHA_PROXY_PRE_PARAM, MHA_BACKWARD_LAYOUT_PARAM +#define MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM \ + MHA_PROXY_PRE_PARAM, MHA_BACKWARD_LAYOUT_CONST_PARAM +#define MHA_PROXY_BACKWARD_CALL MHA_PROXY_PRE_CALL, MHA_BACKWARD_CALL + +/***************************** MHA other *****************************/ +#define MHA_FORWARD_TENSOR_TO_LAYOUT_CALL \ + queries.layout, keys.layout, values.layout, qkvo_weight_bias.layout, \ + attn_mask.layout, bias_k.layout, bias_v.layout, out.layout, \ + attn_weight.layout, mask_reservespace.layout, othr_reservespace.layout +#define MHA_BACKWARD_TENSOR_TO_LAYOUT_CALL \ + diff.layout, queries.layout, keys.layout, values.layout, qkvo_weight_bias.layout, \ + attn_mask.layout, attn_weight.layout, mask_reservespace.layout, \ + othr_reservespace.layout, dqueries.layout, dkeys.layout, dvalues.layout, \ + dqkvo_weight_bias.layout, dbias_k.layout, dbias_v.layout +#define MHA_PROXY_FORWARD_TENSOR_TO_LAYOUT_CALL \ + MHA_PROXY_PRE_CALL, MHA_FORWARD_TENSOR_TO_LAYOUT_CALL +#define MHA_PROXY_BACKWARD_TENSOR_TO_LAYOUT_CALL \ + MHA_PROXY_PRE_CALL, MHA_BACKWARD_TENSOR_TO_LAYOUT_CALL + +} // namespace multi_head_attn +} // namespace megdnn diff --git a/dnn/src/common/multi_head_attn/proxy_backward_base.cpp b/dnn/src/common/multi_head_attn/proxy_backward_base.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b93d4c7bd5d179b0730e4f261b0ca66d76974b50 --- /dev/null +++ b/dnn/src/common/multi_head_attn/proxy_backward_base.cpp @@ -0,0 +1,796 @@ +#include "src/common/multi_head_attn/proxy_backward_base.h" +#include "megdnn/basic_types.h" +#include "megdnn/oprs/nn.h" + +namespace megdnn { + +namespace multi_head_attn { + +bool MHABackwardProxyBase::layout_ismatch(MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM) { + MEGDNN_MARK_USED_VAR(handle); + MEGDNN_MARK_USED_VAR(diff); + MEGDNN_MARK_USED_VAR(attn_mask); + MEGDNN_MARK_USED_VAR(attn_weight); + MEGDNN_MARK_USED_VAR(mask_reservespace); + MEGDNN_MARK_USED_VAR(othr_reservespace); + MEGDNN_MARK_USED_VAR(qkvo_weight_bias); + MEGDNN_MARK_USED_VAR(dqueries); + MEGDNN_MARK_USED_VAR(dkeys); + MEGDNN_MARK_USED_VAR(dvalues); + MEGDNN_MARK_USED_VAR(dqkvo_weight_bias); + MEGDNN_MARK_USED_VAR(dbias_k); + MEGDNN_MARK_USED_VAR(dbias_v); + if (m_matmul_opr == nullptr or m_bmatmul_opr == nullptr or m_add_opr == nullptr or + m_elem_opr == nullptr or m_reduce_opr == nullptr or + m_softmaxbw_opr == nullptr or m_dropout_opr == nullptr or + m_dropoutbw_opr == nullptr or m_relayout_opr == nullptr) { + megdnn_assert( + m_matmul_opr == nullptr and m_bmatmul_opr == nullptr and + m_add_opr == nullptr and m_elem_opr == nullptr and + m_reduce_opr == nullptr and m_softmaxbw_opr == nullptr and + m_dropout_opr == nullptr and m_dropoutbw_opr == nullptr and + m_relayout_opr == nullptr, + "All the sub-opr are either not constructed or all constructed, but " + "now only a part is constructed."); + m_matmul_opr = handle->create_operator(); + m_bmatmul_opr = handle->create_operator(); + m_add_opr = handle->create_operator(); + m_elem_opr = handle->create_operator(); + m_reduce_opr = handle->create_operator(); + m_softmaxbw_opr = handle->create_operator(); + m_dropout_opr = handle->create_operator(); + m_dropoutbw_opr = handle->create_operator(); + m_relayout_opr = handle->create_operator(); + } + auto matmul_layout = [](const TensorLayout& A, const TensorLayout& B, + const TensorLayout& C, bool enable) -> bool { + if (!enable) { + return true; + } + // [A0, A1, A2]@[B0, B1] = [C0, C1, C2] + if (A[2] != B[0] || C[0] != A[0] || A[1] != C[1] || C[2] != B[1]) { + return false; + } + return true; + }; + + auto equal_metadata = [&](const Param& param) -> bool { + return m_head == param.num_heads && m_embed_size == param.embeding_size && + m_ksize == param.k_size && m_vsize == param.v_size && + m_qproj_size == param.qproj_size && m_kproj_size == param.kproj_size && + m_vproj_size == param.vproj_size && m_oproj_size == param.oproj_size && + m_qbias == param.qbias && m_kbias == param.kbias && + m_vbias == param.vbias && m_obias == param.obias; + }; + + return equal_metadata(param) && m_datatype == queries.dtype.enumv() && + matmul_layout( + queries, m_wq_layout, m_grad_q_layout, param.qproj_size != 0) && + matmul_layout(keys, m_wk_layout, m_grad_k_layout, param.kproj_size != 0) && + matmul_layout(values, m_wv_layout, m_grad_v_layout, param.vproj_size != 0) && + diff.eq_layout(m_grad_out_layout) && diff.eq_layout(m_grad_drop2_layout) && + dqueries.eq_layout(m_grad_qin_layout) && + dkeys.eq_layout(m_grad_kin_layout) && dvalues.eq_layout(m_grad_vin_layout); +} + +void MHABackwardProxyBase::layout_refill(MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM) { + MEGDNN_MARK_USED_VAR(attn_weight); + MEGDNN_MARK_USED_VAR(mask_reservespace); + MEGDNN_MARK_USED_VAR(handle); + MEGDNN_MARK_USED_VAR(diff); + MEGDNN_MARK_USED_VAR(attn_mask); + MEGDNN_MARK_USED_VAR(othr_reservespace); + MEGDNN_MARK_USED_VAR(dqueries); + MEGDNN_MARK_USED_VAR(dkeys); + MEGDNN_MARK_USED_VAR(dvalues); + MEGDNN_MARK_USED_VAR(dqkvo_weight_bias); + MEGDNN_MARK_USED_VAR(dbias_k); + MEGDNN_MARK_USED_VAR(dbias_v); + // proxy opr + m_softmaxbw_opr->param().axis = -1; + m_matmul_opr->param().format = param::MatrixMul::Format::DEFAULT; + m_bmatmul_opr->param().format = param::MatrixMul::Format::DEFAULT; + m_dropoutbw_opr->param().seed = param.seed; + m_dropoutbw_opr->param().seed = param.seed; + m_dropout_opr->param().seed = param.seed; + m_reduce_opr->param().mode = param::Reduce::Mode::SUM; + m_reduce_opr->param().data_type = param::Reduce::DataType::DEFAULT; + + m_head = param.num_heads; + m_embed_size = param.embeding_size; + m_ksize = param.k_size; + m_vsize = param.v_size; + m_qproj_size = param.qproj_size; + m_kproj_size = param.kproj_size; + m_vproj_size = param.vproj_size; + m_oproj_size = param.oproj_size; + m_qbias = param.qbias; + m_kbias = param.kbias; + m_vbias = param.vbias; + m_obias = param.obias; + auto cal_type = qkvo_weight_bias.dtype; + m_grad_qin_layout = queries; + m_grad_kin_layout = keys; + m_grad_vin_layout = values; + + auto reflash_dtype = [&](DType dtype) { + m_grad_drop2_layout.dtype = dtype; + m_grad_out_layout.dtype = dtype; + m_grad_z_layout.dtype = dtype; + m_grad_wo_layout.dtype = dtype; + m_grad_bo_layout.dtype = dtype; + m_grad_nz_layout.dtype = dtype; + m_grad_nv_layout.dtype = dtype; + m_grad_ny_layout.dtype = dtype; + m_grad_drop1_layout.dtype = dtype; + m_grad_nx_layout.dtype = dtype; + m_grad_nq_layout.dtype = dtype; + m_grad_nk_layout.dtype = dtype; + m_grad_q_layout.dtype = dtype; + m_grad_k_layout.dtype = dtype; + m_grad_v_layout.dtype = dtype; + m_grad_qin_layout.dtype = dtype; + m_grad_wq_layout.dtype = dtype; + m_grad_bq_layout.dtype = dtype; + m_grad_kin_layout.dtype = dtype; + m_grad_wk_layout.dtype = dtype; + m_grad_bk_layout.dtype = dtype; + m_grad_vin_layout.dtype = dtype; + m_grad_wv_layout.dtype = dtype; + m_grad_bv_layout.dtype = dtype; + }; + reflash_dtype(queries.dtype); + m_datatype = queries.dtype.enumv(); +#define cb(DType) \ + if (queries.dtype.enumv() == DTypeTrait::enumv) { \ + m_sizeof_datatype = sizeof(DTypeTrait::ctype); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb + + // weight/bias + m_wq_layout = TensorLayout{{m_embed_size, m_qproj_size}, cal_type}; + m_wk_layout = TensorLayout{{m_ksize, m_kproj_size}, cal_type}; + m_wv_layout = TensorLayout{{m_vsize, m_vproj_size}, cal_type}; + if (m_vproj_size > 0) { + m_wo_layout = TensorLayout{{m_vproj_size, m_oproj_size}, cal_type}; + } else { + m_wo_layout = TensorLayout{{m_vsize * m_head, m_oproj_size}, cal_type}; + } + m_bq_layout = TensorLayout{{1, 1, m_qproj_size}, cal_type}; + m_bk_layout = TensorLayout{{1, 1, m_kproj_size}, cal_type}; + m_bv_layout = TensorLayout{{1, 1, m_vproj_size}, cal_type}; + m_bo_layout = TensorLayout{{1, 1, m_oproj_size}, cal_type}; + + size_t end = 0; + m_wq_off = 0, m_wk_off = 0, m_wv_off = 0, m_wo_off = 0; + m_bq_off = 0, m_bk_off = 0, m_bv_off = 0, m_bo_off = 0; + if (param.qproj_size) { + m_wq_off = end; + end += m_wq_layout.total_nr_elems(); + } + if (param.kproj_size) { + m_wk_off = end; + end += m_wk_layout.total_nr_elems(); + } + if (param.vproj_size) { + m_wv_off = end; + end += m_wv_layout.total_nr_elems(); + } + if (param.oproj_size) { + m_wo_off = end; + end += m_wo_layout.total_nr_elems(); + } + + if (param.qbias && param.qproj_size) { + m_bq_off = end; + end += m_bq_layout.total_nr_elems(); + } + if (param.kbias && param.kproj_size) { + m_bk_off = end; + end += m_bk_layout.total_nr_elems(); + } + if (param.vbias && param.vproj_size) { + m_bv_off = end; + end += m_bv_layout.total_nr_elems(); + } + if (param.obias && param.oproj_size) { + m_bo_off = end; + end += m_bo_layout.total_nr_elems(); + } + + // q/k/v + m_matmul_opr->param().transposeA = false; + m_matmul_opr->param().transposeB = false; + if (param.qproj_size) { + matmul_deduce_layout(m_matmul_opr, queries, m_wq_layout, m_grad_q_layout); + m_grad_nq_layout = TensorLayout{ + {m_grad_q_layout.shape[0] * m_head, m_grad_q_layout.shape[1], + m_grad_q_layout.shape[2] / m_head}, + m_grad_q_layout.dtype}; + } else { + m_grad_q_layout = queries; + m_grad_nq_layout = TensorLayout{ + {m_grad_q_layout[0] * m_head, m_grad_q_layout[1], m_grad_q_layout[2]}, + m_grad_q_layout.dtype}; + } + if (param.kproj_size) { + matmul_deduce_layout(m_matmul_opr, keys, m_wk_layout, m_grad_k_layout); + m_grad_nk_layout = TensorLayout{ + {m_grad_k_layout.shape[0] * m_head, m_grad_k_layout.shape[1], + m_grad_k_layout.shape[2] / m_head}, + m_grad_k_layout.dtype}; + } else { + m_grad_k_layout = keys; + m_grad_nk_layout = TensorLayout{ + {m_grad_k_layout[0] * m_head, m_grad_k_layout[1], m_grad_k_layout[2]}, + m_grad_k_layout.dtype}; + } + if (param.vproj_size) { + matmul_deduce_layout(m_matmul_opr, values, m_wv_layout, m_grad_v_layout); + m_grad_nv_layout = TensorLayout{ + {m_grad_v_layout.shape[0] * m_head, m_grad_v_layout.shape[1], + m_grad_v_layout.shape[2] / m_head}, + m_grad_v_layout.dtype}; + } else { + m_grad_v_layout = values; + m_grad_nv_layout = TensorLayout{ + {m_grad_v_layout[0] * m_head, m_grad_v_layout[1], m_grad_v_layout[2]}, + m_grad_v_layout.dtype}; + } + + // nx + m_bmatmul_opr->param().transposeA = false; + m_bmatmul_opr->param().transposeB = true; + m_bmatmul_opr->deduce_layout(m_grad_nq_layout, m_grad_nk_layout, m_grad_nx_layout); + m_bmatmul_opr->param().transposeA = false; + m_bmatmul_opr->param().transposeB = false; + m_grad_nq_workspacesize = m_bmatmul_opr->get_workspace_in_bytes( + m_grad_nx_layout, m_grad_nk_layout, m_grad_nq_layout); + m_bmatmul_opr->param().transposeA = true; + m_bmatmul_opr->param().transposeB = false; + m_grad_nk_workspacesize = m_bmatmul_opr->get_workspace_in_bytes( + m_grad_nx_layout, m_grad_nq_layout, m_grad_nk_layout); + // softmax + m_grad_ny_layout = m_grad_nx_layout; + m_grad_nx_workspacesize = m_softmaxbw_opr->get_workspace_in_bytes( + m_grad_nx_layout, m_grad_ny_layout, m_grad_nx_layout); + // dropout + m_dropout_opr->param().drop_prob = param.attn_prob; + m_dropoutbw_opr->param().drop_prob = param.attn_prob; + m_dropout_opr->deduce_layout(m_grad_ny_layout, m_grad_drop1_layout, m_mask1_layout); + m_grad_drop1_workspacesize = m_dropoutbw_opr->get_workspace_in_bytes( + m_grad_drop1_layout, m_mask1_layout, m_grad_ny_layout); + + // nz + m_bmatmul_opr->param().transposeA = false; + m_bmatmul_opr->param().transposeB = false; + m_bmatmul_opr->deduce_layout(m_grad_ny_layout, m_grad_nv_layout, m_grad_nz_layout); + m_bmatmul_opr->param().transposeA = false; + m_bmatmul_opr->param().transposeB = true; + m_grad_ny_workspacesize = m_bmatmul_opr->get_workspace_in_bytes( + m_grad_nz_layout, m_grad_nv_layout, m_grad_ny_layout); + m_bmatmul_opr->param().transposeA = true; + m_bmatmul_opr->param().transposeB = false; + m_grad_nv_workspacesize = m_bmatmul_opr->get_workspace_in_bytes( + m_grad_ny_layout, m_grad_nz_layout, m_grad_nv_layout); + + // z + m_grad_z_layout = TensorLayout{ + {m_grad_nz_layout.shape[0] / m_head, m_grad_nz_layout.shape[1], + m_grad_nz_layout.shape[2] * m_head}, + m_grad_nz_layout.dtype}; + + // out + m_matmul_opr->param().transposeA = false; + m_matmul_opr->param().transposeB = false; + if (param.oproj_size) { + matmul_deduce_layout( + m_matmul_opr, m_grad_z_layout, m_wo_layout, m_grad_out_layout); + } else { + m_grad_out_layout = m_grad_z_layout; + } + + // dropout + m_dropout_opr->param().drop_prob = param.out_prob; + m_dropoutbw_opr->param().drop_prob = param.out_prob; + m_dropout_opr->deduce_layout( + m_grad_out_layout, m_grad_drop2_layout, m_mask2_layout); + m_grad_drop2_workspacesize = m_dropoutbw_opr->get_workspace_in_bytes( + m_grad_drop2_layout, m_mask2_layout, m_grad_out_layout); + + // q = qin @ wq + bq + // k = kin @ wk + bk + // v = vin @ wv + bv + m_matmul_opr->param().transposeA = false; + m_matmul_opr->param().transposeB = true; + m_grad_z_workspacesize = 0; + m_grad_qin_workspacesize = 0; + m_grad_kin_workspacesize = 0; + m_grad_vin_workspacesize = 0; + + m_bmatmul_opr->param().transposeA = true; + m_bmatmul_opr->param().transposeB = false; + m_bmatmul_opr->deduce_layout(m_grad_z_layout, m_grad_out_layout, m_grad_wo_layout); + m_grad_wo0_workspacesize = m_bmatmul_opr->get_workspace_in_bytes( + m_grad_z_layout, m_grad_out_layout, m_grad_wo_layout); + m_reduce_opr->param().axis = 0; + m_grad_wo1_workspacesize = + m_reduce_opr->get_workspace_in_bytes(m_grad_wo_layout, m_wo_layout); + m_grad_bo_layout = m_grad_out_layout; + m_grad_bo_layout.shape[0] = 1; + m_grad_bo0_workspacesize = + m_reduce_opr->get_workspace_in_bytes(m_grad_out_layout, m_grad_bo_layout); + m_reduce_opr->param().axis = 1; + m_grad_bo1_workspacesize = + m_reduce_opr->get_workspace_in_bytes(m_grad_bo_layout, m_bo_layout); + + m_bmatmul_opr->deduce_layout(queries, m_grad_q_layout, m_grad_wq_layout); + m_grad_wq0_workspacesize = m_bmatmul_opr->get_workspace_in_bytes( + queries, m_grad_q_layout, m_grad_wq_layout); + m_reduce_opr->param().axis = 0; + m_grad_wq1_workspacesize = + m_reduce_opr->get_workspace_in_bytes(m_grad_wq_layout, m_wq_layout); + m_grad_bq_layout = m_grad_q_layout; + m_grad_bq_layout.shape[0] = 1; + m_grad_bq0_workspacesize = + m_reduce_opr->get_workspace_in_bytes(m_grad_q_layout, m_grad_bq_layout); + m_reduce_opr->param().axis = 1; + m_grad_bq1_workspacesize = + m_reduce_opr->get_workspace_in_bytes(m_grad_bq_layout, m_bq_layout); + + m_bmatmul_opr->deduce_layout(keys, m_grad_k_layout, m_grad_wk_layout); + m_grad_wk0_workspacesize = m_bmatmul_opr->get_workspace_in_bytes( + keys, m_grad_k_layout, m_grad_wk_layout); + m_reduce_opr->param().axis = 0; + m_grad_wk1_workspacesize = + m_reduce_opr->get_workspace_in_bytes(m_grad_wk_layout, m_wk_layout); + m_grad_bk_layout = m_grad_k_layout; + m_grad_bk_layout.shape[0] = 1; + m_grad_bk0_workspacesize = + m_reduce_opr->get_workspace_in_bytes(m_grad_k_layout, m_grad_bk_layout); + m_reduce_opr->param().axis = 1; + m_grad_bk1_workspacesize = + m_reduce_opr->get_workspace_in_bytes(m_grad_bk_layout, m_bk_layout); + + m_bmatmul_opr->deduce_layout(values, m_grad_v_layout, m_grad_wv_layout); + m_grad_wv0_workspacesize = m_bmatmul_opr->get_workspace_in_bytes( + values, m_grad_v_layout, m_grad_wv_layout); + m_reduce_opr->param().axis = 0; + m_grad_wv1_workspacesize = + m_reduce_opr->get_workspace_in_bytes(m_grad_wv_layout, m_wv_layout); + m_grad_bv_layout = m_grad_v_layout; + m_grad_bv_layout.shape[0] = 1; + m_grad_bv0_workspacesize = + m_reduce_opr->get_workspace_in_bytes(m_grad_v_layout, m_grad_bv_layout); + m_reduce_opr->param().axis = 1; + m_grad_bv1_workspacesize = + m_reduce_opr->get_workspace_in_bytes(m_grad_bv_layout, m_bv_layout); + + m_reduce_opr->param().axis = 1; + m_grad_qin_reduce_workspacesize = m_reduce_opr->get_workspace_in_bytes( + {{m_grad_nq_layout[0] / m_head, m_head, m_grad_nq_layout[1], + m_grad_nq_layout[2]}, + m_grad_nq_layout.dtype}, + {{m_grad_nq_layout[0] / m_head, 1, m_grad_nq_layout[1], + m_grad_nq_layout[2]}, + m_grad_nq_layout.dtype}); + m_grad_kin_reduce_workspacesize = m_reduce_opr->get_workspace_in_bytes( + {{m_grad_nk_layout[0] / m_head, m_head, m_grad_nk_layout[1], + m_grad_nk_layout[2]}, + m_grad_nk_layout.dtype}, + {{m_grad_nk_layout[0] / m_head, 1, m_grad_nk_layout[1], + m_grad_nk_layout[2]}, + m_grad_nk_layout.dtype}); + m_grad_vin_reduce_workspacesize = m_reduce_opr->get_workspace_in_bytes( + {{m_grad_nv_layout[0] / m_head, m_head, m_grad_nv_layout[1], + m_grad_nv_layout[2]}, + m_grad_nv_layout.dtype}, + {{m_grad_nv_layout[0] / m_head, 1, m_grad_nv_layout[1], + m_grad_nv_layout[2]}, + m_grad_nv_layout.dtype}); +} + +WorkspaceBundle MHABackwardProxyBase::get_mask_reservespace_bundle( + MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM, void* ptr) { + if (!layout_ismatch(MHA_PROXY_BACKWARD_CALL)) { + layout_refill(MHA_PROXY_BACKWARD_CALL); + } + return WorkspaceBundle( + ptr, {m_mask1_layout.span().dist_byte(), m_mask2_layout.span().dist_byte()}, + 4); +} + +WorkspaceBundle MHABackwardProxyBase::get_othr_reservespace_bundle( + MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM, void* ptr) { + if (!layout_ismatch(MHA_PROXY_BACKWARD_CALL)) { + layout_refill(MHA_PROXY_BACKWARD_CALL); + } + return WorkspaceBundle( + ptr, + {param.num_heads > 1 or param.qproj_size + ? m_grad_nq_layout.span().dist_byte() + : 0, + param.num_heads > 1 or param.kproj_size + ? m_grad_nk_layout.span().dist_byte() + : 0, + param.num_heads > 1 or param.vproj_size + ? m_grad_nv_layout.span().dist_byte() + : 0, + m_grad_nx_layout.span().dist_byte(), + param.oproj_size ? m_grad_z_layout.span().dist_byte() : 0}, + queries.dtype.size()); +} + +WorkspaceBundle MHABackwardProxyBase::get_workspace_bundle( + MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM, void* ptr) { + if (!layout_ismatch(MHA_PROXY_BACKWARD_CALL)) { + layout_refill(MHA_PROXY_BACKWARD_CALL); + } + return WorkspaceBundle( + ptr, + {m_grad_drop2_layout.span().dist_byte(), + m_grad_drop2_workspacesize, + param.oproj_size ? m_grad_z_layout.span().dist_byte() : 0, + param.oproj_size ? m_grad_wo_layout.span().dist_byte() : 0, + param.oproj_size ? m_grad_z_workspacesize : 0, + param.oproj_size ? m_grad_wo0_workspacesize : 0, + param.oproj_size ? m_grad_wo1_workspacesize : 0, + (param.oproj_size and param.obias) ? m_grad_bo_layout.span().dist_byte() + : 0, + (param.oproj_size and param.obias) ? m_grad_bo0_workspacesize : 0, + (param.oproj_size and param.obias) ? m_grad_bo1_workspacesize : 0, + param.num_heads > 1 ? m_grad_nz_layout.span().dist_byte() : 0, + m_grad_ny_layout.span().dist_byte(), + m_grad_nv_layout.span().dist_byte(), + m_grad_ny_workspacesize, + m_grad_nv_workspacesize, + m_grad_drop1_layout.span().dist_byte(), + m_grad_drop1_workspacesize, + m_grad_nx_layout.span().dist_byte(), + m_grad_nx_workspacesize, + m_sizeof_datatype, + m_grad_nq_layout.span().dist_byte(), + m_grad_nk_layout.span().dist_byte(), + m_grad_nq_workspacesize, + m_grad_nk_workspacesize, + (param.qproj_size and param.num_heads > 1) + ? m_grad_q_layout.span().dist_byte() + : 0, + param.qproj_size ? m_grad_wq_layout.span().dist_byte() : 0, + param.qproj_size ? m_grad_bq_layout.span().dist_byte() : 0, + param.qproj_size ? m_grad_qin_workspacesize : 0, + param.qproj_size ? m_grad_wq0_workspacesize : 0, + param.qproj_size ? m_grad_wq1_workspacesize : 0, + (param.qproj_size and param.qbias) ? m_grad_bq0_workspacesize : 0, + (param.qproj_size and param.qbias) ? m_grad_bq1_workspacesize : 0, + param.qproj_size == 0 ? m_grad_qin_reduce_workspacesize : 0, + (param.kproj_size and param.num_heads > 1) + ? m_grad_k_layout.span().dist_byte() + : 0, + param.kproj_size ? m_grad_wk_layout.span().dist_byte() : 0, + param.kproj_size ? m_grad_bk_layout.span().dist_byte() : 0, + param.kproj_size ? m_grad_kin_workspacesize : 0, + param.kproj_size ? m_grad_wk0_workspacesize : 0, + param.kproj_size ? m_grad_wk1_workspacesize : 0, + (param.kproj_size and param.kbias) ? m_grad_bk0_workspacesize : 0, + (param.kproj_size and param.kbias) ? m_grad_bk1_workspacesize : 0, + param.kproj_size == 0 ? m_grad_kin_reduce_workspacesize : 0, + (param.vproj_size and param.num_heads > 1) + ? m_grad_v_layout.span().dist_byte() + : 0, + param.vproj_size ? m_grad_wv_layout.span().dist_byte() : 0, + param.vproj_size ? m_grad_bv_layout.span().dist_byte() : 0, + param.vproj_size ? m_grad_vin_workspacesize : 0, + param.vproj_size ? m_grad_wv0_workspacesize : 0, + param.vproj_size ? m_grad_wv1_workspacesize : 0, + (param.vproj_size and param.vbias) ? m_grad_bv0_workspacesize : 0, + (param.vproj_size and param.vbias) ? m_grad_bv1_workspacesize : 0, + param.vproj_size == 0 ? m_grad_vin_reduce_workspacesize : 0}); +} + +size_t MHABackwardProxyBase::get_mask_reservespace_in_bytes( + MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM) { + auto bundle = get_mask_reservespace_bundle(MHA_PROXY_BACKWARD_CALL); + return bundle.total_size_in_bytes(); +} + +size_t MHABackwardProxyBase::get_othr_reservespace_in_bytes( + MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM) { + auto bundle = get_othr_reservespace_bundle(MHA_PROXY_BACKWARD_CALL); + return bundle.total_size_in_bytes(); +} + +size_t MHABackwardProxyBase::get_workspace_in_bytes( + MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM) { + auto bundle = get_workspace_bundle(MHA_PROXY_BACKWARD_CALL); + return bundle.total_size_in_bytes(); +} + +void MHABackwardProxyBase::exec(MHA_PROXY_BACKWARD_EXEC_PARAM) { +#define cb(DType) \ + if (queries.layout.dtype.enumv() == DTypeTrait::enumv) { \ + using ctype = typename DTypeTrait::ctype; \ + exec_internal(MHA_PROXY_BACKWARD_CALL, workspace); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb +} + +template +void MHABackwardProxyBase::exec_internal(MHA_PROXY_BACKWARD_EXEC_PARAM) { + auto wksp_bundle = get_workspace_bundle( + MHA_PROXY_BACKWARD_TENSOR_TO_LAYOUT_CALL, workspace.raw_ptr); + auto mask_bundle = get_mask_reservespace_bundle( + MHA_PROXY_BACKWARD_TENSOR_TO_LAYOUT_CALL, mask_reservespace.raw_ptr()); + auto othr_bundle = get_othr_reservespace_bundle( + MHA_PROXY_BACKWARD_TENSOR_TO_LAYOUT_CALL, othr_reservespace.raw_ptr()); + size_t head = param.num_heads; + size_t one = 1; + TensorND mask1{mask_bundle.get_workspace(0).raw_ptr, m_mask1_layout}; + TensorND mask2{mask_bundle.get_workspace(1).raw_ptr, m_mask2_layout}; + TensorND nq, nk, nv; + if (param.qproj_size == 0 and param.num_heads == 1) { + nq = queries; + } else { + nq = TensorND{othr_bundle.get_workspace(0).raw_ptr, m_grad_nq_layout}; + } + if (param.kproj_size == 0 and param.num_heads == 1) { + nk = keys; + } else { + nk = TensorND{othr_bundle.get_workspace(1).raw_ptr, m_grad_nk_layout}; + } + if (param.vproj_size == 0 and param.num_heads == 1) { + nv = values; + } else { + nv = TensorND{othr_bundle.get_workspace(2).raw_ptr, m_grad_nv_layout}; + } + TensorND nx{othr_bundle.get_workspace(3).raw_ptr, m_grad_nx_layout}; + + // out = dropout(out) + TensorND grad_drop2{wksp_bundle.get_workspace(0).raw_ptr, m_grad_drop2_layout}; + m_dropoutbw_opr->param().drop_prob = param.out_prob; + m_dropoutbw_opr->exec(diff, mask2, grad_drop2, wksp_bundle.get_workspace(1)); + + // out = z @ wo + bo + TensorND grad_z; + if (param.oproj_size) { + TensorND z{othr_bundle.get_workspace(4).raw_ptr, m_grad_z_layout}; + TensorND oweight{qkvo_weight_bias.ptr() + m_wo_off, m_wo_layout}; + grad_z = TensorND{wksp_bundle.get_workspace(2).raw_ptr, m_grad_z_layout}; + TensorND grad_wo{wksp_bundle.get_workspace(3).raw_ptr, m_grad_wo_layout}; + m_matmul_opr->param().transposeA = false; + m_matmul_opr->param().transposeB = true; + matmul_exec( + m_matmul_opr, grad_drop2, oweight, grad_z, + wksp_bundle.get_workspace(4)); + m_bmatmul_opr->param().transposeA = true; + m_bmatmul_opr->param().transposeB = false; + m_bmatmul_opr->exec(z, grad_drop2, grad_wo, wksp_bundle.get_workspace(5)); + std::swap(m_grad_wo_layout.shape[0], one); + TensorND doweight{dqkvo_weight_bias.ptr() + m_wo_off, m_grad_wo_layout}; + std::swap(m_grad_wo_layout.shape[0], one); + m_reduce_opr->param().axis = 0; + m_reduce_opr->exec(grad_wo, doweight, wksp_bundle.get_workspace(6)); + if (param.obias) { + TensorND dobias{dqkvo_weight_bias.ptr() + m_bo_off, m_bo_layout}; + TensorND grad_bo{wksp_bundle.get_workspace(7).raw_ptr, m_grad_bo_layout}; + m_reduce_opr->exec(grad_drop2, grad_bo, wksp_bundle.get_workspace(8)); + m_reduce_opr->param().axis = 1; + m_reduce_opr->exec(grad_bo, dobias, wksp_bundle.get_workspace(9)); + } + } else { + grad_z = grad_drop2; + } + + // z = nz + TensorND grad_nz; + if (param.num_heads > 1) { + grad_nz = TensorND{wksp_bundle.get_workspace(10).raw_ptr, m_grad_nz_layout}; + auto to_multihead_layout = [&](size_t head, + const TensorLayout& layout) -> TensorLayout { + size_t batch = layout.shape[0]; + size_t seq = layout.shape[1]; + size_t embeding_size = layout.shape[2]; + TensorLayout ret; + ret = TensorLayout{{batch, seq, head, embeding_size / head}, layout.dtype}; + ret = ret.dimshuffle({0, 2, 1, 3}); + return ret; + }; + m_relayout_opr->exec( + {grad_z.raw_ptr(), to_multihead_layout(head, grad_z.layout)}, grad_nz); + } else { + grad_nz = grad_z; + } + + // nz = ny @ nv + TensorND grad_ny{wksp_bundle.get_workspace(11).raw_ptr, m_grad_ny_layout}; + TensorND grad_nv{wksp_bundle.get_workspace(12).raw_ptr, m_grad_nv_layout}; + m_bmatmul_opr->param().transposeA = false; + m_bmatmul_opr->param().transposeB = true; + m_bmatmul_opr->exec(grad_nz, nv, grad_ny, wksp_bundle.get_workspace(13)); + m_bmatmul_opr->param().transposeA = true; + m_bmatmul_opr->param().transposeB = false; + m_bmatmul_opr->exec(attn_weight, grad_nz, grad_nv, wksp_bundle.get_workspace(14)); + + // ny = dropout(ny) + TensorND grad_drop1{wksp_bundle.get_workspace(15).raw_ptr, m_grad_drop1_layout}; + m_dropoutbw_opr->param().drop_prob = param.attn_prob; + m_dropoutbw_opr->exec(grad_ny, mask1, grad_drop1, wksp_bundle.get_workspace(16)); + // ny = softmax(nx) + TensorND grad_nx{wksp_bundle.get_workspace(17).raw_ptr, m_grad_nx_layout}; + m_softmaxbw_opr->param().axis = -1; + m_softmaxbw_opr->exec(nx, grad_drop1, grad_nx, wksp_bundle.get_workspace(18)); + // nx = nx * scaler + T* d_scaler = wksp_bundle.get_workspace(19).ptr(); + T param_scaler = static_cast(param.sm_scaler); + move_scaler_to_device(handle, d_scaler, ¶m_scaler); + m_elem_opr->param().mode = Elemwise::Mode::MUL; + m_elem_opr->exec( + {grad_nx, TensorND{d_scaler, {{1}, queries.layout.dtype}}}, grad_nx); + + // nx = nq @ nk + TensorND grad_nq{wksp_bundle.get_workspace(20).raw_ptr, m_grad_nq_layout}; + TensorND grad_nk{wksp_bundle.get_workspace(21).raw_ptr, m_grad_nk_layout}; + m_bmatmul_opr->param().transposeA = false; + m_bmatmul_opr->param().transposeB = false; + m_bmatmul_opr->exec(grad_nx, nk, grad_nq, wksp_bundle.get_workspace(22)); + m_bmatmul_opr->param().transposeA = true; + m_bmatmul_opr->param().transposeB = false; + m_bmatmul_opr->exec(grad_nx, nq, grad_nk, wksp_bundle.get_workspace(23)); + + // nq, nk, nv = q, k, v + auto from_multihead_layout = [&](size_t head, + const TensorLayout& layout) -> TensorLayout { + size_t batch = layout.shape[0]; + size_t seq = layout.shape[1]; + size_t embeding_size = layout.shape[2]; + TensorLayout ret; + ret = TensorLayout{{batch / head, head, seq, embeding_size}, layout.dtype}; + ret = ret.dimshuffle({0, 2, 1, 3}); + return ret; + }; + TensorND grad_k; + TensorND grad_q; + TensorND grad_v; + + // q = qin @ wq + bq + if (param.qproj_size) { + if (param.num_heads > 1) { + grad_q = TensorND{wksp_bundle.get_workspace(24).raw_ptr, m_grad_q_layout}; + m_relayout_opr->exec( + {grad_nq.raw_ptr(), from_multihead_layout(head, m_grad_nq_layout)}, + grad_q); + } else { + grad_q = grad_nq; + } + TensorND qweight{qkvo_weight_bias.ptr() + m_wq_off, m_wq_layout}; + TensorND grad_wq{wksp_bundle.get_workspace(25).raw_ptr, m_grad_wq_layout}; + TensorND grad_bq{wksp_bundle.get_workspace(26).raw_ptr, m_grad_bq_layout}; + m_matmul_opr->param().transposeA = false; + m_matmul_opr->param().transposeB = true; + matmul_exec( + m_matmul_opr, grad_q, qweight, dqueries, wksp_bundle.get_workspace(27)); + + m_bmatmul_opr->param().transposeA = true; + m_bmatmul_opr->param().transposeB = false; + m_bmatmul_opr->exec(queries, grad_q, grad_wq, wksp_bundle.get_workspace(28)); + std::swap(m_grad_wq_layout.shape[0], one); + TensorND dqweight{dqkvo_weight_bias.ptr() + m_wq_off, m_grad_wq_layout}; + std::swap(m_grad_wq_layout.shape[0], one); + m_reduce_opr->param().axis = 0; + m_reduce_opr->exec(grad_wq, dqweight, wksp_bundle.get_workspace(29)); + if (param.qbias) { + TensorND dqbias{dqkvo_weight_bias.ptr() + m_bq_off, m_bq_layout}; + m_reduce_opr->exec(grad_q, grad_bq, wksp_bundle.get_workspace(30)); + m_reduce_opr->param().axis = 1; + m_reduce_opr->exec(grad_bq, dqbias, wksp_bundle.get_workspace(31)); + } + } else { + m_reduce_opr->param().axis = 1; + grad_nq.layout = TensorLayout{ + {grad_nq.layout[0] / head, head, grad_nq.layout[1], grad_nq.layout[2]}, + grad_nq.layout.dtype}; + m_reduce_opr->exec( + grad_nq, + {dqueries.raw_ptr(), + {{dqueries.layout[0], 1, dqueries.layout[1], dqueries.layout[2]}, + dqueries.layout.dtype}}, + wksp_bundle.get_workspace(32)); + } + + // k = kin @ wk + bk + if (param.kproj_size) { + if (param.num_heads > 1) { + grad_k = TensorND{wksp_bundle.get_workspace(33).raw_ptr, m_grad_k_layout}; + m_relayout_opr->exec( + {grad_nk.raw_ptr(), from_multihead_layout(head, m_grad_nk_layout)}, + grad_k); + } else { + grad_k = grad_nk; + } + + TensorND kweight{qkvo_weight_bias.ptr() + m_wk_off, m_wk_layout}; + TensorND grad_wk{wksp_bundle.get_workspace(34).raw_ptr, m_grad_wk_layout}; + TensorND grad_bk{wksp_bundle.get_workspace(35).raw_ptr, m_grad_bk_layout}; + m_matmul_opr->param().transposeA = false; + m_matmul_opr->param().transposeB = true; + matmul_exec( + m_matmul_opr, grad_k, kweight, dkeys, wksp_bundle.get_workspace(36)); + m_bmatmul_opr->param().transposeA = true; + m_bmatmul_opr->param().transposeB = false; + m_bmatmul_opr->exec(keys, grad_k, grad_wk, wksp_bundle.get_workspace(37)); + std::swap(m_grad_wk_layout.shape[0], one); + TensorND dkweight{dqkvo_weight_bias.ptr() + m_wk_off, m_grad_wk_layout}; + std::swap(m_grad_wk_layout.shape[0], one); + m_reduce_opr->param().axis = 0; + m_reduce_opr->exec(grad_wk, dkweight, wksp_bundle.get_workspace(38)); + if (param.kbias) { + TensorND dkbias{dqkvo_weight_bias.ptr() + m_bk_off, m_bk_layout}; + m_reduce_opr->exec(grad_k, grad_bk, wksp_bundle.get_workspace(39)); + m_reduce_opr->param().axis = 1; + m_reduce_opr->exec(grad_bk, dkbias, wksp_bundle.get_workspace(40)); + } + } else { + m_reduce_opr->param().axis = 1; + grad_nk.layout = TensorLayout{ + {grad_nk.layout[0] / head, head, grad_nk.layout[1], grad_nk.layout[2]}, + grad_nk.layout.dtype}; + m_reduce_opr->exec( + grad_nk, + {dkeys.raw_ptr(), + {{dkeys.layout[0], 1, dkeys.layout[1], dkeys.layout[2]}, + dkeys.layout.dtype}}, + wksp_bundle.get_workspace(41)); + } + + // v = vin @ wv + bv + if (param.vproj_size) { + if (param.num_heads > 1) { + grad_v = TensorND{wksp_bundle.get_workspace(42).raw_ptr, m_grad_v_layout}; + m_relayout_opr->exec( + {grad_nv.raw_ptr(), from_multihead_layout(head, m_grad_nv_layout)}, + grad_v); + } else { + grad_v = grad_nv; + } + + TensorND vweight{qkvo_weight_bias.ptr() + m_wv_off, m_wv_layout}; + TensorND grad_wv{wksp_bundle.get_workspace(43).raw_ptr, m_grad_wv_layout}; + TensorND grad_bv{wksp_bundle.get_workspace(44).raw_ptr, m_grad_bv_layout}; + m_matmul_opr->param().transposeA = false; + m_matmul_opr->param().transposeB = true; + matmul_exec( + m_matmul_opr, grad_v, vweight, dvalues, wksp_bundle.get_workspace(45)); + m_bmatmul_opr->param().transposeA = true; + m_bmatmul_opr->param().transposeB = false; + m_bmatmul_opr->exec(values, grad_v, grad_wv, wksp_bundle.get_workspace(46)); + std::swap(m_grad_wv_layout.shape[0], one); + TensorND dvweight{dqkvo_weight_bias.ptr() + m_wv_off, m_grad_wv_layout}; + std::swap(m_grad_wv_layout.shape[0], one); + m_reduce_opr->param().axis = 0; + m_reduce_opr->exec(grad_wv, dvweight, wksp_bundle.get_workspace(47)); + if (param.vbias) { + TensorND dvbias{dqkvo_weight_bias.ptr() + m_bv_off, m_bv_layout}; + m_reduce_opr->exec(grad_v, grad_bv, wksp_bundle.get_workspace(48)); + m_reduce_opr->param().axis = 1; + m_reduce_opr->exec(grad_bv, dvbias, wksp_bundle.get_workspace(49)); + } + } else { + m_reduce_opr->param().axis = 1; + grad_nv.layout = TensorLayout{ + {grad_nv.layout[0] / head, head, grad_nv.layout[1], grad_nv.layout[2]}, + grad_nv.layout.dtype}; + m_reduce_opr->exec( + grad_nv, + {dvalues.raw_ptr(), + {{dvalues.layout[0], 1, dvalues.layout[1], dvalues.layout[2]}, + dvalues.layout.dtype}}, + wksp_bundle.get_workspace(50)); + } +} + +} // namespace multi_head_attn +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/common/multi_head_attn/proxy_backward_base.h b/dnn/src/common/multi_head_attn/proxy_backward_base.h new file mode 100644 index 0000000000000000000000000000000000000000..e6ca41f901302b09c23e375ebbcc4a2039e5240d --- /dev/null +++ b/dnn/src/common/multi_head_attn/proxy_backward_base.h @@ -0,0 +1,122 @@ +#pragma once +#include "megdnn/dtype.h" + +#include "megdnn/basic_types.h" +#include "megdnn/handle.h" +#include "megdnn/oprs/linalg.h" +#include "megdnn/oprs/nn.h" +#include "src/common/multi_head_attn/helper.h" +#include "src/common/utils.h" + +namespace megdnn { + +namespace multi_head_attn { + +struct MHABackwardProxyBase { + MHABackwardProxyBase() {} + virtual ~MHABackwardProxyBase() = default; + + /********************** function member **********************/ + template + void exec_internal(MHA_PROXY_BACKWARD_EXEC_PARAM); + void exec(MHA_PROXY_BACKWARD_EXEC_PARAM); + + // lambda +#define cb(DType) \ + virtual void move_scaler_to_device( \ + Handle* handle, DTypeTrait::ctype* dst, \ + DTypeTrait::ctype* src) = 0; + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb + + size_t get_workspace_in_bytes(MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM); + size_t get_mask_reservespace_in_bytes(MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM); + size_t get_othr_reservespace_in_bytes(MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM); + + void layout_refill(MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM); + bool layout_ismatch(MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM); + WorkspaceBundle get_mask_reservespace_bundle( + MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM, void* ptr = nullptr); + WorkspaceBundle get_othr_reservespace_bundle( + MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM, void* ptr = nullptr); + WorkspaceBundle get_workspace_bundle( + MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM, void* ptr = nullptr); + + /************************ data member ************************/ + std::unique_ptr m_matmul_opr; + std::unique_ptr m_bmatmul_opr; + std::unique_ptr m_add_opr; + std::unique_ptr m_elem_opr; + std::unique_ptr m_reduce_opr; + std::unique_ptr m_softmaxbw_opr; + std::unique_ptr m_dropout_opr; + std::unique_ptr m_dropoutbw_opr; + std::unique_ptr m_relayout_opr; + + // metadata + size_t m_sizeof_datatype; + megdnn::DTypeEnum m_datatype; + size_t m_wq_off, m_wk_off, m_wv_off, m_wo_off; + size_t m_bq_off, m_bk_off, m_bv_off, m_bo_off; + size_t m_head, m_embed_size, m_ksize, m_vsize, m_qproj_size, m_kproj_size, + m_vproj_size, m_oproj_size; + bool m_qbias, m_kbias, m_vbias, m_obias; + + // out = dropout(out) + TensorLayout m_mask2_layout; + TensorLayout m_grad_drop2_layout; + size_t m_grad_drop2_workspacesize; + TensorLayout m_grad_out_layout; + + // out = z @ wo + bo + TensorLayout m_wo_layout, m_bo_layout; + TensorLayout m_grad_z_layout, m_grad_wo_layout, m_grad_bo_layout; + size_t m_grad_z_workspacesize, m_grad_wo0_workspacesize, m_grad_wo1_workspacesize, + m_grad_bo0_workspacesize, m_grad_bo1_workspacesize; + + // z = nz + TensorLayout m_grad_nz_layout; + + // nz = ny @ nv + TensorLayout m_grad_nv_layout, m_grad_ny_layout; + size_t m_grad_nv_workspacesize, m_grad_ny_workspacesize; + + // ny = dropout(ny) + TensorLayout m_mask1_layout; + TensorLayout m_grad_drop1_layout; + size_t m_grad_drop1_workspacesize; + + // ny = softmax(nx) + TensorLayout m_grad_nx_layout; + size_t m_grad_nx_workspacesize; + + // nx = nq @ nk + TensorLayout m_grad_nq_layout, m_grad_nk_layout; + size_t m_grad_nq_workspacesize, m_grad_nk_workspacesize; + + // nq, nk, nv = q, k, v + TensorLayout m_grad_q_layout, m_grad_k_layout, m_grad_v_layout; + + // q = qin @ wq + bq + TensorLayout m_wq_layout, m_bq_layout; + TensorLayout m_grad_qin_layout, m_grad_wq_layout, m_grad_bq_layout; + size_t m_grad_qin_workspacesize, m_grad_wq0_workspacesize, m_grad_wq1_workspacesize, + m_grad_bq0_workspacesize, m_grad_bq1_workspacesize; + size_t m_grad_qin_reduce_workspacesize, m_grad_kin_reduce_workspacesize, + m_grad_vin_reduce_workspacesize; + + // k = kin @ wk + bk + TensorLayout m_wk_layout, m_bk_layout; + TensorLayout m_grad_kin_layout, m_grad_wk_layout, m_grad_bk_layout; + size_t m_grad_kin_workspacesize, m_grad_wk0_workspacesize, m_grad_wk1_workspacesize, + m_grad_bk0_workspacesize, m_grad_bk1_workspacesize; + + // v = vin @ wv + bv + TensorLayout m_wv_layout, m_bv_layout; + TensorLayout m_grad_vin_layout, m_grad_wv_layout, m_grad_bv_layout; + size_t m_grad_vin_workspacesize, m_grad_wv0_workspacesize, m_grad_wv1_workspacesize, + m_grad_bv0_workspacesize, m_grad_bv1_workspacesize; +}; + +} // namespace multi_head_attn +} // namespace megdnn diff --git a/dnn/src/common/multi_head_attn/proxy_forward_base.cpp b/dnn/src/common/multi_head_attn/proxy_forward_base.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2d7c4f38d55674d42a00af55e07ea0a4c6634c45 --- /dev/null +++ b/dnn/src/common/multi_head_attn/proxy_forward_base.cpp @@ -0,0 +1,587 @@ +#include "megdnn/basic_types.h" +#include "megdnn/dtype.h" +#include "megdnn/oprs.h" +#include "src/common/utils.cuh" + +#include "src/common/multi_head_attn/proxy_forward_base.h" +#include "src/common/utils.h" + +namespace megdnn { + +namespace multi_head_attn { + +bool MHAForwardProxyBase::layout_ismatch(MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM) { + MEGDNN_MARK_USED_VAR(handle); + MEGDNN_MARK_USED_VAR(qkvo_weight_bias); + MEGDNN_MARK_USED_VAR(attn_mask); + MEGDNN_MARK_USED_VAR(bias_k); + MEGDNN_MARK_USED_VAR(bias_v); + MEGDNN_MARK_USED_VAR(out); + MEGDNN_MARK_USED_VAR(attn_weight); + MEGDNN_MARK_USED_VAR(mask_reservespace); + MEGDNN_MARK_USED_VAR(othr_reservespace); + if (m_matmul_opr == nullptr or m_bmatmul_opr == nullptr or m_add_opr == nullptr or + m_elem_opr == nullptr or m_softmax_opr == nullptr or m_dropout_opr == nullptr or + m_relayout_opr == nullptr or m_repeat_opr == nullptr) { + megdnn_assert( + m_matmul_opr == nullptr and m_bmatmul_opr == nullptr and + m_add_opr == nullptr and m_elem_opr == nullptr and + m_softmax_opr == nullptr and m_dropout_opr == nullptr and + m_relayout_opr == nullptr and m_repeat_opr == nullptr, + "All the sub-opr are either not constructed or all constructed, but " + "now only a part is constructed."); + m_matmul_opr = handle->create_operator(); + m_bmatmul_opr = handle->create_operator(); + m_add_opr = handle->create_operator(); + m_elem_opr = handle->create_operator(); + m_softmax_opr = handle->create_operator(); + m_dropout_opr = handle->create_operator(); + m_relayout_opr = handle->create_operator(); + m_repeat_opr = handle->create_operator(); + } + + auto matmul_layout = [](const TensorLayout& A, const TensorLayout& B, + const TensorLayout& C, bool enable) -> bool { + if (!enable) { + return true; + } + // [A0, A1, A2]@[B0, B1] = [C0, C1, C2] + if (A[2] != B[0] || C[0] != A[0] || A[1] != C[1] || C[2] != B[1]) { + return false; + } + return true; + }; + + auto ndim_valid = [&](const Param& param) -> bool { + if (param.num_heads > 1 && param.training) { + return m_q_layout.ndim != 0 && m_k_layout.ndim != 0 && + m_v_layout.ndim != 0 && m_nq_layout.ndim != 0 && + m_nk_layout.ndim != 0 && m_nv_layout.ndim != 0 && + m_nx_layout.ndim != 0 && m_nz_layout.ndim != 0 && + m_z_layout.ndim != 0 && m_out_layout.ndim != 0 && + m_mask1_layout.ndim != 0 && m_mask2_layout.ndim != 0; + } else if (param.num_heads > 1 && !param.training) { + return m_q_layout.ndim != 0 && m_k_layout.ndim != 0 && + m_v_layout.ndim != 0 && m_nq_layout.ndim != 0 && + m_nk_layout.ndim != 0 && m_nv_layout.ndim != 0 && + m_nx_layout.ndim != 0 && m_nz_layout.ndim != 0 && + m_z_layout.ndim != 0 && m_out_layout.ndim != 0; + } else if (param.num_heads == 1 && param.training) { + return m_q_layout.ndim != 0 && m_k_layout.ndim != 0 && + m_v_layout.ndim != 0 && m_nx_layout.ndim != 0 && + m_z_layout.ndim != 0 && m_out_layout.ndim != 0 && + m_mask1_layout.ndim != 0 && m_mask2_layout.ndim != 0; + } else { + return m_q_layout.ndim != 0 && m_k_layout.ndim != 0 && + m_v_layout.ndim != 0 && m_nx_layout.ndim != 0 && + m_z_layout.ndim != 0 && m_out_layout.ndim != 0; + } + }; + + auto equal_metadata = [&](const Param& param) -> bool { + return m_heads == param.num_heads && m_embed_size == param.embeding_size && + m_ksize == param.k_size && m_vsize == param.v_size && + m_qproj_size == param.qproj_size && m_kproj_size == param.kproj_size && + m_vproj_size == param.vproj_size && m_oproj_size == param.oproj_size && + m_qbias == param.qbias && m_kbias == param.kbias && + m_vbias == param.vbias && m_obias == param.obias; + }; + + return equal_metadata(param) && ndim_valid(param) && + m_datatype == queries.dtype.enumv() && + matmul_layout(queries, m_wq_layout, m_q_layout, param.qproj_size != 0) && + matmul_layout(keys, m_wk_layout, m_k_layout, param.kproj_size != 0) && + matmul_layout(values, m_wv_layout, m_v_layout, param.vproj_size != 0); +} + +void MHAForwardProxyBase::layout_refill(MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM) { + MEGDNN_MARK_USED_VAR(handle); + MEGDNN_MARK_USED_VAR(attn_mask); + MEGDNN_MARK_USED_VAR(bias_k); + MEGDNN_MARK_USED_VAR(bias_v); + MEGDNN_MARK_USED_VAR(out); + MEGDNN_MARK_USED_VAR(attn_weight); + MEGDNN_MARK_USED_VAR(mask_reservespace); + MEGDNN_MARK_USED_VAR(othr_reservespace); + + m_heads = param.num_heads; + m_embed_size = param.embeding_size; + m_ksize = param.k_size; + m_vsize = param.v_size; + m_qproj_size = param.qproj_size; + m_kproj_size = param.kproj_size; + m_vproj_size = param.vproj_size; + m_oproj_size = param.oproj_size; + m_qbias = param.qbias; + m_kbias = param.kbias; + m_vbias = param.vbias; + m_obias = param.obias; + auto cal_type = qkvo_weight_bias.dtype; + TensorLayout placeholder_layout; + + auto reflash_dtype = [&](DType dtype) { + m_q_layout.dtype = dtype; + m_k_layout.dtype = dtype; + m_v_layout.dtype = dtype; + m_nq_layout.dtype = dtype; + m_nk_layout.dtype = dtype; + m_nv_layout.dtype = dtype; + m_nx_layout.dtype = dtype; + m_mask1_layout.dtype = dtype; + m_nz_layout.dtype = dtype; + m_z_layout.dtype = dtype; + m_out_layout.dtype = dtype; + m_mask2_layout.dtype = dtype; + }; + reflash_dtype(queries.dtype); + m_datatype = queries.dtype.enumv(); +#define cb(DType) \ + if (queries.dtype.enumv() == DTypeTrait::enumv) { \ + m_sizeof_datatype = sizeof(DTypeTrait::ctype); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb + + // proxy opr + m_matmul_opr->param().format = param::MatrixMul::Format::DEFAULT; + m_bmatmul_opr->param().format = param::MatrixMul::Format::DEFAULT; + m_softmax_opr->param().axis = -1; + m_dropout_opr->param().seed = param.seed; + + // wq/wk/wv/wo + m_wq_layout = TensorLayout{{m_embed_size, m_qproj_size}, cal_type}; + m_wk_layout = TensorLayout{{m_ksize, m_kproj_size}, cal_type}; + m_wv_layout = TensorLayout{{m_vsize, m_vproj_size}, cal_type}; + if (m_vproj_size > 0) { + m_wo_layout = TensorLayout{{m_vproj_size, m_oproj_size}, cal_type}; + } else { + m_wo_layout = TensorLayout{{m_vsize * m_heads, m_oproj_size}, cal_type}; + } + // bq/bk/bv/bo + m_bq_layout = TensorLayout{{m_qproj_size}, cal_type}; + m_bk_layout = TensorLayout{{m_kproj_size}, cal_type}; + m_bv_layout = TensorLayout{{m_vproj_size}, cal_type}; + m_bo_layout = TensorLayout{{m_oproj_size}, cal_type}; + + // wq/wk/wv/wo/bq/bk/bv/bo offset + size_t end = 0; + m_wq_off = 0, m_wk_off = 0, m_wv_off = 0, m_wo_off = 0; + m_bq_off = 0, m_bk_off = 0, m_bv_off = 0, m_bo_off = 0; + if (param.qproj_size) { + m_wq_off = end; + end += m_wq_layout.total_nr_elems(); + } + if (param.kproj_size) { + m_wk_off = end; + end += m_wk_layout.total_nr_elems(); + } + if (param.vproj_size) { + m_wv_off = end; + end += m_wv_layout.total_nr_elems(); + } + if (param.oproj_size) { + m_wo_off = end; + end += m_wo_layout.total_nr_elems(); + } + + if (param.qbias && param.qproj_size) { + m_bq_off = end; + end += m_bq_layout.total_nr_elems(); + } + if (param.kbias && param.kproj_size) { + m_bk_off = end; + end += m_bk_layout.total_nr_elems(); + } + if (param.vbias && param.vproj_size) { + m_bv_off = end; + end += m_bv_layout.total_nr_elems(); + } + if (param.obias && param.oproj_size) { + m_bo_off = end; + end += m_bo_layout.total_nr_elems(); + } + + // q/k/v, nq/nk/nv + auto head_repeat = [&](TensorLayout& m_q_layout, TensorLayout& m_nq_layout) { + m_repeat_opr->param().times = + TensorLayout({1, m_heads, 1, 1}, m_q_layout.dtype); + return m_repeat_opr->get_workspace_in_bytes( + {{m_q_layout[0], 1, m_q_layout[1], m_q_layout[2]}, m_q_layout.dtype}, + {{m_nq_layout[0] / m_heads, m_heads, m_nq_layout[1], m_nq_layout[2]}, + m_nq_layout.dtype}); + }; + m_matmul_opr->param().transposeA = false; + m_matmul_opr->param().transposeB = false; + if (param.qproj_size) { + matmul_deduce_layout(m_matmul_opr, queries, m_wq_layout, m_q_layout); + m_nq_layout = TensorLayout{ + {m_q_layout.shape[0] * m_heads, m_q_layout.shape[1], + m_q_layout.shape[2] / m_heads}, + m_q_layout.dtype}; + m_q_head_repeat_workspacesize = 0; + } else { + m_q_layout = queries; + m_nq_layout = TensorLayout{ + {m_q_layout[0] * m_heads, m_q_layout[1], m_q_layout[2]}, + m_q_layout.dtype}; + m_q_head_repeat_workspacesize = head_repeat(m_q_layout, m_nq_layout); + } + if (param.kproj_size) { + matmul_deduce_layout(m_matmul_opr, keys, m_wk_layout, m_k_layout); + m_nk_layout = TensorLayout{ + {m_k_layout.shape[0] * m_heads, m_k_layout.shape[1], + m_k_layout.shape[2] / m_heads}, + m_k_layout.dtype}; + m_k_head_repeat_workspacesize = 0; + } else { + m_k_layout = keys; + m_nk_layout = TensorLayout{ + {m_k_layout[0] * m_heads, m_k_layout[1], m_k_layout[2]}, + m_k_layout.dtype}; + m_k_head_repeat_workspacesize = head_repeat(m_k_layout, m_nk_layout); + } + if (param.vproj_size) { + matmul_deduce_layout(m_matmul_opr, values, m_wv_layout, m_v_layout); + m_nv_layout = TensorLayout{ + {m_v_layout.shape[0] * m_heads, m_v_layout.shape[1], + m_v_layout.shape[2] / m_heads}, + m_v_layout.dtype}; + m_v_head_repeat_workspacesize = 0; + } else { + m_v_layout = values; + m_nv_layout = TensorLayout{ + {m_v_layout[0] * m_heads, m_v_layout[1], m_v_layout[2]}, + m_v_layout.dtype}; + m_v_head_repeat_workspacesize = head_repeat(m_v_layout, m_nv_layout); + } + m_q_workspacesize = 0; + m_k_workspacesize = 0; + m_v_workspacesize = 0; + + // nx + m_bmatmul_opr->param().transposeA = false; + m_bmatmul_opr->param().transposeB = true; + m_dropout_opr->param().drop_prob = param.attn_prob; + m_bmatmul_opr->deduce_layout(m_nq_layout, m_nk_layout, m_nx_layout); + m_dropout_opr->deduce_layout(m_nx_layout, placeholder_layout, m_mask1_layout); + m_nx_workspacesize = m_bmatmul_opr->get_workspace_in_bytes( + m_nq_layout, m_nk_layout, m_nx_layout); + m_softmax_workspacesize = m_softmax_opr->get_workspace_in_bytes(m_nx_layout, {}); + m_dropout1_workspacesize = m_dropout_opr->get_workspace_in_bytes( + m_nx_layout, placeholder_layout, m_mask1_layout); + + // nz + m_bmatmul_opr->param().transposeA = false; + m_bmatmul_opr->param().transposeB = false; + m_bmatmul_opr->deduce_layout(m_nx_layout, m_nv_layout, m_nz_layout); + m_nz_workspacesize = m_bmatmul_opr->get_workspace_in_bytes( + m_nx_layout, m_nv_layout, m_nz_layout); + + // z + m_z_layout = TensorLayout{ + {m_nz_layout.shape[0] / m_heads, m_nz_layout.shape[1], + m_nz_layout.shape[2] * m_heads}, + m_nz_layout.dtype}; + + // out + m_dropout_opr->param().drop_prob = param.out_prob; + if (param.oproj_size) { + matmul_deduce_layout(m_matmul_opr, m_z_layout, m_wo_layout, m_out_layout); + } else { + m_out_layout = m_z_layout; + } + m_dropout_opr->deduce_layout(m_out_layout, placeholder_layout, m_mask2_layout); + m_out_workspacesize = 0; + m_dropout2_workspacesize = m_dropout_opr->get_workspace_in_bytes( + m_out_layout, placeholder_layout, m_mask2_layout); +} + +WorkspaceBundle MHAForwardProxyBase::get_mask_reservespace_bundle( + MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM, void* ptr) { + if (!layout_ismatch(MHA_PROXY_FORWARD_CALL)) { + layout_refill(MHA_PROXY_FORWARD_CALL); + } + return WorkspaceBundle( + ptr, + {param.training ? m_mask1_layout.span().dist_byte() : 0, + param.training ? m_mask2_layout.span().dist_byte() : 0}, + 4); +} + +WorkspaceBundle MHAForwardProxyBase::get_othr_reservespace_bundle( + MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM, void* ptr) { + if (!layout_ismatch(MHA_PROXY_FORWARD_CALL)) { + layout_refill(MHA_PROXY_FORWARD_CALL); + } + return WorkspaceBundle( + ptr, + {param.num_heads > 1 or param.qproj_size ? m_nq_layout.span().dist_byte() + : 0, + param.num_heads > 1 or param.kproj_size ? m_nk_layout.span().dist_byte() + : 0, + param.num_heads > 1 or param.vproj_size ? m_nv_layout.span().dist_byte() + : 0, + param.training ? m_nx_layout.span().dist_byte() : 0, + param.training or param.oproj_size ? m_z_layout.span().dist_byte() : 0}, + queries.dtype.size()); +} + +WorkspaceBundle MHAForwardProxyBase::get_workspace_bundle( + MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM, void* ptr) { + if (!layout_ismatch(MHA_PROXY_FORWARD_CALL)) { + layout_refill(MHA_PROXY_FORWARD_CALL); + } + return WorkspaceBundle( + ptr, + {param.num_heads > 1 and param.qproj_size ? m_q_layout.span().dist_byte() + : 0, + param.num_heads > 1 and param.qproj_size ? m_q_workspacesize : 0, + param.num_heads > 1 and param.kproj_size ? m_k_layout.span().dist_byte() + : 0, + param.num_heads > 1 and param.kproj_size ? m_k_workspacesize : 0, + param.num_heads > 1 and param.vproj_size ? m_v_layout.span().dist_byte() + : 0, + param.num_heads > 1 and param.vproj_size ? m_v_workspacesize : 0, + param.num_heads > 1 and !param.qproj_size ? m_q_head_repeat_workspacesize + : 0, + param.num_heads > 1 and !param.kproj_size ? m_k_head_repeat_workspacesize + : 0, + param.num_heads > 1 and !param.vproj_size ? m_v_head_repeat_workspacesize + : 0, + m_nx_layout.span().dist_byte(), m_nx_workspacesize, m_sizeof_datatype, + m_softmax_workspacesize, param.training ? m_dropout1_workspacesize : 0, + param.num_heads > 1 ? m_nz_layout.span().dist_byte() : 0, + m_nz_workspacesize, + (param.oproj_size and param.training) ? m_out_layout.span().dist_byte() + : 0, + param.oproj_size ? m_out_workspacesize : 0, + param.training ? m_dropout2_workspacesize : 0}); +} + +size_t MHAForwardProxyBase::get_mask_reservespace_in_bytes( + MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM) { + auto bundle = get_mask_reservespace_bundle(MHA_PROXY_FORWARD_CALL); + return bundle.total_size_in_bytes(); +} + +size_t MHAForwardProxyBase::get_othr_reservespace_in_bytes( + MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM) { + auto bundle = get_othr_reservespace_bundle(MHA_PROXY_FORWARD_CALL); + return bundle.total_size_in_bytes(); +} + +size_t MHAForwardProxyBase::get_workspace_in_bytes( + MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM) { + auto bundle = get_workspace_bundle(MHA_PROXY_FORWARD_CALL); + return bundle.total_size_in_bytes(); +} + +void MHAForwardProxyBase::deduce_layout(MHA_PROXY_FORWARD_LAYOUT_PARAM) { + if (!layout_ismatch(MHA_PROXY_FORWARD_CALL)) { + layout_refill(MHA_PROXY_FORWARD_CALL); + } + attn_weight = m_nx_layout; + out = m_out_layout; + size_t mask_size = get_mask_reservespace_in_bytes(MHA_PROXY_FORWARD_CALL); + size_t othr_size = get_othr_reservespace_in_bytes(MHA_PROXY_FORWARD_CALL); + mask_reservespace = TensorLayout{{mask_size}, dtype::Uint8()}; + othr_reservespace = TensorLayout{{othr_size / queries.dtype.size()}, queries.dtype}; +} + +void MHAForwardProxyBase::exec(MHA_PROXY_FORWARD_EXEC_PARAM) { +#define cb(DType) \ + if (queries.layout.dtype.enumv() == DTypeTrait::enumv) { \ + using ctype = typename DTypeTrait::ctype; \ + exec_internal(MHA_PROXY_FORWARD_CALL, workspace); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb +} + +template +void MHAForwardProxyBase::exec_internal(MHA_PROXY_FORWARD_EXEC_PARAM) { + auto wksp_bundle = get_workspace_bundle( + MHA_PROXY_FORWARD_TENSOR_TO_LAYOUT_CALL, workspace.raw_ptr); + auto mask_bundle = get_mask_reservespace_bundle( + MHA_PROXY_FORWARD_TENSOR_TO_LAYOUT_CALL, mask_reservespace.raw_ptr()); + auto othr_bundle = get_othr_reservespace_bundle( + MHA_PROXY_FORWARD_TENSOR_TO_LAYOUT_CALL, othr_reservespace.raw_ptr()); + + m_matmul_opr->param().transposeA = false; + m_matmul_opr->param().transposeB = false; + TensorND q, k, v; + if (param.qproj_size) { + if (param.num_heads == 1) { + q = TensorND{othr_bundle.get_workspace(0).raw_ptr, m_q_layout}; + } else { + q = TensorND{wksp_bundle.get_workspace(0).raw_ptr, m_q_layout}; + } + TensorND qweight{qkvo_weight_bias.ptr() + m_wq_off, m_wq_layout}; + matmul_exec(m_matmul_opr, queries, qweight, q, wksp_bundle.get_workspace(1)); + if (param.qbias) { + m_add_opr->exec(q, {qkvo_weight_bias.ptr() + m_bq_off, m_bq_layout}); + } + } else { + q = TensorND{queries.raw_ptr(), queries.layout}; + } + if (param.kproj_size) { + if (param.num_heads == 1) { + k = TensorND{othr_bundle.get_workspace(1).raw_ptr, m_k_layout}; + } else { + k = TensorND{wksp_bundle.get_workspace(2).raw_ptr, m_k_layout}; + } + TensorND kweight{qkvo_weight_bias.ptr() + m_wk_off, m_wk_layout}; + matmul_exec(m_matmul_opr, keys, kweight, k, wksp_bundle.get_workspace(3)); + if (param.kbias) { + m_add_opr->exec(k, {qkvo_weight_bias.ptr() + m_bk_off, m_bk_layout}); + } + } else { + k = TensorND{keys.raw_ptr(), keys.layout}; + } + if (param.vproj_size) { + if (param.num_heads == 1) { + v = TensorND{othr_bundle.get_workspace(2).raw_ptr, m_v_layout}; + } else { + v = TensorND{wksp_bundle.get_workspace(4).raw_ptr, m_v_layout}; + } + TensorND vweight{qkvo_weight_bias.ptr() + m_wv_off, m_wv_layout}; + matmul_exec(m_matmul_opr, values, vweight, v, wksp_bundle.get_workspace(5)); + if (param.vbias) { + m_add_opr->exec(v, {qkvo_weight_bias.ptr() + m_bv_off, m_bv_layout}); + } + } else { + v = TensorND{values.raw_ptr(), values.layout}; + } + + // nq/nk/nv: norm to multihead + auto relayout_to_multihead = [&](TensorND& q, TensorND& nq) { + size_t batch = q.layout[0]; + size_t seq = q.layout[1]; + size_t embeding_size = q.layout[2]; + TensorLayout nlayout{ + {batch, seq, m_heads, embeding_size / m_heads}, q.layout.dtype}; + nlayout = nlayout.dimshuffle({0, 2, 1, 3}); + m_relayout_opr->exec({q.raw_ptr(), nlayout}, nq); + }; + auto repeat_to_multihead = [&](TensorND& q, TensorND& nq, size_t idx) { + q.layout = TensorLayout( + {q.layout[0], 1, q.layout[1], q.layout[2]}, q.layout.dtype); + nq.layout = TensorLayout( + {nq.layout[0] / m_heads, m_heads, nq.layout[1], nq.layout[2]}, + nq.layout.dtype); + m_repeat_opr->param().times = TensorLayout({1, m_heads, 1, 1}, q.layout.dtype); + m_repeat_opr->exec(q, nq, wksp_bundle.get_workspace(idx)); + nq.layout = TensorLayout( + {nq.layout[0] * nq.layout[1], nq.layout[2], nq.layout[3]}, + nq.layout.dtype); + }; + TensorND nq = q, nk = k, nv = v; + if (param.num_heads > 1) { + nq = TensorND{othr_bundle.get_workspace(0).raw_ptr, m_nq_layout}; + if (param.qproj_size) { + relayout_to_multihead(q, nq); + } else { + repeat_to_multihead(q, nq, 6); + } + } + if (param.num_heads > 1) { + nk = TensorND{othr_bundle.get_workspace(1).raw_ptr, m_nk_layout}; + if (param.kproj_size) { + relayout_to_multihead(k, nk); + } else { + repeat_to_multihead(k, nk, 7); + } + } + if (param.num_heads > 1) { + nv = TensorND{othr_bundle.get_workspace(2).raw_ptr, m_nv_layout}; + if (param.vproj_size) { + relayout_to_multihead(v, nv); + } else { + repeat_to_multihead(v, nv, 8); + } + } + + // nx + TensorND nx{wksp_bundle.get_workspace(9).raw_ptr, m_nx_layout}; + TensorND ny{othr_bundle.get_workspace(3).raw_ptr, m_nx_layout}; + TensorND mask1{mask_bundle.get_workspace(0).raw_ptr, m_mask1_layout}; + m_bmatmul_opr->param().transposeA = false; + m_bmatmul_opr->param().transposeB = true; + m_bmatmul_opr->exec(nq, nk, nx, wksp_bundle.get_workspace(10)); + // scale + auto d_scaler = wksp_bundle.get_workspace(11).ptr(); + T param_scaler = static_cast(param.sm_scaler); + move_scaler_to_device(handle, d_scaler, ¶m_scaler); + m_elem_opr->param().mode = Elemwise::Mode::MUL; + m_elem_opr->exec({nx, TensorND{d_scaler, {{1}, queries.layout.dtype}}}, nx); + // mask + if (param.attn_mask_type == MaskType::DEFAULT_MASK or + param.attn_mask_type == MaskType::USER_DEFINED_MASK) { + m_elem_opr->param().mode = Elemwise::Mode::ADD; + m_elem_opr->exec({nx, attn_mask}, nx); + } + if (param.training) { + // softmax + m_softmax_opr->exec(nx, ny, wksp_bundle.get_workspace(12)); + // dropout + m_dropout_opr->param().drop_prob = param.attn_prob; + m_dropout_opr->exec(ny, attn_weight, mask1, wksp_bundle.get_workspace(13)); + } else { + m_softmax_opr->exec(nx, attn_weight, wksp_bundle.get_workspace(12)); + } + // nz + TensorND nz{wksp_bundle.get_workspace(14).raw_ptr, m_nz_layout}; + TensorND z{othr_bundle.get_workspace(4).raw_ptr, m_z_layout}; + m_bmatmul_opr->param().transposeA = false; + m_bmatmul_opr->param().transposeB = false; + if (param.num_heads > 1) { + m_bmatmul_opr->exec(attn_weight, nv, nz, wksp_bundle.get_workspace(15)); + // z: multihead to norm + auto relayout_from_multihead = [&](const TensorND& nq, const TensorND& q) { + size_t batch = nq.layout[0]; + size_t seq = nq.layout[1]; + size_t embeding_size = nq.layout[2]; + TensorLayout layout{ + {batch / m_heads, m_heads, seq, embeding_size}, nq.layout.dtype}; + layout = layout.dimshuffle({0, 2, 1, 3}); + m_relayout_opr->exec({nq.raw_ptr(), layout}, q); + }; + if ((param.training == false) and (param.oproj_size == 0)) { + relayout_from_multihead(nz, out); + } else { + relayout_from_multihead(nz, z); + } + } else if ((param.training == false) and (param.oproj_size == 0)) { + m_bmatmul_opr->exec(attn_weight, nv, out, wksp_bundle.get_workspace(15)); + } else { + m_bmatmul_opr->exec(attn_weight, nv, z, wksp_bundle.get_workspace(15)); + } + + // o + TensorND o; + TensorND mask2{mask_bundle.get_workspace(1).raw_ptr, m_mask2_layout}; + m_matmul_opr->param().transposeA = false; + m_matmul_opr->param().transposeB = false; + if (param.oproj_size) { + if (param.training) { + o = TensorND{wksp_bundle.get_workspace(16).raw_ptr, m_out_layout}; + } else { + o = out; + } + TensorND oweight{qkvo_weight_bias.ptr() + m_wo_off, m_wo_layout}; + matmul_exec(m_matmul_opr, z, oweight, o, wksp_bundle.get_workspace(17)); + if (param.obias) { + m_add_opr->exec(o, {qkvo_weight_bias.ptr() + m_bo_off, m_bo_layout}); + } + } else { + o = TensorND{z.raw_ptr(), m_z_layout}; + } + if (param.training) { + m_dropout_opr->param().drop_prob = param.out_prob; + m_dropout_opr->exec(o, out, mask2, wksp_bundle.get_workspace(18)); + } +} +} // namespace multi_head_attn +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/common/multi_head_attn/proxy_forward_base.h b/dnn/src/common/multi_head_attn/proxy_forward_base.h new file mode 100644 index 0000000000000000000000000000000000000000..ba4fccc913574e25f474be2594f8a67d5d3bbd08 --- /dev/null +++ b/dnn/src/common/multi_head_attn/proxy_forward_base.h @@ -0,0 +1,96 @@ +#pragma once +#include "megdnn/dtype.h" + +#include "megdnn/basic_types.h" +#include "megdnn/handle.h" +#include "megdnn/oprs/linalg.h" +#include "megdnn/oprs/nn.h" +#include "src/common/multi_head_attn/helper.h" +#include "src/common/utils.h" + +namespace megdnn { + +namespace multi_head_attn { + +struct MHAForwardProxyBase { + MHAForwardProxyBase() {} + virtual ~MHAForwardProxyBase() = default; + + /********************** function member **********************/ + template + void exec_internal(MHA_PROXY_FORWARD_EXEC_PARAM); + void exec(MHA_PROXY_FORWARD_EXEC_PARAM); + + // lambda +#define cb(DType) \ + virtual void move_scaler_to_device( \ + Handle* handle, DTypeTrait::ctype* dst, \ + DTypeTrait::ctype* src) = 0; + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb + + void deduce_layout(MHA_PROXY_FORWARD_LAYOUT_PARAM); + size_t get_workspace_in_bytes(MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM); + size_t get_mask_reservespace_in_bytes(MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM); + size_t get_othr_reservespace_in_bytes(MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM); + + void layout_refill(MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM); + bool layout_ismatch(MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM); + WorkspaceBundle get_mask_reservespace_bundle( + MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM, void* ptr = nullptr); + WorkspaceBundle get_othr_reservespace_bundle( + MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM, void* ptr = nullptr); + WorkspaceBundle get_workspace_bundle( + MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM, void* ptr = nullptr); + + /************************ data member ************************/ + std::unique_ptr m_matmul_opr; + std::unique_ptr m_bmatmul_opr; + std::unique_ptr m_add_opr; + std::unique_ptr m_elem_opr; + std::unique_ptr m_softmax_opr; + std::unique_ptr m_dropout_opr; + std::unique_ptr m_relayout_opr; + std::unique_ptr m_repeat_opr; + + // metadata + size_t m_sizeof_datatype; + megdnn::DTypeEnum m_datatype; + size_t m_wq_off, m_wk_off, m_wv_off, m_wo_off; + size_t m_bq_off, m_bk_off, m_bv_off, m_bo_off; + size_t m_heads, m_embed_size, m_ksize, m_vsize, m_qproj_size, m_kproj_size, + m_vproj_size, m_oproj_size; + bool m_qbias, m_kbias, m_vbias, m_obias; + + // q/k/v = matmul(qu/ky/va, wq/wk/wv, bq/bk/bv) + // nq/nk/nv = dimshuffle(q/k/v) (norm to multihead) + TensorLayout m_wq_layout, m_wk_layout, m_wv_layout; + TensorLayout m_bq_layout, m_bk_layout, m_bv_layout; + TensorLayout m_q_layout, m_k_layout, m_v_layout; + TensorLayout m_nq_layout, m_nk_layout, m_nv_layout; + size_t m_q_workspacesize, m_k_workspacesize, m_v_workspacesize; + size_t m_q_head_repeat_workspacesize, m_k_head_repeat_workspacesize, + m_v_head_repeat_workspacesize; + + // nx = matmul(nq, nk) + // ny = softmax(nx), ny_layout = m_nx_layout; + // ny = dropout(ny), dropout1_layout = m_nx_layout; + TensorLayout m_nx_layout; + TensorLayout m_mask1_layout; + size_t m_nx_workspacesize, m_softmax_workspacesize, m_dropout1_workspacesize; + + // nz = matmul(ny, v) + // z = dimshuffle(nz) (multihead to norm) + TensorLayout m_nz_layout; + TensorLayout m_z_layout; + size_t m_nz_workspacesize; + + // out = matmul(z, wo, bo) + // out = dropout(out), dropout2_layout = m_out_layout; + TensorLayout m_wo_layout, m_bo_layout; + TensorLayout m_out_layout; + TensorLayout m_mask2_layout; + size_t m_out_workspacesize, m_dropout2_workspacesize; +}; +} // namespace multi_head_attn +} // namespace megdnn diff --git a/dnn/src/cuda/multi_head_attn/cudnn_fwbw.cpp b/dnn/src/cuda/multi_head_attn/cudnn_fwbw.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2ab067f58583229d8ee1dc95ec669e7172a11897 --- /dev/null +++ b/dnn/src/cuda/multi_head_attn/cudnn_fwbw.cpp @@ -0,0 +1,398 @@ +#include "src/cuda/multi_head_attn/cudnn_fwbw.h" +#include +#include "megdnn/handle.h" +#include "src/cuda/utils.h" +#if CUDNN_VERSION >= 8004 +#include "megdnn/dtype.h" + +namespace megdnn { +namespace cuda { + +/***************************** AuxiliaryArray *****************************/ +AuxiliaryArray::~AuxiliaryArray() { + if (attnMaskType != MaskType::CUDNN_STYLE_MASK) { + if (devSeqQArray) { + cuda_check(cudaFree(devSeqQArray)); + } + if (devSeqKArray) { + cuda_check(cudaFree(devSeqKArray)); + } + } +} + +bool AuxiliaryArray::is_initialized( + const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK, + MaskType _attnMaskType) { + if (_batchSize != batchSize or _seqLenQ != seqLenQ or _seqLenK != seqLenK or + _attnMaskType != attnMaskType or (seqQArray.size() != _batchSize) or + (seqKArray.size() != _batchSize) or !devSeqQArray or !devSeqKArray or + (loWinIdx.size() != _seqLenQ) or (hiWinIdx.size() != _seqLenQ)) { + return false; + } + return true; +} + +void AuxiliaryArray::set_cudnn_style_mask(Handle* handle, const TensorND& attn_mask) { + megdnn_assert(attnMaskType == MaskType::CUDNN_STYLE_MASK); + auto stream = cuda_stream(handle); +#define T DTypeTrait<::megdnn::dtype::Int32>::ctype + devSeqQArray = attn_mask.ptr() + 2 * seqLenQ; + devSeqKArray = attn_mask.ptr() + 2 * seqLenQ + batchSize; + cuda_check(cudaMemcpyAsync( + seqQArray.data(), devSeqQArray, batchSize * sizeof(int), + cudaMemcpyDeviceToHost, stream)); + cuda_check(cudaMemcpyAsync( + seqKArray.data(), devSeqKArray, batchSize * sizeof(int), + cudaMemcpyDeviceToHost, stream)); + cuda_check(cudaMemcpyAsync( + loWinIdx.data(), attn_mask.ptr(), seqLenQ * sizeof(int), + cudaMemcpyDeviceToHost, stream)); + cuda_check(cudaMemcpyAsync( + hiWinIdx.data(), attn_mask.ptr() + seqLenQ, seqLenQ * sizeof(int), + cudaMemcpyDeviceToHost, stream)); +#undef T +} + +void AuxiliaryArray::set( + Handle* handle, const size_t _batchSize, const size_t _seqLenQ, + const size_t _seqLenK, MaskType _attnMaskType) { + if (_batchSize == batchSize && _seqLenQ == seqLenQ && _seqLenK == seqLenK && + _attnMaskType == attnMaskType) { + return; + } else { + if (attnMaskType != MaskType::CUDNN_STYLE_MASK) { + if (devSeqQArray) { + cuda_check(cudaFree(devSeqQArray)); + } + if (devSeqKArray) { + cuda_check(cudaFree(devSeqKArray)); + } + } + }; + seqLenQ = _seqLenQ; + seqLenK = _seqLenK; + batchSize = _batchSize; + attnMaskType = _attnMaskType; + loWinIdx.resize(seqLenQ); + hiWinIdx.resize(seqLenQ); + size_t seqQArraySize = 1 * batchSize; + size_t seqKArraySize = batchSize; + seqQArray.resize(seqQArraySize); + seqKArray.resize(seqKArraySize); + if (attnMaskType == MaskType::CUDNN_STYLE_MASK) { + return; + } + for (size_t i = 0; i < seqQArraySize; ++i) { + seqQArray[i] = seqLenQ; + } + for (size_t i = 0; i < seqKArraySize; ++i) { + seqKArray[i] = seqLenK; + } + + cuda_check(cudaMalloc((void**)&devSeqQArray, seqQArraySize * sizeof(int))); + cuda_check(cudaMalloc((void**)&devSeqKArray, seqKArraySize * sizeof(int))); + + auto stream = cuda_stream(handle); + cuda_check(cudaMemcpyAsync( + devSeqQArray, seqQArray.data(), seqQArraySize * sizeof(int), + cudaMemcpyHostToDevice, stream)); + cuda_check(cudaMemcpyAsync( + devSeqKArray, seqKArray.data(), seqKArraySize * sizeof(int), + cudaMemcpyHostToDevice, stream)); + + for (size_t i = 0; i < seqLenQ; ++i) { + loWinIdx[i] = 0; + if (attnMaskType == MaskType::DEFAULT_MASK) { + hiWinIdx[i] = i + 1; + } else if (attnMaskType == MaskType::NO_MASK) { + hiWinIdx[i] = seqLenK; + } + } +} + +/***************************** MultiHeadAttnStatus *****************************/ +void MultiHeadAttnStatus::set( + Handle* handle, const Param& p, const TensorLayout& q, const TensorLayout& k, + const TensorLayout& v) { + // It is consistent with the conditions judged in is_initialized. + // dropout + float attn_prob = p.training ? p.attn_prob : 0.f; + float out_prob = p.training ? p.out_prob : 0.f; + if (!attn_dropout_status.initialized()) { + attn_dropout_status.set(cudnn_handle(handle), p.seed, attn_prob); + } + if (!out_dropout_status.initialized()) { + out_dropout_status.set(cudnn_handle(handle), p.seed, out_prob); + } + + if (attn_dropout_status.drop_prob != attn_prob) { + attn_dropout_status.drop_prob = attn_prob; + attn_dropout_status.restore_desc(cudnn_handle(handle)); + } + if (out_dropout_status.drop_prob != out_prob) { + out_dropout_status.drop_prob = out_prob; + out_dropout_status.restore_desc(cudnn_handle(handle)); + } + // size + batchSize = q.shape[0]; + seqLenQ = q.shape[1]; + seqLenK = k.shape[1]; + numHeads = p.num_heads; + qSize = p.embeding_size; + kSize = p.k_size; + vSize = p.v_size; + qProjSize = p.qproj_size / numHeads; + kProjSize = p.kproj_size / numHeads; + vProjSize = p.vproj_size / numHeads; + oProjSize = p.oproj_size; + attnMaskType = p.attn_mask_type; + bias = p.qbias or p.kbias or p.vbias or p.obias; + cudnnDataType_t cudnn_dtype = to_cudnn_dtype(q.dtype); + auto flag = CUDNN_ATTN_QUERYMAP_ONE_TO_ONE; + if (bias) { + flag = flag | CUDNN_ATTN_ENABLE_PROJ_BIASES; + } +#if CUDNN_VERSION < 8600 + cudnn_check(cudnnSetAttnDescriptor( + attn_desc, flag, numHeads, p.sm_scaler, cudnn_dtype, cudnn_dtype, + CUDNN_DEFAULT_MATH, attn_dropout_status.desc.desc, NULL, qSize, kSize, + vSize, qProjSize, kProjSize, vProjSize, oProjSize, seqLenQ, seqLenK, + batchSize, 1)); +#else + cudnn_check(cudnnSetAttnDescriptor( + attn_desc, flag, numHeads, p.sm_scaler, cudnn_dtype, cudnn_dtype, + CUDNN_DEFAULT_MATH, attn_dropout_status.desc.desc, + out_dropout_status.desc.desc, qSize, kSize, vSize, qProjSize, kProjSize, + vProjSize, oProjSize, seqLenQ, seqLenK, batchSize, 1)); +#endif + + // misc + auxArray.set(handle, batchSize, seqLenQ, seqLenK, p.attn_mask_type); + + if (p.training) { + cudnnGetMultiHeadAttnBuffers( + cudnn_handle(handle), attn_desc, &sizeWeights, &sizeWkspace, + &sizeReserve); + } else { + cudnnGetMultiHeadAttnBuffers( + cudnn_handle(handle), attn_desc, &sizeWeights, &sizeWkspace, NULL); + sizeReserve = 0; + } +} + +void MultiHeadAttnStatus::set_cudnn_style_mask( + Handle* handle, const TensorND& attn_mask) { + auxArray.set_cudnn_style_mask(handle, attn_mask); +} + +bool MultiHeadAttnStatus::is_initialized( + const Param& p, const TensorLayout& q, const TensorLayout& k, + const TensorLayout& v) { + // By default, the info of q, k and v must be consistent with the corresponding + // parameters in param respectively, otherwise an error will occur (so, check is not + // done here, mainly by check_exec). + // dropout + float attn_prob = p.training ? p.attn_prob : 0.f; + float out_prob = p.training ? p.out_prob : 0.f; + if (!attn_dropout_status.initialized() or !out_dropout_status.initialized() or + attn_dropout_status.drop_prob != attn_prob or + out_dropout_status.drop_prob != out_prob) { + return false; + } + // size + if (q.shape[0] != batchSize or q.shape[1] != seqLenQ or k.shape[1] != seqLenK or + q.shape[2] != qSize or k.shape[2] != kSize or v.shape[2] != vSize or + attnMaskType != p.attn_mask_type or numHeads != p.num_heads) { + return false; + } + bool pbias = p.qbias or p.kbias or p.vbias or p.obias; + if (qSize != p.embeding_size or kSize != p.k_size or vSize != p.v_size) { + return false; + } + if (bias != pbias) { + return false; + } + if ((qProjSize != (p.qproj_size / p.num_heads)) or + (kProjSize != (p.kproj_size / p.num_heads)) or + (vProjSize != (p.vproj_size / p.num_heads)) or (oProjSize != p.oproj_size)) { + return false; + } + // misc + if (!auxArray.is_initialized(batchSize, seqLenQ, seqLenK, attnMaskType)) { + return false; + } + if (p.training and sizeReserve == 0) { + return false; + } + return true; +} + +/***************************** MHA forward *****************************/ +void MHAForwardCudnnOpr::deduce_layout(MHA_PROXY_FORWARD_LAYOUT_PARAM) { + MEGDNN_MARK_USED_VAR(qkvo_weight_bias); + MEGDNN_MARK_USED_VAR(attn_mask); + MEGDNN_MARK_USED_VAR(bias_k); + MEGDNN_MARK_USED_VAR(bias_v); + megdnn_assert( + queries.ndim == 3, + "queries.ndim should be 3[batch, sequence, embeding], but got %zu", + queries.ndim); + if (!desc_status.is_initialized(param, queries, keys, values)) { + desc_status.set(handle, param, queries, keys, values); + } + + attn_weight = TensorLayout( + TensorShape{ + queries.shape[0] * param.num_heads, queries.shape[1], + keys.shape[1]}, + queries.dtype); + size_t osize = param.oproj_size != 0 + ? param.oproj_size + : (param.vproj_size != 0 ? param.vproj_size + : (param.v_size * param.num_heads)); + out = TensorLayout( + TensorShape{queries.shape[0], queries.shape[1], osize}, queries.dtype); + mask_reservespace = TensorLayout(TensorShape{0}, dtype::Uint8()); + othr_reservespace = TensorLayout( + TensorShape{desc_status.sizeReserve / queries.dtype.size()}, queries.dtype); +} + +void MHAForwardCudnnOpr::exec(MHA_PROXY_FORWARD_EXEC_PARAM) { + if (!desc_status.is_initialized( + param, queries.layout, keys.layout, values.layout)) { + desc_status.set(handle, param, queries.layout, keys.layout, values.layout); + } + if (param.attn_mask_type == MaskType::CUDNN_STYLE_MASK) { + desc_status.set_cudnn_style_mask(handle, attn_mask); + } + + size_t osize = desc_status.oProjSize != 0 + ? desc_status.oProjSize + : (desc_status.vProjSize != 0 + ? desc_status.vProjSize * param.num_heads + : desc_status.vSize * param.num_heads); + SeqTensorDesc q{queries.layout, desc_status.batchSize, + desc_status.seqLenQ, desc_status.qSize, + param.input_order, desc_status.auxArray.seqQArray.data()}; + SeqTensorDesc o{out.layout, desc_status.batchSize, + desc_status.seqLenQ, osize, + param.input_order, desc_status.auxArray.seqQArray.data()}; + SeqTensorDesc k{keys.layout, desc_status.batchSize, + desc_status.seqLenK, desc_status.kSize, + param.input_order, desc_status.auxArray.seqKArray.data()}; + SeqTensorDesc v{values.layout, desc_status.batchSize, + desc_status.seqLenK, desc_status.vSize, + param.input_order, desc_status.auxArray.seqKArray.data()}; + + cudnn_check(cudnnMultiHeadAttnForward( + cudnn_handle(handle), desc_status.attn_desc, -1, + desc_status.auxArray.loWinIdx.data(), desc_status.auxArray.hiWinIdx.data(), + desc_status.auxArray.devSeqQArray, desc_status.auxArray.devSeqKArray, + q.desc, queries.raw_ptr(), param.reslink ? queries.raw_ptr() : NULL, k.desc, + keys.raw_ptr(), v.desc, values.raw_ptr(), o.desc, out.raw_ptr(), + desc_status.sizeWeights, + desc_status.sizeWeights > 0 ? qkvo_weight_bias.raw_ptr() : NULL, + desc_status.sizeWkspace, workspace.raw_ptr, + param.training ? desc_status.sizeReserve : 0, + param.training ? othr_reservespace.raw_ptr() : NULL)); +} + +size_t MHAForwardCudnnOpr::get_workspace_in_bytes( + MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM) { + MEGDNN_MARK_USED_VAR(qkvo_weight_bias); + MEGDNN_MARK_USED_VAR(attn_mask); + MEGDNN_MARK_USED_VAR(out); + MEGDNN_MARK_USED_VAR(attn_weight); + MEGDNN_MARK_USED_VAR(mask_reservespace); + MEGDNN_MARK_USED_VAR(othr_reservespace); + + if (!desc_status.is_initialized(param, queries, keys, values)) { + desc_status.set(handle, param, queries, keys, values); + } + + return desc_status.sizeWkspace; +} + +size_t MHAForwardCudnnOpr::get_mask_reservespace_in_bytes( + MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM) { + MEGDNN_MARK_USED_VAR(qkvo_weight_bias); + MEGDNN_MARK_USED_VAR(attn_mask); + MEGDNN_MARK_USED_VAR(out); + MEGDNN_MARK_USED_VAR(attn_weight); + MEGDNN_MARK_USED_VAR(mask_reservespace); + MEGDNN_MARK_USED_VAR(othr_reservespace); + if (!desc_status.is_initialized(param, queries, keys, values)) { + desc_status.set(handle, param, queries, keys, values); + } + return 0; +} + +size_t MHAForwardCudnnOpr::get_othr_reservespace_in_bytes( + MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM) { + MEGDNN_MARK_USED_VAR(qkvo_weight_bias); + MEGDNN_MARK_USED_VAR(out); + MEGDNN_MARK_USED_VAR(mask_reservespace); + MEGDNN_MARK_USED_VAR(othr_reservespace); + if (!desc_status.is_initialized(param, queries, keys, values)) { + desc_status.set(handle, param, queries, keys, values); + } + + return desc_status.sizeReserve; +} + +/***************************** MHA backward *****************************/ +void MHABackwardCudnnOpr::exec(MHA_PROXY_BACKWARD_EXEC_PARAM) { + if (!desc_status.is_initialized( + param, queries.layout, keys.layout, values.layout)) { + desc_status.set(handle, param, queries.layout, keys.layout, values.layout); + } + if (param.attn_mask_type == MaskType::CUDNN_STYLE_MASK) { + desc_status.set_cudnn_style_mask(handle, attn_mask); + } + + size_t osize = desc_status.oProjSize != 0 + ? desc_status.oProjSize + : (desc_status.vProjSize != 0 + ? (desc_status.vProjSize * param.num_heads) + : (desc_status.vSize * param.num_heads)); + SeqTensorDesc q{queries.layout, desc_status.batchSize, + desc_status.seqLenQ, desc_status.qSize, + param.input_order, desc_status.auxArray.seqQArray.data()}; + SeqTensorDesc d{diff.layout, desc_status.batchSize, + desc_status.seqLenQ, osize, + param.input_order, desc_status.auxArray.seqQArray.data()}; + SeqTensorDesc k{keys.layout, desc_status.batchSize, + desc_status.seqLenK, desc_status.kSize, + param.input_order, desc_status.auxArray.seqKArray.data()}; + SeqTensorDesc v{values.layout, desc_status.batchSize, + desc_status.seqLenK, desc_status.vSize, + param.input_order, desc_status.auxArray.seqKArray.data()}; + + cudnn_check(cudnnMultiHeadAttnBackwardData( + cudnn_handle(handle), desc_status.attn_desc, + desc_status.auxArray.loWinIdx.data(), desc_status.auxArray.hiWinIdx.data(), + desc_status.auxArray.devSeqQArray, desc_status.auxArray.devSeqKArray, + d.desc, diff.raw_ptr(), q.desc, dqueries.raw_ptr(), queries.raw_ptr(), + k.desc, dkeys.raw_ptr(), keys.raw_ptr(), v.desc, dvalues.raw_ptr(), + values.raw_ptr(), desc_status.sizeWeights, + desc_status.sizeWeights > 0 ? qkvo_weight_bias.raw_ptr() : NULL, + desc_status.sizeWkspace, workspace.raw_ptr, desc_status.sizeReserve, + othr_reservespace.raw_ptr())); + + cuda_check(cudaMemset(dqkvo_weight_bias.raw_ptr(), 0, desc_status.sizeWeights)); +#if CUDNN_VERSION < 8600 + cuda_check(cudaDeviceSynchronize()); +#endif + cudnn_check(cudnnMultiHeadAttnBackwardWeights( + cudnn_handle(handle), desc_status.attn_desc, CUDNN_WGRAD_MODE_ADD, q.desc, + queries.raw_ptr(), k.desc, keys.raw_ptr(), v.desc, values.raw_ptr(), d.desc, + diff.raw_ptr(), desc_status.sizeWeights, + desc_status.sizeWeights > 0 ? qkvo_weight_bias.raw_ptr() : NULL, + desc_status.sizeWeights > 0 ? dqkvo_weight_bias.raw_ptr() : NULL, + desc_status.sizeWkspace, workspace.raw_ptr, desc_status.sizeReserve, + othr_reservespace.raw_ptr())); +} +} // namespace cuda +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/multi_head_attn/helper.h b/dnn/src/cuda/multi_head_attn/cudnn_fwbw.h similarity index 53% rename from dnn/src/cuda/multi_head_attn/helper.h rename to dnn/src/cuda/multi_head_attn/cudnn_fwbw.h index bdb37005c6c82519b349a3c277a702f9feedc851..688fa7090fe8a44832fea147352c745d269b44d9 100644 --- a/dnn/src/cuda/multi_head_attn/helper.h +++ b/dnn/src/cuda/multi_head_attn/cudnn_fwbw.h @@ -1,39 +1,46 @@ #pragma once +#include +#include "megdnn/handle.h" +#include "megdnn/thin/small_vector.h" #include "src/cuda/cudnn_wrapper.h" #if CUDNN_VERSION >= 8004 #include "megdnn/basic_types.h" #include "megdnn/oprs/nn.h" #include "src/common/algo_chooser.h" +#include "src/common/multi_head_attn/helper.h" #include "src/common/utils.h" #include "src/cuda/dropout/opr_impl.h" #include "src/cuda/handle.h" +using Param = megdnn::MultiHeadAttn::Param; +using MaskType = Param::AttnMaskType; +using InputType = Param::TensorCombinationType; + namespace megdnn { namespace cuda { struct AuxiliaryArray { public: - int* seqQArray = nullptr; - int* seqKArray = nullptr; + SmallVector seqQArray; + SmallVector seqKArray; int* devSeqQArray = nullptr; int* devSeqKArray = nullptr; - int* loWinIdx = nullptr; - int* hiWinIdx = nullptr; + SmallVector loWinIdx; + SmallVector hiWinIdx; size_t seqLenQ = 0; size_t seqLenK = 0; size_t batchSize = 0; - bool attnMask = 0; + MaskType attnMaskType = MaskType::NO_MASK; ~AuxiliaryArray(); void set( - const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK, - bool _attnMask); + Handle* handle, const size_t _batchSize, const size_t _seqLenQ, + const size_t _seqLenK, MaskType _attnMaskType); + void set_cudnn_style_mask(Handle* handle, const TensorND& attn_mask); bool is_initialized( const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK, - bool _attnMask); + MaskType _attnMaskType); }; -using Param = megdnn::MultiHeadAttn::Param; - class MultiHeadAttnStatus { DropoutStatus attn_dropout_status; DropoutStatus out_dropout_status; @@ -53,7 +60,8 @@ class MultiHeadAttnStatus { size_t kProjSize = 0; size_t vProjSize = 0; size_t oProjSize = 0; - bool attnMask = 0; + MaskType attnMaskType = MaskType::NO_MASK; + bool bias = false; size_t sizeWeights = 0; size_t sizeWkspace = 0; @@ -65,16 +73,44 @@ public: private: void set( - cudnnHandle_t handle, const Param& p, const TensorLayout& q, + Handle* handle, const Param& p, const TensorLayout& q, const TensorLayout& k, const TensorLayout& v); + void set_cudnn_style_mask(Handle* handle, const TensorND& attn_mask); bool is_initialized( const Param& p, const TensorLayout& q, const TensorLayout& k, const TensorLayout& v); friend class MultiHeadAttnBase; friend class MultiHeadAttnForwardImpl; friend class MultiHeadAttnBackwardImpl; + friend class MHAForwardCudnnOpr; + friend class MHABackwardCudnnOpr; +}; + +class MHAForwardCudnnOpr { +public: + MHAForwardCudnnOpr(){}; + + void exec(MHA_PROXY_FORWARD_EXEC_PARAM); + void deduce_layout(MHA_PROXY_FORWARD_LAYOUT_PARAM); + size_t get_workspace_in_bytes(MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM); + size_t get_mask_reservespace_in_bytes(MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM); + size_t get_othr_reservespace_in_bytes(MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM); + +private: + MultiHeadAttnStatus desc_status; }; + +class MHABackwardCudnnOpr { +public: + MHABackwardCudnnOpr(){}; + + void exec(MHA_PROXY_BACKWARD_EXEC_PARAM); + +private: + MultiHeadAttnStatus desc_status; +}; + } // namespace cuda } // namespace megdnn #endif -// vim: syntax=cpp.doxygen + // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/multi_head_attn/helper.cpp b/dnn/src/cuda/multi_head_attn/helper.cpp deleted file mode 100644 index 03e7c8e90f0fe1fcd4aa19a9dd0e8b9bb42c467c..0000000000000000000000000000000000000000 --- a/dnn/src/cuda/multi_head_attn/helper.cpp +++ /dev/null @@ -1,185 +0,0 @@ -#include "src/cuda/multi_head_attn/helper.h" -#if CUDNN_VERSION >= 8004 - -namespace megdnn { -namespace cuda { - -AuxiliaryArray::~AuxiliaryArray() { - if (loWinIdx) - free(loWinIdx); - if (hiWinIdx) - free(hiWinIdx); - if (seqQArray) - free(seqQArray); - if (seqKArray) - free(seqKArray); - if (devSeqQArray) - cuda_check(cudaFree(devSeqQArray)); - if (devSeqKArray) - cuda_check(cudaFree(devSeqKArray)); -} - -bool AuxiliaryArray::is_initialized( - const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK, - bool _attnMask) { - if (_batchSize != batchSize or _seqLenQ != seqLenQ or _seqLenK != seqLenK or - _attnMask != attnMask or !seqQArray or !seqKArray or !devSeqQArray or - !devSeqKArray or !loWinIdx or !hiWinIdx) - return false; - return true; -} - -void AuxiliaryArray::set( - const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK, - bool _attnMask) { - if (_batchSize == batchSize && _seqLenQ == seqLenQ && _seqLenK == seqLenK && - _attnMask == attnMask) - return; - else { - if (loWinIdx) - free(loWinIdx); - if (hiWinIdx) - free(hiWinIdx); - if (seqQArray) - free(seqQArray); - if (seqKArray) - free(seqKArray); - if (devSeqQArray) - cuda_check(cudaFree(devSeqQArray)); - if (devSeqKArray) - cuda_check(cudaFree(devSeqKArray)); - }; - - seqLenQ = _seqLenQ; - seqLenK = _seqLenK; - batchSize = _batchSize; - attnMask = _attnMask; - size_t seqQArraySize = 1 * batchSize; - size_t seqKArraySize = batchSize; - seqQArray = (int*)calloc(seqQArraySize, sizeof(int)); - seqKArray = (int*)calloc(seqKArraySize, sizeof(int)); - for (size_t i = 0; i < seqQArraySize; ++i) - seqQArray[i] = seqLenQ; - for (size_t i = 0; i < seqKArraySize; ++i) - seqKArray[i] = seqLenK; - - cuda_check(cudaMalloc((void**)&devSeqQArray, seqQArraySize * sizeof(int))); - cuda_check(cudaMalloc((void**)&devSeqKArray, seqKArraySize * sizeof(int))); - - cuda_check(cudaMemcpy( - devSeqQArray, seqQArray, seqQArraySize * sizeof(int), - cudaMemcpyHostToDevice)); - cuda_check(cudaMemcpy( - devSeqKArray, seqKArray, seqKArraySize * sizeof(int), - cudaMemcpyHostToDevice)); - - loWinIdx = (int*)calloc(seqLenQ, sizeof(int)); - hiWinIdx = (int*)calloc(seqLenQ, sizeof(int)); - for (size_t i = 0; i < seqLenQ; ++i) { - loWinIdx[i] = 0; - if (attnMask) - hiWinIdx[i] = i + 1; - else - hiWinIdx[i] = seqLenK; - } -} - -void MultiHeadAttnStatus::set( - cudnnHandle_t handle, const Param& p, const TensorLayout& q, - const TensorLayout& k, const TensorLayout& v) { - float attn_prob = p.training ? p.attn_prob : 0.f; - float out_prob = p.training ? p.out_prob : 0.f; - if (!attn_dropout_status.initialized()) - attn_dropout_status.set(handle, p.seed, attn_prob); - if (!out_dropout_status.initialized()) - out_dropout_status.set(handle, p.seed, out_prob); - - if (attn_dropout_status.drop_prob != attn_prob) { - attn_dropout_status.drop_prob = attn_prob; - attn_dropout_status.restore_desc(handle); - } - if (out_dropout_status.drop_prob != out_prob) { - out_dropout_status.drop_prob = out_prob; - out_dropout_status.restore_desc(handle); - } - batchSize = q.shape[0]; - seqLenQ = q.shape[1]; - seqLenK = k.shape[1]; - qSize = q.shape[2]; - kSize = k.shape[2]; - vSize = v.shape[2]; - numHeads = p.num_heads; - qProjSize = p.qproj_size ? qSize / numHeads : 0; - kProjSize = p.kproj_size ? kSize / numHeads : 0; - vProjSize = p.vproj_size ? vSize / numHeads : 0; - oProjSize = p.oproj_size ? qSize : 0; - attnMask = p.attn_mask_type >= param::MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK; - cudnnDataType_t cudnn_dtype = to_cudnn_dtype(q.dtype); - auto flag = CUDNN_ATTN_QUERYMAP_ONE_TO_ONE; - if (p.qbias or p.kbias or p.vbias or p.obias) - flag = flag | CUDNN_ATTN_ENABLE_PROJ_BIASES; -#if CUDNN_VERSION < 8600 - // TODO: CUDNN_VERSION < 8600 and out dropout > 0.0, we need to go to the proxy cuda - // implementation. - cudnn_check(cudnnSetAttnDescriptor( - attn_desc, flag, numHeads, p.sm_scaler, cudnn_dtype, cudnn_dtype, - CUDNN_DEFAULT_MATH, attn_dropout_status.desc.desc, NULL, qSize, kSize, - vSize, qProjSize, kProjSize, vProjSize, oProjSize, seqLenQ, seqLenK, - batchSize, 1)); -#else - cudnn_check(cudnnSetAttnDescriptor( - attn_desc, flag, numHeads, p.sm_scaler, cudnn_dtype, cudnn_dtype, - CUDNN_DEFAULT_MATH, attn_dropout_status.desc.desc, - out_dropout_status.desc.desc, qSize, kSize, vSize, qProjSize, kProjSize, - vProjSize, oProjSize, seqLenQ, seqLenK, batchSize, 1)); -#endif - - auxArray.set( - batchSize, seqLenQ, seqLenK, - p.attn_mask_type >= param::MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK); - - if (p.training) - cudnnGetMultiHeadAttnBuffers( - handle, attn_desc, &sizeWeights, &sizeWkspace, &sizeReserve); - else { - cudnnGetMultiHeadAttnBuffers( - handle, attn_desc, &sizeWeights, &sizeWkspace, NULL); - sizeReserve = 0; - } -} - -bool MultiHeadAttnStatus::is_initialized( - const Param& p, const TensorLayout& q, const TensorLayout& k, - const TensorLayout& v) { - float attn_prob = p.training ? p.attn_prob : 0.f; - float out_prob = p.training ? p.out_prob : 0.f; - if (!attn_dropout_status.initialized() or !out_dropout_status.initialized() or - attn_dropout_status.drop_prob != attn_prob or - out_dropout_status.drop_prob != out_prob) - return false; - if (q.shape[0] != batchSize or q.shape[1] != seqLenQ or k.shape[1] != seqLenK or - q.shape[2] != qSize or k.shape[2] != kSize or v.shape[2] != vSize or - attnMask != (p.attn_mask_type >= - param::MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK) or - numHeads != p.num_heads) { - return false; - } - if ((p.qproj_size && (qProjSize == 0 or qProjSize != qSize / p.num_heads)) or - (p.kproj_size && (kProjSize == 0 or kProjSize != kSize / p.num_heads)) or - (p.vproj_size && (vProjSize == 0 or vProjSize != vSize / p.num_heads)) or - (p.oproj_size && (oProjSize == 0 or oProjSize != q.shape[2]))) - return false; - if ((!p.qproj_size && qProjSize != 0) or (!p.kproj_size && kProjSize != 0) or - (!p.vproj_size && vProjSize != 0) or (!p.oproj_size && oProjSize != 0)) - return false; - if (!auxArray.is_initialized(batchSize, seqLenQ, seqLenK, attnMask)) - return false; - if (p.training and sizeReserve == 0) - return false; - return true; -} - -} // namespace cuda -} // namespace megdnn -#endif -// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/multi_head_attn/opr_impl.cpp b/dnn/src/cuda/multi_head_attn/opr_impl.cpp index 3a628c1dc898c649980037287aecc822abfb71ba..99d4e73a72ed12950c29b072ab81b38bb742b862 100644 --- a/dnn/src/cuda/multi_head_attn/opr_impl.cpp +++ b/dnn/src/cuda/multi_head_attn/opr_impl.cpp @@ -6,349 +6,196 @@ namespace megdnn { namespace cuda { -void MultiHeadAttnForwardImpl::deduce_layout( - const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& qkvo_weight_bias, - const TensorLayout& attn_mask, const TensorLayout& bias_k, - const TensorLayout& bias_v, TensorLayout& out, TensorLayout& attn_weight, - TensorLayout& mask_reservespace, TensorLayout& othr_reservespace) { +bool can_use_mha_cudnn(const Param& param) { #if CUDNN_VERSION < 8004 - // TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation. - MEGDNN_MARK_USED_VAR(queries); - MEGDNN_MARK_USED_VAR(keys); - MEGDNN_MARK_USED_VAR(values); - MEGDNN_MARK_USED_VAR(qkvo_weight_bias); - MEGDNN_MARK_USED_VAR(attn_mask); - MEGDNN_MARK_USED_VAR(bias_k); - MEGDNN_MARK_USED_VAR(bias_v); - MEGDNN_MARK_USED_VAR(out); - MEGDNN_MARK_USED_VAR(attn_weight); - MEGDNN_MARK_USED_VAR(mask_reservespace); - MEGDNN_MARK_USED_VAR(othr_reservespace); - return; + MEGDNN_MARK_USED_VAR(param); + return false; #else - MEGDNN_MARK_USED_VAR(qkvo_weight_bias); - MEGDNN_MARK_USED_VAR(attn_mask); - auto p = param(); - megdnn_assert( - queries.ndim == 3, "queries.ndim should be 3, but got %zu", queries.ndim); - - if (!desc_status.is_initialized(p, queries, keys, values)) - desc_status.set(cudnn_handle(this->handle()), p, queries, keys, values); - - auto input_type = p.tensor_combination_type; - using INPUT_TYPE = Param::TENSOR_COMBINATION_TYPE; - bool have_biaskv = - input_type == INPUT_TYPE::ONLY_BIASKV or input_type == INPUT_TYPE::ALL; - size_t attn_seqk_dim_add = (have_biaskv ? 1 : 0) + (p.add_zero_attn ? 1 : 0); - attn_weight = TensorLayout( - TensorShape{ - queries.shape[0] * p.num_heads, queries.shape[1], - keys.shape[1] + attn_seqk_dim_add}, - queries.dtype); - - size_t osize = p.oproj_size != 0 ? p.oproj_size - : (p.vproj_size != 0 ? p.vproj_size : p.v_size); - out = TensorLayout( - TensorShape{queries.shape[0], queries.shape[1], osize}, queries.dtype); - mask_reservespace = TensorLayout(TensorShape{0}, dtype::Uint8()); - othr_reservespace = - TensorLayout(TensorShape{desc_status.sizeReserve}, queries.dtype); + bool flag = true; + size_t bias_num = 0; + size_t weight_num = 0; + bias_num += (param.qbias ? 1 : 0); + bias_num += (param.kbias ? 1 : 0); + bias_num += (param.vbias ? 1 : 0); + bias_num += (param.obias ? 1 : 0); + weight_num += (param.qproj_size > 0 ? 1 : 0); + weight_num += (param.kproj_size > 0 ? 1 : 0); + weight_num += (param.vproj_size > 0 ? 1 : 0); + weight_num += (param.oproj_size > 0 ? 1 : 0); + if (bias_num != weight_num && bias_num != 0) { + flag = false; + } +#if CUDNN_VERSION < 8600 + if (bias_num > 0 && param.training == true) { + flag = false; + } + if (param.out_prob > 0) { + flag = false; + } +#endif + if (param.need_weights) { + flag = false; + } + if (param.attn_mask_type == MaskType::USER_DEFINED_MASK) { + flag = false; + } + if (param.attn_mask_type == MaskType::CUDNN_STYLE_MASK) { + megdnn_assert( + flag == true, + "maybe_cudnn_style_mask=True, but can not run cudnn impl, Please make " + "sure that cuda is available, and check you parameter or do not use " + "cudnn style mask."); + } + return flag; +#endif +} +void MultiHeadAttnForwardImpl::deduce_layout(MHA_FORWARD_LAYOUT_PARAM) { + Param p = param(); +#if CUDNN_VERSION < 8004 + proxy_opr.deduce_layout( + this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask, + bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace); +#else + if (can_use_mha_cudnn(p)) { + cudnn_opr.deduce_layout( + this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask, + bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace); + } else { + proxy_opr.deduce_layout( + this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask, + bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace); + } #endif } size_t MultiHeadAttnForwardImpl::get_workspace_in_bytes( - const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& qkvo_weight_bias, - const TensorLayout& attn_mask, const TensorLayout& bias_k, - const TensorLayout& bias_v, const TensorLayout& out, - const TensorLayout& attn_weight, const TensorLayout& mask_reservespace, - const TensorLayout& othr_reservespace) { + MHA_FORWARD_LAYOUT_CONST_PARAM) { + Param p = param(); #if CUDNN_VERSION < 8004 - // TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation. - MEGDNN_MARK_USED_VAR(queries); - MEGDNN_MARK_USED_VAR(keys); - MEGDNN_MARK_USED_VAR(values); - MEGDNN_MARK_USED_VAR(qkvo_weight_bias); - MEGDNN_MARK_USED_VAR(attn_mask); - MEGDNN_MARK_USED_VAR(bias_k); - MEGDNN_MARK_USED_VAR(bias_v); - MEGDNN_MARK_USED_VAR(out); - MEGDNN_MARK_USED_VAR(attn_weight); - MEGDNN_MARK_USED_VAR(mask_reservespace); - MEGDNN_MARK_USED_VAR(othr_reservespace); - return 0; + return proxy_opr.get_workspace_in_bytes( + this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask, + bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace); #else - MEGDNN_MARK_USED_VAR(qkvo_weight_bias); - MEGDNN_MARK_USED_VAR(attn_mask); - MEGDNN_MARK_USED_VAR(out); - MEGDNN_MARK_USED_VAR(attn_weight); - MEGDNN_MARK_USED_VAR(mask_reservespace); - MEGDNN_MARK_USED_VAR(othr_reservespace); - - if (!desc_status.is_initialized(param(), queries, keys, values)) - desc_status.set(cudnn_handle(this->handle()), param(), queries, keys, values); - - return desc_status.sizeWkspace; + if (can_use_mha_cudnn(p)) { + return cudnn_opr.get_workspace_in_bytes( + this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask, + bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace); + } else { + return proxy_opr.get_workspace_in_bytes( + this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask, + bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace); + } #endif } size_t MultiHeadAttnForwardImpl::get_mask_reservespace_in_bytes( - const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& qkvo_weight_bias, - const TensorLayout& attn_mask, const TensorLayout& bias_k, - const TensorLayout& bias_v, const TensorLayout& out, - const TensorLayout& attn_weight, const TensorLayout& mask_reservespace, - const TensorLayout& othr_reservespace) { + MHA_FORWARD_LAYOUT_CONST_PARAM) { + Param p = param(); #if CUDNN_VERSION < 8004 - // TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation. - MEGDNN_MARK_USED_VAR(queries); - MEGDNN_MARK_USED_VAR(keys); - MEGDNN_MARK_USED_VAR(values); - MEGDNN_MARK_USED_VAR(qkvo_weight_bias); - MEGDNN_MARK_USED_VAR(attn_mask); - MEGDNN_MARK_USED_VAR(bias_k); - MEGDNN_MARK_USED_VAR(bias_v); - MEGDNN_MARK_USED_VAR(out); - MEGDNN_MARK_USED_VAR(attn_weight); - MEGDNN_MARK_USED_VAR(mask_reservespace); - MEGDNN_MARK_USED_VAR(othr_reservespace); - return 0; + return proxy_opr.get_mask_reservespace_in_bytes( + this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask, + bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace); #else - MEGDNN_MARK_USED_VAR(qkvo_weight_bias); - MEGDNN_MARK_USED_VAR(attn_mask); - MEGDNN_MARK_USED_VAR(out); - MEGDNN_MARK_USED_VAR(attn_weight); - MEGDNN_MARK_USED_VAR(mask_reservespace); - MEGDNN_MARK_USED_VAR(othr_reservespace); - if (!desc_status.is_initialized(param(), queries, keys, values)) - desc_status.set(cudnn_handle(this->handle()), param(), queries, keys, values); - return 0; + if (can_use_mha_cudnn(p)) { + return cudnn_opr.get_mask_reservespace_in_bytes( + this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask, + bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace); + } else { + return proxy_opr.get_mask_reservespace_in_bytes( + this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask, + bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace); + } #endif } + size_t MultiHeadAttnForwardImpl::get_othr_reservespace_in_bytes( - const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& qkvo_weight_bias, - const TensorLayout& attn_mask, const TensorLayout& bias_k, - const TensorLayout& bias_v, const TensorLayout& out, - const TensorLayout& attn_weight, const TensorLayout& mask_reservespace, - const TensorLayout& othr_reservespace) { + MHA_FORWARD_LAYOUT_CONST_PARAM) { + Param p = param(); #if CUDNN_VERSION < 8004 - // TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation. - MEGDNN_MARK_USED_VAR(queries); - MEGDNN_MARK_USED_VAR(keys); - MEGDNN_MARK_USED_VAR(values); - MEGDNN_MARK_USED_VAR(qkvo_weight_bias); - MEGDNN_MARK_USED_VAR(attn_mask); - MEGDNN_MARK_USED_VAR(bias_k); - MEGDNN_MARK_USED_VAR(bias_v); - MEGDNN_MARK_USED_VAR(out); - MEGDNN_MARK_USED_VAR(attn_weight); - MEGDNN_MARK_USED_VAR(mask_reservespace); - MEGDNN_MARK_USED_VAR(othr_reservespace); - return 0; + return proxy_opr.get_othr_reservespace_in_bytes( + this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask, + bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace); #else - MEGDNN_MARK_USED_VAR(qkvo_weight_bias); - MEGDNN_MARK_USED_VAR(attn_mask); - MEGDNN_MARK_USED_VAR(bias_k); - MEGDNN_MARK_USED_VAR(bias_v); - MEGDNN_MARK_USED_VAR(out); - MEGDNN_MARK_USED_VAR(attn_weight); - MEGDNN_MARK_USED_VAR(mask_reservespace); - MEGDNN_MARK_USED_VAR(othr_reservespace); - if (!desc_status.is_initialized(param(), queries, keys, values)) - desc_status.set(cudnn_handle(this->handle()), param(), queries, keys, values); - return desc_status.sizeReserve; + if (can_use_mha_cudnn(p)) { + return cudnn_opr.get_othr_reservespace_in_bytes( + this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask, + bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace); + } else { + return proxy_opr.get_othr_reservespace_in_bytes( + this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask, + bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace); + } #endif } -void MultiHeadAttnForwardImpl::exec( - _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values, - _megdnn_tensor_in qkvo_weight_bias, _megdnn_tensor_in attn_mask, - _megdnn_tensor_in bias_k, _megdnn_tensor_in bias_v, _megdnn_tensor_out out, - _megdnn_tensor_out attn_weight, _megdnn_tensor_out mask_reservespace, - _megdnn_tensor_out othr_reservespace, _megdnn_workspace workspace) { -#if CUDNN_VERSION < 8004 - // TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation. - MEGDNN_MARK_USED_VAR(queries); - MEGDNN_MARK_USED_VAR(keys); - MEGDNN_MARK_USED_VAR(values); - MEGDNN_MARK_USED_VAR(qkvo_weight_bias); - MEGDNN_MARK_USED_VAR(attn_mask); - MEGDNN_MARK_USED_VAR(bias_k); - MEGDNN_MARK_USED_VAR(bias_v); - MEGDNN_MARK_USED_VAR(out); - MEGDNN_MARK_USED_VAR(attn_weight); - MEGDNN_MARK_USED_VAR(mask_reservespace); - MEGDNN_MARK_USED_VAR(othr_reservespace); - megdnn_throw( - "The cudnn version is lower than 8.0.4. Please upgrade the cudnn version."); -#else +void MultiHeadAttnForwardImpl::exec(MHA_FORWARD_EXEC_PARAM) { check_exec( queries.layout, keys.layout, values.layout, qkvo_weight_bias.layout, attn_mask.layout, bias_k.layout, bias_v.layout, out.layout, attn_weight.layout, mask_reservespace.layout, othr_reservespace.layout, workspace.size); - auto p = param(); - - if (!desc_status.is_initialized(p, queries.layout, keys.layout, values.layout)) - desc_status.set( - cudnn_handle(this->handle()), p, queries.layout, keys.layout, - values.layout); - - size_t osize = - desc_status.oProjSize != 0 - ? desc_status.oProjSize - : (desc_status.vProjSize != 0 ? desc_status.vProjSize * p.num_heads - : desc_status.vSize); - SeqTensorDesc q{queries.layout, desc_status.batchSize, - desc_status.seqLenQ, desc_status.qSize, - p.input_order, desc_status.auxArray.seqQArray}; - SeqTensorDesc o{out.layout, desc_status.batchSize, desc_status.seqLenQ, - osize, p.input_order, desc_status.auxArray.seqQArray}; - SeqTensorDesc k{keys.layout, desc_status.batchSize, - desc_status.seqLenK, desc_status.kSize, - p.input_order, desc_status.auxArray.seqKArray}; - SeqTensorDesc v{values.layout, desc_status.batchSize, - desc_status.seqLenK, desc_status.vSize, - p.input_order, desc_status.auxArray.seqKArray}; - - cudnn_check(cudnnMultiHeadAttnForward( - cudnn_handle(this->handle()), desc_status.attn_desc, -1, - desc_status.auxArray.loWinIdx, desc_status.auxArray.hiWinIdx, - desc_status.auxArray.devSeqQArray, desc_status.auxArray.devSeqKArray, - q.desc, queries.raw_ptr(), p.reslink ? queries.raw_ptr() : NULL, k.desc, - keys.raw_ptr(), v.desc, values.raw_ptr(), o.desc, out.raw_ptr(), - desc_status.sizeWeights, - desc_status.sizeWeights > 0 ? qkvo_weight_bias.raw_ptr() : NULL, - desc_status.sizeWkspace, workspace.raw_ptr, - p.training ? desc_status.sizeReserve : 0, - p.training ? othr_reservespace.raw_ptr() : NULL)); -#endif -} - -void MultiHeadAttnBackwardImpl::exec( - _megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys, - _megdnn_tensor_in values, _megdnn_tensor_in qkvo_weight_bias, - _megdnn_tensor_in attn_mask, _megdnn_tensor_in attn_weight, - _megdnn_tensor_in mask_reservespace, _megdnn_tensor_in othr_reservespace, - _megdnn_tensor_out dqueries, _megdnn_tensor_out dkeys, - _megdnn_tensor_out dvalues, _megdnn_tensor_out dqkvo_weight_bias, - _megdnn_tensor_out dbias_k, _megdnn_tensor_out dbias_v, - _megdnn_workspace workspace) { + Param p = param(); #if CUDNN_VERSION < 8004 - // TODO: CUDNN_VERSION < 8004 and param().bias = true, we need to go to the proxy - // cuda implementation. - MEGDNN_MARK_USED_VAR(diff); - MEGDNN_MARK_USED_VAR(queries); - MEGDNN_MARK_USED_VAR(keys); - MEGDNN_MARK_USED_VAR(values); - MEGDNN_MARK_USED_VAR(qkvo_weight_bias); - MEGDNN_MARK_USED_VAR(attn_mask); - MEGDNN_MARK_USED_VAR(attn_weight); - MEGDNN_MARK_USED_VAR(mask_reservespace); - MEGDNN_MARK_USED_VAR(othr_reservespace); - MEGDNN_MARK_USED_VAR(dqueries); - MEGDNN_MARK_USED_VAR(dkeys); - MEGDNN_MARK_USED_VAR(dvalues); - MEGDNN_MARK_USED_VAR(dqkvo_weight_bias); - MEGDNN_MARK_USED_VAR(dbias_k); - MEGDNN_MARK_USED_VAR(dbias_v); - megdnn_throw( - "The cudnn version is lower than 8.0.4. Please upgrade the cudnn version."); + proxy_opr.exec( + this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask, + bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace, + workspace); #else -#if CUDNN_VERSION < 8600 - megdnn_assert( - !(param().qbias or param().kbias or param().vbias or param().obias), - "If the cudnn version is lower than 8.6.0, param().bias must be false, " - "but got true, because there is an error in the " - "dbias result during the backward calculation."); + if (can_use_mha_cudnn(p)) { + cudnn_opr.exec( + this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask, + bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace, + workspace); + } else { + proxy_opr.exec( + this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask, + bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace, + workspace); + } #endif - MEGDNN_MARK_USED_VAR(attn_mask); - MEGDNN_MARK_USED_VAR(dbias_k); - MEGDNN_MARK_USED_VAR(dbias_v); +} +void MultiHeadAttnBackwardImpl::exec(MHA_BACKWARD_EXEC_PARAM) { check_exec( diff.layout, queries.layout, keys.layout, values.layout, qkvo_weight_bias.layout, attn_mask.layout, attn_weight.layout, mask_reservespace.layout, othr_reservespace.layout, dqueries.layout, dkeys.layout, dvalues.layout, dqkvo_weight_bias.layout, dbias_k.layout, dbias_v.layout, workspace.size); - auto p = param(); - - if (!desc_status.is_initialized(p, queries.layout, keys.layout, values.layout)) - desc_status.set( - cudnn_handle(this->handle()), p, queries.layout, keys.layout, - values.layout); - - size_t osize = - desc_status.oProjSize != 0 - ? desc_status.oProjSize - : (desc_status.vProjSize != 0 ? desc_status.vProjSize * p.num_heads - : desc_status.vSize); - SeqTensorDesc q{queries.layout, desc_status.batchSize, - desc_status.seqLenQ, desc_status.qSize, - p.input_order, desc_status.auxArray.seqQArray}; - SeqTensorDesc d{diff.layout, desc_status.batchSize, desc_status.seqLenQ, - osize, p.input_order, desc_status.auxArray.seqQArray}; - SeqTensorDesc k{keys.layout, desc_status.batchSize, - desc_status.seqLenK, desc_status.kSize, - p.input_order, desc_status.auxArray.seqKArray}; - SeqTensorDesc v{values.layout, desc_status.batchSize, - desc_status.seqLenK, desc_status.vSize, - p.input_order, desc_status.auxArray.seqKArray}; - - cudnn_check(cudnnMultiHeadAttnBackwardData( - cudnn_handle(this->handle()), desc_status.attn_desc, - desc_status.auxArray.loWinIdx, desc_status.auxArray.hiWinIdx, - desc_status.auxArray.devSeqQArray, desc_status.auxArray.devSeqKArray, - d.desc, diff.raw_ptr(), q.desc, dqueries.raw_ptr(), queries.raw_ptr(), - k.desc, dkeys.raw_ptr(), keys.raw_ptr(), v.desc, dvalues.raw_ptr(), - values.raw_ptr(), desc_status.sizeWeights, - desc_status.sizeWeights > 0 ? qkvo_weight_bias.raw_ptr() : NULL, - desc_status.sizeWkspace, workspace.raw_ptr, desc_status.sizeReserve, - othr_reservespace.raw_ptr())); - - cuda_check(cudaMemset(dqkvo_weight_bias.raw_ptr(), 0, desc_status.sizeWeights)); -#if CUDNN_VERSION < 8600 - cuda_check(cudaDeviceSynchronize()); -#endif - cudnn_check(cudnnMultiHeadAttnBackwardWeights( - cudnn_handle(this->handle()), desc_status.attn_desc, CUDNN_WGRAD_MODE_ADD, - q.desc, queries.raw_ptr(), k.desc, keys.raw_ptr(), v.desc, values.raw_ptr(), - d.desc, diff.raw_ptr(), desc_status.sizeWeights, - desc_status.sizeWeights > 0 ? qkvo_weight_bias.raw_ptr() : NULL, - desc_status.sizeWeights > 0 ? dqkvo_weight_bias.raw_ptr() : NULL, - desc_status.sizeWkspace, workspace.raw_ptr, desc_status.sizeReserve, - othr_reservespace.raw_ptr())); + Param p = param(); +#if CUDNN_VERSION < 8004 + proxy_opr.exec( + this->handle(), p, diff, queries, keys, values, qkvo_weight_bias, attn_mask, + attn_weight, mask_reservespace, othr_reservespace, dqueries, dkeys, dvalues, + dqkvo_weight_bias, dbias_k, dbias_v, workspace); +#else + if (can_use_mha_cudnn(p)) { + cudnn_opr.exec( + this->handle(), p, diff, queries, keys, values, qkvo_weight_bias, + attn_mask, attn_weight, mask_reservespace, othr_reservespace, dqueries, + dkeys, dvalues, dqkvo_weight_bias, dbias_k, dbias_v, workspace); + } else { + proxy_opr.exec( + this->handle(), p, diff, queries, keys, values, qkvo_weight_bias, + attn_mask, attn_weight, mask_reservespace, othr_reservespace, dqueries, + dkeys, dvalues, dqkvo_weight_bias, dbias_k, dbias_v, workspace); + } #endif } + size_t MultiHeadAttnBackwardImpl::get_workspace_in_bytes( - const TensorLayout& diff, const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& qkvo_weight_bias, - const TensorLayout& attn_mask, const TensorLayout& attn_weight, - const TensorLayout& mask_reservespace, const TensorLayout& othr_reservespace, - const TensorLayout& dqueries, const TensorLayout& dkeys, - const TensorLayout& dvalues, const TensorLayout& dqkvo_weight_bias, - const TensorLayout& dbias_k, const TensorLayout& dbias_v) { - MEGDNN_MARK_USED_VAR(diff); - MEGDNN_MARK_USED_VAR(queries); - MEGDNN_MARK_USED_VAR(keys); - MEGDNN_MARK_USED_VAR(values); - MEGDNN_MARK_USED_VAR(qkvo_weight_bias); - MEGDNN_MARK_USED_VAR(attn_mask); - MEGDNN_MARK_USED_VAR(attn_weight); - MEGDNN_MARK_USED_VAR(mask_reservespace); - MEGDNN_MARK_USED_VAR(othr_reservespace); - MEGDNN_MARK_USED_VAR(dqueries); - MEGDNN_MARK_USED_VAR(dkeys); - MEGDNN_MARK_USED_VAR(dvalues); - MEGDNN_MARK_USED_VAR(dqkvo_weight_bias); - MEGDNN_MARK_USED_VAR(dbias_k); - MEGDNN_MARK_USED_VAR(dbias_v); - return 0; + MHA_BACKWARD_LAYOUT_CONST_PARAM) { + Param p = param(); + if (can_use_mha_cudnn(p)) { + return 0; + } else { + return proxy_opr.get_workspace_in_bytes( + this->handle(), p, diff, queries, keys, values, qkvo_weight_bias, + attn_mask, attn_weight, mask_reservespace, othr_reservespace, dqueries, + dkeys, dvalues, dqkvo_weight_bias, dbias_k, dbias_v); + } } } // namespace cuda } // namespace megdnn -// vim: syntax=cpp.doxygen + // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/multi_head_attn/opr_impl.h b/dnn/src/cuda/multi_head_attn/opr_impl.h index b7aed9774c9d7f01b27500ca1f519d46701e220c..ac2cbac5fe794c7451710417f832ed71fc4a3dec 100644 --- a/dnn/src/cuda/multi_head_attn/opr_impl.h +++ b/dnn/src/cuda/multi_head_attn/opr_impl.h @@ -4,78 +4,45 @@ #include "src/common/reduce_helper.h" #include "src/cuda/cudnn_wrapper.h" #include "src/cuda/handle.h" -#include "src/cuda/multi_head_attn/helper.h" +#include "src/cuda/multi_head_attn/cudnn_fwbw.h" +#include "src/cuda/multi_head_attn/proxy_bw.h" +#include "src/cuda/multi_head_attn/proxy_fw.h" #include "src/cuda/utils.h" namespace megdnn { namespace cuda { +using Param = megdnn::MultiHeadAttn::Param; +using MaskType = Param::AttnMaskType; +using InputType = Param::TensorCombinationType; + +bool can_use_mha_cudnn(const Param& param); + class MultiHeadAttnForwardImpl final : public MultiHeadAttnForward { public: using MultiHeadAttnForward::MultiHeadAttnForward; #if CUDNN_VERSION >= 8004 - MultiHeadAttnStatus desc_status; + MHAForwardCudnnOpr cudnn_opr; #endif + MHAForwardProxyOpr proxy_opr; - void exec( - _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values, - _megdnn_tensor_in qkvo_weight_bias, _megdnn_tensor_in attn_mask, - _megdnn_tensor_in bias_k, _megdnn_tensor_in bias_v, _megdnn_tensor_out out, - _megdnn_tensor_out attn_weight, _megdnn_tensor_out mask_reservespace, - _megdnn_tensor_out othr_reservespace, _megdnn_workspace workspace) override; - void deduce_layout( - const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& qkvo_weight_bias, - const TensorLayout& attn_mask, const TensorLayout& bias_k, - const TensorLayout& bias_v, TensorLayout& out, TensorLayout& attn_weight, - TensorLayout& mask_reservespace, TensorLayout& othr_reservespace) override; - size_t get_mask_reservespace_in_bytes( - const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& qkvo_weight_bias, - const TensorLayout& attn_mask, const TensorLayout& bias_k, - const TensorLayout& bias_v, const TensorLayout& out, - const TensorLayout& attn_weight, const TensorLayout& mask_reservespace, - const TensorLayout& othr_reservespace) override; - size_t get_othr_reservespace_in_bytes( - const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& qkvo_weight_bias, - const TensorLayout& attn_mask, const TensorLayout& bias_k, - const TensorLayout& bias_v, const TensorLayout& out, - const TensorLayout& attn_weight, const TensorLayout& mask_reservespace, - const TensorLayout& othr_reservespace) override; - size_t get_workspace_in_bytes( - const TensorLayout& queries, const TensorLayout& keys, - const TensorLayout& values, const TensorLayout& qkvo_weight_bias, - const TensorLayout& attn_mask, const TensorLayout& bias_k, - const TensorLayout& bias_v, const TensorLayout& out, - const TensorLayout& attn_weight, const TensorLayout& mask_reservespace, - const TensorLayout& othr_reservespace) override; + void exec(MHA_FORWARD_EXEC_PARAM) override; + void deduce_layout(MHA_FORWARD_LAYOUT_PARAM) override; + size_t get_workspace_in_bytes(MHA_FORWARD_LAYOUT_CONST_PARAM) override; + size_t get_mask_reservespace_in_bytes(MHA_FORWARD_LAYOUT_CONST_PARAM) override; + size_t get_othr_reservespace_in_bytes(MHA_FORWARD_LAYOUT_CONST_PARAM) override; }; class MultiHeadAttnBackwardImpl final : public MultiHeadAttnBackward { public: using MultiHeadAttnBackward::MultiHeadAttnBackward; #if CUDNN_VERSION >= 8004 - MultiHeadAttnStatus desc_status; + MHABackwardCudnnOpr cudnn_opr; #endif - void exec( - _megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys, - _megdnn_tensor_in values, _megdnn_tensor_in qkvo_weight_bias, - _megdnn_tensor_in attn_mask, _megdnn_tensor_in attn_weight, - _megdnn_tensor_in mask_reservespace, _megdnn_tensor_in othr_reservespace, - _megdnn_tensor_out dqueries, _megdnn_tensor_out dkeys, - _megdnn_tensor_out dvalues, _megdnn_tensor_out dqkvo_weight_bias, - _megdnn_tensor_out dbias_k, _megdnn_tensor_out dbias_v, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes( - const TensorLayout& diff, const TensorLayout& queries, - const TensorLayout& keys, const TensorLayout& values, - const TensorLayout& qkvo_weight_bias, const TensorLayout& attn_mask, - const TensorLayout& attn_weight, const TensorLayout& mask_reservespace, - const TensorLayout& othr_reservespace, const TensorLayout& dqueries, - const TensorLayout& dkeys, const TensorLayout& dvalues, - const TensorLayout& dqkvo_weight_bias, const TensorLayout& dbias_k, - const TensorLayout& dbias_v) override; + MHABackwardProxyOpr proxy_opr; + + void exec(MHA_BACKWARD_EXEC_PARAM) override; + size_t get_workspace_in_bytes(MHA_BACKWARD_LAYOUT_CONST_PARAM) override; }; } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/multi_head_attn/proxy_bw.cpp b/dnn/src/cuda/multi_head_attn/proxy_bw.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1e546ee26879b41da49ea88f2e36d9f5db0e428a --- /dev/null +++ b/dnn/src/cuda/multi_head_attn/proxy_bw.cpp @@ -0,0 +1,23 @@ +#include "src/cuda/multi_head_attn/proxy_bw.h" +#include "megdnn/basic_types.h" +#include "megdnn/handle.h" +#include "megdnn/oprs/nn.h" + +namespace megdnn { +namespace cuda { + +#define cb(DType) \ + void MHABackwardProxyOpr::move_scaler_to_device( \ + Handle* handle, DTypeTrait::ctype* dst, \ + DTypeTrait::ctype* src) { \ + cudaMemcpyAsync( \ + dst, src, sizeof(DTypeTrait::ctype), cudaMemcpyHostToDevice, \ + cuda_stream(handle)); \ + }; +MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb + +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/multi_head_attn/proxy_bw.h b/dnn/src/cuda/multi_head_attn/proxy_bw.h new file mode 100644 index 0000000000000000000000000000000000000000..ab281b5171fc55e53b1a99335a662fa3bdd66117 --- /dev/null +++ b/dnn/src/cuda/multi_head_attn/proxy_bw.h @@ -0,0 +1,33 @@ +#pragma once +#include "megdnn/handle.h" +#include "megdnn/oprs.h" +#include "megdnn/oprs/general.h" +#include "megdnn/oprs/nn.h" +#include "src/common/multi_head_attn/proxy_backward_base.h" +#include "src/common/reduce_helper.h" +#include "src/cuda/cudnn_wrapper.h" +#include "src/cuda/handle.h" +#include "src/cuda/utils.h" + +namespace megdnn { +namespace cuda { + +using Param = megdnn::MultiHeadAttn::Param; +using MaskType = Param::AttnMaskType; +using InputType = Param::TensorCombinationType; +using multi_head_attn::matmul_deduce_layout; +using multi_head_attn::matmul_exec; + +class MHABackwardProxyOpr final : public multi_head_attn::MHABackwardProxyBase { +public: + MHABackwardProxyOpr() : MHABackwardProxyBase() {} + +#define cb(DType) \ + void move_scaler_to_device( \ + Handle*, DTypeTrait::ctype*, DTypeTrait::ctype*) override; + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb +}; +} // namespace cuda +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/multi_head_attn/proxy_fw.cpp b/dnn/src/cuda/multi_head_attn/proxy_fw.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d857460b82c5f1f301148be8b7e58ccb0860b6dc --- /dev/null +++ b/dnn/src/cuda/multi_head_attn/proxy_fw.cpp @@ -0,0 +1,22 @@ +#include "src/cuda/multi_head_attn/proxy_fw.h" +#include "megdnn/basic_types.h" +#include "megdnn/dtype.h" +#include "src/cuda/matrix_mul/opr_impl.h" + +namespace megdnn { +namespace cuda { + +#define cb(DType) \ + void MHAForwardProxyOpr::move_scaler_to_device( \ + Handle* handle, DTypeTrait::ctype* dst, \ + DTypeTrait::ctype* src) { \ + cudaMemcpyAsync( \ + dst, src, sizeof(DTypeTrait::ctype), cudaMemcpyHostToDevice, \ + cuda_stream(handle)); \ + }; +MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb + +} // namespace cuda +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/multi_head_attn/proxy_fw.h b/dnn/src/cuda/multi_head_attn/proxy_fw.h new file mode 100644 index 0000000000000000000000000000000000000000..ab0c35bb2edeb5bdbdb8b0afced3dd9ae914e02e --- /dev/null +++ b/dnn/src/cuda/multi_head_attn/proxy_fw.h @@ -0,0 +1,35 @@ +#pragma once +#include "megdnn/handle.h" +#include "megdnn/oprs.h" +#include "megdnn/oprs/general.h" +#include "src/common/multi_head_attn/helper.h" +#include "src/common/multi_head_attn/proxy_forward_base.h" +#include "src/common/reduce_helper.h" +#include "src/common/utils.h" +#include "src/cuda/cudnn_wrapper.h" +#include "src/cuda/handle.h" +#include "src/cuda/matrix_mul/opr_impl.h" +#include "src/cuda/utils.h" + +namespace megdnn { +namespace cuda { + +using Param = megdnn::MultiHeadAttn::Param; +using MaskType = Param::AttnMaskType; +using InputType = Param::TensorCombinationType; +using multi_head_attn::matmul_deduce_layout; +using multi_head_attn::matmul_exec; + +class MHAForwardProxyOpr final : public multi_head_attn::MHAForwardProxyBase { +public: + MHAForwardProxyOpr() : MHAForwardProxyBase() {} + +#define cb(DType) \ + void move_scaler_to_device( \ + Handle*, DTypeTrait::ctype*, DTypeTrait::ctype*) override; + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb +}; +} // namespace cuda +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index c3568702576840681f30994c923cd40998d7f446..2d5e44f9e7baf85c45d932898b00d103ac60c214 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -52,6 +52,7 @@ from .tensor import ( concat, expand_dims, ones, + repeat, reshape, squeeze, transpose, @@ -2071,9 +2072,9 @@ def _mha_shape_check( k_dim = key.ndim v_dim = value.ndim kpm_dim = key_padding_mask.ndim if key_padding_mask is not None else 0 - kpm_shape = key_padding_mask.shape if key_padding_mask is not None else None + kpm_shape = tuple(key_padding_mask.shape) if key_padding_mask is not None else None am_dim = attn_mask.ndim if attn_mask is not None else 0 - am_shape = attn_mask.shape if attn_mask is not None else None + am_shape = tuple(attn_mask.shape) if attn_mask is not None else None # Shape check. if q_dim == 3: # Batched Inputs @@ -2098,10 +2099,10 @@ def _mha_shape_check( "For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D" f" but found {kpm_dim}-D tensor instead" ) - expected_shape0 = (k_shape0, k_shape1) # norm style - expected_shape1 = (2, k_shape0) # cudnn style - assert expected_shape0 == kpm_shape and expected_shape1 == kpm_shape, ( - f"For batched (3-D) `query`, expected `key_padding_mask.shape` equal {expected_shape0} or {expected_shape1}" + assert (kpm_shape[0] == k_shape0 and kpm_shape[1] == k_shape1) or ( + kpm_shape[0] == 2 and kpm_shape[1] == k_shape0 + ), ( + f"For batched (3-D) `query`, expected `key_padding_mask.shape` equal {k_shape0, k_shape1} or {2, k_shape0}" f" but found {kpm_shape} instead" ) if attn_mask is not None: @@ -2110,16 +2111,15 @@ def _mha_shape_check( f" but found {am_dim}-D tensor instead" ) if am_dim == 2: - expected_shape0 = (q_shape1, k_shape1) # norm style - expected_shape1 = (2, q_shape1) # cudnn style - assert ( - am_shape == expected_shape0 or am_shape == expected_shape1 - ), f"Expected `attn_mask` shape to be {expected_shape0} or {expected_shape1} but got {am_shape}" + assert (am_shape[0] == q_shape1 and am_shape[1] == k_shape1) or ( + am_shape[0] == 2 and am_shape[1] == q_shape1 + ), f"Expected `attn_mask` shape to be {q_shape1, k_shape1} or {2, q_shape1} but got {am_shape}" if am_dim == 3: - expected_shape = (q_shape0 * num_heads, q_shape1, k_shape1) assert ( - am_shape == expected_shape - ), f"Expected `attn_mask` shape to be {expected_shape0} but got {am_shape}" + am_shape[0] == q_shape0 * num_heads + and am_shape[1] == q_shape1 + and am_shape[2] == k_shape1 + ), f"Expected `attn_mask` shape to be {q_shape0 * num_heads, q_shape1, k_shape1} but got {am_shape}" elif q_dim == 2: # Unbatched Inputs is_batched = False @@ -2139,10 +2139,11 @@ def _mha_shape_check( "For unbatched (2-D) `query`, expected `key_padding_mask` to be `None`, 1-D or 2-D" f" but found {kpm_dim}-D tensor instead" ) - expected_shape0 = k_shape0 # norm style - expected_shape1 = (2, 1) # cudnn style - assert expected_shape0 == kpm_shape or expected_shape1 == kpm_shape, ( - f"For batched (3-D) `query`, expected `key_padding_mask.shape` equal {expected_shape0} or {expected_shape1}" + assert (kpm_dim == 1 and kpm_shape[0] == k_shape0) or ( + kpm_dim == 2 and kpm_shape[0] == 2, + kpm_shape[1] == 1, + ), ( + f"For batched (3-D) `query`, expected `key_padding_mask.shape` equal {k_shape0} or {2,1}" f" but found {kpm_shape} tensor instead" ) if attn_mask is not None: @@ -2151,16 +2152,15 @@ def _mha_shape_check( f" but found {am_dim}-D tensor instead" ) if am_dim == 2: - expected_shape0 = (q_shape0, k_shape0) # normal style mask - expected_shape1 = (2, q_shape0) # cudnn style mask - assert ( - am_shape == expected_shape0 or am_shape == expected_shape1 - ), f"Expected `attn_mask` shape to be {expected_shape0} or {expected_shape1} but got {am_shape}" + assert (am_shape[0] == q_shape0 and am_shape[1] == k_shape0) or ( + am_shape[0] == 2 and am_shape[1] == q_shape0 + ), f"Expected `attn_mask` shape to be {q_shape0, k_shape0} or {2, q_shape0} but got {am_shape}" if am_dim == 3: - expected_shape = (num_heads, q_shape0, k_shape0) assert ( - am_shape == expected_shape - ), f"Expected `attn_mask` shape to be {expected_shape} but got {am_shape}" + am_shape[0] == num_heads + and am_shape[1] == q_shape0 + and am_shape[2] == k_shape0 + ), f"Expected `attn_mask` shape to be {num_heads, q_shape0, k_shape0} but got {am_shape}" else: raise AssertionError( f"query should be unbatched 2D or batched 3D tensor but received {q_dim}-D query tensor" @@ -2176,8 +2176,9 @@ def _canonical_mask( other_name: str, target_type, check_other: bool = True, + maybe_cudnn_style_mask=False, ) -> Optional[Tensor]: - if mask is not None: + if mask is not None and not maybe_cudnn_style_mask: _mask_dtype = mask.dtype _mask_is_float = ( _mask_dtype == np.float16 @@ -2232,42 +2233,96 @@ def _merge_masks( """ mask_type = "no_mask" merged_mask = None + batch_size = query.shape[0] seq_qlen = query.shape[1] seq_klen = key.shape[1] - attn_mask_np = attn_mask.numpy() if attn_mask is not None else None + attn_mask_dim = attn_mask.ndim if attn_mask is not None else 0 + attn_mask_shape = attn_mask.shape if attn_mask is not None else 0 + key_padding_mask_shape = ( + key_padding_mask.shape if key_padding_mask is not None else 0 + ) # is_causal is used to hint whether to use a causal mask, where the upper right triangle is all -inf, # and the diagonal and lower left triangle are all 0. But if attn_mask is given, attn_mask is used first. - if is_causal and attn_mask is None and key_padding_mask is None: + if ( + not maybe_cudnn_style_mask + and is_causal + and attn_mask is None + and key_padding_mask is None + ): # At this point, merged_mask = None mask_type = "default_mask" - elif is_causal and attn_mask is not None and key_padding_mask is None: + elif ( + not maybe_cudnn_style_mask + and is_causal + and attn_mask is not None + and key_padding_mask is None + ): # At this point, merged_mask = attn_mask - default_mask_np = np.triu( - -float("inf") * np.ones((seq_qlen, seq_klen)), k=1 - ).astype("float32") - if (attn_mask_np == default_mask_np).all(): - mask_type = "default_mask" - else: - mask_type = "user_defined_mask" + mask_type = "default_mask" merged_mask = attn_mask + elif maybe_cudnn_style_mask and not add_zero_attn and not add_bias_kv: + # Please be careful, we only check if the shape is correct, + # and we will not check if the values in attn_mask_tensor and key_padding_mask_tensor are correct. + assert ( + attn_mask is not None and key_padding_mask is not None + ), "if maybe_cudnn_style_mask, must given attn_mask and key_padding_mask." + assert attn_mask_shape == (2, seq_qlen) and key_padding_mask_shape == ( + 2, + batch_size, + ) + merged_mask = concat( + ( + reshape(attn_mask, (2 * seq_qlen)), + reshape(key_padding_mask, (2 * batch_size)), + ) + ).astype("int32") + mask_type = "cudnn_style_mask" else: - if attn_mask is not None: + if attn_mask is not None and key_padding_mask is None: # At this point, merged_mask = attn_mask - default_mask_np = np.triu( - -float("inf") * np.ones((seq_qlen, seq_klen)), k=1 - ).astype("float32") - if ( - attn_mask_np == default_mask_np - and (attn_mask_np == default_mask_np).all() - ): - mask_type = "default_mask" - merged_mask = attn_mask - elif np.all(attn_mask_np == 0): - mask_type = "no_mask" + mask_type = "user_defined_mask" + merged_mask = attn_mask + elif key_padding_mask is not None and attn_mask is None: + mask_type = "user_defined_mask" + # At this point, merged_mask.ndim = 4 + key_padding_mask_expanded = reshape( + key_padding_mask, (batch_size, 1, 1, seq_klen) + ) + key_padding_mask_expanded = broadcast_to( + key_padding_mask_expanded, (None, num_heads, seq_qlen, None) + ) + merged_mask = key_padding_mask_expanded + merged_mask = reshape( + merged_mask, (batch_size * num_heads, seq_qlen, seq_klen) + ) + elif (attn_mask is not None) and (key_padding_mask is not None): + # At this point, merged_mask.ndim = 3 + mask_type = "user_defined_mask" + if attn_mask_dim == 2: + key_padding_mask_expanded = reshape( + key_padding_mask, (batch_size, 1, seq_klen) + ) + merged_mask = (attn_mask + key_padding_mask_expanded).reshape( + (batch_size, 1, seq_qlen, seq_klen) + ) + attn_mask_expanded = broadcast_to( + merged_mask, (None, num_heads, None, None) + ) + merged_mask = reshape( + attn_mask_expanded, (batch_size * num_heads, seq_qlen, seq_klen) + ) else: - mask_type = "user_defined_mask" - merged_mask = attn_mask + key_padding_mask_expanded = reshape( + key_padding_mask, (batch_size, 1, 1, seq_klen) + ) + attn_mask_expanded = reshape( + attn_mask, (batch_size, num_heads, seq_qlen, seq_klen) + ) + merged_mask = attn_mask_expanded + key_padding_mask_expanded + merged_mask = reshape( + merged_mask, (batch_size * num_heads, seq_qlen, seq_klen) + ) return merged_mask, mask_type @@ -2311,7 +2366,7 @@ def multi_head_attention( See :class:`~.module.MultiHeadAttn` for more details. - Note: This API is experimental, and there is a possibility of subsequent changes. Currently, only the cuda platform is supported, and if the cudnn version >=8.6.0, the calculation results are completely correct; If the cudnn version >=8.0.4 but <8.6.0, if there is a bias, only the dbias result calculated from the backward is incorrect. If there is no bias, the forward and backward calculations are correct; If the cudnn version is less than 8.0.4, this operator is not supported. + Note: This API is experimental, and there is a possibility of subsequent changes. Currently, only the cuda platform is supported. Args: query, key, value: map a query and a set of key-value pairs to an output. @@ -2337,24 +2392,19 @@ def multi_head_attention( Note: should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. key_padding_mask: if specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. - Note: Should be set to None, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. - Note: User-defined mask not supported now, only support no mask or default mask, where the upper right triangle is all -inf, and the diagonal and lower left triangle are all 0. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. - need_weights: indicates whether to return the attention weight, which is the output result of softmax. Default: `True` - Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. + need_weights: indicates whether to return the attention weight, which is the output result of softmax. Default: `False` average_attn_weights: if true, indicates that the returned ``attn_weights`` should be averaged across heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an - effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) - Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. + effect when ``need_weights=True``. Default: ``False`` (i.e. average weights across heads) is_causal: if specified, applies a causal mask as attention mask. Default: ``False`` Warning: ``is_causal`` provides a hint that ``attn_mask`` is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility. maybe_cudnn_style_mask: if specified, applies a cudnn style mask as attention mask. Default: ``False`` Note: In the cudnn style, the shape of the attn_mask is :math:`(2, L)`, and the shape of the key_padding_mask is :math:`(2, N)`. Warning: like is_causal, maybe_cudnn_style_mask provides a hint that attn_mask and key_padding_mask is a cudnn style mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility. In addition, if the ``_merge_masks`` function returns ``merge_type=cudnn_style_mask``, please ensure that other conditions are correct so that it can run the implementation of cudnn, otherwise an error will be reported. - Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that the underlying implementation only accepts two types of mask type, namely "no_mask" and "default_mask", and we may try to loosen this option after submitting the commit that users can pass in custom attention mask tensors. reslink: add input query to final output. - Note: It is only valid if the input query is the same as the shape of the output. + Note: It is only valid if the input query is the same as the shape of the output. Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. training: will apply dropout if is ``True``. Outputs: @@ -2366,9 +2416,8 @@ def multi_head_attention( :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N * \text{num\_heads}, L, S)`. - Note: Now only None will be returned. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. - - **out[2]=mask_reversespace** - Used to save the dropout mask needed for backward propagation.`, - - **out[3]=othr_reversespace** - Used to save the intermediate results that need to be used in backward propagation.`, + - **out[2]=mask_reversespace** - Used to save the dropout mask needed for backward propagation., + - **out[3]=othr_reversespace** - Used to save the intermediate results that need to be used in backward propagation., """ qproj_size = embed_dim if qproj_size is None else qproj_size kproj_size = embed_dim if kproj_size is None else kproj_size @@ -2395,22 +2444,6 @@ def multi_head_attention( "add_zero_attn should be False, and configuration of this parameter is not supported now." + unsupport_reason ) - assert key_padding_mask is None, ( - "key_padding_mask should be None, and configuration of this parameter is not supported now." - + unsupport_reason - ) - assert need_weights == False, ( - "need_weights should be set to False, and configuration of this parameter is not supported now." - + unsupport_reason - ) - assert average_attn_weights == False, ( - "average_attn_weights should be set to False, and configuration of this parameter is not supported now." - + unsupport_reason - ) - assert maybe_cudnn_style_mask == False, ( - "maybe_cudnn_style_mask should be set to False, and configuration of this parameter is not supported now." - + unsupport_reason - ) assert bias_k is None, ( "bias_k should be None, and configuration of this parameter is not supported now." + unsupport_reason @@ -2419,7 +2452,11 @@ def multi_head_attention( "bias_v should be None, and configuration of this parameter is not supported now." + unsupport_reason ) - head_dim = (qproj_size if qproj_size != 0 else embed_dim) // num_heads + assert reslink is False, ( + "reslink should be False, and configuration of this parameter is not supported now." + + unsupport_reason + ) + head_dim = embed_dim if qproj_size == 0 else embed_dim // num_heads smScaler = head_dim ** -0.5 k_size = key.shape[2] v_size = value.shape[2] @@ -2440,6 +2477,7 @@ def multi_head_attention( other_type=attn_mask, other_name="attn_mask", target_type=query.dtype, + maybe_cudnn_style_mask=maybe_cudnn_style_mask, ) attn_mask = _canonical_mask( mask=attn_mask, @@ -2448,6 +2486,7 @@ def multi_head_attention( other_name="", target_type=query.dtype, check_other=False, + maybe_cudnn_style_mask=maybe_cudnn_style_mask, ) attn_mask_tensor, attn_mask_type = _merge_masks( attn_mask=attn_mask, @@ -2511,8 +2550,15 @@ def multi_head_attention( out = apply( op, query, key, value, io_weight_bias, attn_mask_tensor, bias_k, bias_v ) - - return out[0], out[1] + if need_weights: + if average_attn_weights: + shape = out[1].shape + out_weight = out[1].reshape(-1, num_heads, shape[-2], shape[-1]) + return out[0], out_weight.mean(axis=1) + else: + return out[0], out[1] + else: + return out[0], None from .loss import * # isort:skip diff --git a/imperative/python/megengine/module/multiheadattn.py b/imperative/python/megengine/module/multiheadattn.py index cc26620a874682c7bbbec9dd1d32cf1250771718..85648081cea71f1d3dad556deac2ae72e50b1954 100644 --- a/imperative/python/megengine/module/multiheadattn.py +++ b/imperative/python/megengine/module/multiheadattn.py @@ -22,23 +22,22 @@ class MultiHeadAttention(Module): \text{MultiHeadAttn}\big(q,K,V, W_Q, W_V, W_O\big) = \sum^{nHeads-1}_{i=0}W_{O,i}h_i where :math:`h_i=W_{V,i}V \text{Softmax}\Big( \text{smScaler} \cdot K^TW^T_{K,i}W_{Q,i}q \Big),\text{for }i\text{ = 0 ... nHeads-1}`. - - Note: This API is experimental, and there is a possibility of subsequent changes. Currently, only the cuda platform is supported, and if the cudnn version >=8.6.0, the calculation results are completely correct; If the cudnn version >=8.0.4 but <8.6.0, if there is a bias, only the dbias result calculated from the backward is incorrect. If there is no bias, the forward and backward calculations are correct; If the cudnn version is less than 8.0.4, this operator is not supported. - - When the following conditions are met, you can go to the cudnn backend: - - ``cudnn version`` greater than or equal to 8.0.4 and ``bias`` is ``False`` and ``training`` is ``False`` - - ``cudnn version`` greater than or equal to 8.6.0 + Note: This API is experimental, and there is a possibility of subsequent changes. Currently, only the cuda platform is supported. + + The implementation of cudnn can be run if the following conditions are met: + + - cuda is available, cudnn is available, and the version of cudnn is greater than or equal to 8.0.4. + - ``bias`` is ``False``, when ``training`` is ``False`` and ``cudnn version`` greater than or equal to 8.0.4. if the ``cudnn version`` greater than or equal to 8.6.0, this point can be ignored. - ``add_bias_kv`` is ``False`` - ``add_zero_attn`` is ``False`` - ``need_weights`` is ``False`` - ``average_attn_weights`` is ``False`` - - ``maybe_cudnn_style_mask`` is ``True`` if support else ``False`` + - ``maybe_cudnn_style_mask`` is ``True`` - ``attn_mask`` and ``key_padding_mask`` is cudnn style mask, i.e. the shape of the attn_mask is :math:`(2, L)`, and the shape of the key_padding_mask is :math:`(2, N)`. - The shape of attn_mask is :math:`(2, L)`, where :math:`(0, :)` elements specify the start index, :math:`(1, :)` elements specify the end index, the start index is inclusive, and the end index is not exclusive. The start index (i.e. elements in `attn_mask[0, x]`) must be less than the corresponding end index (i.e. elements in `attn_mask[1, x]`). The end index must be less than or equal to :math:`S`, where :math:`S` is the source sequence length, :math:`L` is the target sequence length. - The shape of key_padding_mask is :math:`(2, N)`, where :math:`(0, :)` elements specify the target sequence padding in cudnn style mask and the element must equal to or less than :math:`L`, :math:`(1, :)` elements specify the source sequence padding in cudnn style mask and the element must equal to or less than :math:`S`, where :math:`S` is the source sequence length, :math:`L` is the target sequence length. - - ``qbias``, ``kbias``, ``vbias`` and ``obias`` are equal - + Note: If there is no mask or the default mask is used, cudnn impl will also be used. At this time, cudnn will automatically generate the corresponding cudnn style mask. Args: embed_dim: Total dimension of the model. @@ -139,11 +138,9 @@ class MultiHeadAttention(Module): if self.add_bias_kv: xavier_uniform_(self.bias_k) - else: - self.bias_k = None - if self.add_bias_kv: xavier_uniform_(self.bias_v) else: + self.bias_k = None self.bias_v = None def forward( @@ -159,59 +156,36 @@ class MultiHeadAttention(Module): maybe_cudnn_style_mask: bool = False, ): r""" - Args: - query: Query embeddings of shape :math:`(N, L, E_q)`, - where :math:`N` is the batch size, :math:`L` is the target sequence length, and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against key-value pairs to produce the output. See "Attention Is All You Need" for more details. - key: Key embeddings of shape :math:`(N, S, E_k)`, - where :math:`N` is the batch size, :math:`S` is the source sequence length, and :math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details. - value: Value embeddings of shape :math:`(N, S, E_v)`, - where :math:`N` is the batch size, :math:`S` is the source sequence length, and :math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details. - key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` to ignore for the purpose of - attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. - Note: Should be set to None, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - Note: User-defined mask not supported now, only support no mask or default mask, where the upper right triangle is all -inf, and the diagonal and lower left triangle are all 0. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. - need_weights: indicates whether to return the attention weight, which is the output result of softmax. Default: `True` - Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. - average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across - heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an - effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) - Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. - is_causal: If specified, applies a causal mask as attention mask. Default: ``False`` - Warning: ``is_causal`` provides a hint that ``attn_mask`` is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility. - maybe_cudnn_style_mask: if specified, applies a cudnn style mask as attention mask. Default: ``False`` - Note: In the cudnn style, the shape of the attn_mask is :math:`(2, L)`, and the shape of the key_padding_mask is :math:`(2, N)`. - Warning: like is_causal, maybe_cudnn_style_mask provides a hint that attn_mask and key_padding_mask is a cudnn style mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility. In addition, if the ``_merge_masks`` function returns ``merge_type=cudnn_style_mask``, please ensure that other conditions are correct so that it can run the implementation of cudnn, otherwise an error will be reported. - Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that the underlying implementation only accepts two types of mask type, namely "no_mask" and "default_mask", and we may try to loosen this option after submitting the commit that users can pass in custom attention mask tensors. - - Outputs: - - **attn_output** - Attention outputs of shape :math:`(N, L, E)`, - where :math:`L` is the target sequence length, :math:`N` is - the batch size, and :math:`E` is the embedding dimension ``embed_dim``. - - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, - returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or - :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and - :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per - head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N * \text{num\_heads}, L, S)`. - Note: Now only None will be returned. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation. + Args: + query: Query embeddings of shape :math:`(N, L, E_q)`, + where :math:`N` is the batch size, :math:`L` is the target sequence length, and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against key-value pairs to produce the output. See "Attention Is All You Need" for more details. + key: Key embeddings of shape :math:`(N, S, E_k)`, + where :math:`N` is the batch size, :math:`S` is the source sequence length, and :math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details. + value: Value embeddings of shape :math:`(N, S, E_v)`, + where :math:`N` is the batch size, :math:`S` is the source sequence length, and :math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details. + key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` to ignore for the purpose of + attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + need_weights: indicates whether to return the attention weight, which is the output result of softmax. Default: `False` + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across + heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an + effect when ``need_weights=True``. Default: ``False`` (i.e. not average weights across heads) + is_causal: If specified, applies a causal mask as attention mask. Default: ``False`` + Warning: ``is_causal`` provides a hint that ``attn_mask`` is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility. + maybe_cudnn_style_mask: if specified, applies a cudnn style mask as attention mask. Default: ``False`` + Note: In the cudnn style, the shape of the attn_mask is :math:`(2, L)`, and the shape of the key_padding_mask is :math:`(2, N)`. + Warning: like is_causal, maybe_cudnn_style_mask provides a hint that attn_mask and key_padding_mask is a cudnn style mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility. In addition, if the ``_merge_masks`` function returns ``merge_type=cudnn_style_mask``, please ensure that other conditions are correct so that it can run the implementation of cudnn, otherwise an error will be reported. + + Outputs: + - **attn_output** - Attention outputs of shape :math:`(N, L, E)`, + where :math:`L` is the target sequence length, :math:`N` is + the batch size, and :math:`E` is the embedding dimension ``embed_dim``. + - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, + returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N * \text{num\_heads}, L, S)`. """ - assert key_padding_mask is None, ( - "key_padding_mask should be None, and configuration of this parameter is not supported now." - + self.unsupport_reason - ) - assert need_weights == False, ( - "need_weights should be set to False, and configuration of this parameter is not supported now." - + self.unsupport_reason - ) - assert average_attn_weights == False, ( - "average_attn_weights should be set to False, and configuration of this parameter is not supported now." - + self.unsupport_reason - ) - assert maybe_cudnn_style_mask == False, ( - "maybe_cudnn_style_mask should be set to False, and configuration of this parameter is not supported now." - + self.unsupport_reason - ) return multi_head_attention( query, diff --git a/imperative/src/impl/ops/rng.cpp b/imperative/src/impl/ops/rng.cpp index 6b62f3e9c985bece66fbe03c593f3a5682e1b0e8..189ce96d25af740fe0b2f953b1fc7c810bcd6516 100644 --- a/imperative/src/impl/ops/rng.cpp +++ b/imperative/src/impl/ops/rng.cpp @@ -613,25 +613,25 @@ std::tuple, bool> _infer_output_attrs SmallVector infer_output_attrs( const OpDef& op, const SmallVector& inputs) { - using INPUT_TYPE = opr::MultiHeadAttn::Param::TENSOR_COMBINATION_TYPE; + using InputType = opr::MultiHeadAttn::Param::TensorCombinationType; auto&& cn = inputs[0]->comp_node(); auto input_type = op.cast_final_safe().tensor_combination_type; std::tuple, bool> ret; TensorLayout empty_layout; - if (input_type == INPUT_TYPE::NONE) + if (input_type == InputType::NONE) ret = _infer_output_attrs( op, {inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(), inputs[3]->layout(), empty_layout, empty_layout, empty_layout}, cn); - else if (input_type == INPUT_TYPE::ONLY_MASK) + else if (input_type == InputType::ONLY_MASK) ret = _infer_output_attrs( op, {inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(), inputs[3]->layout(), inputs[4]->layout(), empty_layout, empty_layout}, cn); - else if (input_type == INPUT_TYPE::ONLY_BIASKV) + else if (input_type == InputType::ONLY_BIASKV) ret = _infer_output_attrs( op, {inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(), @@ -666,7 +666,7 @@ template <> SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs, SmallVector& output_descs, const bool& validated) { - using INPUT_TYPE = opr::MultiHeadAttn::Param::TENSOR_COMBINATION_TYPE; + using InputType = opr::MultiHeadAttn::Param::TensorCombinationType; SmallVector outputs; SmallVector desc = infer_output_attrs(def, inputs); @@ -705,7 +705,7 @@ SmallVector apply_on_physical_tensor( TensorLayout empty_layout; megdnn::TensorND empty_tensor; - if (input_type == INPUT_TYPE::ALL) { + if (input_type == InputType::ALL) { wk_size = dnn_op->get_workspace_in_bytes( inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(), inputs[3]->layout(), inputs[4]->layout(), inputs[5]->layout(), @@ -725,7 +725,7 @@ SmallVector apply_on_physical_tensor( outputs[1]->dev_tensor().as_megdnn(), outputs[2]->dev_tensor().as_megdnn(), outputs[3]->dev_tensor().as_megdnn(), dnn_wk); - } else if (input_type == INPUT_TYPE::ONLY_MASK) { + } else if (input_type == InputType::ONLY_MASK) { wk_size = dnn_op->get_workspace_in_bytes( inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(), inputs[3]->layout(), inputs[4]->layout(), empty_layout, empty_layout, @@ -743,7 +743,7 @@ SmallVector apply_on_physical_tensor( outputs[1]->dev_tensor().as_megdnn(), outputs[2]->dev_tensor().as_megdnn(), outputs[3]->dev_tensor().as_megdnn(), dnn_wk); - } else if (input_type == INPUT_TYPE::ONLY_BIASKV) { + } else if (input_type == InputType::ONLY_BIASKV) { wk_size = dnn_op->get_workspace_in_bytes( inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(), inputs[3]->layout(), empty_layout, inputs[4]->layout(), @@ -801,13 +801,13 @@ template <> SymbolVarArray apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& rng = def.cast_final_safe(); - using INPUT_TYPE = opr::MultiHeadAttn::Param::TENSOR_COMBINATION_TYPE; + using InputType = opr::MultiHeadAttn::Param::TensorCombinationType; auto input_type = rng.tensor_combination_type; - if (input_type == INPUT_TYPE::ALL) { + if (input_type == InputType::ALL) { return _RNGOprMaker<7>::make(inputs, rng); - } else if (input_type == INPUT_TYPE::ONLY_BIASKV) { + } else if (input_type == InputType::ONLY_BIASKV) { return _RNGOprMaker<6>::make(inputs, rng); - } else if (input_type == INPUT_TYPE::ONLY_MASK) { + } else if (input_type == InputType::ONLY_MASK) { return _RNGOprMaker<5>::make(inputs, rng); } else { return _RNGOprMaker<4>::make(inputs, rng); @@ -884,25 +884,25 @@ std::tuple, bool> infer_output_attrs_fallible std::tuple, bool> infer_output_attrs_fallible< MultiHeadAttn>(const OpDef& op, const SmallVector& inputs) { - using INPUT_TYPE = opr::MultiHeadAttn::Param::TENSOR_COMBINATION_TYPE; + using InputType = opr::MultiHeadAttn::Param::TensorCombinationType; auto&& cn = inputs[0].comp_node; auto input_type = op.cast_final_safe().tensor_combination_type; std::tuple, bool> ret; TensorLayout empty_layout; - if (input_type == INPUT_TYPE::NONE) + if (input_type == InputType::NONE) ret = _infer_output_attrs( op, {inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout, empty_layout, empty_layout, empty_layout}, cn); - else if (input_type == INPUT_TYPE::ONLY_MASK) + else if (input_type == InputType::ONLY_MASK) ret = _infer_output_attrs( op, {inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout, inputs[4].layout, empty_layout, empty_layout}, cn); - else if (input_type == INPUT_TYPE::ONLY_BIASKV) + else if (input_type == InputType::ONLY_BIASKV) ret = _infer_output_attrs( op, {inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout, diff --git a/imperative/tablegen/generated/enum_macro.h b/imperative/tablegen/generated/enum_macro.h index b3fbfd2dc4938742fc1f10e9f2dfbeff39b0b2c1..97a874623d3daffbd1e7b8859d3f99bc49888375 100644 --- a/imperative/tablegen/generated/enum_macro.h +++ b/imperative/tablegen/generated/enum_macro.h @@ -20,8 +20,8 @@ cb(::megdnn::param::CvtColor::Mode); \ cb(::megdnn::param::Elemwise::Mode); \ cb(::megdnn::param::ElemwiseMultiType::Mode); \ - cb(::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE); \ - cb(::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE); \ + cb(::megdnn::param::MultiHeadAttn::AttnMaskType); \ + cb(::megdnn::param::MultiHeadAttn::TensorCombinationType); \ cb(::megdnn::param::Padding::PaddingMode); \ cb(::megdnn::param::RNNCell::NonlineMode); \ cb(::megdnn::param::ROIAlignV0::Mode); \ diff --git a/imperative/tablegen/generated/hash.txt b/imperative/tablegen/generated/hash.txt index 71eb57d1ce310f39603d6c48d5b9ef8e4d33b7ff..395aadc3e114d84d998d48849f05e926e570e8e1 100644 --- a/imperative/tablegen/generated/hash.txt +++ b/imperative/tablegen/generated/hash.txt @@ -1,7 +1,7 @@ -0a8cd3cd50cadfaae0478ee70621618e ../../dnn/scripts/opr_param_defs.py +20aa8ae7e128c1e24564ce68389307cc ../../dnn/scripts/opr_param_defs.py 9e9636d66694dd7d5a7853247a5406f9 ../../src/core/include/megbrain/ir/ops.td -2c15c869c1731d1bc5f25f9b132f4f08 generated/opdef.h.inl -0dabeee4b8f81be4c1809906b99795a5 generated/opdef.cpp.inl -be20faf18eccbc56f535b012170ed90a generated/opdef.py.inl -af9ab62fe962d409bb65e66af5f44a79 generated/opdef.cpy.inl -d468302f2d4b113913b76b5a181aae56 generated/enum_macro.h +e4489c2e1ea2b680d61c352842e56929 generated/opdef.h.inl +fd27534146a1cfcc791e40b2bb532076 generated/opdef.cpp.inl +6754eaa59ef19178eba41e99e418790c generated/opdef.py.inl +df66a3089aa6c12e5b1d943cd3d20e80 generated/opdef.cpy.inl +911001ef0dd771024919f7a1a3a009db generated/enum_macro.h diff --git a/imperative/tablegen/generated/opdef.cpp.inl b/imperative/tablegen/generated/opdef.cpp.inl index e44998ba4cdf4069e682b93d452c4f13047b10c3..09155086cad11f92cd6cccd7d8bec1db54a91296 100644 --- a/imperative/tablegen/generated/opdef.cpp.inl +++ b/imperative/tablegen/generated/opdef.cpp.inl @@ -5288,16 +5288,16 @@ std::vector> MultiHeadAttn_props_impl(const props_.emplace_back("sm_scaler", std::to_string(op_.sm_scaler)); props_.emplace_back("input_order", std::to_string(op_.input_order)); switch (op_.attn_mask_type){ - case MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK: + case MultiHeadAttn::AttnMaskType::NO_MASK: props_.emplace_back("attn_mask_type", "NO_MASK"); break; - case MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK: + case MultiHeadAttn::AttnMaskType::DEFAULT_MASK: props_.emplace_back("attn_mask_type", "DEFAULT_MASK"); break; - case MultiHeadAttn::ATTN_MASK_TYPE::CUDNN_STYLE_MASK: + case MultiHeadAttn::AttnMaskType::CUDNN_STYLE_MASK: props_.emplace_back("attn_mask_type", "CUDNN_STYLE_MASK"); break; - case MultiHeadAttn::ATTN_MASK_TYPE::USER_DEFINED_MASK: + case MultiHeadAttn::AttnMaskType::USER_DEFINED_MASK: props_.emplace_back("attn_mask_type", "USER_DEFINED_MASK"); break; default: @@ -5305,16 +5305,16 @@ std::vector> MultiHeadAttn_props_impl(const break; } switch (op_.tensor_combination_type){ - case MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE: + case MultiHeadAttn::TensorCombinationType::NONE: props_.emplace_back("tensor_combination_type", "NONE"); break; - case MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_MASK: + case MultiHeadAttn::TensorCombinationType::ONLY_MASK: props_.emplace_back("tensor_combination_type", "ONLY_MASK"); break; - case MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_BIASKV: + case MultiHeadAttn::TensorCombinationType::ONLY_BIASKV: props_.emplace_back("tensor_combination_type", "ONLY_BIASKV"); break; - case MultiHeadAttn::TENSOR_COMBINATION_TYPE::ALL: + case MultiHeadAttn::TensorCombinationType::ALL: props_.emplace_back("tensor_combination_type", "ALL"); break; default: diff --git a/imperative/tablegen/generated/opdef.cpy.inl b/imperative/tablegen/generated/opdef.cpy.inl index f10c3a482b67b804ac8f3280c5652375150a33bd..b79b2e82330fec426a4c8251e3826bf2bd4180d0 100644 --- a/imperative/tablegen/generated/opdef.cpy.inl +++ b/imperative/tablegen/generated/opdef.cpy.inl @@ -15043,39 +15043,39 @@ void _init_py_MeshIndexing(py::module m) { mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MeshIndexing::typeinfo(), &py_type).second); } -template<> struct EnumTrait { - static constexpr const char *name = "MultiHeadAttn.ATTN_MASK_TYPE"; - static constexpr std::underlying_type_t max = 4 - 1; +template<> struct EnumTrait { + static constexpr const char *name = "MultiHeadAttn.AttnMaskType"; + static constexpr std::underlying_type_t max = 4 - 1; }; -template<> PyTypeObject* EnumWrapper::type = nullptr; +template<> PyTypeObject* EnumWrapper::type = nullptr; template<> const char* -EnumWrapper::members[] = {"NO_MASK", "DEFAULT_MASK", "CUDNN_STYLE_MASK", "USER_DEFINED_MASK"}; +EnumWrapper::members[] = {"NO_MASK", "DEFAULT_MASK", "CUDNN_STYLE_MASK", "USER_DEFINED_MASK"}; -template<> std::unordered_map -EnumWrapper::mem2value = {{normalize_enum("NO_MASK"), MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK}, {normalize_enum("DEFAULT_MASK"), MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK}, {normalize_enum("CUDNN_STYLE_MASK"), MultiHeadAttn::ATTN_MASK_TYPE::CUDNN_STYLE_MASK}, {normalize_enum("USER_DEFINED_MASK"), MultiHeadAttn::ATTN_MASK_TYPE::USER_DEFINED_MASK}}; -template<> PyObject* EnumWrapper::pyobj_insts[4] = {nullptr}; +template<> std::unordered_map +EnumWrapper::mem2value = {{normalize_enum("NO_MASK"), MultiHeadAttn::AttnMaskType::NO_MASK}, {normalize_enum("DEFAULT_MASK"), MultiHeadAttn::AttnMaskType::DEFAULT_MASK}, {normalize_enum("CUDNN_STYLE_MASK"), MultiHeadAttn::AttnMaskType::CUDNN_STYLE_MASK}, {normalize_enum("USER_DEFINED_MASK"), MultiHeadAttn::AttnMaskType::USER_DEFINED_MASK}}; +template<> PyObject* EnumWrapper::pyobj_insts[4] = {nullptr}; -void _init_py_MultiHeadAttn_ATTN_MASK_TYPE(PyTypeObject& py_type) { - auto& e_type = EnumWrapper::type; +void _init_py_MultiHeadAttn_AttnMaskType(PyTypeObject& py_type) { + auto& e_type = EnumWrapper::type; static PyMethodDef tp_methods[] = { - {const_cast("dump"), (PyCFunction)EnumWrapper::py_dump, METH_NOARGS, NULL}, + {const_cast("dump"), (PyCFunction)EnumWrapper::py_dump, METH_NOARGS, NULL}, {NULL} /* Sentinel */ }; static PyType_Slot slots[] = { - {Py_tp_repr, (void*)EnumWrapper::py_repr}, - {Py_tp_richcompare, (void*)EnumWrapper::tp_richcompare}, + {Py_tp_repr, (void*)EnumWrapper::py_repr}, + {Py_tp_richcompare, (void*)EnumWrapper::tp_richcompare}, {Py_tp_methods, tp_methods}, {0, NULL} }; static PyType_Spec spec = { // name - "megengine.core._imperative_rt.ops.MultiHeadAttn.ATTN_MASK_TYPE", + "megengine.core._imperative_rt.ops.MultiHeadAttn.AttnMaskType", // basicsize - sizeof(EnumWrapper), + sizeof(EnumWrapper), // itemsize 0, // flags @@ -15089,7 +15089,7 @@ void _init_py_MultiHeadAttn_ATTN_MASK_TYPE(PyTypeObject& py_type) { e_type->tp_setattro( reinterpret_cast(e_type), py::cast("__name__").release().ptr(), - py::cast("ATTN_MASK_TYPE").release().ptr()) >= 0); + py::cast("AttnMaskType").release().ptr()) >= 0); mgb_assert( e_type->tp_setattro( @@ -15101,66 +15101,66 @@ void _init_py_MultiHeadAttn_ATTN_MASK_TYPE(PyTypeObject& py_type) { e_type->tp_setattro( reinterpret_cast(e_type), py::cast("__qualname__").release().ptr(), - py::cast("MultiHeadAttn.ATTN_MASK_TYPE").release().ptr()) >= 0); + py::cast("MultiHeadAttn.AttnMaskType").release().ptr()) >= 0); { PyObject* inst = e_type->tp_alloc(e_type, 0); - reinterpret_cast*>(inst)->value = MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK; + reinterpret_cast*>(inst)->value = MultiHeadAttn::AttnMaskType::NO_MASK; mgb_assert(PyDict_SetItemString(e_type->tp_dict, "NO_MASK", inst) >= 0); - EnumWrapper::pyobj_insts[0] = inst; + EnumWrapper::pyobj_insts[0] = inst; }{ PyObject* inst = e_type->tp_alloc(e_type, 0); - reinterpret_cast*>(inst)->value = MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK; + reinterpret_cast*>(inst)->value = MultiHeadAttn::AttnMaskType::DEFAULT_MASK; mgb_assert(PyDict_SetItemString(e_type->tp_dict, "DEFAULT_MASK", inst) >= 0); - EnumWrapper::pyobj_insts[1] = inst; + EnumWrapper::pyobj_insts[1] = inst; }{ PyObject* inst = e_type->tp_alloc(e_type, 0); - reinterpret_cast*>(inst)->value = MultiHeadAttn::ATTN_MASK_TYPE::CUDNN_STYLE_MASK; + reinterpret_cast*>(inst)->value = MultiHeadAttn::AttnMaskType::CUDNN_STYLE_MASK; mgb_assert(PyDict_SetItemString(e_type->tp_dict, "CUDNN_STYLE_MASK", inst) >= 0); - EnumWrapper::pyobj_insts[2] = inst; + EnumWrapper::pyobj_insts[2] = inst; }{ PyObject* inst = e_type->tp_alloc(e_type, 0); - reinterpret_cast*>(inst)->value = MultiHeadAttn::ATTN_MASK_TYPE::USER_DEFINED_MASK; + reinterpret_cast*>(inst)->value = MultiHeadAttn::AttnMaskType::USER_DEFINED_MASK; mgb_assert(PyDict_SetItemString(e_type->tp_dict, "USER_DEFINED_MASK", inst) >= 0); - EnumWrapper::pyobj_insts[3] = inst; + EnumWrapper::pyobj_insts[3] = inst; } Py_INCREF(e_type); mgb_assert(PyDict_SetItemString( - py_type.tp_dict, "ATTN_MASK_TYPE", reinterpret_cast(e_type)) >= 0); + py_type.tp_dict, "AttnMaskType", reinterpret_cast(e_type)) >= 0); } -template<> struct EnumTrait { - static constexpr const char *name = "MultiHeadAttn.TENSOR_COMBINATION_TYPE"; - static constexpr std::underlying_type_t max = 4 - 1; +template<> struct EnumTrait { + static constexpr const char *name = "MultiHeadAttn.TensorCombinationType"; + static constexpr std::underlying_type_t max = 4 - 1; }; -template<> PyTypeObject* EnumWrapper::type = nullptr; +template<> PyTypeObject* EnumWrapper::type = nullptr; template<> const char* -EnumWrapper::members[] = {"NONE", "ONLY_MASK", "ONLY_BIASKV", "ALL"}; +EnumWrapper::members[] = {"NONE", "ONLY_MASK", "ONLY_BIASKV", "ALL"}; -template<> std::unordered_map -EnumWrapper::mem2value = {{normalize_enum("NONE"), MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE}, {normalize_enum("ONLY_MASK"), MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_MASK}, {normalize_enum("ONLY_BIASKV"), MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_BIASKV}, {normalize_enum("ALL"), MultiHeadAttn::TENSOR_COMBINATION_TYPE::ALL}}; -template<> PyObject* EnumWrapper::pyobj_insts[4] = {nullptr}; +template<> std::unordered_map +EnumWrapper::mem2value = {{normalize_enum("NONE"), MultiHeadAttn::TensorCombinationType::NONE}, {normalize_enum("ONLY_MASK"), MultiHeadAttn::TensorCombinationType::ONLY_MASK}, {normalize_enum("ONLY_BIASKV"), MultiHeadAttn::TensorCombinationType::ONLY_BIASKV}, {normalize_enum("ALL"), MultiHeadAttn::TensorCombinationType::ALL}}; +template<> PyObject* EnumWrapper::pyobj_insts[4] = {nullptr}; -void _init_py_MultiHeadAttn_TENSOR_COMBINATION_TYPE(PyTypeObject& py_type) { - auto& e_type = EnumWrapper::type; +void _init_py_MultiHeadAttn_TensorCombinationType(PyTypeObject& py_type) { + auto& e_type = EnumWrapper::type; static PyMethodDef tp_methods[] = { - {const_cast("dump"), (PyCFunction)EnumWrapper::py_dump, METH_NOARGS, NULL}, + {const_cast("dump"), (PyCFunction)EnumWrapper::py_dump, METH_NOARGS, NULL}, {NULL} /* Sentinel */ }; static PyType_Slot slots[] = { - {Py_tp_repr, (void*)EnumWrapper::py_repr}, - {Py_tp_richcompare, (void*)EnumWrapper::tp_richcompare}, + {Py_tp_repr, (void*)EnumWrapper::py_repr}, + {Py_tp_richcompare, (void*)EnumWrapper::tp_richcompare}, {Py_tp_methods, tp_methods}, {0, NULL} }; static PyType_Spec spec = { // name - "megengine.core._imperative_rt.ops.MultiHeadAttn.TENSOR_COMBINATION_TYPE", + "megengine.core._imperative_rt.ops.MultiHeadAttn.TensorCombinationType", // basicsize - sizeof(EnumWrapper), + sizeof(EnumWrapper), // itemsize 0, // flags @@ -15174,7 +15174,7 @@ void _init_py_MultiHeadAttn_TENSOR_COMBINATION_TYPE(PyTypeObject& py_type) { e_type->tp_setattro( reinterpret_cast(e_type), py::cast("__name__").release().ptr(), - py::cast("TENSOR_COMBINATION_TYPE").release().ptr()) >= 0); + py::cast("TensorCombinationType").release().ptr()) >= 0); mgb_assert( e_type->tp_setattro( @@ -15186,31 +15186,31 @@ void _init_py_MultiHeadAttn_TENSOR_COMBINATION_TYPE(PyTypeObject& py_type) { e_type->tp_setattro( reinterpret_cast(e_type), py::cast("__qualname__").release().ptr(), - py::cast("MultiHeadAttn.TENSOR_COMBINATION_TYPE").release().ptr()) >= 0); + py::cast("MultiHeadAttn.TensorCombinationType").release().ptr()) >= 0); { PyObject* inst = e_type->tp_alloc(e_type, 0); - reinterpret_cast*>(inst)->value = MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE; + reinterpret_cast*>(inst)->value = MultiHeadAttn::TensorCombinationType::NONE; mgb_assert(PyDict_SetItemString(e_type->tp_dict, "NONE", inst) >= 0); - EnumWrapper::pyobj_insts[0] = inst; + EnumWrapper::pyobj_insts[0] = inst; }{ PyObject* inst = e_type->tp_alloc(e_type, 0); - reinterpret_cast*>(inst)->value = MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_MASK; + reinterpret_cast*>(inst)->value = MultiHeadAttn::TensorCombinationType::ONLY_MASK; mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ONLY_MASK", inst) >= 0); - EnumWrapper::pyobj_insts[1] = inst; + EnumWrapper::pyobj_insts[1] = inst; }{ PyObject* inst = e_type->tp_alloc(e_type, 0); - reinterpret_cast*>(inst)->value = MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_BIASKV; + reinterpret_cast*>(inst)->value = MultiHeadAttn::TensorCombinationType::ONLY_BIASKV; mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ONLY_BIASKV", inst) >= 0); - EnumWrapper::pyobj_insts[2] = inst; + EnumWrapper::pyobj_insts[2] = inst; }{ PyObject* inst = e_type->tp_alloc(e_type, 0); - reinterpret_cast*>(inst)->value = MultiHeadAttn::TENSOR_COMBINATION_TYPE::ALL; + reinterpret_cast*>(inst)->value = MultiHeadAttn::TensorCombinationType::ALL; mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ALL", inst) >= 0); - EnumWrapper::pyobj_insts[3] = inst; + EnumWrapper::pyobj_insts[3] = inst; } Py_INCREF(e_type); mgb_assert(PyDict_SetItemString( - py_type.tp_dict, "TENSOR_COMBINATION_TYPE", reinterpret_cast(e_type)) >= 0); + py_type.tp_dict, "TensorCombinationType", reinterpret_cast(e_type)) >= 0); } PyOpDefBegin(MultiHeadAttn) // { @@ -15708,7 +15708,7 @@ PyMethodDef PyOp(MultiHeadAttn)::py_init_methoddef = { "__init__", (PyCFunction)PyOp(MultiHeadAttn)::py_init_proxy, METH_VARARGS | METH_KEYWORDS, - "__init__(self, num_heads: int = ..., embeding_size: int = ..., k_size: int = ..., v_size: int = ..., qproj_size: int = ..., kproj_size: int = ..., vproj_size: int = ..., oproj_size: int = ..., qbias: bool = ..., kbias: bool = ..., vbias: bool = ..., obias: bool = ..., sm_scaler: float = ..., input_order: int = ..., attn_mask_type: Union[str, ATTN_MASK_TYPE] = ..., tensor_combination_type: Union[str, TENSOR_COMBINATION_TYPE] = ..., add_zero_attn: bool = ..., need_weights: bool = ..., reslink: bool = ..., training: bool = ..., seed: int = ..., attn_prob: float = ..., out_prob: float = ..., handle: int = ...) -> None\n" + "__init__(self, num_heads: int = ..., embeding_size: int = ..., k_size: int = ..., v_size: int = ..., qproj_size: int = ..., kproj_size: int = ..., vproj_size: int = ..., oproj_size: int = ..., qbias: bool = ..., kbias: bool = ..., vbias: bool = ..., obias: bool = ..., sm_scaler: float = ..., input_order: int = ..., attn_mask_type: Union[str, AttnMaskType] = ..., tensor_combination_type: Union[str, TensorCombinationType] = ..., add_zero_attn: bool = ..., need_weights: bool = ..., reslink: bool = ..., training: bool = ..., seed: int = ..., attn_prob: float = ..., out_prob: float = ..., handle: int = ...) -> None\n" }; void _init_py_MultiHeadAttn(py::module m) { @@ -15730,8 +15730,8 @@ void _init_py_MultiHeadAttn(py::module m) { PyObject* descr = PyDescr_NewMethod(&PyOpType(MultiHeadAttn), &PyOp(MultiHeadAttn)::py_init_methoddef); PyDict_SetItemString(py_type.tp_dict, "__init__", descr); mgb_assert(PyType_Ready(&py_type) >= 0); - _init_py_MultiHeadAttn_ATTN_MASK_TYPE(py_type); - _init_py_MultiHeadAttn_TENSOR_COMBINATION_TYPE(py_type); + _init_py_MultiHeadAttn_AttnMaskType(py_type); + _init_py_MultiHeadAttn_TensorCombinationType(py_type); PyType_Modified(&py_type); m.add_object("MultiHeadAttn", reinterpret_cast(&py_type)); diff --git a/imperative/tablegen/generated/opdef.h.inl b/imperative/tablegen/generated/opdef.h.inl index a099e0b4a50b0dc0bd02d2c5995ebf7958fb4498..2b20899e4431003dc6fd4c3e581834711c2bcb37 100644 --- a/imperative/tablegen/generated/opdef.h.inl +++ b/imperative/tablegen/generated/opdef.h.inl @@ -1398,8 +1398,8 @@ class MultiHeadAttn : public OpDefImplBase { MGB_DYN_TYPE_OBJ_FINAL_DECL; public: - using ATTN_MASK_TYPE = ::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE; - using TENSOR_COMBINATION_TYPE = ::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE; + using AttnMaskType = ::megdnn::param::MultiHeadAttn::AttnMaskType; + using TensorCombinationType = ::megdnn::param::MultiHeadAttn::TensorCombinationType; uint32_t num_heads = 1; uint32_t embeding_size = 0; uint32_t k_size = 0; @@ -1414,8 +1414,8 @@ public: bool obias = false; float sm_scaler = 1.f; uint32_t input_order = 0; - ATTN_MASK_TYPE attn_mask_type = ::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK; - TENSOR_COMBINATION_TYPE tensor_combination_type = ::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE; + AttnMaskType attn_mask_type = ::megdnn::param::MultiHeadAttn::AttnMaskType::NO_MASK; + TensorCombinationType tensor_combination_type = ::megdnn::param::MultiHeadAttn::TensorCombinationType::NONE; bool add_zero_attn = false; bool need_weights = false; bool reslink = false; @@ -1425,7 +1425,7 @@ public: float out_prob = 0.f; size_t handle; MultiHeadAttn() = default; - MultiHeadAttn(uint32_t num_heads_, uint32_t embeding_size_, uint32_t k_size_, uint32_t v_size_, uint32_t qproj_size_, uint32_t kproj_size_, uint32_t vproj_size_, uint32_t oproj_size_, bool qbias_, bool kbias_, bool vbias_, bool obias_, float sm_scaler_, uint32_t input_order_, ATTN_MASK_TYPE attn_mask_type_, TENSOR_COMBINATION_TYPE tensor_combination_type_, bool add_zero_attn_, bool need_weights_, bool reslink_, bool training_, uint64_t seed_, float attn_prob_, float out_prob_, size_t handle_, std::string scope_ = {}): num_heads(num_heads_), embeding_size(embeding_size_), k_size(k_size_), v_size(v_size_), qproj_size(qproj_size_), kproj_size(kproj_size_), vproj_size(vproj_size_), oproj_size(oproj_size_), qbias(qbias_), kbias(kbias_), vbias(vbias_), obias(obias_), sm_scaler(sm_scaler_), input_order(input_order_), attn_mask_type(attn_mask_type_), tensor_combination_type(tensor_combination_type_), add_zero_attn(add_zero_attn_), need_weights(need_weights_), reslink(reslink_), training(training_), seed(seed_), attn_prob(attn_prob_), out_prob(out_prob_), handle(handle_) { set_scope(scope_); } + MultiHeadAttn(uint32_t num_heads_, uint32_t embeding_size_, uint32_t k_size_, uint32_t v_size_, uint32_t qproj_size_, uint32_t kproj_size_, uint32_t vproj_size_, uint32_t oproj_size_, bool qbias_, bool kbias_, bool vbias_, bool obias_, float sm_scaler_, uint32_t input_order_, AttnMaskType attn_mask_type_, TensorCombinationType tensor_combination_type_, bool add_zero_attn_, bool need_weights_, bool reslink_, bool training_, uint64_t seed_, float attn_prob_, float out_prob_, size_t handle_, std::string scope_ = {}): num_heads(num_heads_), embeding_size(embeding_size_), k_size(k_size_), v_size(v_size_), qproj_size(qproj_size_), kproj_size(kproj_size_), vproj_size(vproj_size_), oproj_size(oproj_size_), qbias(qbias_), kbias(kbias_), vbias(vbias_), obias(obias_), sm_scaler(sm_scaler_), input_order(input_order_), attn_mask_type(attn_mask_type_), tensor_combination_type(tensor_combination_type_), add_zero_attn(add_zero_attn_), need_weights(need_weights_), reslink(reslink_), training(training_), seed(seed_), attn_prob(attn_prob_), out_prob(out_prob_), handle(handle_) { set_scope(scope_); } MultiHeadAttn(::megdnn::param::MultiHeadAttn packed_param_0, size_t handle_): num_heads(packed_param_0.num_heads), embeding_size(packed_param_0.embeding_size), k_size(packed_param_0.k_size), v_size(packed_param_0.v_size), qproj_size(packed_param_0.qproj_size), kproj_size(packed_param_0.kproj_size), vproj_size(packed_param_0.vproj_size), oproj_size(packed_param_0.oproj_size), qbias(packed_param_0.qbias), kbias(packed_param_0.kbias), vbias(packed_param_0.vbias), obias(packed_param_0.obias), sm_scaler(packed_param_0.sm_scaler), input_order(packed_param_0.input_order), attn_mask_type(packed_param_0.attn_mask_type), tensor_combination_type(packed_param_0.tensor_combination_type), add_zero_attn(packed_param_0.add_zero_attn), need_weights(packed_param_0.need_weights), reslink(packed_param_0.reslink), training(packed_param_0.training), seed(packed_param_0.seed), attn_prob(packed_param_0.attn_prob), out_prob(packed_param_0.out_prob), handle(handle_) {} ::megdnn::param::MultiHeadAttn param() const { return {num_heads, embeding_size, k_size, v_size, qproj_size, kproj_size, vproj_size, oproj_size, qbias, kbias, vbias, obias, sm_scaler, input_order, attn_mask_type, tensor_combination_type, add_zero_attn, need_weights, reslink, training, seed, attn_prob, out_prob}; diff --git a/imperative/tablegen/generated/opdef.py.inl b/imperative/tablegen/generated/opdef.py.inl index 7a7de5f8acca10297bac31b2407ff517df048df5..44fd2a388b88b85a8f5324306579b96723152fda 100644 --- a/imperative/tablegen/generated/opdef.py.inl +++ b/imperative/tablegen/generated/opdef.py.inl @@ -1479,38 +1479,38 @@ MeshIndexingInst py::class_, OpDef> MultiHeadAttnInst(m, "MultiHeadAttn"); -py::enum_(MultiHeadAttnInst, "ATTN_MASK_TYPE") - .value("NO_MASK", MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK) - .value("DEFAULT_MASK", MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK) - .value("CUDNN_STYLE_MASK", MultiHeadAttn::ATTN_MASK_TYPE::CUDNN_STYLE_MASK) - .value("USER_DEFINED_MASK", MultiHeadAttn::ATTN_MASK_TYPE::USER_DEFINED_MASK) +py::enum_(MultiHeadAttnInst, "AttnMaskType") + .value("NO_MASK", MultiHeadAttn::AttnMaskType::NO_MASK) + .value("DEFAULT_MASK", MultiHeadAttn::AttnMaskType::DEFAULT_MASK) + .value("CUDNN_STYLE_MASK", MultiHeadAttn::AttnMaskType::CUDNN_STYLE_MASK) + .value("USER_DEFINED_MASK", MultiHeadAttn::AttnMaskType::USER_DEFINED_MASK) .def(py::init([](const std::string& in) { auto&& str = normalize_enum(in); - if (str == "NO_MASK") return MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK; - if (str == "DEFAULT_MASK") return MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK; - if (str == "CUDNN_STYLE_MASK") return MultiHeadAttn::ATTN_MASK_TYPE::CUDNN_STYLE_MASK; - if (str == "USER_DEFINED_MASK") return MultiHeadAttn::ATTN_MASK_TYPE::USER_DEFINED_MASK; + if (str == "NO_MASK") return MultiHeadAttn::AttnMaskType::NO_MASK; + if (str == "DEFAULT_MASK") return MultiHeadAttn::AttnMaskType::DEFAULT_MASK; + if (str == "CUDNN_STYLE_MASK") return MultiHeadAttn::AttnMaskType::CUDNN_STYLE_MASK; + if (str == "USER_DEFINED_MASK") return MultiHeadAttn::AttnMaskType::USER_DEFINED_MASK; throw py::cast_error("invalid enum value " + in); })); -py::implicitly_convertible(); +py::implicitly_convertible(); -py::enum_(MultiHeadAttnInst, "TENSOR_COMBINATION_TYPE") - .value("NONE", MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE) - .value("ONLY_MASK", MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_MASK) - .value("ONLY_BIASKV", MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_BIASKV) - .value("ALL", MultiHeadAttn::TENSOR_COMBINATION_TYPE::ALL) +py::enum_(MultiHeadAttnInst, "TensorCombinationType") + .value("NONE", MultiHeadAttn::TensorCombinationType::NONE) + .value("ONLY_MASK", MultiHeadAttn::TensorCombinationType::ONLY_MASK) + .value("ONLY_BIASKV", MultiHeadAttn::TensorCombinationType::ONLY_BIASKV) + .value("ALL", MultiHeadAttn::TensorCombinationType::ALL) .def(py::init([](const std::string& in) { auto&& str = normalize_enum(in); - if (str == "NONE") return MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE; - if (str == "ONLY_MASK") return MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_MASK; - if (str == "ONLY_BIASKV") return MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_BIASKV; - if (str == "ALL") return MultiHeadAttn::TENSOR_COMBINATION_TYPE::ALL; + if (str == "NONE") return MultiHeadAttn::TensorCombinationType::NONE; + if (str == "ONLY_MASK") return MultiHeadAttn::TensorCombinationType::ONLY_MASK; + if (str == "ONLY_BIASKV") return MultiHeadAttn::TensorCombinationType::ONLY_BIASKV; + if (str == "ALL") return MultiHeadAttn::TensorCombinationType::ALL; throw py::cast_error("invalid enum value " + in); })); -py::implicitly_convertible(); +py::implicitly_convertible(); MultiHeadAttnInst - .def(py::init(), py::arg("num_heads") = 1, py::arg("embeding_size") = 0, py::arg("k_size") = 0, py::arg("v_size") = 0, py::arg("qproj_size") = 0, py::arg("kproj_size") = 0, py::arg("vproj_size") = 0, py::arg("oproj_size") = 0, py::arg("qbias") = false, py::arg("kbias") = false, py::arg("vbias") = false, py::arg("obias") = false, py::arg("sm_scaler") = 1.f, py::arg("input_order") = 0, py::arg("attn_mask_type") = ::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK, py::arg("tensor_combination_type") = ::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE, py::arg("add_zero_attn") = false, py::arg("need_weights") = false, py::arg("reslink") = false, py::arg("training") = true, py::arg("seed") = 0, py::arg("attn_prob") = 0.f, py::arg("out_prob") = 0.f, py::arg("handle"), py::arg("scope") = {}) + .def(py::init(), py::arg("num_heads") = 1, py::arg("embeding_size") = 0, py::arg("k_size") = 0, py::arg("v_size") = 0, py::arg("qproj_size") = 0, py::arg("kproj_size") = 0, py::arg("vproj_size") = 0, py::arg("oproj_size") = 0, py::arg("qbias") = false, py::arg("kbias") = false, py::arg("vbias") = false, py::arg("obias") = false, py::arg("sm_scaler") = 1.f, py::arg("input_order") = 0, py::arg("attn_mask_type") = ::megdnn::param::MultiHeadAttn::AttnMaskType::NO_MASK, py::arg("tensor_combination_type") = ::megdnn::param::MultiHeadAttn::TensorCombinationType::NONE, py::arg("add_zero_attn") = false, py::arg("need_weights") = false, py::arg("reslink") = false, py::arg("training") = true, py::arg("seed") = 0, py::arg("attn_prob") = 0.f, py::arg("out_prob") = 0.f, py::arg("handle"), py::arg("scope") = {}) .def(py::init<>()) .def_readwrite("num_heads", &MultiHeadAttn::num_heads) .def_readwrite("embeding_size", &MultiHeadAttn::embeding_size) diff --git a/src/opr/impl/rand.cpp b/src/opr/impl/rand.cpp index 8782633d777dee742403f93ad3641002ad8c1b65..91c3144bb489a03e6da892453472224506924372 100644 --- a/src/opr/impl/rand.cpp +++ b/src/opr/impl/rand.cpp @@ -1,8 +1,10 @@ #include "megbrain/opr/rand.h" #include "megbrain/graph/grad_impl.h" +#include "megbrain/graph/static_infer.h" #include "megbrain/opr/utility.h" #include "./internal/megdnn_opr_wrapper.inl" +#include "megdnn/basic_types.h" using namespace mgb; using namespace opr; @@ -424,7 +426,7 @@ void DropoutBackward::scn_do_execute() { } /* ==================== MultiHeadAttnForward ==================== */ -using INPUT_TYPE = MultiHeadAttnForward::Param::TENSOR_COMBINATION_TYPE; +using InputType = MultiHeadAttnForward::Param::TensorCombinationType; MGB_DYN_TYPE_OBJ_FINAL_IMPL(MultiHeadAttnForward); @@ -439,9 +441,11 @@ MultiHeadAttnForward::MultiHeadAttnForward( param} { mgb_assert( param.tensor_combination_type == - MultiHeadAttnForward::Param::TENSOR_COMBINATION_TYPE::ALL); + MultiHeadAttnForward::Param::TensorCombinationType::ALL); add_input({queries, keys, values, qkvo_weight_bias, attn_mask, bias_k, bias_v}); - add_output(None)->dtype(queries->dtype()); + add_output(None) + ->dtype(queries->dtype()) + .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); add_output(None) ->dtype(queries->dtype()) .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); @@ -463,9 +467,11 @@ MultiHeadAttnForward::MultiHeadAttnForward( param} { mgb_assert( param.tensor_combination_type == - MultiHeadAttnForward::Param::TENSOR_COMBINATION_TYPE::ONLY_MASK); + MultiHeadAttnForward::Param::TensorCombinationType::ONLY_MASK); add_input({queries, keys, values, qkvo_weight_bias, attn_mask}); - add_output(None)->dtype(queries->dtype()); + add_output(None) + ->dtype(queries->dtype()) + .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); add_output(None) ->dtype(queries->dtype()) .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); @@ -488,9 +494,11 @@ MultiHeadAttnForward::MultiHeadAttnForward( param} { mgb_assert( param.tensor_combination_type == - MultiHeadAttnForward::Param::TENSOR_COMBINATION_TYPE::ONLY_BIASKV); + MultiHeadAttnForward::Param::TensorCombinationType::ONLY_BIASKV); add_input({queries, keys, values, qkvo_weight_bias, bias_k, bias_v}); - add_output(None)->dtype(queries->dtype()); + add_output(None) + ->dtype(queries->dtype()) + .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); add_output(None) ->dtype(queries->dtype()) .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); @@ -512,9 +520,11 @@ MultiHeadAttnForward::MultiHeadAttnForward( param} { mgb_assert( param.tensor_combination_type == - MultiHeadAttnForward::Param::TENSOR_COMBINATION_TYPE::NONE); + MultiHeadAttnForward::Param::TensorCombinationType::NONE); add_input({queries, keys, values, qkvo_weight_bias}); - add_output(None)->dtype(queries->dtype()); + add_output(None) + ->dtype(queries->dtype()) + .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); add_output(None) ->dtype(queries->dtype()) .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); @@ -540,6 +550,7 @@ SymbolVarArray MultiHeadAttnForward::make( mgb_assert(outs.size() == 5); return {outs[0], outs[1], outs[2], outs[3]}; } + SymbolVarArray MultiHeadAttnForward::make( SymbolVar queries, SymbolVar keys, SymbolVar values, SymbolVar qkvo_weight_bias, SymbolVar attn_mask, const Param& param, const OperatorNodeConfig& config) { @@ -553,6 +564,7 @@ SymbolVarArray MultiHeadAttnForward::make( mgb_assert(outs.size() == 5); return {outs[0], outs[1], outs[2], outs[3]}; } + SymbolVarArray MultiHeadAttnForward::make( SymbolVar queries, SymbolVar keys, SymbolVar values, SymbolVar qkvo_weight_bias, SymbolVar bias_k, SymbolVar bias_v, const Param& param, @@ -567,6 +579,7 @@ SymbolVarArray MultiHeadAttnForward::make( mgb_assert(outs.size() == 5); return {outs[0], outs[1], outs[2], outs[3]}; } + SymbolVarArray MultiHeadAttnForward::make( SymbolVar queries, SymbolVar keys, SymbolVar values, SymbolVar qkvo_weight_bias, const Param& param, const OperatorNodeConfig& config) { @@ -583,54 +596,98 @@ SymbolVarArray MultiHeadAttnForward::make( void MultiHeadAttnForward::init_output_static_infer_desc() { using namespace cg::static_infer; auto&& mgr = owner_graph()->static_infer_manager(); - mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0))); + auto input_type = param().tensor_combination_type; - auto infer_oshp1 = [this](TensorShape& dest, const InpVal& iv) { - TensorLayout in0{iv.val[0].shape(), input(0)->dtype()}; - TensorLayout in1{iv.val[1].shape(), input(1)->dtype()}; - TensorLayout in2{iv.val[2].shape(), input(2)->dtype()}; - TensorLayout in3{iv.val[3].shape(), input(3)->dtype()}; - TensorLayout in4{iv.val[4].shape(), input(4)->dtype()}; - TensorLayout in5{iv.val[5].shape(), input(5)->dtype()}; - TensorLayout in6{iv.val[6].shape(), input(6)->dtype()}; +#define DECLARE_LAYOUT_FROM_INPVAL(iv) \ + TensorLayout in0{iv.val[0].shape(), input(0)->dtype()}; \ + TensorLayout in1{iv.val[1].shape(), input(1)->dtype()}; \ + TensorLayout in2{iv.val[2].shape(), input(2)->dtype()}; \ + TensorLayout in3{iv.val[3].shape(), input(3)->dtype()}; \ + TensorLayout in4, in5, in6; \ + if (input_type == InputType::ONLY_MASK) { \ + in4 = {iv.val[4].shape(), input(4)->dtype()}; \ + } \ + if (input_type == InputType::ONLY_BIASKV) { \ + in5 = {iv.val[4].shape(), input(4)->dtype()}; \ + in6 = {iv.val[5].shape(), input(5)->dtype()}; \ + } \ + if (input_type == InputType::ALL) { \ + in4 = {iv.val[4].shape(), input(4)->dtype()}; \ + in5 = {iv.val[5].shape(), input(5)->dtype()}; \ + in6 = {iv.val[6].shape(), input(6)->dtype()}; \ + } + + auto infer_oshp0 = [this, input_type](TensorShape& dest, const InpVal& iv) { + ensure_megdnn_opr(); + DECLARE_LAYOUT_FROM_INPVAL(iv) + TensorLayout o0, o1, o2, o3; + m_dnn_opr->deduce_layout(in0, in1, in2, in3, in4, in5, in6, o0, o1, o2, o3); + dest = o0; + return true; + }; + auto infer_oshp1 = [this, input_type](TensorShape& dest, const InpVal& iv) { + ensure_megdnn_opr(); + DECLARE_LAYOUT_FROM_INPVAL(iv) TensorLayout o0, o1, o2, o3; m_dnn_opr->deduce_layout(in0, in1, in2, in3, in4, in5, in6, o0, o1, o2, o3); dest = o1; return true; }; - mgr.register_shape_infer( - output(1), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_oshp1}); - - auto infer_mask = [this](TensorShape& dest, const InpVal& iv) { + auto infer_mask = [this, input_type](TensorShape& dest, const InpVal& iv) { ensure_megdnn_opr(); dest.ndim = 1; + DECLARE_LAYOUT_FROM_INPVAL(iv) dest.shape[0] = m_dnn_opr->get_mask_reservespace_in_bytes( - {iv.val[0].shape(), input(0)->dtype()}, - {iv.val[1].shape(), input(1)->dtype()}, - {iv.val[2].shape(), input(2)->dtype()}, - {iv.val[3].shape(), input(3)->dtype()}, - {iv.val[4].shape(), input(4)->dtype()}, - {iv.val[5].shape(), input(5)->dtype()}, - {iv.val[6].shape(), input(6)->dtype()}, {}, {}, {}, {}); + in0, in1, in2, in3, in4, in5, in6, {}, {}, {}, {}); return true; }; - auto infer_othr = [this](TensorShape& dest, const InpVal& iv) { + auto infer_othr = [this, input_type](TensorShape& dest, const InpVal& iv) { ensure_megdnn_opr(); dest.ndim = 1; - dest.shape[0] = m_dnn_opr->get_othr_reservespace_in_bytes( - {iv.val[0].shape(), input(0)->dtype()}, - {iv.val[1].shape(), input(1)->dtype()}, - {iv.val[2].shape(), input(2)->dtype()}, - {iv.val[3].shape(), input(3)->dtype()}, - {iv.val[4].shape(), input(4)->dtype()}, - {iv.val[5].shape(), input(5)->dtype()}, - {iv.val[6].shape(), input(6)->dtype()}, {}, {}, {}, {}); + DECLARE_LAYOUT_FROM_INPVAL(iv) + size_t size = m_dnn_opr->get_othr_reservespace_in_bytes( + in0, in1, in2, in3, in4, in5, in6, {}, {}, {}, {}); + dest.shape[0] = size / input(0)->dtype().size(); return true; }; - mgr.register_shape_infer( - output(2), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_mask}); - mgr.register_shape_infer( - output(3), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_othr}); + auto infer_wk = [this, input_type](TensorShape& dest, const InpVal& iv) { + ensure_megdnn_opr(); + dest.ndim = 1; + DECLARE_LAYOUT_FROM_INPVAL(iv) + dest.shape[0] = m_dnn_opr->get_workspace_in_bytes( + in0, in1, in2, in3, in4, in5, in6, {}, {}, {}, {}); + return true; + }; + + DepElement inp0{input(0), DepType::SHAPE}; + DepElement inp1{input(1), DepType::SHAPE}; + DepElement inp2{input(2), DepType::SHAPE}; + DepElement inp3{input(3), DepType::SHAPE}; + DepVal out_dep; + if (input_type == InputType::NONE) { + out_dep = {inp0, inp1, inp2, inp3}; + } + if (input_type == InputType::ONLY_MASK) { + DepElement inp4 = {input(4), DepType::SHAPE}; + out_dep = {inp0, inp1, inp2, inp3, inp4}; + } + if (input_type == InputType::ONLY_BIASKV) { + DepElement inp5 = {input(4), DepType::SHAPE}; + DepElement inp6 = {input(5), DepType::SHAPE}; + out_dep = {inp0, inp1, inp2, inp3, inp5, inp6}; + } + if (input_type == InputType::ALL) { + DepElement inp4 = {input(4), DepType::SHAPE}; + DepElement inp5 = {input(5), DepType::SHAPE}; + DepElement inp6 = {input(6), DepType::SHAPE}; + out_dep = {inp0, inp1, inp2, inp3, inp4, inp5, inp6}; + } + mgr.register_shape_infer(output(0), {SourceType::DEP, out_dep, infer_oshp0}); + mgr.register_shape_infer(output(1), {SourceType::DEP, out_dep, infer_oshp1}); + mgr.register_shape_infer(output(2), {SourceType::DEP, out_dep, infer_mask}); + mgr.register_shape_infer(output(3), {SourceType::DEP, out_dep, infer_othr}); + mgr.register_shape_infer(output(4), {SourceType::DEP, out_dep, infer_wk}); +#undef DECLARE_LAYOUT_FROM_INPVAL } void MultiHeadAttnForward::add_input_layout_constraint() { @@ -647,47 +704,30 @@ void MultiHeadAttnForward::scn_do_execute() { mgb_assert(ret->dev_tensor().empty()); return; } - megdnn::TensorND empty_dnn; - if (input_type == INPUT_TYPE::ALL) { - m_dnn_opr->exec( - input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), - input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(), - input(4)->dev_tensor().as_megdnn(), input(5)->dev_tensor().as_megdnn(), - input(6)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), - output(1)->dev_tensor().as_megdnn(), - output(2)->dev_tensor().as_megdnn(), - output(3)->dev_tensor().as_megdnn(), - get_megdnn_workspace_from_var(output(4))); - } else if (input_type == INPUT_TYPE::ONLY_MASK) { - m_dnn_opr->exec( - input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), - input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(), - input(4)->dev_tensor().as_megdnn(), empty_dnn, empty_dnn, - output(0)->dev_tensor().as_megdnn(), - output(1)->dev_tensor().as_megdnn(), - output(2)->dev_tensor().as_megdnn(), - output(3)->dev_tensor().as_megdnn(), - get_megdnn_workspace_from_var(output(4))); - } else if (input_type == INPUT_TYPE::ONLY_BIASKV) { - m_dnn_opr->exec( - input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), - input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(), - empty_dnn, input(4)->dev_tensor().as_megdnn(), - input(5)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), - output(1)->dev_tensor().as_megdnn(), - output(2)->dev_tensor().as_megdnn(), - output(3)->dev_tensor().as_megdnn(), - get_megdnn_workspace_from_var(output(4))); - } else { - m_dnn_opr->exec( - input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), - input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(), - empty_dnn, empty_dnn, empty_dnn, output(0)->dev_tensor().as_megdnn(), - output(1)->dev_tensor().as_megdnn(), - output(2)->dev_tensor().as_megdnn(), - output(3)->dev_tensor().as_megdnn(), - get_megdnn_workspace_from_var(output(4))); + + megdnn::TensorND in4; + megdnn::TensorND in5; + megdnn::TensorND in6; + if (input_type == InputType::ONLY_MASK) { + in4 = input(4)->dev_tensor().as_megdnn(); + } + if (input_type == InputType::ONLY_BIASKV) { + in5 = input(4)->dev_tensor().as_megdnn(); + in6 = input(5)->dev_tensor().as_megdnn(); + } + if (input_type == InputType::ALL) { + in4 = input(4)->dev_tensor().as_megdnn(); + in5 = input(5)->dev_tensor().as_megdnn(); + in6 = input(6)->dev_tensor().as_megdnn(); } + + m_dnn_opr->exec( + input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), + input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(), in4, + in5, in6, output(0)->dev_tensor().as_megdnn(), + output(1)->dev_tensor().as_megdnn(), output(2)->dev_tensor().as_megdnn(), + output(3)->dev_tensor().as_megdnn(), + get_megdnn_workspace_from_var(output(4))); } cg::OperatorNodeBase::NodeProp* MultiHeadAttnForward::do_make_node_prop() const { @@ -707,7 +747,7 @@ MGB_IMPL_OPR_GRAD(MultiHeadAttnForward) { VarNodeArray ret; mgb_assert(wrt_idx < 7, "wrt_idx %zu is out of range", wrt_idx); auto input_type = opr.param().tensor_combination_type; - if (input_type == INPUT_TYPE::ALL or input_type == INPUT_TYPE::ONLY_MASK) + if (input_type == InputType::ALL or input_type == InputType::ONLY_MASK) grad = MultiHeadAttnBackward::make( out_grad[0], opr.input(0), opr.input(1), opr.input(2), opr.input(3), opr.input(4), opr.output(1), opr.output(2), opr.output(3), opr.param()); @@ -716,11 +756,11 @@ MGB_IMPL_OPR_GRAD(MultiHeadAttnForward) { out_grad[0], opr.input(0), opr.input(1), opr.input(2), opr.input(3), opr.output(1), opr.output(2), opr.output(3), opr.param()); uint32_t nr_ret = 7; - if (input_type == INPUT_TYPE::NONE) + if (input_type == InputType::NONE) nr_ret = 4; - if (input_type == INPUT_TYPE::ONLY_MASK) + if (input_type == InputType::ONLY_MASK) nr_ret = 5; - if (input_type == INPUT_TYPE::ONLY_BIASKV) + if (input_type == InputType::ONLY_BIASKV) nr_ret = 6; for (uint32_t i = 0; i < nr_ret; ++i) { ret.push_back(grad[i].node()); @@ -747,10 +787,14 @@ MultiHeadAttnBackward::MultiHeadAttnBackward( add_input( {diff, queries, keys, values, qkvo_weight_bias, attn_mask, attn_weight, mask_reservespace, othr_reservespace}); + this->output()[0]->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + this->output()[1]->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + this->output()[2]->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); this->output()[3]->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); this->output()[4]->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); this->output()[5]->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); } + MultiHeadAttnBackward::MultiHeadAttnBackward( VarNode* diff, VarNode* queries, VarNode* keys, VarNode* values, VarNode* qkvo_weight_bias, VarNode* attn_weight, VarNode* mask_reservespace, @@ -766,6 +810,9 @@ MultiHeadAttnBackward::MultiHeadAttnBackward( add_input( {diff, queries, keys, values, qkvo_weight_bias, attn_weight, mask_reservespace, othr_reservespace}); + this->output()[0]->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + this->output()[1]->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + this->output()[2]->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); this->output()[3]->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); this->output()[4]->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); this->output()[5]->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); @@ -788,6 +835,7 @@ SymbolVarArray MultiHeadAttnBackward::make( return {outs[0], outs[1], outs[2], outs[3], outs[4], outs[5], {}}; } + SymbolVarArray MultiHeadAttnBackward::make( SymbolVar diff, SymbolVar queries, SymbolVar keys, SymbolVar values, SymbolVar qkvo_weight_bias, SymbolVar attn_weight, SymbolVar mask_reservespace, @@ -814,7 +862,7 @@ void MultiHeadAttnBackward::init_output_static_infer_desc() { mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(3))); mgr.register_shape_infer(output(3), ShapeInferDesc::make_identity(input(4))); auto input_type = param().tensor_combination_type; - if (input_type == INPUT_TYPE::ALL or input_type == INPUT_TYPE::ONLY_BIASKV) { + if (input_type == InputType::ALL or input_type == InputType::ONLY_BIASKV) { mgr.register_shape_infer(output(4), ShapeInferDesc::make_identity(input(4))); mgr.register_shape_infer(output(5), ShapeInferDesc::make_identity(input(4))); } else { @@ -838,46 +886,36 @@ size_t MultiHeadAttnBackward::get_workspace_size_bytes( const TensorShapeArray& input_shapes, const TensorShapeArray& output_shapes) const { auto input_type = megdnn_opr()->param().tensor_combination_type; - megdnn::TensorLayout empty_dnn; - if (input_type == INPUT_TYPE::ALL or input_type == INPUT_TYPE::ONLY_MASK) - return megdnn_opr()->get_workspace_in_bytes( - {input_shapes[0], input(0)->dtype(), input(0)->format()}, - {input_shapes[1], input(1)->dtype(), input(1)->format()}, - {input_shapes[2], input(2)->dtype(), input(2)->format()}, - {input_shapes[3], input(3)->dtype(), input(3)->format()}, - {input_shapes[4], input(4)->dtype(), input(4)->format()}, - {input_shapes[5], input(5)->dtype(), input(5)->format()}, - {input_shapes[6], input(6)->dtype(), input(6)->format()}, - {input_shapes[7], input(7)->dtype(), input(7)->format()}, - {input_shapes[8], input(8)->dtype(), input(8)->format()}, - {output_shapes[0], output(0)->dtype(), output(0)->format()}, - {output_shapes[1], output(1)->dtype(), output(1)->format()}, - {output_shapes[2], output(2)->dtype(), output(2)->format()}, - {output_shapes[3], output(3)->dtype(), output(3)->format()}, - {output_shapes[4], output(4)->dtype(), output(4)->format()}, - {output_shapes[5], output(5)->dtype(), output(5)->format()}); - else - return megdnn_opr()->get_workspace_in_bytes( - {input_shapes[0], input(0)->dtype(), input(0)->format()}, - {input_shapes[1], input(1)->dtype(), input(1)->format()}, - {input_shapes[2], input(2)->dtype(), input(2)->format()}, - {input_shapes[3], input(3)->dtype(), input(3)->format()}, - {input_shapes[4], input(4)->dtype(), input(4)->format()}, empty_dnn, - {input_shapes[5], input(5)->dtype(), input(5)->format()}, - {input_shapes[6], input(6)->dtype(), input(6)->format()}, - {input_shapes[7], input(7)->dtype(), input(7)->format()}, - {output_shapes[0], output(0)->dtype(), output(0)->format()}, - {output_shapes[1], output(1)->dtype(), output(1)->format()}, - {output_shapes[2], output(2)->dtype(), output(2)->format()}, - {output_shapes[3], output(3)->dtype(), output(3)->format()}, - {output_shapes[4], output(4)->dtype(), output(4)->format()}, - {output_shapes[5], output(5)->dtype(), output(5)->format()}); + megdnn::TensorLayout in0{input_shapes[0], input(0)->dtype(), input(0)->format()}; + megdnn::TensorLayout in1{input_shapes[1], input(1)->dtype(), input(1)->format()}; + megdnn::TensorLayout in2{input_shapes[2], input(2)->dtype(), input(2)->format()}; + megdnn::TensorLayout in3{input_shapes[3], input(3)->dtype(), input(3)->format()}; + megdnn::TensorLayout in4{input_shapes[4], input(4)->dtype(), input(4)->format()}; + megdnn::TensorLayout in5, in6, in7, in8; + if (input_type == InputType::ALL or input_type == InputType::ONLY_MASK) { + in5 = {input_shapes[5], input(5)->dtype(), input(5)->format()}; + in6 = {input_shapes[6], input(6)->dtype(), input(6)->format()}; + in7 = {input_shapes[7], input(7)->dtype(), input(7)->format()}; + in8 = {input_shapes[8], input(8)->dtype(), input(8)->format()}; + } else { + in6 = {input_shapes[5], input(5)->dtype(), input(5)->format()}; + in7 = {input_shapes[6], input(6)->dtype(), input(6)->format()}; + in8 = {input_shapes[7], input(7)->dtype(), input(7)->format()}; + } + return megdnn_opr()->get_workspace_in_bytes( + in0, in1, in2, in3, in4, in5, in6, in7, in8, + {output_shapes[0], output(0)->dtype(), output(0)->format()}, + {output_shapes[1], output(1)->dtype(), output(1)->format()}, + {output_shapes[2], output(2)->dtype(), output(2)->format()}, + {output_shapes[3], output(3)->dtype(), output(3)->format()}, + {output_shapes[4], output(4)->dtype(), output(4)->format()}, + {output_shapes[5], output(5)->dtype(), output(5)->format()}); } void MultiHeadAttnBackward::scn_do_execute() { auto input_type = megdnn_opr()->param().tensor_combination_type; megdnn::TensorND empty_dnn; - if (input_type == INPUT_TYPE::ALL or input_type == INPUT_TYPE::ONLY_MASK) + if (input_type == InputType::ALL or input_type == InputType::ONLY_MASK) megdnn_opr()->exec( input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(), diff --git a/src/opr/impl/rand.sereg.h b/src/opr/impl/rand.sereg.h index 0552b2d67191dcdea1c9598c8a07e1b49ff6f9e8..bb38514a1b23c7611bcea31cc55f836d6fb54ecb 100644 --- a/src/opr/impl/rand.sereg.h +++ b/src/opr/impl/rand.sereg.h @@ -33,31 +33,31 @@ struct OprMaker { template <> struct OprMaker { using Param = opr::MultiHeadAttn::Param; - using INPUT_TYPE = Param::TENSOR_COMBINATION_TYPE; + using InputType = Param::TensorCombinationType; static cg::OperatorNodeBase* make( const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, const OperatorNodeConfig& config) { MGB_MARK_USED_VAR(graph); if (i.size() == 7) { - mgb_assert(INPUT_TYPE::ALL == param.tensor_combination_type); + mgb_assert(InputType::ALL == param.tensor_combination_type); return opr::MultiHeadAttn::make( i[0], i[1], i[2], i[3], i[4], i[5], i[6], param, config)[0] .node() ->owner_opr(); } else if (i.size() == 6) { - mgb_assert(INPUT_TYPE::ONLY_BIASKV == param.tensor_combination_type); + mgb_assert(InputType::ONLY_BIASKV == param.tensor_combination_type); return opr::MultiHeadAttn::make( i[0], i[1], i[2], i[3], i[4], i[5], param, config)[0] .node() ->owner_opr(); } else if (i.size() == 5) { - mgb_assert(INPUT_TYPE::ONLY_MASK == param.tensor_combination_type); + mgb_assert(InputType::ONLY_MASK == param.tensor_combination_type); return opr::MultiHeadAttn::make( i[0], i[1], i[2], i[3], i[4], param, config)[0] .node() ->owner_opr(); } else { - mgb_assert(INPUT_TYPE::NONE == param.tensor_combination_type); + mgb_assert(InputType::NONE == param.tensor_combination_type); return opr::MultiHeadAttn::make(i[0], i[1], i[2], i[3], param, config)[0] .node() ->owner_opr();