helper.cpp 6.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
#include "src/cuda/multi_head_attn/helper.h"
#if CUDNN_VERSION >= 8004

namespace megdnn {
namespace cuda {

AuxiliaryArray::~AuxiliaryArray() {
    if (loWinIdx)
        free(loWinIdx);
    if (hiWinIdx)
        free(hiWinIdx);
    if (seqQArray)
        free(seqQArray);
    if (seqKArray)
        free(seqKArray);
    if (devSeqQArray)
        cuda_check(cudaFree(devSeqQArray));
    if (devSeqKArray)
        cuda_check(cudaFree(devSeqKArray));
}

bool AuxiliaryArray::is_initialized(
        const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK,
        bool _attnMask) {
    if (_batchSize != batchSize or _seqLenQ != seqLenQ or _seqLenK != seqLenK or
        _attnMask != attnMask or !seqQArray or !seqKArray or !devSeqQArray or
        !devSeqKArray or !loWinIdx or !hiWinIdx)
        return false;
    return true;
}

void AuxiliaryArray::set(
        const size_t _batchSize, const size_t _seqLenQ, const size_t _seqLenK,
        bool _attnMask) {
    if (_batchSize == batchSize && _seqLenQ == seqLenQ && _seqLenK == seqLenK &&
        _attnMask == attnMask)
        return;
    else {
        if (loWinIdx)
            free(loWinIdx);
        if (hiWinIdx)
            free(hiWinIdx);
        if (seqQArray)
            free(seqQArray);
        if (seqKArray)
            free(seqKArray);
        if (devSeqQArray)
            cuda_check(cudaFree(devSeqQArray));
        if (devSeqKArray)
            cuda_check(cudaFree(devSeqKArray));
    };

    seqLenQ = _seqLenQ;
    seqLenK = _seqLenK;
    batchSize = _batchSize;
    attnMask = _attnMask;
    size_t seqQArraySize = 1 * batchSize;
    size_t seqKArraySize = batchSize;
    seqQArray = (int*)calloc(seqQArraySize, sizeof(int));
    seqKArray = (int*)calloc(seqKArraySize, sizeof(int));
    for (size_t i = 0; i < seqQArraySize; ++i)
        seqQArray[i] = seqLenQ;
    for (size_t i = 0; i < seqKArraySize; ++i)
        seqKArray[i] = seqLenK;

    cuda_check(cudaMalloc((void**)&devSeqQArray, seqQArraySize * sizeof(int)));
    cuda_check(cudaMalloc((void**)&devSeqKArray, seqKArraySize * sizeof(int)));

    cuda_check(cudaMemcpy(
            devSeqQArray, seqQArray, seqQArraySize * sizeof(int),
            cudaMemcpyHostToDevice));
    cuda_check(cudaMemcpy(
            devSeqKArray, seqKArray, seqKArraySize * sizeof(int),
            cudaMemcpyHostToDevice));

    loWinIdx = (int*)calloc(seqLenQ, sizeof(int));
    hiWinIdx = (int*)calloc(seqLenQ, sizeof(int));
    for (size_t i = 0; i < seqLenQ; ++i) {
        loWinIdx[i] = 0;
        if (attnMask)
            hiWinIdx[i] = i + 1;
        else
            hiWinIdx[i] = seqLenK;
    }
}

void MultiHeadAttnStatus::set(
        cudnnHandle_t handle, const Param& p, const TensorLayout& q,
        const TensorLayout& k, const TensorLayout& v) {
    float attn_prob = p.training ? p.attn_prob : 0.f;
    float out_prob = p.training ? p.out_prob : 0.f;
    if (!attn_dropout_status.initialized())
        attn_dropout_status.set(handle, p.seed, attn_prob);
    if (!out_dropout_status.initialized())
        out_dropout_status.set(handle, p.seed, out_prob);

    if (attn_dropout_status.drop_prob != attn_prob) {
        attn_dropout_status.drop_prob = attn_prob;
        attn_dropout_status.restore_desc(handle);
    }
    if (out_dropout_status.drop_prob != out_prob) {
        out_dropout_status.drop_prob = out_prob;
        out_dropout_status.restore_desc(handle);
    }
    batchSize = q.shape[0];
    seqLenQ = q.shape[1];
    seqLenK = k.shape[1];
    qSize = q.shape[2];
    kSize = k.shape[2];
    vSize = v.shape[2];
    numHeads = p.num_heads;
112 113 114 115 116
    qProjSize = p.qproj_size ? qSize / numHeads : 0;
    kProjSize = p.kproj_size ? kSize / numHeads : 0;
    vProjSize = p.vproj_size ? vSize / numHeads : 0;
    oProjSize = p.oproj_size ? qSize : 0;
    attnMask = p.attn_mask_type >= param::MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK;
117 118
    cudnnDataType_t cudnn_dtype = to_cudnn_dtype(q.dtype);
    auto flag = CUDNN_ATTN_QUERYMAP_ONE_TO_ONE;
119
    if (p.qbias or p.kbias or p.vbias or p.obias)
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
        flag = flag | CUDNN_ATTN_ENABLE_PROJ_BIASES;
#if CUDNN_VERSION < 8600
    // TODO: CUDNN_VERSION < 8600 and out dropout > 0.0, we need to go to the proxy cuda
    // implementation.
    cudnn_check(cudnnSetAttnDescriptor(
            attn_desc, flag, numHeads, p.sm_scaler, cudnn_dtype, cudnn_dtype,
            CUDNN_DEFAULT_MATH, attn_dropout_status.desc.desc, NULL, qSize, kSize,
            vSize, qProjSize, kProjSize, vProjSize, oProjSize, seqLenQ, seqLenK,
            batchSize, 1));
#else
    cudnn_check(cudnnSetAttnDescriptor(
            attn_desc, flag, numHeads, p.sm_scaler, cudnn_dtype, cudnn_dtype,
            CUDNN_DEFAULT_MATH, attn_dropout_status.desc.desc,
            out_dropout_status.desc.desc, qSize, kSize, vSize, qProjSize, kProjSize,
            vProjSize, oProjSize, seqLenQ, seqLenK, batchSize, 1));
#endif

137 138 139
    auxArray.set(
            batchSize, seqLenQ, seqLenK,
            p.attn_mask_type >= param::MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK);
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161

    if (p.training)
        cudnnGetMultiHeadAttnBuffers(
                handle, attn_desc, &sizeWeights, &sizeWkspace, &sizeReserve);
    else {
        cudnnGetMultiHeadAttnBuffers(
                handle, attn_desc, &sizeWeights, &sizeWkspace, NULL);
        sizeReserve = 0;
    }
}

bool MultiHeadAttnStatus::is_initialized(
        const Param& p, const TensorLayout& q, const TensorLayout& k,
        const TensorLayout& v) {
    float attn_prob = p.training ? p.attn_prob : 0.f;
    float out_prob = p.training ? p.out_prob : 0.f;
    if (!attn_dropout_status.initialized() or !out_dropout_status.initialized() or
        attn_dropout_status.drop_prob != attn_prob or
        out_dropout_status.drop_prob != out_prob)
        return false;
    if (q.shape[0] != batchSize or q.shape[1] != seqLenQ or k.shape[1] != seqLenK or
        q.shape[2] != qSize or k.shape[2] != kSize or v.shape[2] != vSize or
162 163 164
        attnMask != (p.attn_mask_type >=
                     param::MultiHeadAttn::ATTN_MASK_TYPE::DEFAULT_MASK) or
        numHeads != p.num_heads) {
165 166
        return false;
    }
167 168 169 170
    if ((p.qproj_size && (qProjSize == 0 or qProjSize != qSize / p.num_heads)) or
        (p.kproj_size && (kProjSize == 0 or kProjSize != kSize / p.num_heads)) or
        (p.vproj_size && (vProjSize == 0 or vProjSize != vSize / p.num_heads)) or
        (p.oproj_size && (oProjSize == 0 or oProjSize != q.shape[2])))
171
        return false;
172 173
    if ((!p.qproj_size && qProjSize != 0) or (!p.kproj_size && kProjSize != 0) or
        (!p.vproj_size && vProjSize != 0) or (!p.oproj_size && oProjSize != 0))
174 175 176 177 178 179 180 181 182 183 184 185
        return false;
    if (!auxArray.is_initialized(batchSize, seqLenQ, seqLenK, attnMask))
        return false;
    if (p.training and sizeReserve == 0)
        return false;
    return true;
}

}  // namespace cuda
}  // namespace megdnn
#endif
// vim: syntax=cpp.doxygen