提交 a7ca0588 编写于 作者: M Megvii Engine Team

feat(opr): add multiattention cuda backend

GitOrigin-RevId: 9f5613cf3effd0691652e6e2395f73e4fe38f660
上级 41d58740
......@@ -2574,6 +2574,73 @@ protected:
size_t workspace_in_bytes);
};
class MultiHeadAttnBase : public OperatorBase {
DEF_OPR_IMPL_CTOR(MultiHeadAttnBase, OperatorBase);
DEF_OPR_PARAM(MultiHeadAttn);
};
class MultiHeadAttnForward : public MultiHeadAttnBase {
DEF_OPR_IMPL(MultiHeadAttnForward, MultiHeadAttnBase, 4, 2);
public:
virtual void exec(
_megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values,
_megdnn_tensor_in wqkv, _megdnn_tensor_out out,
_megdnn_tensor_out reserveSpace, _megdnn_workspace workspace) = 0;
MGE_WIN_DECLSPEC_FUC void deduce_layout(
const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv, TensorLayout& out,
TensorLayout& reserveSpace);
virtual size_t get_workspace_in_bytes(
const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv,
const TensorLayout& out, const TensorLayout& reserveSpace) = 0;
virtual size_t get_reservespace_in_bytes(
const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv,
const TensorLayout& out, const TensorLayout& reserveSpace) = 0;
protected:
void check_exec(
const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv,
const TensorLayout& out, const TensorLayout& reserveSpace,
size_t workspace_in_bytes);
};
using MultiHeadAttn = MultiHeadAttnForward;
class MultiHeadAttnBackward : public MultiHeadAttnBase {
DEF_OPR_IMPL(MultiHeadAttnBackward, MultiHeadAttnBase, 6, 4);
public:
virtual void exec(
_megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys,
_megdnn_tensor_in values, _megdnn_tensor_in wqkv,
_megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries,
_megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues,
_megdnn_tensor_out dweights, _megdnn_workspace workspace) = 0;
MGE_WIN_DECLSPEC_FUC void deduce_layout(
const TensorLayout& diff, const TensorLayout& queries,
const TensorLayout& keys, const TensorLayout& values,
const TensorLayout& wqkv, const TensorLayout& reserveSpace,
TensorLayout& dqueries, TensorLayout& dkeys, TensorLayout& dvalues,
TensorLayout& dweights);
virtual size_t get_workspace_in_bytes(
const TensorLayout& diff, const TensorLayout& queries,
const TensorLayout& keys, const TensorLayout& values,
const TensorLayout& wqkv, const TensorLayout& reserveSpace,
const TensorLayout& dqueries, const TensorLayout& dkeys,
const TensorLayout& dvalues, const TensorLayout& dweights) = 0;
protected:
void check_exec(
const TensorLayout& diff, const TensorLayout& queries,
const TensorLayout& keys, const TensorLayout& values,
const TensorLayout& wqkv, const TensorLayout& reserveSpace,
const TensorLayout& dqueries, const TensorLayout& dkeys,
const TensorLayout& dvalues, const TensorLayout& dweights,
size_t workspace_in_bytes);
};
} // namespace megdnn
#include "megdnn/internal/opr_header_epilogue.h"
......
......@@ -1330,3 +1330,20 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'),
add_fields('float32', Doc('p', 'the order of norm'), '2').
add_fields('int32', Doc('dim', 'which dim the norm performed along'), '-1'),
)
(pdef('MultiHeadAttn')
.add_fields('uint32', Doc('num_heads', 'Number of parallel attention heads.'), '1')
.add_fields('float32', Doc('sm_scaler', 'Softmax smoothing (1.0 >= smScaler >= 0.0) or sharpening (smScaler > 1.0) coefficient.'), '1.f')
.add_fields('uint32', Doc('input_order', 'The sequence data layout, allows the user to select 3! = 6 different data layouts or permutations of BEAM, BATCH and TIME dimensions.'), '0')
.add_fields('bool', Doc('reslink', 'Whether to add input query to final output.'), 'false')
.add_fields('bool', Doc('training', 'Whether it is in training mode.'), 'true')
.add_fields('bool', Doc('bias', 'Whether to add linear bias.'), 'false')
.add_fields('bool', Doc('attn_mask', 'Whether to add attn_mask.'), 'false')
.add_fields('bool', Doc('enable_qproj', 'enable query weight projection.'), 'true')
.add_fields('bool', Doc('enable_kproj', 'enable key weight projection.'), 'true')
.add_fields('bool', Doc('enable_vproj', 'enable value weight projection.'), 'true')
.add_fields('bool', Doc('enable_oproj', 'enable output weight projection.'), 'true')
.add_fields('uint64', Doc('seed', 'Random number seed for drop'), '0')
.add_fields('float32', Doc('attn_prob', 'Dropout probability on attention, is applied directly to the softmax output'), '0.f')
.add_fields('float32', Doc('out_prob', 'Dropout probability on output, alters the multi-head attention output'), '0.f')
)
......@@ -221,7 +221,9 @@ private:
cb(RegionRestrictedConvolutionBackwardFilter) \
cb(GroupNormForward) \
cb(GroupNormBackward) \
cb(MaskedFill)
cb(MaskedFill) \
cb(MultiHeadAttnForward)\
cb(MultiHeadAttnBackward)
// clang-format on
/*!
......
#include "megdnn/basic_types.h"
#include "megdnn/oprs.h"
#include "src/common/utils.cuh"
#include "unroll_macro.h"
#include "src/common/utils.h"
namespace megdnn {
using Param = MultiHeadAttnBase::Param;
void MultiHeadAttnForward::deduce_layout(
const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv, TensorLayout& out,
TensorLayout& reserveSpace) {
megdnn_assert(
queries.ndim == 3,
"queries.ndim should be 3[batch, sequence, embeding], but got %zu",
queries.ndim);
size_t size =
get_reservespace_in_bytes(queries, keys, values, wqkv, out, reserveSpace);
out = TensorLayout(
{queries.shape[0], queries.shape[1], queries.shape[2]}, queries.dtype);
reserveSpace = TensorLayout({size}, queries.dtype);
}
void MultiHeadAttnForward::check_exec(
const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& out,
const TensorLayout& reserveSpace, size_t workspace_in_bytes) {
Param p = param();
megdnn_assert_contiguous(queries);
megdnn_assert_contiguous(keys);
megdnn_assert_contiguous(values);
megdnn_assert_contiguous(wqkv);
megdnn_assert_contiguous(out);
if (p.training)
megdnn_assert_contiguous(reserveSpace);
auto required_workspace_in_bytes =
get_workspace_in_bytes(queries, keys, values, wqkv, out, reserveSpace);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
megdnn_assert(
queries.ndim == 3,
"queries.ndim should be 3[batch, sequence, embeding], but got %zu",
queries.ndim);
megdnn_assert(
keys.ndim == 3,
"keys.ndim should be 3[batch, sequence, embeding], but got %zu", keys.ndim);
megdnn_assert(
values.ndim == 3,
"values.ndim should be 3[batch, sequence, embeding], but got %zu",
values.ndim);
auto errmsg = [&]() {
return megdnn_layout_msg(queries) + ", " + megdnn_layout_msg(keys) + ", " +
megdnn_layout_msg(values) + ", " + megdnn_layout_msg(wqkv) + ", " +
megdnn_layout_msg(out) + ", " + megdnn_layout_msg(reserveSpace);
};
megdnn_assert(queries.shape[0] == out.shape[0], "%s", errmsg().c_str());
megdnn_assert(keys.shape[0] == values.shape[0], "%s", errmsg().c_str());
megdnn_assert(queries.shape[0] == keys.shape[0], "%s", errmsg().c_str());
megdnn_assert(queries.shape[1] == out.shape[1], "%s", errmsg().c_str());
megdnn_assert(keys.shape[1] == values.shape[1], "%s", errmsg().c_str());
megdnn_assert(
queries.shape[2] == keys.shape[2] and keys.shape[2] == values.shape[2] and
queries.shape[2] == out.shape[2],
"%s", errmsg().c_str());
}
void MultiHeadAttnBackward::deduce_layout(
const TensorLayout& diff, const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv,
const TensorLayout& reserveSpace, TensorLayout& dqueries, TensorLayout& dkeys,
TensorLayout& dvalues, TensorLayout& dweights) {
MEGDNN_MARK_USED_VAR(diff);
MEGDNN_MARK_USED_VAR(reserveSpace);
dqueries = queries;
dkeys = keys;
dvalues = values;
dweights = wqkv;
}
void MultiHeadAttnBackward::check_exec(
const TensorLayout& diff, const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv,
const TensorLayout& reserveSpace, const TensorLayout& dqueries,
const TensorLayout& dkeys, const TensorLayout& dvalues,
const TensorLayout& dweights, size_t workspace_in_bytes) {
Param p = param();
megdnn_assert(
p.training,
"When calling MultiHeadAttn backward, param().training must be true, "
"but got false");
megdnn_assert_contiguous(diff);
megdnn_assert_contiguous(queries);
megdnn_assert_contiguous(keys);
megdnn_assert_contiguous(values);
megdnn_assert_contiguous(wqkv);
megdnn_assert_contiguous(dqueries);
megdnn_assert_contiguous(dkeys);
megdnn_assert_contiguous(dvalues);
megdnn_assert_contiguous(dweights);
if (p.training)
megdnn_assert_contiguous(reserveSpace);
auto required_workspace_in_bytes = get_workspace_in_bytes(
diff, queries, keys, values, wqkv, reserveSpace, dqueries, dkeys, dvalues,
dweights);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
megdnn_assert(reserveSpace.total_nr_elems() > 0);
megdnn_assert(
queries.ndim == 3,
"queries.ndim should be 3[batch, sequence, embeding], but got %zu",
queries.ndim);
megdnn_assert(
keys.ndim == 3,
"keys.ndim should be 3[batch, sequence, embeding], but got %zu", keys.ndim);
megdnn_assert(
values.ndim == 3,
"values.ndim should be 3[batch, sequence, embeding], but got %zu",
values.ndim);
megdnn_assert(
diff.ndim == 3,
"diff.ndim should be 3[batch, sequence, embeding], but got %zu", diff.ndim);
auto errmsg = [&]() {
return megdnn_layout_msg(diff) + ", " + megdnn_layout_msg(queries) + ", " +
megdnn_layout_msg(keys) + ", " + megdnn_layout_msg(values) + ", " +
megdnn_layout_msg(wqkv) + ", " + megdnn_layout_msg(reserveSpace) + ", " +
megdnn_layout_msg(dqueries) + ", " + megdnn_layout_msg(dkeys) + ", " +
megdnn_layout_msg(dvalues) + ", " + megdnn_layout_msg(dweights);
};
auto equal_layout = [](const TensorLayout& lhs, const TensorLayout& rhs) -> bool {
if (!(lhs.ndim == rhs.ndim && lhs.dtype == rhs.dtype &&
lhs.format == rhs.format))
return false;
for (size_t i = 0; i < lhs.ndim; ++i) {
if (lhs.shape[i] != rhs.shape[i] || lhs.stride[i] != rhs.stride[i]) {
return false;
}
}
return true;
};
megdnn_assert(equal_layout(queries, diff), "%s", errmsg().c_str());
megdnn_assert(equal_layout(queries, dqueries), "%s", errmsg().c_str());
megdnn_assert(equal_layout(keys, dkeys), "%s", errmsg().c_str());
megdnn_assert(equal_layout(values, dvalues), "%s", errmsg().c_str());
megdnn_assert(equal_layout(wqkv, dweights), "%s", errmsg().c_str());
megdnn_assert(queries.shape[0] == diff.shape[0], "%s", errmsg().c_str());
megdnn_assert(keys.shape[0] == values.shape[0], "%s", errmsg().c_str());
megdnn_assert(queries.shape[0] == keys.shape[0], "%s", errmsg().c_str());
megdnn_assert(queries.shape[1] == diff.shape[1], "%s", errmsg().c_str());
megdnn_assert(keys.shape[1] == values.shape[1], "%s", errmsg().c_str());
megdnn_assert(
queries.shape[2] == keys.shape[2] and keys.shape[2] == values.shape[2] and
queries.shape[2] == diff.shape[2],
"%s", errmsg().c_str());
}
} // namespace megdnn
// vim: syntax=cpp.doxygen
#pragma once
#include "megdnn/oprs.h"
#include "megdnn/oprs/nn.h"
#include <cstddef>
......@@ -147,6 +148,8 @@ DEF(RegionRestrictedConvolutionBackwardFilter, 5, true, false);
DEF(GroupNormForward, 6, true, true);
DEF(GroupNormBackward, 8, true, true);
DEF(MaskedFill, 3, false, true);
DEF(MultiHeadAttnForward, 6, true, true);
DEF(MultiHeadAttnBackward, 10, true, true);
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -3,12 +3,10 @@
#include "src/common/utils.h"
#include "src/cuda/utils.h"
namespace {
using namespace megdnn;
namespace megdnn {
namespace cuda {
cudnnDataType_t to_cudnn_dtype(
DType type, const param::Convolution::Format format = {}) {
cudnnDataType_t to_cudnn_dtype(DType type, const param::Convolution::Format format) {
switch (type.enumv()) {
case DTypeEnum::Float32:
return CUDNN_DATA_FLOAT;
......@@ -66,8 +64,9 @@ cudnnTensorFormat_t to_cudnn_format(const param::Convolution::Format format) {
megdnn_assert_internal(0);
}
}
} // namespace cuda
} // namespace
} // namespace megdnn
namespace megdnn {
namespace cuda {
......@@ -558,6 +557,71 @@ const std::unordered_map<cudnnConvolutionFwdAlgo_t, CudnnAlgoPack::Attr> CudnnAl
#undef V
#undef V1
#if CUDNN_VERSION >= 8004
SeqTensorDesc::~SeqTensorDesc() {
cudnn_check(cudnnDestroySeqDataDescriptor(desc));
}
SeqTensorDesc::SeqTensorDesc() {
cudnnCreateSeqDataDescriptor(&desc);
}
SeqTensorDesc::SeqTensorDesc(
const TensorLayout& layout, const size_t batchSize, const size_t seqLen,
const size_t elemSize, const size_t input_order, int* seqArray) {
cudnnCreateSeqDataDescriptor(&desc);
set(layout, batchSize, seqLen, elemSize, input_order, seqArray);
}
void SeqTensorDesc::set(
const TensorLayout& layout, const size_t batchSize, const size_t seqLen,
const size_t elemSize, const size_t input_order, int* seqArray) {
switch (input_order) {
case 0: // dimAxes = [Batch, Beam, Time]
dimAxes[0] = CUDNN_SEQDATA_BATCH_DIM;
dimAxes[1] = CUDNN_SEQDATA_BEAM_DIM;
dimAxes[2] = CUDNN_SEQDATA_TIME_DIM;
break;
case 1: // dimAxes = [Beam, Batch, Time]
dimAxes[0] = CUDNN_SEQDATA_BEAM_DIM;
dimAxes[1] = CUDNN_SEQDATA_BATCH_DIM;
dimAxes[2] = CUDNN_SEQDATA_TIME_DIM;
break;
case 2: // dimAxes = [Batch, Time, Beam]
dimAxes[0] = CUDNN_SEQDATA_BATCH_DIM;
dimAxes[1] = CUDNN_SEQDATA_TIME_DIM;
dimAxes[2] = CUDNN_SEQDATA_BEAM_DIM;
break;
case 3: // dimAxes = [Beam, Time, Batch]
dimAxes[0] = CUDNN_SEQDATA_BEAM_DIM;
dimAxes[1] = CUDNN_SEQDATA_TIME_DIM;
dimAxes[2] = CUDNN_SEQDATA_BATCH_DIM;
break;
case 4: // dimAxes = [Time, Batch, Beam]
dimAxes[0] = CUDNN_SEQDATA_TIME_DIM;
dimAxes[1] = CUDNN_SEQDATA_BATCH_DIM;
dimAxes[2] = CUDNN_SEQDATA_BEAM_DIM;
break;
case 5: // dimAxes = [Time, Beam, Batch]
dimAxes[0] = CUDNN_SEQDATA_TIME_DIM;
dimAxes[1] = CUDNN_SEQDATA_BEAM_DIM;
dimAxes[2] = CUDNN_SEQDATA_BATCH_DIM;
break;
default:
megdnn_throw(ssprintf("ERROR: wrong attention layout %zu", input_order));
}
dimAxes[3] = CUDNN_SEQDATA_VECT_DIM;
dim[CUDNN_SEQDATA_BEAM_DIM] = 1;
dim[CUDNN_SEQDATA_BATCH_DIM] = batchSize;
dim[CUDNN_SEQDATA_TIME_DIM] = seqLen;
dim[CUDNN_SEQDATA_VECT_DIM] = elemSize;
cudnnDataType_t cudnn_dtype = to_cudnn_dtype(layout.dtype);
cudnn_check(cudnnSetSeqDataDescriptor(
desc, cudnn_dtype, CUDNN_SEQDATA_DIM_COUNT, dim, dimAxes, batchSize,
seqArray, NULL));
}
#endif
} // namespace cuda
} // namespace megdnn
......
......@@ -2,12 +2,18 @@
#include <unordered_map>
#include "megdnn/basic_types.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs/nn.h"
#include "src/cuda/cudnn_with_check.h"
namespace megdnn {
namespace cuda {
cudnnDataType_t to_cudnn_dtype(
DType type, const param::Convolution::Format format = {});
cudnnTensorFormat_t to_cudnn_format(const param::Convolution::Format format);
/*!
* \brief get compute_type of convolution operations
*/
......@@ -85,6 +91,24 @@ public:
cudnnConvolutionDescriptor_t desc;
};
#if CUDNN_VERSION >= 8004
class SeqTensorDesc {
public:
int dim[CUDNN_SEQDATA_DIM_COUNT];
cudnnSeqDataAxis_t dimAxes[CUDNN_SEQDATA_DIM_COUNT];
cudnnSeqDataDescriptor_t desc;
~SeqTensorDesc();
SeqTensorDesc();
SeqTensorDesc(
const TensorLayout& layout, const size_t batchSize, const size_t seqLen,
const size_t elemSize, const size_t dataLayout, int* seqArray);
void set(
const TensorLayout& layout, const size_t batchSize, const size_t seqLen,
const size_t elemSize, const size_t dataLayout, int* seqArray);
};
#endif
class CudnnAlgoPack {
public:
//! algorithm attr
......
......@@ -55,7 +55,6 @@ void run_conv_bias_act_with_plan(
const cudnnHandle_t& handle, const cudnn_frontend::ExecutionPlan& plan,
const TensorND& x, const TensorND& y, const TensorND& w, const TensorND& b,
const TensorND& z, const Workspace& workspace);
} // namespace cuda
} // namespace megdnn
......
......@@ -61,6 +61,9 @@ public:
bool initialized() { return status != nullptr; }
friend class DropoutForwardImpl;
friend class DropoutBackwardImpl;
#if CUDNN_VERSION >= 8004
friend class MultiHeadAttnStatus;
#endif
};
// similar to RNG operator, dropout operator also have status
......
......@@ -50,6 +50,7 @@
#include "src/cuda/matrix_mul/opr_impl.h"
#include "src/cuda/max_tensor_diff/opr_impl.h"
#include "src/cuda/mesh_indexing/opr_impl.h"
#include "src/cuda/multi_head_attn/opr_impl.h"
#include "src/cuda/norm/opr_impl.h"
#include "src/cuda/padding/opr_impl.h"
#include "src/cuda/param_pack/opr_impl.h"
......@@ -230,6 +231,8 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(NormForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RegionRestrictedConvolutionForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RegionRestrictedConvolutionBackwardData);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RegionRestrictedConvolutionBackwardFilter);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MultiHeadAttnForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MultiHeadAttnBackward);
template <typename Opr>
std::unique_ptr<Opr> HandleImpl::create_operator() {
......
#include "src/cuda/multi_head_attn/helper.h"
#if CUDNN_VERSION >= 8004
namespace megdnn {
namespace cuda {
AuxiliaryArray::~AuxiliaryArray() {
if (loWinIdx)
free(loWinIdx);
if (hiWinIdx)
free(hiWinIdx);
if (seqQArray)
free(seqQArray);
if (seqKArray)
free(seqKArray);
if (devSeqQArray)
cuda_check(cudaFree(devSeqQArray));
if (devSeqKArray)
cuda_check(cudaFree(devSeqKArray));
}
bool AuxiliaryArray::is_initialized(
const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK,
bool _attnMask) {
if (_batchSize != batchSize or _seqLenQ != seqLenQ or _seqLenK != seqLenK or
_attnMask != attnMask or !seqQArray or !seqKArray or !devSeqQArray or
!devSeqKArray or !loWinIdx or !hiWinIdx)
return false;
return true;
}
void AuxiliaryArray::set(
const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK,
bool _attnMask) {
if (_batchSize == batchSize && _seqLenQ == seqLenQ && _seqLenK == seqLenK &&
_attnMask == attnMask)
return;
else {
if (loWinIdx)
free(loWinIdx);
if (hiWinIdx)
free(hiWinIdx);
if (seqQArray)
free(seqQArray);
if (seqKArray)
free(seqKArray);
if (devSeqQArray)
cuda_check(cudaFree(devSeqQArray));
if (devSeqKArray)
cuda_check(cudaFree(devSeqKArray));
};
seqLenQ = _seqLenQ;
seqLenK = _seqLenK;
batchSize = _batchSize;
attnMask = _attnMask;
size_t seqQArraySize = 1 * batchSize;
size_t seqKArraySize = batchSize;
seqQArray = (int*)calloc(seqQArraySize, sizeof(int));
seqKArray = (int*)calloc(seqKArraySize, sizeof(int));
for (size_t i = 0; i < seqQArraySize; ++i)
seqQArray[i] = seqLenQ;
for (size_t i = 0; i < seqKArraySize; ++i)
seqKArray[i] = seqLenK;
cuda_check(cudaMalloc((void**)&devSeqQArray, seqQArraySize * sizeof(int)));
cuda_check(cudaMalloc((void**)&devSeqKArray, seqKArraySize * sizeof(int)));
cuda_check(cudaMemcpy(
devSeqQArray, seqQArray, seqQArraySize * sizeof(int),
cudaMemcpyHostToDevice));
cuda_check(cudaMemcpy(
devSeqKArray, seqKArray, seqKArraySize * sizeof(int),
cudaMemcpyHostToDevice));
loWinIdx = (int*)calloc(seqLenQ, sizeof(int));
hiWinIdx = (int*)calloc(seqLenQ, sizeof(int));
for (size_t i = 0; i < seqLenQ; ++i) {
loWinIdx[i] = 0;
if (attnMask)
hiWinIdx[i] = i + 1;
else
hiWinIdx[i] = seqLenK;
}
}
void MultiHeadAttnStatus::set(
cudnnHandle_t handle, const Param& p, const TensorLayout& q,
const TensorLayout& k, const TensorLayout& v) {
float attn_prob = p.training ? p.attn_prob : 0.f;
float out_prob = p.training ? p.out_prob : 0.f;
if (!attn_dropout_status.initialized())
attn_dropout_status.set(handle, p.seed, attn_prob);
if (!out_dropout_status.initialized())
out_dropout_status.set(handle, p.seed, out_prob);
if (attn_dropout_status.drop_prob != attn_prob) {
attn_dropout_status.drop_prob = attn_prob;
attn_dropout_status.restore_desc(handle);
}
if (out_dropout_status.drop_prob != out_prob) {
out_dropout_status.drop_prob = out_prob;
out_dropout_status.restore_desc(handle);
}
batchSize = q.shape[0];
seqLenQ = q.shape[1];
seqLenK = k.shape[1];
qSize = q.shape[2];
kSize = k.shape[2];
vSize = v.shape[2];
numHeads = p.num_heads;
qProjSize = p.enable_qproj ? qSize / numHeads : 0;
kProjSize = p.enable_kproj ? kSize / numHeads : 0;
vProjSize = p.enable_vproj ? vSize / numHeads : 0;
oProjSize = p.enable_oproj ? qSize : 0;
attnMask = p.attn_mask;
cudnnDataType_t cudnn_dtype = to_cudnn_dtype(q.dtype);
auto flag = CUDNN_ATTN_QUERYMAP_ONE_TO_ONE;
if (p.bias)
flag = flag | CUDNN_ATTN_ENABLE_PROJ_BIASES;
#if CUDNN_VERSION < 8600
// TODO: CUDNN_VERSION < 8600 and out dropout > 0.0, we need to go to the proxy cuda
// implementation.
cudnn_check(cudnnSetAttnDescriptor(
attn_desc, flag, numHeads, p.sm_scaler, cudnn_dtype, cudnn_dtype,
CUDNN_DEFAULT_MATH, attn_dropout_status.desc.desc, NULL, qSize, kSize,
vSize, qProjSize, kProjSize, vProjSize, oProjSize, seqLenQ, seqLenK,
batchSize, 1));
#else
cudnn_check(cudnnSetAttnDescriptor(
attn_desc, flag, numHeads, p.sm_scaler, cudnn_dtype, cudnn_dtype,
CUDNN_DEFAULT_MATH, attn_dropout_status.desc.desc,
out_dropout_status.desc.desc, qSize, kSize, vSize, qProjSize, kProjSize,
vProjSize, oProjSize, seqLenQ, seqLenK, batchSize, 1));
#endif
auxArray.set(batchSize, seqLenQ, seqLenK, p.attn_mask);
if (p.training)
cudnnGetMultiHeadAttnBuffers(
handle, attn_desc, &sizeWeights, &sizeWkspace, &sizeReserve);
else {
cudnnGetMultiHeadAttnBuffers(
handle, attn_desc, &sizeWeights, &sizeWkspace, NULL);
sizeReserve = 0;
}
}
bool MultiHeadAttnStatus::is_initialized(
const Param& p, const TensorLayout& q, const TensorLayout& k,
const TensorLayout& v) {
float attn_prob = p.training ? p.attn_prob : 0.f;
float out_prob = p.training ? p.out_prob : 0.f;
if (!attn_dropout_status.initialized() or !out_dropout_status.initialized() or
attn_dropout_status.drop_prob != attn_prob or
out_dropout_status.drop_prob != out_prob)
return false;
if (q.shape[0] != batchSize or q.shape[1] != seqLenQ or k.shape[1] != seqLenK or
q.shape[2] != qSize or k.shape[2] != kSize or v.shape[2] != vSize or
attnMask != p.attn_mask or numHeads != p.num_heads) {
return false;
}
if ((p.enable_qproj && (qProjSize == 0 or qProjSize != qSize / p.num_heads)) or
(p.enable_kproj && (kProjSize == 0 or kProjSize != kSize / p.num_heads)) or
(p.enable_vproj && (vProjSize == 0 or vProjSize != vSize / p.num_heads)) or
(p.enable_oproj && (oProjSize == 0 or oProjSize != q.shape[2])))
return false;
if ((!p.enable_qproj && qProjSize != 0) or (!p.enable_kproj && kProjSize != 0) or
(!p.enable_vproj && vProjSize != 0) or (!p.enable_oproj && oProjSize != 0))
return false;
if (!auxArray.is_initialized(batchSize, seqLenQ, seqLenK, attnMask))
return false;
if (p.training and sizeReserve == 0)
return false;
return true;
}
} // namespace cuda
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
#pragma once
#include "src/cuda/cudnn_wrapper.h"
#if CUDNN_VERSION >= 8004
#include "megdnn/basic_types.h"
#include "megdnn/oprs/nn.h"
#include "src/common/algo_chooser.h"
#include "src/common/utils.h"
#include "src/cuda/dropout/opr_impl.h"
#include "src/cuda/handle.h"
namespace megdnn {
namespace cuda {
struct AuxiliaryArray {
public:
int* seqQArray = nullptr;
int* seqKArray = nullptr;
int* devSeqQArray = nullptr;
int* devSeqKArray = nullptr;
int* loWinIdx = nullptr;
int* hiWinIdx = nullptr;
size_t seqLenQ = 0;
size_t seqLenK = 0;
size_t batchSize = 0;
bool attnMask = 0;
~AuxiliaryArray();
void set(
const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK,
bool _attnMask);
bool is_initialized(
const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK,
bool _attnMask);
};
using Param = megdnn::MultiHeadAttn::Param;
class MultiHeadAttnStatus {
DropoutStatus attn_dropout_status;
DropoutStatus out_dropout_status;
cudnnAttnDescriptor_t attn_desc;
AuxiliaryArray auxArray;
size_t numHeads = 0;
size_t batchSize = 0;
size_t seqLenQ = 0;
size_t seqLenK = 0;
size_t qSize = 0;
size_t kSize = 0;
size_t vSize = 0;
size_t qProjSize = 0;
size_t kProjSize = 0;
size_t vProjSize = 0;
size_t oProjSize = 0;
bool attnMask = 0;
size_t sizeWeights = 0;
size_t sizeWkspace = 0;
size_t sizeReserve = 0;
public:
MultiHeadAttnStatus() { cudnn_check(cudnnCreateAttnDescriptor(&attn_desc)); }
~MultiHeadAttnStatus() { cudnn_check(cudnnDestroyAttnDescriptor(attn_desc)); }
private:
void set(
cudnnHandle_t handle, const Param& p, const TensorLayout& q,
const TensorLayout& k, const TensorLayout& v);
bool is_initialized(
const Param& p, const TensorLayout& q, const TensorLayout& k,
const TensorLayout& v);
friend class MultiHeadAttnBase;
friend class MultiHeadAttnForwardImpl;
friend class MultiHeadAttnBackwardImpl;
};
} // namespace cuda
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
#include "src/cuda/multi_head_attn/opr_impl.h"
#include "src/common/utils.cuh"
#include "src/cuda/utils.cuh"
#include "src/cuda/utils.h"
namespace megdnn {
namespace cuda {
void MultiHeadAttnForwardImpl::deduce_layout(
const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv, TensorLayout& out,
TensorLayout& reserveSpace) {
#if CUDNN_VERSION < 8004
// TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation.
MEGDNN_MARK_USED_VAR(queries);
MEGDNN_MARK_USED_VAR(keys);
MEGDNN_MARK_USED_VAR(values);
MEGDNN_MARK_USED_VAR(wqkv);
MEGDNN_MARK_USED_VAR(out);
MEGDNN_MARK_USED_VAR(reserveSpace);
return;
#else
MEGDNN_MARK_USED_VAR(keys);
MEGDNN_MARK_USED_VAR(wqkv);
megdnn_assert(
queries.ndim == 3,
"queries.ndim should be 3[batch, sequence, embeding], but got %zu",
queries.ndim);
if (!desc_status.is_initialized(param(), queries, keys, values)) {
desc_status.set(cudnn_handle(this->handle()), param(), queries, keys, values);
out = TensorLayout(
TensorShape{queries.shape[0], queries.shape[1], queries.shape[2]},
queries.dtype);
reserveSpace =
TensorLayout(TensorShape{desc_status.sizeReserve}, queries.dtype);
}
#endif
}
size_t MultiHeadAttnForwardImpl::get_workspace_in_bytes(
const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& out,
const TensorLayout& reserveSpace) {
#if CUDNN_VERSION < 8004
// TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation.
MEGDNN_MARK_USED_VAR(queries);
MEGDNN_MARK_USED_VAR(keys);
MEGDNN_MARK_USED_VAR(values);
MEGDNN_MARK_USED_VAR(wqkv);
MEGDNN_MARK_USED_VAR(out);
MEGDNN_MARK_USED_VAR(reserveSpace);
return 0;
#else
MEGDNN_MARK_USED_VAR(wqkv);
MEGDNN_MARK_USED_VAR(out);
MEGDNN_MARK_USED_VAR(reserveSpace);
if (!desc_status.is_initialized(param(), queries, keys, values))
desc_status.set(cudnn_handle(this->handle()), param(), queries, keys, values);
return desc_status.sizeWkspace;
#endif
}
size_t MultiHeadAttnForwardImpl::get_reservespace_in_bytes(
const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& out,
const TensorLayout& reserveSpace) {
#if CUDNN_VERSION < 8004
// TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation.
MEGDNN_MARK_USED_VAR(queries);
MEGDNN_MARK_USED_VAR(keys);
MEGDNN_MARK_USED_VAR(values);
MEGDNN_MARK_USED_VAR(wqkv);
MEGDNN_MARK_USED_VAR(out);
MEGDNN_MARK_USED_VAR(reserveSpace);
return 0;
#else
MEGDNN_MARK_USED_VAR(wqkv);
MEGDNN_MARK_USED_VAR(out);
MEGDNN_MARK_USED_VAR(reserveSpace);
if (!desc_status.is_initialized(param(), queries, keys, values))
desc_status.set(cudnn_handle(this->handle()), param(), queries, keys, values);
return desc_status.sizeReserve;
#endif
}
void MultiHeadAttnForwardImpl::exec(
_megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values,
_megdnn_tensor_in wqkv, _megdnn_tensor_out out, _megdnn_tensor_out reserveSpace,
_megdnn_workspace workspace) {
#if CUDNN_VERSION < 8004
// TODO: CUDNN_VERSION < 8004, we need to go to the proxy cuda implementation.
MEGDNN_MARK_USED_VAR(queries);
MEGDNN_MARK_USED_VAR(keys);
MEGDNN_MARK_USED_VAR(values);
MEGDNN_MARK_USED_VAR(wqkv);
MEGDNN_MARK_USED_VAR(out);
MEGDNN_MARK_USED_VAR(reserveSpace);
MEGDNN_MARK_USED_VAR(workspace);
megdnn_throw(
"The cudnn version is lower than 8.0.4. Please upgrade the cudnn version.");
#else
check_exec(
queries.layout, keys.layout, values.layout, wqkv.layout, out.layout,
reserveSpace.layout, workspace.size);
auto p = param();
if (!desc_status.is_initialized(p, queries.layout, keys.layout, values.layout))
desc_status.set(
cudnn_handle(this->handle()), p, queries.layout, keys.layout,
values.layout);
SeqTensorDesc q{queries.layout, desc_status.batchSize,
desc_status.seqLenQ, desc_status.qSize,
p.input_order, desc_status.auxArray.seqQArray};
SeqTensorDesc o{out.layout, desc_status.batchSize,
desc_status.seqLenQ, desc_status.oProjSize,
p.input_order, desc_status.auxArray.seqQArray};
SeqTensorDesc k{keys.layout, desc_status.batchSize,
desc_status.seqLenK, desc_status.kSize,
p.input_order, desc_status.auxArray.seqKArray};
SeqTensorDesc v{values.layout, desc_status.batchSize,
desc_status.seqLenK, desc_status.vSize,
p.input_order, desc_status.auxArray.seqKArray};
cudnn_check(cudnnMultiHeadAttnForward(
cudnn_handle(this->handle()), desc_status.attn_desc, -1,
desc_status.auxArray.loWinIdx, desc_status.auxArray.hiWinIdx,
desc_status.auxArray.devSeqQArray, desc_status.auxArray.devSeqKArray,
q.desc, queries.raw_ptr(), p.reslink ? queries.raw_ptr() : NULL, k.desc,
keys.raw_ptr(), v.desc, values.raw_ptr(), o.desc, out.raw_ptr(),
desc_status.sizeWeights,
desc_status.sizeWeights > 0 ? wqkv.raw_ptr() : NULL,
desc_status.sizeWkspace, workspace.raw_ptr,
p.training ? desc_status.sizeReserve : 0,
p.training ? reserveSpace.raw_ptr() : NULL));
#endif
}
void MultiHeadAttnBackwardImpl::exec(
_megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys,
_megdnn_tensor_in values, _megdnn_tensor_in wqkv,
_megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries,
_megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues,
_megdnn_tensor_out dweights, _megdnn_workspace workspace) {
#if CUDNN_VERSION < 8004
// TODO: CUDNN_VERSION < 8004 and param().bias = true, we need to go to the proxy
// cuda implementation.
MEGDNN_MARK_USED_VAR(diff);
MEGDNN_MARK_USED_VAR(queries);
MEGDNN_MARK_USED_VAR(keys);
MEGDNN_MARK_USED_VAR(values);
MEGDNN_MARK_USED_VAR(wqkv);
MEGDNN_MARK_USED_VAR(reserveSpace);
MEGDNN_MARK_USED_VAR(dqueries);
MEGDNN_MARK_USED_VAR(dkeys);
MEGDNN_MARK_USED_VAR(dvalues);
MEGDNN_MARK_USED_VAR(dweights);
megdnn_throw(
"The cudnn version is lower than 8.0.4. Please upgrade the cudnn version.");
#else
#if CUDNN_VERSION < 8600
megdnn_assert(
!param().bias,
"If the cudnn version is lower than 8.6.0, param().bias must be false, "
"but got true, because there is an error in the "
"dbias result during the backward calculation.");
#endif
check_exec(
diff.layout, queries.layout, keys.layout, values.layout, wqkv.layout,
reserveSpace.layout, dqueries.layout, dkeys.layout, dvalues.layout,
dweights.layout, workspace.size);
auto p = param();
if (!desc_status.is_initialized(p, queries.layout, keys.layout, values.layout))
desc_status.set(
cudnn_handle(this->handle()), p, queries.layout, keys.layout,
values.layout);
SeqTensorDesc q{queries.layout, desc_status.batchSize,
desc_status.seqLenQ, desc_status.qSize,
p.input_order, desc_status.auxArray.seqQArray};
SeqTensorDesc d{diff.layout, desc_status.batchSize,
desc_status.seqLenQ, desc_status.oProjSize,
p.input_order, desc_status.auxArray.seqQArray};
SeqTensorDesc k{keys.layout, desc_status.batchSize,
desc_status.seqLenK, desc_status.kSize,
p.input_order, desc_status.auxArray.seqKArray};
SeqTensorDesc v{values.layout, desc_status.batchSize,
desc_status.seqLenK, desc_status.vSize,
p.input_order, desc_status.auxArray.seqKArray};
cudnn_check(cudnnMultiHeadAttnBackwardData(
cudnn_handle(this->handle()), desc_status.attn_desc,
desc_status.auxArray.loWinIdx, desc_status.auxArray.hiWinIdx,
desc_status.auxArray.devSeqQArray, desc_status.auxArray.devSeqKArray,
d.desc, diff.raw_ptr(), q.desc, dqueries.raw_ptr(), queries.raw_ptr(),
k.desc, dkeys.raw_ptr(), keys.raw_ptr(), v.desc, dvalues.raw_ptr(),
values.raw_ptr(), desc_status.sizeWeights,
desc_status.sizeWeights > 0 ? wqkv.raw_ptr() : NULL,
desc_status.sizeWkspace, workspace.raw_ptr, desc_status.sizeReserve,
reserveSpace.raw_ptr()));
cuda_check(cudaMemset(dweights.raw_ptr(), 0, desc_status.sizeWeights));
#if CUDNN_VERSION < 8600
cuda_check(cudaDeviceSynchronize());
#endif
cudnn_check(cudnnMultiHeadAttnBackwardWeights(
cudnn_handle(this->handle()), desc_status.attn_desc, CUDNN_WGRAD_MODE_ADD,
q.desc, queries.raw_ptr(), k.desc, keys.raw_ptr(), v.desc, values.raw_ptr(),
d.desc, diff.raw_ptr(), desc_status.sizeWeights,
desc_status.sizeWeights > 0 ? wqkv.raw_ptr() : NULL,
desc_status.sizeWeights > 0 ? dweights.raw_ptr() : NULL,
desc_status.sizeWkspace, workspace.raw_ptr, desc_status.sizeReserve,
reserveSpace.raw_ptr()));
#endif
}
size_t MultiHeadAttnBackwardImpl::get_workspace_in_bytes(
const TensorLayout& diff, const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv,
const TensorLayout& reserveSpace, const TensorLayout& dqueries,
const TensorLayout& dkeys, const TensorLayout& dvalues,
const TensorLayout& dweights) {
MEGDNN_MARK_USED_VAR(diff);
MEGDNN_MARK_USED_VAR(queries);
MEGDNN_MARK_USED_VAR(keys);
MEGDNN_MARK_USED_VAR(values);
MEGDNN_MARK_USED_VAR(wqkv);
MEGDNN_MARK_USED_VAR(reserveSpace);
MEGDNN_MARK_USED_VAR(dqueries);
MEGDNN_MARK_USED_VAR(dkeys);
MEGDNN_MARK_USED_VAR(dvalues);
MEGDNN_MARK_USED_VAR(dweights);
return 0;
}
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
#pragma once
#include "megdnn/handle.h"
#include "megdnn/oprs.h"
#include "src/common/reduce_helper.h"
#include "src/cuda/cudnn_wrapper.h"
#include "src/cuda/handle.h"
#include "src/cuda/multi_head_attn/helper.h"
#include "src/cuda/utils.h"
namespace megdnn {
namespace cuda {
class MultiHeadAttnForwardImpl final : public MultiHeadAttnForward {
public:
using MultiHeadAttnForward::MultiHeadAttnForward;
#if CUDNN_VERSION >= 8004
MultiHeadAttnStatus desc_status;
#endif
void exec(
_megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values,
_megdnn_tensor_in wqkv, _megdnn_tensor_out out,
_megdnn_tensor_out reserveSpace, _megdnn_workspace workspace) override;
void deduce_layout(
const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv, TensorLayout& out,
TensorLayout& reserveSpace);
size_t get_reservespace_in_bytes(
const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv,
const TensorLayout& out, const TensorLayout& reserveSpace) override;
size_t get_workspace_in_bytes(
const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv,
const TensorLayout& out, const TensorLayout& reserveSpace) override;
};
class MultiHeadAttnBackwardImpl final : public MultiHeadAttnBackward {
public:
using MultiHeadAttnBackward::MultiHeadAttnBackward;
#if CUDNN_VERSION >= 8004
MultiHeadAttnStatus desc_status;
#endif
void exec(
_megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys,
_megdnn_tensor_in values, _megdnn_tensor_in wqkv,
_megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries,
_megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues,
_megdnn_tensor_out dweights, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& diff, const TensorLayout& queries,
const TensorLayout& keys, const TensorLayout& values,
const TensorLayout& wqkv, const TensorLayout& reserveSpace,
const TensorLayout& dqueries, const TensorLayout& dkeys,
const TensorLayout& dvalues, const TensorLayout& dweights) override;
};
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -54,6 +54,7 @@
#include "src/naive/matrix_mul/opr_impl.h"
#include "src/naive/max_tensor_diff/opr_impl.h"
#include "src/naive/mesh_indexing/opr_impl.h"
#include "src/naive/multi_head_attn/opr_impl.h"
#include "src/naive/norm/opr_impl.h"
#include "src/naive/padding/opr_impl.h"
#include "src/naive/param_pack/opr_impl.h"
......
#include "src/naive/multi_head_attn/opr_impl.h"
#include "megdnn/oprs/linalg.h"
#include "src/common/utils.cuh"
namespace megdnn {
namespace naive {
using Param = MultiHeadAttnBase::Param;
size_t MultiHeadAttnForwardImpl::get_workspace_in_bytes(
const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv, const TensorLayout& out,
const TensorLayout& reserveSpace) {
MEGDNN_MARK_USED_VAR(queries);
MEGDNN_MARK_USED_VAR(keys);
MEGDNN_MARK_USED_VAR(values);
MEGDNN_MARK_USED_VAR(wqkv);
MEGDNN_MARK_USED_VAR(out);
MEGDNN_MARK_USED_VAR(reserveSpace);
megdnn_throw("unsupported naive multiheadattn forward\n");
}
void MultiHeadAttnForwardImpl::exec(
_megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values,
_megdnn_tensor_in wqkv, _megdnn_tensor_out out, _megdnn_tensor_out reserveSpace,
_megdnn_workspace workspace) {
MEGDNN_MARK_USED_VAR(queries);
MEGDNN_MARK_USED_VAR(keys);
MEGDNN_MARK_USED_VAR(values);
MEGDNN_MARK_USED_VAR(wqkv);
MEGDNN_MARK_USED_VAR(out);
MEGDNN_MARK_USED_VAR(reserveSpace);
check_exec(
queries.layout, keys.layout, values.layout, wqkv.layout, out.layout,
reserveSpace.layout, workspace.size);
megdnn_throw("unsupported naive multiheadattn forward\n");
}
void MultiHeadAttnBackwardImpl::exec(
_megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys,
_megdnn_tensor_in values, _megdnn_tensor_in wqkv,
_megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries,
_megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues,
_megdnn_tensor_out dweights, _megdnn_workspace workspace) {
check_exec(
diff.layout, queries.layout, keys.layout, values.layout, wqkv.layout,
reserveSpace.layout, dqueries.layout, dkeys.layout, dvalues.layout,
dweights.layout, workspace.size);
megdnn_throw("unsupported naive multiheadattn backward\n");
}
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
#pragma once
#include <memory>
#include "megdnn/oprs.h"
#include "megdnn/oprs/cv.h"
#include "megdnn/oprs/general.h"
#include "megdnn/oprs/linalg.h"
#include "megdnn/oprs/nn.h"
namespace megdnn {
namespace naive {
class MultiHeadAttnForwardImpl final : public MultiHeadAttnForward {
public:
using MultiHeadAttnForward::MultiHeadAttnForward;
void exec(
_megdnn_tensor_in queries, _megdnn_tensor_in keys, _megdnn_tensor_in values,
_megdnn_tensor_in wqkv, _megdnn_tensor_out out,
_megdnn_tensor_out reserveSpace, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& queries, const TensorLayout& keys,
const TensorLayout& values, const TensorLayout& wqkv,
const TensorLayout& out, const TensorLayout& reserveSpace) override;
size_t get_reservespace_in_bytes(
const TensorLayout& /*queries*/, const TensorLayout& /*keys*/,
const TensorLayout& /*values*/, const TensorLayout& /*wqkv*/,
const TensorLayout& /*out*/,
const TensorLayout& /*reserveSpace*/) override {
return 0;
}
};
class MultiHeadAttnBackwardImpl final : public MultiHeadAttnBackward {
public:
using MultiHeadAttnBackward::MultiHeadAttnBackward;
void exec(
_megdnn_tensor_in diff, _megdnn_tensor_in queries, _megdnn_tensor_in keys,
_megdnn_tensor_in values, _megdnn_tensor_in wqkv,
_megdnn_tensor_in reserveSpace, _megdnn_tensor_out dqueries,
_megdnn_tensor_out dkeys, _megdnn_tensor_out dvalues,
_megdnn_tensor_out dweights, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& /*diff*/, const TensorLayout& /* queries*/,
const TensorLayout& /*keyes*/, const TensorLayout& /* values*/,
const TensorLayout& /*wqkv*/, const TensorLayout& /* reserveSpace*/,
const TensorLayout& /*dqueries*/, const TensorLayout& /* dkeyes*/,
const TensorLayout& /*dvalues*/,
const TensorLayout& /* dweights*/) override {
return 0;
}
};
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -35,7 +35,7 @@ from ..core.tensor.utils import (
subgraph,
subgraph_fn,
)
from ..device import get_default_device
from ..device import get_cudnn_version, get_default_device, is_cuda_available
from ..distributed import WORLD, is_distributed
from ..jit import exclude_from_trace
from ..logger import get_logger
......@@ -104,6 +104,7 @@ __all__ = [
"warp_perspective",
"pixel_shuffle",
"region_restricted_conv",
"multi_head_attention",
]
......@@ -2052,7 +2053,82 @@ def region_restricted_conv(
return output
from .quantized import conv_bias_activation # isort:skip
def multi_head_attention(
query: Tensor,
key: Tensor,
value: Tensor,
embed_dim: int,
num_heads: int,
attn_drop: float,
out_drop: float,
io_weight_bias: Optional[Tensor],
bias: bool = False,
reslink: bool = False,
training: bool = True,
attn_mask: bool = False,
enable_qproj: bool = True,
enable_kproj: bool = True,
enable_vproj: bool = True,
enable_oproj: bool = True,
):
r"""Allows the model to jointly attend to information
from different representation subspaces.
See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
.. math::
\text{MultiHeadAttn}\big(q,K,V, W_Q, W_V, W_O\big) = \sum^{nHeads-1}_{i=0}W_{O,i}h_i
where :math:`h_i=W_{V,i}V \text{Softmax}\Big( \text{smScaler} \cdot K^TW^T_{K,i}W_{Q,i}q \Big),\text{for }i\text{ = 0 ... nHeads-1}`.
See :class:`~.module.MultiHeadAttn` for more details.
Note: This API is experimental, and there is a possibility of subsequent changes. Currently, only the cuda platform is supported, and if the cudnn version >=8.6.0, the calculation results are completely correct; If the cudnn version >=8.0.4 but <8.6.0, if there is a bias, only the dbias result calculated from the backward is incorrect. If there is no bias, the forward and backward calculations are correct; If the cudnn version is less than 8.0.4, this operator is not supported.
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
embed_dim: total dimension of the model.
num_heads: parallel attention heads.
attn_drop: probability of an element to be zeroed, used in attention matrix.
out_drop: probability of an element to be zeroed, used in final output.
io_weight_bias: input/output projection weight/bias all in one, used for cudnn api.
bias: used to indicate a bias in io_weight_bias, used for cudnn api.
reslink: add input query to final output.
training: will apply dropout if is ``True``.
attn_mask: used to indicate whether to add a mask to the attention matrix.
By default, the upper right triangle of the mask is -inf, and the diagonal and lower left triangle are all 0.
Default: `True`
enable_qproj: enable query weight projection. Default: ``True``.
enable_kproj: enable key weight projection. Default: ``True``.
enable_vproj: enable value weight projection. Default: ``True``.
enable_oproj: enable output weight projection. Default: ``True``.
"""
head_dim = embed_dim // num_heads
smScaler = head_dim ** -0.5
op = builtin.MultiHeadAttn(
num_heads=num_heads,
sm_scaler=smScaler,
attn_prob=attn_drop,
out_prob=out_drop,
reslink=reslink,
training=training,
input_order=0,
seed=_get_global_rng_seed(),
bias=bias,
attn_mask=attn_mask,
enable_qproj=enable_qproj,
enable_kproj=enable_kproj,
enable_vproj=enable_vproj,
enable_oproj=enable_oproj,
)
out, reserveSpace = apply(op, query, key, value, io_weight_bias)
return out
from .loss import * # isort:skip
from .metric import * # isort:skip
from .vision import * # isort:skip
from .quantized import conv_bias_activation # isort:skip
......@@ -27,6 +27,7 @@ from .identity import Identity
from .linear import Linear
from .lrn import LocalResponseNorm
from .module import Module
from .multiheadattn import MultiHeadAttention
from .normalization import GeneralNorm, GroupNorm, InstanceNorm, LayerNorm
from .padding import Pad
from .pixel_shuffle import PixelShuffle
......
......@@ -3,6 +3,7 @@ import numpy as np
from ..functional import gelu, leaky_relu, prelu, relu, sigmoid, silu, softmax
from ..tensor import Parameter
from .init import ones_, zeros_
from .module import Module
......
from typing import Optional
import numpy as np
import megengine as mge
import megengine.functional as F
from megengine import Parameter
from ..device import get_cudnn_version, is_cuda_available
from ..functional.nn import multi_head_attention
from ..tensor import Tensor
from .init import ones_, zeros_
from .module import Module
class MultiHeadAttention(Module):
r"""Allows the model to jointly attend to information
from different representation subspaces.
See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
.. math::
\text{MultiHeadAttn}\big(q,K,V, W_Q, W_V, W_O\big) = \sum^{nHeads-1}_{i=0}W_{O,i}h_i
where :math:`h_i=W_{V,i}V \text{Softmax}\Big( \text{smScaler} \cdot K^TW^T_{K,i}W_{Q,i}q \Big),\text{for }i\text{ = 0 ... nHeads-1}`.
Note: This API is experimental, and there is a possibility of subsequent changes. Currently, only the cuda platform is supported, and if the cudnn version >=8.6.0, the calculation results are completely correct; If the cudnn version >=8.0.4 but <8.6.0, if there is a bias, only the dbias result calculated from the backward is incorrect. If there is no bias, the forward and backward calculations are correct; If the cudnn version is less than 8.0.4, this operator is not supported.
Args:
embed_dim: Total dimension of the model.
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
enable_qproj: enable query weight projection. Default: ``True``.
enable_kproj: enable key weight projection. Default: ``True``.
enable_vproj: enable value weight projection. Default: ``True``.
enable_oproj: enable output weight projection. Default: ``True``.
Examples::
>>> import numpy as np
>>> batch_size, seq_len, embed_dim, num_heads = 2, 4, 4, 2
>>> x = Tensor(np.arange(batch_size * seq_len * embed_dim).astype(np.float32).reshape(batch_size, seq_len, embed_dim))
>>> multihead_attn = M.MultiHeadAttention(embed_dim, num_heads)
>>> if is_cuda_available() and get_cudnn_version() >= 8004:
... out = multihead_attn(x, x, x)
... out.numpy().shape
... else:
... print(np.zeros((2,4,4)).shape)
(2, 4, 4)
"""
def __init__(
self,
embed_dim,
num_heads,
attn_dropout=0.0,
out_dropout=0.0,
kdim=None,
vdim=None,
bias=True,
enable_qproj=True,
enable_kproj=True,
enable_vproj=True,
enable_oproj=True,
**kwargs
):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.attn_dropout = attn_dropout
self.out_dropout = out_dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
assert (
self._qkv_same_embed_dim
), "it does not support the case where q, k, and v are different."
self.bias = bias
self.enable_qproj = enable_qproj
self.enable_kproj = enable_kproj
self.enable_vproj = enable_vproj
self.enable_oproj = enable_oproj
self.nproj = enable_qproj + enable_kproj + enable_vproj + enable_oproj
if self.bias:
io_weight = np.ones((embed_dim, self.nproj * embed_dim))
io_bias = np.zeros((1, self.nproj * embed_dim))
self.io_weight_bias = Parameter(
np.concatenate((io_weight, io_bias), axis=0), dtype="float32"
)
else:
self.io_weight_bias = Parameter(
np.ones((self.nproj * embed_dim, embed_dim), dtype="float32")
)
self.reset_parameters()
def reset_parameters(self):
self.attn_dropout = 0.0
self.out_dropout = 0.0
if self.bias:
io_weight = np.ones((self.embed_dim, self.nproj * self.embed_dim))
io_bias = np.zeros((1, self.nproj * self.embed_dim))
self.io_weight_bias._reset(np.concatenate((io_weight, io_bias), axis=0))
else:
ones_(self.io_weight_bias)
def forward(
self, query, key, value, attn_mask: bool = True,
):
r"""
Args:
query: Query embeddings of shape :math:`(N, L, E_q)`, where :math:`N` is the batch size, :math:`L` is the target sequence length,
and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against
key-value pairs to produce the output. See "Attention Is All You Need" for more details.
key: Key embeddings of shape :math:`(N, S, E_k)`, where :math:`N` is the batch size, :math:`S` is the source sequence length, and
:math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details.
value: Value embeddings of shape :math:`(N, S, E_v)`, where :math:`N` is the batch size, :math:`S` is the source sequence length, and
:math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details.
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
Outputs:
- **attn_output** - Attention outputs of shape :math:`(N, L, E)`,
where :math:`L` is the target sequence length, :math:`N` is
the batch size, and :math:`E` is the embedding dimension ``embed_dim``.
"""
return multi_head_attention(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.attn_dropout,
self.out_dropout,
self.io_weight_bias,
self.bias,
training=self.training,
attn_mask=attn_mask,
enable_qproj=self.enable_qproj,
enable_kproj=self.enable_kproj,
enable_vproj=self.enable_vproj,
enable_oproj=self.enable_oproj,
)
def _module_info_string(self) -> str:
s = "embed_dim={embed_dim}, num_heads={num_heads}, dropout={dropout}, bias={bias}, kdim={kdim}, vdim={vdim}"
return s.format(**self.__dict__)
......@@ -285,6 +285,25 @@ struct OpMeth<Dropout> {
}
};
template <>
struct OpMeth<MultiHeadAttn> {
using DnnOp = megdnn::MultiHeadAttn;
using Param = DnnOp::Param;
using OpNode = mgb::opr::MultiHeadAttn;
static Param make_param(const MultiHeadAttn& opdef) {
auto handle_seed = RNGDnnOpManager::get_seed(opdef.handle);
mgb_assert(
handle_seed == opdef.seed,
"inconsistent multiheadattn seed: dropout op: %lu handle: %lu",
handle_seed, opdef.seed);
return {opdef.num_heads, opdef.sm_scaler, opdef.input_order,
opdef.reslink, opdef.training, opdef.bias,
opdef.attn_mask, opdef.enable_qproj, opdef.enable_kproj,
opdef.enable_vproj, opdef.enable_oproj, handle_seed,
opdef.attn_prob, opdef.out_prob};
}
};
template <bool>
struct _InferLayout;
......@@ -401,6 +420,14 @@ _INST_RNG_MAKER(2)
#undef _FOR_EACH_OUT
#undef _FOR_EACH_IN
#define _FOR_EACH_IN(subfix) \
inputs[0] subfix, inputs[1] subfix, inputs[2] subfix, inputs[3] subfix,
#define _FOR_EACH_OUT(subfix) outputs[0] subfix, outputs[1] subfix
_INST_RNG_INVOLKER(4, 2)
_INST_RNG_MAKER(4)
#undef _FOR_EACH_OUT
#undef _FOR_EACH_IN
#undef _INST_RNG_INVOLKER
#undef _INST_RNG_MAKER
......@@ -506,6 +533,39 @@ SmallVector<LogicalTensorDesc> infer_output_attrs<Dropout>(
return dests;
}
template <>
SmallVector<LogicalTensorDesc> infer_output_attrs<MultiHeadAttn>(
const OpDef& op, const SmallVector<TensorPtr>& inputs) {
SmallVector<LogicalTensorDesc> dests(2);
auto&& cn = inputs[0]->comp_node();
dests[0].comp_node = cn;
dests[0].layout = TensorLayout(inputs[0]->layout());
dests[0].layout.dtype = inputs[0]->layout().dtype;
auto get_reservespace_in_bytes = [&]() -> size_t {
// retrieve dnn_op from glob cache
auto&& rng = op.cast_final_safe<MultiHeadAttn>();
auto handle = rng.handle;
if (!handle) {
handle = RNGDnnOpManager::get_default_handle(cn);
}
auto dnn_op_thread_safe =
RNGDnnOpManager::inst().get_dnn_op<megdnn::MultiHeadAttn>(
handle, reinterpret_cast<size_t>(op.dyn_typeinfo()), cn);
auto dnn_op = std::get<1>(dnn_op_thread_safe);
dnn_op->param() = OpMeth<MultiHeadAttn>::make_param(rng);
return dnn_op->get_reservespace_in_bytes(
inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
inputs[3]->layout(), {}, {});
};
dests[1].comp_node = cn;
dests[1].layout =
TensorLayout(TensorShape({get_reservespace_in_bytes()}), dtype::Byte());
return dests;
}
template <typename Op>
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
......@@ -600,6 +660,44 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dro
return {dests, success};
}
template <>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<
MultiHeadAttn>(const OpDef& op, const SmallVector<LogicalTensorDesc>& inputs) {
bool success = inputs[0].layout.ndim != 0;
SmallVector<LogicalTensorDesc> dests(2);
auto cn = inputs[0].comp_node;
dests[0].comp_node = cn;
dests[0].layout = TensorLayout(inputs[0].layout);
dests[0].layout.dtype = inputs[0].layout.dtype;
auto get_reservespace_in_bytes = [&]() -> size_t {
auto&& rng = op.cast_final_safe<MultiHeadAttn>();
auto handle = rng.handle;
if (!handle) {
handle = RNGDnnOpManager::get_default_handle(cn);
}
auto dnn_op_thread_safe =
RNGDnnOpManager::inst().get_dnn_op<megdnn::MultiHeadAttn>(
handle, reinterpret_cast<size_t>(op.dyn_typeinfo()), cn);
auto dnn_op = std::get<1>(dnn_op_thread_safe);
dnn_op->param() = OpMeth<MultiHeadAttn>::make_param(rng);
return dnn_op->get_reservespace_in_bytes(
inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout,
{}, {});
};
dests[1].comp_node = cn;
if (success) {
dests[1].layout =
TensorLayout(TensorShape({get_reservespace_in_bytes()}), dtype::Byte());
} else {
dests[1].layout = TensorLayout(dtype::Byte());
}
return {dests, success};
}
template <typename Op>
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
......@@ -647,6 +745,7 @@ REG_RNG_OP(PoissonRNG, SymbolVar)
REG_RNG_OP(BetaRNG, SymbolVar)
REG_RNG_OP(ShuffleRNG, SymbolVarArray)
REG_RNG_OP(Dropout, SymbolVarArray)
REG_RNG_OP(MultiHeadAttn, SymbolVarArray)
#undef REG_RNG_OP
} // namespace mgb::imperative::rng
......
148f3844ee8787250cd231eb3c1989c3 ../../dnn/scripts/opr_param_defs.py
b603857c46345dcb9f1693f49217b269 ../../src/core/include/megbrain/ir/ops.td
ae54e2eba267dc21d8c648963df23a90 generated/opdef.h.inl
94b20dcecd3dea69883d46ca7b8482be generated/opdef.cpp.inl
4ae5f0198e97e69eb381411f3d60e8c8 generated/opdef.py.inl
4971c6b2ba7f6fca395d73c554526a0e generated/opdef.cpy.inl
c5a5d1bd44473912f14cecee3df6409e ../../dnn/scripts/opr_param_defs.py
4ed3e8cbef0fa5f4d6824d8d55dec722 ../../src/core/include/megbrain/ir/ops.td
dc2d4ec8f4f5e203ce0a76bc20f62529 generated/opdef.h.inl
906957f12994d43c69248a6acfefa396 generated/opdef.cpp.inl
8817af8997ba0cc00048e71093755238 generated/opdef.py.inl
c43ae8b706e3f3658fe3cc0f60061981 generated/opdef.cpy.inl
71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h
......@@ -5186,6 +5186,96 @@ OP_TRAIT_REG(MeshIndexing, MeshIndexing)
.props(MeshIndexing_props_impl)
.make_name(MeshIndexing_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(MultiHeadAttn);
namespace {
size_t MultiHeadAttn_hash_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<MultiHeadAttn>();
static_cast<void>(op_);
return mgb::hash_pair_combine(
mgb::hash(op_.dyn_typeinfo()),
mgb::hash_pair_combine(
mgb::hash(op_.handle),
mgb::hash_pair_combine(
mgb::hash(op_.num_heads),
mgb::hash_pair_combine(
mgb::hash(op_.sm_scaler),
mgb::hash_pair_combine(
mgb::hash(op_.input_order),
mgb::hash_pair_combine(
mgb::hash(op_.reslink),
mgb::hash_pair_combine(
mgb::hash(op_.training),
mgb::hash_pair_combine(
mgb::hash(op_.bias),
mgb::hash_pair_combine(
mgb::hash(op_.attn_mask),
mgb::hash_pair_combine(
mgb::hash(op_.enable_qproj),
mgb::hash_pair_combine(
mgb::hash(op_.enable_kproj),
mgb::hash_pair_combine(
mgb::hash(op_.enable_vproj),
mgb::hash_pair_combine(
mgb::hash(op_.enable_oproj),
mgb::hash_pair_combine(
mgb::hash(op_.attn_prob),
mgb::hash(op_.out_prob)
)
)
)
)
)
)
)
)
)
)
)
)
)
);
}
bool MultiHeadAttn_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
auto &&a_ = lhs_.cast_final_safe<MultiHeadAttn>(),
&&b_ = rhs_.cast_final_safe<MultiHeadAttn>();
static_cast<void>(a_);
static_cast<void>(b_);
return a_.handle == b_.handle && a_.num_heads == b_.num_heads && a_.sm_scaler == b_.sm_scaler && a_.input_order == b_.input_order && a_.reslink == b_.reslink && a_.training == b_.training && a_.bias == b_.bias && a_.attn_mask == b_.attn_mask && a_.enable_qproj == b_.enable_qproj && a_.enable_kproj == b_.enable_kproj && a_.enable_vproj == b_.enable_vproj && a_.enable_oproj == b_.enable_oproj && a_.attn_prob == b_.attn_prob && a_.out_prob == b_.out_prob;}
std::vector<std::pair<const char*, std::string>> MultiHeadAttn_props_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<MultiHeadAttn>();
static_cast<void>(op_);
std::vector<std::pair<const char*, std::string>> props_;
props_.emplace_back("num_heads", std::to_string(op_.num_heads));
props_.emplace_back("sm_scaler", std::to_string(op_.sm_scaler));
props_.emplace_back("input_order", std::to_string(op_.input_order));
props_.emplace_back("reslink", std::to_string(op_.reslink));
props_.emplace_back("training", std::to_string(op_.training));
props_.emplace_back("bias", std::to_string(op_.bias));
props_.emplace_back("attn_mask", std::to_string(op_.attn_mask));
props_.emplace_back("enable_qproj", std::to_string(op_.enable_qproj));
props_.emplace_back("enable_kproj", std::to_string(op_.enable_kproj));
props_.emplace_back("enable_vproj", std::to_string(op_.enable_vproj));
props_.emplace_back("enable_oproj", std::to_string(op_.enable_oproj));
props_.emplace_back("seed", std::to_string(op_.seed));
props_.emplace_back("attn_prob", std::to_string(op_.attn_prob));
props_.emplace_back("out_prob", std::to_string(op_.out_prob));
props_.emplace_back("handle", std::to_string(op_.handle));
return props_;
}
std::string MultiHeadAttn_make_name_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<MultiHeadAttn>();
static_cast<void>(op_);
return "MultiHeadAttn";
}
} // anonymous namespace
OP_TRAIT_REG(MultiHeadAttn, MultiHeadAttn)
.hash(MultiHeadAttn_hash_impl)
.is_same_st(MultiHeadAttn_is_same_st_impl)
.props(MultiHeadAttn_props_impl)
.make_name(MultiHeadAttn_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(NMSKeep);
namespace {
......
......@@ -15043,6 +15043,367 @@ void _init_py_MeshIndexing(py::module m) {
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MeshIndexing::typeinfo(), &py_type).second);
}
PyOpDefBegin(MultiHeadAttn) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
static PyObject* getstate(PyObject* self, PyObject*) {
auto& opdef = reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst();
static_cast<void>(opdef);
std::unordered_map<std::string, py::object> state {
{"num_heads", serialization<decltype(opdef.num_heads)>::dump(opdef.num_heads)},
{"sm_scaler", serialization<decltype(opdef.sm_scaler)>::dump(opdef.sm_scaler)},
{"input_order", serialization<decltype(opdef.input_order)>::dump(opdef.input_order)},
{"reslink", serialization<decltype(opdef.reslink)>::dump(opdef.reslink)},
{"training", serialization<decltype(opdef.training)>::dump(opdef.training)},
{"bias", serialization<decltype(opdef.bias)>::dump(opdef.bias)},
{"attn_mask", serialization<decltype(opdef.attn_mask)>::dump(opdef.attn_mask)},
{"enable_qproj", serialization<decltype(opdef.enable_qproj)>::dump(opdef.enable_qproj)},
{"enable_kproj", serialization<decltype(opdef.enable_kproj)>::dump(opdef.enable_kproj)},
{"enable_vproj", serialization<decltype(opdef.enable_vproj)>::dump(opdef.enable_vproj)},
{"enable_oproj", serialization<decltype(opdef.enable_oproj)>::dump(opdef.enable_oproj)},
{"seed", serialization<decltype(opdef.seed)>::dump(opdef.seed)},
{"attn_prob", serialization<decltype(opdef.attn_prob)>::dump(opdef.attn_prob)},
{"out_prob", serialization<decltype(opdef.out_prob)>::dump(opdef.out_prob)},
{"handle", serialization<decltype(opdef.handle)>::dump(opdef.handle)}
};
return py::cast(state).release().ptr();
}
static PyObject* setstate(PyObject* self, PyObject* args) {
PyObject* dict = PyTuple_GetItem(args, 0);
if (!dict) return NULL;
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
auto& opdef = reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst();
static_cast<void>(opdef);
{
auto&& iter = state.find("num_heads");
if (iter != state.end()) {
opdef.num_heads = serialization<decltype(opdef.num_heads)>::load(iter->second);
}
}
{
auto&& iter = state.find("sm_scaler");
if (iter != state.end()) {
opdef.sm_scaler = serialization<decltype(opdef.sm_scaler)>::load(iter->second);
}
}
{
auto&& iter = state.find("input_order");
if (iter != state.end()) {
opdef.input_order = serialization<decltype(opdef.input_order)>::load(iter->second);
}
}
{
auto&& iter = state.find("reslink");
if (iter != state.end()) {
opdef.reslink = serialization<decltype(opdef.reslink)>::load(iter->second);
}
}
{
auto&& iter = state.find("training");
if (iter != state.end()) {
opdef.training = serialization<decltype(opdef.training)>::load(iter->second);
}
}
{
auto&& iter = state.find("bias");
if (iter != state.end()) {
opdef.bias = serialization<decltype(opdef.bias)>::load(iter->second);
}
}
{
auto&& iter = state.find("attn_mask");
if (iter != state.end()) {
opdef.attn_mask = serialization<decltype(opdef.attn_mask)>::load(iter->second);
}
}
{
auto&& iter = state.find("enable_qproj");
if (iter != state.end()) {
opdef.enable_qproj = serialization<decltype(opdef.enable_qproj)>::load(iter->second);
}
}
{
auto&& iter = state.find("enable_kproj");
if (iter != state.end()) {
opdef.enable_kproj = serialization<decltype(opdef.enable_kproj)>::load(iter->second);
}
}
{
auto&& iter = state.find("enable_vproj");
if (iter != state.end()) {
opdef.enable_vproj = serialization<decltype(opdef.enable_vproj)>::load(iter->second);
}
}
{
auto&& iter = state.find("enable_oproj");
if (iter != state.end()) {
opdef.enable_oproj = serialization<decltype(opdef.enable_oproj)>::load(iter->second);
}
}
{
auto&& iter = state.find("seed");
if (iter != state.end()) {
opdef.seed = serialization<decltype(opdef.seed)>::load(iter->second);
}
}
{
auto&& iter = state.find("attn_prob");
if (iter != state.end()) {
opdef.attn_prob = serialization<decltype(opdef.attn_prob)>::load(iter->second);
}
}
{
auto&& iter = state.find("out_prob");
if (iter != state.end()) {
opdef.out_prob = serialization<decltype(opdef.out_prob)>::load(iter->second);
}
}
{
auto&& iter = state.find("handle");
if (iter != state.end()) {
opdef.handle = serialization<decltype(opdef.handle)>::load(iter->second);
}
}
Py_RETURN_NONE;
}
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds);
static PyMethodDef py_init_methoddef;
// };
PyOpDefEnd(MultiHeadAttn)
int PyOp(MultiHeadAttn)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
static const char* kwlist[] = {"num_heads", "sm_scaler", "input_order", "reslink", "training", "bias", "attn_mask", "enable_qproj", "enable_kproj", "enable_vproj", "enable_oproj", "seed", "attn_prob", "out_prob", "handle", "scope", NULL};
PyObject *num_heads = NULL, *sm_scaler = NULL, *input_order = NULL, *reslink = NULL, *training = NULL, *bias = NULL, *attn_mask = NULL, *enable_qproj = NULL, *enable_kproj = NULL, *enable_vproj = NULL, *enable_oproj = NULL, *seed = NULL, *attn_prob = NULL, *out_prob = NULL, *handle = NULL, *scope = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOOOOOOOOOOOOO", const_cast<char**>(kwlist), &num_heads, &sm_scaler, &input_order, &reslink, &training, &bias, &attn_mask, &enable_qproj, &enable_kproj, &enable_vproj, &enable_oproj, &seed, &attn_prob, &out_prob, &handle, &scope))
return -1;
if (num_heads) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().num_heads =
py::cast<decltype(MultiHeadAttn::num_heads)>(py::handle(num_heads));
} CATCH_ALL(-1)
}
if (sm_scaler) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().sm_scaler =
py::cast<decltype(MultiHeadAttn::sm_scaler)>(py::handle(sm_scaler));
} CATCH_ALL(-1)
}
if (input_order) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().input_order =
py::cast<decltype(MultiHeadAttn::input_order)>(py::handle(input_order));
} CATCH_ALL(-1)
}
if (reslink) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().reslink =
py::cast<decltype(MultiHeadAttn::reslink)>(py::handle(reslink));
} CATCH_ALL(-1)
}
if (training) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().training =
py::cast<decltype(MultiHeadAttn::training)>(py::handle(training));
} CATCH_ALL(-1)
}
if (bias) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().bias =
py::cast<decltype(MultiHeadAttn::bias)>(py::handle(bias));
} CATCH_ALL(-1)
}
if (attn_mask) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().attn_mask =
py::cast<decltype(MultiHeadAttn::attn_mask)>(py::handle(attn_mask));
} CATCH_ALL(-1)
}
if (enable_qproj) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().enable_qproj =
py::cast<decltype(MultiHeadAttn::enable_qproj)>(py::handle(enable_qproj));
} CATCH_ALL(-1)
}
if (enable_kproj) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().enable_kproj =
py::cast<decltype(MultiHeadAttn::enable_kproj)>(py::handle(enable_kproj));
} CATCH_ALL(-1)
}
if (enable_vproj) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().enable_vproj =
py::cast<decltype(MultiHeadAttn::enable_vproj)>(py::handle(enable_vproj));
} CATCH_ALL(-1)
}
if (enable_oproj) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().enable_oproj =
py::cast<decltype(MultiHeadAttn::enable_oproj)>(py::handle(enable_oproj));
} CATCH_ALL(-1)
}
if (seed) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().seed =
py::cast<decltype(MultiHeadAttn::seed)>(py::handle(seed));
} CATCH_ALL(-1)
}
if (attn_prob) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().attn_prob =
py::cast<decltype(MultiHeadAttn::attn_prob)>(py::handle(attn_prob));
} CATCH_ALL(-1)
}
if (out_prob) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().out_prob =
py::cast<decltype(MultiHeadAttn::out_prob)>(py::handle(out_prob));
} CATCH_ALL(-1)
}
if (handle) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(MultiHeadAttn)*>(self)->inst().handle =
py::cast<decltype(MultiHeadAttn::handle)>(py::handle(handle));
} CATCH_ALL(-1)
}
if (scope) {
try {
reinterpret_cast<PyOp(OpDef)*>(self)->op
->set_scope(py::cast<std::string>(py::handle(scope)));
} CATCH_ALL(-1)
}
return 0;
}
PyGetSetDef PyOp(MultiHeadAttn)::py_getsetters[] = {
{const_cast<char*>("num_heads"), py_get_generic(MultiHeadAttn, num_heads), py_set_generic(MultiHeadAttn, num_heads), const_cast<char*>("num_heads"), NULL},
{const_cast<char*>("sm_scaler"), py_get_generic(MultiHeadAttn, sm_scaler), py_set_generic(MultiHeadAttn, sm_scaler), const_cast<char*>("sm_scaler"), NULL},
{const_cast<char*>("input_order"), py_get_generic(MultiHeadAttn, input_order), py_set_generic(MultiHeadAttn, input_order), const_cast<char*>("input_order"), NULL},
{const_cast<char*>("reslink"), py_get_generic(MultiHeadAttn, reslink), py_set_generic(MultiHeadAttn, reslink), const_cast<char*>("reslink"), NULL},
{const_cast<char*>("training"), py_get_generic(MultiHeadAttn, training), py_set_generic(MultiHeadAttn, training), const_cast<char*>("training"), NULL},
{const_cast<char*>("bias"), py_get_generic(MultiHeadAttn, bias), py_set_generic(MultiHeadAttn, bias), const_cast<char*>("bias"), NULL},
{const_cast<char*>("attn_mask"), py_get_generic(MultiHeadAttn, attn_mask), py_set_generic(MultiHeadAttn, attn_mask), const_cast<char*>("attn_mask"), NULL},
{const_cast<char*>("enable_qproj"), py_get_generic(MultiHeadAttn, enable_qproj), py_set_generic(MultiHeadAttn, enable_qproj), const_cast<char*>("enable_qproj"), NULL},
{const_cast<char*>("enable_kproj"), py_get_generic(MultiHeadAttn, enable_kproj), py_set_generic(MultiHeadAttn, enable_kproj), const_cast<char*>("enable_kproj"), NULL},
{const_cast<char*>("enable_vproj"), py_get_generic(MultiHeadAttn, enable_vproj), py_set_generic(MultiHeadAttn, enable_vproj), const_cast<char*>("enable_vproj"), NULL},
{const_cast<char*>("enable_oproj"), py_get_generic(MultiHeadAttn, enable_oproj), py_set_generic(MultiHeadAttn, enable_oproj), const_cast<char*>("enable_oproj"), NULL},
{const_cast<char*>("seed"), py_get_generic(MultiHeadAttn, seed), py_set_generic(MultiHeadAttn, seed), const_cast<char*>("seed"), NULL},
{const_cast<char*>("attn_prob"), py_get_generic(MultiHeadAttn, attn_prob), py_set_generic(MultiHeadAttn, attn_prob), const_cast<char*>("attn_prob"), NULL},
{const_cast<char*>("out_prob"), py_get_generic(MultiHeadAttn, out_prob), py_set_generic(MultiHeadAttn, out_prob), const_cast<char*>("out_prob"), NULL},
{const_cast<char*>("handle"), py_get_generic(MultiHeadAttn, handle), py_set_generic(MultiHeadAttn, handle), const_cast<char*>("handle"), NULL},
{NULL} /* Sentinel */
};
PyMethodDef PyOp(MultiHeadAttn)::tp_methods[] = {
{const_cast<char*>("__getstate__"), PyOp(MultiHeadAttn)::getstate, METH_NOARGS, "MultiHeadAttn getstate"},
{const_cast<char*>("__setstate__"), PyOp(MultiHeadAttn)::setstate, METH_VARARGS, "MultiHeadAttn setstate"},
{NULL} /* Sentinel */
};
PyObject *PyOp(MultiHeadAttn)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) {
if (PyOp(MultiHeadAttn)::py_init(self, args, kwds) < 0) {
return NULL;
}
Py_RETURN_NONE;
}
PyMethodDef PyOp(MultiHeadAttn)::py_init_methoddef = {
"__init__",
(PyCFunction)PyOp(MultiHeadAttn)::py_init_proxy,
METH_VARARGS | METH_KEYWORDS,
"__init__(self, num_heads: int = ..., sm_scaler: float = ..., input_order: int = ..., reslink: bool = ..., training: bool = ..., bias: bool = ..., attn_mask: bool = ..., enable_qproj: bool = ..., enable_kproj: bool = ..., enable_vproj: bool = ..., enable_oproj: bool = ..., seed: int = ..., attn_prob: float = ..., out_prob: float = ..., handle: int = ...) -> None\n"
};
void _init_py_MultiHeadAttn(py::module m) {
using py_op = PyOp(MultiHeadAttn);
auto& py_type = PyOpType(MultiHeadAttn);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.MultiHeadAttn";
py_type.tp_basicsize = sizeof(PyOp(MultiHeadAttn));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "MultiHeadAttn";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_methods = py_op::tp_methods;
py_type.tp_getset = py_op::py_getsetters;
py_type.tp_dict = PyDict_New();
PyObject* descr = PyDescr_NewMethod(&PyOpType(MultiHeadAttn), &PyOp(MultiHeadAttn)::py_init_methoddef);
PyDict_SetItemString(py_type.tp_dict, "__init__", descr);
mgb_assert(PyType_Ready(&py_type) >= 0);
PyType_Modified(&py_type);
m.add_object("MultiHeadAttn", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MultiHeadAttn::typeinfo(), &py_type).second);
}
PyOpDefBegin(NMSKeep) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
......@@ -22608,6 +22969,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) {
_init_py_MatrixMul(m); \
_init_py_MeshGrid(m); \
_init_py_MeshIndexing(m); \
_init_py_MultiHeadAttn(m); \
_init_py_NMSKeep(m); \
_init_py_NvOf(m); \
_init_py_Padding(m); \
......
......@@ -1394,6 +1394,33 @@ public:
MeshIndexing(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
};
class MultiHeadAttn : public OpDefImplBase<MultiHeadAttn> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
uint32_t num_heads = 1;
float sm_scaler = 1.f;
uint32_t input_order = 0;
bool reslink = false;
bool training = true;
bool bias = false;
bool attn_mask = false;
bool enable_qproj = true;
bool enable_kproj = true;
bool enable_vproj = true;
bool enable_oproj = true;
uint64_t seed = 0;
float attn_prob = 0.f;
float out_prob = 0.f;
size_t handle;
MultiHeadAttn() = default;
MultiHeadAttn(uint32_t num_heads_, float sm_scaler_, uint32_t input_order_, bool reslink_, bool training_, bool bias_, bool attn_mask_, bool enable_qproj_, bool enable_kproj_, bool enable_vproj_, bool enable_oproj_, uint64_t seed_, float attn_prob_, float out_prob_, size_t handle_, std::string scope_ = {}): num_heads(num_heads_), sm_scaler(sm_scaler_), input_order(input_order_), reslink(reslink_), training(training_), bias(bias_), attn_mask(attn_mask_), enable_qproj(enable_qproj_), enable_kproj(enable_kproj_), enable_vproj(enable_vproj_), enable_oproj(enable_oproj_), seed(seed_), attn_prob(attn_prob_), out_prob(out_prob_), handle(handle_) { set_scope(scope_); }
MultiHeadAttn(::megdnn::param::MultiHeadAttn packed_param_0, size_t handle_): num_heads(packed_param_0.num_heads), sm_scaler(packed_param_0.sm_scaler), input_order(packed_param_0.input_order), reslink(packed_param_0.reslink), training(packed_param_0.training), bias(packed_param_0.bias), attn_mask(packed_param_0.attn_mask), enable_qproj(packed_param_0.enable_qproj), enable_kproj(packed_param_0.enable_kproj), enable_vproj(packed_param_0.enable_vproj), enable_oproj(packed_param_0.enable_oproj), seed(packed_param_0.seed), attn_prob(packed_param_0.attn_prob), out_prob(packed_param_0.out_prob), handle(handle_) {}
::megdnn::param::MultiHeadAttn param() const {
return {num_heads, sm_scaler, input_order, reslink, training, bias, attn_mask, enable_qproj, enable_kproj, enable_vproj, enable_oproj, seed, attn_prob, out_prob};
}
};
class NMSKeep : public OpDefImplBase<NMSKeep> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
......
......@@ -1477,6 +1477,27 @@ MeshIndexingInst
.def(py::init<>())
.def_readwrite("items", &MeshIndexing::items);
py::class_<MultiHeadAttn, std::shared_ptr<MultiHeadAttn>, OpDef> MultiHeadAttnInst(m, "MultiHeadAttn");
MultiHeadAttnInst
.def(py::init<uint32_t, float, uint32_t, bool, bool, bool, bool, bool, bool, bool, bool, uint64_t, float, float, size_t, std::string>(), py::arg("num_heads") = 1, py::arg("sm_scaler") = 1.f, py::arg("input_order") = 0, py::arg("reslink") = false, py::arg("training") = true, py::arg("bias") = false, py::arg("attn_mask") = false, py::arg("enable_qproj") = true, py::arg("enable_kproj") = true, py::arg("enable_vproj") = true, py::arg("enable_oproj") = true, py::arg("seed") = 0, py::arg("attn_prob") = 0.f, py::arg("out_prob") = 0.f, py::arg("handle"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("num_heads", &MultiHeadAttn::num_heads)
.def_readwrite("sm_scaler", &MultiHeadAttn::sm_scaler)
.def_readwrite("input_order", &MultiHeadAttn::input_order)
.def_readwrite("reslink", &MultiHeadAttn::reslink)
.def_readwrite("training", &MultiHeadAttn::training)
.def_readwrite("bias", &MultiHeadAttn::bias)
.def_readwrite("attn_mask", &MultiHeadAttn::attn_mask)
.def_readwrite("enable_qproj", &MultiHeadAttn::enable_qproj)
.def_readwrite("enable_kproj", &MultiHeadAttn::enable_kproj)
.def_readwrite("enable_vproj", &MultiHeadAttn::enable_vproj)
.def_readwrite("enable_oproj", &MultiHeadAttn::enable_oproj)
.def_readwrite("seed", &MultiHeadAttn::seed)
.def_readwrite("attn_prob", &MultiHeadAttn::attn_prob)
.def_readwrite("out_prob", &MultiHeadAttn::out_prob)
.def_readwrite("handle", &MultiHeadAttn::handle);
py::class_<NMSKeep, std::shared_ptr<NMSKeep>, OpDef> NMSKeepInst(m, "NMSKeep");
NMSKeepInst
......
......@@ -559,4 +559,57 @@ def RegionRestrictedConvolution: MgbHashableOp<"RegionRestrictedConvolution", [C
def RegionRestrictedConvolutionBackwardData: MgbHashableOp<"RegionRestrictedConvolutionBackwardData", [ConvolutionParam]>;
def MaskedFill: MgbHashableOp<"MaskedFill", [FillParam]>;
def MultiHeadAttn: MgbHashableOp<"MultiHeadAttn", [MultiHeadAttnParam]> {
let extraArguments = (ins
MgbSizeTAddr:$handle
);
let hashFunction = [{
return mgb::hash_pair_combine(
mgb::hash($_self.dyn_typeinfo()),
mgb::hash_pair_combine(
mgb::hash($_self.handle),
mgb::hash_pair_combine(
mgb::hash($_self.num_heads),
mgb::hash_pair_combine(
mgb::hash($_self.sm_scaler),
mgb::hash_pair_combine(
mgb::hash($_self.input_order),
mgb::hash_pair_combine(
mgb::hash($_self.reslink),
mgb::hash_pair_combine(
mgb::hash($_self.training),
mgb::hash_pair_combine(
mgb::hash($_self.bias),
mgb::hash_pair_combine(
mgb::hash($_self.attn_mask),
mgb::hash_pair_combine(
mgb::hash($_self.enable_qproj),
mgb::hash_pair_combine(
mgb::hash($_self.enable_kproj),
mgb::hash_pair_combine(
mgb::hash($_self.enable_vproj),
mgb::hash_pair_combine(
mgb::hash($_self.enable_oproj),
mgb::hash_pair_combine(
mgb::hash($_self.attn_prob),
mgb::hash($_self.out_prob)
)
)
)
)
)
)
)
)
)
)
)
)
)
);
}];
let cmpFunction = [{return $0.handle == $1.handle && $0.num_heads == $1.num_heads && $0.sm_scaler == $1.sm_scaler && $0.input_order == $1.input_order && $0.reslink == $1.reslink && $0.training == $1.training && $0.bias == $1.bias && $0.attn_mask == $1.attn_mask && $0.enable_qproj == $1.enable_qproj && $0.enable_kproj == $1.enable_kproj && $0.enable_vproj == $1.enable_vproj && $0.enable_oproj == $1.enable_oproj && $0.attn_prob == $1.attn_prob && $0.out_prob == $1.out_prob;}];
}
#endif // MGB_OPS
......@@ -159,6 +159,11 @@ using MegDNNOprMethInvoker = _MegDNNOprMethInvoker<Opr::NR_INPUTS, Opr::NR_OUTPU
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _o(0)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 4
#define _NR_OUTPUTS 2
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _o(0), _o(1)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 4
#define _NR_OUTPUTS 4
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _o(0), _o(1), _o(2), _o(3)
......@@ -179,6 +184,12 @@ using MegDNNOprMethInvoker = _MegDNNOprMethInvoker<Opr::NR_INPUTS, Opr::NR_OUTPU
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1), _o(2)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 5
#define _NR_OUTPUTS 5
#define _FOREACH_IO(_i, _o) \
_i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1), _o(2), _o(3), _o(4)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 6
#define _NR_OUTPUTS 1
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0)
......@@ -195,6 +206,12 @@ using MegDNNOprMethInvoker = _MegDNNOprMethInvoker<Opr::NR_INPUTS, Opr::NR_OUTPU
_i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0), _o(1), _o(2)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 6
#define _NR_OUTPUTS 4
#define _FOREACH_IO(_i, _o) \
_i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0), _o(1), _o(2), _o(3)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 7
#define _NR_OUTPUTS 3
#define _FOREACH_IO(_i, _o) \
......
......@@ -192,6 +192,8 @@ template class RNGOprBase<::megdnn::ShuffleRNGForward>;
template class RNGOprBase<::megdnn::ShuffleRNGBackward>;
template class RNGOprBase<::megdnn::DropoutForward>;
template class RNGOprBase<::megdnn::DropoutBackward>;
template class RNGOprBase<::megdnn::MultiHeadAttnForward>;
template class RNGOprBase<::megdnn::MultiHeadAttnBackward>;
#if MGB_ENABLE_GRAD
IMPL(GaussianRNG);
IMPL(UniformRNG);
......@@ -375,7 +377,7 @@ MGB_IMPL_OPR_GRAD(DropoutForward) {
}
#endif
/* ==================== LayerNormBackward ==================== */
/* ==================== DropoutBackward ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(DropoutBackward);
......@@ -421,4 +423,170 @@ void DropoutBackward::scn_do_execute() {
output(0)->dev_tensor().as_megdnn(), {});
}
/* ==================== MultiHeadAttnForward ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(MultiHeadAttnForward);
MultiHeadAttnForward::MultiHeadAttnForward(
VarNode* queries, VarNode* keys, VarNode* values, VarNode* wqkv,
const Param& param, const OperatorNodeConfig& config)
: Super{{queries->owner_graph(),
config,
"multi_head_attn",
{queries, keys, values, wqkv}},
param} {
add_input({queries, keys, values, wqkv});
add_output(None)
->dtype(queries->dtype())
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
add_output(None)->dtype(dtype::Byte()).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
cg::add_workspace_output(this);
add_equivalence_component<ScalarHash<void*>>(this);
}
SymbolVarArray MultiHeadAttnForward::make(
SymbolVar queries, SymbolVar keys, SymbolVar values, SymbolVar wqkv,
const Param& param, const OperatorNodeConfig& config) {
auto outs = queries.node()
->owner_graph()
->insert_opr(std::make_unique<MultiHeadAttnForward>(
queries.node(), keys.node(), values.node(), wqkv.node(),
param, config))
->output();
mgb_assert(outs.size() == 3);
return {outs[0], outs[1]};
}
void MultiHeadAttnForward::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0)));
auto infer_mask = [this](TensorShape& dest, const InpVal& iv) {
ensure_megdnn_opr();
dest.ndim = 1;
dest.shape[0] = m_dnn_opr->get_reservespace_in_bytes(
{iv.val[0].shape(), input(0)->dtype()},
{iv.val[1].shape(), input(1)->dtype()},
{iv.val[2].shape(), input(2)->dtype()},
{iv.val[3].shape(), input(3)->dtype()}, {}, {});
return true;
};
mgr.register_shape_infer(
output(1), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_mask});
}
void MultiHeadAttnForward::add_input_layout_constraint() {
input(0)->add_layout_constraint_contiguous();
input(1)->add_layout_constraint_contiguous();
input(2)->add_layout_constraint_contiguous();
input(3)->add_layout_constraint_contiguous();
};
void MultiHeadAttnForward::scn_do_execute() {
auto&& ret = output(0);
if (ret->layout().is_empty()) {
mgb_assert(ret->dev_tensor().empty());
return;
}
m_dnn_opr->exec(
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(),
output(0)->dev_tensor().as_megdnn(), output(1)->dev_tensor().as_megdnn(),
get_megdnn_workspace_from_var(output(2)));
}
cg::OperatorNodeBase::NodeProp* MultiHeadAttnForward::do_make_node_prop() const {
auto prop = Super::do_make_node_prop();
prop->add_flag(NodeProp::Flag::IMPURE_FUNC);
for (auto i : input()) {
prop->add_dep_type_existing_var(i, NodeProp::DepType::VALUE_ALLOW_EMPTY);
}
return prop;
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(MultiHeadAttnForward) {
MGB_MARK_USED_VAR(opr);
MGB_MARK_USED_VAR(out_grad);
SymbolVarArray grad;
VarNodeArray ret;
mgb_assert(wrt_idx < 5, "wrt_idx %zu is out of range", wrt_idx);
grad = MultiHeadAttnBackward::make(
out_grad[0], opr.input(0), opr.input(1), opr.input(2), opr.input(3),
opr.output(1), opr.param());
uint32_t nr_ret = 4;
for (uint32_t i = 0; i < nr_ret; ++i) {
ret.push_back(grad[i].node());
}
return ret;
}
#endif
/* ==================== MultiHeadAttnBackwardData ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(MultiHeadAttnBackward);
MultiHeadAttnBackward::MultiHeadAttnBackward(
VarNode* diff, VarNode* queries, VarNode* keys, VarNode* values, VarNode* wqkv,
VarNode* reserveSpace, const Param& param, const OperatorNodeConfig& config)
: Super({queries->owner_graph(),
config,
"multi_head_attn_backward",
{diff, queries, keys, values, wqkv, reserveSpace}},
0, true) {
init_megdnn_opr(*this, param);
add_input({diff, queries, keys, values, wqkv, reserveSpace});
}
SymbolVarArray MultiHeadAttnBackward::make(
SymbolVar diff, SymbolVar queries, SymbolVar keys, SymbolVar values,
SymbolVar wqkv, SymbolVar reserveSpace, const Param& param,
const OperatorNodeConfig& config) {
auto outs = queries.node()
->owner_graph()
->insert_opr(std::make_unique<MultiHeadAttnBackward>(
diff.node(), queries.node(), keys.node(), values.node(),
wqkv.node(), reserveSpace.node(), param, config))
->output();
mgb_assert(outs.size() == 5);
return {outs[0], outs[1], outs[2], outs[3]};
}
void MultiHeadAttnBackward::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(1)));
mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2)));
mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(3)));
mgr.register_shape_infer(output(3), ShapeInferDesc::make_identity(input(4)));
this->init_output_static_infer_desc_workspace(false);
}
void MultiHeadAttnBackward::init_output_dtype() {
output(0)->dtype(input(1)->dtype());
output(1)->dtype(input(2)->dtype());
output(2)->dtype(input(3)->dtype());
output(3)->dtype(input(4)->dtype());
}
size_t MultiHeadAttnBackward::get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const {
MGB_MARK_USED_VAR(input_shapes);
MGB_MARK_USED_VAR(output_shapes);
return 0;
}
void MultiHeadAttnBackward::scn_do_execute() {
megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(),
input(4)->dev_tensor().as_megdnn(), input(5)->dev_tensor().as_megdnn(),
output(0)->dev_tensor().as_megdnn(), output(1)->dev_tensor().as_megdnn(),
output(2)->dev_tensor().as_megdnn(), output(3)->dev_tensor().as_megdnn(),
{});
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -30,6 +30,35 @@ struct OprMaker<opr::DropoutForward, 1> {
}
};
template <>
struct OprMaker<opr::MultiHeadAttn, 0> {
using Param = opr::MultiHeadAttn::Param;
static cg::OperatorNodeBase* make(
const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
return opr::MultiHeadAttn::make(i[0], i[1], i[2], i[3], param, config)[0]
.node()
->owner_opr();
}
};
// OprMaker in MGB_SEREG_OPR only support unique output opr
template <>
struct OprMaker<opr::MultiHeadAttnBackward, 0> {
using Param = opr::MultiHeadAttnBackward::Param;
static cg::OperatorNodeBase* make(
const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
return opr::MultiHeadAttnBackward::make(
i[0], i[1], i[2], i[3], i[4], i[5], param, config)[0]
.node()
->owner_opr();
}
};
} // namespace serialization
namespace opr {
......@@ -46,6 +75,8 @@ MGB_SEREG_OPR(ShuffleRNG, 1);
MGB_SEREG_OPR(ShuffleRNGBackward, 3);
MGB_SEREG_OPR(Dropout, 1);
MGB_SEREG_OPR(DropoutBackward, 2);
MGB_SEREG_OPR(MultiHeadAttn, 0);
MGB_SEREG_OPR(MultiHeadAttnBackward, 0);
} // namespace opr
} // namespace mgb
......
......@@ -86,6 +86,13 @@ _DEFINE_RNG_OPR_WITH_INPUT_CLASS(BetaRNG)
_DEFINE_RNG_OPR_WITH_INPUT_CLASS(GammaRNG)
#undef _OUTPUTS
#undef _INPUTS
/* ================= 4 input ================= */
#define _INPUTS(preifx) preifx i0, preifx i1, preifx i2, preifx i3
#define _OUTPUTS SymbolVarArray
_DEFINE_RNG_OPR_WITH_INPUT_CLASS(MultiHeadAttnForward)
#undef _OUTPUTS
#undef _INPUTS
#undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS
} // namespace intl
......@@ -99,6 +106,7 @@ using BetaRNG = intl::BetaRNG;
using ShuffleRNG = intl::ShuffleRNGForward;
using Dropout = intl::DropoutForward;
using DropoutForward = intl::DropoutForward;
using MultiHeadAttn = intl::MultiHeadAttnForward;
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
ShuffleRNGBackward, intl::MegDNNOprWrapperBwd<megdnn::ShuffleRNGBackward>) // {
......@@ -132,6 +140,29 @@ private:
void scn_do_execute() override;
};
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
MultiHeadAttnBackward,
intl::MegDNNOprWrapperBwd<megdnn::MultiHeadAttnBackward>) // {
public:
MGE_WIN_DECLSPEC_FUC MultiHeadAttnBackward(
VarNode* diff, VarNode* queries, VarNode* keys, VarNode* values,
VarNode* wqkv, VarNode* reserveSpace, const Param& param,
const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar diff, SymbolVar queries, SymbolVar keys, SymbolVar values,
SymbolVar wqkv, SymbolVar reserveSpace, const Param& param = {},
const OperatorNodeConfig& config = {});
private:
void init_output_static_infer_desc() override;
void init_output_dtype() override;
size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
void scn_do_execute() override;
};
} // namespace opr
} // namespace mgb
......
......@@ -126,6 +126,7 @@ union OperatorParam {
param.GroupNorm = 92,
param.Fill = 93,
param.GeneralNorm=94,
param.MultiHeadAttn=95,
}
table Operator {
......
......@@ -143,6 +143,7 @@ union OperatorParam {
param.GroupNorm = 92,
param.Fill = 93,
param.GeneralNorm=94,
param.MultiHeadAttn=95,
}
table Operator {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册