diff --git a/dnn/include/megdnn/oprs/nn.h b/dnn/include/megdnn/oprs/nn.h
index 7d8c25d647c643f17ce05df34861b9c158e2c951..fa0be658e7c8ace2f1e6c8fdab8ed2971ac6c6cc 100644
--- a/dnn/include/megdnn/oprs/nn.h
+++ b/dnn/include/megdnn/oprs/nn.h
@@ -2574,6 +2574,73 @@ protected:
             size_t workspace_in_bytes);
 };
 
+class MultiHeadAttnBase : public OperatorBase {
+    DEF_OPR_IMPL_CTOR(MultiHeadAttnBase, OperatorBase);
+    DEF_OPR_PARAM(MultiHeadAttn);
+};
+
+class MultiHeadAttnForward : public MultiHeadAttnBase {
+    DEF_OPR_IMPL(MultiHeadAttnForward, MultiHeadAttnBase, 4, 2);
+
+public:
+    virtual void exec(
+            _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values,
+            _megdnn_tensor_in wqkv, _megdnn_tensor_out out,
+            _megdnn_tensor_out reserveSpace, _megdnn_workspace workspace) = 0;
+    MGE_WIN_DECLSPEC_FUC void deduce_layout(
+            const TensorLayout& queries, const TensorLayout& keys,
+            const TensorLayout& values, const TensorLayout& wqkv, TensorLayout& out,
+            TensorLayout& reserveSpace);
+    virtual size_t get_workspace_in_bytes(
+            const TensorLayout& queries, const TensorLayout& keys,
+            const TensorLayout& values, const TensorLayout& wqkv,
+            const TensorLayout& out, const TensorLayout& reserveSpace) = 0;
+    virtual size_t get_reservespace_in_bytes(
+            const TensorLayout& queries, const TensorLayout& keys,
+            const TensorLayout& values, const TensorLayout& wqkv,
+            const TensorLayout& out, const TensorLayout& reserveSpace) = 0;
+
+protected:
+    void check_exec(
+            const TensorLayout& queries, const TensorLayout& keys,
+            const TensorLayout& values, const TensorLayout& wqkv,
+            const TensorLayout& out, const TensorLayout& reserveSpace,
+            size_t workspace_in_bytes);
+};
+using MultiHeadAttn = MultiHeadAttnForward;
+
+class MultiHeadAttnBackward : public MultiHeadAttnBase {
+    DEF_OPR_IMPL(MultiHeadAttnBackward, MultiHeadAttnBase, 6, 4);
+
+public:
+    virtual void exec(
+            _megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys,
+            _megdnn_tensor_in values, _megdnn_tensor_in wqkv,
+            _megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries,
+            _megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues,
+            _megdnn_tensor_out dweights, _megdnn_workspace workspace) = 0;
+    MGE_WIN_DECLSPEC_FUC void deduce_layout(
+            const TensorLayout& diff, const TensorLayout& queries,
+            const TensorLayout& keys, const TensorLayout& values,
+            const TensorLayout& wqkv, const TensorLayout& reserveSpace,
+            TensorLayout& dqueries, TensorLayout& dkeys, TensorLayout& dvalues,
+            TensorLayout& dweights);
+    virtual size_t get_workspace_in_bytes(
+            const TensorLayout& diff, const TensorLayout& queries,
+            const TensorLayout& keys, const TensorLayout& values,
+            const TensorLayout& wqkv, const TensorLayout& reserveSpace,
+            const TensorLayout& dqueries, const TensorLayout& dkeys,
+            const TensorLayout& dvalues, const TensorLayout& dweights) = 0;
+
+protected:
+    void check_exec(
+            const TensorLayout& diff, const TensorLayout& queries,
+            const TensorLayout& keys, const TensorLayout& values,
+            const TensorLayout& wqkv, const TensorLayout& reserveSpace,
+            const TensorLayout& dqueries, const TensorLayout& dkeys,
+            const TensorLayout& dvalues, const TensorLayout& dweights,
+            size_t workspace_in_bytes);
+};
 }  // namespace megdnn
 #include "megdnn/internal/opr_header_epilogue.h"
 
diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py
index 78cbaf4d3b9fe14da224ef862e4a2a82a6ecf619..33fe4cbace400b96707866681e1b2fc184e42899 100755
--- a/dnn/scripts/opr_param_defs.py
+++ b/dnn/scripts/opr_param_defs.py
@@ -1330,3 +1330,20 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'),
  add_fields('float32', Doc('p', 'the order of norm'), '2').
  add_fields('int32', Doc('dim', 'which dim the norm performed along'), '-1'),
  )
+
+(pdef('MultiHeadAttn')
+ .add_fields('uint32', Doc('num_heads', 'Number of parallel attention heads.'), '1')
+ .add_fields('float32', Doc('sm_scaler', 'Softmax smoothing (1.0 >= smScaler >= 0.0) or sharpening (smScaler > 1.0) coefficient.'), '1.f')
+ .add_fields('uint32', Doc('input_order', 'The sequence data layout, allows the user to select 3! = 6 different data layouts or permutations of BEAM, BATCH and TIME dimensions.'), '0')
+ .add_fields('bool', Doc('reslink', 'Whether to add input query to final output.'), 'false')
+ .add_fields('bool', Doc('training', 'Whether it is in training mode.'), 'true')
+ .add_fields('bool', Doc('bias', 'Whether to add linear bias.'), 'false')
+ .add_fields('bool', Doc('attn_mask', 'Whether to add attn_mask.'), 'false')
+ .add_fields('bool', Doc('enable_qproj', 'enable query weight projection.'), 'true')
+ .add_fields('bool', Doc('enable_kproj', 'enable key weight projection.'), 'true')
+ .add_fields('bool', Doc('enable_vproj', 'enable value weight projection.'), 'true')
+ .add_fields('bool', Doc('enable_oproj', 'enable output weight projection.'), 'true')
+ .add_fields('uint64', Doc('seed', 'Random number seed for drop'), '0')
+ .add_fields('float32', Doc('attn_prob', 'Dropout probability on attention, is applied directly to the softmax output'), '0.f')
+ .add_fields('float32', Doc('out_prob', 'Dropout probability on output, alters the multi-head attention output'), '0.f')
+ )
diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h
index 19124f780a1417f277d9f40ba78f7480e992e9ad..7cca0cbf56dbb374c475f7b647232447475d7025 100644
--- a/dnn/src/common/handle_impl.h
+++ b/dnn/src/common/handle_impl.h
@@ -221,7 +221,9 @@ private:
     cb(RegionRestrictedConvolutionBackwardFilter) \
     cb(GroupNormForward) \
     cb(GroupNormBackward) \
-    cb(MaskedFill)
+    cb(MaskedFill) \
+    cb(MultiHeadAttnForward)\
+    cb(MultiHeadAttnBackward)
 // clang-format on
 
 /*!
diff --git a/dnn/src/common/multi_head_attn.cpp b/dnn/src/common/multi_head_attn.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..a3219a67483c39fce374e12100972ac6c39d0c82
--- /dev/null
+++ b/dnn/src/common/multi_head_attn.cpp
@@ -0,0 +1,166 @@
+#include "megdnn/basic_types.h"
+#include "megdnn/oprs.h"
+#include "src/common/utils.cuh"
+#include "unroll_macro.h"
+
+#include "src/common/utils.h"
+
+namespace megdnn {
+
+using Param = MultiHeadAttnBase::Param;
+
+void MultiHeadAttnForward::deduce_layout(
+        const TensorLayout& queries, const TensorLayout& keys,
+        const TensorLayout& values, const TensorLayout& wqkv, TensorLayout& out,
+        TensorLayout& reserveSpace) {
+    megdnn_assert(
+            queries.ndim == 3,
+            "queries.ndim should be 3[batch, sequence, embeding], but got %zu",
+            queries.ndim);
+    size_t size =
+            get_reservespace_in_bytes(queries, keys, values, wqkv, out, reserveSpace);
+    out = TensorLayout(
+            {queries.shape[0], queries.shape[1], queries.shape[2]}, queries.dtype);
+    reserveSpace = TensorLayout({size}, queries.dtype);
+}
+
+void MultiHeadAttnForward::check_exec(
+        const TensorLayout& queries, const TensorLayout& keys,
+        const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& out,
+        const TensorLayout& reserveSpace, size_t workspace_in_bytes) {
+    Param p = param();
+    megdnn_assert_contiguous(queries);
+    megdnn_assert_contiguous(keys);
+    megdnn_assert_contiguous(values);
+    megdnn_assert_contiguous(wqkv);
+    megdnn_assert_contiguous(out);
+    if (p.training)
+        megdnn_assert_contiguous(reserveSpace);
+    auto required_workspace_in_bytes =
+            get_workspace_in_bytes(queries, keys, values, wqkv, out, reserveSpace);
+    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
+
+    megdnn_assert(
+            queries.ndim == 3,
+            "queries.ndim should be 3[batch, sequence, embeding], but got %zu",
+            queries.ndim);
+    megdnn_assert(
+            keys.ndim == 3,
+            "keys.ndim should be 3[batch, sequence, embeding], but got %zu", keys.ndim);
+    megdnn_assert(
+            values.ndim == 3,
+            "values.ndim should be 3[batch, sequence, embeding], but got %zu",
+            values.ndim);
+
+    auto errmsg = [&]() {
+        return megdnn_layout_msg(queries) + ", " + megdnn_layout_msg(keys) + ", " +
+               megdnn_layout_msg(values) + ", " + megdnn_layout_msg(wqkv) + ", " +
+               megdnn_layout_msg(out) + ", " + megdnn_layout_msg(reserveSpace);
+    };
+    megdnn_assert(queries.shape[0] == out.shape[0], "%s", errmsg().c_str());
+    megdnn_assert(keys.shape[0] == values.shape[0], "%s", errmsg().c_str());
+    megdnn_assert(queries.shape[0] == keys.shape[0], "%s", errmsg().c_str());
+    megdnn_assert(queries.shape[1] == out.shape[1], "%s", errmsg().c_str());
+    megdnn_assert(keys.shape[1] == values.shape[1], "%s", errmsg().c_str());
+    megdnn_assert(
+            queries.shape[2] == keys.shape[2] and keys.shape[2] == values.shape[2] and
+                    queries.shape[2] == out.shape[2],
+            "%s", errmsg().c_str());
+}
+
+void MultiHeadAttnBackward::deduce_layout(
+        const TensorLayout& diff, const TensorLayout& queries, const TensorLayout& keys,
+        const TensorLayout& values, const TensorLayout& wqkv,
+        const TensorLayout& reserveSpace, TensorLayout& dqueries, TensorLayout& dkeys,
+        TensorLayout& dvalues, TensorLayout& dweights) {
+    MEGDNN_MARK_USED_VAR(diff);
+    MEGDNN_MARK_USED_VAR(reserveSpace);
+    dqueries = queries;
+    dkeys = keys;
+    dvalues = values;
+    dweights = wqkv;
+}
+
+void MultiHeadAttnBackward::check_exec(
+        const TensorLayout& diff, const TensorLayout& queries, const TensorLayout& keys,
+        const TensorLayout& values, const TensorLayout& wqkv,
+        const TensorLayout& reserveSpace, const TensorLayout& dqueries,
+        const TensorLayout& dkeys, const TensorLayout& dvalues,
+        const TensorLayout& dweights, size_t workspace_in_bytes) {
+    Param p = param();
+    megdnn_assert(
+            p.training,
+            "When calling MultiHeadAttn backward, param().training must be true, "
+            "but got false");
+    megdnn_assert_contiguous(diff);
+    megdnn_assert_contiguous(queries);
+    megdnn_assert_contiguous(keys);
+    megdnn_assert_contiguous(values);
+    megdnn_assert_contiguous(wqkv);
+    megdnn_assert_contiguous(dqueries);
+    megdnn_assert_contiguous(dkeys);
+    megdnn_assert_contiguous(dvalues);
+    megdnn_assert_contiguous(dweights);
+    if (p.training)
+        megdnn_assert_contiguous(reserveSpace);
+    auto required_workspace_in_bytes = get_workspace_in_bytes(
+            diff, queries, keys, values, wqkv, reserveSpace, dqueries, dkeys, dvalues,
+            dweights);
+    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
+    megdnn_assert(reserveSpace.total_nr_elems() > 0);
+
+    megdnn_assert(
+            queries.ndim == 3,
+            "queries.ndim should be 3[batch, sequence, embeding], but got %zu",
+            queries.ndim);
+    megdnn_assert(
+            keys.ndim == 3,
+            "keys.ndim should be 3[batch, sequence, embeding], but got %zu", keys.ndim);
+    megdnn_assert(
+            values.ndim == 3,
+            "values.ndim should be 3[batch, sequence, embeding], but got %zu",
+            values.ndim);
+    megdnn_assert(
+            diff.ndim == 3,
+            "diff.ndim should be 3[batch, sequence, embeding], but got %zu", diff.ndim);
+
+    auto errmsg = [&]() {
+        return megdnn_layout_msg(diff) + ", " + megdnn_layout_msg(queries) + ", " +
+               megdnn_layout_msg(keys) + ", " + megdnn_layout_msg(values) + ", " +
+               megdnn_layout_msg(wqkv) + ", " + megdnn_layout_msg(reserveSpace) + ", " +
+               megdnn_layout_msg(dqueries) + ", " + megdnn_layout_msg(dkeys) + ", " +
+               megdnn_layout_msg(dvalues) + ", " + megdnn_layout_msg(dweights);
+    };
+
+    auto equal_layout = [](const TensorLayout& lhs, const TensorLayout& rhs) -> bool {
+        if (!(lhs.ndim == rhs.ndim && lhs.dtype == rhs.dtype &&
+              lhs.format == rhs.format))
+            return false;
+        for (size_t i = 0; i < lhs.ndim; ++i) {
+            if (lhs.shape[i] != rhs.shape[i] || lhs.stride[i] != rhs.stride[i]) {
+                return false;
+            }
+        }
+        return true;
+    };
+
+    megdnn_assert(equal_layout(queries, diff), "%s", errmsg().c_str());
+    megdnn_assert(equal_layout(queries, dqueries), "%s", errmsg().c_str());
+    megdnn_assert(equal_layout(keys, dkeys), "%s", errmsg().c_str());
+    megdnn_assert(equal_layout(values, dvalues), "%s", errmsg().c_str());
+    megdnn_assert(equal_layout(wqkv, dweights), "%s", errmsg().c_str());
+
+    megdnn_assert(queries.shape[0] == diff.shape[0], "%s", errmsg().c_str());
+    megdnn_assert(keys.shape[0] == values.shape[0], "%s", errmsg().c_str());
+    megdnn_assert(queries.shape[0] == keys.shape[0], "%s", errmsg().c_str());
+    megdnn_assert(queries.shape[1] == diff.shape[1], "%s", errmsg().c_str());
+    megdnn_assert(keys.shape[1] == values.shape[1], "%s", errmsg().c_str());
+    megdnn_assert(
+            queries.shape[2] == keys.shape[2] and keys.shape[2] == values.shape[2] and
+                    queries.shape[2] == diff.shape[2],
+            "%s", errmsg().c_str());
+}
+
+}  // namespace megdnn
+
+// vim: syntax=cpp.doxygen
diff --git a/dnn/src/common/opr_trait.h b/dnn/src/common/opr_trait.h
index 541bd67a3d0bad15867ce577cdfed39c806c6733..65591fb9194b9653c0efe0a7e720ea69364d3cf3 100644
--- a/dnn/src/common/opr_trait.h
+++ b/dnn/src/common/opr_trait.h
@@ -1,5 +1,6 @@
 #pragma once
 #include "megdnn/oprs.h"
+#include "megdnn/oprs/nn.h"
 
 #include <cstddef>
 
@@ -147,6 +148,8 @@ DEF(RegionRestrictedConvolutionBackwardFilter, 5, true, false);
 DEF(GroupNormForward, 6, true, true);
 DEF(GroupNormBackward, 8, true, true);
 DEF(MaskedFill, 3, false, true);
+DEF(MultiHeadAttnForward, 6, true, true);
+DEF(MultiHeadAttnBackward, 10, true, true);
 }  // namespace megdnn
 
 // vim: syntax=cpp.doxygen
diff --git a/dnn/src/cuda/cudnn_wrapper.cpp b/dnn/src/cuda/cudnn_wrapper.cpp
index b8ff39882b28b4c4b4aa19cc37e01c1a7b936067..a5b93b429c39cc525619fecc1e97dc4405132f1f 100644
--- a/dnn/src/cuda/cudnn_wrapper.cpp
+++ b/dnn/src/cuda/cudnn_wrapper.cpp
@@ -3,12 +3,10 @@
 #include "src/common/utils.h"
 #include "src/cuda/utils.h"
 
-namespace {
-
-using namespace megdnn;
+namespace megdnn {
+namespace cuda {
 
-cudnnDataType_t to_cudnn_dtype(
-        DType type, const param::Convolution::Format format = {}) {
+cudnnDataType_t to_cudnn_dtype(DType type, const param::Convolution::Format format) {
     switch (type.enumv()) {
         case DTypeEnum::Float32:
             return CUDNN_DATA_FLOAT;
@@ -66,8 +64,9 @@ cudnnTensorFormat_t to_cudnn_format(const param::Convolution::Format format) {
             megdnn_assert_internal(0);
     }
 }
+}  // namespace cuda
 
-}  // namespace
+}  // namespace megdnn
 
 namespace megdnn {
 namespace cuda {
@@ -558,6 +557,71 @@ const std::unordered_map<cudnnConvolutionFwdAlgo_t, CudnnAlgoPack::Attr> CudnnAl
 #undef V
 #undef V1
 
+#if CUDNN_VERSION >= 8004
+SeqTensorDesc::~SeqTensorDesc() {
+    cudnn_check(cudnnDestroySeqDataDescriptor(desc));
+}
+SeqTensorDesc::SeqTensorDesc() {
+    cudnnCreateSeqDataDescriptor(&desc);
+}
+
+SeqTensorDesc::SeqTensorDesc(
+        const TensorLayout& layout, const size_t batchSize, const size_t seqLen,
+        const size_t elemSize, const size_t input_order, int* seqArray) {
+    cudnnCreateSeqDataDescriptor(&desc);
+    set(layout, batchSize, seqLen, elemSize, input_order, seqArray);
+}
+
+void SeqTensorDesc::set(
+        const TensorLayout& layout, const size_t batchSize, const size_t seqLen,
+        const size_t elemSize, const size_t input_order, int* seqArray) {
+    switch (input_order) {
+        case 0:  // dimAxes = [Batch, Beam, Time]
+            dimAxes[0] = CUDNN_SEQDATA_BATCH_DIM;
+            dimAxes[1] = CUDNN_SEQDATA_BEAM_DIM;
+            dimAxes[2] = CUDNN_SEQDATA_TIME_DIM;
+            break;
+        case 1:  // dimAxes = [Beam, Batch, Time]
+            dimAxes[0] = CUDNN_SEQDATA_BEAM_DIM;
+            dimAxes[1] = CUDNN_SEQDATA_BATCH_DIM;
+            dimAxes[2] = CUDNN_SEQDATA_TIME_DIM;
+            break;
+        case 2:  // dimAxes = [Batch, Time, Beam]
+            dimAxes[0] = CUDNN_SEQDATA_BATCH_DIM;
+            dimAxes[1] = CUDNN_SEQDATA_TIME_DIM;
+            dimAxes[2] = CUDNN_SEQDATA_BEAM_DIM;
+            break;
+        case 3:  // dimAxes = [Beam, Time, Batch]
+            dimAxes[0] = CUDNN_SEQDATA_BEAM_DIM;
+            dimAxes[1] = CUDNN_SEQDATA_TIME_DIM;
+            dimAxes[2] = CUDNN_SEQDATA_BATCH_DIM;
+            break;
+        case 4:  // dimAxes = [Time, Batch, Beam]
+            dimAxes[0] = CUDNN_SEQDATA_TIME_DIM;
+            dimAxes[1] = CUDNN_SEQDATA_BATCH_DIM;
+            dimAxes[2] = CUDNN_SEQDATA_BEAM_DIM;
+            break;
+        case 5:  // dimAxes = [Time, Beam, Batch]
+            dimAxes[0] = CUDNN_SEQDATA_TIME_DIM;
+            dimAxes[1] = CUDNN_SEQDATA_BEAM_DIM;
+            dimAxes[2] = CUDNN_SEQDATA_BATCH_DIM;
+            break;
+        default:
+            megdnn_throw(ssprintf("ERROR: wrong attention layout %zu", input_order));
+    }
+    dimAxes[3] = CUDNN_SEQDATA_VECT_DIM;
+
+    dim[CUDNN_SEQDATA_BEAM_DIM] = 1;
+    dim[CUDNN_SEQDATA_BATCH_DIM] = batchSize;
+    dim[CUDNN_SEQDATA_TIME_DIM] = seqLen;
+    dim[CUDNN_SEQDATA_VECT_DIM] = elemSize;
+
+    cudnnDataType_t cudnn_dtype = to_cudnn_dtype(layout.dtype);
+    cudnn_check(cudnnSetSeqDataDescriptor(
+            desc, cudnn_dtype, CUDNN_SEQDATA_DIM_COUNT, dim, dimAxes, batchSize,
+            seqArray, NULL));
+}
+#endif
 }  // namespace cuda
 }  // namespace megdnn
 
diff --git a/dnn/src/cuda/cudnn_wrapper.h b/dnn/src/cuda/cudnn_wrapper.h
index c2f3992091114b98626213480dc16f1c0f56c8b4..a1962c836237f89b544dbb52f4451b1360cfdba9 100644
--- a/dnn/src/cuda/cudnn_wrapper.h
+++ b/dnn/src/cuda/cudnn_wrapper.h
@@ -2,12 +2,18 @@
 
 #include <unordered_map>
 #include "megdnn/basic_types.h"
+#include "megdnn/opr_param_defs.h"
 #include "megdnn/oprs/nn.h"
 #include "src/cuda/cudnn_with_check.h"
 
 namespace megdnn {
 namespace cuda {
 
+cudnnDataType_t to_cudnn_dtype(
+        DType type, const param::Convolution::Format format = {});
+
+cudnnTensorFormat_t to_cudnn_format(const param::Convolution::Format format);
+
 /*!
  * \brief get compute_type of convolution operations
  */
@@ -85,6 +91,24 @@ public:
     cudnnConvolutionDescriptor_t desc;
 };
 
+#if CUDNN_VERSION >= 8004
+class SeqTensorDesc {
+public:
+    int dim[CUDNN_SEQDATA_DIM_COUNT];
+    cudnnSeqDataAxis_t dimAxes[CUDNN_SEQDATA_DIM_COUNT];
+    cudnnSeqDataDescriptor_t desc;
+
+    ~SeqTensorDesc();
+    SeqTensorDesc();
+    SeqTensorDesc(
+            const TensorLayout& layout, const size_t batchSize, const size_t seqLen,
+            const size_t elemSize, const size_t dataLayout, int* seqArray);
+    void set(
+            const TensorLayout& layout, const size_t batchSize, const size_t seqLen,
+            const size_t elemSize, const size_t dataLayout, int* seqArray);
+};
+#endif
+
 class CudnnAlgoPack {
 public:
     //! algorithm attr
diff --git a/dnn/src/cuda/cudnn_wrapper_v8.h b/dnn/src/cuda/cudnn_wrapper_v8.h
index 575f43d0f00ea8e6e5a3ce5f64c6303c9b9a085a..09ec3b89b7033b1a28f05178468cb8db477d427f 100644
--- a/dnn/src/cuda/cudnn_wrapper_v8.h
+++ b/dnn/src/cuda/cudnn_wrapper_v8.h
@@ -55,7 +55,6 @@ void run_conv_bias_act_with_plan(
         const cudnnHandle_t& handle, const cudnn_frontend::ExecutionPlan& plan,
         const TensorND& x, const TensorND& y, const TensorND& w, const TensorND& b,
         const TensorND& z, const Workspace& workspace);
-
 }  // namespace cuda
 }  // namespace megdnn
 
diff --git a/dnn/src/cuda/dropout/opr_impl.h b/dnn/src/cuda/dropout/opr_impl.h
index 5d995eeb6103c475bb9f29c7315db146667627a2..50b0f3f7cc0c0df47e3ea5030b06e0e9094e7fd7 100644
--- a/dnn/src/cuda/dropout/opr_impl.h
+++ b/dnn/src/cuda/dropout/opr_impl.h
@@ -61,6 +61,9 @@ public:
     bool initialized() { return status != nullptr; }
     friend class DropoutForwardImpl;
     friend class DropoutBackwardImpl;
+#if CUDNN_VERSION >= 8004
+    friend class MultiHeadAttnStatus;
+#endif
 };
 
 // similar to RNG operator, dropout operator also have status
diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp
index 01fbfba5612365310c235dd1bedd6dfbf84832e0..e1f397ce7bb13c1744f72e97b229e3324437f698 100644
--- a/dnn/src/cuda/handle_create.cpp
+++ b/dnn/src/cuda/handle_create.cpp
@@ -50,6 +50,7 @@
 #include "src/cuda/matrix_mul/opr_impl.h"
 #include "src/cuda/max_tensor_diff/opr_impl.h"
 #include "src/cuda/mesh_indexing/opr_impl.h"
+#include "src/cuda/multi_head_attn/opr_impl.h"
 #include "src/cuda/norm/opr_impl.h"
 #include "src/cuda/padding/opr_impl.h"
 #include "src/cuda/param_pack/opr_impl.h"
@@ -230,6 +231,8 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(NormForward);
 MEGDNN_SPECIALIZE_CREATE_OPERATOR(RegionRestrictedConvolutionForward);
 MEGDNN_SPECIALIZE_CREATE_OPERATOR(RegionRestrictedConvolutionBackwardData);
 MEGDNN_SPECIALIZE_CREATE_OPERATOR(RegionRestrictedConvolutionBackwardFilter);
+MEGDNN_SPECIALIZE_CREATE_OPERATOR(MultiHeadAttnForward);
+MEGDNN_SPECIALIZE_CREATE_OPERATOR(MultiHeadAttnBackward);
 
 template <typename Opr>
 std::unique_ptr<Opr> HandleImpl::create_operator() {
diff --git a/dnn/src/cuda/multi_head_attn/helper.cpp b/dnn/src/cuda/multi_head_attn/helper.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..ab685629242deb8f3682214d5afe778fb135d07a
--- /dev/null
+++ b/dnn/src/cuda/multi_head_attn/helper.cpp
@@ -0,0 +1,181 @@
+#include "src/cuda/multi_head_attn/helper.h"
+#if CUDNN_VERSION >= 8004
+
+namespace megdnn {
+namespace cuda {
+
+AuxiliaryArray::~AuxiliaryArray() {
+    if (loWinIdx)
+        free(loWinIdx);
+    if (hiWinIdx)
+        free(hiWinIdx);
+    if (seqQArray)
+        free(seqQArray);
+    if (seqKArray)
+        free(seqKArray);
+    if (devSeqQArray)
+        cuda_check(cudaFree(devSeqQArray));
+    if (devSeqKArray)
+        cuda_check(cudaFree(devSeqKArray));
+}
+
+bool AuxiliaryArray::is_initialized(
+        const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK,
+        bool _attnMask) {
+    if (_batchSize != batchSize or _seqLenQ != seqLenQ or _seqLenK != seqLenK or
+        _attnMask != attnMask or !seqQArray or !seqKArray or !devSeqQArray or
+        !devSeqKArray or !loWinIdx or !hiWinIdx)
+        return false;
+    return true;
+}
+
+void AuxiliaryArray::set(
+        const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK,
+        bool _attnMask) {
+    if (_batchSize == batchSize && _seqLenQ == seqLenQ && _seqLenK == seqLenK &&
+        _attnMask == attnMask)
+        return;
+    else {
+        if (loWinIdx)
+            free(loWinIdx);
+        if (hiWinIdx)
+            free(hiWinIdx);
+        if (seqQArray)
+            free(seqQArray);
+        if (seqKArray)
+            free(seqKArray);
+        if (devSeqQArray)
+            cuda_check(cudaFree(devSeqQArray));
+        if (devSeqKArray)
+            cuda_check(cudaFree(devSeqKArray));
+    };
+
+    seqLenQ = _seqLenQ;
+    seqLenK = _seqLenK;
+    batchSize = _batchSize;
+    attnMask = _attnMask;
+    size_t seqQArraySize = 1 * batchSize;
+    size_t seqKArraySize = batchSize;
+    seqQArray = (int*)calloc(seqQArraySize, sizeof(int));
+    seqKArray = (int*)calloc(seqKArraySize, sizeof(int));
+    for (size_t i = 0; i < seqQArraySize; ++i)
+        seqQArray[i] = seqLenQ;
+    for (size_t i = 0; i < seqKArraySize; ++i)
+        seqKArray[i] = seqLenK;
+
+    cuda_check(cudaMalloc((void**)&devSeqQArray, seqQArraySize * sizeof(int)));
+    cuda_check(cudaMalloc((void**)&devSeqKArray, seqKArraySize * sizeof(int)));
+
+    cuda_check(cudaMemcpy(
+            devSeqQArray, seqQArray, seqQArraySize * sizeof(int),
+            cudaMemcpyHostToDevice));
+    cuda_check(cudaMemcpy(
+            devSeqKArray, seqKArray, seqKArraySize * sizeof(int),
+            cudaMemcpyHostToDevice));
+
+    loWinIdx = (int*)calloc(seqLenQ, sizeof(int));
+    hiWinIdx = (int*)calloc(seqLenQ, sizeof(int));
+    for (size_t i = 0; i < seqLenQ; ++i) {
+        loWinIdx[i] = 0;
+        if (attnMask)
+            hiWinIdx[i] = i + 1;
+        else
+            hiWinIdx[i] = seqLenK;
+    }
+}
+
+void MultiHeadAttnStatus::set(
+        cudnnHandle_t handle, const Param& p, const TensorLayout& q,
+        const TensorLayout& k, const TensorLayout& v) {
+    float attn_prob = p.training ? p.attn_prob : 0.f;
+    float out_prob = p.training ? p.out_prob : 0.f;
+    if (!attn_dropout_status.initialized())
+        attn_dropout_status.set(handle, p.seed, attn_prob);
+    if (!out_dropout_status.initialized())
+        out_dropout_status.set(handle, p.seed, out_prob);
+
+    if (attn_dropout_status.drop_prob != attn_prob) {
+        attn_dropout_status.drop_prob = attn_prob;
+        attn_dropout_status.restore_desc(handle);
+    }
+    if (out_dropout_status.drop_prob != out_prob) {
+        out_dropout_status.drop_prob = out_prob;
+        out_dropout_status.restore_desc(handle);
+    }
+    batchSize = q.shape[0];
+    seqLenQ = q.shape[1];
+    seqLenK = k.shape[1];
+    qSize = q.shape[2];
+    kSize = k.shape[2];
+    vSize = v.shape[2];
+    numHeads = p.num_heads;
+    qProjSize = p.enable_qproj ? qSize / numHeads : 0;
+    kProjSize = p.enable_kproj ? kSize / numHeads : 0;
+    vProjSize = p.enable_vproj ? vSize / numHeads : 0;
+    oProjSize = p.enable_oproj ? qSize : 0;
+    attnMask = p.attn_mask;
+    cudnnDataType_t cudnn_dtype = to_cudnn_dtype(q.dtype);
+    auto flag = CUDNN_ATTN_QUERYMAP_ONE_TO_ONE;
+    if (p.bias)
+        flag = flag | CUDNN_ATTN_ENABLE_PROJ_BIASES;
+#if CUDNN_VERSION < 8600
+    // TODO: CUDNN_VERSION < 8600 and out dropout > 0.0, we need to go to the proxy cuda
+    // implementation.
+    cudnn_check(cudnnSetAttnDescriptor(
+            attn_desc, flag, numHeads, p.sm_scaler, cudnn_dtype, cudnn_dtype,
+            CUDNN_DEFAULT_MATH, attn_dropout_status.desc.desc, NULL, qSize, kSize,
+            vSize, qProjSize, kProjSize, vProjSize, oProjSize, seqLenQ, seqLenK,
+            batchSize, 1));
+#else
+    cudnn_check(cudnnSetAttnDescriptor(
+            attn_desc, flag, numHeads, p.sm_scaler, cudnn_dtype, cudnn_dtype,
+            CUDNN_DEFAULT_MATH, attn_dropout_status.desc.desc,
+            out_dropout_status.desc.desc, qSize, kSize, vSize, qProjSize, kProjSize,
+            vProjSize, oProjSize, seqLenQ, seqLenK, batchSize, 1));
+#endif
+
+    auxArray.set(batchSize, seqLenQ, seqLenK, p.attn_mask);
+
+    if (p.training)
+        cudnnGetMultiHeadAttnBuffers(
+                handle, attn_desc, &sizeWeights, &sizeWkspace, &sizeReserve);
+    else {
+        cudnnGetMultiHeadAttnBuffers(
+                handle, attn_desc, &sizeWeights, &sizeWkspace, NULL);
+        sizeReserve = 0;
+    }
+}
+
+bool MultiHeadAttnStatus::is_initialized(
+        const Param& p, const TensorLayout& q, const TensorLayout& k,
+        const TensorLayout& v) {
+    float attn_prob = p.training ? p.attn_prob : 0.f;
+    float out_prob = p.training ? p.out_prob : 0.f;
+    if (!attn_dropout_status.initialized() or !out_dropout_status.initialized() or
+        attn_dropout_status.drop_prob != attn_prob or
+        out_dropout_status.drop_prob != out_prob)
+        return false;
+    if (q.shape[0] != batchSize or q.shape[1] != seqLenQ or k.shape[1] != seqLenK or
+        q.shape[2] != qSize or k.shape[2] != kSize or v.shape[2] != vSize or
+        attnMask != p.attn_mask or numHeads != p.num_heads) {
+        return false;
+    }
+    if ((p.enable_qproj && (qProjSize == 0 or qProjSize != qSize / p.num_heads)) or
+        (p.enable_kproj && (kProjSize == 0 or kProjSize != kSize / p.num_heads)) or
+        (p.enable_vproj && (vProjSize == 0 or vProjSize != vSize / p.num_heads)) or
+        (p.enable_oproj && (oProjSize == 0 or oProjSize != q.shape[2])))
+        return false;
+    if ((!p.enable_qproj && qProjSize != 0) or (!p.enable_kproj && kProjSize != 0) or
+        (!p.enable_vproj && vProjSize != 0) or (!p.enable_oproj && oProjSize != 0))
+        return false;
+    if (!auxArray.is_initialized(batchSize, seqLenQ, seqLenK, attnMask))
+        return false;
+    if (p.training and sizeReserve == 0)
+        return false;
+    return true;
+}
+
+}  // namespace cuda
+}  // namespace megdnn
+#endif
+// vim: syntax=cpp.doxygen
diff --git a/dnn/src/cuda/multi_head_attn/helper.h b/dnn/src/cuda/multi_head_attn/helper.h
new file mode 100644
index 0000000000000000000000000000000000000000..bdb37005c6c82519b349a3c277a702f9feedc851
--- /dev/null
+++ b/dnn/src/cuda/multi_head_attn/helper.h
@@ -0,0 +1,80 @@
+#pragma once
+#include "src/cuda/cudnn_wrapper.h"
+#if CUDNN_VERSION >= 8004
+#include "megdnn/basic_types.h"
+#include "megdnn/oprs/nn.h"
+#include "src/common/algo_chooser.h"
+#include "src/common/utils.h"
+#include "src/cuda/dropout/opr_impl.h"
+#include "src/cuda/handle.h"
+
+namespace megdnn {
+namespace cuda {
+
+struct AuxiliaryArray {
+public:
+    int* seqQArray = nullptr;
+    int* seqKArray = nullptr;
+    int* devSeqQArray = nullptr;
+    int* devSeqKArray = nullptr;
+    int* loWinIdx = nullptr;
+    int* hiWinIdx = nullptr;
+    size_t seqLenQ = 0;
+    size_t seqLenK = 0;
+    size_t batchSize = 0;
+    bool attnMask = 0;
+    ~AuxiliaryArray();
+    void set(
+            const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK,
+            bool _attnMask);
+    bool is_initialized(
+            const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK,
+            bool _attnMask);
+};
+
+using Param = megdnn::MultiHeadAttn::Param;
+
+class MultiHeadAttnStatus {
+    DropoutStatus attn_dropout_status;
+    DropoutStatus out_dropout_status;
+
+    cudnnAttnDescriptor_t attn_desc;
+
+    AuxiliaryArray auxArray;
+
+    size_t numHeads = 0;
+    size_t batchSize = 0;
+    size_t seqLenQ = 0;
+    size_t seqLenK = 0;
+    size_t qSize = 0;
+    size_t kSize = 0;
+    size_t vSize = 0;
+    size_t qProjSize = 0;
+    size_t kProjSize = 0;
+    size_t vProjSize = 0;
+    size_t oProjSize = 0;
+    bool attnMask = 0;
+
+    size_t sizeWeights = 0;
+    size_t sizeWkspace = 0;
+    size_t sizeReserve = 0;
+
+public:
+    MultiHeadAttnStatus() { cudnn_check(cudnnCreateAttnDescriptor(&attn_desc)); }
+    ~MultiHeadAttnStatus() { cudnn_check(cudnnDestroyAttnDescriptor(attn_desc)); }
+
+private:
+    void set(
+            cudnnHandle_t handle, const Param& p, const TensorLayout& q,
+            const TensorLayout& k, const TensorLayout& v);
+    bool is_initialized(
+            const Param& p, const TensorLayout& q, const TensorLayout& k,
+            const TensorLayout& v);
+    friend class MultiHeadAttnBase;
+    friend class MultiHeadAttnForwardImpl;
+    friend class MultiHeadAttnBackwardImpl;
+};
+}  // namespace cuda
+}  // namespace megdnn
+#endif
+// vim: syntax=cpp.doxygen
diff --git a/dnn/src/cuda/multi_head_attn/opr_impl.cpp b/dnn/src/cuda/multi_head_attn/opr_impl.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..f4125cdd18db020a58fb5ca90006befc10a26e31
--- /dev/null
+++ b/dnn/src/cuda/multi_head_attn/opr_impl.cpp
@@ -0,0 +1,241 @@
+#include "src/cuda/multi_head_attn/opr_impl.h"
+#include "src/common/utils.cuh"
+#include "src/cuda/utils.cuh"
+#include "src/cuda/utils.h"
+
+namespace megdnn {
+namespace cuda {
+
+void MultiHeadAttnForwardImpl::deduce_layout(
+        const TensorLayout& queries, const TensorLayout& keys,
+        const TensorLayout& values, const TensorLayout& wqkv, TensorLayout& out,
+        TensorLayout& reserveSpace) {
+#if CUDNN_VERSION < 8004
+    // TODO: CUDNN_VERSION < 8004,  we need to go to the proxy cuda implementation.
+    MEGDNN_MARK_USED_VAR(queries);
+    MEGDNN_MARK_USED_VAR(keys);
+    MEGDNN_MARK_USED_VAR(values);
+    MEGDNN_MARK_USED_VAR(wqkv);
+    MEGDNN_MARK_USED_VAR(out);
+    MEGDNN_MARK_USED_VAR(reserveSpace);
+    return;
+#else
+    MEGDNN_MARK_USED_VAR(keys);
+    MEGDNN_MARK_USED_VAR(wqkv);
+    megdnn_assert(
+            queries.ndim == 3,
+            "queries.ndim should be 3[batch, sequence, embeding], but got %zu",
+            queries.ndim);
+
+    if (!desc_status.is_initialized(param(), queries, keys, values)) {
+        desc_status.set(cudnn_handle(this->handle()), param(), queries, keys, values);
+
+        out = TensorLayout(
+                TensorShape{queries.shape[0], queries.shape[1], queries.shape[2]},
+                queries.dtype);
+        reserveSpace =
+                TensorLayout(TensorShape{desc_status.sizeReserve}, queries.dtype);
+    }
+#endif
+}
+
+size_t MultiHeadAttnForwardImpl::get_workspace_in_bytes(
+        const TensorLayout& queries, const TensorLayout& keys,
+        const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& out,
+        const TensorLayout& reserveSpace) {
+#if CUDNN_VERSION < 8004
+    // TODO: CUDNN_VERSION < 8004,  we need to go to the proxy cuda implementation.
+    MEGDNN_MARK_USED_VAR(queries);
+    MEGDNN_MARK_USED_VAR(keys);
+    MEGDNN_MARK_USED_VAR(values);
+    MEGDNN_MARK_USED_VAR(wqkv);
+    MEGDNN_MARK_USED_VAR(out);
+    MEGDNN_MARK_USED_VAR(reserveSpace);
+    return 0;
+#else
+    MEGDNN_MARK_USED_VAR(wqkv);
+    MEGDNN_MARK_USED_VAR(out);
+    MEGDNN_MARK_USED_VAR(reserveSpace);
+
+    if (!desc_status.is_initialized(param(), queries, keys, values))
+        desc_status.set(cudnn_handle(this->handle()), param(), queries, keys, values);
+
+    return desc_status.sizeWkspace;
+#endif
+}
+
+size_t MultiHeadAttnForwardImpl::get_reservespace_in_bytes(
+        const TensorLayout& queries, const TensorLayout& keys,
+        const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& out,
+        const TensorLayout& reserveSpace) {
+#if CUDNN_VERSION < 8004
+    // TODO: CUDNN_VERSION < 8004,  we need to go to the proxy cuda implementation.
+    MEGDNN_MARK_USED_VAR(queries);
+    MEGDNN_MARK_USED_VAR(keys);
+    MEGDNN_MARK_USED_VAR(values);
+    MEGDNN_MARK_USED_VAR(wqkv);
+    MEGDNN_MARK_USED_VAR(out);
+    MEGDNN_MARK_USED_VAR(reserveSpace);
+    return 0;
+#else
+    MEGDNN_MARK_USED_VAR(wqkv);
+    MEGDNN_MARK_USED_VAR(out);
+    MEGDNN_MARK_USED_VAR(reserveSpace);
+    if (!desc_status.is_initialized(param(), queries, keys, values))
+        desc_status.set(cudnn_handle(this->handle()), param(), queries, keys, values);
+    return desc_status.sizeReserve;
+#endif
+}
+void MultiHeadAttnForwardImpl::exec(
+        _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values,
+        _megdnn_tensor_in wqkv, _megdnn_tensor_out out, _megdnn_tensor_out reserveSpace,
+        _megdnn_workspace workspace) {
+#if CUDNN_VERSION < 8004
+    // TODO: CUDNN_VERSION < 8004,  we need to go to the proxy cuda implementation.
+    MEGDNN_MARK_USED_VAR(queries);
+    MEGDNN_MARK_USED_VAR(keys);
+    MEGDNN_MARK_USED_VAR(values);
+    MEGDNN_MARK_USED_VAR(wqkv);
+    MEGDNN_MARK_USED_VAR(out);
+    MEGDNN_MARK_USED_VAR(reserveSpace);
+    MEGDNN_MARK_USED_VAR(workspace);
+    megdnn_throw(
+            "The cudnn version is lower than 8.0.4. Please upgrade the cudnn version.");
+#else
+    check_exec(
+            queries.layout, keys.layout, values.layout, wqkv.layout, out.layout,
+            reserveSpace.layout, workspace.size);
+    auto p = param();
+
+    if (!desc_status.is_initialized(p, queries.layout, keys.layout, values.layout))
+        desc_status.set(
+                cudnn_handle(this->handle()), p, queries.layout, keys.layout,
+                values.layout);
+
+    SeqTensorDesc q{queries.layout,      desc_status.batchSize,
+                    desc_status.seqLenQ, desc_status.qSize,
+                    p.input_order,       desc_status.auxArray.seqQArray};
+    SeqTensorDesc o{out.layout,          desc_status.batchSize,
+                    desc_status.seqLenQ, desc_status.oProjSize,
+                    p.input_order,       desc_status.auxArray.seqQArray};
+    SeqTensorDesc k{keys.layout,         desc_status.batchSize,
+                    desc_status.seqLenK, desc_status.kSize,
+                    p.input_order,       desc_status.auxArray.seqKArray};
+    SeqTensorDesc v{values.layout,       desc_status.batchSize,
+                    desc_status.seqLenK, desc_status.vSize,
+                    p.input_order,       desc_status.auxArray.seqKArray};
+
+    cudnn_check(cudnnMultiHeadAttnForward(
+            cudnn_handle(this->handle()), desc_status.attn_desc, -1,
+            desc_status.auxArray.loWinIdx, desc_status.auxArray.hiWinIdx,
+            desc_status.auxArray.devSeqQArray, desc_status.auxArray.devSeqKArray,
+            q.desc, queries.raw_ptr(), p.reslink ? queries.raw_ptr() : NULL, k.desc,
+            keys.raw_ptr(), v.desc, values.raw_ptr(), o.desc, out.raw_ptr(),
+            desc_status.sizeWeights,
+            desc_status.sizeWeights > 0 ? wqkv.raw_ptr() : NULL,
+            desc_status.sizeWkspace, workspace.raw_ptr,
+            p.training ? desc_status.sizeReserve : 0,
+            p.training ? reserveSpace.raw_ptr() : NULL));
+#endif
+}
+
+void MultiHeadAttnBackwardImpl::exec(
+        _megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys,
+        _megdnn_tensor_in values, _megdnn_tensor_in wqkv,
+        _megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries,
+        _megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues,
+        _megdnn_tensor_out dweights, _megdnn_workspace workspace) {
+#if CUDNN_VERSION < 8004
+    // TODO: CUDNN_VERSION < 8004 and param().bias = true, we need to go to the proxy
+    // cuda implementation.
+    MEGDNN_MARK_USED_VAR(diff);
+    MEGDNN_MARK_USED_VAR(queries);
+    MEGDNN_MARK_USED_VAR(keys);
+    MEGDNN_MARK_USED_VAR(values);
+    MEGDNN_MARK_USED_VAR(wqkv);
+    MEGDNN_MARK_USED_VAR(reserveSpace);
+    MEGDNN_MARK_USED_VAR(dqueries);
+    MEGDNN_MARK_USED_VAR(dkeys);
+    MEGDNN_MARK_USED_VAR(dvalues);
+    MEGDNN_MARK_USED_VAR(dweights);
+    megdnn_throw(
+            "The cudnn version is lower than 8.0.4. Please upgrade the cudnn version.");
+#else
+#if CUDNN_VERSION < 8600
+    megdnn_assert(
+            !param().bias,
+            "If the cudnn version is lower than 8.6.0, param().bias must be false, "
+            "but got true, because there is an error in the "
+            "dbias result during the backward calculation.");
+#endif
+
+    check_exec(
+            diff.layout, queries.layout, keys.layout, values.layout, wqkv.layout,
+            reserveSpace.layout, dqueries.layout, dkeys.layout, dvalues.layout,
+            dweights.layout, workspace.size);
+    auto p = param();
+
+    if (!desc_status.is_initialized(p, queries.layout, keys.layout, values.layout))
+        desc_status.set(
+                cudnn_handle(this->handle()), p, queries.layout, keys.layout,
+                values.layout);
+
+    SeqTensorDesc q{queries.layout,      desc_status.batchSize,
+                    desc_status.seqLenQ, desc_status.qSize,
+                    p.input_order,       desc_status.auxArray.seqQArray};
+    SeqTensorDesc d{diff.layout,         desc_status.batchSize,
+                    desc_status.seqLenQ, desc_status.oProjSize,
+                    p.input_order,       desc_status.auxArray.seqQArray};
+    SeqTensorDesc k{keys.layout,         desc_status.batchSize,
+                    desc_status.seqLenK, desc_status.kSize,
+                    p.input_order,       desc_status.auxArray.seqKArray};
+    SeqTensorDesc v{values.layout,       desc_status.batchSize,
+                    desc_status.seqLenK, desc_status.vSize,
+                    p.input_order,       desc_status.auxArray.seqKArray};
+
+    cudnn_check(cudnnMultiHeadAttnBackwardData(
+            cudnn_handle(this->handle()), desc_status.attn_desc,
+            desc_status.auxArray.loWinIdx, desc_status.auxArray.hiWinIdx,
+            desc_status.auxArray.devSeqQArray, desc_status.auxArray.devSeqKArray,
+            d.desc, diff.raw_ptr(), q.desc, dqueries.raw_ptr(), queries.raw_ptr(),
+            k.desc, dkeys.raw_ptr(), keys.raw_ptr(), v.desc, dvalues.raw_ptr(),
+            values.raw_ptr(), desc_status.sizeWeights,
+            desc_status.sizeWeights > 0 ? wqkv.raw_ptr() : NULL,
+            desc_status.sizeWkspace, workspace.raw_ptr, desc_status.sizeReserve,
+            reserveSpace.raw_ptr()));
+
+    cuda_check(cudaMemset(dweights.raw_ptr(), 0, desc_status.sizeWeights));
+#if CUDNN_VERSION < 8600
+    cuda_check(cudaDeviceSynchronize());
+#endif
+    cudnn_check(cudnnMultiHeadAttnBackwardWeights(
+            cudnn_handle(this->handle()), desc_status.attn_desc, CUDNN_WGRAD_MODE_ADD,
+            q.desc, queries.raw_ptr(), k.desc, keys.raw_ptr(), v.desc, values.raw_ptr(),
+            d.desc, diff.raw_ptr(), desc_status.sizeWeights,
+            desc_status.sizeWeights > 0 ? wqkv.raw_ptr() : NULL,
+            desc_status.sizeWeights > 0 ? dweights.raw_ptr() : NULL,
+            desc_status.sizeWkspace, workspace.raw_ptr, desc_status.sizeReserve,
+            reserveSpace.raw_ptr()));
+#endif
+}
+size_t MultiHeadAttnBackwardImpl::get_workspace_in_bytes(
+        const TensorLayout& diff, const TensorLayout& queries, const TensorLayout& keys,
+        const TensorLayout& values, const TensorLayout& wqkv,
+        const TensorLayout& reserveSpace, const TensorLayout& dqueries,
+        const TensorLayout& dkeys, const TensorLayout& dvalues,
+        const TensorLayout& dweights) {
+    MEGDNN_MARK_USED_VAR(diff);
+    MEGDNN_MARK_USED_VAR(queries);
+    MEGDNN_MARK_USED_VAR(keys);
+    MEGDNN_MARK_USED_VAR(values);
+    MEGDNN_MARK_USED_VAR(wqkv);
+    MEGDNN_MARK_USED_VAR(reserveSpace);
+    MEGDNN_MARK_USED_VAR(dqueries);
+    MEGDNN_MARK_USED_VAR(dkeys);
+    MEGDNN_MARK_USED_VAR(dvalues);
+    MEGDNN_MARK_USED_VAR(dweights);
+    return 0;
+}
+}  // namespace cuda
+}  // namespace megdnn
+// vim: syntax=cpp.doxygen
diff --git a/dnn/src/cuda/multi_head_attn/opr_impl.h b/dnn/src/cuda/multi_head_attn/opr_impl.h
new file mode 100644
index 0000000000000000000000000000000000000000..4596bd37175c50469e031701ef799cf64c871e84
--- /dev/null
+++ b/dnn/src/cuda/multi_head_attn/opr_impl.h
@@ -0,0 +1,59 @@
+#pragma once
+#include "megdnn/handle.h"
+#include "megdnn/oprs.h"
+#include "src/common/reduce_helper.h"
+#include "src/cuda/cudnn_wrapper.h"
+#include "src/cuda/handle.h"
+#include "src/cuda/multi_head_attn/helper.h"
+#include "src/cuda/utils.h"
+
+namespace megdnn {
+namespace cuda {
+
+class MultiHeadAttnForwardImpl final : public MultiHeadAttnForward {
+public:
+    using MultiHeadAttnForward::MultiHeadAttnForward;
+#if CUDNN_VERSION >= 8004
+    MultiHeadAttnStatus desc_status;
+#endif
+
+    void exec(
+            _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values,
+            _megdnn_tensor_in wqkv, _megdnn_tensor_out out,
+            _megdnn_tensor_out reserveSpace, _megdnn_workspace workspace) override;
+    void deduce_layout(
+            const TensorLayout& queries, const TensorLayout& keys,
+            const TensorLayout& values, const TensorLayout& wqkv, TensorLayout& out,
+            TensorLayout& reserveSpace);
+    size_t get_reservespace_in_bytes(
+            const TensorLayout& queries, const TensorLayout& keys,
+            const TensorLayout& values, const TensorLayout& wqkv,
+            const TensorLayout& out, const TensorLayout& reserveSpace) override;
+    size_t get_workspace_in_bytes(
+            const TensorLayout& queries, const TensorLayout& keys,
+            const TensorLayout& values, const TensorLayout& wqkv,
+            const TensorLayout& out, const TensorLayout& reserveSpace) override;
+};
+
+class MultiHeadAttnBackwardImpl final : public MultiHeadAttnBackward {
+public:
+    using MultiHeadAttnBackward::MultiHeadAttnBackward;
+#if CUDNN_VERSION >= 8004
+    MultiHeadAttnStatus desc_status;
+#endif
+    void exec(
+            _megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys,
+            _megdnn_tensor_in values, _megdnn_tensor_in wqkv,
+            _megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries,
+            _megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues,
+            _megdnn_tensor_out dweights, _megdnn_workspace workspace) override;
+    size_t get_workspace_in_bytes(
+            const TensorLayout& diff, const TensorLayout& queries,
+            const TensorLayout& keys, const TensorLayout& values,
+            const TensorLayout& wqkv, const TensorLayout& reserveSpace,
+            const TensorLayout& dqueries, const TensorLayout& dkeys,
+            const TensorLayout& dvalues, const TensorLayout& dweights) override;
+};
+}  // namespace cuda
+}  // namespace megdnn
+// vim: syntax=cpp.doxygen
diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp
index bc1db909c00094eb3c8457a7237d33ff512fae65..dc741f937669be0a9b88465d2f6a1660dcae4fa4 100644
--- a/dnn/src/naive/handle.cpp
+++ b/dnn/src/naive/handle.cpp
@@ -54,6 +54,7 @@
 #include "src/naive/matrix_mul/opr_impl.h"
 #include "src/naive/max_tensor_diff/opr_impl.h"
 #include "src/naive/mesh_indexing/opr_impl.h"
+#include "src/naive/multi_head_attn/opr_impl.h"
 #include "src/naive/norm/opr_impl.h"
 #include "src/naive/padding/opr_impl.h"
 #include "src/naive/param_pack/opr_impl.h"
diff --git a/dnn/src/naive/multi_head_attn/opr_impl.cpp b/dnn/src/naive/multi_head_attn/opr_impl.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..773810601f459c873f86e25f7db6add8c4af6d51
--- /dev/null
+++ b/dnn/src/naive/multi_head_attn/opr_impl.cpp
@@ -0,0 +1,56 @@
+#include "src/naive/multi_head_attn/opr_impl.h"
+#include "megdnn/oprs/linalg.h"
+#include "src/common/utils.cuh"
+
+namespace megdnn {
+namespace naive {
+
+using Param = MultiHeadAttnBase::Param;
+
+size_t MultiHeadAttnForwardImpl::get_workspace_in_bytes(
+        const TensorLayout& queries, const TensorLayout& keys,
+        const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& out,
+        const TensorLayout& reserveSpace) {
+    MEGDNN_MARK_USED_VAR(queries);
+    MEGDNN_MARK_USED_VAR(keys);
+    MEGDNN_MARK_USED_VAR(values);
+    MEGDNN_MARK_USED_VAR(wqkv);
+    MEGDNN_MARK_USED_VAR(out);
+    MEGDNN_MARK_USED_VAR(reserveSpace);
+    megdnn_throw("unsupported naive multiheadattn forward\n");
+}
+
+void MultiHeadAttnForwardImpl::exec(
+        _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values,
+        _megdnn_tensor_in wqkv, _megdnn_tensor_out out, _megdnn_tensor_out reserveSpace,
+        _megdnn_workspace workspace) {
+    MEGDNN_MARK_USED_VAR(queries);
+    MEGDNN_MARK_USED_VAR(keys);
+    MEGDNN_MARK_USED_VAR(values);
+    MEGDNN_MARK_USED_VAR(wqkv);
+    MEGDNN_MARK_USED_VAR(out);
+    MEGDNN_MARK_USED_VAR(reserveSpace);
+    check_exec(
+            queries.layout, keys.layout, values.layout, wqkv.layout, out.layout,
+            reserveSpace.layout, workspace.size);
+
+    megdnn_throw("unsupported naive multiheadattn forward\n");
+}
+
+void MultiHeadAttnBackwardImpl::exec(
+        _megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys,
+        _megdnn_tensor_in values, _megdnn_tensor_in wqkv,
+        _megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries,
+        _megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues,
+        _megdnn_tensor_out dweights, _megdnn_workspace workspace) {
+    check_exec(
+            diff.layout, queries.layout, keys.layout, values.layout, wqkv.layout,
+            reserveSpace.layout, dqueries.layout, dkeys.layout, dvalues.layout,
+            dweights.layout, workspace.size);
+
+    megdnn_throw("unsupported naive multiheadattn backward\n");
+}
+
+}  // namespace naive
+}  // namespace megdnn
+   // vim: syntax=cpp.doxygen
diff --git a/dnn/src/naive/multi_head_attn/opr_impl.h b/dnn/src/naive/multi_head_attn/opr_impl.h
new file mode 100644
index 0000000000000000000000000000000000000000..5fb1cba9120b545b1167757045cd55e37942e16a
--- /dev/null
+++ b/dnn/src/naive/multi_head_attn/opr_impl.h
@@ -0,0 +1,55 @@
+#pragma once
+#include <memory>
+#include "megdnn/oprs.h"
+#include "megdnn/oprs/cv.h"
+#include "megdnn/oprs/general.h"
+#include "megdnn/oprs/linalg.h"
+#include "megdnn/oprs/nn.h"
+
+namespace megdnn {
+namespace naive {
+
+class MultiHeadAttnForwardImpl final : public MultiHeadAttnForward {
+public:
+    using MultiHeadAttnForward::MultiHeadAttnForward;
+    void exec(
+            _megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values,
+            _megdnn_tensor_in wqkv, _megdnn_tensor_out out,
+            _megdnn_tensor_out reserveSpace, _megdnn_workspace workspace) override;
+    size_t get_workspace_in_bytes(
+            const TensorLayout& queries, const TensorLayout& keys,
+            const TensorLayout& values, const TensorLayout& wqkv,
+            const TensorLayout& out, const TensorLayout& reserveSpace) override;
+    size_t get_reservespace_in_bytes(
+            const TensorLayout& /*queries*/, const TensorLayout& /*keys*/,
+            const TensorLayout& /*values*/, const TensorLayout& /*wqkv*/,
+            const TensorLayout& /*out*/,
+            const TensorLayout& /*reserveSpace*/) override {
+        return 0;
+    }
+};
+
+class MultiHeadAttnBackwardImpl final : public MultiHeadAttnBackward {
+public:
+    using MultiHeadAttnBackward::MultiHeadAttnBackward;
+    void exec(
+            _megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys,
+            _megdnn_tensor_in values, _megdnn_tensor_in wqkv,
+            _megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries,
+            _megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues,
+            _megdnn_tensor_out dweights, _megdnn_workspace workspace) override;
+    size_t get_workspace_in_bytes(
+            const TensorLayout& /*diff*/, const TensorLayout& /* queries*/,
+            const TensorLayout& /*keyes*/, const TensorLayout& /* values*/,
+            const TensorLayout& /*wqkv*/, const TensorLayout& /* reserveSpace*/,
+            const TensorLayout& /*dqueries*/, const TensorLayout& /* dkeyes*/,
+            const TensorLayout& /*dvalues*/,
+            const TensorLayout& /* dweights*/) override {
+        return 0;
+    }
+};
+
+}  // namespace naive
+}  // namespace megdnn
+
+// vim: syntax=cpp.doxygen
diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py
index 22a780f8e103d21bd5b201fd72310185f3037804..4bed60525e75c0a61cc314b7432641b1f3ab3ee5 100644
--- a/imperative/python/megengine/functional/nn.py
+++ b/imperative/python/megengine/functional/nn.py
@@ -35,7 +35,7 @@ from ..core.tensor.utils import (
     subgraph,
     subgraph_fn,
 )
-from ..device import get_default_device
+from ..device import get_cudnn_version, get_default_device, is_cuda_available
 from ..distributed import WORLD, is_distributed
 from ..jit import exclude_from_trace
 from ..logger import get_logger
@@ -104,6 +104,7 @@ __all__ = [
     "warp_perspective",
     "pixel_shuffle",
     "region_restricted_conv",
+    "multi_head_attention",
 ]
 
 
@@ -1053,7 +1054,7 @@ def instance_norm(
     r"""Applies instance normalization to the input.
 
     Refer to :class:`~.InstanceNorm` for more information.
-    
+
     Args:
         inp: input tensor.
         affine: whether to use learnable affine parameters (weight, bias)
@@ -1083,7 +1084,7 @@ def group_norm(
     r"""Applies group normalization to the input.
 
     Refer to :class:`~.GroupNorm` for more information.
-    
+
     Args:
         inp: input tensor.
         num_groups: number of groups to separate the channels into
@@ -2052,7 +2053,82 @@ def region_restricted_conv(
     return output
 
 
-from .quantized import conv_bias_activation  # isort:skip
+def multi_head_attention(
+    query: Tensor,
+    key: Tensor,
+    value: Tensor,
+    embed_dim: int,
+    num_heads: int,
+    attn_drop: float,
+    out_drop: float,
+    io_weight_bias: Optional[Tensor],
+    bias: bool = False,
+    reslink: bool = False,
+    training: bool = True,
+    attn_mask: bool = False,
+    enable_qproj: bool = True,
+    enable_kproj: bool = True,
+    enable_vproj: bool = True,
+    enable_oproj: bool = True,
+):
+    r"""Allows the model to jointly attend to information
+    from different representation subspaces.
+    See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
+
+    .. math::
+        \text{MultiHeadAttn}\big(q,K,V, W_Q, W_V, W_O\big) = \sum^{nHeads-1}_{i=0}W_{O,i}h_i
+
+    where :math:`h_i=W_{V,i}V \text{Softmax}\Big( \text{smScaler} \cdot K^TW^T_{K,i}W_{Q,i}q \Big),\text{for }i\text{ = 0 ... nHeads-1}`.
+
+    See :class:`~.module.MultiHeadAttn` for more details.
+    
+    Note: This API is experimental, and there is a possibility of subsequent changes. Currently, only the cuda platform is supported, and if the cudnn version >=8.6.0, the calculation results are completely correct; If the cudnn version >=8.0.4 but <8.6.0, if there is a bias, only the dbias result calculated from the backward is incorrect. If there is no bias, the forward and backward calculations are correct; If the cudnn version is less than 8.0.4, this operator is not supported.
+    
+    Args:
+        query, key, value: map a query and a set of key-value pairs to an output.
+            See "Attention Is All You Need" for more details.
+        embed_dim: total dimension of the model.
+        num_heads: parallel attention heads.
+        attn_drop: probability of an element to be zeroed, used in attention matrix.
+        out_drop: probability of an element to be zeroed, used in final output.
+        io_weight_bias: input/output projection weight/bias all in one, used for cudnn api.
+        bias: used to indicate a bias in io_weight_bias, used for cudnn api.
+        reslink: add input query to final output.
+        training: will apply dropout if is ``True``.
+        attn_mask: used to indicate whether to add a mask to the attention matrix. 
+            By default, the upper right triangle of the mask is -inf, and the diagonal and lower left triangle are all 0.
+            Default: `True`
+        enable_qproj: enable query weight projection. Default: ``True``.
+        enable_kproj: enable key weight projection. Default: ``True``.
+        enable_vproj: enable value weight projection. Default: ``True``.
+        enable_oproj: enable output weight projection. Default: ``True``.
+    """
+
+    head_dim = embed_dim // num_heads
+    smScaler = head_dim ** -0.5
+
+    op = builtin.MultiHeadAttn(
+        num_heads=num_heads,
+        sm_scaler=smScaler,
+        attn_prob=attn_drop,
+        out_prob=out_drop,
+        reslink=reslink,
+        training=training,
+        input_order=0,
+        seed=_get_global_rng_seed(),
+        bias=bias,
+        attn_mask=attn_mask,
+        enable_qproj=enable_qproj,
+        enable_kproj=enable_kproj,
+        enable_vproj=enable_vproj,
+        enable_oproj=enable_oproj,
+    )
+
+    out, reserveSpace = apply(op, query, key, value, io_weight_bias)
+    return out
+
+
 from .loss import *  # isort:skip
 from .metric import *  # isort:skip
 from .vision import *  # isort:skip
+from .quantized import conv_bias_activation  # isort:skip
diff --git a/imperative/python/megengine/module/__init__.py b/imperative/python/megengine/module/__init__.py
index 985f4275e98cbca8b789b82ffe8ebe7468d5317c..89e1cbd0201ae285a65cddcd70ec6ae9aa4feafa 100644
--- a/imperative/python/megengine/module/__init__.py
+++ b/imperative/python/megengine/module/__init__.py
@@ -27,6 +27,7 @@ from .identity import Identity
 from .linear import Linear
 from .lrn import LocalResponseNorm
 from .module import Module
+from .multiheadattn import MultiHeadAttention
 from .normalization import GeneralNorm, GroupNorm, InstanceNorm, LayerNorm
 from .padding import Pad
 from .pixel_shuffle import PixelShuffle
diff --git a/imperative/python/megengine/module/activation.py b/imperative/python/megengine/module/activation.py
index 5f1c7d092daca74cbe20db0393e73b1bb9d94d1e..9de1b4b44a1419eed05205a5e8336edd3a1ba8a3 100644
--- a/imperative/python/megengine/module/activation.py
+++ b/imperative/python/megengine/module/activation.py
@@ -3,6 +3,7 @@ import numpy as np
 
 from ..functional import gelu, leaky_relu, prelu, relu, sigmoid, silu, softmax
 from ..tensor import Parameter
+from .init import ones_, zeros_
 from .module import Module
 
 
diff --git a/imperative/python/megengine/module/multiheadattn.py b/imperative/python/megengine/module/multiheadattn.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad80b7d62e1310578d7fd32a58aa032c1729a5cb
--- /dev/null
+++ b/imperative/python/megengine/module/multiheadattn.py
@@ -0,0 +1,159 @@
+from typing import Optional
+
+import numpy as np
+
+import megengine as mge
+import megengine.functional as F
+from megengine import Parameter
+
+from ..device import get_cudnn_version, is_cuda_available
+from ..functional.nn import multi_head_attention
+from ..tensor import Tensor
+from .init import ones_, zeros_
+from .module import Module
+
+
+class MultiHeadAttention(Module):
+    r"""Allows the model to jointly attend to information
+    from different representation subspaces.
+    See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
+
+    .. math::
+        \text{MultiHeadAttn}\big(q,K,V, W_Q, W_V, W_O\big) = \sum^{nHeads-1}_{i=0}W_{O,i}h_i
+
+    where :math:`h_i=W_{V,i}V \text{Softmax}\Big( \text{smScaler} \cdot K^TW^T_{K,i}W_{Q,i}q \Big),\text{for }i\text{ = 0 ... nHeads-1}`.
+    
+    Note: This API is experimental, and there is a possibility of subsequent changes. Currently, only the cuda platform is supported, and if the cudnn version >=8.6.0, the calculation results are completely correct; If the cudnn version >=8.0.4 but <8.6.0, if there is a bias, only the dbias result calculated from the backward is incorrect. If there is no bias, the forward and backward calculations are correct; If the cudnn version is less than 8.0.4, this operator is not supported.
+
+    Args:
+        embed_dim: Total dimension of the model.
+        num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
+            across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
+        dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
+        bias: If specified, adds bias to input / output projection layers. Default: ``True``.
+        kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
+        vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
+        enable_qproj: enable query weight projection. Default: ``True``.
+        enable_kproj: enable key weight projection. Default: ``True``.
+        enable_vproj: enable value weight projection. Default: ``True``.
+        enable_oproj: enable output weight projection. Default: ``True``.
+
+    Examples::
+        >>> import numpy as np
+        >>> batch_size, seq_len, embed_dim, num_heads = 2, 4, 4, 2
+        >>> x = Tensor(np.arange(batch_size * seq_len * embed_dim).astype(np.float32).reshape(batch_size, seq_len, embed_dim))
+        >>> multihead_attn = M.MultiHeadAttention(embed_dim, num_heads)
+        >>> if is_cuda_available() and get_cudnn_version() >= 8004:
+        ...     out = multihead_attn(x, x, x)
+        ...     out.numpy().shape
+        ... else:
+        ...     print(np.zeros((2,4,4)).shape)
+        (2, 4, 4)
+    """
+
+    def __init__(
+        self,
+        embed_dim,
+        num_heads,
+        attn_dropout=0.0,
+        out_dropout=0.0,
+        kdim=None,
+        vdim=None,
+        bias=True,
+        enable_qproj=True,
+        enable_kproj=True,
+        enable_vproj=True,
+        enable_oproj=True,
+        **kwargs
+    ):
+        super().__init__(**kwargs)
+        self.embed_dim = embed_dim
+        self.kdim = kdim if kdim is not None else embed_dim
+        self.vdim = vdim if vdim is not None else embed_dim
+        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
+
+        self.num_heads = num_heads
+        self.attn_dropout = attn_dropout
+        self.out_dropout = out_dropout
+        self.head_dim = embed_dim // num_heads
+        assert (
+            self.head_dim * num_heads == self.embed_dim
+        ), "embed_dim must be divisible by num_heads"
+        assert (
+            self._qkv_same_embed_dim
+        ), "it does not support the case where q, k, and v are different."
+        self.bias = bias
+
+        self.enable_qproj = enable_qproj
+        self.enable_kproj = enable_kproj
+        self.enable_vproj = enable_vproj
+        self.enable_oproj = enable_oproj
+        self.nproj = enable_qproj + enable_kproj + enable_vproj + enable_oproj
+
+        if self.bias:
+            io_weight = np.ones((embed_dim, self.nproj * embed_dim))
+            io_bias = np.zeros((1, self.nproj * embed_dim))
+            self.io_weight_bias = Parameter(
+                np.concatenate((io_weight, io_bias), axis=0), dtype="float32"
+            )
+        else:
+            self.io_weight_bias = Parameter(
+                np.ones((self.nproj * embed_dim, embed_dim), dtype="float32")
+            )
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        self.attn_dropout = 0.0
+        self.out_dropout = 0.0
+        if self.bias:
+            io_weight = np.ones((self.embed_dim, self.nproj * self.embed_dim))
+            io_bias = np.zeros((1, self.nproj * self.embed_dim))
+            self.io_weight_bias._reset(np.concatenate((io_weight, io_bias), axis=0))
+        else:
+            ones_(self.io_weight_bias)
+
+    def forward(
+        self, query, key, value, attn_mask: bool = True,
+    ):
+        r"""
+    Args:
+        query: Query embeddings of shape :math:`(N, L, E_q)`, where :math:`N` is the batch size, :math:`L` is the target sequence length,
+            and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against
+            key-value pairs to produce the output. See "Attention Is All You Need" for more details.
+        key: Key embeddings of shape :math:`(N, S, E_k)`, where :math:`N` is the batch size, :math:`S` is the source sequence length, and
+            :math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details.
+        value: Value embeddings of shape :math:`(N, S, E_v)`, where :math:`N` is the batch size, :math:`S` is the source sequence length, and
+            :math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details.
+        attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
+            :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
+            :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
+            broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
+            
+
+    Outputs:
+        - **attn_output** - Attention outputs of shape :math:`(N, L, E)`, 
+          where :math:`L` is the target sequence length, :math:`N` is
+          the batch size, and :math:`E` is the embedding dimension ``embed_dim``.
+        """
+
+        return multi_head_attention(
+            query,
+            key,
+            value,
+            self.embed_dim,
+            self.num_heads,
+            self.attn_dropout,
+            self.out_dropout,
+            self.io_weight_bias,
+            self.bias,
+            training=self.training,
+            attn_mask=attn_mask,
+            enable_qproj=self.enable_qproj,
+            enable_kproj=self.enable_kproj,
+            enable_vproj=self.enable_vproj,
+            enable_oproj=self.enable_oproj,
+        )
+
+    def _module_info_string(self) -> str:
+        s = "embed_dim={embed_dim}, num_heads={num_heads}, dropout={dropout}, bias={bias}, kdim={kdim}, vdim={vdim}"
+        return s.format(**self.__dict__)
diff --git a/imperative/src/impl/ops/rng.cpp b/imperative/src/impl/ops/rng.cpp
index 50421e9b193d23ddeb4472af6b7002c6d391d2d6..19e012956c39392b483002b65d624df536ba61ed 100644
--- a/imperative/src/impl/ops/rng.cpp
+++ b/imperative/src/impl/ops/rng.cpp
@@ -285,6 +285,25 @@ struct OpMeth<Dropout> {
     }
 };
 
+template <>
+struct OpMeth<MultiHeadAttn> {
+    using DnnOp = megdnn::MultiHeadAttn;
+    using Param = DnnOp::Param;
+    using OpNode = mgb::opr::MultiHeadAttn;
+    static Param make_param(const MultiHeadAttn& opdef) {
+        auto handle_seed = RNGDnnOpManager::get_seed(opdef.handle);
+        mgb_assert(
+                handle_seed == opdef.seed,
+                "inconsistent multiheadattn seed: dropout op: %lu handle: %lu",
+                handle_seed, opdef.seed);
+        return {opdef.num_heads,    opdef.sm_scaler,    opdef.input_order,
+                opdef.reslink,      opdef.training,     opdef.bias,
+                opdef.attn_mask,    opdef.enable_qproj, opdef.enable_kproj,
+                opdef.enable_vproj, opdef.enable_oproj, handle_seed,
+                opdef.attn_prob,    opdef.out_prob};
+    }
+};
+
 template <bool>
 struct _InferLayout;
 
@@ -401,6 +420,14 @@ _INST_RNG_MAKER(2)
 #undef _FOR_EACH_OUT
 #undef _FOR_EACH_IN
 
+#define _FOR_EACH_IN(subfix) \
+    inputs[0] subfix, inputs[1] subfix, inputs[2] subfix, inputs[3] subfix,
+#define _FOR_EACH_OUT(subfix) outputs[0] subfix, outputs[1] subfix
+_INST_RNG_INVOLKER(4, 2)
+_INST_RNG_MAKER(4)
+#undef _FOR_EACH_OUT
+#undef _FOR_EACH_IN
+
 #undef _INST_RNG_INVOLKER
 #undef _INST_RNG_MAKER
 
@@ -506,6 +533,39 @@ SmallVector<LogicalTensorDesc> infer_output_attrs<Dropout>(
     return dests;
 }
 
+template <>
+SmallVector<LogicalTensorDesc> infer_output_attrs<MultiHeadAttn>(
+        const OpDef& op, const SmallVector<TensorPtr>& inputs) {
+    SmallVector<LogicalTensorDesc> dests(2);
+    auto&& cn = inputs[0]->comp_node();
+
+    dests[0].comp_node = cn;
+    dests[0].layout = TensorLayout(inputs[0]->layout());
+    dests[0].layout.dtype = inputs[0]->layout().dtype;
+
+    auto get_reservespace_in_bytes = [&]() -> size_t {
+        // retrieve dnn_op from glob cache
+        auto&& rng = op.cast_final_safe<MultiHeadAttn>();
+        auto handle = rng.handle;
+        if (!handle) {
+            handle = RNGDnnOpManager::get_default_handle(cn);
+        }
+        auto dnn_op_thread_safe =
+                RNGDnnOpManager::inst().get_dnn_op<megdnn::MultiHeadAttn>(
+                        handle, reinterpret_cast<size_t>(op.dyn_typeinfo()), cn);
+        auto dnn_op = std::get<1>(dnn_op_thread_safe);
+        dnn_op->param() = OpMeth<MultiHeadAttn>::make_param(rng);
+
+        return dnn_op->get_reservespace_in_bytes(
+                inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
+                inputs[3]->layout(), {}, {});
+    };
+    dests[1].comp_node = cn;
+    dests[1].layout =
+            TensorLayout(TensorShape({get_reservespace_in_bytes()}), dtype::Byte());
+    return dests;
+}
+
 template <typename Op>
 SmallVector<TensorPtr> apply_on_physical_tensor(
         const OpDef& def, const SmallVector<TensorPtr>& inputs,
@@ -600,6 +660,44 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dro
     return {dests, success};
 }
 
+template <>
+std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<
+        MultiHeadAttn>(const OpDef& op, const SmallVector<LogicalTensorDesc>& inputs) {
+    bool success = inputs[0].layout.ndim != 0;
+
+    SmallVector<LogicalTensorDesc> dests(2);
+    auto cn = inputs[0].comp_node;
+    dests[0].comp_node = cn;
+    dests[0].layout = TensorLayout(inputs[0].layout);
+    dests[0].layout.dtype = inputs[0].layout.dtype;
+
+    auto get_reservespace_in_bytes = [&]() -> size_t {
+        auto&& rng = op.cast_final_safe<MultiHeadAttn>();
+        auto handle = rng.handle;
+        if (!handle) {
+            handle = RNGDnnOpManager::get_default_handle(cn);
+        }
+        auto dnn_op_thread_safe =
+                RNGDnnOpManager::inst().get_dnn_op<megdnn::MultiHeadAttn>(
+                        handle, reinterpret_cast<size_t>(op.dyn_typeinfo()), cn);
+        auto dnn_op = std::get<1>(dnn_op_thread_safe);
+        dnn_op->param() = OpMeth<MultiHeadAttn>::make_param(rng);
+
+        return dnn_op->get_reservespace_in_bytes(
+                inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout,
+                {}, {});
+    };
+    dests[1].comp_node = cn;
+    if (success) {
+        dests[1].layout =
+                TensorLayout(TensorShape({get_reservespace_in_bytes()}), dtype::Byte());
+    } else {
+        dests[1].layout = TensorLayout(dtype::Byte());
+    }
+
+    return {dests, success};
+}
+
 template <typename Op>
 SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
         const OpDef& def, const SmallVector<TensorPtr>& inputs) {
@@ -647,6 +745,7 @@ REG_RNG_OP(PoissonRNG, SymbolVar)
 REG_RNG_OP(BetaRNG, SymbolVar)
 REG_RNG_OP(ShuffleRNG, SymbolVarArray)
 REG_RNG_OP(Dropout, SymbolVarArray)
+REG_RNG_OP(MultiHeadAttn, SymbolVarArray)
 #undef REG_RNG_OP
 
 }  // namespace mgb::imperative::rng
diff --git a/imperative/tablegen/generated/hash.txt b/imperative/tablegen/generated/hash.txt
index 21afdaef6c39df0c0c0f2039bc66e7133a85182f..ab7a867652a7546e3a5a7787eb1999710dacef95 100644
--- a/imperative/tablegen/generated/hash.txt
+++ b/imperative/tablegen/generated/hash.txt
@@ -1,7 +1,7 @@
-148f3844ee8787250cd231eb3c1989c3  ../../dnn/scripts/opr_param_defs.py
-b603857c46345dcb9f1693f49217b269  ../../src/core/include/megbrain/ir/ops.td
-ae54e2eba267dc21d8c648963df23a90  generated/opdef.h.inl
-94b20dcecd3dea69883d46ca7b8482be  generated/opdef.cpp.inl
-4ae5f0198e97e69eb381411f3d60e8c8  generated/opdef.py.inl
-4971c6b2ba7f6fca395d73c554526a0e  generated/opdef.cpy.inl
+c5a5d1bd44473912f14cecee3df6409e  ../../dnn/scripts/opr_param_defs.py
+4ed3e8cbef0fa5f4d6824d8d55dec722  ../../src/core/include/megbrain/ir/ops.td
+dc2d4ec8f4f5e203ce0a76bc20f62529  generated/opdef.h.inl
+906957f12994d43c69248a6acfefa396  generated/opdef.cpp.inl
+8817af8997ba0cc00048e71093755238  generated/opdef.py.inl
+c43ae8b706e3f3658fe3cc0f60061981  generated/opdef.cpy.inl
 71e1462bf4d882e2615c3c632cb671cc  generated/enum_macro.h
diff --git a/imperative/tablegen/generated/opdef.cpp.inl b/imperative/tablegen/generated/opdef.cpp.inl
index 01ca5845807577001e4b52af9005ae10bebc8c21..bbde2e9b665fdb16eaa930a3819570c29d6fd314 100644
--- a/imperative/tablegen/generated/opdef.cpp.inl
+++ b/imperative/tablegen/generated/opdef.cpp.inl
@@ -5186,6 +5186,96 @@ OP_TRAIT_REG(MeshIndexing, MeshIndexing)
     .props(MeshIndexing_props_impl)
     .make_name(MeshIndexing_make_name_impl);
 
+MGB_DYN_TYPE_OBJ_FINAL_IMPL(MultiHeadAttn);
+
+namespace {
+size_t MultiHeadAttn_hash_impl(const OpDef& def_) {
+    auto&& op_ = def_.cast_final_safe<MultiHeadAttn>();
+    static_cast<void>(op_);
+
+    return mgb::hash_pair_combine(
+      mgb::hash(op_.dyn_typeinfo()),
+      mgb::hash_pair_combine(
+        mgb::hash(op_.handle),
+        mgb::hash_pair_combine(
+          mgb::hash(op_.num_heads),
+          mgb::hash_pair_combine(
+            mgb::hash(op_.sm_scaler),
+            mgb::hash_pair_combine(
+              mgb::hash(op_.input_order),
+              mgb::hash_pair_combine(
+                mgb::hash(op_.reslink),
+                mgb::hash_pair_combine(
+                  mgb::hash(op_.training),
+                  mgb::hash_pair_combine(
+                    mgb::hash(op_.bias),
+                    mgb::hash_pair_combine(
+                      mgb::hash(op_.attn_mask),
+                      mgb::hash_pair_combine(
+                        mgb::hash(op_.enable_qproj),
+                        mgb::hash_pair_combine(
+                          mgb::hash(op_.enable_kproj),
+                          mgb::hash_pair_combine(
+                            mgb::hash(op_.enable_vproj),
+                            mgb::hash_pair_combine(
+                              mgb::hash(op_.enable_oproj),
+                              mgb::hash_pair_combine(
+                                mgb::hash(op_.attn_prob),
+                                mgb::hash(op_.out_prob)
+                                )
+                              )
+                            )
+                          )
+                        )
+                      )
+                    )
+                  )
+                )
+              )
+            )
+          )
+        )
+      );
+  }
+bool MultiHeadAttn_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
+    auto &&a_ = lhs_.cast_final_safe<MultiHeadAttn>(),
+         &&b_ = rhs_.cast_final_safe<MultiHeadAttn>();
+    static_cast<void>(a_);
+    static_cast<void>(b_);
+return a_.handle == b_.handle && a_.num_heads == b_.num_heads && a_.sm_scaler == b_.sm_scaler && a_.input_order == b_.input_order && a_.reslink == b_.reslink && a_.training == b_.training && a_.bias == b_.bias && a_.attn_mask == b_.attn_mask && a_.enable_qproj == b_.enable_qproj && a_.enable_kproj == b_.enable_kproj && a_.enable_vproj == b_.enable_vproj && a_.enable_oproj == b_.enable_oproj && a_.attn_prob == b_.attn_prob && a_.out_prob == b_.out_prob;}
+std::vector<std::pair<const char*, std::string>> MultiHeadAttn_props_impl(const OpDef& def_) {
+    auto&& op_ = def_.cast_final_safe<MultiHeadAttn>();
+    static_cast<void>(op_);
+    std::vector<std::pair<const char*, std::string>> props_;
+    props_.emplace_back("num_heads", std::to_string(op_.num_heads));
+    props_.emplace_back("sm_scaler", std::to_string(op_.sm_scaler));
+    props_.emplace_back("input_order", std::to_string(op_.input_order));
+    props_.emplace_back("reslink", std::to_string(op_.reslink));
+    props_.emplace_back("training", std::to_string(op_.training));
+    props_.emplace_back("bias", std::to_string(op_.bias));
+    props_.emplace_back("attn_mask", std::to_string(op_.attn_mask));
+    props_.emplace_back("enable_qproj", std::to_string(op_.enable_qproj));
+    props_.emplace_back("enable_kproj", std::to_string(op_.enable_kproj));
+    props_.emplace_back("enable_vproj", std::to_string(op_.enable_vproj));
+    props_.emplace_back("enable_oproj", std::to_string(op_.enable_oproj));
+    props_.emplace_back("seed", std::to_string(op_.seed));
+    props_.emplace_back("attn_prob", std::to_string(op_.attn_prob));
+    props_.emplace_back("out_prob", std::to_string(op_.out_prob));
+    props_.emplace_back("handle", std::to_string(op_.handle));
+    return props_;
+}
+std::string MultiHeadAttn_make_name_impl(const OpDef& def_) {
+    auto&& op_ = def_.cast_final_safe<MultiHeadAttn>();
+    static_cast<void>(op_);
+    return "MultiHeadAttn";
+}
+} // anonymous namespace
+OP_TRAIT_REG(MultiHeadAttn, MultiHeadAttn)
+    .hash(MultiHeadAttn_hash_impl)
+    .is_same_st(MultiHeadAttn_is_same_st_impl)
+    .props(MultiHeadAttn_props_impl)
+    .make_name(MultiHeadAttn_make_name_impl);
+
 MGB_DYN_TYPE_OBJ_FINAL_IMPL(NMSKeep);
 
 namespace {
diff --git a/imperative/tablegen/generated/opdef.cpy.inl b/imperative/tablegen/generated/opdef.cpy.inl
index d92339ce9f399e2052b7861aa0ba4fd91261f65f..98a5af8f2005284d566ee74cae7b5f3eb9e3df83 100644
--- a/imperative/tablegen/generated/opdef.cpy.inl
+++ b/imperative/tablegen/generated/opdef.cpy.inl
@@ -15043,6 +15043,367 @@ void _init_py_MeshIndexing(py::module m) {
     mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MeshIndexing::typeinfo(), &py_type).second);
 }
 
+PyOpDefBegin(MultiHeadAttn) // {
+    static PyGetSetDef py_getsetters[];
+    static PyMethodDef tp_methods[];
+    
+    static PyObject* getstate(PyObject* self, PyObject*) {
+        auto& opdef = reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst();
+        static_cast<void>(opdef);
+        std::unordered_map<std::string, py::object> state {
+            
+            {"num_heads", serialization<decltype(opdef.num_heads)>::dump(opdef.num_heads)},
+            {"sm_scaler", serialization<decltype(opdef.sm_scaler)>::dump(opdef.sm_scaler)},
+            {"input_order", serialization<decltype(opdef.input_order)>::dump(opdef.input_order)},
+            {"reslink", serialization<decltype(opdef.reslink)>::dump(opdef.reslink)},
+            {"training", serialization<decltype(opdef.training)>::dump(opdef.training)},
+            {"bias", serialization<decltype(opdef.bias)>::dump(opdef.bias)},
+            {"attn_mask", serialization<decltype(opdef.attn_mask)>::dump(opdef.attn_mask)},
+            {"enable_qproj", serialization<decltype(opdef.enable_qproj)>::dump(opdef.enable_qproj)},
+            {"enable_kproj", serialization<decltype(opdef.enable_kproj)>::dump(opdef.enable_kproj)},
+            {"enable_vproj", serialization<decltype(opdef.enable_vproj)>::dump(opdef.enable_vproj)},
+            {"enable_oproj", serialization<decltype(opdef.enable_oproj)>::dump(opdef.enable_oproj)},
+            {"seed", serialization<decltype(opdef.seed)>::dump(opdef.seed)},
+            {"attn_prob", serialization<decltype(opdef.attn_prob)>::dump(opdef.attn_prob)},
+            {"out_prob", serialization<decltype(opdef.out_prob)>::dump(opdef.out_prob)},
+            {"handle", serialization<decltype(opdef.handle)>::dump(opdef.handle)}
+        };
+        return py::cast(state).release().ptr();
+    }
+    static PyObject* setstate(PyObject* self, PyObject* args) {
+        PyObject* dict = PyTuple_GetItem(args, 0);
+        if (!dict) return NULL;
+        auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
+        auto& opdef = reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst();
+        static_cast<void>(opdef);
+        
+        {
+        auto&& iter = state.find("num_heads");
+        if (iter != state.end()) {
+            opdef.num_heads = serialization<decltype(opdef.num_heads)>::load(iter->second);
+        }
+        }
+
+        {
+        auto&& iter = state.find("sm_scaler");
+        if (iter != state.end()) {
+            opdef.sm_scaler = serialization<decltype(opdef.sm_scaler)>::load(iter->second);
+        }
+        }
+
+        {
+        auto&& iter = state.find("input_order");
+        if (iter != state.end()) {
+            opdef.input_order = serialization<decltype(opdef.input_order)>::load(iter->second);
+        }
+        }
+
+        {
+        auto&& iter = state.find("reslink");
+        if (iter != state.end()) {
+            opdef.reslink = serialization<decltype(opdef.reslink)>::load(iter->second);
+        }
+        }
+
+        {
+        auto&& iter = state.find("training");
+        if (iter != state.end()) {
+            opdef.training = serialization<decltype(opdef.training)>::load(iter->second);
+        }
+        }
+
+        {
+        auto&& iter = state.find("bias");
+        if (iter != state.end()) {
+            opdef.bias = serialization<decltype(opdef.bias)>::load(iter->second);
+        }
+        }
+
+        {
+        auto&& iter = state.find("attn_mask");
+        if (iter != state.end()) {
+            opdef.attn_mask = serialization<decltype(opdef.attn_mask)>::load(iter->second);
+        }
+        }
+
+        {
+        auto&& iter = state.find("enable_qproj");
+        if (iter != state.end()) {
+            opdef.enable_qproj = serialization<decltype(opdef.enable_qproj)>::load(iter->second);
+        }
+        }
+
+        {
+        auto&& iter = state.find("enable_kproj");
+        if (iter != state.end()) {
+            opdef.enable_kproj = serialization<decltype(opdef.enable_kproj)>::load(iter->second);
+        }
+        }
+
+        {
+        auto&& iter = state.find("enable_vproj");
+        if (iter != state.end()) {
+            opdef.enable_vproj = serialization<decltype(opdef.enable_vproj)>::load(iter->second);
+        }
+        }
+
+        {
+        auto&& iter = state.find("enable_oproj");
+        if (iter != state.end()) {
+            opdef.enable_oproj = serialization<decltype(opdef.enable_oproj)>::load(iter->second);
+        }
+        }
+
+        {
+        auto&& iter = state.find("seed");
+        if (iter != state.end()) {
+            opdef.seed = serialization<decltype(opdef.seed)>::load(iter->second);
+        }
+        }
+
+        {
+        auto&& iter = state.find("attn_prob");
+        if (iter != state.end()) {
+            opdef.attn_prob = serialization<decltype(opdef.attn_prob)>::load(iter->second);
+        }
+        }
+
+        {
+        auto&& iter = state.find("out_prob");
+        if (iter != state.end()) {
+            opdef.out_prob = serialization<decltype(opdef.out_prob)>::load(iter->second);
+        }
+        }
+
+        {
+        auto&& iter = state.find("handle");
+        if (iter != state.end()) {
+            opdef.handle = serialization<decltype(opdef.handle)>::load(iter->second);
+        }
+        }
+        Py_RETURN_NONE;
+    }
+    static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
+    static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds);
+    static PyMethodDef py_init_methoddef;
+// };
+PyOpDefEnd(MultiHeadAttn)
+
+int PyOp(MultiHeadAttn)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
+    static const char* kwlist[] = {"num_heads", "sm_scaler", "input_order", "reslink", "training", "bias", "attn_mask", "enable_qproj", "enable_kproj", "enable_vproj", "enable_oproj", "seed", "attn_prob", "out_prob", "handle", "scope", NULL};
+    PyObject *num_heads = NULL, *sm_scaler = NULL, *input_order = NULL, *reslink = NULL, *training = NULL, *bias = NULL, *attn_mask = NULL, *enable_qproj = NULL, *enable_kproj = NULL, *enable_vproj = NULL, *enable_oproj = NULL, *seed = NULL, *attn_prob = NULL, *out_prob = NULL, *handle = NULL, *scope = NULL;
+    if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOOOOOOOOOOOOO", const_cast<char**>(kwlist), &num_heads, &sm_scaler, &input_order, &reslink, &training, &bias, &attn_mask, &enable_qproj, &enable_kproj, &enable_vproj, &enable_oproj, &seed, &attn_prob, &out_prob, &handle, &scope))
+    return -1;
+
+    if (num_heads) {
+        try {
+            // TODO: remove this guard which is used for pybind11 implicit conversion
+            py::detail::loader_life_support guard{};
+            reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().num_heads =
+                    py::cast<decltype(MultiHeadAttn::num_heads)>(py::handle(num_heads));
+        } CATCH_ALL(-1)
+    }
+
+    if (sm_scaler) {
+        try {
+            // TODO: remove this guard which is used for pybind11 implicit conversion
+            py::detail::loader_life_support guard{};
+            reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().sm_scaler =
+                    py::cast<decltype(MultiHeadAttn::sm_scaler)>(py::handle(sm_scaler));
+        } CATCH_ALL(-1)
+    }
+
+    if (input_order) {
+        try {
+            // TODO: remove this guard which is used for pybind11 implicit conversion
+            py::detail::loader_life_support guard{};
+            reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().input_order =
+                    py::cast<decltype(MultiHeadAttn::input_order)>(py::handle(input_order));
+        } CATCH_ALL(-1)
+    }
+
+    if (reslink) {
+        try {
+            // TODO: remove this guard which is used for pybind11 implicit conversion
+            py::detail::loader_life_support guard{};
+            reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().reslink =
+                    py::cast<decltype(MultiHeadAttn::reslink)>(py::handle(reslink));
+        } CATCH_ALL(-1)
+    }
+
+    if (training) {
+        try {
+            // TODO: remove this guard which is used for pybind11 implicit conversion
+            py::detail::loader_life_support guard{};
+            reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().training =
+                    py::cast<decltype(MultiHeadAttn::training)>(py::handle(training));
+        } CATCH_ALL(-1)
+    }
+
+    if (bias) {
+        try {
+            // TODO: remove this guard which is used for pybind11 implicit conversion
+            py::detail::loader_life_support guard{};
+            reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().bias =
+                    py::cast<decltype(MultiHeadAttn::bias)>(py::handle(bias));
+        } CATCH_ALL(-1)
+    }
+
+    if (attn_mask) {
+        try {
+            // TODO: remove this guard which is used for pybind11 implicit conversion
+            py::detail::loader_life_support guard{};
+            reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().attn_mask =
+                    py::cast<decltype(MultiHeadAttn::attn_mask)>(py::handle(attn_mask));
+        } CATCH_ALL(-1)
+    }
+
+    if (enable_qproj) {
+        try {
+            // TODO: remove this guard which is used for pybind11 implicit conversion
+            py::detail::loader_life_support guard{};
+            reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().enable_qproj =
+                    py::cast<decltype(MultiHeadAttn::enable_qproj)>(py::handle(enable_qproj));
+        } CATCH_ALL(-1)
+    }
+
+    if (enable_kproj) {
+        try {
+            // TODO: remove this guard which is used for pybind11 implicit conversion
+            py::detail::loader_life_support guard{};
+            reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().enable_kproj =
+                    py::cast<decltype(MultiHeadAttn::enable_kproj)>(py::handle(enable_kproj));
+        } CATCH_ALL(-1)
+    }
+
+    if (enable_vproj) {
+        try {
+            // TODO: remove this guard which is used for pybind11 implicit conversion
+            py::detail::loader_life_support guard{};
+            reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().enable_vproj =
+                    py::cast<decltype(MultiHeadAttn::enable_vproj)>(py::handle(enable_vproj));
+        } CATCH_ALL(-1)
+    }
+
+    if (enable_oproj) {
+        try {
+            // TODO: remove this guard which is used for pybind11 implicit conversion
+            py::detail::loader_life_support guard{};
+            reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().enable_oproj =
+                    py::cast<decltype(MultiHeadAttn::enable_oproj)>(py::handle(enable_oproj));
+        } CATCH_ALL(-1)
+    }
+
+    if (seed) {
+        try {
+            // TODO: remove this guard which is used for pybind11 implicit conversion
+            py::detail::loader_life_support guard{};
+            reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().seed =
+                    py::cast<decltype(MultiHeadAttn::seed)>(py::handle(seed));
+        } CATCH_ALL(-1)
+    }
+
+    if (attn_prob) {
+        try {
+            // TODO: remove this guard which is used for pybind11 implicit conversion
+            py::detail::loader_life_support guard{};
+            reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().attn_prob =
+                    py::cast<decltype(MultiHeadAttn::attn_prob)>(py::handle(attn_prob));
+        } CATCH_ALL(-1)
+    }
+
+    if (out_prob) {
+        try {
+            // TODO: remove this guard which is used for pybind11 implicit conversion
+            py::detail::loader_life_support guard{};
+            reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().out_prob =
+                    py::cast<decltype(MultiHeadAttn::out_prob)>(py::handle(out_prob));
+        } CATCH_ALL(-1)
+    }
+
+    if (handle) {
+        try {
+            // TODO: remove this guard which is used for pybind11 implicit conversion
+            py::detail::loader_life_support guard{};
+            reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().handle =
+                    py::cast<decltype(MultiHeadAttn::handle)>(py::handle(handle));
+        } CATCH_ALL(-1)
+    }
+
+    if (scope) {
+        try {
+            reinterpret_cast<PyOp(OpDef)*>(self)->op
+                ->set_scope(py::cast<std::string>(py::handle(scope)));
+        } CATCH_ALL(-1)
+    }
+
+    return 0;
+}
+
+PyGetSetDef PyOp(MultiHeadAttn)::py_getsetters[] = {
+    {const_cast<char*>("num_heads"), py_get_generic(MultiHeadAttn, num_heads), py_set_generic(MultiHeadAttn, num_heads), const_cast<char*>("num_heads"), NULL},
+    {const_cast<char*>("sm_scaler"), py_get_generic(MultiHeadAttn, sm_scaler), py_set_generic(MultiHeadAttn, sm_scaler), const_cast<char*>("sm_scaler"), NULL},
+    {const_cast<char*>("input_order"), py_get_generic(MultiHeadAttn, input_order), py_set_generic(MultiHeadAttn, input_order), const_cast<char*>("input_order"), NULL},
+    {const_cast<char*>("reslink"), py_get_generic(MultiHeadAttn, reslink), py_set_generic(MultiHeadAttn, reslink), const_cast<char*>("reslink"), NULL},
+    {const_cast<char*>("training"), py_get_generic(MultiHeadAttn, training), py_set_generic(MultiHeadAttn, training), const_cast<char*>("training"), NULL},
+    {const_cast<char*>("bias"), py_get_generic(MultiHeadAttn, bias), py_set_generic(MultiHeadAttn, bias), const_cast<char*>("bias"), NULL},
+    {const_cast<char*>("attn_mask"), py_get_generic(MultiHeadAttn, attn_mask), py_set_generic(MultiHeadAttn, attn_mask), const_cast<char*>("attn_mask"), NULL},
+    {const_cast<char*>("enable_qproj"), py_get_generic(MultiHeadAttn, enable_qproj), py_set_generic(MultiHeadAttn, enable_qproj), const_cast<char*>("enable_qproj"), NULL},
+    {const_cast<char*>("enable_kproj"), py_get_generic(MultiHeadAttn, enable_kproj), py_set_generic(MultiHeadAttn, enable_kproj), const_cast<char*>("enable_kproj"), NULL},
+    {const_cast<char*>("enable_vproj"), py_get_generic(MultiHeadAttn, enable_vproj), py_set_generic(MultiHeadAttn, enable_vproj), const_cast<char*>("enable_vproj"), NULL},
+    {const_cast<char*>("enable_oproj"), py_get_generic(MultiHeadAttn, enable_oproj), py_set_generic(MultiHeadAttn, enable_oproj), const_cast<char*>("enable_oproj"), NULL},
+    {const_cast<char*>("seed"), py_get_generic(MultiHeadAttn, seed), py_set_generic(MultiHeadAttn, seed), const_cast<char*>("seed"), NULL},
+    {const_cast<char*>("attn_prob"), py_get_generic(MultiHeadAttn, attn_prob), py_set_generic(MultiHeadAttn, attn_prob), const_cast<char*>("attn_prob"), NULL},
+    {const_cast<char*>("out_prob"), py_get_generic(MultiHeadAttn, out_prob), py_set_generic(MultiHeadAttn, out_prob), const_cast<char*>("out_prob"), NULL},
+    {const_cast<char*>("handle"), py_get_generic(MultiHeadAttn, handle), py_set_generic(MultiHeadAttn, handle), const_cast<char*>("handle"), NULL},
+    {NULL}  /* Sentinel */
+};
+
+    PyMethodDef PyOp(MultiHeadAttn)::tp_methods[] = {
+        {const_cast<char*>("__getstate__"), PyOp(MultiHeadAttn)::getstate, METH_NOARGS, "MultiHeadAttn getstate"},
+    {const_cast<char*>("__setstate__"), PyOp(MultiHeadAttn)::setstate, METH_VARARGS, "MultiHeadAttn setstate"},
+        {NULL}  /* Sentinel */
+    };
+    
+PyObject *PyOp(MultiHeadAttn)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) {
+    if (PyOp(MultiHeadAttn)::py_init(self, args, kwds) < 0) {
+        return NULL;
+    }
+    Py_RETURN_NONE;
+}
+
+PyMethodDef PyOp(MultiHeadAttn)::py_init_methoddef = {
+    "__init__",
+    (PyCFunction)PyOp(MultiHeadAttn)::py_init_proxy,
+    METH_VARARGS | METH_KEYWORDS,
+    "__init__(self, num_heads: int = ..., sm_scaler: float = ..., input_order: int = ..., reslink: bool = ..., training: bool = ..., bias: bool = ..., attn_mask: bool = ..., enable_qproj: bool = ..., enable_kproj: bool = ..., enable_vproj: bool = ..., enable_oproj: bool = ..., seed: int = ..., attn_prob: float = ..., out_prob: float = ..., handle: int = ...) -> None\n"
+};
+
+void _init_py_MultiHeadAttn(py::module m) {
+    using py_op = PyOp(MultiHeadAttn);
+    auto& py_type = PyOpType(MultiHeadAttn);
+    py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
+    py_type.tp_name = "megengine.core._imperative_rt.ops.MultiHeadAttn";
+    py_type.tp_basicsize = sizeof(PyOp(MultiHeadAttn));
+    py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
+    py_type.tp_doc = "MultiHeadAttn";
+    py_type.tp_base = &PyOpType(OpDef);
+    py_type.tp_dealloc = py_dealloc_generic<py_op>;
+    py_type.tp_new = py_new_generic<py_op>;
+    py_type.tp_init = py_op::py_init;
+    py_type.tp_methods = py_op::tp_methods;
+    py_type.tp_getset = py_op::py_getsetters;
+
+    py_type.tp_dict = PyDict_New();
+    PyObject* descr = PyDescr_NewMethod(&PyOpType(MultiHeadAttn), &PyOp(MultiHeadAttn)::py_init_methoddef);
+    PyDict_SetItemString(py_type.tp_dict, "__init__", descr);
+    mgb_assert(PyType_Ready(&py_type) >= 0);
+    
+    PyType_Modified(&py_type);
+    m.add_object("MultiHeadAttn", reinterpret_cast<PyObject*>(&py_type));
+    mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MultiHeadAttn::typeinfo(), &py_type).second);
+}
+
 PyOpDefBegin(NMSKeep) // {
     static PyGetSetDef py_getsetters[];
     static PyMethodDef tp_methods[];
@@ -22608,6 +22969,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) {
     _init_py_MatrixMul(m); \
     _init_py_MeshGrid(m); \
     _init_py_MeshIndexing(m); \
+    _init_py_MultiHeadAttn(m); \
     _init_py_NMSKeep(m); \
     _init_py_NvOf(m); \
     _init_py_Padding(m); \
diff --git a/imperative/tablegen/generated/opdef.h.inl b/imperative/tablegen/generated/opdef.h.inl
index e3bd9b9fca3964e330db267585758d288d0c02ff..5f774aaae359105370fe43db18fe10512118414b 100644
--- a/imperative/tablegen/generated/opdef.h.inl
+++ b/imperative/tablegen/generated/opdef.h.inl
@@ -1394,6 +1394,33 @@ public:
     MeshIndexing(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
 };
 
+class MultiHeadAttn : public OpDefImplBase<MultiHeadAttn> {
+    MGB_DYN_TYPE_OBJ_FINAL_DECL;
+
+public:
+    uint32_t num_heads = 1;
+    float sm_scaler = 1.f;
+    uint32_t input_order = 0;
+    bool reslink = false;
+    bool training = true;
+    bool bias = false;
+    bool attn_mask = false;
+    bool enable_qproj = true;
+    bool enable_kproj = true;
+    bool enable_vproj = true;
+    bool enable_oproj = true;
+    uint64_t seed = 0;
+    float attn_prob = 0.f;
+    float out_prob = 0.f;
+    size_t handle;
+    MultiHeadAttn() = default;
+    MultiHeadAttn(uint32_t num_heads_, float sm_scaler_, uint32_t input_order_, bool reslink_, bool training_, bool bias_, bool attn_mask_, bool enable_qproj_, bool enable_kproj_, bool enable_vproj_, bool enable_oproj_, uint64_t seed_, float attn_prob_, float out_prob_, size_t handle_, std::string scope_ = {}): num_heads(num_heads_), sm_scaler(sm_scaler_), input_order(input_order_), reslink(reslink_), training(training_), bias(bias_), attn_mask(attn_mask_), enable_qproj(enable_qproj_), enable_kproj(enable_kproj_), enable_vproj(enable_vproj_), enable_oproj(enable_oproj_), seed(seed_), attn_prob(attn_prob_), out_prob(out_prob_), handle(handle_) { set_scope(scope_); }
+    MultiHeadAttn(::megdnn::param::MultiHeadAttn packed_param_0, size_t handle_): num_heads(packed_param_0.num_heads), sm_scaler(packed_param_0.sm_scaler), input_order(packed_param_0.input_order), reslink(packed_param_0.reslink), training(packed_param_0.training), bias(packed_param_0.bias), attn_mask(packed_param_0.attn_mask), enable_qproj(packed_param_0.enable_qproj), enable_kproj(packed_param_0.enable_kproj), enable_vproj(packed_param_0.enable_vproj), enable_oproj(packed_param_0.enable_oproj), seed(packed_param_0.seed), attn_prob(packed_param_0.attn_prob), out_prob(packed_param_0.out_prob), handle(handle_) {}
+    ::megdnn::param::MultiHeadAttn param() const {
+        return {num_heads, sm_scaler, input_order, reslink, training, bias, attn_mask, enable_qproj, enable_kproj, enable_vproj, enable_oproj, seed, attn_prob, out_prob};
+    }
+};
+
 class NMSKeep : public OpDefImplBase<NMSKeep> {
     MGB_DYN_TYPE_OBJ_FINAL_DECL;
 
diff --git a/imperative/tablegen/generated/opdef.py.inl b/imperative/tablegen/generated/opdef.py.inl
index 39cf73a5b03880c1aec97575dd06276e30b6819e..b6591c362de4a3e22a8f2b79463ff9703abbe2b1 100644
--- a/imperative/tablegen/generated/opdef.py.inl
+++ b/imperative/tablegen/generated/opdef.py.inl
@@ -1477,6 +1477,27 @@ MeshIndexingInst
     .def(py::init<>())
     .def_readwrite("items", &MeshIndexing::items);
 
+py::class_<MultiHeadAttn, std::shared_ptr<MultiHeadAttn>, OpDef> MultiHeadAttnInst(m, "MultiHeadAttn");
+
+MultiHeadAttnInst
+    .def(py::init<uint32_t, float, uint32_t, bool, bool, bool, bool, bool, bool, bool, bool, uint64_t, float, float, size_t, std::string>(), py::arg("num_heads") = 1, py::arg("sm_scaler") = 1.f, py::arg("input_order") = 0, py::arg("reslink") = false, py::arg("training") = true, py::arg("bias") = false, py::arg("attn_mask") = false, py::arg("enable_qproj") = true, py::arg("enable_kproj") = true, py::arg("enable_vproj") = true, py::arg("enable_oproj") = true, py::arg("seed") = 0, py::arg("attn_prob") = 0.f, py::arg("out_prob") = 0.f, py::arg("handle"), py::arg("scope") = {})
+    .def(py::init<>())
+    .def_readwrite("num_heads", &MultiHeadAttn::num_heads)
+    .def_readwrite("sm_scaler", &MultiHeadAttn::sm_scaler)
+    .def_readwrite("input_order", &MultiHeadAttn::input_order)
+    .def_readwrite("reslink", &MultiHeadAttn::reslink)
+    .def_readwrite("training", &MultiHeadAttn::training)
+    .def_readwrite("bias", &MultiHeadAttn::bias)
+    .def_readwrite("attn_mask", &MultiHeadAttn::attn_mask)
+    .def_readwrite("enable_qproj", &MultiHeadAttn::enable_qproj)
+    .def_readwrite("enable_kproj", &MultiHeadAttn::enable_kproj)
+    .def_readwrite("enable_vproj", &MultiHeadAttn::enable_vproj)
+    .def_readwrite("enable_oproj", &MultiHeadAttn::enable_oproj)
+    .def_readwrite("seed", &MultiHeadAttn::seed)
+    .def_readwrite("attn_prob", &MultiHeadAttn::attn_prob)
+    .def_readwrite("out_prob", &MultiHeadAttn::out_prob)
+    .def_readwrite("handle", &MultiHeadAttn::handle);
+
 py::class_<NMSKeep, std::shared_ptr<NMSKeep>, OpDef> NMSKeepInst(m, "NMSKeep");
 
 NMSKeepInst
diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td
index 2a672774942a09bde26c6b4b2cc26fd31c6ddc8c..c1f8eff0284b3fe810cdb824a5f0691fa3a05892 100644
--- a/src/core/include/megbrain/ir/ops.td
+++ b/src/core/include/megbrain/ir/ops.td
@@ -559,4 +559,57 @@ def RegionRestrictedConvolution: MgbHashableOp<"RegionRestrictedConvolution", [C
 def RegionRestrictedConvolutionBackwardData: MgbHashableOp<"RegionRestrictedConvolutionBackwardData", [ConvolutionParam]>;
 def MaskedFill: MgbHashableOp<"MaskedFill", [FillParam]>;
 
+def MultiHeadAttn: MgbHashableOp<"MultiHeadAttn", [MultiHeadAttnParam]> {
+  let extraArguments = (ins
+    MgbSizeTAddr:$handle
+  );
+  let hashFunction = [{
+    return mgb::hash_pair_combine(
+      mgb::hash($_self.dyn_typeinfo()),
+      mgb::hash_pair_combine(
+        mgb::hash($_self.handle),
+        mgb::hash_pair_combine(
+          mgb::hash($_self.num_heads),
+          mgb::hash_pair_combine(
+            mgb::hash($_self.sm_scaler),
+            mgb::hash_pair_combine(
+              mgb::hash($_self.input_order),
+              mgb::hash_pair_combine(
+                mgb::hash($_self.reslink),
+                mgb::hash_pair_combine(
+                  mgb::hash($_self.training),
+                  mgb::hash_pair_combine(
+                    mgb::hash($_self.bias),
+                    mgb::hash_pair_combine(
+                      mgb::hash($_self.attn_mask),
+                      mgb::hash_pair_combine(
+                        mgb::hash($_self.enable_qproj),
+                        mgb::hash_pair_combine(
+                          mgb::hash($_self.enable_kproj),
+                          mgb::hash_pair_combine(
+                            mgb::hash($_self.enable_vproj),
+                            mgb::hash_pair_combine(
+                              mgb::hash($_self.enable_oproj),
+                              mgb::hash_pair_combine(
+                                mgb::hash($_self.attn_prob),
+                                mgb::hash($_self.out_prob)
+                                )
+                              )
+                            )
+                          )
+                        )
+                      )
+                    )
+                  )
+                )
+              )
+            )
+          )
+        )
+      );
+  }];
+  let cmpFunction = [{return $0.handle == $1.handle && $0.num_heads == $1.num_heads && $0.sm_scaler == $1.sm_scaler && $0.input_order == $1.input_order && $0.reslink == $1.reslink && $0.training == $1.training && $0.bias == $1.bias && $0.attn_mask == $1.attn_mask && $0.enable_qproj == $1.enable_qproj && $0.enable_kproj == $1.enable_kproj && $0.enable_vproj == $1.enable_vproj && $0.enable_oproj == $1.enable_oproj && $0.attn_prob == $1.attn_prob && $0.out_prob == $1.out_prob;}];
+
+}
+
 #endif // MGB_OPS
diff --git a/src/opr/impl/internal/megdnn_opr_wrapper.inl b/src/opr/impl/internal/megdnn_opr_wrapper.inl
index 861903663035910128d05c95552de1221de57fb6..f54ccec2f9fe19b90cdf00b04f2084503a34cdaf 100644
--- a/src/opr/impl/internal/megdnn_opr_wrapper.inl
+++ b/src/opr/impl/internal/megdnn_opr_wrapper.inl
@@ -159,6 +159,11 @@ using MegDNNOprMethInvoker = _MegDNNOprMethInvoker<Opr::NR_INPUTS, Opr::NR_OUTPU
 #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _o(0)
 #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
 
+#define _NR_INPUTS          4
+#define _NR_OUTPUTS         2
+#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _o(0), _o(1)
+#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
+
 #define _NR_INPUTS          4
 #define _NR_OUTPUTS         4
 #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _o(0), _o(1), _o(2), _o(3)
@@ -179,6 +184,12 @@ using MegDNNOprMethInvoker = _MegDNNOprMethInvoker<Opr::NR_INPUTS, Opr::NR_OUTPU
 #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1), _o(2)
 #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
 
+#define _NR_INPUTS  5
+#define _NR_OUTPUTS 5
+#define _FOREACH_IO(_i, _o) \
+    _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1), _o(2), _o(3), _o(4)
+#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
+
 #define _NR_INPUTS          6
 #define _NR_OUTPUTS         1
 #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0)
@@ -195,6 +206,12 @@ using MegDNNOprMethInvoker = _MegDNNOprMethInvoker<Opr::NR_INPUTS, Opr::NR_OUTPU
     _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0), _o(1), _o(2)
 #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
 
+#define _NR_INPUTS  6
+#define _NR_OUTPUTS 4
+#define _FOREACH_IO(_i, _o) \
+    _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0), _o(1), _o(2), _o(3)
+#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
+
 #define _NR_INPUTS  7
 #define _NR_OUTPUTS 3
 #define _FOREACH_IO(_i, _o) \
diff --git a/src/opr/impl/rand.cpp b/src/opr/impl/rand.cpp
index 83dda237dc2ebcc91f3257725cc3282f5a4b6511..b5dc33ced97b5a89bc238584ea3ec2de123f2eb1 100644
--- a/src/opr/impl/rand.cpp
+++ b/src/opr/impl/rand.cpp
@@ -192,6 +192,8 @@ template class RNGOprBase<::megdnn::ShuffleRNGForward>;
 template class RNGOprBase<::megdnn::ShuffleRNGBackward>;
 template class RNGOprBase<::megdnn::DropoutForward>;
 template class RNGOprBase<::megdnn::DropoutBackward>;
+template class RNGOprBase<::megdnn::MultiHeadAttnForward>;
+template class RNGOprBase<::megdnn::MultiHeadAttnBackward>;
 #if MGB_ENABLE_GRAD
 IMPL(GaussianRNG);
 IMPL(UniformRNG);
@@ -375,7 +377,7 @@ MGB_IMPL_OPR_GRAD(DropoutForward) {
 }
 #endif
 
-/* ==================== LayerNormBackward ==================== */
+/* ==================== DropoutBackward ==================== */
 
 MGB_DYN_TYPE_OBJ_FINAL_IMPL(DropoutBackward);
 
@@ -421,4 +423,170 @@ void DropoutBackward::scn_do_execute() {
             output(0)->dev_tensor().as_megdnn(), {});
 }
 
+/* ==================== MultiHeadAttnForward  ==================== */
+MGB_DYN_TYPE_OBJ_FINAL_IMPL(MultiHeadAttnForward);
+
+MultiHeadAttnForward::MultiHeadAttnForward(
+        VarNode* queries, VarNode* keys, VarNode* values, VarNode* wqkv,
+        const Param& param, const OperatorNodeConfig& config)
+        : Super{{queries->owner_graph(),
+                 config,
+                 "multi_head_attn",
+                 {queries, keys, values, wqkv}},
+                param} {
+    add_input({queries, keys, values, wqkv});
+    add_output(None)
+            ->dtype(queries->dtype())
+            .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
+    add_output(None)->dtype(dtype::Byte()).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
+    cg::add_workspace_output(this);
+    add_equivalence_component<ScalarHash<void*>>(this);
+}
+
+SymbolVarArray MultiHeadAttnForward::make(
+        SymbolVar queries, SymbolVar keys, SymbolVar values, SymbolVar wqkv,
+        const Param& param, const OperatorNodeConfig& config) {
+    auto outs = queries.node()
+                        ->owner_graph()
+                        ->insert_opr(std::make_unique<MultiHeadAttnForward>(
+                                queries.node(), keys.node(), values.node(), wqkv.node(),
+                                param, config))
+                        ->output();
+    mgb_assert(outs.size() == 3);
+    return {outs[0], outs[1]};
+}
+
+void MultiHeadAttnForward::init_output_static_infer_desc() {
+    using namespace cg::static_infer;
+    auto&& mgr = owner_graph()->static_infer_manager();
+    mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0)));
+
+    auto infer_mask = [this](TensorShape& dest, const InpVal& iv) {
+        ensure_megdnn_opr();
+        dest.ndim = 1;
+        dest.shape[0] = m_dnn_opr->get_reservespace_in_bytes(
+                {iv.val[0].shape(), input(0)->dtype()},
+                {iv.val[1].shape(), input(1)->dtype()},
+                {iv.val[2].shape(), input(2)->dtype()},
+                {iv.val[3].shape(), input(3)->dtype()}, {}, {});
+        return true;
+    };
+    mgr.register_shape_infer(
+            output(1), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_mask});
+}
+
+void MultiHeadAttnForward::add_input_layout_constraint() {
+    input(0)->add_layout_constraint_contiguous();
+    input(1)->add_layout_constraint_contiguous();
+    input(2)->add_layout_constraint_contiguous();
+    input(3)->add_layout_constraint_contiguous();
+};
+
+void MultiHeadAttnForward::scn_do_execute() {
+    auto&& ret = output(0);
+    if (ret->layout().is_empty()) {
+        mgb_assert(ret->dev_tensor().empty());
+        return;
+    }
+    m_dnn_opr->exec(
+            input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
+            input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(),
+            output(0)->dev_tensor().as_megdnn(), output(1)->dev_tensor().as_megdnn(),
+            get_megdnn_workspace_from_var(output(2)));
+}
+
+cg::OperatorNodeBase::NodeProp* MultiHeadAttnForward::do_make_node_prop() const {
+    auto prop = Super::do_make_node_prop();
+    prop->add_flag(NodeProp::Flag::IMPURE_FUNC);
+    for (auto i : input()) {
+        prop->add_dep_type_existing_var(i, NodeProp::DepType::VALUE_ALLOW_EMPTY);
+    }
+    return prop;
+}
+
+#if MGB_ENABLE_GRAD
+MGB_IMPL_OPR_GRAD(MultiHeadAttnForward) {
+    MGB_MARK_USED_VAR(opr);
+    MGB_MARK_USED_VAR(out_grad);
+    SymbolVarArray grad;
+    VarNodeArray ret;
+    mgb_assert(wrt_idx < 5, "wrt_idx %zu is out of range", wrt_idx);
+    grad = MultiHeadAttnBackward::make(
+            out_grad[0], opr.input(0), opr.input(1), opr.input(2), opr.input(3),
+            opr.output(1), opr.param());
+
+    uint32_t nr_ret = 4;
+    for (uint32_t i = 0; i < nr_ret; ++i) {
+        ret.push_back(grad[i].node());
+    }
+    return ret;
+}
+#endif
+
+/* ==================== MultiHeadAttnBackwardData ==================== */
+MGB_DYN_TYPE_OBJ_FINAL_IMPL(MultiHeadAttnBackward);
+
+MultiHeadAttnBackward::MultiHeadAttnBackward(
+        VarNode* diff, VarNode* queries, VarNode* keys, VarNode* values, VarNode* wqkv,
+        VarNode* reserveSpace, const Param& param, const OperatorNodeConfig& config)
+        : Super({queries->owner_graph(),
+                 config,
+                 "multi_head_attn_backward",
+                 {diff, queries, keys, values, wqkv, reserveSpace}},
+                0, true) {
+    init_megdnn_opr(*this, param);
+    add_input({diff, queries, keys, values, wqkv, reserveSpace});
+}
+
+SymbolVarArray MultiHeadAttnBackward::make(
+        SymbolVar diff, SymbolVar queries, SymbolVar keys, SymbolVar values,
+        SymbolVar wqkv, SymbolVar reserveSpace, const Param& param,
+        const OperatorNodeConfig& config) {
+    auto outs = queries.node()
+                        ->owner_graph()
+                        ->insert_opr(std::make_unique<MultiHeadAttnBackward>(
+                                diff.node(), queries.node(), keys.node(), values.node(),
+                                wqkv.node(), reserveSpace.node(), param, config))
+                        ->output();
+    mgb_assert(outs.size() == 5);
+    return {outs[0], outs[1], outs[2], outs[3]};
+}
+
+void MultiHeadAttnBackward::init_output_static_infer_desc() {
+    using namespace cg::static_infer;
+    auto&& mgr = owner_graph()->static_infer_manager();
+    mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(1)));
+    mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2)));
+    mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(3)));
+    mgr.register_shape_infer(output(3), ShapeInferDesc::make_identity(input(4)));
+
+    this->init_output_static_infer_desc_workspace(false);
+}
+
+void MultiHeadAttnBackward::init_output_dtype() {
+    output(0)->dtype(input(1)->dtype());
+    output(1)->dtype(input(2)->dtype());
+    output(2)->dtype(input(3)->dtype());
+    output(3)->dtype(input(4)->dtype());
+}
+
+size_t MultiHeadAttnBackward::get_workspace_size_bytes(
+        const TensorShapeArray& input_shapes,
+        const TensorShapeArray& output_shapes) const {
+    MGB_MARK_USED_VAR(input_shapes);
+    MGB_MARK_USED_VAR(output_shapes);
+
+    return 0;
+}
+
+void MultiHeadAttnBackward::scn_do_execute() {
+    megdnn_opr()->exec(
+            input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
+            input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(),
+            input(4)->dev_tensor().as_megdnn(), input(5)->dev_tensor().as_megdnn(),
+            output(0)->dev_tensor().as_megdnn(), output(1)->dev_tensor().as_megdnn(),
+            output(2)->dev_tensor().as_megdnn(), output(3)->dev_tensor().as_megdnn(),
+            {});
+}
+
 // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
diff --git a/src/opr/impl/rand.sereg.h b/src/opr/impl/rand.sereg.h
index a5d1a91521a2df67943b6b601458d2e88617ccb9..0333543342cea1c3892711cc14f0b3549aaa6605 100644
--- a/src/opr/impl/rand.sereg.h
+++ b/src/opr/impl/rand.sereg.h
@@ -30,6 +30,35 @@ struct OprMaker<opr::DropoutForward, 1> {
     }
 };
 
+template <>
+struct OprMaker<opr::MultiHeadAttn, 0> {
+    using Param = opr::MultiHeadAttn::Param;
+    static cg::OperatorNodeBase* make(
+            const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
+            const OperatorNodeConfig& config) {
+        MGB_MARK_USED_VAR(graph);
+        return opr::MultiHeadAttn::make(i[0], i[1], i[2], i[3], param, config)[0]
+                .node()
+                ->owner_opr();
+    }
+};
+
+// OprMaker in MGB_SEREG_OPR only support unique output opr
+template <>
+struct OprMaker<opr::MultiHeadAttnBackward, 0> {
+    using Param = opr::MultiHeadAttnBackward::Param;
+    static cg::OperatorNodeBase* make(
+            const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
+            const OperatorNodeConfig& config) {
+        MGB_MARK_USED_VAR(graph);
+
+        return opr::MultiHeadAttnBackward::make(
+                       i[0], i[1], i[2], i[3], i[4], i[5], param, config)[0]
+                .node()
+                ->owner_opr();
+    }
+};
+
 }  // namespace serialization
 
 namespace opr {
@@ -46,6 +75,8 @@ MGB_SEREG_OPR(ShuffleRNG, 1);
 MGB_SEREG_OPR(ShuffleRNGBackward, 3);
 MGB_SEREG_OPR(Dropout, 1);
 MGB_SEREG_OPR(DropoutBackward, 2);
+MGB_SEREG_OPR(MultiHeadAttn, 0);
+MGB_SEREG_OPR(MultiHeadAttnBackward, 0);
 
 }  // namespace opr
 }  // namespace mgb
diff --git a/src/opr/include/megbrain/opr/rand.h b/src/opr/include/megbrain/opr/rand.h
index fd482d608ffbd63913a7a5357213daca996b999a..d5d79352cd97fc806433e94321bfaf4e0d68bf98 100644
--- a/src/opr/include/megbrain/opr/rand.h
+++ b/src/opr/include/megbrain/opr/rand.h
@@ -86,6 +86,13 @@ _DEFINE_RNG_OPR_WITH_INPUT_CLASS(BetaRNG)
 _DEFINE_RNG_OPR_WITH_INPUT_CLASS(GammaRNG)
 #undef _OUTPUTS
 #undef _INPUTS
+
+/* ================= 4 input =================  */
+#define _INPUTS(preifx) preifx i0, preifx i1, preifx i2, preifx i3
+#define _OUTPUTS        SymbolVarArray
+_DEFINE_RNG_OPR_WITH_INPUT_CLASS(MultiHeadAttnForward)
+#undef _OUTPUTS
+#undef _INPUTS
 #undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS
 
 }  // namespace intl
@@ -99,6 +106,7 @@ using BetaRNG = intl::BetaRNG;
 using ShuffleRNG = intl::ShuffleRNGForward;
 using Dropout = intl::DropoutForward;
 using DropoutForward = intl::DropoutForward;
+using MultiHeadAttn = intl::MultiHeadAttnForward;
 
 MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
         ShuffleRNGBackward, intl::MegDNNOprWrapperBwd<megdnn::ShuffleRNGBackward>) // {
@@ -132,6 +140,29 @@ private:
     void scn_do_execute() override;
 };
 
+MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
+        MultiHeadAttnBackward,
+        intl::MegDNNOprWrapperBwd<megdnn::MultiHeadAttnBackward>) // {
+public:
+    MGE_WIN_DECLSPEC_FUC MultiHeadAttnBackward(
+            VarNode* diff, VarNode* queries, VarNode* keys, VarNode* values,
+            VarNode* wqkv, VarNode* reserveSpace, const Param& param,
+            const OperatorNodeConfig& config);
+
+    MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
+            SymbolVar diff, SymbolVar queries, SymbolVar keys, SymbolVar values,
+            SymbolVar wqkv, SymbolVar reserveSpace, const Param& param = {},
+            const OperatorNodeConfig& config = {});
+
+private:
+    void init_output_static_infer_desc() override;
+    void init_output_dtype() override;
+    size_t get_workspace_size_bytes(
+            const TensorShapeArray& input_shapes,
+            const TensorShapeArray& output_shapes) const override;
+    void scn_do_execute() override;
+};
+
 }  // namespace opr
 }  // namespace mgb
 
diff --git a/src/serialization/impl/schema.fbs b/src/serialization/impl/schema.fbs
index 9bcb8911c50057be124a531c9729cb1eb1a40a9d..8f647494f6d4e81568ce31d453d6bb6ebed94743 100644
--- a/src/serialization/impl/schema.fbs
+++ b/src/serialization/impl/schema.fbs
@@ -126,6 +126,7 @@ union OperatorParam {
     param.GroupNorm = 92,
     param.Fill = 93,
     param.GeneralNorm=94,
+    param.MultiHeadAttn=95,
 }
 
 table Operator {
diff --git a/src/serialization/impl/schema_v2.fbs b/src/serialization/impl/schema_v2.fbs
index 7c2c89ff036527f2e0be7eb85dc93033cfc4c953..add03d695b8a563d0f70793f2238646ebcabd757 100644
--- a/src/serialization/impl/schema_v2.fbs
+++ b/src/serialization/impl/schema_v2.fbs
@@ -143,6 +143,7 @@ union OperatorParam {
     param.GroupNorm = 92,
     param.Fill = 93,
     param.GeneralNorm=94,
+    param.MultiHeadAttn=95,
 }
 
 table Operator {