diff --git a/dnn/include/megdnn/oprs/nn.h b/dnn/include/megdnn/oprs/nn.h index 7d8c25d647c643f17ce05df34861b9c158e2c951..fa0be658e7c8ace2f1e6c8fdab8ed2971ac6c6cc 100644 --- a/dnn/include/megdnn/oprs/nn.h +++ b/dnn/include/megdnn/oprs/nn.h @@ -2574,6 +2574,73 @@ protected: size_t workspace_in_bytes); }; +class MultiHeadAttnBase : public OperatorBase { + DEF_OPR_IMPL_CTOR(MultiHeadAttnBase, OperatorBase); + DEF_OPR_PARAM(MultiHeadAttn); +}; + +class MultiHeadAttnForward : public MultiHeadAttnBase { + DEF_OPR_IMPL(MultiHeadAttnForward, MultiHeadAttnBase, 4, 2); + +public: + virtual void exec( + _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values, + _megdnn_tensor_in wqkv, _megdnn_tensor_out out, + _megdnn_tensor_out reserveSpace, _megdnn_workspace workspace) = 0; + MGE_WIN_DECLSPEC_FUC void deduce_layout( + const TensorLayout& queries, const TensorLayout& keys, + const TensorLayout& values, const TensorLayout& wqkv, TensorLayout& out, + TensorLayout& reserveSpace); + virtual size_t get_workspace_in_bytes( + const TensorLayout& queries, const TensorLayout& keys, + const TensorLayout& values, const TensorLayout& wqkv, + const TensorLayout& out, const TensorLayout& reserveSpace) = 0; + virtual size_t get_reservespace_in_bytes( + const TensorLayout& queries, const TensorLayout& keys, + const TensorLayout& values, const TensorLayout& wqkv, + const TensorLayout& out, const TensorLayout& reserveSpace) = 0; + +protected: + void check_exec( + const TensorLayout& queries, const TensorLayout& keys, + const TensorLayout& values, const TensorLayout& wqkv, + const TensorLayout& out, const TensorLayout& reserveSpace, + size_t workspace_in_bytes); +}; +using MultiHeadAttn = MultiHeadAttnForward; + +class MultiHeadAttnBackward : public MultiHeadAttnBase { + DEF_OPR_IMPL(MultiHeadAttnBackward, MultiHeadAttnBase, 6, 4); + +public: + virtual void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys, + _megdnn_tensor_in values, _megdnn_tensor_in wqkv, + _megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries, + _megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues, + _megdnn_tensor_out dweights, _megdnn_workspace workspace) = 0; + MGE_WIN_DECLSPEC_FUC void deduce_layout( + const TensorLayout& diff, const TensorLayout& queries, + const TensorLayout& keys, const TensorLayout& values, + const TensorLayout& wqkv, const TensorLayout& reserveSpace, + TensorLayout& dqueries, TensorLayout& dkeys, TensorLayout& dvalues, + TensorLayout& dweights); + virtual size_t get_workspace_in_bytes( + const TensorLayout& diff, const TensorLayout& queries, + const TensorLayout& keys, const TensorLayout& values, + const TensorLayout& wqkv, const TensorLayout& reserveSpace, + const TensorLayout& dqueries, const TensorLayout& dkeys, + const TensorLayout& dvalues, const TensorLayout& dweights) = 0; + +protected: + void check_exec( + const TensorLayout& diff, const TensorLayout& queries, + const TensorLayout& keys, const TensorLayout& values, + const TensorLayout& wqkv, const TensorLayout& reserveSpace, + const TensorLayout& dqueries, const TensorLayout& dkeys, + const TensorLayout& dvalues, const TensorLayout& dweights, + size_t workspace_in_bytes); +}; } // namespace megdnn #include "megdnn/internal/opr_header_epilogue.h" diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 78cbaf4d3b9fe14da224ef862e4a2a82a6ecf619..33fe4cbace400b96707866681e1b2fc184e42899 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -1330,3 +1330,20 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), add_fields('float32', Doc('p', 'the order of norm'), '2'). add_fields('int32', Doc('dim', 'which dim the norm performed along'), '-1'), ) + +(pdef('MultiHeadAttn') + .add_fields('uint32', Doc('num_heads', 'Number of parallel attention heads.'), '1') + .add_fields('float32', Doc('sm_scaler', 'Softmax smoothing (1.0 >= smScaler >= 0.0) or sharpening (smScaler > 1.0) coefficient.'), '1.f') + .add_fields('uint32', Doc('input_order', 'The sequence data layout, allows the user to select 3! = 6 different data layouts or permutations of BEAM, BATCH and TIME dimensions.'), '0') + .add_fields('bool', Doc('reslink', 'Whether to add input query to final output.'), 'false') + .add_fields('bool', Doc('training', 'Whether it is in training mode.'), 'true') + .add_fields('bool', Doc('bias', 'Whether to add linear bias.'), 'false') + .add_fields('bool', Doc('attn_mask', 'Whether to add attn_mask.'), 'false') + .add_fields('bool', Doc('enable_qproj', 'enable query weight projection.'), 'true') + .add_fields('bool', Doc('enable_kproj', 'enable key weight projection.'), 'true') + .add_fields('bool', Doc('enable_vproj', 'enable value weight projection.'), 'true') + .add_fields('bool', Doc('enable_oproj', 'enable output weight projection.'), 'true') + .add_fields('uint64', Doc('seed', 'Random number seed for drop'), '0') + .add_fields('float32', Doc('attn_prob', 'Dropout probability on attention, is applied directly to the softmax output'), '0.f') + .add_fields('float32', Doc('out_prob', 'Dropout probability on output, alters the multi-head attention output'), '0.f') + ) diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index 19124f780a1417f277d9f40ba78f7480e992e9ad..7cca0cbf56dbb374c475f7b647232447475d7025 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -221,7 +221,9 @@ private: cb(RegionRestrictedConvolutionBackwardFilter) \ cb(GroupNormForward) \ cb(GroupNormBackward) \ - cb(MaskedFill) + cb(MaskedFill) \ + cb(MultiHeadAttnForward)\ + cb(MultiHeadAttnBackward) // clang-format on /*! diff --git a/dnn/src/common/multi_head_attn.cpp b/dnn/src/common/multi_head_attn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a3219a67483c39fce374e12100972ac6c39d0c82 --- /dev/null +++ b/dnn/src/common/multi_head_attn.cpp @@ -0,0 +1,166 @@ +#include "megdnn/basic_types.h" +#include "megdnn/oprs.h" +#include "src/common/utils.cuh" +#include "unroll_macro.h" + +#include "src/common/utils.h" + +namespace megdnn { + +using Param = MultiHeadAttnBase::Param; + +void MultiHeadAttnForward::deduce_layout( + const TensorLayout& queries, const TensorLayout& keys, + const TensorLayout& values, const TensorLayout& wqkv, TensorLayout& out, + TensorLayout& reserveSpace) { + megdnn_assert( + queries.ndim == 3, + "queries.ndim should be 3[batch, sequence, embeding], but got %zu", + queries.ndim); + size_t size = + get_reservespace_in_bytes(queries, keys, values, wqkv, out, reserveSpace); + out = TensorLayout( + {queries.shape[0], queries.shape[1], queries.shape[2]}, queries.dtype); + reserveSpace = TensorLayout({size}, queries.dtype); +} + +void MultiHeadAttnForward::check_exec( + const TensorLayout& queries, const TensorLayout& keys, + const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& out, + const TensorLayout& reserveSpace, size_t workspace_in_bytes) { + Param p = param(); + megdnn_assert_contiguous(queries); + megdnn_assert_contiguous(keys); + megdnn_assert_contiguous(values); + megdnn_assert_contiguous(wqkv); + megdnn_assert_contiguous(out); + if (p.training) + megdnn_assert_contiguous(reserveSpace); + auto required_workspace_in_bytes = + get_workspace_in_bytes(queries, keys, values, wqkv, out, reserveSpace); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); + + megdnn_assert( + queries.ndim == 3, + "queries.ndim should be 3[batch, sequence, embeding], but got %zu", + queries.ndim); + megdnn_assert( + keys.ndim == 3, + "keys.ndim should be 3[batch, sequence, embeding], but got %zu", keys.ndim); + megdnn_assert( + values.ndim == 3, + "values.ndim should be 3[batch, sequence, embeding], but got %zu", + values.ndim); + + auto errmsg = [&]() { + return megdnn_layout_msg(queries) + ", " + megdnn_layout_msg(keys) + ", " + + megdnn_layout_msg(values) + ", " + megdnn_layout_msg(wqkv) + ", " + + megdnn_layout_msg(out) + ", " + megdnn_layout_msg(reserveSpace); + }; + megdnn_assert(queries.shape[0] == out.shape[0], "%s", errmsg().c_str()); + megdnn_assert(keys.shape[0] == values.shape[0], "%s", errmsg().c_str()); + megdnn_assert(queries.shape[0] == keys.shape[0], "%s", errmsg().c_str()); + megdnn_assert(queries.shape[1] == out.shape[1], "%s", errmsg().c_str()); + megdnn_assert(keys.shape[1] == values.shape[1], "%s", errmsg().c_str()); + megdnn_assert( + queries.shape[2] == keys.shape[2] and keys.shape[2] == values.shape[2] and + queries.shape[2] == out.shape[2], + "%s", errmsg().c_str()); +} + +void MultiHeadAttnBackward::deduce_layout( + const TensorLayout& diff, const TensorLayout& queries, const TensorLayout& keys, + const TensorLayout& values, const TensorLayout& wqkv, + const TensorLayout& reserveSpace, TensorLayout& dqueries, TensorLayout& dkeys, + TensorLayout& dvalues, TensorLayout& dweights) { + MEGDNN_MARK_USED_VAR(diff); + MEGDNN_MARK_USED_VAR(reserveSpace); + dqueries = queries; + dkeys = keys; + dvalues = values; + dweights = wqkv; +} + +void MultiHeadAttnBackward::check_exec( + const TensorLayout& diff, const TensorLayout& queries, const TensorLayout& keys, + const TensorLayout& values, const TensorLayout& wqkv, + const TensorLayout& reserveSpace, const TensorLayout& dqueries, + const TensorLayout& dkeys, const TensorLayout& dvalues, + const TensorLayout& dweights, size_t workspace_in_bytes) { + Param p = param(); + megdnn_assert( + p.training, + "When calling MultiHeadAttn backward, param().training must be true, " + "but got false"); + megdnn_assert_contiguous(diff); + megdnn_assert_contiguous(queries); + megdnn_assert_contiguous(keys); + megdnn_assert_contiguous(values); + megdnn_assert_contiguous(wqkv); + megdnn_assert_contiguous(dqueries); + megdnn_assert_contiguous(dkeys); + megdnn_assert_contiguous(dvalues); + megdnn_assert_contiguous(dweights); + if (p.training) + megdnn_assert_contiguous(reserveSpace); + auto required_workspace_in_bytes = get_workspace_in_bytes( + diff, queries, keys, values, wqkv, reserveSpace, dqueries, dkeys, dvalues, + dweights); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); + megdnn_assert(reserveSpace.total_nr_elems() > 0); + + megdnn_assert( + queries.ndim == 3, + "queries.ndim should be 3[batch, sequence, embeding], but got %zu", + queries.ndim); + megdnn_assert( + keys.ndim == 3, + "keys.ndim should be 3[batch, sequence, embeding], but got %zu", keys.ndim); + megdnn_assert( + values.ndim == 3, + "values.ndim should be 3[batch, sequence, embeding], but got %zu", + values.ndim); + megdnn_assert( + diff.ndim == 3, + "diff.ndim should be 3[batch, sequence, embeding], but got %zu", diff.ndim); + + auto errmsg = [&]() { + return megdnn_layout_msg(diff) + ", " + megdnn_layout_msg(queries) + ", " + + megdnn_layout_msg(keys) + ", " + megdnn_layout_msg(values) + ", " + + megdnn_layout_msg(wqkv) + ", " + megdnn_layout_msg(reserveSpace) + ", " + + megdnn_layout_msg(dqueries) + ", " + megdnn_layout_msg(dkeys) + ", " + + megdnn_layout_msg(dvalues) + ", " + megdnn_layout_msg(dweights); + }; + + auto equal_layout = [](const TensorLayout& lhs, const TensorLayout& rhs) -> bool { + if (!(lhs.ndim == rhs.ndim && lhs.dtype == rhs.dtype && + lhs.format == rhs.format)) + return false; + for (size_t i = 0; i < lhs.ndim; ++i) { + if (lhs.shape[i] != rhs.shape[i] || lhs.stride[i] != rhs.stride[i]) { + return false; + } + } + return true; + }; + + megdnn_assert(equal_layout(queries, diff), "%s", errmsg().c_str()); + megdnn_assert(equal_layout(queries, dqueries), "%s", errmsg().c_str()); + megdnn_assert(equal_layout(keys, dkeys), "%s", errmsg().c_str()); + megdnn_assert(equal_layout(values, dvalues), "%s", errmsg().c_str()); + megdnn_assert(equal_layout(wqkv, dweights), "%s", errmsg().c_str()); + + megdnn_assert(queries.shape[0] == diff.shape[0], "%s", errmsg().c_str()); + megdnn_assert(keys.shape[0] == values.shape[0], "%s", errmsg().c_str()); + megdnn_assert(queries.shape[0] == keys.shape[0], "%s", errmsg().c_str()); + megdnn_assert(queries.shape[1] == diff.shape[1], "%s", errmsg().c_str()); + megdnn_assert(keys.shape[1] == values.shape[1], "%s", errmsg().c_str()); + megdnn_assert( + queries.shape[2] == keys.shape[2] and keys.shape[2] == values.shape[2] and + queries.shape[2] == diff.shape[2], + "%s", errmsg().c_str()); +} + +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/common/opr_trait.h b/dnn/src/common/opr_trait.h index 541bd67a3d0bad15867ce577cdfed39c806c6733..65591fb9194b9653c0efe0a7e720ea69364d3cf3 100644 --- a/dnn/src/common/opr_trait.h +++ b/dnn/src/common/opr_trait.h @@ -1,5 +1,6 @@ #pragma once #include "megdnn/oprs.h" +#include "megdnn/oprs/nn.h" #include <cstddef> @@ -147,6 +148,8 @@ DEF(RegionRestrictedConvolutionBackwardFilter, 5, true, false); DEF(GroupNormForward, 6, true, true); DEF(GroupNormBackward, 8, true, true); DEF(MaskedFill, 3, false, true); +DEF(MultiHeadAttnForward, 6, true, true); +DEF(MultiHeadAttnBackward, 10, true, true); } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/cudnn_wrapper.cpp b/dnn/src/cuda/cudnn_wrapper.cpp index b8ff39882b28b4c4b4aa19cc37e01c1a7b936067..a5b93b429c39cc525619fecc1e97dc4405132f1f 100644 --- a/dnn/src/cuda/cudnn_wrapper.cpp +++ b/dnn/src/cuda/cudnn_wrapper.cpp @@ -3,12 +3,10 @@ #include "src/common/utils.h" #include "src/cuda/utils.h" -namespace { - -using namespace megdnn; +namespace megdnn { +namespace cuda { -cudnnDataType_t to_cudnn_dtype( - DType type, const param::Convolution::Format format = {}) { +cudnnDataType_t to_cudnn_dtype(DType type, const param::Convolution::Format format) { switch (type.enumv()) { case DTypeEnum::Float32: return CUDNN_DATA_FLOAT; @@ -66,8 +64,9 @@ cudnnTensorFormat_t to_cudnn_format(const param::Convolution::Format format) { megdnn_assert_internal(0); } } +} // namespace cuda -} // namespace +} // namespace megdnn namespace megdnn { namespace cuda { @@ -558,6 +557,71 @@ const std::unordered_map<cudnnConvolutionFwdAlgo_t, CudnnAlgoPack::Attr> CudnnAl #undef V #undef V1 +#if CUDNN_VERSION >= 8004 +SeqTensorDesc::~SeqTensorDesc() { + cudnn_check(cudnnDestroySeqDataDescriptor(desc)); +} +SeqTensorDesc::SeqTensorDesc() { + cudnnCreateSeqDataDescriptor(&desc); +} + +SeqTensorDesc::SeqTensorDesc( + const TensorLayout& layout, const size_t batchSize, const size_t seqLen, + const size_t elemSize, const size_t input_order, int* seqArray) { + cudnnCreateSeqDataDescriptor(&desc); + set(layout, batchSize, seqLen, elemSize, input_order, seqArray); +} + +void SeqTensorDesc::set( + const TensorLayout& layout, const size_t batchSize, const size_t seqLen, + const size_t elemSize, const size_t input_order, int* seqArray) { + switch (input_order) { + case 0: // dimAxes = [Batch, Beam, Time] + dimAxes[0] = CUDNN_SEQDATA_BATCH_DIM; + dimAxes[1] = CUDNN_SEQDATA_BEAM_DIM; + dimAxes[2] = CUDNN_SEQDATA_TIME_DIM; + break; + case 1: // dimAxes = [Beam, Batch, Time] + dimAxes[0] = CUDNN_SEQDATA_BEAM_DIM; + dimAxes[1] = CUDNN_SEQDATA_BATCH_DIM; + dimAxes[2] = CUDNN_SEQDATA_TIME_DIM; + break; + case 2: // dimAxes = [Batch, Time, Beam] + dimAxes[0] = CUDNN_SEQDATA_BATCH_DIM; + dimAxes[1] = CUDNN_SEQDATA_TIME_DIM; + dimAxes[2] = CUDNN_SEQDATA_BEAM_DIM; + break; + case 3: // dimAxes = [Beam, Time, Batch] + dimAxes[0] = CUDNN_SEQDATA_BEAM_DIM; + dimAxes[1] = CUDNN_SEQDATA_TIME_DIM; + dimAxes[2] = CUDNN_SEQDATA_BATCH_DIM; + break; + case 4: // dimAxes = [Time, Batch, Beam] + dimAxes[0] = CUDNN_SEQDATA_TIME_DIM; + dimAxes[1] = CUDNN_SEQDATA_BATCH_DIM; + dimAxes[2] = CUDNN_SEQDATA_BEAM_DIM; + break; + case 5: // dimAxes = [Time, Beam, Batch] + dimAxes[0] = CUDNN_SEQDATA_TIME_DIM; + dimAxes[1] = CUDNN_SEQDATA_BEAM_DIM; + dimAxes[2] = CUDNN_SEQDATA_BATCH_DIM; + break; + default: + megdnn_throw(ssprintf("ERROR: wrong attention layout %zu", input_order)); + } + dimAxes[3] = CUDNN_SEQDATA_VECT_DIM; + + dim[CUDNN_SEQDATA_BEAM_DIM] = 1; + dim[CUDNN_SEQDATA_BATCH_DIM] = batchSize; + dim[CUDNN_SEQDATA_TIME_DIM] = seqLen; + dim[CUDNN_SEQDATA_VECT_DIM] = elemSize; + + cudnnDataType_t cudnn_dtype = to_cudnn_dtype(layout.dtype); + cudnn_check(cudnnSetSeqDataDescriptor( + desc, cudnn_dtype, CUDNN_SEQDATA_DIM_COUNT, dim, dimAxes, batchSize, + seqArray, NULL)); +} +#endif } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/cudnn_wrapper.h b/dnn/src/cuda/cudnn_wrapper.h index c2f3992091114b98626213480dc16f1c0f56c8b4..a1962c836237f89b544dbb52f4451b1360cfdba9 100644 --- a/dnn/src/cuda/cudnn_wrapper.h +++ b/dnn/src/cuda/cudnn_wrapper.h @@ -2,12 +2,18 @@ #include <unordered_map> #include "megdnn/basic_types.h" +#include "megdnn/opr_param_defs.h" #include "megdnn/oprs/nn.h" #include "src/cuda/cudnn_with_check.h" namespace megdnn { namespace cuda { +cudnnDataType_t to_cudnn_dtype( + DType type, const param::Convolution::Format format = {}); + +cudnnTensorFormat_t to_cudnn_format(const param::Convolution::Format format); + /*! * \brief get compute_type of convolution operations */ @@ -85,6 +91,24 @@ public: cudnnConvolutionDescriptor_t desc; }; +#if CUDNN_VERSION >= 8004 +class SeqTensorDesc { +public: + int dim[CUDNN_SEQDATA_DIM_COUNT]; + cudnnSeqDataAxis_t dimAxes[CUDNN_SEQDATA_DIM_COUNT]; + cudnnSeqDataDescriptor_t desc; + + ~SeqTensorDesc(); + SeqTensorDesc(); + SeqTensorDesc( + const TensorLayout& layout, const size_t batchSize, const size_t seqLen, + const size_t elemSize, const size_t dataLayout, int* seqArray); + void set( + const TensorLayout& layout, const size_t batchSize, const size_t seqLen, + const size_t elemSize, const size_t dataLayout, int* seqArray); +}; +#endif + class CudnnAlgoPack { public: //! algorithm attr diff --git a/dnn/src/cuda/cudnn_wrapper_v8.h b/dnn/src/cuda/cudnn_wrapper_v8.h index 575f43d0f00ea8e6e5a3ce5f64c6303c9b9a085a..09ec3b89b7033b1a28f05178468cb8db477d427f 100644 --- a/dnn/src/cuda/cudnn_wrapper_v8.h +++ b/dnn/src/cuda/cudnn_wrapper_v8.h @@ -55,7 +55,6 @@ void run_conv_bias_act_with_plan( const cudnnHandle_t& handle, const cudnn_frontend::ExecutionPlan& plan, const TensorND& x, const TensorND& y, const TensorND& w, const TensorND& b, const TensorND& z, const Workspace& workspace); - } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/dropout/opr_impl.h b/dnn/src/cuda/dropout/opr_impl.h index 5d995eeb6103c475bb9f29c7315db146667627a2..50b0f3f7cc0c0df47e3ea5030b06e0e9094e7fd7 100644 --- a/dnn/src/cuda/dropout/opr_impl.h +++ b/dnn/src/cuda/dropout/opr_impl.h @@ -61,6 +61,9 @@ public: bool initialized() { return status != nullptr; } friend class DropoutForwardImpl; friend class DropoutBackwardImpl; +#if CUDNN_VERSION >= 8004 + friend class MultiHeadAttnStatus; +#endif }; // similar to RNG operator, dropout operator also have status diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp index 01fbfba5612365310c235dd1bedd6dfbf84832e0..e1f397ce7bb13c1744f72e97b229e3324437f698 100644 --- a/dnn/src/cuda/handle_create.cpp +++ b/dnn/src/cuda/handle_create.cpp @@ -50,6 +50,7 @@ #include "src/cuda/matrix_mul/opr_impl.h" #include "src/cuda/max_tensor_diff/opr_impl.h" #include "src/cuda/mesh_indexing/opr_impl.h" +#include "src/cuda/multi_head_attn/opr_impl.h" #include "src/cuda/norm/opr_impl.h" #include "src/cuda/padding/opr_impl.h" #include "src/cuda/param_pack/opr_impl.h" @@ -230,6 +231,8 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(NormForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(RegionRestrictedConvolutionForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(RegionRestrictedConvolutionBackwardData); MEGDNN_SPECIALIZE_CREATE_OPERATOR(RegionRestrictedConvolutionBackwardFilter); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(MultiHeadAttnForward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(MultiHeadAttnBackward); template <typename Opr> std::unique_ptr<Opr> HandleImpl::create_operator() { diff --git a/dnn/src/cuda/multi_head_attn/helper.cpp b/dnn/src/cuda/multi_head_attn/helper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ab685629242deb8f3682214d5afe778fb135d07a --- /dev/null +++ b/dnn/src/cuda/multi_head_attn/helper.cpp @@ -0,0 +1,181 @@ +#include "src/cuda/multi_head_attn/helper.h" +#if CUDNN_VERSION >= 8004 + +namespace megdnn { +namespace cuda { + +AuxiliaryArray::~AuxiliaryArray() { + if (loWinIdx) + free(loWinIdx); + if (hiWinIdx) + free(hiWinIdx); + if (seqQArray) + free(seqQArray); + if (seqKArray) + free(seqKArray); + if (devSeqQArray) + cuda_check(cudaFree(devSeqQArray)); + if (devSeqKArray) + cuda_check(cudaFree(devSeqKArray)); +} + +bool AuxiliaryArray::is_initialized( + const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK, + bool _attnMask) { + if (_batchSize != batchSize or _seqLenQ != seqLenQ or _seqLenK != seqLenK or + _attnMask != attnMask or !seqQArray or !seqKArray or !devSeqQArray or + !devSeqKArray or !loWinIdx or !hiWinIdx) + return false; + return true; +} + +void AuxiliaryArray::set( + const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK, + bool _attnMask) { + if (_batchSize == batchSize && _seqLenQ == seqLenQ && _seqLenK == seqLenK && + _attnMask == attnMask) + return; + else { + if (loWinIdx) + free(loWinIdx); + if (hiWinIdx) + free(hiWinIdx); + if (seqQArray) + free(seqQArray); + if (seqKArray) + free(seqKArray); + if (devSeqQArray) + cuda_check(cudaFree(devSeqQArray)); + if (devSeqKArray) + cuda_check(cudaFree(devSeqKArray)); + }; + + seqLenQ = _seqLenQ; + seqLenK = _seqLenK; + batchSize = _batchSize; + attnMask = _attnMask; + size_t seqQArraySize = 1 * batchSize; + size_t seqKArraySize = batchSize; + seqQArray = (int*)calloc(seqQArraySize, sizeof(int)); + seqKArray = (int*)calloc(seqKArraySize, sizeof(int)); + for (size_t i = 0; i < seqQArraySize; ++i) + seqQArray[i] = seqLenQ; + for (size_t i = 0; i < seqKArraySize; ++i) + seqKArray[i] = seqLenK; + + cuda_check(cudaMalloc((void**)&devSeqQArray, seqQArraySize * sizeof(int))); + cuda_check(cudaMalloc((void**)&devSeqKArray, seqKArraySize * sizeof(int))); + + cuda_check(cudaMemcpy( + devSeqQArray, seqQArray, seqQArraySize * sizeof(int), + cudaMemcpyHostToDevice)); + cuda_check(cudaMemcpy( + devSeqKArray, seqKArray, seqKArraySize * sizeof(int), + cudaMemcpyHostToDevice)); + + loWinIdx = (int*)calloc(seqLenQ, sizeof(int)); + hiWinIdx = (int*)calloc(seqLenQ, sizeof(int)); + for (size_t i = 0; i < seqLenQ; ++i) { + loWinIdx[i] = 0; + if (attnMask) + hiWinIdx[i] = i + 1; + else + hiWinIdx[i] = seqLenK; + } +} + +void MultiHeadAttnStatus::set( + cudnnHandle_t handle, const Param& p, const TensorLayout& q, + const TensorLayout& k, const TensorLayout& v) { + float attn_prob = p.training ? p.attn_prob : 0.f; + float out_prob = p.training ? p.out_prob : 0.f; + if (!attn_dropout_status.initialized()) + attn_dropout_status.set(handle, p.seed, attn_prob); + if (!out_dropout_status.initialized()) + out_dropout_status.set(handle, p.seed, out_prob); + + if (attn_dropout_status.drop_prob != attn_prob) { + attn_dropout_status.drop_prob = attn_prob; + attn_dropout_status.restore_desc(handle); + } + if (out_dropout_status.drop_prob != out_prob) { + out_dropout_status.drop_prob = out_prob; + out_dropout_status.restore_desc(handle); + } + batchSize = q.shape[0]; + seqLenQ = q.shape[1]; + seqLenK = k.shape[1]; + qSize = q.shape[2]; + kSize = k.shape[2]; + vSize = v.shape[2]; + numHeads = p.num_heads; + qProjSize = p.enable_qproj ? qSize / numHeads : 0; + kProjSize = p.enable_kproj ? kSize / numHeads : 0; + vProjSize = p.enable_vproj ? vSize / numHeads : 0; + oProjSize = p.enable_oproj ? qSize : 0; + attnMask = p.attn_mask; + cudnnDataType_t cudnn_dtype = to_cudnn_dtype(q.dtype); + auto flag = CUDNN_ATTN_QUERYMAP_ONE_TO_ONE; + if (p.bias) + flag = flag | CUDNN_ATTN_ENABLE_PROJ_BIASES; +#if CUDNN_VERSION < 8600 + // TODO: CUDNN_VERSION < 8600 and out dropout > 0.0, we need to go to the proxy cuda + // implementation. + cudnn_check(cudnnSetAttnDescriptor( + attn_desc, flag, numHeads, p.sm_scaler, cudnn_dtype, cudnn_dtype, + CUDNN_DEFAULT_MATH, attn_dropout_status.desc.desc, NULL, qSize, kSize, + vSize, qProjSize, kProjSize, vProjSize, oProjSize, seqLenQ, seqLenK, + batchSize, 1)); +#else + cudnn_check(cudnnSetAttnDescriptor( + attn_desc, flag, numHeads, p.sm_scaler, cudnn_dtype, cudnn_dtype, + CUDNN_DEFAULT_MATH, attn_dropout_status.desc.desc, + out_dropout_status.desc.desc, qSize, kSize, vSize, qProjSize, kProjSize, + vProjSize, oProjSize, seqLenQ, seqLenK, batchSize, 1)); +#endif + + auxArray.set(batchSize, seqLenQ, seqLenK, p.attn_mask); + + if (p.training) + cudnnGetMultiHeadAttnBuffers( + handle, attn_desc, &sizeWeights, &sizeWkspace, &sizeReserve); + else { + cudnnGetMultiHeadAttnBuffers( + handle, attn_desc, &sizeWeights, &sizeWkspace, NULL); + sizeReserve = 0; + } +} + +bool MultiHeadAttnStatus::is_initialized( + const Param& p, const TensorLayout& q, const TensorLayout& k, + const TensorLayout& v) { + float attn_prob = p.training ? p.attn_prob : 0.f; + float out_prob = p.training ? p.out_prob : 0.f; + if (!attn_dropout_status.initialized() or !out_dropout_status.initialized() or + attn_dropout_status.drop_prob != attn_prob or + out_dropout_status.drop_prob != out_prob) + return false; + if (q.shape[0] != batchSize or q.shape[1] != seqLenQ or k.shape[1] != seqLenK or + q.shape[2] != qSize or k.shape[2] != kSize or v.shape[2] != vSize or + attnMask != p.attn_mask or numHeads != p.num_heads) { + return false; + } + if ((p.enable_qproj && (qProjSize == 0 or qProjSize != qSize / p.num_heads)) or + (p.enable_kproj && (kProjSize == 0 or kProjSize != kSize / p.num_heads)) or + (p.enable_vproj && (vProjSize == 0 or vProjSize != vSize / p.num_heads)) or + (p.enable_oproj && (oProjSize == 0 or oProjSize != q.shape[2]))) + return false; + if ((!p.enable_qproj && qProjSize != 0) or (!p.enable_kproj && kProjSize != 0) or + (!p.enable_vproj && vProjSize != 0) or (!p.enable_oproj && oProjSize != 0)) + return false; + if (!auxArray.is_initialized(batchSize, seqLenQ, seqLenK, attnMask)) + return false; + if (p.training and sizeReserve == 0) + return false; + return true; +} + +} // namespace cuda +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/multi_head_attn/helper.h b/dnn/src/cuda/multi_head_attn/helper.h new file mode 100644 index 0000000000000000000000000000000000000000..bdb37005c6c82519b349a3c277a702f9feedc851 --- /dev/null +++ b/dnn/src/cuda/multi_head_attn/helper.h @@ -0,0 +1,80 @@ +#pragma once +#include "src/cuda/cudnn_wrapper.h" +#if CUDNN_VERSION >= 8004 +#include "megdnn/basic_types.h" +#include "megdnn/oprs/nn.h" +#include "src/common/algo_chooser.h" +#include "src/common/utils.h" +#include "src/cuda/dropout/opr_impl.h" +#include "src/cuda/handle.h" + +namespace megdnn { +namespace cuda { + +struct AuxiliaryArray { +public: + int* seqQArray = nullptr; + int* seqKArray = nullptr; + int* devSeqQArray = nullptr; + int* devSeqKArray = nullptr; + int* loWinIdx = nullptr; + int* hiWinIdx = nullptr; + size_t seqLenQ = 0; + size_t seqLenK = 0; + size_t batchSize = 0; + bool attnMask = 0; + ~AuxiliaryArray(); + void set( + const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK, + bool _attnMask); + bool is_initialized( + const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK, + bool _attnMask); +}; + +using Param = megdnn::MultiHeadAttn::Param; + +class MultiHeadAttnStatus { + DropoutStatus attn_dropout_status; + DropoutStatus out_dropout_status; + + cudnnAttnDescriptor_t attn_desc; + + AuxiliaryArray auxArray; + + size_t numHeads = 0; + size_t batchSize = 0; + size_t seqLenQ = 0; + size_t seqLenK = 0; + size_t qSize = 0; + size_t kSize = 0; + size_t vSize = 0; + size_t qProjSize = 0; + size_t kProjSize = 0; + size_t vProjSize = 0; + size_t oProjSize = 0; + bool attnMask = 0; + + size_t sizeWeights = 0; + size_t sizeWkspace = 0; + size_t sizeReserve = 0; + +public: + MultiHeadAttnStatus() { cudnn_check(cudnnCreateAttnDescriptor(&attn_desc)); } + ~MultiHeadAttnStatus() { cudnn_check(cudnnDestroyAttnDescriptor(attn_desc)); } + +private: + void set( + cudnnHandle_t handle, const Param& p, const TensorLayout& q, + const TensorLayout& k, const TensorLayout& v); + bool is_initialized( + const Param& p, const TensorLayout& q, const TensorLayout& k, + const TensorLayout& v); + friend class MultiHeadAttnBase; + friend class MultiHeadAttnForwardImpl; + friend class MultiHeadAttnBackwardImpl; +}; +} // namespace cuda +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/multi_head_attn/opr_impl.cpp b/dnn/src/cuda/multi_head_attn/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f4125cdd18db020a58fb5ca90006befc10a26e31 --- /dev/null +++ b/dnn/src/cuda/multi_head_attn/opr_impl.cpp @@ -0,0 +1,241 @@ +#include "src/cuda/multi_head_attn/opr_impl.h" +#include "src/common/utils.cuh" +#include "src/cuda/utils.cuh" +#include "src/cuda/utils.h" + +namespace megdnn { +namespace cuda { + +void MultiHeadAttnForwardImpl::deduce_layout( + const TensorLayout& queries, const TensorLayout& keys, + const TensorLayout& values, const TensorLayout& wqkv, TensorLayout& out, + TensorLayout& reserveSpace) { +#if CUDNN_VERSION < 8004 + // TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation. + MEGDNN_MARK_USED_VAR(queries); + MEGDNN_MARK_USED_VAR(keys); + MEGDNN_MARK_USED_VAR(values); + MEGDNN_MARK_USED_VAR(wqkv); + MEGDNN_MARK_USED_VAR(out); + MEGDNN_MARK_USED_VAR(reserveSpace); + return; +#else + MEGDNN_MARK_USED_VAR(keys); + MEGDNN_MARK_USED_VAR(wqkv); + megdnn_assert( + queries.ndim == 3, + "queries.ndim should be 3[batch, sequence, embeding], but got %zu", + queries.ndim); + + if (!desc_status.is_initialized(param(), queries, keys, values)) { + desc_status.set(cudnn_handle(this->handle()), param(), queries, keys, values); + + out = TensorLayout( + TensorShape{queries.shape[0], queries.shape[1], queries.shape[2]}, + queries.dtype); + reserveSpace = + TensorLayout(TensorShape{desc_status.sizeReserve}, queries.dtype); + } +#endif +} + +size_t MultiHeadAttnForwardImpl::get_workspace_in_bytes( + const TensorLayout& queries, const TensorLayout& keys, + const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& out, + const TensorLayout& reserveSpace) { +#if CUDNN_VERSION < 8004 + // TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation. + MEGDNN_MARK_USED_VAR(queries); + MEGDNN_MARK_USED_VAR(keys); + MEGDNN_MARK_USED_VAR(values); + MEGDNN_MARK_USED_VAR(wqkv); + MEGDNN_MARK_USED_VAR(out); + MEGDNN_MARK_USED_VAR(reserveSpace); + return 0; +#else + MEGDNN_MARK_USED_VAR(wqkv); + MEGDNN_MARK_USED_VAR(out); + MEGDNN_MARK_USED_VAR(reserveSpace); + + if (!desc_status.is_initialized(param(), queries, keys, values)) + desc_status.set(cudnn_handle(this->handle()), param(), queries, keys, values); + + return desc_status.sizeWkspace; +#endif +} + +size_t MultiHeadAttnForwardImpl::get_reservespace_in_bytes( + const TensorLayout& queries, const TensorLayout& keys, + const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& out, + const TensorLayout& reserveSpace) { +#if CUDNN_VERSION < 8004 + // TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation. + MEGDNN_MARK_USED_VAR(queries); + MEGDNN_MARK_USED_VAR(keys); + MEGDNN_MARK_USED_VAR(values); + MEGDNN_MARK_USED_VAR(wqkv); + MEGDNN_MARK_USED_VAR(out); + MEGDNN_MARK_USED_VAR(reserveSpace); + return 0; +#else + MEGDNN_MARK_USED_VAR(wqkv); + MEGDNN_MARK_USED_VAR(out); + MEGDNN_MARK_USED_VAR(reserveSpace); + if (!desc_status.is_initialized(param(), queries, keys, values)) + desc_status.set(cudnn_handle(this->handle()), param(), queries, keys, values); + return desc_status.sizeReserve; +#endif +} +void MultiHeadAttnForwardImpl::exec( + _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values, + _megdnn_tensor_in wqkv, _megdnn_tensor_out out, _megdnn_tensor_out reserveSpace, + _megdnn_workspace workspace) { +#if CUDNN_VERSION < 8004 + // TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation. + MEGDNN_MARK_USED_VAR(queries); + MEGDNN_MARK_USED_VAR(keys); + MEGDNN_MARK_USED_VAR(values); + MEGDNN_MARK_USED_VAR(wqkv); + MEGDNN_MARK_USED_VAR(out); + MEGDNN_MARK_USED_VAR(reserveSpace); + MEGDNN_MARK_USED_VAR(workspace); + megdnn_throw( + "The cudnn version is lower than 8.0.4. Please upgrade the cudnn version."); +#else + check_exec( + queries.layout, keys.layout, values.layout, wqkv.layout, out.layout, + reserveSpace.layout, workspace.size); + auto p = param(); + + if (!desc_status.is_initialized(p, queries.layout, keys.layout, values.layout)) + desc_status.set( + cudnn_handle(this->handle()), p, queries.layout, keys.layout, + values.layout); + + SeqTensorDesc q{queries.layout, desc_status.batchSize, + desc_status.seqLenQ, desc_status.qSize, + p.input_order, desc_status.auxArray.seqQArray}; + SeqTensorDesc o{out.layout, desc_status.batchSize, + desc_status.seqLenQ, desc_status.oProjSize, + p.input_order, desc_status.auxArray.seqQArray}; + SeqTensorDesc k{keys.layout, desc_status.batchSize, + desc_status.seqLenK, desc_status.kSize, + p.input_order, desc_status.auxArray.seqKArray}; + SeqTensorDesc v{values.layout, desc_status.batchSize, + desc_status.seqLenK, desc_status.vSize, + p.input_order, desc_status.auxArray.seqKArray}; + + cudnn_check(cudnnMultiHeadAttnForward( + cudnn_handle(this->handle()), desc_status.attn_desc, -1, + desc_status.auxArray.loWinIdx, desc_status.auxArray.hiWinIdx, + desc_status.auxArray.devSeqQArray, desc_status.auxArray.devSeqKArray, + q.desc, queries.raw_ptr(), p.reslink ? queries.raw_ptr() : NULL, k.desc, + keys.raw_ptr(), v.desc, values.raw_ptr(), o.desc, out.raw_ptr(), + desc_status.sizeWeights, + desc_status.sizeWeights > 0 ? wqkv.raw_ptr() : NULL, + desc_status.sizeWkspace, workspace.raw_ptr, + p.training ? desc_status.sizeReserve : 0, + p.training ? reserveSpace.raw_ptr() : NULL)); +#endif +} + +void MultiHeadAttnBackwardImpl::exec( + _megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys, + _megdnn_tensor_in values, _megdnn_tensor_in wqkv, + _megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries, + _megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues, + _megdnn_tensor_out dweights, _megdnn_workspace workspace) { +#if CUDNN_VERSION < 8004 + // TODO: CUDNN_VERSION < 8004 and param().bias = true, we need to go to the proxy + // cuda implementation. + MEGDNN_MARK_USED_VAR(diff); + MEGDNN_MARK_USED_VAR(queries); + MEGDNN_MARK_USED_VAR(keys); + MEGDNN_MARK_USED_VAR(values); + MEGDNN_MARK_USED_VAR(wqkv); + MEGDNN_MARK_USED_VAR(reserveSpace); + MEGDNN_MARK_USED_VAR(dqueries); + MEGDNN_MARK_USED_VAR(dkeys); + MEGDNN_MARK_USED_VAR(dvalues); + MEGDNN_MARK_USED_VAR(dweights); + megdnn_throw( + "The cudnn version is lower than 8.0.4. Please upgrade the cudnn version."); +#else +#if CUDNN_VERSION < 8600 + megdnn_assert( + !param().bias, + "If the cudnn version is lower than 8.6.0, param().bias must be false, " + "but got true, because there is an error in the " + "dbias result during the backward calculation."); +#endif + + check_exec( + diff.layout, queries.layout, keys.layout, values.layout, wqkv.layout, + reserveSpace.layout, dqueries.layout, dkeys.layout, dvalues.layout, + dweights.layout, workspace.size); + auto p = param(); + + if (!desc_status.is_initialized(p, queries.layout, keys.layout, values.layout)) + desc_status.set( + cudnn_handle(this->handle()), p, queries.layout, keys.layout, + values.layout); + + SeqTensorDesc q{queries.layout, desc_status.batchSize, + desc_status.seqLenQ, desc_status.qSize, + p.input_order, desc_status.auxArray.seqQArray}; + SeqTensorDesc d{diff.layout, desc_status.batchSize, + desc_status.seqLenQ, desc_status.oProjSize, + p.input_order, desc_status.auxArray.seqQArray}; + SeqTensorDesc k{keys.layout, desc_status.batchSize, + desc_status.seqLenK, desc_status.kSize, + p.input_order, desc_status.auxArray.seqKArray}; + SeqTensorDesc v{values.layout, desc_status.batchSize, + desc_status.seqLenK, desc_status.vSize, + p.input_order, desc_status.auxArray.seqKArray}; + + cudnn_check(cudnnMultiHeadAttnBackwardData( + cudnn_handle(this->handle()), desc_status.attn_desc, + desc_status.auxArray.loWinIdx, desc_status.auxArray.hiWinIdx, + desc_status.auxArray.devSeqQArray, desc_status.auxArray.devSeqKArray, + d.desc, diff.raw_ptr(), q.desc, dqueries.raw_ptr(), queries.raw_ptr(), + k.desc, dkeys.raw_ptr(), keys.raw_ptr(), v.desc, dvalues.raw_ptr(), + values.raw_ptr(), desc_status.sizeWeights, + desc_status.sizeWeights > 0 ? wqkv.raw_ptr() : NULL, + desc_status.sizeWkspace, workspace.raw_ptr, desc_status.sizeReserve, + reserveSpace.raw_ptr())); + + cuda_check(cudaMemset(dweights.raw_ptr(), 0, desc_status.sizeWeights)); +#if CUDNN_VERSION < 8600 + cuda_check(cudaDeviceSynchronize()); +#endif + cudnn_check(cudnnMultiHeadAttnBackwardWeights( + cudnn_handle(this->handle()), desc_status.attn_desc, CUDNN_WGRAD_MODE_ADD, + q.desc, queries.raw_ptr(), k.desc, keys.raw_ptr(), v.desc, values.raw_ptr(), + d.desc, diff.raw_ptr(), desc_status.sizeWeights, + desc_status.sizeWeights > 0 ? wqkv.raw_ptr() : NULL, + desc_status.sizeWeights > 0 ? dweights.raw_ptr() : NULL, + desc_status.sizeWkspace, workspace.raw_ptr, desc_status.sizeReserve, + reserveSpace.raw_ptr())); +#endif +} +size_t MultiHeadAttnBackwardImpl::get_workspace_in_bytes( + const TensorLayout& diff, const TensorLayout& queries, const TensorLayout& keys, + const TensorLayout& values, const TensorLayout& wqkv, + const TensorLayout& reserveSpace, const TensorLayout& dqueries, + const TensorLayout& dkeys, const TensorLayout& dvalues, + const TensorLayout& dweights) { + MEGDNN_MARK_USED_VAR(diff); + MEGDNN_MARK_USED_VAR(queries); + MEGDNN_MARK_USED_VAR(keys); + MEGDNN_MARK_USED_VAR(values); + MEGDNN_MARK_USED_VAR(wqkv); + MEGDNN_MARK_USED_VAR(reserveSpace); + MEGDNN_MARK_USED_VAR(dqueries); + MEGDNN_MARK_USED_VAR(dkeys); + MEGDNN_MARK_USED_VAR(dvalues); + MEGDNN_MARK_USED_VAR(dweights); + return 0; +} +} // namespace cuda +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/multi_head_attn/opr_impl.h b/dnn/src/cuda/multi_head_attn/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..4596bd37175c50469e031701ef799cf64c871e84 --- /dev/null +++ b/dnn/src/cuda/multi_head_attn/opr_impl.h @@ -0,0 +1,59 @@ +#pragma once +#include "megdnn/handle.h" +#include "megdnn/oprs.h" +#include "src/common/reduce_helper.h" +#include "src/cuda/cudnn_wrapper.h" +#include "src/cuda/handle.h" +#include "src/cuda/multi_head_attn/helper.h" +#include "src/cuda/utils.h" + +namespace megdnn { +namespace cuda { + +class MultiHeadAttnForwardImpl final : public MultiHeadAttnForward { +public: + using MultiHeadAttnForward::MultiHeadAttnForward; +#if CUDNN_VERSION >= 8004 + MultiHeadAttnStatus desc_status; +#endif + + void exec( + _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values, + _megdnn_tensor_in wqkv, _megdnn_tensor_out out, + _megdnn_tensor_out reserveSpace, _megdnn_workspace workspace) override; + void deduce_layout( + const TensorLayout& queries, const TensorLayout& keys, + const TensorLayout& values, const TensorLayout& wqkv, TensorLayout& out, + TensorLayout& reserveSpace); + size_t get_reservespace_in_bytes( + const TensorLayout& queries, const TensorLayout& keys, + const TensorLayout& values, const TensorLayout& wqkv, + const TensorLayout& out, const TensorLayout& reserveSpace) override; + size_t get_workspace_in_bytes( + const TensorLayout& queries, const TensorLayout& keys, + const TensorLayout& values, const TensorLayout& wqkv, + const TensorLayout& out, const TensorLayout& reserveSpace) override; +}; + +class MultiHeadAttnBackwardImpl final : public MultiHeadAttnBackward { +public: + using MultiHeadAttnBackward::MultiHeadAttnBackward; +#if CUDNN_VERSION >= 8004 + MultiHeadAttnStatus desc_status; +#endif + void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys, + _megdnn_tensor_in values, _megdnn_tensor_in wqkv, + _megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries, + _megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues, + _megdnn_tensor_out dweights, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& diff, const TensorLayout& queries, + const TensorLayout& keys, const TensorLayout& values, + const TensorLayout& wqkv, const TensorLayout& reserveSpace, + const TensorLayout& dqueries, const TensorLayout& dkeys, + const TensorLayout& dvalues, const TensorLayout& dweights) override; +}; +} // namespace cuda +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index bc1db909c00094eb3c8457a7237d33ff512fae65..dc741f937669be0a9b88465d2f6a1660dcae4fa4 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -54,6 +54,7 @@ #include "src/naive/matrix_mul/opr_impl.h" #include "src/naive/max_tensor_diff/opr_impl.h" #include "src/naive/mesh_indexing/opr_impl.h" +#include "src/naive/multi_head_attn/opr_impl.h" #include "src/naive/norm/opr_impl.h" #include "src/naive/padding/opr_impl.h" #include "src/naive/param_pack/opr_impl.h" diff --git a/dnn/src/naive/multi_head_attn/opr_impl.cpp b/dnn/src/naive/multi_head_attn/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..773810601f459c873f86e25f7db6add8c4af6d51 --- /dev/null +++ b/dnn/src/naive/multi_head_attn/opr_impl.cpp @@ -0,0 +1,56 @@ +#include "src/naive/multi_head_attn/opr_impl.h" +#include "megdnn/oprs/linalg.h" +#include "src/common/utils.cuh" + +namespace megdnn { +namespace naive { + +using Param = MultiHeadAttnBase::Param; + +size_t MultiHeadAttnForwardImpl::get_workspace_in_bytes( + const TensorLayout& queries, const TensorLayout& keys, + const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& out, + const TensorLayout& reserveSpace) { + MEGDNN_MARK_USED_VAR(queries); + MEGDNN_MARK_USED_VAR(keys); + MEGDNN_MARK_USED_VAR(values); + MEGDNN_MARK_USED_VAR(wqkv); + MEGDNN_MARK_USED_VAR(out); + MEGDNN_MARK_USED_VAR(reserveSpace); + megdnn_throw("unsupported naive multiheadattn forward\n"); +} + +void MultiHeadAttnForwardImpl::exec( + _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values, + _megdnn_tensor_in wqkv, _megdnn_tensor_out out, _megdnn_tensor_out reserveSpace, + _megdnn_workspace workspace) { + MEGDNN_MARK_USED_VAR(queries); + MEGDNN_MARK_USED_VAR(keys); + MEGDNN_MARK_USED_VAR(values); + MEGDNN_MARK_USED_VAR(wqkv); + MEGDNN_MARK_USED_VAR(out); + MEGDNN_MARK_USED_VAR(reserveSpace); + check_exec( + queries.layout, keys.layout, values.layout, wqkv.layout, out.layout, + reserveSpace.layout, workspace.size); + + megdnn_throw("unsupported naive multiheadattn forward\n"); +} + +void MultiHeadAttnBackwardImpl::exec( + _megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys, + _megdnn_tensor_in values, _megdnn_tensor_in wqkv, + _megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries, + _megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues, + _megdnn_tensor_out dweights, _megdnn_workspace workspace) { + check_exec( + diff.layout, queries.layout, keys.layout, values.layout, wqkv.layout, + reserveSpace.layout, dqueries.layout, dkeys.layout, dvalues.layout, + dweights.layout, workspace.size); + + megdnn_throw("unsupported naive multiheadattn backward\n"); +} + +} // namespace naive +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/multi_head_attn/opr_impl.h b/dnn/src/naive/multi_head_attn/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..5fb1cba9120b545b1167757045cd55e37942e16a --- /dev/null +++ b/dnn/src/naive/multi_head_attn/opr_impl.h @@ -0,0 +1,55 @@ +#pragma once +#include <memory> +#include "megdnn/oprs.h" +#include "megdnn/oprs/cv.h" +#include "megdnn/oprs/general.h" +#include "megdnn/oprs/linalg.h" +#include "megdnn/oprs/nn.h" + +namespace megdnn { +namespace naive { + +class MultiHeadAttnForwardImpl final : public MultiHeadAttnForward { +public: + using MultiHeadAttnForward::MultiHeadAttnForward; + void exec( + _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values, + _megdnn_tensor_in wqkv, _megdnn_tensor_out out, + _megdnn_tensor_out reserveSpace, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& queries, const TensorLayout& keys, + const TensorLayout& values, const TensorLayout& wqkv, + const TensorLayout& out, const TensorLayout& reserveSpace) override; + size_t get_reservespace_in_bytes( + const TensorLayout& /*queries*/, const TensorLayout& /*keys*/, + const TensorLayout& /*values*/, const TensorLayout& /*wqkv*/, + const TensorLayout& /*out*/, + const TensorLayout& /*reserveSpace*/) override { + return 0; + } +}; + +class MultiHeadAttnBackwardImpl final : public MultiHeadAttnBackward { +public: + using MultiHeadAttnBackward::MultiHeadAttnBackward; + void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys, + _megdnn_tensor_in values, _megdnn_tensor_in wqkv, + _megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries, + _megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues, + _megdnn_tensor_out dweights, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& /*diff*/, const TensorLayout& /* queries*/, + const TensorLayout& /*keyes*/, const TensorLayout& /* values*/, + const TensorLayout& /*wqkv*/, const TensorLayout& /* reserveSpace*/, + const TensorLayout& /*dqueries*/, const TensorLayout& /* dkeyes*/, + const TensorLayout& /*dvalues*/, + const TensorLayout& /* dweights*/) override { + return 0; + } +}; + +} // namespace naive +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 22a780f8e103d21bd5b201fd72310185f3037804..4bed60525e75c0a61cc314b7432641b1f3ab3ee5 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -35,7 +35,7 @@ from ..core.tensor.utils import ( subgraph, subgraph_fn, ) -from ..device import get_default_device +from ..device import get_cudnn_version, get_default_device, is_cuda_available from ..distributed import WORLD, is_distributed from ..jit import exclude_from_trace from ..logger import get_logger @@ -104,6 +104,7 @@ __all__ = [ "warp_perspective", "pixel_shuffle", "region_restricted_conv", + "multi_head_attention", ] @@ -1053,7 +1054,7 @@ def instance_norm( r"""Applies instance normalization to the input. Refer to :class:`~.InstanceNorm` for more information. - + Args: inp: input tensor. affine: whether to use learnable affine parameters (weight, bias) @@ -1083,7 +1084,7 @@ def group_norm( r"""Applies group normalization to the input. Refer to :class:`~.GroupNorm` for more information. - + Args: inp: input tensor. num_groups: number of groups to separate the channels into @@ -2052,7 +2053,82 @@ def region_restricted_conv( return output -from .quantized import conv_bias_activation # isort:skip +def multi_head_attention( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim: int, + num_heads: int, + attn_drop: float, + out_drop: float, + io_weight_bias: Optional[Tensor], + bias: bool = False, + reslink: bool = False, + training: bool = True, + attn_mask: bool = False, + enable_qproj: bool = True, + enable_kproj: bool = True, + enable_vproj: bool = True, + enable_oproj: bool = True, +): + r"""Allows the model to jointly attend to information + from different representation subspaces. + See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_. + + .. math:: + \text{MultiHeadAttn}\big(q,K,V, W_Q, W_V, W_O\big) = \sum^{nHeads-1}_{i=0}W_{O,i}h_i + + where :math:`h_i=W_{V,i}V \text{Softmax}\Big( \text{smScaler} \cdot K^TW^T_{K,i}W_{Q,i}q \Big),\text{for }i\text{ = 0 ... nHeads-1}`. + + See :class:`~.module.MultiHeadAttn` for more details. + + Note: This API is experimental, and there is a possibility of subsequent changes. Currently, only the cuda platform is supported, and if the cudnn version >=8.6.0, the calculation results are completely correct; If the cudnn version >=8.0.4 but <8.6.0, if there is a bias, only the dbias result calculated from the backward is incorrect. If there is no bias, the forward and backward calculations are correct; If the cudnn version is less than 8.0.4, this operator is not supported. + + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + attn_drop: probability of an element to be zeroed, used in attention matrix. + out_drop: probability of an element to be zeroed, used in final output. + io_weight_bias: input/output projection weight/bias all in one, used for cudnn api. + bias: used to indicate a bias in io_weight_bias, used for cudnn api. + reslink: add input query to final output. + training: will apply dropout if is ``True``. + attn_mask: used to indicate whether to add a mask to the attention matrix. + By default, the upper right triangle of the mask is -inf, and the diagonal and lower left triangle are all 0. + Default: `True` + enable_qproj: enable query weight projection. Default: ``True``. + enable_kproj: enable key weight projection. Default: ``True``. + enable_vproj: enable value weight projection. Default: ``True``. + enable_oproj: enable output weight projection. Default: ``True``. + """ + + head_dim = embed_dim // num_heads + smScaler = head_dim ** -0.5 + + op = builtin.MultiHeadAttn( + num_heads=num_heads, + sm_scaler=smScaler, + attn_prob=attn_drop, + out_prob=out_drop, + reslink=reslink, + training=training, + input_order=0, + seed=_get_global_rng_seed(), + bias=bias, + attn_mask=attn_mask, + enable_qproj=enable_qproj, + enable_kproj=enable_kproj, + enable_vproj=enable_vproj, + enable_oproj=enable_oproj, + ) + + out, reserveSpace = apply(op, query, key, value, io_weight_bias) + return out + + from .loss import * # isort:skip from .metric import * # isort:skip from .vision import * # isort:skip +from .quantized import conv_bias_activation # isort:skip diff --git a/imperative/python/megengine/module/__init__.py b/imperative/python/megengine/module/__init__.py index 985f4275e98cbca8b789b82ffe8ebe7468d5317c..89e1cbd0201ae285a65cddcd70ec6ae9aa4feafa 100644 --- a/imperative/python/megengine/module/__init__.py +++ b/imperative/python/megengine/module/__init__.py @@ -27,6 +27,7 @@ from .identity import Identity from .linear import Linear from .lrn import LocalResponseNorm from .module import Module +from .multiheadattn import MultiHeadAttention from .normalization import GeneralNorm, GroupNorm, InstanceNorm, LayerNorm from .padding import Pad from .pixel_shuffle import PixelShuffle diff --git a/imperative/python/megengine/module/activation.py b/imperative/python/megengine/module/activation.py index 5f1c7d092daca74cbe20db0393e73b1bb9d94d1e..9de1b4b44a1419eed05205a5e8336edd3a1ba8a3 100644 --- a/imperative/python/megengine/module/activation.py +++ b/imperative/python/megengine/module/activation.py @@ -3,6 +3,7 @@ import numpy as np from ..functional import gelu, leaky_relu, prelu, relu, sigmoid, silu, softmax from ..tensor import Parameter +from .init import ones_, zeros_ from .module import Module diff --git a/imperative/python/megengine/module/multiheadattn.py b/imperative/python/megengine/module/multiheadattn.py new file mode 100644 index 0000000000000000000000000000000000000000..ad80b7d62e1310578d7fd32a58aa032c1729a5cb --- /dev/null +++ b/imperative/python/megengine/module/multiheadattn.py @@ -0,0 +1,159 @@ +from typing import Optional + +import numpy as np + +import megengine as mge +import megengine.functional as F +from megengine import Parameter + +from ..device import get_cudnn_version, is_cuda_available +from ..functional.nn import multi_head_attention +from ..tensor import Tensor +from .init import ones_, zeros_ +from .module import Module + + +class MultiHeadAttention(Module): + r"""Allows the model to jointly attend to information + from different representation subspaces. + See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_. + + .. math:: + \text{MultiHeadAttn}\big(q,K,V, W_Q, W_V, W_O\big) = \sum^{nHeads-1}_{i=0}W_{O,i}h_i + + where :math:`h_i=W_{V,i}V \text{Softmax}\Big( \text{smScaler} \cdot K^TW^T_{K,i}W_{Q,i}q \Big),\text{for }i\text{ = 0 ... nHeads-1}`. + + Note: This API is experimental, and there is a possibility of subsequent changes. Currently, only the cuda platform is supported, and if the cudnn version >=8.6.0, the calculation results are completely correct; If the cudnn version >=8.0.4 but <8.6.0, if there is a bias, only the dbias result calculated from the backward is incorrect. If there is no bias, the forward and backward calculations are correct; If the cudnn version is less than 8.0.4, this operator is not supported. + + Args: + embed_dim: Total dimension of the model. + num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split + across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). + dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). + bias: If specified, adds bias to input / output projection layers. Default: ``True``. + kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). + vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). + enable_qproj: enable query weight projection. Default: ``True``. + enable_kproj: enable key weight projection. Default: ``True``. + enable_vproj: enable value weight projection. Default: ``True``. + enable_oproj: enable output weight projection. Default: ``True``. + + Examples:: + >>> import numpy as np + >>> batch_size, seq_len, embed_dim, num_heads = 2, 4, 4, 2 + >>> x = Tensor(np.arange(batch_size * seq_len * embed_dim).astype(np.float32).reshape(batch_size, seq_len, embed_dim)) + >>> multihead_attn = M.MultiHeadAttention(embed_dim, num_heads) + >>> if is_cuda_available() and get_cudnn_version() >= 8004: + ... out = multihead_attn(x, x, x) + ... out.numpy().shape + ... else: + ... print(np.zeros((2,4,4)).shape) + (2, 4, 4) + """ + + def __init__( + self, + embed_dim, + num_heads, + attn_dropout=0.0, + out_dropout=0.0, + kdim=None, + vdim=None, + bias=True, + enable_qproj=True, + enable_kproj=True, + enable_vproj=True, + enable_oproj=True, + **kwargs + ): + super().__init__(**kwargs) + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.attn_dropout = attn_dropout + self.out_dropout = out_dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + assert ( + self._qkv_same_embed_dim + ), "it does not support the case where q, k, and v are different." + self.bias = bias + + self.enable_qproj = enable_qproj + self.enable_kproj = enable_kproj + self.enable_vproj = enable_vproj + self.enable_oproj = enable_oproj + self.nproj = enable_qproj + enable_kproj + enable_vproj + enable_oproj + + if self.bias: + io_weight = np.ones((embed_dim, self.nproj * embed_dim)) + io_bias = np.zeros((1, self.nproj * embed_dim)) + self.io_weight_bias = Parameter( + np.concatenate((io_weight, io_bias), axis=0), dtype="float32" + ) + else: + self.io_weight_bias = Parameter( + np.ones((self.nproj * embed_dim, embed_dim), dtype="float32") + ) + self.reset_parameters() + + def reset_parameters(self): + self.attn_dropout = 0.0 + self.out_dropout = 0.0 + if self.bias: + io_weight = np.ones((self.embed_dim, self.nproj * self.embed_dim)) + io_bias = np.zeros((1, self.nproj * self.embed_dim)) + self.io_weight_bias._reset(np.concatenate((io_weight, io_bias), axis=0)) + else: + ones_(self.io_weight_bias) + + def forward( + self, query, key, value, attn_mask: bool = True, + ): + r""" + Args: + query: Query embeddings of shape :math:`(N, L, E_q)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, + and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against + key-value pairs to produce the output. See "Attention Is All You Need" for more details. + key: Key embeddings of shape :math:`(N, S, E_k)`, where :math:`N` is the batch size, :math:`S` is the source sequence length, and + :math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details. + value: Value embeddings of shape :math:`(N, S, E_v)`, where :math:`N` is the batch size, :math:`S` is the source sequence length, and + :math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details. + attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape + :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, + :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be + broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. + + + Outputs: + - **attn_output** - Attention outputs of shape :math:`(N, L, E)`, + where :math:`L` is the target sequence length, :math:`N` is + the batch size, and :math:`E` is the embedding dimension ``embed_dim``. + """ + + return multi_head_attention( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.attn_dropout, + self.out_dropout, + self.io_weight_bias, + self.bias, + training=self.training, + attn_mask=attn_mask, + enable_qproj=self.enable_qproj, + enable_kproj=self.enable_kproj, + enable_vproj=self.enable_vproj, + enable_oproj=self.enable_oproj, + ) + + def _module_info_string(self) -> str: + s = "embed_dim={embed_dim}, num_heads={num_heads}, dropout={dropout}, bias={bias}, kdim={kdim}, vdim={vdim}" + return s.format(**self.__dict__) diff --git a/imperative/src/impl/ops/rng.cpp b/imperative/src/impl/ops/rng.cpp index 50421e9b193d23ddeb4472af6b7002c6d391d2d6..19e012956c39392b483002b65d624df536ba61ed 100644 --- a/imperative/src/impl/ops/rng.cpp +++ b/imperative/src/impl/ops/rng.cpp @@ -285,6 +285,25 @@ struct OpMeth<Dropout> { } }; +template <> +struct OpMeth<MultiHeadAttn> { + using DnnOp = megdnn::MultiHeadAttn; + using Param = DnnOp::Param; + using OpNode = mgb::opr::MultiHeadAttn; + static Param make_param(const MultiHeadAttn& opdef) { + auto handle_seed = RNGDnnOpManager::get_seed(opdef.handle); + mgb_assert( + handle_seed == opdef.seed, + "inconsistent multiheadattn seed: dropout op: %lu handle: %lu", + handle_seed, opdef.seed); + return {opdef.num_heads, opdef.sm_scaler, opdef.input_order, + opdef.reslink, opdef.training, opdef.bias, + opdef.attn_mask, opdef.enable_qproj, opdef.enable_kproj, + opdef.enable_vproj, opdef.enable_oproj, handle_seed, + opdef.attn_prob, opdef.out_prob}; + } +}; + template <bool> struct _InferLayout; @@ -401,6 +420,14 @@ _INST_RNG_MAKER(2) #undef _FOR_EACH_OUT #undef _FOR_EACH_IN +#define _FOR_EACH_IN(subfix) \ + inputs[0] subfix, inputs[1] subfix, inputs[2] subfix, inputs[3] subfix, +#define _FOR_EACH_OUT(subfix) outputs[0] subfix, outputs[1] subfix +_INST_RNG_INVOLKER(4, 2) +_INST_RNG_MAKER(4) +#undef _FOR_EACH_OUT +#undef _FOR_EACH_IN + #undef _INST_RNG_INVOLKER #undef _INST_RNG_MAKER @@ -506,6 +533,39 @@ SmallVector<LogicalTensorDesc> infer_output_attrs<Dropout>( return dests; } +template <> +SmallVector<LogicalTensorDesc> infer_output_attrs<MultiHeadAttn>( + const OpDef& op, const SmallVector<TensorPtr>& inputs) { + SmallVector<LogicalTensorDesc> dests(2); + auto&& cn = inputs[0]->comp_node(); + + dests[0].comp_node = cn; + dests[0].layout = TensorLayout(inputs[0]->layout()); + dests[0].layout.dtype = inputs[0]->layout().dtype; + + auto get_reservespace_in_bytes = [&]() -> size_t { + // retrieve dnn_op from glob cache + auto&& rng = op.cast_final_safe<MultiHeadAttn>(); + auto handle = rng.handle; + if (!handle) { + handle = RNGDnnOpManager::get_default_handle(cn); + } + auto dnn_op_thread_safe = + RNGDnnOpManager::inst().get_dnn_op<megdnn::MultiHeadAttn>( + handle, reinterpret_cast<size_t>(op.dyn_typeinfo()), cn); + auto dnn_op = std::get<1>(dnn_op_thread_safe); + dnn_op->param() = OpMeth<MultiHeadAttn>::make_param(rng); + + return dnn_op->get_reservespace_in_bytes( + inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(), + inputs[3]->layout(), {}, {}); + }; + dests[1].comp_node = cn; + dests[1].layout = + TensorLayout(TensorShape({get_reservespace_in_bytes()}), dtype::Byte()); + return dests; +} + template <typename Op> SmallVector<TensorPtr> apply_on_physical_tensor( const OpDef& def, const SmallVector<TensorPtr>& inputs, @@ -600,6 +660,44 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dro return {dests, success}; } +template <> +std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible< + MultiHeadAttn>(const OpDef& op, const SmallVector<LogicalTensorDesc>& inputs) { + bool success = inputs[0].layout.ndim != 0; + + SmallVector<LogicalTensorDesc> dests(2); + auto cn = inputs[0].comp_node; + dests[0].comp_node = cn; + dests[0].layout = TensorLayout(inputs[0].layout); + dests[0].layout.dtype = inputs[0].layout.dtype; + + auto get_reservespace_in_bytes = [&]() -> size_t { + auto&& rng = op.cast_final_safe<MultiHeadAttn>(); + auto handle = rng.handle; + if (!handle) { + handle = RNGDnnOpManager::get_default_handle(cn); + } + auto dnn_op_thread_safe = + RNGDnnOpManager::inst().get_dnn_op<megdnn::MultiHeadAttn>( + handle, reinterpret_cast<size_t>(op.dyn_typeinfo()), cn); + auto dnn_op = std::get<1>(dnn_op_thread_safe); + dnn_op->param() = OpMeth<MultiHeadAttn>::make_param(rng); + + return dnn_op->get_reservespace_in_bytes( + inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout, + {}, {}); + }; + dests[1].comp_node = cn; + if (success) { + dests[1].layout = + TensorLayout(TensorShape({get_reservespace_in_bytes()}), dtype::Byte()); + } else { + dests[1].layout = TensorLayout(dtype::Byte()); + } + + return {dests, success}; +} + template <typename Op> SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( const OpDef& def, const SmallVector<TensorPtr>& inputs) { @@ -647,6 +745,7 @@ REG_RNG_OP(PoissonRNG, SymbolVar) REG_RNG_OP(BetaRNG, SymbolVar) REG_RNG_OP(ShuffleRNG, SymbolVarArray) REG_RNG_OP(Dropout, SymbolVarArray) +REG_RNG_OP(MultiHeadAttn, SymbolVarArray) #undef REG_RNG_OP } // namespace mgb::imperative::rng diff --git a/imperative/tablegen/generated/hash.txt b/imperative/tablegen/generated/hash.txt index 21afdaef6c39df0c0c0f2039bc66e7133a85182f..ab7a867652a7546e3a5a7787eb1999710dacef95 100644 --- a/imperative/tablegen/generated/hash.txt +++ b/imperative/tablegen/generated/hash.txt @@ -1,7 +1,7 @@ -148f3844ee8787250cd231eb3c1989c3 ../../dnn/scripts/opr_param_defs.py -b603857c46345dcb9f1693f49217b269 ../../src/core/include/megbrain/ir/ops.td -ae54e2eba267dc21d8c648963df23a90 generated/opdef.h.inl -94b20dcecd3dea69883d46ca7b8482be generated/opdef.cpp.inl -4ae5f0198e97e69eb381411f3d60e8c8 generated/opdef.py.inl -4971c6b2ba7f6fca395d73c554526a0e generated/opdef.cpy.inl +c5a5d1bd44473912f14cecee3df6409e ../../dnn/scripts/opr_param_defs.py +4ed3e8cbef0fa5f4d6824d8d55dec722 ../../src/core/include/megbrain/ir/ops.td +dc2d4ec8f4f5e203ce0a76bc20f62529 generated/opdef.h.inl +906957f12994d43c69248a6acfefa396 generated/opdef.cpp.inl +8817af8997ba0cc00048e71093755238 generated/opdef.py.inl +c43ae8b706e3f3658fe3cc0f60061981 generated/opdef.cpy.inl 71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h diff --git a/imperative/tablegen/generated/opdef.cpp.inl b/imperative/tablegen/generated/opdef.cpp.inl index 01ca5845807577001e4b52af9005ae10bebc8c21..bbde2e9b665fdb16eaa930a3819570c29d6fd314 100644 --- a/imperative/tablegen/generated/opdef.cpp.inl +++ b/imperative/tablegen/generated/opdef.cpp.inl @@ -5186,6 +5186,96 @@ OP_TRAIT_REG(MeshIndexing, MeshIndexing) .props(MeshIndexing_props_impl) .make_name(MeshIndexing_make_name_impl); +MGB_DYN_TYPE_OBJ_FINAL_IMPL(MultiHeadAttn); + +namespace { +size_t MultiHeadAttn_hash_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe<MultiHeadAttn>(); + static_cast<void>(op_); + + return mgb::hash_pair_combine( + mgb::hash(op_.dyn_typeinfo()), + mgb::hash_pair_combine( + mgb::hash(op_.handle), + mgb::hash_pair_combine( + mgb::hash(op_.num_heads), + mgb::hash_pair_combine( + mgb::hash(op_.sm_scaler), + mgb::hash_pair_combine( + mgb::hash(op_.input_order), + mgb::hash_pair_combine( + mgb::hash(op_.reslink), + mgb::hash_pair_combine( + mgb::hash(op_.training), + mgb::hash_pair_combine( + mgb::hash(op_.bias), + mgb::hash_pair_combine( + mgb::hash(op_.attn_mask), + mgb::hash_pair_combine( + mgb::hash(op_.enable_qproj), + mgb::hash_pair_combine( + mgb::hash(op_.enable_kproj), + mgb::hash_pair_combine( + mgb::hash(op_.enable_vproj), + mgb::hash_pair_combine( + mgb::hash(op_.enable_oproj), + mgb::hash_pair_combine( + mgb::hash(op_.attn_prob), + mgb::hash(op_.out_prob) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ); + } +bool MultiHeadAttn_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { + auto &&a_ = lhs_.cast_final_safe<MultiHeadAttn>(), + &&b_ = rhs_.cast_final_safe<MultiHeadAttn>(); + static_cast<void>(a_); + static_cast<void>(b_); +return a_.handle == b_.handle && a_.num_heads == b_.num_heads && a_.sm_scaler == b_.sm_scaler && a_.input_order == b_.input_order && a_.reslink == b_.reslink && a_.training == b_.training && a_.bias == b_.bias && a_.attn_mask == b_.attn_mask && a_.enable_qproj == b_.enable_qproj && a_.enable_kproj == b_.enable_kproj && a_.enable_vproj == b_.enable_vproj && a_.enable_oproj == b_.enable_oproj && a_.attn_prob == b_.attn_prob && a_.out_prob == b_.out_prob;} +std::vector<std::pair<const char*, std::string>> MultiHeadAttn_props_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe<MultiHeadAttn>(); + static_cast<void>(op_); + std::vector<std::pair<const char*, std::string>> props_; + props_.emplace_back("num_heads", std::to_string(op_.num_heads)); + props_.emplace_back("sm_scaler", std::to_string(op_.sm_scaler)); + props_.emplace_back("input_order", std::to_string(op_.input_order)); + props_.emplace_back("reslink", std::to_string(op_.reslink)); + props_.emplace_back("training", std::to_string(op_.training)); + props_.emplace_back("bias", std::to_string(op_.bias)); + props_.emplace_back("attn_mask", std::to_string(op_.attn_mask)); + props_.emplace_back("enable_qproj", std::to_string(op_.enable_qproj)); + props_.emplace_back("enable_kproj", std::to_string(op_.enable_kproj)); + props_.emplace_back("enable_vproj", std::to_string(op_.enable_vproj)); + props_.emplace_back("enable_oproj", std::to_string(op_.enable_oproj)); + props_.emplace_back("seed", std::to_string(op_.seed)); + props_.emplace_back("attn_prob", std::to_string(op_.attn_prob)); + props_.emplace_back("out_prob", std::to_string(op_.out_prob)); + props_.emplace_back("handle", std::to_string(op_.handle)); + return props_; +} +std::string MultiHeadAttn_make_name_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe<MultiHeadAttn>(); + static_cast<void>(op_); + return "MultiHeadAttn"; +} +} // anonymous namespace +OP_TRAIT_REG(MultiHeadAttn, MultiHeadAttn) + .hash(MultiHeadAttn_hash_impl) + .is_same_st(MultiHeadAttn_is_same_st_impl) + .props(MultiHeadAttn_props_impl) + .make_name(MultiHeadAttn_make_name_impl); + MGB_DYN_TYPE_OBJ_FINAL_IMPL(NMSKeep); namespace { diff --git a/imperative/tablegen/generated/opdef.cpy.inl b/imperative/tablegen/generated/opdef.cpy.inl index d92339ce9f399e2052b7861aa0ba4fd91261f65f..98a5af8f2005284d566ee74cae7b5f3eb9e3df83 100644 --- a/imperative/tablegen/generated/opdef.cpy.inl +++ b/imperative/tablegen/generated/opdef.cpy.inl @@ -15043,6 +15043,367 @@ void _init_py_MeshIndexing(py::module m) { mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MeshIndexing::typeinfo(), &py_type).second); } +PyOpDefBegin(MultiHeadAttn) // { + static PyGetSetDef py_getsetters[]; + static PyMethodDef tp_methods[]; + + static PyObject* getstate(PyObject* self, PyObject*) { + auto& opdef = reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst(); + static_cast<void>(opdef); + std::unordered_map<std::string, py::object> state { + + {"num_heads", serialization<decltype(opdef.num_heads)>::dump(opdef.num_heads)}, + {"sm_scaler", serialization<decltype(opdef.sm_scaler)>::dump(opdef.sm_scaler)}, + {"input_order", serialization<decltype(opdef.input_order)>::dump(opdef.input_order)}, + {"reslink", serialization<decltype(opdef.reslink)>::dump(opdef.reslink)}, + {"training", serialization<decltype(opdef.training)>::dump(opdef.training)}, + {"bias", serialization<decltype(opdef.bias)>::dump(opdef.bias)}, + {"attn_mask", serialization<decltype(opdef.attn_mask)>::dump(opdef.attn_mask)}, + {"enable_qproj", serialization<decltype(opdef.enable_qproj)>::dump(opdef.enable_qproj)}, + {"enable_kproj", serialization<decltype(opdef.enable_kproj)>::dump(opdef.enable_kproj)}, + {"enable_vproj", serialization<decltype(opdef.enable_vproj)>::dump(opdef.enable_vproj)}, + {"enable_oproj", serialization<decltype(opdef.enable_oproj)>::dump(opdef.enable_oproj)}, + {"seed", serialization<decltype(opdef.seed)>::dump(opdef.seed)}, + {"attn_prob", serialization<decltype(opdef.attn_prob)>::dump(opdef.attn_prob)}, + {"out_prob", serialization<decltype(opdef.out_prob)>::dump(opdef.out_prob)}, + {"handle", serialization<decltype(opdef.handle)>::dump(opdef.handle)} + }; + return py::cast(state).release().ptr(); + } + static PyObject* setstate(PyObject* self, PyObject* args) { + PyObject* dict = PyTuple_GetItem(args, 0); + if (!dict) return NULL; + auto state = py::cast<std::unordered_map<std::string, py::object>>(dict); + auto& opdef = reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst(); + static_cast<void>(opdef); + + { + auto&& iter = state.find("num_heads"); + if (iter != state.end()) { + opdef.num_heads = serialization<decltype(opdef.num_heads)>::load(iter->second); + } + } + + { + auto&& iter = state.find("sm_scaler"); + if (iter != state.end()) { + opdef.sm_scaler = serialization<decltype(opdef.sm_scaler)>::load(iter->second); + } + } + + { + auto&& iter = state.find("input_order"); + if (iter != state.end()) { + opdef.input_order = serialization<decltype(opdef.input_order)>::load(iter->second); + } + } + + { + auto&& iter = state.find("reslink"); + if (iter != state.end()) { + opdef.reslink = serialization<decltype(opdef.reslink)>::load(iter->second); + } + } + + { + auto&& iter = state.find("training"); + if (iter != state.end()) { + opdef.training = serialization<decltype(opdef.training)>::load(iter->second); + } + } + + { + auto&& iter = state.find("bias"); + if (iter != state.end()) { + opdef.bias = serialization<decltype(opdef.bias)>::load(iter->second); + } + } + + { + auto&& iter = state.find("attn_mask"); + if (iter != state.end()) { + opdef.attn_mask = serialization<decltype(opdef.attn_mask)>::load(iter->second); + } + } + + { + auto&& iter = state.find("enable_qproj"); + if (iter != state.end()) { + opdef.enable_qproj = serialization<decltype(opdef.enable_qproj)>::load(iter->second); + } + } + + { + auto&& iter = state.find("enable_kproj"); + if (iter != state.end()) { + opdef.enable_kproj = serialization<decltype(opdef.enable_kproj)>::load(iter->second); + } + } + + { + auto&& iter = state.find("enable_vproj"); + if (iter != state.end()) { + opdef.enable_vproj = serialization<decltype(opdef.enable_vproj)>::load(iter->second); + } + } + + { + auto&& iter = state.find("enable_oproj"); + if (iter != state.end()) { + opdef.enable_oproj = serialization<decltype(opdef.enable_oproj)>::load(iter->second); + } + } + + { + auto&& iter = state.find("seed"); + if (iter != state.end()) { + opdef.seed = serialization<decltype(opdef.seed)>::load(iter->second); + } + } + + { + auto&& iter = state.find("attn_prob"); + if (iter != state.end()) { + opdef.attn_prob = serialization<decltype(opdef.attn_prob)>::load(iter->second); + } + } + + { + auto&& iter = state.find("out_prob"); + if (iter != state.end()) { + opdef.out_prob = serialization<decltype(opdef.out_prob)>::load(iter->second); + } + } + + { + auto&& iter = state.find("handle"); + if (iter != state.end()) { + opdef.handle = serialization<decltype(opdef.handle)>::load(iter->second); + } + } + Py_RETURN_NONE; + } + static int py_init(PyObject *self, PyObject *args, PyObject *kwds); + static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds); + static PyMethodDef py_init_methoddef; +// }; +PyOpDefEnd(MultiHeadAttn) + +int PyOp(MultiHeadAttn)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { + static const char* kwlist[] = {"num_heads", "sm_scaler", "input_order", "reslink", "training", "bias", "attn_mask", "enable_qproj", "enable_kproj", "enable_vproj", "enable_oproj", "seed", "attn_prob", "out_prob", "handle", "scope", NULL}; + PyObject *num_heads = NULL, *sm_scaler = NULL, *input_order = NULL, *reslink = NULL, *training = NULL, *bias = NULL, *attn_mask = NULL, *enable_qproj = NULL, *enable_kproj = NULL, *enable_vproj = NULL, *enable_oproj = NULL, *seed = NULL, *attn_prob = NULL, *out_prob = NULL, *handle = NULL, *scope = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOOOOOOOOOOOOO", const_cast<char**>(kwlist), &num_heads, &sm_scaler, &input_order, &reslink, &training, &bias, &attn_mask, &enable_qproj, &enable_kproj, &enable_vproj, &enable_oproj, &seed, &attn_prob, &out_prob, &handle, &scope)) + return -1; + + if (num_heads) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().num_heads = + py::cast<decltype(MultiHeadAttn::num_heads)>(py::handle(num_heads)); + } CATCH_ALL(-1) + } + + if (sm_scaler) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().sm_scaler = + py::cast<decltype(MultiHeadAttn::sm_scaler)>(py::handle(sm_scaler)); + } CATCH_ALL(-1) + } + + if (input_order) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().input_order = + py::cast<decltype(MultiHeadAttn::input_order)>(py::handle(input_order)); + } CATCH_ALL(-1) + } + + if (reslink) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().reslink = + py::cast<decltype(MultiHeadAttn::reslink)>(py::handle(reslink)); + } CATCH_ALL(-1) + } + + if (training) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().training = + py::cast<decltype(MultiHeadAttn::training)>(py::handle(training)); + } CATCH_ALL(-1) + } + + if (bias) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().bias = + py::cast<decltype(MultiHeadAttn::bias)>(py::handle(bias)); + } CATCH_ALL(-1) + } + + if (attn_mask) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().attn_mask = + py::cast<decltype(MultiHeadAttn::attn_mask)>(py::handle(attn_mask)); + } CATCH_ALL(-1) + } + + if (enable_qproj) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().enable_qproj = + py::cast<decltype(MultiHeadAttn::enable_qproj)>(py::handle(enable_qproj)); + } CATCH_ALL(-1) + } + + if (enable_kproj) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().enable_kproj = + py::cast<decltype(MultiHeadAttn::enable_kproj)>(py::handle(enable_kproj)); + } CATCH_ALL(-1) + } + + if (enable_vproj) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().enable_vproj = + py::cast<decltype(MultiHeadAttn::enable_vproj)>(py::handle(enable_vproj)); + } CATCH_ALL(-1) + } + + if (enable_oproj) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().enable_oproj = + py::cast<decltype(MultiHeadAttn::enable_oproj)>(py::handle(enable_oproj)); + } CATCH_ALL(-1) + } + + if (seed) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().seed = + py::cast<decltype(MultiHeadAttn::seed)>(py::handle(seed)); + } CATCH_ALL(-1) + } + + if (attn_prob) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().attn_prob = + py::cast<decltype(MultiHeadAttn::attn_prob)>(py::handle(attn_prob)); + } CATCH_ALL(-1) + } + + if (out_prob) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().out_prob = + py::cast<decltype(MultiHeadAttn::out_prob)>(py::handle(out_prob)); + } CATCH_ALL(-1) + } + + if (handle) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().handle = + py::cast<decltype(MultiHeadAttn::handle)>(py::handle(handle)); + } CATCH_ALL(-1) + } + + if (scope) { + try { + reinterpret_cast<PyOp(OpDef)*>(self)->op + ->set_scope(py::cast<std::string>(py::handle(scope))); + } CATCH_ALL(-1) + } + + return 0; +} + +PyGetSetDef PyOp(MultiHeadAttn)::py_getsetters[] = { + {const_cast<char*>("num_heads"), py_get_generic(MultiHeadAttn, num_heads), py_set_generic(MultiHeadAttn, num_heads), const_cast<char*>("num_heads"), NULL}, + {const_cast<char*>("sm_scaler"), py_get_generic(MultiHeadAttn, sm_scaler), py_set_generic(MultiHeadAttn, sm_scaler), const_cast<char*>("sm_scaler"), NULL}, + {const_cast<char*>("input_order"), py_get_generic(MultiHeadAttn, input_order), py_set_generic(MultiHeadAttn, input_order), const_cast<char*>("input_order"), NULL}, + {const_cast<char*>("reslink"), py_get_generic(MultiHeadAttn, reslink), py_set_generic(MultiHeadAttn, reslink), const_cast<char*>("reslink"), NULL}, + {const_cast<char*>("training"), py_get_generic(MultiHeadAttn, training), py_set_generic(MultiHeadAttn, training), const_cast<char*>("training"), NULL}, + {const_cast<char*>("bias"), py_get_generic(MultiHeadAttn, bias), py_set_generic(MultiHeadAttn, bias), const_cast<char*>("bias"), NULL}, + {const_cast<char*>("attn_mask"), py_get_generic(MultiHeadAttn, attn_mask), py_set_generic(MultiHeadAttn, attn_mask), const_cast<char*>("attn_mask"), NULL}, + {const_cast<char*>("enable_qproj"), py_get_generic(MultiHeadAttn, enable_qproj), py_set_generic(MultiHeadAttn, enable_qproj), const_cast<char*>("enable_qproj"), NULL}, + {const_cast<char*>("enable_kproj"), py_get_generic(MultiHeadAttn, enable_kproj), py_set_generic(MultiHeadAttn, enable_kproj), const_cast<char*>("enable_kproj"), NULL}, + {const_cast<char*>("enable_vproj"), py_get_generic(MultiHeadAttn, enable_vproj), py_set_generic(MultiHeadAttn, enable_vproj), const_cast<char*>("enable_vproj"), NULL}, + {const_cast<char*>("enable_oproj"), py_get_generic(MultiHeadAttn, enable_oproj), py_set_generic(MultiHeadAttn, enable_oproj), const_cast<char*>("enable_oproj"), NULL}, + {const_cast<char*>("seed"), py_get_generic(MultiHeadAttn, seed), py_set_generic(MultiHeadAttn, seed), const_cast<char*>("seed"), NULL}, + {const_cast<char*>("attn_prob"), py_get_generic(MultiHeadAttn, attn_prob), py_set_generic(MultiHeadAttn, attn_prob), const_cast<char*>("attn_prob"), NULL}, + {const_cast<char*>("out_prob"), py_get_generic(MultiHeadAttn, out_prob), py_set_generic(MultiHeadAttn, out_prob), const_cast<char*>("out_prob"), NULL}, + {const_cast<char*>("handle"), py_get_generic(MultiHeadAttn, handle), py_set_generic(MultiHeadAttn, handle), const_cast<char*>("handle"), NULL}, + {NULL} /* Sentinel */ +}; + + PyMethodDef PyOp(MultiHeadAttn)::tp_methods[] = { + {const_cast<char*>("__getstate__"), PyOp(MultiHeadAttn)::getstate, METH_NOARGS, "MultiHeadAttn getstate"}, + {const_cast<char*>("__setstate__"), PyOp(MultiHeadAttn)::setstate, METH_VARARGS, "MultiHeadAttn setstate"}, + {NULL} /* Sentinel */ + }; + +PyObject *PyOp(MultiHeadAttn)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) { + if (PyOp(MultiHeadAttn)::py_init(self, args, kwds) < 0) { + return NULL; + } + Py_RETURN_NONE; +} + +PyMethodDef PyOp(MultiHeadAttn)::py_init_methoddef = { + "__init__", + (PyCFunction)PyOp(MultiHeadAttn)::py_init_proxy, + METH_VARARGS | METH_KEYWORDS, + "__init__(self, num_heads: int = ..., sm_scaler: float = ..., input_order: int = ..., reslink: bool = ..., training: bool = ..., bias: bool = ..., attn_mask: bool = ..., enable_qproj: bool = ..., enable_kproj: bool = ..., enable_vproj: bool = ..., enable_oproj: bool = ..., seed: int = ..., attn_prob: float = ..., out_prob: float = ..., handle: int = ...) -> None\n" +}; + +void _init_py_MultiHeadAttn(py::module m) { + using py_op = PyOp(MultiHeadAttn); + auto& py_type = PyOpType(MultiHeadAttn); + py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; + py_type.tp_name = "megengine.core._imperative_rt.ops.MultiHeadAttn"; + py_type.tp_basicsize = sizeof(PyOp(MultiHeadAttn)); + py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + py_type.tp_doc = "MultiHeadAttn"; + py_type.tp_base = &PyOpType(OpDef); + py_type.tp_dealloc = py_dealloc_generic<py_op>; + py_type.tp_new = py_new_generic<py_op>; + py_type.tp_init = py_op::py_init; + py_type.tp_methods = py_op::tp_methods; + py_type.tp_getset = py_op::py_getsetters; + + py_type.tp_dict = PyDict_New(); + PyObject* descr = PyDescr_NewMethod(&PyOpType(MultiHeadAttn), &PyOp(MultiHeadAttn)::py_init_methoddef); + PyDict_SetItemString(py_type.tp_dict, "__init__", descr); + mgb_assert(PyType_Ready(&py_type) >= 0); + + PyType_Modified(&py_type); + m.add_object("MultiHeadAttn", reinterpret_cast<PyObject*>(&py_type)); + mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MultiHeadAttn::typeinfo(), &py_type).second); +} + PyOpDefBegin(NMSKeep) // { static PyGetSetDef py_getsetters[]; static PyMethodDef tp_methods[]; @@ -22608,6 +22969,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) { _init_py_MatrixMul(m); \ _init_py_MeshGrid(m); \ _init_py_MeshIndexing(m); \ + _init_py_MultiHeadAttn(m); \ _init_py_NMSKeep(m); \ _init_py_NvOf(m); \ _init_py_Padding(m); \ diff --git a/imperative/tablegen/generated/opdef.h.inl b/imperative/tablegen/generated/opdef.h.inl index e3bd9b9fca3964e330db267585758d288d0c02ff..5f774aaae359105370fe43db18fe10512118414b 100644 --- a/imperative/tablegen/generated/opdef.h.inl +++ b/imperative/tablegen/generated/opdef.h.inl @@ -1394,6 +1394,33 @@ public: MeshIndexing(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); } }; +class MultiHeadAttn : public OpDefImplBase<MultiHeadAttn> { + MGB_DYN_TYPE_OBJ_FINAL_DECL; + +public: + uint32_t num_heads = 1; + float sm_scaler = 1.f; + uint32_t input_order = 0; + bool reslink = false; + bool training = true; + bool bias = false; + bool attn_mask = false; + bool enable_qproj = true; + bool enable_kproj = true; + bool enable_vproj = true; + bool enable_oproj = true; + uint64_t seed = 0; + float attn_prob = 0.f; + float out_prob = 0.f; + size_t handle; + MultiHeadAttn() = default; + MultiHeadAttn(uint32_t num_heads_, float sm_scaler_, uint32_t input_order_, bool reslink_, bool training_, bool bias_, bool attn_mask_, bool enable_qproj_, bool enable_kproj_, bool enable_vproj_, bool enable_oproj_, uint64_t seed_, float attn_prob_, float out_prob_, size_t handle_, std::string scope_ = {}): num_heads(num_heads_), sm_scaler(sm_scaler_), input_order(input_order_), reslink(reslink_), training(training_), bias(bias_), attn_mask(attn_mask_), enable_qproj(enable_qproj_), enable_kproj(enable_kproj_), enable_vproj(enable_vproj_), enable_oproj(enable_oproj_), seed(seed_), attn_prob(attn_prob_), out_prob(out_prob_), handle(handle_) { set_scope(scope_); } + MultiHeadAttn(::megdnn::param::MultiHeadAttn packed_param_0, size_t handle_): num_heads(packed_param_0.num_heads), sm_scaler(packed_param_0.sm_scaler), input_order(packed_param_0.input_order), reslink(packed_param_0.reslink), training(packed_param_0.training), bias(packed_param_0.bias), attn_mask(packed_param_0.attn_mask), enable_qproj(packed_param_0.enable_qproj), enable_kproj(packed_param_0.enable_kproj), enable_vproj(packed_param_0.enable_vproj), enable_oproj(packed_param_0.enable_oproj), seed(packed_param_0.seed), attn_prob(packed_param_0.attn_prob), out_prob(packed_param_0.out_prob), handle(handle_) {} + ::megdnn::param::MultiHeadAttn param() const { + return {num_heads, sm_scaler, input_order, reslink, training, bias, attn_mask, enable_qproj, enable_kproj, enable_vproj, enable_oproj, seed, attn_prob, out_prob}; + } +}; + class NMSKeep : public OpDefImplBase<NMSKeep> { MGB_DYN_TYPE_OBJ_FINAL_DECL; diff --git a/imperative/tablegen/generated/opdef.py.inl b/imperative/tablegen/generated/opdef.py.inl index 39cf73a5b03880c1aec97575dd06276e30b6819e..b6591c362de4a3e22a8f2b79463ff9703abbe2b1 100644 --- a/imperative/tablegen/generated/opdef.py.inl +++ b/imperative/tablegen/generated/opdef.py.inl @@ -1477,6 +1477,27 @@ MeshIndexingInst .def(py::init<>()) .def_readwrite("items", &MeshIndexing::items); +py::class_<MultiHeadAttn, std::shared_ptr<MultiHeadAttn>, OpDef> MultiHeadAttnInst(m, "MultiHeadAttn"); + +MultiHeadAttnInst + .def(py::init<uint32_t, float, uint32_t, bool, bool, bool, bool, bool, bool, bool, bool, uint64_t, float, float, size_t, std::string>(), py::arg("num_heads") = 1, py::arg("sm_scaler") = 1.f, py::arg("input_order") = 0, py::arg("reslink") = false, py::arg("training") = true, py::arg("bias") = false, py::arg("attn_mask") = false, py::arg("enable_qproj") = true, py::arg("enable_kproj") = true, py::arg("enable_vproj") = true, py::arg("enable_oproj") = true, py::arg("seed") = 0, py::arg("attn_prob") = 0.f, py::arg("out_prob") = 0.f, py::arg("handle"), py::arg("scope") = {}) + .def(py::init<>()) + .def_readwrite("num_heads", &MultiHeadAttn::num_heads) + .def_readwrite("sm_scaler", &MultiHeadAttn::sm_scaler) + .def_readwrite("input_order", &MultiHeadAttn::input_order) + .def_readwrite("reslink", &MultiHeadAttn::reslink) + .def_readwrite("training", &MultiHeadAttn::training) + .def_readwrite("bias", &MultiHeadAttn::bias) + .def_readwrite("attn_mask", &MultiHeadAttn::attn_mask) + .def_readwrite("enable_qproj", &MultiHeadAttn::enable_qproj) + .def_readwrite("enable_kproj", &MultiHeadAttn::enable_kproj) + .def_readwrite("enable_vproj", &MultiHeadAttn::enable_vproj) + .def_readwrite("enable_oproj", &MultiHeadAttn::enable_oproj) + .def_readwrite("seed", &MultiHeadAttn::seed) + .def_readwrite("attn_prob", &MultiHeadAttn::attn_prob) + .def_readwrite("out_prob", &MultiHeadAttn::out_prob) + .def_readwrite("handle", &MultiHeadAttn::handle); + py::class_<NMSKeep, std::shared_ptr<NMSKeep>, OpDef> NMSKeepInst(m, "NMSKeep"); NMSKeepInst diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 2a672774942a09bde26c6b4b2cc26fd31c6ddc8c..c1f8eff0284b3fe810cdb824a5f0691fa3a05892 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -559,4 +559,57 @@ def RegionRestrictedConvolution: MgbHashableOp<"RegionRestrictedConvolution", [C def RegionRestrictedConvolutionBackwardData: MgbHashableOp<"RegionRestrictedConvolutionBackwardData", [ConvolutionParam]>; def MaskedFill: MgbHashableOp<"MaskedFill", [FillParam]>; +def MultiHeadAttn: MgbHashableOp<"MultiHeadAttn", [MultiHeadAttnParam]> { + let extraArguments = (ins + MgbSizeTAddr:$handle + ); + let hashFunction = [{ + return mgb::hash_pair_combine( + mgb::hash($_self.dyn_typeinfo()), + mgb::hash_pair_combine( + mgb::hash($_self.handle), + mgb::hash_pair_combine( + mgb::hash($_self.num_heads), + mgb::hash_pair_combine( + mgb::hash($_self.sm_scaler), + mgb::hash_pair_combine( + mgb::hash($_self.input_order), + mgb::hash_pair_combine( + mgb::hash($_self.reslink), + mgb::hash_pair_combine( + mgb::hash($_self.training), + mgb::hash_pair_combine( + mgb::hash($_self.bias), + mgb::hash_pair_combine( + mgb::hash($_self.attn_mask), + mgb::hash_pair_combine( + mgb::hash($_self.enable_qproj), + mgb::hash_pair_combine( + mgb::hash($_self.enable_kproj), + mgb::hash_pair_combine( + mgb::hash($_self.enable_vproj), + mgb::hash_pair_combine( + mgb::hash($_self.enable_oproj), + mgb::hash_pair_combine( + mgb::hash($_self.attn_prob), + mgb::hash($_self.out_prob) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ); + }]; + let cmpFunction = [{return $0.handle == $1.handle && $0.num_heads == $1.num_heads && $0.sm_scaler == $1.sm_scaler && $0.input_order == $1.input_order && $0.reslink == $1.reslink && $0.training == $1.training && $0.bias == $1.bias && $0.attn_mask == $1.attn_mask && $0.enable_qproj == $1.enable_qproj && $0.enable_kproj == $1.enable_kproj && $0.enable_vproj == $1.enable_vproj && $0.enable_oproj == $1.enable_oproj && $0.attn_prob == $1.attn_prob && $0.out_prob == $1.out_prob;}]; + +} + #endif // MGB_OPS diff --git a/src/opr/impl/internal/megdnn_opr_wrapper.inl b/src/opr/impl/internal/megdnn_opr_wrapper.inl index 861903663035910128d05c95552de1221de57fb6..f54ccec2f9fe19b90cdf00b04f2084503a34cdaf 100644 --- a/src/opr/impl/internal/megdnn_opr_wrapper.inl +++ b/src/opr/impl/internal/megdnn_opr_wrapper.inl @@ -159,6 +159,11 @@ using MegDNNOprMethInvoker = _MegDNNOprMethInvoker<Opr::NR_INPUTS, Opr::NR_OUTPU #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _o(0) #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" +#define _NR_INPUTS 4 +#define _NR_OUTPUTS 2 +#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _o(0), _o(1) +#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" + #define _NR_INPUTS 4 #define _NR_OUTPUTS 4 #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _o(0), _o(1), _o(2), _o(3) @@ -179,6 +184,12 @@ using MegDNNOprMethInvoker = _MegDNNOprMethInvoker<Opr::NR_INPUTS, Opr::NR_OUTPU #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1), _o(2) #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" +#define _NR_INPUTS 5 +#define _NR_OUTPUTS 5 +#define _FOREACH_IO(_i, _o) \ + _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1), _o(2), _o(3), _o(4) +#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" + #define _NR_INPUTS 6 #define _NR_OUTPUTS 1 #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0) @@ -195,6 +206,12 @@ using MegDNNOprMethInvoker = _MegDNNOprMethInvoker<Opr::NR_INPUTS, Opr::NR_OUTPU _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0), _o(1), _o(2) #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" +#define _NR_INPUTS 6 +#define _NR_OUTPUTS 4 +#define _FOREACH_IO(_i, _o) \ + _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0), _o(1), _o(2), _o(3) +#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" + #define _NR_INPUTS 7 #define _NR_OUTPUTS 3 #define _FOREACH_IO(_i, _o) \ diff --git a/src/opr/impl/rand.cpp b/src/opr/impl/rand.cpp index 83dda237dc2ebcc91f3257725cc3282f5a4b6511..b5dc33ced97b5a89bc238584ea3ec2de123f2eb1 100644 --- a/src/opr/impl/rand.cpp +++ b/src/opr/impl/rand.cpp @@ -192,6 +192,8 @@ template class RNGOprBase<::megdnn::ShuffleRNGForward>; template class RNGOprBase<::megdnn::ShuffleRNGBackward>; template class RNGOprBase<::megdnn::DropoutForward>; template class RNGOprBase<::megdnn::DropoutBackward>; +template class RNGOprBase<::megdnn::MultiHeadAttnForward>; +template class RNGOprBase<::megdnn::MultiHeadAttnBackward>; #if MGB_ENABLE_GRAD IMPL(GaussianRNG); IMPL(UniformRNG); @@ -375,7 +377,7 @@ MGB_IMPL_OPR_GRAD(DropoutForward) { } #endif -/* ==================== LayerNormBackward ==================== */ +/* ==================== DropoutBackward ==================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(DropoutBackward); @@ -421,4 +423,170 @@ void DropoutBackward::scn_do_execute() { output(0)->dev_tensor().as_megdnn(), {}); } +/* ==================== MultiHeadAttnForward ==================== */ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(MultiHeadAttnForward); + +MultiHeadAttnForward::MultiHeadAttnForward( + VarNode* queries, VarNode* keys, VarNode* values, VarNode* wqkv, + const Param& param, const OperatorNodeConfig& config) + : Super{{queries->owner_graph(), + config, + "multi_head_attn", + {queries, keys, values, wqkv}}, + param} { + add_input({queries, keys, values, wqkv}); + add_output(None) + ->dtype(queries->dtype()) + .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + add_output(None)->dtype(dtype::Byte()).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + cg::add_workspace_output(this); + add_equivalence_component<ScalarHash<void*>>(this); +} + +SymbolVarArray MultiHeadAttnForward::make( + SymbolVar queries, SymbolVar keys, SymbolVar values, SymbolVar wqkv, + const Param& param, const OperatorNodeConfig& config) { + auto outs = queries.node() + ->owner_graph() + ->insert_opr(std::make_unique<MultiHeadAttnForward>( + queries.node(), keys.node(), values.node(), wqkv.node(), + param, config)) + ->output(); + mgb_assert(outs.size() == 3); + return {outs[0], outs[1]}; +} + +void MultiHeadAttnForward::init_output_static_infer_desc() { + using namespace cg::static_infer; + auto&& mgr = owner_graph()->static_infer_manager(); + mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0))); + + auto infer_mask = [this](TensorShape& dest, const InpVal& iv) { + ensure_megdnn_opr(); + dest.ndim = 1; + dest.shape[0] = m_dnn_opr->get_reservespace_in_bytes( + {iv.val[0].shape(), input(0)->dtype()}, + {iv.val[1].shape(), input(1)->dtype()}, + {iv.val[2].shape(), input(2)->dtype()}, + {iv.val[3].shape(), input(3)->dtype()}, {}, {}); + return true; + }; + mgr.register_shape_infer( + output(1), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_mask}); +} + +void MultiHeadAttnForward::add_input_layout_constraint() { + input(0)->add_layout_constraint_contiguous(); + input(1)->add_layout_constraint_contiguous(); + input(2)->add_layout_constraint_contiguous(); + input(3)->add_layout_constraint_contiguous(); +}; + +void MultiHeadAttnForward::scn_do_execute() { + auto&& ret = output(0); + if (ret->layout().is_empty()) { + mgb_assert(ret->dev_tensor().empty()); + return; + } + m_dnn_opr->exec( + input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), + input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(), + output(0)->dev_tensor().as_megdnn(), output(1)->dev_tensor().as_megdnn(), + get_megdnn_workspace_from_var(output(2))); +} + +cg::OperatorNodeBase::NodeProp* MultiHeadAttnForward::do_make_node_prop() const { + auto prop = Super::do_make_node_prop(); + prop->add_flag(NodeProp::Flag::IMPURE_FUNC); + for (auto i : input()) { + prop->add_dep_type_existing_var(i, NodeProp::DepType::VALUE_ALLOW_EMPTY); + } + return prop; +} + +#if MGB_ENABLE_GRAD +MGB_IMPL_OPR_GRAD(MultiHeadAttnForward) { + MGB_MARK_USED_VAR(opr); + MGB_MARK_USED_VAR(out_grad); + SymbolVarArray grad; + VarNodeArray ret; + mgb_assert(wrt_idx < 5, "wrt_idx %zu is out of range", wrt_idx); + grad = MultiHeadAttnBackward::make( + out_grad[0], opr.input(0), opr.input(1), opr.input(2), opr.input(3), + opr.output(1), opr.param()); + + uint32_t nr_ret = 4; + for (uint32_t i = 0; i < nr_ret; ++i) { + ret.push_back(grad[i].node()); + } + return ret; +} +#endif + +/* ==================== MultiHeadAttnBackwardData ==================== */ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(MultiHeadAttnBackward); + +MultiHeadAttnBackward::MultiHeadAttnBackward( + VarNode* diff, VarNode* queries, VarNode* keys, VarNode* values, VarNode* wqkv, + VarNode* reserveSpace, const Param& param, const OperatorNodeConfig& config) + : Super({queries->owner_graph(), + config, + "multi_head_attn_backward", + {diff, queries, keys, values, wqkv, reserveSpace}}, + 0, true) { + init_megdnn_opr(*this, param); + add_input({diff, queries, keys, values, wqkv, reserveSpace}); +} + +SymbolVarArray MultiHeadAttnBackward::make( + SymbolVar diff, SymbolVar queries, SymbolVar keys, SymbolVar values, + SymbolVar wqkv, SymbolVar reserveSpace, const Param& param, + const OperatorNodeConfig& config) { + auto outs = queries.node() + ->owner_graph() + ->insert_opr(std::make_unique<MultiHeadAttnBackward>( + diff.node(), queries.node(), keys.node(), values.node(), + wqkv.node(), reserveSpace.node(), param, config)) + ->output(); + mgb_assert(outs.size() == 5); + return {outs[0], outs[1], outs[2], outs[3]}; +} + +void MultiHeadAttnBackward::init_output_static_infer_desc() { + using namespace cg::static_infer; + auto&& mgr = owner_graph()->static_infer_manager(); + mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(1))); + mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2))); + mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(3))); + mgr.register_shape_infer(output(3), ShapeInferDesc::make_identity(input(4))); + + this->init_output_static_infer_desc_workspace(false); +} + +void MultiHeadAttnBackward::init_output_dtype() { + output(0)->dtype(input(1)->dtype()); + output(1)->dtype(input(2)->dtype()); + output(2)->dtype(input(3)->dtype()); + output(3)->dtype(input(4)->dtype()); +} + +size_t MultiHeadAttnBackward::get_workspace_size_bytes( + const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) const { + MGB_MARK_USED_VAR(input_shapes); + MGB_MARK_USED_VAR(output_shapes); + + return 0; +} + +void MultiHeadAttnBackward::scn_do_execute() { + megdnn_opr()->exec( + input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), + input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(), + input(4)->dev_tensor().as_megdnn(), input(5)->dev_tensor().as_megdnn(), + output(0)->dev_tensor().as_megdnn(), output(1)->dev_tensor().as_megdnn(), + output(2)->dev_tensor().as_megdnn(), output(3)->dev_tensor().as_megdnn(), + {}); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/rand.sereg.h b/src/opr/impl/rand.sereg.h index a5d1a91521a2df67943b6b601458d2e88617ccb9..0333543342cea1c3892711cc14f0b3549aaa6605 100644 --- a/src/opr/impl/rand.sereg.h +++ b/src/opr/impl/rand.sereg.h @@ -30,6 +30,35 @@ struct OprMaker<opr::DropoutForward, 1> { } }; +template <> +struct OprMaker<opr::MultiHeadAttn, 0> { + using Param = opr::MultiHeadAttn::Param; + static cg::OperatorNodeBase* make( + const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, + const OperatorNodeConfig& config) { + MGB_MARK_USED_VAR(graph); + return opr::MultiHeadAttn::make(i[0], i[1], i[2], i[3], param, config)[0] + .node() + ->owner_opr(); + } +}; + +// OprMaker in MGB_SEREG_OPR only support unique output opr +template <> +struct OprMaker<opr::MultiHeadAttnBackward, 0> { + using Param = opr::MultiHeadAttnBackward::Param; + static cg::OperatorNodeBase* make( + const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, + const OperatorNodeConfig& config) { + MGB_MARK_USED_VAR(graph); + + return opr::MultiHeadAttnBackward::make( + i[0], i[1], i[2], i[3], i[4], i[5], param, config)[0] + .node() + ->owner_opr(); + } +}; + } // namespace serialization namespace opr { @@ -46,6 +75,8 @@ MGB_SEREG_OPR(ShuffleRNG, 1); MGB_SEREG_OPR(ShuffleRNGBackward, 3); MGB_SEREG_OPR(Dropout, 1); MGB_SEREG_OPR(DropoutBackward, 2); +MGB_SEREG_OPR(MultiHeadAttn, 0); +MGB_SEREG_OPR(MultiHeadAttnBackward, 0); } // namespace opr } // namespace mgb diff --git a/src/opr/include/megbrain/opr/rand.h b/src/opr/include/megbrain/opr/rand.h index fd482d608ffbd63913a7a5357213daca996b999a..d5d79352cd97fc806433e94321bfaf4e0d68bf98 100644 --- a/src/opr/include/megbrain/opr/rand.h +++ b/src/opr/include/megbrain/opr/rand.h @@ -86,6 +86,13 @@ _DEFINE_RNG_OPR_WITH_INPUT_CLASS(BetaRNG) _DEFINE_RNG_OPR_WITH_INPUT_CLASS(GammaRNG) #undef _OUTPUTS #undef _INPUTS + +/* ================= 4 input ================= */ +#define _INPUTS(preifx) preifx i0, preifx i1, preifx i2, preifx i3 +#define _OUTPUTS SymbolVarArray +_DEFINE_RNG_OPR_WITH_INPUT_CLASS(MultiHeadAttnForward) +#undef _OUTPUTS +#undef _INPUTS #undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS } // namespace intl @@ -99,6 +106,7 @@ using BetaRNG = intl::BetaRNG; using ShuffleRNG = intl::ShuffleRNGForward; using Dropout = intl::DropoutForward; using DropoutForward = intl::DropoutForward; +using MultiHeadAttn = intl::MultiHeadAttnForward; MGB_DEFINE_OPR_CLASS_WITH_EXPORT( ShuffleRNGBackward, intl::MegDNNOprWrapperBwd<megdnn::ShuffleRNGBackward>) // { @@ -132,6 +140,29 @@ private: void scn_do_execute() override; }; +MGB_DEFINE_OPR_CLASS_WITH_EXPORT( + MultiHeadAttnBackward, + intl::MegDNNOprWrapperBwd<megdnn::MultiHeadAttnBackward>) // { +public: + MGE_WIN_DECLSPEC_FUC MultiHeadAttnBackward( + VarNode* diff, VarNode* queries, VarNode* keys, VarNode* values, + VarNode* wqkv, VarNode* reserveSpace, const Param& param, + const OperatorNodeConfig& config); + + MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( + SymbolVar diff, SymbolVar queries, SymbolVar keys, SymbolVar values, + SymbolVar wqkv, SymbolVar reserveSpace, const Param& param = {}, + const OperatorNodeConfig& config = {}); + +private: + void init_output_static_infer_desc() override; + void init_output_dtype() override; + size_t get_workspace_size_bytes( + const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) const override; + void scn_do_execute() override; +}; + } // namespace opr } // namespace mgb diff --git a/src/serialization/impl/schema.fbs b/src/serialization/impl/schema.fbs index 9bcb8911c50057be124a531c9729cb1eb1a40a9d..8f647494f6d4e81568ce31d453d6bb6ebed94743 100644 --- a/src/serialization/impl/schema.fbs +++ b/src/serialization/impl/schema.fbs @@ -126,6 +126,7 @@ union OperatorParam { param.GroupNorm = 92, param.Fill = 93, param.GeneralNorm=94, + param.MultiHeadAttn=95, } table Operator { diff --git a/src/serialization/impl/schema_v2.fbs b/src/serialization/impl/schema_v2.fbs index 7c2c89ff036527f2e0be7eb85dc93033cfc4c953..add03d695b8a563d0f70793f2238646ebcabd757 100644 --- a/src/serialization/impl/schema_v2.fbs +++ b/src/serialization/impl/schema_v2.fbs @@ -143,6 +143,7 @@ union OperatorParam { param.GroupNorm = 92, param.Fill = 93, param.GeneralNorm=94, + param.MultiHeadAttn=95, } table Operator {