提交 f95f2037 编写于 作者: M Megvii Engine Team

feat(opr): add multiattention cuda proxy backend

GitOrigin-RevId: d5d688db5fa44af75a8d82578edf1a783d3d6ea9
上级 7af49c98
......@@ -1346,12 +1346,12 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'),
.add_fields('bool', Doc('obias', 'Whether to add out bias.'), 'false')
.add_fields('float32', Doc('sm_scaler', 'Softmax smoothing (1.0 >= smScaler >= 0.0) or sharpening (smScaler > 1.0) coefficient.'), '1.f')
.add_fields('uint32', Doc('input_order', 'The sequence data layout, allows the user to select 3! = 6 different data layouts or permutations of BEAM, BATCH and TIME dimensions.'), '0')
.add_enum('ATTN_MASK_TYPE',
.add_enum('AttnMaskType',
Doc('NO_MASK = 0', 'Indicates that there is no mask.'),
Doc('DEFAULT_MASK = 1', 'Use the default mask which the upper right triangle of the mask is -inf, and the diagonal and lower left triangle are all 0.'),
Doc('CUDNN_STYLE_MASK = 2', 'Indicates the use of a cudnn style mask.'),
Doc('USER_DEFINED_MASK = 3', 'Use the user-defined mask.'), name_field="attn_mask_type")
.add_enum(Doc('TENSOR_COMBINATION_TYPE', 'Used to determine whether mask tensor and bias_kv tensor exist in the input. Note that bias_kv here is not kbias and vbias in the linear layer, and bias_kv here will be added to the K and V at sequence dimensions, where K and V are the matrices of key and value after projection, and K and V will be used to calculate the attention matrix.'),
.add_enum(Doc('TensorCombinationType', 'Used to determine whether mask tensor and bias_kv tensor exist in the input. Note that bias_kv here is not m_kbias and m_vbias in the linear layer, and bias_kv here will be added to the K and V at sequence dimensions, where K and V are the matrices of key and value after projection, and K and V will be used to calculate the attention matrix.'),
Doc('NONE = 0', 'Indicates that there are no mask tensor and bias_kv tensor in the input.'),
Doc('ONLY_MASK = 1',
'Indicates that there is only mask tensor in input.'),
......@@ -1363,5 +1363,5 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'),
.add_fields('bool', Doc('training', 'Whether it is in training mode.'), 'true')
.add_fields('uint64', Doc('seed', 'Random number seed for drop'), '0')
.add_fields('float32', Doc('attn_prob', 'Dropout probability on attention, is applied directly to the softmax output'), '0.f')
.add_fields('float32', Doc('out_prob', 'Dropout probability on output, alters the multi-head attention output'), '0.f')
.add_fields('float32', Doc('out_prob', 'Dropout probability on output, alters the multi-m_head attention output'), '0.f')
)
......@@ -9,7 +9,8 @@
namespace megdnn {
using Param = MultiHeadAttnBase::Param;
using INPUT_TYPE = Param::TENSOR_COMBINATION_TYPE;
using InputType = Param::TensorCombinationType;
using MaskType = Param::AttnMaskType;
void MultiHeadAttnForward::check_exec(
const TensorLayout& queries, const TensorLayout& keys,
......@@ -33,12 +34,12 @@ void MultiHeadAttnForward::check_exec(
bool have_mask = false;
bool have_biaskv = false;
auto input_type = p.tensor_combination_type;
if (input_type == INPUT_TYPE::ONLY_BIASKV or input_type == INPUT_TYPE::ALL) {
if (input_type == InputType::ONLY_BIASKV or input_type == InputType::ALL) {
have_biaskv = true;
megdnn_assert_contiguous(bias_k);
megdnn_assert_contiguous(bias_v);
}
if (input_type == INPUT_TYPE::ONLY_MASK or input_type == INPUT_TYPE::ALL) {
if (input_type == InputType::ONLY_MASK or input_type == InputType::ALL) {
have_mask = true;
megdnn_assert_contiguous(attn_mask);
}
......@@ -162,9 +163,10 @@ void MultiHeadAttnForward::check_exec(
if (qprojsize == 0 and kprojsize == 0)
megdnn_assert(embeding_size == ksize, "%s", param_errmsg().c_str());
if (qprojsize == 0 and kprojsize != 0)
megdnn_assert(embeding_size == kprojsize, "%s", param_errmsg().c_str());
megdnn_assert(
embeding_size * p.num_heads == kprojsize, "%s", param_errmsg().c_str());
if (qprojsize != 0 and kprojsize == 0)
megdnn_assert(qprojsize == ksize, "%s", param_errmsg().c_str());
megdnn_assert(qprojsize == ksize * p.num_heads, "%s", param_errmsg().c_str());
if (qprojsize != 0 and kprojsize != 0)
megdnn_assert(qprojsize == kprojsize, "%s", param_errmsg().c_str());
if (p.qbias)
......@@ -184,7 +186,7 @@ void MultiHeadAttnForward::check_exec(
if (p.oproj_size > 0 and p.vproj_size > 0)
weight_len += vprojsize * oprojsize + (p.obias ? oprojsize : 0);
else if (p.oproj_size > 0 and p.vproj_size == 0)
weight_len += vsize * oprojsize + (p.obias ? oprojsize : 0);
weight_len += p.num_heads * vsize * oprojsize + (p.obias ? oprojsize : 0);
megdnn_assert(
weight_len == qkvo_weight_bias.total_nr_elems(),
"qkvo_weight_bias length should be %zu, but got %zu. details: %s",
......@@ -208,7 +210,7 @@ void MultiHeadAttnBackward::deduce_layout(
dvalues = values;
dqkvo_weight_bias = qkvo_weight_bias;
auto input_type = param().tensor_combination_type;
if (input_type == INPUT_TYPE::ONLY_BIASKV or input_type == INPUT_TYPE::ALL) {
if (input_type == InputType::ONLY_BIASKV or input_type == InputType::ALL) {
dbias_k = TensorLayout(
{1, 1, param().kproj_size ? param().kproj_size : param().k_size},
keys.dtype);
......@@ -255,8 +257,8 @@ void MultiHeadAttnBackward::check_exec(
auto input_type = p.tensor_combination_type;
bool have_mask = false;
bool have_biaskv =
input_type == INPUT_TYPE::ONLY_BIASKV or input_type == INPUT_TYPE::ALL;
if (input_type == INPUT_TYPE::ONLY_MASK or input_type == INPUT_TYPE::ALL) {
input_type == InputType::ONLY_BIASKV or input_type == InputType::ALL;
if (input_type == InputType::ONLY_MASK or input_type == InputType::ALL) {
have_mask = true;
megdnn_assert_contiguous(attn_mask);
}
......@@ -296,8 +298,9 @@ void MultiHeadAttnBackward::check_exec(
};
// layout check
size_t osize = p.oproj_size != 0 ? p.oproj_size
: (p.vproj_size != 0 ? p.vproj_size : p.v_size);
size_t osize = p.oproj_size != 0
? p.oproj_size
: (p.vproj_size != 0 ? p.vproj_size : p.v_size * p.num_heads);
TensorLayout diff_expect = TensorLayout(
TensorShape{queries.shape[0], queries.shape[1], osize}, queries.dtype);
megdnn_assert(equal_layout(diff, diff_expect), "%s", errmsg().c_str());
......@@ -409,9 +412,10 @@ void MultiHeadAttnBackward::check_exec(
if (qprojsize == 0 and kprojsize == 0)
megdnn_assert(embeding_size == ksize, "%s", param_errmsg().c_str());
if (qprojsize == 0 and kprojsize != 0)
megdnn_assert(embeding_size == kprojsize, "%s", param_errmsg().c_str());
megdnn_assert(
embeding_size * p.num_heads == kprojsize, "%s", param_errmsg().c_str());
if (qprojsize != 0 and kprojsize == 0)
megdnn_assert(qprojsize == ksize, "%s", param_errmsg().c_str());
megdnn_assert(qprojsize == ksize * p.num_heads, "%s", param_errmsg().c_str());
if (qprojsize != 0 and kprojsize != 0)
megdnn_assert(qprojsize == kprojsize, "%s", param_errmsg().c_str());
if (p.qbias)
......@@ -431,7 +435,7 @@ void MultiHeadAttnBackward::check_exec(
if (p.oproj_size > 0 and p.vproj_size > 0)
weight_len += vprojsize * oprojsize + (p.obias ? oprojsize : 0);
else if (p.oproj_size > 0 and p.vproj_size == 0)
weight_len += vsize * oprojsize + (p.obias ? oprojsize : 0);
weight_len += p.num_heads * vsize * oprojsize + (p.obias ? oprojsize : 0);
megdnn_assert(
weight_len == qkvo_weight_bias.total_nr_elems(),
"qkvo_weight_bias length should be %zu, but got %zu. details: %s",
......
#pragma once
#include "megdnn/dtype.h"
#include "megdnn/basic_types.h"
#include "megdnn/handle.h"
#include "megdnn/oprs/linalg.h"
#include "megdnn/oprs/nn.h"
#include "src/common/utils.h"
namespace megdnn {
namespace multi_head_attn {
inline void matmul_deduce_layout(
std::unique_ptr<MatrixMulForward>& opr, const TensorLayout& A,
const TensorLayout& B, TensorLayout& C) {
megdnn_assert(A.ndim == 3 && B.ndim == 2);
auto m_param = opr->param();
size_t A1, A2, B0, B1;
A1 = A.shape[1];
A2 = A.shape[2];
B0 = B.shape[0];
B1 = B.shape[1];
if (m_param.transposeA) {
std::swap(A1, A2);
}
if (m_param.transposeB) {
std::swap(B0, B1);
}
C = TensorLayout(TensorShape({A.shape[0], A1, B1}), A.dtype);
}
inline void matmul_exec(
std::unique_ptr<MatrixMulForward>& opr, _megdnn_tensor_in A,
_megdnn_tensor_in B, _megdnn_tensor_out C, _megdnn_workspace workspace) {
auto Batch = A.layout.shape[0];
auto Astrd = A.layout.dtype.size() * A.layout.stride[0],
Cstrd = C.layout.dtype.size() * C.layout.stride[0];
auto Aref = A.get_ref_ptr();
auto Bref = B.get_ref_ptr();
auto Cref = C.get_ref_ptr();
rep(b, Batch) {
//! all tensors should share the same RefPtr
auto A_ref = Aref;
A_ref += b * Astrd;
auto B_ref = Bref;
auto C_ref = Cref;
C_ref += b * Cstrd;
TensorND A_{A.layout.remove_axis(0), A_ref};
TensorND B_{B.layout, B_ref};
TensorND C_{C.layout.remove_axis(0), C_ref};
opr->exec(A_, B_, C_, workspace);
}
}
using Param = MultiHeadAttnBase::Param;
using MaskType = Param::AttnMaskType;
using InputType = Param::TensorCombinationType;
/***************************** MHA base *****************************/
#define _MHA_FORWARD(INPUT_TYPE, OUTPUT_TYPE) \
INPUT_TYPE queries, INPUT_TYPE keys, INPUT_TYPE values, \
INPUT_TYPE qkvo_weight_bias, INPUT_TYPE attn_mask, INPUT_TYPE bias_k, \
INPUT_TYPE bias_v, OUTPUT_TYPE out, OUTPUT_TYPE attn_weight, \
OUTPUT_TYPE mask_reservespace, OUTPUT_TYPE othr_reservespace
#define _MHA_BACKWARD(INPUT_TYPE, OUTPUT_TYPE) \
INPUT_TYPE diff, INPUT_TYPE queries, INPUT_TYPE keys, INPUT_TYPE values, \
INPUT_TYPE qkvo_weight_bias, INPUT_TYPE attn_mask, INPUT_TYPE attn_weight, \
INPUT_TYPE mask_reservespace, INPUT_TYPE othr_reservespace, \
OUTPUT_TYPE dqueries, OUTPUT_TYPE dkeys, OUTPUT_TYPE dvalues, \
OUTPUT_TYPE dqkvo_weight_bias, OUTPUT_TYPE dbias_k, OUTPUT_TYPE dbias_v
#define _MHA_PROXY_PRE(HANDLE_TYPE, PARAM_TYPE) HANDLE_TYPE handle, PARAM_TYPE param
#define MHA_EXEC_PARAM(cb) \
cb(_megdnn_tensor_in, _megdnn_tensor_out), _megdnn_workspace workspace
#define MHA_LAYOUT_CONST_PARAM(cb) cb(const TensorLayout&, const TensorLayout&)
#define MHA_LAYOUT_PARAM(cb) cb(const TensorLayout&, TensorLayout&)
#define MHA_CALL(cb) cb(, )
#define MHA_PROXY_PRE_PARAM _MHA_PROXY_PRE(Handle*, Param&)
#define MHA_PROXY_PRE_CALL _MHA_PROXY_PRE(, )
/***************************** MHA forward *****************************/
#define MHA_FORWARD_EXEC_PARAM MHA_EXEC_PARAM(_MHA_FORWARD)
#define MHA_FORWARD_LAYOUT_CONST_PARAM MHA_LAYOUT_CONST_PARAM(_MHA_FORWARD)
#define MHA_FORWARD_LAYOUT_PARAM MHA_LAYOUT_PARAM(_MHA_FORWARD)
#define MHA_FORWARD_CALL MHA_CALL(_MHA_FORWARD)
#define MHA_PROXY_FORWARD_EXEC_PARAM MHA_PROXY_PRE_PARAM, MHA_FORWARD_EXEC_PARAM
#define MHA_PROXY_FORWARD_LAYOUT_PARAM MHA_PROXY_PRE_PARAM, MHA_FORWARD_LAYOUT_PARAM
#define MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM \
MHA_PROXY_PRE_PARAM, MHA_FORWARD_LAYOUT_CONST_PARAM
#define MHA_PROXY_FORWARD_CALL MHA_PROXY_PRE_CALL, MHA_FORWARD_CALL
/***************************** MHA backward *****************************/
#define MHA_BACKWARD_EXEC_PARAM MHA_EXEC_PARAM(_MHA_BACKWARD)
#define MHA_BACKWARD_LAYOUT_CONST_PARAM MHA_LAYOUT_CONST_PARAM(_MHA_BACKWARD)
#define MHA_BACKWARD_LAYOUT_PARAM MHA_LAYOUT_PARAM(_MHA_BACKWARD)
#define MHA_BACKWARD_CALL MHA_CALL(_MHA_BACKWARD)
#define MHA_PROXY_BACKWARD_EXEC_PARAM MHA_PROXY_PRE_PARAM, MHA_BACKWARD_EXEC_PARAM
#define MHA_PROXY_BACKWARD_LAYOUT_PARAM MHA_PROXY_PRE_PARAM, MHA_BACKWARD_LAYOUT_PARAM
#define MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM \
MHA_PROXY_PRE_PARAM, MHA_BACKWARD_LAYOUT_CONST_PARAM
#define MHA_PROXY_BACKWARD_CALL MHA_PROXY_PRE_CALL, MHA_BACKWARD_CALL
/***************************** MHA other *****************************/
#define MHA_FORWARD_TENSOR_TO_LAYOUT_CALL \
queries.layout, keys.layout, values.layout, qkvo_weight_bias.layout, \
attn_mask.layout, bias_k.layout, bias_v.layout, out.layout, \
attn_weight.layout, mask_reservespace.layout, othr_reservespace.layout
#define MHA_BACKWARD_TENSOR_TO_LAYOUT_CALL \
diff.layout, queries.layout, keys.layout, values.layout, qkvo_weight_bias.layout, \
attn_mask.layout, attn_weight.layout, mask_reservespace.layout, \
othr_reservespace.layout, dqueries.layout, dkeys.layout, dvalues.layout, \
dqkvo_weight_bias.layout, dbias_k.layout, dbias_v.layout
#define MHA_PROXY_FORWARD_TENSOR_TO_LAYOUT_CALL \
MHA_PROXY_PRE_CALL, MHA_FORWARD_TENSOR_TO_LAYOUT_CALL
#define MHA_PROXY_BACKWARD_TENSOR_TO_LAYOUT_CALL \
MHA_PROXY_PRE_CALL, MHA_BACKWARD_TENSOR_TO_LAYOUT_CALL
} // namespace multi_head_attn
} // namespace megdnn
此差异已折叠。
#pragma once
#include "megdnn/dtype.h"
#include "megdnn/basic_types.h"
#include "megdnn/handle.h"
#include "megdnn/oprs/linalg.h"
#include "megdnn/oprs/nn.h"
#include "src/common/multi_head_attn/helper.h"
#include "src/common/utils.h"
namespace megdnn {
namespace multi_head_attn {
struct MHABackwardProxyBase {
MHABackwardProxyBase() {}
virtual ~MHABackwardProxyBase() = default;
/********************** function member **********************/
template <typename T>
void exec_internal(MHA_PROXY_BACKWARD_EXEC_PARAM);
void exec(MHA_PROXY_BACKWARD_EXEC_PARAM);
// lambda
#define cb(DType) \
virtual void move_scaler_to_device( \
Handle* handle, DTypeTrait<DType>::ctype* dst, \
DTypeTrait<DType>::ctype* src) = 0;
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
size_t get_workspace_in_bytes(MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM);
size_t get_mask_reservespace_in_bytes(MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM);
size_t get_othr_reservespace_in_bytes(MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM);
void layout_refill(MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM);
bool layout_ismatch(MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM);
WorkspaceBundle get_mask_reservespace_bundle(
MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM, void* ptr = nullptr);
WorkspaceBundle get_othr_reservespace_bundle(
MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM, void* ptr = nullptr);
WorkspaceBundle get_workspace_bundle(
MHA_PROXY_BACKWARD_LAYOUT_CONST_PARAM, void* ptr = nullptr);
/************************ data member ************************/
std::unique_ptr<MatrixMulForward> m_matmul_opr;
std::unique_ptr<BatchedMatrixMul> m_bmatmul_opr;
std::unique_ptr<AddUpdate> m_add_opr;
std::unique_ptr<Elemwise> m_elem_opr;
std::unique_ptr<Reduce> m_reduce_opr;
std::unique_ptr<SoftmaxBackward> m_softmaxbw_opr;
std::unique_ptr<Dropout> m_dropout_opr;
std::unique_ptr<DropoutBackward> m_dropoutbw_opr;
std::unique_ptr<Relayout> m_relayout_opr;
// metadata
size_t m_sizeof_datatype;
megdnn::DTypeEnum m_datatype;
size_t m_wq_off, m_wk_off, m_wv_off, m_wo_off;
size_t m_bq_off, m_bk_off, m_bv_off, m_bo_off;
size_t m_head, m_embed_size, m_ksize, m_vsize, m_qproj_size, m_kproj_size,
m_vproj_size, m_oproj_size;
bool m_qbias, m_kbias, m_vbias, m_obias;
// out = dropout(out)
TensorLayout m_mask2_layout;
TensorLayout m_grad_drop2_layout;
size_t m_grad_drop2_workspacesize;
TensorLayout m_grad_out_layout;
// out = z @ wo + bo
TensorLayout m_wo_layout, m_bo_layout;
TensorLayout m_grad_z_layout, m_grad_wo_layout, m_grad_bo_layout;
size_t m_grad_z_workspacesize, m_grad_wo0_workspacesize, m_grad_wo1_workspacesize,
m_grad_bo0_workspacesize, m_grad_bo1_workspacesize;
// z = nz
TensorLayout m_grad_nz_layout;
// nz = ny @ nv
TensorLayout m_grad_nv_layout, m_grad_ny_layout;
size_t m_grad_nv_workspacesize, m_grad_ny_workspacesize;
// ny = dropout(ny)
TensorLayout m_mask1_layout;
TensorLayout m_grad_drop1_layout;
size_t m_grad_drop1_workspacesize;
// ny = softmax(nx)
TensorLayout m_grad_nx_layout;
size_t m_grad_nx_workspacesize;
// nx = nq @ nk
TensorLayout m_grad_nq_layout, m_grad_nk_layout;
size_t m_grad_nq_workspacesize, m_grad_nk_workspacesize;
// nq, nk, nv = q, k, v
TensorLayout m_grad_q_layout, m_grad_k_layout, m_grad_v_layout;
// q = qin @ wq + bq
TensorLayout m_wq_layout, m_bq_layout;
TensorLayout m_grad_qin_layout, m_grad_wq_layout, m_grad_bq_layout;
size_t m_grad_qin_workspacesize, m_grad_wq0_workspacesize, m_grad_wq1_workspacesize,
m_grad_bq0_workspacesize, m_grad_bq1_workspacesize;
size_t m_grad_qin_reduce_workspacesize, m_grad_kin_reduce_workspacesize,
m_grad_vin_reduce_workspacesize;
// k = kin @ wk + bk
TensorLayout m_wk_layout, m_bk_layout;
TensorLayout m_grad_kin_layout, m_grad_wk_layout, m_grad_bk_layout;
size_t m_grad_kin_workspacesize, m_grad_wk0_workspacesize, m_grad_wk1_workspacesize,
m_grad_bk0_workspacesize, m_grad_bk1_workspacesize;
// v = vin @ wv + bv
TensorLayout m_wv_layout, m_bv_layout;
TensorLayout m_grad_vin_layout, m_grad_wv_layout, m_grad_bv_layout;
size_t m_grad_vin_workspacesize, m_grad_wv0_workspacesize, m_grad_wv1_workspacesize,
m_grad_bv0_workspacesize, m_grad_bv1_workspacesize;
};
} // namespace multi_head_attn
} // namespace megdnn
此差异已折叠。
#pragma once
#include "megdnn/dtype.h"
#include "megdnn/basic_types.h"
#include "megdnn/handle.h"
#include "megdnn/oprs/linalg.h"
#include "megdnn/oprs/nn.h"
#include "src/common/multi_head_attn/helper.h"
#include "src/common/utils.h"
namespace megdnn {
namespace multi_head_attn {
struct MHAForwardProxyBase {
MHAForwardProxyBase() {}
virtual ~MHAForwardProxyBase() = default;
/********************** function member **********************/
template <typename T>
void exec_internal(MHA_PROXY_FORWARD_EXEC_PARAM);
void exec(MHA_PROXY_FORWARD_EXEC_PARAM);
// lambda
#define cb(DType) \
virtual void move_scaler_to_device( \
Handle* handle, DTypeTrait<DType>::ctype* dst, \
DTypeTrait<DType>::ctype* src) = 0;
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
void deduce_layout(MHA_PROXY_FORWARD_LAYOUT_PARAM);
size_t get_workspace_in_bytes(MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM);
size_t get_mask_reservespace_in_bytes(MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM);
size_t get_othr_reservespace_in_bytes(MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM);
void layout_refill(MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM);
bool layout_ismatch(MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM);
WorkspaceBundle get_mask_reservespace_bundle(
MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM, void* ptr = nullptr);
WorkspaceBundle get_othr_reservespace_bundle(
MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM, void* ptr = nullptr);
WorkspaceBundle get_workspace_bundle(
MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM, void* ptr = nullptr);
/************************ data member ************************/
std::unique_ptr<MatrixMulForward> m_matmul_opr;
std::unique_ptr<BatchedMatrixMul> m_bmatmul_opr;
std::unique_ptr<AddUpdate> m_add_opr;
std::unique_ptr<Elemwise> m_elem_opr;
std::unique_ptr<Softmax> m_softmax_opr;
std::unique_ptr<Dropout> m_dropout_opr;
std::unique_ptr<Relayout> m_relayout_opr;
std::unique_ptr<RepeatForward> m_repeat_opr;
// metadata
size_t m_sizeof_datatype;
megdnn::DTypeEnum m_datatype;
size_t m_wq_off, m_wk_off, m_wv_off, m_wo_off;
size_t m_bq_off, m_bk_off, m_bv_off, m_bo_off;
size_t m_heads, m_embed_size, m_ksize, m_vsize, m_qproj_size, m_kproj_size,
m_vproj_size, m_oproj_size;
bool m_qbias, m_kbias, m_vbias, m_obias;
// q/k/v = matmul(qu/ky/va, wq/wk/wv, bq/bk/bv)
// nq/nk/nv = dimshuffle(q/k/v) (norm to multihead)
TensorLayout m_wq_layout, m_wk_layout, m_wv_layout;
TensorLayout m_bq_layout, m_bk_layout, m_bv_layout;
TensorLayout m_q_layout, m_k_layout, m_v_layout;
TensorLayout m_nq_layout, m_nk_layout, m_nv_layout;
size_t m_q_workspacesize, m_k_workspacesize, m_v_workspacesize;
size_t m_q_head_repeat_workspacesize, m_k_head_repeat_workspacesize,
m_v_head_repeat_workspacesize;
// nx = matmul(nq, nk)
// ny = softmax(nx), ny_layout = m_nx_layout;
// ny = dropout(ny), dropout1_layout = m_nx_layout;
TensorLayout m_nx_layout;
TensorLayout m_mask1_layout;
size_t m_nx_workspacesize, m_softmax_workspacesize, m_dropout1_workspacesize;
// nz = matmul(ny, v)
// z = dimshuffle(nz) (multihead to norm)
TensorLayout m_nz_layout;
TensorLayout m_z_layout;
size_t m_nz_workspacesize;
// out = matmul(z, wo, bo)
// out = dropout(out), dropout2_layout = m_out_layout;
TensorLayout m_wo_layout, m_bo_layout;
TensorLayout m_out_layout;
TensorLayout m_mask2_layout;
size_t m_out_workspacesize, m_dropout2_workspacesize;
};
} // namespace multi_head_attn
} // namespace megdnn
#include "src/cuda/multi_head_attn/cudnn_fwbw.h"
#include <vector>
#include "megdnn/handle.h"
#include "src/cuda/utils.h"
#if CUDNN_VERSION >= 8004
#include "megdnn/dtype.h"
namespace megdnn {
namespace cuda {
/***************************** AuxiliaryArray *****************************/
AuxiliaryArray::~AuxiliaryArray() {
if (attnMaskType != MaskType::CUDNN_STYLE_MASK) {
if (devSeqQArray) {
cuda_check(cudaFree(devSeqQArray));
}
if (devSeqKArray) {
cuda_check(cudaFree(devSeqKArray));
}
}
}
bool AuxiliaryArray::is_initialized(
const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK,
MaskType _attnMaskType) {
if (_batchSize != batchSize or _seqLenQ != seqLenQ or _seqLenK != seqLenK or
_attnMaskType != attnMaskType or (seqQArray.size() != _batchSize) or
(seqKArray.size() != _batchSize) or !devSeqQArray or !devSeqKArray or
(loWinIdx.size() != _seqLenQ) or (hiWinIdx.size() != _seqLenQ)) {
return false;
}
return true;
}
void AuxiliaryArray::set_cudnn_style_mask(Handle* handle, const TensorND& attn_mask) {
megdnn_assert(attnMaskType == MaskType::CUDNN_STYLE_MASK);
auto stream = cuda_stream(handle);
#define T DTypeTrait<::megdnn::dtype::Int32>::ctype
devSeqQArray = attn_mask.ptr<T>() + 2 * seqLenQ;
devSeqKArray = attn_mask.ptr<T>() + 2 * seqLenQ + batchSize;
cuda_check(cudaMemcpyAsync(
seqQArray.data(), devSeqQArray, batchSize * sizeof(int),
cudaMemcpyDeviceToHost, stream));
cuda_check(cudaMemcpyAsync(
seqKArray.data(), devSeqKArray, batchSize * sizeof(int),
cudaMemcpyDeviceToHost, stream));
cuda_check(cudaMemcpyAsync(
loWinIdx.data(), attn_mask.ptr<T>(), seqLenQ * sizeof(int),
cudaMemcpyDeviceToHost, stream));
cuda_check(cudaMemcpyAsync(
hiWinIdx.data(), attn_mask.ptr<T>() + seqLenQ, seqLenQ * sizeof(int),
cudaMemcpyDeviceToHost, stream));
#undef T
}
void AuxiliaryArray::set(
Handle* handle, const size_t _batchSize, const size_t _seqLenQ,
const size_t _seqLenK, MaskType _attnMaskType) {
if (_batchSize == batchSize && _seqLenQ == seqLenQ && _seqLenK == seqLenK &&
_attnMaskType == attnMaskType) {
return;
} else {
if (attnMaskType != MaskType::CUDNN_STYLE_MASK) {
if (devSeqQArray) {
cuda_check(cudaFree(devSeqQArray));
}
if (devSeqKArray) {
cuda_check(cudaFree(devSeqKArray));
}
}
};
seqLenQ = _seqLenQ;
seqLenK = _seqLenK;
batchSize = _batchSize;
attnMaskType = _attnMaskType;
loWinIdx.resize(seqLenQ);
hiWinIdx.resize(seqLenQ);
size_t seqQArraySize = 1 * batchSize;
size_t seqKArraySize = batchSize;
seqQArray.resize(seqQArraySize);
seqKArray.resize(seqKArraySize);
if (attnMaskType == MaskType::CUDNN_STYLE_MASK) {
return;
}
for (size_t i = 0; i < seqQArraySize; ++i) {
seqQArray[i] = seqLenQ;
}
for (size_t i = 0; i < seqKArraySize; ++i) {
seqKArray[i] = seqLenK;
}
cuda_check(cudaMalloc((void**)&devSeqQArray, seqQArraySize * sizeof(int)));
cuda_check(cudaMalloc((void**)&devSeqKArray, seqKArraySize * sizeof(int)));
auto stream = cuda_stream(handle);
cuda_check(cudaMemcpyAsync(
devSeqQArray, seqQArray.data(), seqQArraySize * sizeof(int),
cudaMemcpyHostToDevice, stream));
cuda_check(cudaMemcpyAsync(
devSeqKArray, seqKArray.data(), seqKArraySize * sizeof(int),
cudaMemcpyHostToDevice, stream));
for (size_t i = 0; i < seqLenQ; ++i) {
loWinIdx[i] = 0;
if (attnMaskType == MaskType::DEFAULT_MASK) {
hiWinIdx[i] = i + 1;
} else if (attnMaskType == MaskType::NO_MASK) {
hiWinIdx[i] = seqLenK;
}
}
}
/***************************** MultiHeadAttnStatus *****************************/
void MultiHeadAttnStatus::set(
Handle* handle, const Param& p, const TensorLayout& q, const TensorLayout& k,
const TensorLayout& v) {
// It is consistent with the conditions judged in is_initialized.
// dropout
float attn_prob = p.training ? p.attn_prob : 0.f;
float out_prob = p.training ? p.out_prob : 0.f;
if (!attn_dropout_status.initialized()) {
attn_dropout_status.set(cudnn_handle(handle), p.seed, attn_prob);
}
if (!out_dropout_status.initialized()) {
out_dropout_status.set(cudnn_handle(handle), p.seed, out_prob);
}
if (attn_dropout_status.drop_prob != attn_prob) {
attn_dropout_status.drop_prob = attn_prob;
attn_dropout_status.restore_desc(cudnn_handle(handle));
}
if (out_dropout_status.drop_prob != out_prob) {
out_dropout_status.drop_prob = out_prob;
out_dropout_status.restore_desc(cudnn_handle(handle));
}
// size
batchSize = q.shape[0];
seqLenQ = q.shape[1];
seqLenK = k.shape[1];
numHeads = p.num_heads;
qSize = p.embeding_size;
kSize = p.k_size;
vSize = p.v_size;
qProjSize = p.qproj_size / numHeads;
kProjSize = p.kproj_size / numHeads;
vProjSize = p.vproj_size / numHeads;
oProjSize = p.oproj_size;
attnMaskType = p.attn_mask_type;
bias = p.qbias or p.kbias or p.vbias or p.obias;
cudnnDataType_t cudnn_dtype = to_cudnn_dtype(q.dtype);
auto flag = CUDNN_ATTN_QUERYMAP_ONE_TO_ONE;
if (bias) {
flag = flag | CUDNN_ATTN_ENABLE_PROJ_BIASES;
}
#if CUDNN_VERSION < 8600
cudnn_check(cudnnSetAttnDescriptor(
attn_desc, flag, numHeads, p.sm_scaler, cudnn_dtype, cudnn_dtype,
CUDNN_DEFAULT_MATH, attn_dropout_status.desc.desc, NULL, qSize, kSize,
vSize, qProjSize, kProjSize, vProjSize, oProjSize, seqLenQ, seqLenK,
batchSize, 1));
#else
cudnn_check(cudnnSetAttnDescriptor(
attn_desc, flag, numHeads, p.sm_scaler, cudnn_dtype, cudnn_dtype,
CUDNN_DEFAULT_MATH, attn_dropout_status.desc.desc,
out_dropout_status.desc.desc, qSize, kSize, vSize, qProjSize, kProjSize,
vProjSize, oProjSize, seqLenQ, seqLenK, batchSize, 1));
#endif
// misc
auxArray.set(handle, batchSize, seqLenQ, seqLenK, p.attn_mask_type);
if (p.training) {
cudnnGetMultiHeadAttnBuffers(
cudnn_handle(handle), attn_desc, &sizeWeights, &sizeWkspace,
&sizeReserve);
} else {
cudnnGetMultiHeadAttnBuffers(
cudnn_handle(handle), attn_desc, &sizeWeights, &sizeWkspace, NULL);
sizeReserve = 0;
}
}
void MultiHeadAttnStatus::set_cudnn_style_mask(
Handle* handle, const TensorND& attn_mask) {
auxArray.set_cudnn_style_mask(handle, attn_mask);
}
bool MultiHeadAttnStatus::is_initialized(
const Param& p, const TensorLayout& q, const TensorLayout& k,
const TensorLayout& v) {
// By default, the info of q, k and v must be consistent with the corresponding
// parameters in param respectively, otherwise an error will occur (so, check is not
// done here, mainly by check_exec).
// dropout
float attn_prob = p.training ? p.attn_prob : 0.f;
float out_prob = p.training ? p.out_prob : 0.f;
if (!attn_dropout_status.initialized() or !out_dropout_status.initialized() or
attn_dropout_status.drop_prob != attn_prob or
out_dropout_status.drop_prob != out_prob) {
return false;
}
// size
if (q.shape[0] != batchSize or q.shape[1] != seqLenQ or k.shape[1] != seqLenK or
q.shape[2] != qSize or k.shape[2] != kSize or v.shape[2] != vSize or
attnMaskType != p.attn_mask_type or numHeads != p.num_heads) {
return false;
}
bool pbias = p.qbias or p.kbias or p.vbias or p.obias;
if (qSize != p.embeding_size or kSize != p.k_size or vSize != p.v_size) {
return false;
}
if (bias != pbias) {
return false;
}
if ((qProjSize != (p.qproj_size / p.num_heads)) or
(kProjSize != (p.kproj_size / p.num_heads)) or
(vProjSize != (p.vproj_size / p.num_heads)) or (oProjSize != p.oproj_size)) {
return false;
}
// misc
if (!auxArray.is_initialized(batchSize, seqLenQ, seqLenK, attnMaskType)) {
return false;
}
if (p.training and sizeReserve == 0) {
return false;
}
return true;
}
/***************************** MHA forward *****************************/
void MHAForwardCudnnOpr::deduce_layout(MHA_PROXY_FORWARD_LAYOUT_PARAM) {
MEGDNN_MARK_USED_VAR(qkvo_weight_bias);
MEGDNN_MARK_USED_VAR(attn_mask);
MEGDNN_MARK_USED_VAR(bias_k);
MEGDNN_MARK_USED_VAR(bias_v);
megdnn_assert(
queries.ndim == 3,
"queries.ndim should be 3[batch, sequence, embeding], but got %zu",
queries.ndim);
if (!desc_status.is_initialized(param, queries, keys, values)) {
desc_status.set(handle, param, queries, keys, values);
}
attn_weight = TensorLayout(
TensorShape{
queries.shape[0] * param.num_heads, queries.shape[1],
keys.shape[1]},
queries.dtype);
size_t osize = param.oproj_size != 0
? param.oproj_size
: (param.vproj_size != 0 ? param.vproj_size
: (param.v_size * param.num_heads));
out = TensorLayout(
TensorShape{queries.shape[0], queries.shape[1], osize}, queries.dtype);
mask_reservespace = TensorLayout(TensorShape{0}, dtype::Uint8());
othr_reservespace = TensorLayout(
TensorShape{desc_status.sizeReserve / queries.dtype.size()}, queries.dtype);
}
void MHAForwardCudnnOpr::exec(MHA_PROXY_FORWARD_EXEC_PARAM) {
if (!desc_status.is_initialized(
param, queries.layout, keys.layout, values.layout)) {
desc_status.set(handle, param, queries.layout, keys.layout, values.layout);
}
if (param.attn_mask_type == MaskType::CUDNN_STYLE_MASK) {
desc_status.set_cudnn_style_mask(handle, attn_mask);
}
size_t osize = desc_status.oProjSize != 0
? desc_status.oProjSize
: (desc_status.vProjSize != 0
? desc_status.vProjSize * param.num_heads
: desc_status.vSize * param.num_heads);
SeqTensorDesc q{queries.layout, desc_status.batchSize,
desc_status.seqLenQ, desc_status.qSize,
param.input_order, desc_status.auxArray.seqQArray.data()};
SeqTensorDesc o{out.layout, desc_status.batchSize,
desc_status.seqLenQ, osize,
param.input_order, desc_status.auxArray.seqQArray.data()};
SeqTensorDesc k{keys.layout, desc_status.batchSize,
desc_status.seqLenK, desc_status.kSize,
param.input_order, desc_status.auxArray.seqKArray.data()};
SeqTensorDesc v{values.layout, desc_status.batchSize,
desc_status.seqLenK, desc_status.vSize,
param.input_order, desc_status.auxArray.seqKArray.data()};
cudnn_check(cudnnMultiHeadAttnForward(
cudnn_handle(handle), desc_status.attn_desc, -1,
desc_status.auxArray.loWinIdx.data(), desc_status.auxArray.hiWinIdx.data(),
desc_status.auxArray.devSeqQArray, desc_status.auxArray.devSeqKArray,
q.desc, queries.raw_ptr(), param.reslink ? queries.raw_ptr() : NULL, k.desc,
keys.raw_ptr(), v.desc, values.raw_ptr(), o.desc, out.raw_ptr(),
desc_status.sizeWeights,
desc_status.sizeWeights > 0 ? qkvo_weight_bias.raw_ptr() : NULL,
desc_status.sizeWkspace, workspace.raw_ptr,
param.training ? desc_status.sizeReserve : 0,
param.training ? othr_reservespace.raw_ptr() : NULL));
}
size_t MHAForwardCudnnOpr::get_workspace_in_bytes(
MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM) {
MEGDNN_MARK_USED_VAR(qkvo_weight_bias);
MEGDNN_MARK_USED_VAR(attn_mask);
MEGDNN_MARK_USED_VAR(out);
MEGDNN_MARK_USED_VAR(attn_weight);
MEGDNN_MARK_USED_VAR(mask_reservespace);
MEGDNN_MARK_USED_VAR(othr_reservespace);
if (!desc_status.is_initialized(param, queries, keys, values)) {
desc_status.set(handle, param, queries, keys, values);
}
return desc_status.sizeWkspace;
}
size_t MHAForwardCudnnOpr::get_mask_reservespace_in_bytes(
MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM) {
MEGDNN_MARK_USED_VAR(qkvo_weight_bias);
MEGDNN_MARK_USED_VAR(attn_mask);
MEGDNN_MARK_USED_VAR(out);
MEGDNN_MARK_USED_VAR(attn_weight);
MEGDNN_MARK_USED_VAR(mask_reservespace);
MEGDNN_MARK_USED_VAR(othr_reservespace);
if (!desc_status.is_initialized(param, queries, keys, values)) {
desc_status.set(handle, param, queries, keys, values);
}
return 0;
}
size_t MHAForwardCudnnOpr::get_othr_reservespace_in_bytes(
MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM) {
MEGDNN_MARK_USED_VAR(qkvo_weight_bias);
MEGDNN_MARK_USED_VAR(out);
MEGDNN_MARK_USED_VAR(mask_reservespace);
MEGDNN_MARK_USED_VAR(othr_reservespace);
if (!desc_status.is_initialized(param, queries, keys, values)) {
desc_status.set(handle, param, queries, keys, values);
}
return desc_status.sizeReserve;
}
/***************************** MHA backward *****************************/
void MHABackwardCudnnOpr::exec(MHA_PROXY_BACKWARD_EXEC_PARAM) {
if (!desc_status.is_initialized(
param, queries.layout, keys.layout, values.layout)) {
desc_status.set(handle, param, queries.layout, keys.layout, values.layout);
}
if (param.attn_mask_type == MaskType::CUDNN_STYLE_MASK) {
desc_status.set_cudnn_style_mask(handle, attn_mask);
}
size_t osize = desc_status.oProjSize != 0
? desc_status.oProjSize
: (desc_status.vProjSize != 0
? (desc_status.vProjSize * param.num_heads)
: (desc_status.vSize * param.num_heads));
SeqTensorDesc q{queries.layout, desc_status.batchSize,
desc_status.seqLenQ, desc_status.qSize,
param.input_order, desc_status.auxArray.seqQArray.data()};
SeqTensorDesc d{diff.layout, desc_status.batchSize,
desc_status.seqLenQ, osize,
param.input_order, desc_status.auxArray.seqQArray.data()};
SeqTensorDesc k{keys.layout, desc_status.batchSize,
desc_status.seqLenK, desc_status.kSize,
param.input_order, desc_status.auxArray.seqKArray.data()};
SeqTensorDesc v{values.layout, desc_status.batchSize,
desc_status.seqLenK, desc_status.vSize,
param.input_order, desc_status.auxArray.seqKArray.data()};
cudnn_check(cudnnMultiHeadAttnBackwardData(
cudnn_handle(handle), desc_status.attn_desc,
desc_status.auxArray.loWinIdx.data(), desc_status.auxArray.hiWinIdx.data(),
desc_status.auxArray.devSeqQArray, desc_status.auxArray.devSeqKArray,
d.desc, diff.raw_ptr(), q.desc, dqueries.raw_ptr(), queries.raw_ptr(),
k.desc, dkeys.raw_ptr(), keys.raw_ptr(), v.desc, dvalues.raw_ptr(),
values.raw_ptr(), desc_status.sizeWeights,
desc_status.sizeWeights > 0 ? qkvo_weight_bias.raw_ptr() : NULL,
desc_status.sizeWkspace, workspace.raw_ptr, desc_status.sizeReserve,
othr_reservespace.raw_ptr()));
cuda_check(cudaMemset(dqkvo_weight_bias.raw_ptr(), 0, desc_status.sizeWeights));
#if CUDNN_VERSION < 8600
cuda_check(cudaDeviceSynchronize());
#endif
cudnn_check(cudnnMultiHeadAttnBackwardWeights(
cudnn_handle(handle), desc_status.attn_desc, CUDNN_WGRAD_MODE_ADD, q.desc,
queries.raw_ptr(), k.desc, keys.raw_ptr(), v.desc, values.raw_ptr(), d.desc,
diff.raw_ptr(), desc_status.sizeWeights,
desc_status.sizeWeights > 0 ? qkvo_weight_bias.raw_ptr() : NULL,
desc_status.sizeWeights > 0 ? dqkvo_weight_bias.raw_ptr() : NULL,
desc_status.sizeWkspace, workspace.raw_ptr, desc_status.sizeReserve,
othr_reservespace.raw_ptr()));
}
} // namespace cuda
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
#pragma once
#include <vector>
#include "megdnn/handle.h"
#include "megdnn/thin/small_vector.h"
#include "src/cuda/cudnn_wrapper.h"
#if CUDNN_VERSION >= 8004
#include "megdnn/basic_types.h"
#include "megdnn/oprs/nn.h"
#include "src/common/algo_chooser.h"
#include "src/common/multi_head_attn/helper.h"
#include "src/common/utils.h"
#include "src/cuda/dropout/opr_impl.h"
#include "src/cuda/handle.h"
using Param = megdnn::MultiHeadAttn::Param;
using MaskType = Param::AttnMaskType;
using InputType = Param::TensorCombinationType;
namespace megdnn {
namespace cuda {
struct AuxiliaryArray {
public:
int* seqQArray = nullptr;
int* seqKArray = nullptr;
SmallVector<int> seqQArray;
SmallVector<int> seqKArray;
int* devSeqQArray = nullptr;
int* devSeqKArray = nullptr;
int* loWinIdx = nullptr;
int* hiWinIdx = nullptr;
SmallVector<int> loWinIdx;
SmallVector<int> hiWinIdx;
size_t seqLenQ = 0;
size_t seqLenK = 0;
size_t batchSize = 0;
bool attnMask = 0;
MaskType attnMaskType = MaskType::NO_MASK;
~AuxiliaryArray();
void set(
const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK,
bool _attnMask);
Handle* handle, const size_t _batchSize, const size_t _seqLenQ,
const size_t _seqLenK, MaskType _attnMaskType);
void set_cudnn_style_mask(Handle* handle, const TensorND& attn_mask);
bool is_initialized(
const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK,
bool _attnMask);
MaskType _attnMaskType);
};
using Param = megdnn::MultiHeadAttn::Param;
class MultiHeadAttnStatus {
DropoutStatus attn_dropout_status;
DropoutStatus out_dropout_status;
......@@ -53,7 +60,8 @@ class MultiHeadAttnStatus {
size_t kProjSize = 0;
size_t vProjSize = 0;
size_t oProjSize = 0;
bool attnMask = 0;
MaskType attnMaskType = MaskType::NO_MASK;
bool bias = false;
size_t sizeWeights = 0;
size_t sizeWkspace = 0;
......@@ -65,16 +73,44 @@ public:
private:
void set(
cudnnHandle_t handle, const Param& p, const TensorLayout& q,
Handle* handle, const Param& p, const TensorLayout& q,
const TensorLayout& k, const TensorLayout& v);
void set_cudnn_style_mask(Handle* handle, const TensorND& attn_mask);
bool is_initialized(
const Param& p, const TensorLayout& q, const TensorLayout& k,
const TensorLayout& v);
friend class MultiHeadAttnBase;
friend class MultiHeadAttnForwardImpl;
friend class MultiHeadAttnBackwardImpl;
friend class MHAForwardCudnnOpr;
friend class MHABackwardCudnnOpr;
};
class MHAForwardCudnnOpr {
public:
MHAForwardCudnnOpr(){};
void exec(MHA_PROXY_FORWARD_EXEC_PARAM);
void deduce_layout(MHA_PROXY_FORWARD_LAYOUT_PARAM);
size_t get_workspace_in_bytes(MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM);
size_t get_mask_reservespace_in_bytes(MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM);
size_t get_othr_reservespace_in_bytes(MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM);
private:
MultiHeadAttnStatus desc_status;
};
class MHABackwardCudnnOpr {
public:
MHABackwardCudnnOpr(){};
void exec(MHA_PROXY_BACKWARD_EXEC_PARAM);
private:
MultiHeadAttnStatus desc_status;
};
} // namespace cuda
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
#include "src/cuda/multi_head_attn/helper.h"
#if CUDNN_VERSION >= 8004
namespace megdnn {
namespace cuda {
AuxiliaryArray::~AuxiliaryArray() {
if (loWinIdx)
free(loWinIdx);
if (hiWinIdx)
free(hiWinIdx);
if (seqQArray)
free(seqQArray);
if (seqKArray)
free(seqKArray);
if (devSeqQArray)
cuda_check(cudaFree(devSeqQArray));
if (devSeqKArray)
cuda_check(cudaFree(devSeqKArray));
}
bool AuxiliaryArray::is_initialized(
const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK,
bool _attnMask) {
if (_batchSize != batchSize or _seqLenQ != seqLenQ or _seqLenK != seqLenK or
_attnMask != attnMask or !seqQArray or !seqKArray or !devSeqQArray or
!devSeqKArray or !loWinIdx or !hiWinIdx)
return false;
return true;
}
void AuxiliaryArray::set(
const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK,
bool _attnMask) {
if (_batchSize == batchSize && _seqLenQ == seqLenQ && _seqLenK == seqLenK &&
_attnMask == attnMask)
return;
else {
if (loWinIdx)
free(loWinIdx);
if (hiWinIdx)
free(hiWinIdx);
if (seqQArray)
free(seqQArray);
if (seqKArray)
free(seqKArray);
if (devSeqQArray)
cuda_check(cudaFree(devSeqQArray));
if (devSeqKArray)
cuda_check(cudaFree(devSeqKArray));
};
seqLenQ = _seqLenQ;
seqLenK = _seqLenK;
batchSize = _batchSize;
attnMask = _attnMask;
size_t seqQArraySize = 1 * batchSize;
size_t seqKArraySize = batchSize;
seqQArray = (int*)calloc(seqQArraySize, sizeof(int));
seqKArray = (int*)calloc(seqKArraySize, sizeof(int));
for (size_t i = 0; i < seqQArraySize; ++i)
seqQArray[i] = seqLenQ;
for (size_t i = 0; i < seqKArraySize; ++i)
seqKArray[i] = seqLenK;
cuda_check(cudaMalloc((void**)&devSeqQArray, seqQArraySize * sizeof(int)));
cuda_check(cudaMalloc((void**)&devSeqKArray, seqKArraySize * sizeof(int)));
cuda_check(cudaMemcpy(
devSeqQArray, seqQArray, seqQArraySize * sizeof(int),
cudaMemcpyHostToDevice));
cuda_check(cudaMemcpy(
devSeqKArray, seqKArray, seqKArraySize * sizeof(int),
cudaMemcpyHostToDevice));
loWinIdx = (int*)calloc(seqLenQ, sizeof(int));
hiWinIdx = (int*)calloc(seqLenQ, sizeof(int));
for (size_t i = 0; i < seqLenQ; ++i) {
loWinIdx[i] = 0;
if (attnMask)
hiWinIdx[i] = i + 1;
else
hiWinIdx[i] = seqLenK;
}
}
void MultiHeadAttnStatus::set(
cudnnHandle_t handle, const Param& p, const TensorLayout& q,
const TensorLayout& k, const TensorLayout& v) {
float attn_prob = p.training ? p.attn_prob : 0.f;
float out_prob = p.training ? p.out_prob : 0.f;
if (!attn_dropout_status.initialized())
attn_dropout_status.set(handle, p.seed, attn_prob);
if (!out_dropout_status.initialized())
out_dropout_status.set(handle, p.seed, out_prob);
if (attn_dropout_status.drop_prob != attn_prob) {
attn_dropout_status.drop_prob = attn_prob;
attn_dropout_status.restore_desc(handle);
}
if (out_dropout_status.drop_prob != out_prob) {
out_dropout_status.drop_prob = out_prob;
out_dropout_status.restore_desc(handle);
}
batchSize = q.shape[0];
seqLenQ = q.shape[1];
seqLenK = k.shape[1];
qSize = q.shape[2];
kSize = k.shape[2];
vSize = v.shape[2];
numHeads = p.num_heads;
qProjSize = p.qproj_size ? qSize / numHeads : 0;
kProjSize = p.kproj_size ? kSize / numHeads : 0;
vProjSize = p.vproj_size ? vSize / numHeads : 0;
oProjSize = p.oproj_size ? qSize : 0;
attnMask = p.attn_mask_type >= param::MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK;
cudnnDataType_t cudnn_dtype = to_cudnn_dtype(q.dtype);
auto flag = CUDNN_ATTN_QUERYMAP_ONE_TO_ONE;
if (p.qbias or p.kbias or p.vbias or p.obias)
flag = flag | CUDNN_ATTN_ENABLE_PROJ_BIASES;
#if CUDNN_VERSION < 8600
// TODO: CUDNN_VERSION < 8600 and out dropout > 0.0, we need to go to the proxy cuda
// implementation.
cudnn_check(cudnnSetAttnDescriptor(
attn_desc, flag, numHeads, p.sm_scaler, cudnn_dtype, cudnn_dtype,
CUDNN_DEFAULT_MATH, attn_dropout_status.desc.desc, NULL, qSize, kSize,
vSize, qProjSize, kProjSize, vProjSize, oProjSize, seqLenQ, seqLenK,
batchSize, 1));
#else
cudnn_check(cudnnSetAttnDescriptor(
attn_desc, flag, numHeads, p.sm_scaler, cudnn_dtype, cudnn_dtype,
CUDNN_DEFAULT_MATH, attn_dropout_status.desc.desc,
out_dropout_status.desc.desc, qSize, kSize, vSize, qProjSize, kProjSize,
vProjSize, oProjSize, seqLenQ, seqLenK, batchSize, 1));
#endif
auxArray.set(
batchSize, seqLenQ, seqLenK,
p.attn_mask_type >= param::MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK);
if (p.training)
cudnnGetMultiHeadAttnBuffers(
handle, attn_desc, &sizeWeights, &sizeWkspace, &sizeReserve);
else {
cudnnGetMultiHeadAttnBuffers(
handle, attn_desc, &sizeWeights, &sizeWkspace, NULL);
sizeReserve = 0;
}
}
bool MultiHeadAttnStatus::is_initialized(
const Param& p, const TensorLayout& q, const TensorLayout& k,
const TensorLayout& v) {
float attn_prob = p.training ? p.attn_prob : 0.f;
float out_prob = p.training ? p.out_prob : 0.f;
if (!attn_dropout_status.initialized() or !out_dropout_status.initialized() or
attn_dropout_status.drop_prob != attn_prob or
out_dropout_status.drop_prob != out_prob)
return false;
if (q.shape[0] != batchSize or q.shape[1] != seqLenQ or k.shape[1] != seqLenK or
q.shape[2] != qSize or k.shape[2] != kSize or v.shape[2] != vSize or
attnMask != (p.attn_mask_type >=
param::MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK) or
numHeads != p.num_heads) {
return false;
}
if ((p.qproj_size && (qProjSize == 0 or qProjSize != qSize / p.num_heads)) or
(p.kproj_size && (kProjSize == 0 or kProjSize != kSize / p.num_heads)) or
(p.vproj_size && (vProjSize == 0 or vProjSize != vSize / p.num_heads)) or
(p.oproj_size && (oProjSize == 0 or oProjSize != q.shape[2])))
return false;
if ((!p.qproj_size && qProjSize != 0) or (!p.kproj_size && kProjSize != 0) or
(!p.vproj_size && vProjSize != 0) or (!p.oproj_size && oProjSize != 0))
return false;
if (!auxArray.is_initialized(batchSize, seqLenQ, seqLenK, attnMask))
return false;
if (p.training and sizeReserve == 0)
return false;
return true;
}
} // namespace cuda
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
......@@ -4,78 +4,45 @@
#include "src/common/reduce_helper.h"
#include "src/cuda/cudnn_wrapper.h"
#include "src/cuda/handle.h"
#include "src/cuda/multi_head_attn/helper.h"
#include "src/cuda/multi_head_attn/cudnn_fwbw.h"
#include "src/cuda/multi_head_attn/proxy_bw.h"
#include "src/cuda/multi_head_attn/proxy_fw.h"
#include "src/cuda/utils.h"
namespace megdnn {
namespace cuda {
using Param = megdnn::MultiHeadAttn::Param;
using MaskType = Param::AttnMaskType;
using InputType = Param::TensorCombinationType;
bool can_use_mha_cudnn(const Param& param);
class MultiHeadAttnForwardImpl final : public MultiHeadAttnForward {
public:
using MultiHeadAttnForward::MultiHeadAttnForward;
#if CUDNN_VERSION >= 8004
MultiHeadAttnStatus desc_status;
MHAForwardCudnnOpr cudnn_opr;
#endif
MHAForwardProxyOpr proxy_opr;
void exec(
_megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values,
_megdnn_tensor_in qkvo_weight_bias, _megdnn_tensor_in attn_mask,
_megdnn_tensor_in bias_k, _megdnn_tensor_in bias_v, _megdnn_tensor_out out,
_megdnn_tensor_out attn_weight, _megdnn_tensor_out mask_reservespace,
_megdnn_tensor_out othr_reservespace, _megdnn_workspace workspace) override;
void deduce_layout(
const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& qkvo_weight_bias,
const TensorLayout& attn_mask, const TensorLayout& bias_k,
const TensorLayout& bias_v, TensorLayout& out, TensorLayout& attn_weight,
TensorLayout& mask_reservespace, TensorLayout& othr_reservespace) override;
size_t get_mask_reservespace_in_bytes(
const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& qkvo_weight_bias,
const TensorLayout& attn_mask, const TensorLayout& bias_k,
const TensorLayout& bias_v, const TensorLayout& out,
const TensorLayout& attn_weight, const TensorLayout& mask_reservespace,
const TensorLayout& othr_reservespace) override;
size_t get_othr_reservespace_in_bytes(
const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& qkvo_weight_bias,
const TensorLayout& attn_mask, const TensorLayout& bias_k,
const TensorLayout& bias_v, const TensorLayout& out,
const TensorLayout& attn_weight, const TensorLayout& mask_reservespace,
const TensorLayout& othr_reservespace) override;
size_t get_workspace_in_bytes(
const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& qkvo_weight_bias,
const TensorLayout& attn_mask, const TensorLayout& bias_k,
const TensorLayout& bias_v, const TensorLayout& out,
const TensorLayout& attn_weight, const TensorLayout& mask_reservespace,
const TensorLayout& othr_reservespace) override;
void exec(MHA_FORWARD_EXEC_PARAM) override;
void deduce_layout(MHA_FORWARD_LAYOUT_PARAM) override;
size_t get_workspace_in_bytes(MHA_FORWARD_LAYOUT_CONST_PARAM) override;
size_t get_mask_reservespace_in_bytes(MHA_FORWARD_LAYOUT_CONST_PARAM) override;
size_t get_othr_reservespace_in_bytes(MHA_FORWARD_LAYOUT_CONST_PARAM) override;
};
class MultiHeadAttnBackwardImpl final : public MultiHeadAttnBackward {
public:
using MultiHeadAttnBackward::MultiHeadAttnBackward;
#if CUDNN_VERSION >= 8004
MultiHeadAttnStatus desc_status;
MHABackwardCudnnOpr cudnn_opr;
#endif
void exec(
_megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys,
_megdnn_tensor_in values, _megdnn_tensor_in qkvo_weight_bias,
_megdnn_tensor_in attn_mask, _megdnn_tensor_in attn_weight,
_megdnn_tensor_in mask_reservespace, _megdnn_tensor_in othr_reservespace,
_megdnn_tensor_out dqueries, _megdnn_tensor_out dkeys,
_megdnn_tensor_out dvalues, _megdnn_tensor_out dqkvo_weight_bias,
_megdnn_tensor_out dbias_k, _megdnn_tensor_out dbias_v,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& diff, const TensorLayout& queries,
const TensorLayout& keys, const TensorLayout& values,
const TensorLayout& qkvo_weight_bias, const TensorLayout& attn_mask,
const TensorLayout& attn_weight, const TensorLayout& mask_reservespace,
const TensorLayout& othr_reservespace, const TensorLayout& dqueries,
const TensorLayout& dkeys, const TensorLayout& dvalues,
const TensorLayout& dqkvo_weight_bias, const TensorLayout& dbias_k,
const TensorLayout& dbias_v) override;
MHABackwardProxyOpr proxy_opr;
void exec(MHA_BACKWARD_EXEC_PARAM) override;
size_t get_workspace_in_bytes(MHA_BACKWARD_LAYOUT_CONST_PARAM) override;
};
} // namespace cuda
} // namespace megdnn
......
#include "src/cuda/multi_head_attn/proxy_bw.h"
#include "megdnn/basic_types.h"
#include "megdnn/handle.h"
#include "megdnn/oprs/nn.h"
namespace megdnn {
namespace cuda {
#define cb(DType) \
void MHABackwardProxyOpr::move_scaler_to_device( \
Handle* handle, DTypeTrait<DType>::ctype* dst, \
DTypeTrait<DType>::ctype* src) { \
cudaMemcpyAsync( \
dst, src, sizeof(DTypeTrait<DType>::ctype), cudaMemcpyHostToDevice, \
cuda_stream(handle)); \
};
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
#pragma once
#include "megdnn/handle.h"
#include "megdnn/oprs.h"
#include "megdnn/oprs/general.h"
#include "megdnn/oprs/nn.h"
#include "src/common/multi_head_attn/proxy_backward_base.h"
#include "src/common/reduce_helper.h"
#include "src/cuda/cudnn_wrapper.h"
#include "src/cuda/handle.h"
#include "src/cuda/utils.h"
namespace megdnn {
namespace cuda {
using Param = megdnn::MultiHeadAttn::Param;
using MaskType = Param::AttnMaskType;
using InputType = Param::TensorCombinationType;
using multi_head_attn::matmul_deduce_layout;
using multi_head_attn::matmul_exec;
class MHABackwardProxyOpr final : public multi_head_attn::MHABackwardProxyBase {
public:
MHABackwardProxyOpr() : MHABackwardProxyBase() {}
#define cb(DType) \
void move_scaler_to_device( \
Handle*, DTypeTrait<DType>::ctype*, DTypeTrait<DType>::ctype*) override;
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
};
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
#include "src/cuda/multi_head_attn/proxy_fw.h"
#include "megdnn/basic_types.h"
#include "megdnn/dtype.h"
#include "src/cuda/matrix_mul/opr_impl.h"
namespace megdnn {
namespace cuda {
#define cb(DType) \
void MHAForwardProxyOpr::move_scaler_to_device( \
Handle* handle, DTypeTrait<DType>::ctype* dst, \
DTypeTrait<DType>::ctype* src) { \
cudaMemcpyAsync( \
dst, src, sizeof(DTypeTrait<DType>::ctype), cudaMemcpyHostToDevice, \
cuda_stream(handle)); \
};
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
#pragma once
#include "megdnn/handle.h"
#include "megdnn/oprs.h"
#include "megdnn/oprs/general.h"
#include "src/common/multi_head_attn/helper.h"
#include "src/common/multi_head_attn/proxy_forward_base.h"
#include "src/common/reduce_helper.h"
#include "src/common/utils.h"
#include "src/cuda/cudnn_wrapper.h"
#include "src/cuda/handle.h"
#include "src/cuda/matrix_mul/opr_impl.h"
#include "src/cuda/utils.h"
namespace megdnn {
namespace cuda {
using Param = megdnn::MultiHeadAttn::Param;
using MaskType = Param::AttnMaskType;
using InputType = Param::TensorCombinationType;
using multi_head_attn::matmul_deduce_layout;
using multi_head_attn::matmul_exec;
class MHAForwardProxyOpr final : public multi_head_attn::MHAForwardProxyBase {
public:
MHAForwardProxyOpr() : MHAForwardProxyBase() {}
#define cb(DType) \
void move_scaler_to_device( \
Handle*, DTypeTrait<DType>::ctype*, DTypeTrait<DType>::ctype*) override;
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
};
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -22,23 +22,22 @@ class MultiHeadAttention(Module):
\text{MultiHeadAttn}\big(q,K,V, W_Q, W_V, W_O\big) = \sum^{nHeads-1}_{i=0}W_{O,i}h_i
where :math:`h_i=W_{V,i}V \text{Softmax}\Big( \text{smScaler} \cdot K^TW^T_{K,i}W_{Q,i}q \Big),\text{for }i\text{ = 0 ... nHeads-1}`.
Note: This API is experimental, and there is a possibility of subsequent changes. Currently, only the cuda platform is supported, and if the cudnn version >=8.6.0, the calculation results are completely correct; If the cudnn version >=8.0.4 but <8.6.0, if there is a bias, only the dbias result calculated from the backward is incorrect. If there is no bias, the forward and backward calculations are correct; If the cudnn version is less than 8.0.4, this operator is not supported.
When the following conditions are met, you can go to the cudnn backend:
- ``cudnn version`` greater than or equal to 8.0.4 and ``bias`` is ``False`` and ``training`` is ``False``
- ``cudnn version`` greater than or equal to 8.6.0
Note: This API is experimental, and there is a possibility of subsequent changes. Currently, only the cuda platform is supported.
The implementation of cudnn can be run if the following conditions are met:
- cuda is available, cudnn is available, and the version of cudnn is greater than or equal to 8.0.4.
- ``bias`` is ``False``, when ``training`` is ``False`` and ``cudnn version`` greater than or equal to 8.0.4. if the ``cudnn version`` greater than or equal to 8.6.0, this point can be ignored.
- ``add_bias_kv`` is ``False``
- ``add_zero_attn`` is ``False``
- ``need_weights`` is ``False``
- ``average_attn_weights`` is ``False``
- ``maybe_cudnn_style_mask`` is ``True`` if support else ``False``
- ``maybe_cudnn_style_mask`` is ``True``
- ``attn_mask`` and ``key_padding_mask`` is cudnn style mask, i.e. the shape of the attn_mask is :math:`(2, L)`, and the shape of the key_padding_mask is :math:`(2, N)`.
- The shape of attn_mask is :math:`(2, L)`, where :math:`(0, :)` elements specify the start index, :math:`(1, :)` elements specify the end index, the start index is inclusive, and the end index is not exclusive. The start index (i.e. elements in `attn_mask[0, x]`) must be less than the corresponding end index (i.e. elements in `attn_mask[1, x]`). The end index must be less than or equal to :math:`S`, where :math:`S` is the source sequence length, :math:`L` is the target sequence length.
- The shape of key_padding_mask is :math:`(2, N)`, where :math:`(0, :)` elements specify the target sequence padding in cudnn style mask and the element must equal to or less than :math:`L`, :math:`(1, :)` elements specify the source sequence padding in cudnn style mask and the element must equal to or less than :math:`S`, where :math:`S` is the source sequence length, :math:`L` is the target sequence length.
- ``qbias``, ``kbias``, ``vbias`` and ``obias`` are equal
Note: If there is no mask or the default mask is used, cudnn impl will also be used. At this time, cudnn will automatically generate the corresponding cudnn style mask.
Args:
embed_dim: Total dimension of the model.
......@@ -139,11 +138,9 @@ class MultiHeadAttention(Module):
if self.add_bias_kv:
xavier_uniform_(self.bias_k)
else:
self.bias_k = None
if self.add_bias_kv:
xavier_uniform_(self.bias_v)
else:
self.bias_k = None
self.bias_v = None
def forward(
......@@ -159,59 +156,36 @@ class MultiHeadAttention(Module):
maybe_cudnn_style_mask: bool = False,
):
r"""
Args:
query: Query embeddings of shape :math:`(N, L, E_q)`,
where :math:`N` is the batch size, :math:`L` is the target sequence length, and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against key-value pairs to produce the output. See "Attention Is All You Need" for more details.
key: Key embeddings of shape :math:`(N, S, E_k)`,
where :math:`N` is the batch size, :math:`S` is the source sequence length, and :math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details.
value: Value embeddings of shape :math:`(N, S, E_v)`,
where :math:`N` is the batch size, :math:`S` is the source sequence length, and :math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details.
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` to ignore for the purpose of
attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
Note: Should be set to None, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Note: User-defined mask not supported now, only support no mask or default mask, where the upper right triangle is all -inf, and the diagonal and lower left triangle are all 0. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
need_weights: indicates whether to return the attention weight, which is the output result of softmax. Default: `True`
Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
is_causal: If specified, applies a causal mask as attention mask. Default: ``False``
Warning: ``is_causal`` provides a hint that ``attn_mask`` is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility.
maybe_cudnn_style_mask: if specified, applies a cudnn style mask as attention mask. Default: ``False``
Note: In the cudnn style, the shape of the attn_mask is :math:`(2, L)`, and the shape of the key_padding_mask is :math:`(2, N)`.
Warning: like is_causal, maybe_cudnn_style_mask provides a hint that attn_mask and key_padding_mask is a cudnn style mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility. In addition, if the ``_merge_masks`` function returns ``merge_type=cudnn_style_mask``, please ensure that other conditions are correct so that it can run the implementation of cudnn, otherwise an error will be reported.
Note: Should be set to False, and configuration of this parameter is not supported now. The reason is that the underlying implementation only accepts two types of mask type, namely "no_mask" and "default_mask", and we may try to loosen this option after submitting the commit that users can pass in custom attention mask tensors.
Outputs:
- **attn_output** - Attention outputs of shape :math:`(N, L, E)`,
where :math:`L` is the target sequence length, :math:`N` is
the batch size, and :math:`E` is the embedding dimension ``embed_dim``.
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N * \text{num\_heads}, L, S)`.
Note: Now only None will be returned. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.
Args:
query: Query embeddings of shape :math:`(N, L, E_q)`,
where :math:`N` is the batch size, :math:`L` is the target sequence length, and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against key-value pairs to produce the output. See "Attention Is All You Need" for more details.
key: Key embeddings of shape :math:`(N, S, E_k)`,
where :math:`N` is the batch size, :math:`S` is the source sequence length, and :math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details.
value: Value embeddings of shape :math:`(N, S, E_v)`,
where :math:`N` is the batch size, :math:`S` is the source sequence length, and :math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details.
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` to ignore for the purpose of
attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
need_weights: indicates whether to return the attention weight, which is the output result of softmax. Default: `False`
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
effect when ``need_weights=True``. Default: ``False`` (i.e. not average weights across heads)
is_causal: If specified, applies a causal mask as attention mask. Default: ``False``
Warning: ``is_causal`` provides a hint that ``attn_mask`` is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility.
maybe_cudnn_style_mask: if specified, applies a cudnn style mask as attention mask. Default: ``False``
Note: In the cudnn style, the shape of the attn_mask is :math:`(2, L)`, and the shape of the key_padding_mask is :math:`(2, N)`.
Warning: like is_causal, maybe_cudnn_style_mask provides a hint that attn_mask and key_padding_mask is a cudnn style mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility. In addition, if the ``_merge_masks`` function returns ``merge_type=cudnn_style_mask``, please ensure that other conditions are correct so that it can run the implementation of cudnn, otherwise an error will be reported.
Outputs:
- **attn_output** - Attention outputs of shape :math:`(N, L, E)`,
where :math:`L` is the target sequence length, :math:`N` is
the batch size, and :math:`E` is the embedding dimension ``embed_dim``.
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N * \text{num\_heads}, L, S)`.
"""
assert key_padding_mask is None, (
"key_padding_mask should be None, and configuration of this parameter is not supported now."
+ self.unsupport_reason
)
assert need_weights == False, (
"need_weights should be set to False, and configuration of this parameter is not supported now."
+ self.unsupport_reason
)
assert average_attn_weights == False, (
"average_attn_weights should be set to False, and configuration of this parameter is not supported now."
+ self.unsupport_reason
)
assert maybe_cudnn_style_mask == False, (
"maybe_cudnn_style_mask should be set to False, and configuration of this parameter is not supported now."
+ self.unsupport_reason
)
return multi_head_attention(
query,
......
......@@ -613,25 +613,25 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> _infer_output_attrs<MultiHeadAt
template <>
SmallVector<LogicalTensorDesc> infer_output_attrs<MultiHeadAttn>(
const OpDef& op, const SmallVector<TensorPtr>& inputs) {
using INPUT_TYPE = opr::MultiHeadAttn::Param::TENSOR_COMBINATION_TYPE;
using InputType = opr::MultiHeadAttn::Param::TensorCombinationType;
auto&& cn = inputs[0]->comp_node();
auto input_type = op.cast_final_safe<MultiHeadAttn>().tensor_combination_type;
std::tuple<SmallVector<LogicalTensorDesc>, bool> ret;
TensorLayout empty_layout;
if (input_type == INPUT_TYPE::NONE)
if (input_type == InputType::NONE)
ret = _infer_output_attrs<MultiHeadAttn>(
op,
{inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
inputs[3]->layout(), empty_layout, empty_layout, empty_layout},
cn);
else if (input_type == INPUT_TYPE::ONLY_MASK)
else if (input_type == InputType::ONLY_MASK)
ret = _infer_output_attrs<MultiHeadAttn>(
op,
{inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
inputs[3]->layout(), inputs[4]->layout(), empty_layout, empty_layout},
cn);
else if (input_type == INPUT_TYPE::ONLY_BIASKV)
else if (input_type == InputType::ONLY_BIASKV)
ret = _infer_output_attrs<MultiHeadAttn>(
op,
{inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
......@@ -666,7 +666,7 @@ template <>
SmallVector<TensorPtr> apply_on_physical_tensor<MultiHeadAttn>(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
using INPUT_TYPE = opr::MultiHeadAttn::Param::TENSOR_COMBINATION_TYPE;
using InputType = opr::MultiHeadAttn::Param::TensorCombinationType;
SmallVector<TensorPtr> outputs;
SmallVector<LogicalTensorDesc> desc =
infer_output_attrs<MultiHeadAttn>(def, inputs);
......@@ -705,7 +705,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor<MultiHeadAttn>(
TensorLayout empty_layout;
megdnn::TensorND empty_tensor;
if (input_type == INPUT_TYPE::ALL) {
if (input_type == InputType::ALL) {
wk_size = dnn_op->get_workspace_in_bytes(
inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
inputs[3]->layout(), inputs[4]->layout(), inputs[5]->layout(),
......@@ -725,7 +725,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor<MultiHeadAttn>(
outputs[1]->dev_tensor().as_megdnn(),
outputs[2]->dev_tensor().as_megdnn(),
outputs[3]->dev_tensor().as_megdnn(), dnn_wk);
} else if (input_type == INPUT_TYPE::ONLY_MASK) {
} else if (input_type == InputType::ONLY_MASK) {
wk_size = dnn_op->get_workspace_in_bytes(
inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
inputs[3]->layout(), inputs[4]->layout(), empty_layout, empty_layout,
......@@ -743,7 +743,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor<MultiHeadAttn>(
outputs[1]->dev_tensor().as_megdnn(),
outputs[2]->dev_tensor().as_megdnn(),
outputs[3]->dev_tensor().as_megdnn(), dnn_wk);
} else if (input_type == INPUT_TYPE::ONLY_BIASKV) {
} else if (input_type == InputType::ONLY_BIASKV) {
wk_size = dnn_op->get_workspace_in_bytes(
inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
inputs[3]->layout(), empty_layout, inputs[4]->layout(),
......@@ -801,13 +801,13 @@ template <>
SymbolVarArray apply_on_var_node<MultiHeadAttn, SymbolVarArray>(
const OpDef& def, const VarNodeArray& inputs) {
auto&& rng = def.cast_final_safe<MultiHeadAttn>();
using INPUT_TYPE = opr::MultiHeadAttn::Param::TENSOR_COMBINATION_TYPE;
using InputType = opr::MultiHeadAttn::Param::TensorCombinationType;
auto input_type = rng.tensor_combination_type;
if (input_type == INPUT_TYPE::ALL) {
if (input_type == InputType::ALL) {
return _RNGOprMaker<7>::make(inputs, rng);
} else if (input_type == INPUT_TYPE::ONLY_BIASKV) {
} else if (input_type == InputType::ONLY_BIASKV) {
return _RNGOprMaker<6>::make(inputs, rng);
} else if (input_type == INPUT_TYPE::ONLY_MASK) {
} else if (input_type == InputType::ONLY_MASK) {
return _RNGOprMaker<5>::make(inputs, rng);
} else {
return _RNGOprMaker<4>::make(inputs, rng);
......@@ -884,25 +884,25 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dro
template <>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<
MultiHeadAttn>(const OpDef& op, const SmallVector<LogicalTensorDesc>& inputs) {
using INPUT_TYPE = opr::MultiHeadAttn::Param::TENSOR_COMBINATION_TYPE;
using InputType = opr::MultiHeadAttn::Param::TensorCombinationType;
auto&& cn = inputs[0].comp_node;
auto input_type = op.cast_final_safe<MultiHeadAttn>().tensor_combination_type;
std::tuple<SmallVector<LogicalTensorDesc>, bool> ret;
TensorLayout empty_layout;
if (input_type == INPUT_TYPE::NONE)
if (input_type == InputType::NONE)
ret = _infer_output_attrs<MultiHeadAttn>(
op,
{inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout,
empty_layout, empty_layout, empty_layout},
cn);
else if (input_type == INPUT_TYPE::ONLY_MASK)
else if (input_type == InputType::ONLY_MASK)
ret = _infer_output_attrs<MultiHeadAttn>(
op,
{inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout,
inputs[4].layout, empty_layout, empty_layout},
cn);
else if (input_type == INPUT_TYPE::ONLY_BIASKV)
else if (input_type == InputType::ONLY_BIASKV)
ret = _infer_output_attrs<MultiHeadAttn>(
op,
{inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout,
......
......@@ -20,8 +20,8 @@
cb(::megdnn::param::CvtColor::Mode); \
cb(::megdnn::param::Elemwise::Mode); \
cb(::megdnn::param::ElemwiseMultiType::Mode); \
cb(::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE); \
cb(::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE); \
cb(::megdnn::param::MultiHeadAttn::AttnMaskType); \
cb(::megdnn::param::MultiHeadAttn::TensorCombinationType); \
cb(::megdnn::param::Padding::PaddingMode); \
cb(::megdnn::param::RNNCell::NonlineMode); \
cb(::megdnn::param::ROIAlignV0::Mode); \
......
0a8cd3cd50cadfaae0478ee70621618e ../../dnn/scripts/opr_param_defs.py
20aa8ae7e128c1e24564ce68389307cc ../../dnn/scripts/opr_param_defs.py
9e9636d66694dd7d5a7853247a5406f9 ../../src/core/include/megbrain/ir/ops.td
2c15c869c1731d1bc5f25f9b132f4f08 generated/opdef.h.inl
0dabeee4b8f81be4c1809906b99795a5 generated/opdef.cpp.inl
be20faf18eccbc56f535b012170ed90a generated/opdef.py.inl
af9ab62fe962d409bb65e66af5f44a79 generated/opdef.cpy.inl
d468302f2d4b113913b76b5a181aae56 generated/enum_macro.h
e4489c2e1ea2b680d61c352842e56929 generated/opdef.h.inl
fd27534146a1cfcc791e40b2bb532076 generated/opdef.cpp.inl
6754eaa59ef19178eba41e99e418790c generated/opdef.py.inl
df66a3089aa6c12e5b1d943cd3d20e80 generated/opdef.cpy.inl
911001ef0dd771024919f7a1a3a009db generated/enum_macro.h
......@@ -5288,16 +5288,16 @@ std::vector<std::pair<const char*, std::string>> MultiHeadAttn_props_impl(const
props_.emplace_back("sm_scaler", std::to_string(op_.sm_scaler));
props_.emplace_back("input_order", std::to_string(op_.input_order));
switch (op_.attn_mask_type){
case MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK:
case MultiHeadAttn::AttnMaskType::NO_MASK:
props_.emplace_back("attn_mask_type", "NO_MASK");
break;
case MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK:
case MultiHeadAttn::AttnMaskType::DEFAULT_MASK:
props_.emplace_back("attn_mask_type", "DEFAULT_MASK");
break;
case MultiHeadAttn::ATTN_MASK_TYPE::CUDNN_STYLE_MASK:
case MultiHeadAttn::AttnMaskType::CUDNN_STYLE_MASK:
props_.emplace_back("attn_mask_type", "CUDNN_STYLE_MASK");
break;
case MultiHeadAttn::ATTN_MASK_TYPE::USER_DEFINED_MASK:
case MultiHeadAttn::AttnMaskType::USER_DEFINED_MASK:
props_.emplace_back("attn_mask_type", "USER_DEFINED_MASK");
break;
default:
......@@ -5305,16 +5305,16 @@ std::vector<std::pair<const char*, std::string>> MultiHeadAttn_props_impl(const
break;
}
switch (op_.tensor_combination_type){
case MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE:
case MultiHeadAttn::TensorCombinationType::NONE:
props_.emplace_back("tensor_combination_type", "NONE");
break;
case MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_MASK:
case MultiHeadAttn::TensorCombinationType::ONLY_MASK:
props_.emplace_back("tensor_combination_type", "ONLY_MASK");
break;
case MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_BIASKV:
case MultiHeadAttn::TensorCombinationType::ONLY_BIASKV:
props_.emplace_back("tensor_combination_type", "ONLY_BIASKV");
break;
case MultiHeadAttn::TENSOR_COMBINATION_TYPE::ALL:
case MultiHeadAttn::TensorCombinationType::ALL:
props_.emplace_back("tensor_combination_type", "ALL");
break;
default:
......
......@@ -15043,39 +15043,39 @@ void _init_py_MeshIndexing(py::module m) {
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MeshIndexing::typeinfo(), &py_type).second);
}
template<> struct EnumTrait<MultiHeadAttn::ATTN_MASK_TYPE> {
static constexpr const char *name = "MultiHeadAttn.ATTN_MASK_TYPE";
static constexpr std::underlying_type_t<MultiHeadAttn::ATTN_MASK_TYPE> max = 4 - 1;
template<> struct EnumTrait<MultiHeadAttn::AttnMaskType> {
static constexpr const char *name = "MultiHeadAttn.AttnMaskType";
static constexpr std::underlying_type_t<MultiHeadAttn::AttnMaskType> max = 4 - 1;
};
template<> PyTypeObject* EnumWrapper<MultiHeadAttn::ATTN_MASK_TYPE>::type = nullptr;
template<> PyTypeObject* EnumWrapper<MultiHeadAttn::AttnMaskType>::type = nullptr;
template<> const char*
EnumWrapper<MultiHeadAttn::ATTN_MASK_TYPE>::members[] = {"NO_MASK", "DEFAULT_MASK", "CUDNN_STYLE_MASK", "USER_DEFINED_MASK"};
EnumWrapper<MultiHeadAttn::AttnMaskType>::members[] = {"NO_MASK", "DEFAULT_MASK", "CUDNN_STYLE_MASK", "USER_DEFINED_MASK"};
template<> std::unordered_map<std::string, MultiHeadAttn::ATTN_MASK_TYPE>
EnumWrapper<MultiHeadAttn::ATTN_MASK_TYPE>::mem2value = {{normalize_enum("NO_MASK"), MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK}, {normalize_enum("DEFAULT_MASK"), MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK}, {normalize_enum("CUDNN_STYLE_MASK"), MultiHeadAttn::ATTN_MASK_TYPE::CUDNN_STYLE_MASK}, {normalize_enum("USER_DEFINED_MASK"), MultiHeadAttn::ATTN_MASK_TYPE::USER_DEFINED_MASK}};
template<> PyObject* EnumWrapper<MultiHeadAttn::ATTN_MASK_TYPE>::pyobj_insts[4] = {nullptr};
template<> std::unordered_map<std::string, MultiHeadAttn::AttnMaskType>
EnumWrapper<MultiHeadAttn::AttnMaskType>::mem2value = {{normalize_enum("NO_MASK"), MultiHeadAttn::AttnMaskType::NO_MASK}, {normalize_enum("DEFAULT_MASK"), MultiHeadAttn::AttnMaskType::DEFAULT_MASK}, {normalize_enum("CUDNN_STYLE_MASK"), MultiHeadAttn::AttnMaskType::CUDNN_STYLE_MASK}, {normalize_enum("USER_DEFINED_MASK"), MultiHeadAttn::AttnMaskType::USER_DEFINED_MASK}};
template<> PyObject* EnumWrapper<MultiHeadAttn::AttnMaskType>::pyobj_insts[4] = {nullptr};
void _init_py_MultiHeadAttn_ATTN_MASK_TYPE(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<MultiHeadAttn::ATTN_MASK_TYPE>::type;
void _init_py_MultiHeadAttn_AttnMaskType(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<MultiHeadAttn::AttnMaskType>::type;
static PyMethodDef tp_methods[] = {
{const_cast<char*>("dump"), (PyCFunction)EnumWrapper<MultiHeadAttn::ATTN_MASK_TYPE>::py_dump, METH_NOARGS, NULL},
{const_cast<char*>("dump"), (PyCFunction)EnumWrapper<MultiHeadAttn::AttnMaskType>::py_dump, METH_NOARGS, NULL},
{NULL} /* Sentinel */
};
static PyType_Slot slots[] = {
{Py_tp_repr, (void*)EnumWrapper<MultiHeadAttn::ATTN_MASK_TYPE>::py_repr},
{Py_tp_richcompare, (void*)EnumWrapper<MultiHeadAttn::ATTN_MASK_TYPE>::tp_richcompare},
{Py_tp_repr, (void*)EnumWrapper<MultiHeadAttn::AttnMaskType>::py_repr},
{Py_tp_richcompare, (void*)EnumWrapper<MultiHeadAttn::AttnMaskType>::tp_richcompare},
{Py_tp_methods, tp_methods},
{0, NULL}
};
static PyType_Spec spec = {
// name
"megengine.core._imperative_rt.ops.MultiHeadAttn.ATTN_MASK_TYPE",
"megengine.core._imperative_rt.ops.MultiHeadAttn.AttnMaskType",
// basicsize
sizeof(EnumWrapper<MultiHeadAttn::ATTN_MASK_TYPE>),
sizeof(EnumWrapper<MultiHeadAttn::AttnMaskType>),
// itemsize
0,
// flags
......@@ -15089,7 +15089,7 @@ void _init_py_MultiHeadAttn_ATTN_MASK_TYPE(PyTypeObject& py_type) {
e_type->tp_setattro(
reinterpret_cast<PyObject*>(e_type),
py::cast("__name__").release().ptr(),
py::cast("ATTN_MASK_TYPE").release().ptr()) >= 0);
py::cast("AttnMaskType").release().ptr()) >= 0);
mgb_assert(
e_type->tp_setattro(
......@@ -15101,66 +15101,66 @@ void _init_py_MultiHeadAttn_ATTN_MASK_TYPE(PyTypeObject& py_type) {
e_type->tp_setattro(
reinterpret_cast<PyObject*>(e_type),
py::cast("__qualname__").release().ptr(),
py::cast("MultiHeadAttn.ATTN_MASK_TYPE").release().ptr()) >= 0);
py::cast("MultiHeadAttn.AttnMaskType").release().ptr()) >= 0);
{
PyObject* inst = e_type->tp_alloc(e_type, 0);
reinterpret_cast<EnumWrapper<MultiHeadAttn::ATTN_MASK_TYPE>*>(inst)->value = MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK;
reinterpret_cast<EnumWrapper<MultiHeadAttn::AttnMaskType>*>(inst)->value = MultiHeadAttn::AttnMaskType::NO_MASK;
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "NO_MASK", inst) >= 0);
EnumWrapper<MultiHeadAttn::ATTN_MASK_TYPE>::pyobj_insts[0] = inst;
EnumWrapper<MultiHeadAttn::AttnMaskType>::pyobj_insts[0] = inst;
}{
PyObject* inst = e_type->tp_alloc(e_type, 0);
reinterpret_cast<EnumWrapper<MultiHeadAttn::ATTN_MASK_TYPE>*>(inst)->value = MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK;
reinterpret_cast<EnumWrapper<MultiHeadAttn::AttnMaskType>*>(inst)->value = MultiHeadAttn::AttnMaskType::DEFAULT_MASK;
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "DEFAULT_MASK", inst) >= 0);
EnumWrapper<MultiHeadAttn::ATTN_MASK_TYPE>::pyobj_insts[1] = inst;
EnumWrapper<MultiHeadAttn::AttnMaskType>::pyobj_insts[1] = inst;
}{
PyObject* inst = e_type->tp_alloc(e_type, 0);
reinterpret_cast<EnumWrapper<MultiHeadAttn::ATTN_MASK_TYPE>*>(inst)->value = MultiHeadAttn::ATTN_MASK_TYPE::CUDNN_STYLE_MASK;
reinterpret_cast<EnumWrapper<MultiHeadAttn::AttnMaskType>*>(inst)->value = MultiHeadAttn::AttnMaskType::CUDNN_STYLE_MASK;
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "CUDNN_STYLE_MASK", inst) >= 0);
EnumWrapper<MultiHeadAttn::ATTN_MASK_TYPE>::pyobj_insts[2] = inst;
EnumWrapper<MultiHeadAttn::AttnMaskType>::pyobj_insts[2] = inst;
}{
PyObject* inst = e_type->tp_alloc(e_type, 0);
reinterpret_cast<EnumWrapper<MultiHeadAttn::ATTN_MASK_TYPE>*>(inst)->value = MultiHeadAttn::ATTN_MASK_TYPE::USER_DEFINED_MASK;
reinterpret_cast<EnumWrapper<MultiHeadAttn::AttnMaskType>*>(inst)->value = MultiHeadAttn::AttnMaskType::USER_DEFINED_MASK;
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "USER_DEFINED_MASK", inst) >= 0);
EnumWrapper<MultiHeadAttn::ATTN_MASK_TYPE>::pyobj_insts[3] = inst;
EnumWrapper<MultiHeadAttn::AttnMaskType>::pyobj_insts[3] = inst;
}
Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "ATTN_MASK_TYPE", reinterpret_cast<PyObject*>(e_type)) >= 0);
py_type.tp_dict, "AttnMaskType", reinterpret_cast<PyObject*>(e_type)) >= 0);
}
template<> struct EnumTrait<MultiHeadAttn::TENSOR_COMBINATION_TYPE> {
static constexpr const char *name = "MultiHeadAttn.TENSOR_COMBINATION_TYPE";
static constexpr std::underlying_type_t<MultiHeadAttn::TENSOR_COMBINATION_TYPE> max = 4 - 1;
template<> struct EnumTrait<MultiHeadAttn::TensorCombinationType> {
static constexpr const char *name = "MultiHeadAttn.TensorCombinationType";
static constexpr std::underlying_type_t<MultiHeadAttn::TensorCombinationType> max = 4 - 1;
};
template<> PyTypeObject* EnumWrapper<MultiHeadAttn::TENSOR_COMBINATION_TYPE>::type = nullptr;
template<> PyTypeObject* EnumWrapper<MultiHeadAttn::TensorCombinationType>::type = nullptr;
template<> const char*
EnumWrapper<MultiHeadAttn::TENSOR_COMBINATION_TYPE>::members[] = {"NONE", "ONLY_MASK", "ONLY_BIASKV", "ALL"};
EnumWrapper<MultiHeadAttn::TensorCombinationType>::members[] = {"NONE", "ONLY_MASK", "ONLY_BIASKV", "ALL"};
template<> std::unordered_map<std::string, MultiHeadAttn::TENSOR_COMBINATION_TYPE>
EnumWrapper<MultiHeadAttn::TENSOR_COMBINATION_TYPE>::mem2value = {{normalize_enum("NONE"), MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE}, {normalize_enum("ONLY_MASK"), MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_MASK}, {normalize_enum("ONLY_BIASKV"), MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_BIASKV}, {normalize_enum("ALL"), MultiHeadAttn::TENSOR_COMBINATION_TYPE::ALL}};
template<> PyObject* EnumWrapper<MultiHeadAttn::TENSOR_COMBINATION_TYPE>::pyobj_insts[4] = {nullptr};
template<> std::unordered_map<std::string, MultiHeadAttn::TensorCombinationType>
EnumWrapper<MultiHeadAttn::TensorCombinationType>::mem2value = {{normalize_enum("NONE"), MultiHeadAttn::TensorCombinationType::NONE}, {normalize_enum("ONLY_MASK"), MultiHeadAttn::TensorCombinationType::ONLY_MASK}, {normalize_enum("ONLY_BIASKV"), MultiHeadAttn::TensorCombinationType::ONLY_BIASKV}, {normalize_enum("ALL"), MultiHeadAttn::TensorCombinationType::ALL}};
template<> PyObject* EnumWrapper<MultiHeadAttn::TensorCombinationType>::pyobj_insts[4] = {nullptr};
void _init_py_MultiHeadAttn_TENSOR_COMBINATION_TYPE(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<MultiHeadAttn::TENSOR_COMBINATION_TYPE>::type;
void _init_py_MultiHeadAttn_TensorCombinationType(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<MultiHeadAttn::TensorCombinationType>::type;
static PyMethodDef tp_methods[] = {
{const_cast<char*>("dump"), (PyCFunction)EnumWrapper<MultiHeadAttn::TENSOR_COMBINATION_TYPE>::py_dump, METH_NOARGS, NULL},
{const_cast<char*>("dump"), (PyCFunction)EnumWrapper<MultiHeadAttn::TensorCombinationType>::py_dump, METH_NOARGS, NULL},
{NULL} /* Sentinel */
};
static PyType_Slot slots[] = {
{Py_tp_repr, (void*)EnumWrapper<MultiHeadAttn::TENSOR_COMBINATION_TYPE>::py_repr},
{Py_tp_richcompare, (void*)EnumWrapper<MultiHeadAttn::TENSOR_COMBINATION_TYPE>::tp_richcompare},
{Py_tp_repr, (void*)EnumWrapper<MultiHeadAttn::TensorCombinationType>::py_repr},
{Py_tp_richcompare, (void*)EnumWrapper<MultiHeadAttn::TensorCombinationType>::tp_richcompare},
{Py_tp_methods, tp_methods},
{0, NULL}
};
static PyType_Spec spec = {
// name
"megengine.core._imperative_rt.ops.MultiHeadAttn.TENSOR_COMBINATION_TYPE",
"megengine.core._imperative_rt.ops.MultiHeadAttn.TensorCombinationType",
// basicsize
sizeof(EnumWrapper<MultiHeadAttn::TENSOR_COMBINATION_TYPE>),
sizeof(EnumWrapper<MultiHeadAttn::TensorCombinationType>),
// itemsize
0,
// flags
......@@ -15174,7 +15174,7 @@ void _init_py_MultiHeadAttn_TENSOR_COMBINATION_TYPE(PyTypeObject& py_type) {
e_type->tp_setattro(
reinterpret_cast<PyObject*>(e_type),
py::cast("__name__").release().ptr(),
py::cast("TENSOR_COMBINATION_TYPE").release().ptr()) >= 0);
py::cast("TensorCombinationType").release().ptr()) >= 0);
mgb_assert(
e_type->tp_setattro(
......@@ -15186,31 +15186,31 @@ void _init_py_MultiHeadAttn_TENSOR_COMBINATION_TYPE(PyTypeObject& py_type) {
e_type->tp_setattro(
reinterpret_cast<PyObject*>(e_type),
py::cast("__qualname__").release().ptr(),
py::cast("MultiHeadAttn.TENSOR_COMBINATION_TYPE").release().ptr()) >= 0);
py::cast("MultiHeadAttn.TensorCombinationType").release().ptr()) >= 0);
{
PyObject* inst = e_type->tp_alloc(e_type, 0);
reinterpret_cast<EnumWrapper<MultiHeadAttn::TENSOR_COMBINATION_TYPE>*>(inst)->value = MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE;
reinterpret_cast<EnumWrapper<MultiHeadAttn::TensorCombinationType>*>(inst)->value = MultiHeadAttn::TensorCombinationType::NONE;
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "NONE", inst) >= 0);
EnumWrapper<MultiHeadAttn::TENSOR_COMBINATION_TYPE>::pyobj_insts[0] = inst;
EnumWrapper<MultiHeadAttn::TensorCombinationType>::pyobj_insts[0] = inst;
}{
PyObject* inst = e_type->tp_alloc(e_type, 0);
reinterpret_cast<EnumWrapper<MultiHeadAttn::TENSOR_COMBINATION_TYPE>*>(inst)->value = MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_MASK;
reinterpret_cast<EnumWrapper<MultiHeadAttn::TensorCombinationType>*>(inst)->value = MultiHeadAttn::TensorCombinationType::ONLY_MASK;
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ONLY_MASK", inst) >= 0);
EnumWrapper<MultiHeadAttn::TENSOR_COMBINATION_TYPE>::pyobj_insts[1] = inst;
EnumWrapper<MultiHeadAttn::TensorCombinationType>::pyobj_insts[1] = inst;
}{
PyObject* inst = e_type->tp_alloc(e_type, 0);
reinterpret_cast<EnumWrapper<MultiHeadAttn::TENSOR_COMBINATION_TYPE>*>(inst)->value = MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_BIASKV;
reinterpret_cast<EnumWrapper<MultiHeadAttn::TensorCombinationType>*>(inst)->value = MultiHeadAttn::TensorCombinationType::ONLY_BIASKV;
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ONLY_BIASKV", inst) >= 0);
EnumWrapper<MultiHeadAttn::TENSOR_COMBINATION_TYPE>::pyobj_insts[2] = inst;
EnumWrapper<MultiHeadAttn::TensorCombinationType>::pyobj_insts[2] = inst;
}{
PyObject* inst = e_type->tp_alloc(e_type, 0);
reinterpret_cast<EnumWrapper<MultiHeadAttn::TENSOR_COMBINATION_TYPE>*>(inst)->value = MultiHeadAttn::TENSOR_COMBINATION_TYPE::ALL;
reinterpret_cast<EnumWrapper<MultiHeadAttn::TensorCombinationType>*>(inst)->value = MultiHeadAttn::TensorCombinationType::ALL;
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ALL", inst) >= 0);
EnumWrapper<MultiHeadAttn::TENSOR_COMBINATION_TYPE>::pyobj_insts[3] = inst;
EnumWrapper<MultiHeadAttn::TensorCombinationType>::pyobj_insts[3] = inst;
}
Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "TENSOR_COMBINATION_TYPE", reinterpret_cast<PyObject*>(e_type)) >= 0);
py_type.tp_dict, "TensorCombinationType", reinterpret_cast<PyObject*>(e_type)) >= 0);
}
PyOpDefBegin(MultiHeadAttn) // {
......@@ -15708,7 +15708,7 @@ PyMethodDef PyOp(MultiHeadAttn)::py_init_methoddef = {
"__init__",
(PyCFunction)PyOp(MultiHeadAttn)::py_init_proxy,
METH_VARARGS | METH_KEYWORDS,
"__init__(self, num_heads: int = ..., embeding_size: int = ..., k_size: int = ..., v_size: int = ..., qproj_size: int = ..., kproj_size: int = ..., vproj_size: int = ..., oproj_size: int = ..., qbias: bool = ..., kbias: bool = ..., vbias: bool = ..., obias: bool = ..., sm_scaler: float = ..., input_order: int = ..., attn_mask_type: Union[str, ATTN_MASK_TYPE] = ..., tensor_combination_type: Union[str, TENSOR_COMBINATION_TYPE] = ..., add_zero_attn: bool = ..., need_weights: bool = ..., reslink: bool = ..., training: bool = ..., seed: int = ..., attn_prob: float = ..., out_prob: float = ..., handle: int = ...) -> None\n"
"__init__(self, num_heads: int = ..., embeding_size: int = ..., k_size: int = ..., v_size: int = ..., qproj_size: int = ..., kproj_size: int = ..., vproj_size: int = ..., oproj_size: int = ..., qbias: bool = ..., kbias: bool = ..., vbias: bool = ..., obias: bool = ..., sm_scaler: float = ..., input_order: int = ..., attn_mask_type: Union[str, AttnMaskType] = ..., tensor_combination_type: Union[str, TensorCombinationType] = ..., add_zero_attn: bool = ..., need_weights: bool = ..., reslink: bool = ..., training: bool = ..., seed: int = ..., attn_prob: float = ..., out_prob: float = ..., handle: int = ...) -> None\n"
};
void _init_py_MultiHeadAttn(py::module m) {
......@@ -15730,8 +15730,8 @@ void _init_py_MultiHeadAttn(py::module m) {
PyObject* descr = PyDescr_NewMethod(&PyOpType(MultiHeadAttn), &PyOp(MultiHeadAttn)::py_init_methoddef);
PyDict_SetItemString(py_type.tp_dict, "__init__", descr);
mgb_assert(PyType_Ready(&py_type) >= 0);
_init_py_MultiHeadAttn_ATTN_MASK_TYPE(py_type);
_init_py_MultiHeadAttn_TENSOR_COMBINATION_TYPE(py_type);
_init_py_MultiHeadAttn_AttnMaskType(py_type);
_init_py_MultiHeadAttn_TensorCombinationType(py_type);
PyType_Modified(&py_type);
m.add_object("MultiHeadAttn", reinterpret_cast<PyObject*>(&py_type));
......
......@@ -1398,8 +1398,8 @@ class MultiHeadAttn : public OpDefImplBase<MultiHeadAttn> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using ATTN_MASK_TYPE = ::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE;
using TENSOR_COMBINATION_TYPE = ::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE;
using AttnMaskType = ::megdnn::param::MultiHeadAttn::AttnMaskType;
using TensorCombinationType = ::megdnn::param::MultiHeadAttn::TensorCombinationType;
uint32_t num_heads = 1;
uint32_t embeding_size = 0;
uint32_t k_size = 0;
......@@ -1414,8 +1414,8 @@ public:
bool obias = false;
float sm_scaler = 1.f;
uint32_t input_order = 0;
ATTN_MASK_TYPE attn_mask_type = ::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK;
TENSOR_COMBINATION_TYPE tensor_combination_type = ::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE;
AttnMaskType attn_mask_type = ::megdnn::param::MultiHeadAttn::AttnMaskType::NO_MASK;
TensorCombinationType tensor_combination_type = ::megdnn::param::MultiHeadAttn::TensorCombinationType::NONE;
bool add_zero_attn = false;
bool need_weights = false;
bool reslink = false;
......@@ -1425,7 +1425,7 @@ public:
float out_prob = 0.f;
size_t handle;
MultiHeadAttn() = default;
MultiHeadAttn(uint32_t num_heads_, uint32_t embeding_size_, uint32_t k_size_, uint32_t v_size_, uint32_t qproj_size_, uint32_t kproj_size_, uint32_t vproj_size_, uint32_t oproj_size_, bool qbias_, bool kbias_, bool vbias_, bool obias_, float sm_scaler_, uint32_t input_order_, ATTN_MASK_TYPE attn_mask_type_, TENSOR_COMBINATION_TYPE tensor_combination_type_, bool add_zero_attn_, bool need_weights_, bool reslink_, bool training_, uint64_t seed_, float attn_prob_, float out_prob_, size_t handle_, std::string scope_ = {}): num_heads(num_heads_), embeding_size(embeding_size_), k_size(k_size_), v_size(v_size_), qproj_size(qproj_size_), kproj_size(kproj_size_), vproj_size(vproj_size_), oproj_size(oproj_size_), qbias(qbias_), kbias(kbias_), vbias(vbias_), obias(obias_), sm_scaler(sm_scaler_), input_order(input_order_), attn_mask_type(attn_mask_type_), tensor_combination_type(tensor_combination_type_), add_zero_attn(add_zero_attn_), need_weights(need_weights_), reslink(reslink_), training(training_), seed(seed_), attn_prob(attn_prob_), out_prob(out_prob_), handle(handle_) { set_scope(scope_); }
MultiHeadAttn(uint32_t num_heads_, uint32_t embeding_size_, uint32_t k_size_, uint32_t v_size_, uint32_t qproj_size_, uint32_t kproj_size_, uint32_t vproj_size_, uint32_t oproj_size_, bool qbias_, bool kbias_, bool vbias_, bool obias_, float sm_scaler_, uint32_t input_order_, AttnMaskType attn_mask_type_, TensorCombinationType tensor_combination_type_, bool add_zero_attn_, bool need_weights_, bool reslink_, bool training_, uint64_t seed_, float attn_prob_, float out_prob_, size_t handle_, std::string scope_ = {}): num_heads(num_heads_), embeding_size(embeding_size_), k_size(k_size_), v_size(v_size_), qproj_size(qproj_size_), kproj_size(kproj_size_), vproj_size(vproj_size_), oproj_size(oproj_size_), qbias(qbias_), kbias(kbias_), vbias(vbias_), obias(obias_), sm_scaler(sm_scaler_), input_order(input_order_), attn_mask_type(attn_mask_type_), tensor_combination_type(tensor_combination_type_), add_zero_attn(add_zero_attn_), need_weights(need_weights_), reslink(reslink_), training(training_), seed(seed_), attn_prob(attn_prob_), out_prob(out_prob_), handle(handle_) { set_scope(scope_); }
MultiHeadAttn(::megdnn::param::MultiHeadAttn packed_param_0, size_t handle_): num_heads(packed_param_0.num_heads), embeding_size(packed_param_0.embeding_size), k_size(packed_param_0.k_size), v_size(packed_param_0.v_size), qproj_size(packed_param_0.qproj_size), kproj_size(packed_param_0.kproj_size), vproj_size(packed_param_0.vproj_size), oproj_size(packed_param_0.oproj_size), qbias(packed_param_0.qbias), kbias(packed_param_0.kbias), vbias(packed_param_0.vbias), obias(packed_param_0.obias), sm_scaler(packed_param_0.sm_scaler), input_order(packed_param_0.input_order), attn_mask_type(packed_param_0.attn_mask_type), tensor_combination_type(packed_param_0.tensor_combination_type), add_zero_attn(packed_param_0.add_zero_attn), need_weights(packed_param_0.need_weights), reslink(packed_param_0.reslink), training(packed_param_0.training), seed(packed_param_0.seed), attn_prob(packed_param_0.attn_prob), out_prob(packed_param_0.out_prob), handle(handle_) {}
::megdnn::param::MultiHeadAttn param() const {
return {num_heads, embeding_size, k_size, v_size, qproj_size, kproj_size, vproj_size, oproj_size, qbias, kbias, vbias, obias, sm_scaler, input_order, attn_mask_type, tensor_combination_type, add_zero_attn, need_weights, reslink, training, seed, attn_prob, out_prob};
......
......@@ -1479,38 +1479,38 @@ MeshIndexingInst
py::class_<MultiHeadAttn, std::shared_ptr<MultiHeadAttn>, OpDef> MultiHeadAttnInst(m, "MultiHeadAttn");
py::enum_<MultiHeadAttn::ATTN_MASK_TYPE>(MultiHeadAttnInst, "ATTN_MASK_TYPE")
.value("NO_MASK", MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK)
.value("DEFAULT_MASK", MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK)
.value("CUDNN_STYLE_MASK", MultiHeadAttn::ATTN_MASK_TYPE::CUDNN_STYLE_MASK)
.value("USER_DEFINED_MASK", MultiHeadAttn::ATTN_MASK_TYPE::USER_DEFINED_MASK)
py::enum_<MultiHeadAttn::AttnMaskType>(MultiHeadAttnInst, "AttnMaskType")
.value("NO_MASK", MultiHeadAttn::AttnMaskType::NO_MASK)
.value("DEFAULT_MASK", MultiHeadAttn::AttnMaskType::DEFAULT_MASK)
.value("CUDNN_STYLE_MASK", MultiHeadAttn::AttnMaskType::CUDNN_STYLE_MASK)
.value("USER_DEFINED_MASK", MultiHeadAttn::AttnMaskType::USER_DEFINED_MASK)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "NO_MASK") return MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK;
if (str == "DEFAULT_MASK") return MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK;
if (str == "CUDNN_STYLE_MASK") return MultiHeadAttn::ATTN_MASK_TYPE::CUDNN_STYLE_MASK;
if (str == "USER_DEFINED_MASK") return MultiHeadAttn::ATTN_MASK_TYPE::USER_DEFINED_MASK;
if (str == "NO_MASK") return MultiHeadAttn::AttnMaskType::NO_MASK;
if (str == "DEFAULT_MASK") return MultiHeadAttn::AttnMaskType::DEFAULT_MASK;
if (str == "CUDNN_STYLE_MASK") return MultiHeadAttn::AttnMaskType::CUDNN_STYLE_MASK;
if (str == "USER_DEFINED_MASK") return MultiHeadAttn::AttnMaskType::USER_DEFINED_MASK;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, MultiHeadAttn::ATTN_MASK_TYPE>();
py::implicitly_convertible<std::string, MultiHeadAttn::AttnMaskType>();
py::enum_<MultiHeadAttn::TENSOR_COMBINATION_TYPE>(MultiHeadAttnInst, "TENSOR_COMBINATION_TYPE")
.value("NONE", MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE)
.value("ONLY_MASK", MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_MASK)
.value("ONLY_BIASKV", MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_BIASKV)
.value("ALL", MultiHeadAttn::TENSOR_COMBINATION_TYPE::ALL)
py::enum_<MultiHeadAttn::TensorCombinationType>(MultiHeadAttnInst, "TensorCombinationType")
.value("NONE", MultiHeadAttn::TensorCombinationType::NONE)
.value("ONLY_MASK", MultiHeadAttn::TensorCombinationType::ONLY_MASK)
.value("ONLY_BIASKV", MultiHeadAttn::TensorCombinationType::ONLY_BIASKV)
.value("ALL", MultiHeadAttn::TensorCombinationType::ALL)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "NONE") return MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE;
if (str == "ONLY_MASK") return MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_MASK;
if (str == "ONLY_BIASKV") return MultiHeadAttn::TENSOR_COMBINATION_TYPE::ONLY_BIASKV;
if (str == "ALL") return MultiHeadAttn::TENSOR_COMBINATION_TYPE::ALL;
if (str == "NONE") return MultiHeadAttn::TensorCombinationType::NONE;
if (str == "ONLY_MASK") return MultiHeadAttn::TensorCombinationType::ONLY_MASK;
if (str == "ONLY_BIASKV") return MultiHeadAttn::TensorCombinationType::ONLY_BIASKV;
if (str == "ALL") return MultiHeadAttn::TensorCombinationType::ALL;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, MultiHeadAttn::TENSOR_COMBINATION_TYPE>();
py::implicitly_convertible<std::string, MultiHeadAttn::TensorCombinationType>();
MultiHeadAttnInst
.def(py::init<uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, bool, bool, bool, bool, float, uint32_t, ::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE, ::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE, bool, bool, bool, bool, uint64_t, float, float, size_t, std::string>(), py::arg("num_heads") = 1, py::arg("embeding_size") = 0, py::arg("k_size") = 0, py::arg("v_size") = 0, py::arg("qproj_size") = 0, py::arg("kproj_size") = 0, py::arg("vproj_size") = 0, py::arg("oproj_size") = 0, py::arg("qbias") = false, py::arg("kbias") = false, py::arg("vbias") = false, py::arg("obias") = false, py::arg("sm_scaler") = 1.f, py::arg("input_order") = 0, py::arg("attn_mask_type") = ::megdnn::param::MultiHeadAttn::ATTN_MASK_TYPE::NO_MASK, py::arg("tensor_combination_type") = ::megdnn::param::MultiHeadAttn::TENSOR_COMBINATION_TYPE::NONE, py::arg("add_zero_attn") = false, py::arg("need_weights") = false, py::arg("reslink") = false, py::arg("training") = true, py::arg("seed") = 0, py::arg("attn_prob") = 0.f, py::arg("out_prob") = 0.f, py::arg("handle"), py::arg("scope") = {})
.def(py::init<uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, bool, bool, bool, bool, float, uint32_t, ::megdnn::param::MultiHeadAttn::AttnMaskType, ::megdnn::param::MultiHeadAttn::TensorCombinationType, bool, bool, bool, bool, uint64_t, float, float, size_t, std::string>(), py::arg("num_heads") = 1, py::arg("embeding_size") = 0, py::arg("k_size") = 0, py::arg("v_size") = 0, py::arg("qproj_size") = 0, py::arg("kproj_size") = 0, py::arg("vproj_size") = 0, py::arg("oproj_size") = 0, py::arg("qbias") = false, py::arg("kbias") = false, py::arg("vbias") = false, py::arg("obias") = false, py::arg("sm_scaler") = 1.f, py::arg("input_order") = 0, py::arg("attn_mask_type") = ::megdnn::param::MultiHeadAttn::AttnMaskType::NO_MASK, py::arg("tensor_combination_type") = ::megdnn::param::MultiHeadAttn::TensorCombinationType::NONE, py::arg("add_zero_attn") = false, py::arg("need_weights") = false, py::arg("reslink") = false, py::arg("training") = true, py::arg("seed") = 0, py::arg("attn_prob") = 0.f, py::arg("out_prob") = 0.f, py::arg("handle"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("num_heads", &MultiHeadAttn::num_heads)
.def_readwrite("embeding_size", &MultiHeadAttn::embeding_size)
......
此差异已折叠。
......@@ -33,31 +33,31 @@ struct OprMaker<opr::DropoutForward, 1> {
template <>
struct OprMaker<opr::MultiHeadAttn, 0> {
using Param = opr::MultiHeadAttn::Param;
using INPUT_TYPE = Param::TENSOR_COMBINATION_TYPE;
using InputType = Param::TensorCombinationType;
static cg::OperatorNodeBase* make(
const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (i.size() == 7) {
mgb_assert(INPUT_TYPE::ALL == param.tensor_combination_type);
mgb_assert(InputType::ALL == param.tensor_combination_type);
return opr::MultiHeadAttn::make(
i[0], i[1], i[2], i[3], i[4], i[5], i[6], param, config)[0]
.node()
->owner_opr();
} else if (i.size() == 6) {
mgb_assert(INPUT_TYPE::ONLY_BIASKV == param.tensor_combination_type);
mgb_assert(InputType::ONLY_BIASKV == param.tensor_combination_type);
return opr::MultiHeadAttn::make(
i[0], i[1], i[2], i[3], i[4], i[5], param, config)[0]
.node()
->owner_opr();
} else if (i.size() == 5) {
mgb_assert(INPUT_TYPE::ONLY_MASK == param.tensor_combination_type);
mgb_assert(InputType::ONLY_MASK == param.tensor_combination_type);
return opr::MultiHeadAttn::make(
i[0], i[1], i[2], i[3], i[4], param, config)[0]
.node()
->owner_opr();
} else {
mgb_assert(INPUT_TYPE::NONE == param.tensor_combination_type);
mgb_assert(InputType::NONE == param.tensor_combination_type);
return opr::MultiHeadAttn::make(i[0], i[1], i[2], i[3], param, config)[0]
.node()
->owner_opr();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册