cudnn_fwbw.h 3.3 KB
Newer Older
1
#pragma once
2 3 4
#include <vector>
#include "megdnn/handle.h"
#include "megdnn/thin/small_vector.h"
5 6 7 8 9
#include "src/cuda/cudnn_wrapper.h"
#if CUDNN_VERSION >= 8004
#include "megdnn/basic_types.h"
#include "megdnn/oprs/nn.h"
#include "src/common/algo_chooser.h"
10
#include "src/common/multi_head_attn/helper.h"
11 12 13 14
#include "src/common/utils.h"
#include "src/cuda/dropout/opr_impl.h"
#include "src/cuda/handle.h"

15 16 17 18
using Param = megdnn::MultiHeadAttn::Param;
using MaskType = Param::AttnMaskType;
using InputType = Param::TensorCombinationType;

19 20 21 22 23
namespace megdnn {
namespace cuda {

struct AuxiliaryArray {
public:
24 25
    SmallVector<int> seqQArray;
    SmallVector<int> seqKArray;
26 27
    int* devSeqQArray = nullptr;
    int* devSeqKArray = nullptr;
28 29
    SmallVector<int> loWinIdx;
    SmallVector<int> hiWinIdx;
30 31 32
    size_t seqLenQ = 0;
    size_t seqLenK = 0;
    size_t batchSize = 0;
33
    MaskType attnMaskType = MaskType::NO_MASK;
34 35
    ~AuxiliaryArray();
    void set(
36 37 38
            Handle* handle, const size_t _batchSize, const size_t _seqLenQ,
            const size_t _seqLenK, MaskType _attnMaskType);
    void set_cudnn_style_mask(Handle* handle, const TensorND& attn_mask);
39 40
    bool is_initialized(
            const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK,
41
            MaskType _attnMaskType);
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
};

class MultiHeadAttnStatus {
    DropoutStatus attn_dropout_status;
    DropoutStatus out_dropout_status;

    cudnnAttnDescriptor_t attn_desc;

    AuxiliaryArray auxArray;

    size_t numHeads = 0;
    size_t batchSize = 0;
    size_t seqLenQ = 0;
    size_t seqLenK = 0;
    size_t qSize = 0;
    size_t kSize = 0;
    size_t vSize = 0;
    size_t qProjSize = 0;
    size_t kProjSize = 0;
    size_t vProjSize = 0;
    size_t oProjSize = 0;
63 64
    MaskType attnMaskType = MaskType::NO_MASK;
    bool bias = false;
65 66 67 68 69 70 71 72 73 74 75

    size_t sizeWeights = 0;
    size_t sizeWkspace = 0;
    size_t sizeReserve = 0;

public:
    MultiHeadAttnStatus() { cudnn_check(cudnnCreateAttnDescriptor(&attn_desc)); }
    ~MultiHeadAttnStatus() { cudnn_check(cudnnDestroyAttnDescriptor(attn_desc)); }

private:
    void set(
76
            Handle* handle, const Param& p, const TensorLayout& q,
77
            const TensorLayout& k, const TensorLayout& v);
78
    void set_cudnn_style_mask(Handle* handle, const TensorND& attn_mask);
79 80 81 82 83 84
    bool is_initialized(
            const Param& p, const TensorLayout& q, const TensorLayout& k,
            const TensorLayout& v);
    friend class MultiHeadAttnBase;
    friend class MultiHeadAttnForwardImpl;
    friend class MultiHeadAttnBackwardImpl;
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
    friend class MHAForwardCudnnOpr;
    friend class MHABackwardCudnnOpr;
};

class MHAForwardCudnnOpr {
public:
    MHAForwardCudnnOpr(){};

    void exec(MHA_PROXY_FORWARD_EXEC_PARAM);
    void deduce_layout(MHA_PROXY_FORWARD_LAYOUT_PARAM);
    size_t get_workspace_in_bytes(MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM);
    size_t get_mask_reservespace_in_bytes(MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM);
    size_t get_othr_reservespace_in_bytes(MHA_PROXY_FORWARD_LAYOUT_CONST_PARAM);

private:
    MultiHeadAttnStatus desc_status;
101
};
102 103 104 105 106 107 108 109 110 111 112

class MHABackwardCudnnOpr {
public:
    MHABackwardCudnnOpr(){};

    void exec(MHA_PROXY_BACKWARD_EXEC_PARAM);

private:
    MultiHeadAttnStatus desc_status;
};

113 114 115
}  // namespace cuda
}  // namespace megdnn
#endif
116
   // vim: syntax=cpp.doxygen