opr_impl.cpp 7.8 KB
Newer Older
1 2 3 4 5 6 7 8
#include "src/cuda/multi_head_attn/opr_impl.h"
#include "src/common/utils.cuh"
#include "src/cuda/utils.cuh"
#include "src/cuda/utils.h"

namespace megdnn {
namespace cuda {

9
bool can_use_mha_cudnn(const Param& param) {
10
#if CUDNN_VERSION < 8004
11 12
    MEGDNN_MARK_USED_VAR(param);
    return false;
13
#else
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
    bool flag = true;
    size_t bias_num = 0;
    size_t weight_num = 0;
    bias_num += (param.qbias ? 1 : 0);
    bias_num += (param.kbias ? 1 : 0);
    bias_num += (param.vbias ? 1 : 0);
    bias_num += (param.obias ? 1 : 0);
    weight_num += (param.qproj_size > 0 ? 1 : 0);
    weight_num += (param.kproj_size > 0 ? 1 : 0);
    weight_num += (param.vproj_size > 0 ? 1 : 0);
    weight_num += (param.oproj_size > 0 ? 1 : 0);
    if (bias_num != weight_num && bias_num != 0) {
        flag = false;
    }
#if CUDNN_VERSION < 8600
    if (bias_num > 0 && param.training == true) {
        flag = false;
    }
    if (param.out_prob > 0) {
        flag = false;
    }
#endif
    if (param.need_weights) {
        flag = false;
    }
    if (param.attn_mask_type == MaskType::USER_DEFINED_MASK) {
        flag = false;
    }
    if (param.attn_mask_type == MaskType::CUDNN_STYLE_MASK) {
        megdnn_assert(
                flag == true,
                "maybe_cudnn_style_mask=True, but can not run cudnn impl, Please make "
                "sure that cuda is available, and check you parameter or do not use "
                "cudnn style mask.");
    }
    return flag;
#endif
}
52

53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
void MultiHeadAttnForwardImpl::deduce_layout(MHA_FORWARD_LAYOUT_PARAM) {
    Param p = param();
#if CUDNN_VERSION < 8004
    proxy_opr.deduce_layout(
            this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask,
            bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace);
#else
    if (can_use_mha_cudnn(p)) {
        cudnn_opr.deduce_layout(
                this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask,
                bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace);
    } else {
        proxy_opr.deduce_layout(
                this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask,
                bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace);
    }
69 70 71 72
#endif
}

size_t MultiHeadAttnForwardImpl::get_workspace_in_bytes(
73 74
        MHA_FORWARD_LAYOUT_CONST_PARAM) {
    Param p = param();
75
#if CUDNN_VERSION < 8004
76 77 78
    return proxy_opr.get_workspace_in_bytes(
            this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask,
            bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace);
79
#else
80 81 82 83 84 85 86 87 88
    if (can_use_mha_cudnn(p)) {
        return cudnn_opr.get_workspace_in_bytes(
                this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask,
                bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace);
    } else {
        return proxy_opr.get_workspace_in_bytes(
                this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask,
                bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace);
    }
89 90 91
#endif
}

92
size_t MultiHeadAttnForwardImpl::get_mask_reservespace_in_bytes(
93 94
        MHA_FORWARD_LAYOUT_CONST_PARAM) {
    Param p = param();
95
#if CUDNN_VERSION < 8004
96 97 98
    return proxy_opr.get_mask_reservespace_in_bytes(
            this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask,
            bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace);
99
#else
100 101 102 103 104 105 106 107 108
    if (can_use_mha_cudnn(p)) {
        return cudnn_opr.get_mask_reservespace_in_bytes(
                this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask,
                bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace);
    } else {
        return proxy_opr.get_mask_reservespace_in_bytes(
                this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask,
                bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace);
    }
109 110
#endif
}
111

112
size_t MultiHeadAttnForwardImpl::get_othr_reservespace_in_bytes(
113 114
        MHA_FORWARD_LAYOUT_CONST_PARAM) {
    Param p = param();
115
#if CUDNN_VERSION < 8004
116 117 118
    return proxy_opr.get_othr_reservespace_in_bytes(
            this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask,
            bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace);
119
#else
120 121 122 123 124 125 126 127 128
    if (can_use_mha_cudnn(p)) {
        return cudnn_opr.get_othr_reservespace_in_bytes(
                this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask,
                bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace);
    } else {
        return proxy_opr.get_othr_reservespace_in_bytes(
                this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask,
                bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace);
    }
129 130
#endif
}
131

132
void MultiHeadAttnForwardImpl::exec(MHA_FORWARD_EXEC_PARAM) {
133
    check_exec(
134 135 136 137
            queries.layout, keys.layout, values.layout, qkvo_weight_bias.layout,
            attn_mask.layout, bias_k.layout, bias_v.layout, out.layout,
            attn_weight.layout, mask_reservespace.layout, othr_reservespace.layout,
            workspace.size);
138
    Param p = param();
139
#if CUDNN_VERSION < 8004
140 141 142 143
    proxy_opr.exec(
            this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask,
            bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace,
            workspace);
144
#else
145 146 147 148 149 150 151 152 153 154 155
    if (can_use_mha_cudnn(p)) {
        cudnn_opr.exec(
                this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask,
                bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace,
                workspace);
    } else {
        proxy_opr.exec(
                this->handle(), p, queries, keys, values, qkvo_weight_bias, attn_mask,
                bias_k, bias_v, out, attn_weight, mask_reservespace, othr_reservespace,
                workspace);
    }
156
#endif
157
}
158

159
void MultiHeadAttnBackwardImpl::exec(MHA_BACKWARD_EXEC_PARAM) {
160
    check_exec(
161 162 163 164 165
            diff.layout, queries.layout, keys.layout, values.layout,
            qkvo_weight_bias.layout, attn_mask.layout, attn_weight.layout,
            mask_reservespace.layout, othr_reservespace.layout, dqueries.layout,
            dkeys.layout, dvalues.layout, dqkvo_weight_bias.layout, dbias_k.layout,
            dbias_v.layout, workspace.size);
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
    Param p = param();
#if CUDNN_VERSION < 8004
    proxy_opr.exec(
            this->handle(), p, diff, queries, keys, values, qkvo_weight_bias, attn_mask,
            attn_weight, mask_reservespace, othr_reservespace, dqueries, dkeys, dvalues,
            dqkvo_weight_bias, dbias_k, dbias_v, workspace);
#else
    if (can_use_mha_cudnn(p)) {
        cudnn_opr.exec(
                this->handle(), p, diff, queries, keys, values, qkvo_weight_bias,
                attn_mask, attn_weight, mask_reservespace, othr_reservespace, dqueries,
                dkeys, dvalues, dqkvo_weight_bias, dbias_k, dbias_v, workspace);
    } else {
        proxy_opr.exec(
                this->handle(), p, diff, queries, keys, values, qkvo_weight_bias,
                attn_mask, attn_weight, mask_reservespace, othr_reservespace, dqueries,
                dkeys, dvalues, dqkvo_weight_bias, dbias_k, dbias_v, workspace);
    }
184 185
#endif
}
186

187
size_t MultiHeadAttnBackwardImpl::get_workspace_in_bytes(
188 189 190 191 192 193 194 195 196 197
        MHA_BACKWARD_LAYOUT_CONST_PARAM) {
    Param p = param();
    if (can_use_mha_cudnn(p)) {
        return 0;
    } else {
        return proxy_opr.get_workspace_in_bytes(
                this->handle(), p, diff, queries, keys, values, qkvo_weight_bias,
                attn_mask, attn_weight, mask_reservespace, othr_reservespace, dqueries,
                dkeys, dvalues, dqkvo_weight_bias, dbias_k, dbias_v);
    }
198 199 200
}
}  // namespace cuda
}  // namespace megdnn
201
   // vim: syntax=cpp.doxygen