cublasLt_wrapper.cpp 13.8 KB
Newer Older
1
#include "src/cuda/matrix_mul/cublasLt_wrapper.h"
M
Megvii Engine Team 已提交
2
#include "src/common/utils.h"
3 4
#include "src/cuda/utils.h"
#if CUDA_VERSION >= 10010
5

6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
namespace megdnn {
namespace cuda {
static cudaDataType_t to_cuda_dtype(DType tp) {
    switch (tp.enumv()) {
        case DTypeEnum::Float16:
            return CUDA_R_16F;
        case DTypeEnum::Float32:
            return CUDA_R_32F;
        case DTypeEnum::Int8:
        case DTypeEnum::QuantizedS8:
            return CUDA_R_8I;
        case DTypeEnum::Int32:
        case DTypeEnum::QuantizedS32:
            return CUDA_R_32I;
        default:
M
Megvii Engine Team 已提交
21
            megdnn_throw("dtype must be float16/float32/int8/qs8/int32");
22 23
    }
}
24

25
#if CUDA_VERSION >= 11000
26 27 28 29 30
static cublasComputeType_t to_cublas_compute_type(DType tp) {
    switch (tp.enumv()) {
        case DTypeEnum::Float16:
            return CUBLAS_COMPUTE_16F;
        case DTypeEnum::Float32:
M
Megvii Engine Team 已提交
31
            return CUBLAS_COMPUTE_32F_FAST_TF32;
32 33 34 35
        case DTypeEnum::Int32:
        case DTypeEnum::QuantizedS32:
            return CUBLAS_COMPUTE_32I;
        default:
M
Megvii Engine Team 已提交
36
            megdnn_throw("dtype must be float16/float32/int32/Qs32");
37 38
    }
}
39
#endif
40

41 42 43 44 45 46 47 48 49 50 51
static const char* cuda_type_to_str(cudaDataType_t tp) {
    switch (tp) {
        case CUDA_R_16F:
            return "CUDA_R_16F";
        case CUDA_R_32F:
            return "CUDA_R_32F";
        case CUDA_R_8I:
            return "CUDA_R_8I";
        case CUDA_R_32I:
            return "CUDA_R_32I";
        default:
M
Megvii Engine Team 已提交
52
            megdnn_throw("dtype must be float16/float32/int8/int32");
53 54
    }
}
55

56 57 58 59 60 61 62 63 64 65
static size_t cuda_dtype_size(cudaDataType_t dt) {
    switch (dt) {
        case CUDA_R_8I:
            return 1_z;
        case CUDA_R_16F:
            return 2_z;
        case CUDA_R_32F:
        case CUDA_R_32I:
            return 4_z;
        default:
M
Megvii Engine Team 已提交
66
            megdnn_throw("dtype must be float16/float32/int8/int32");
67 68
    }
}
69

70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
CUBLASLTMatmulDesc::~CUBLASLTMatmulDesc() {
    if (matmul_desc)
        cublas_check(cublasLtMatmulDescDestroy(matmul_desc));
    if (layout_a)
        cublas_check(cublasLtMatrixLayoutDestroy(layout_a));
    if (layout_b)
        cublas_check(cublasLtMatrixLayoutDestroy(layout_b));
    if (layout_c)
        cublas_check(cublasLtMatrixLayoutDestroy(layout_c));
    if (layout_trans_a)
        cublas_check(cublasLtMatrixLayoutDestroy(layout_trans_a));
    if (layout_trans_b)
        cublas_check(cublasLtMatrixLayoutDestroy(layout_trans_b));
    if (layout_trans_c)
        cublas_check(cublasLtMatrixLayoutDestroy(layout_trans_c));
}
void CUBLASLTMatmulDesc::set(const SizeArgs& args, bool batched) {
    cublasOperation_t trans_a, trans_b;
    auto m = args.layout_c.shape[batched ? 1 : 0],
         n = args.layout_c.shape[batched ? 2 : 1];
    auto k = batched ? args.layout_a.shape[args.transposeA ? 1 : 2]
                     : args.layout_a.shape[args.transposeA ? 0 : 1];
    int batch = (batched ? args.layout_a.shape[0] : 1);
    uint32_t pm = CUBLAS_POINTER_MODE_DEVICE;
    dt_b = to_cuda_dtype(args.layout_b.dtype);
    dt_a = to_cuda_dtype(args.layout_a.dtype);
96
    dt_c = to_cuda_dtype(args.layout_c.dtype);
97

98
    megdnn_assert(dt_a == dt_b, "matrix A and B should have same precision");
99 100
#if CUDA_VERSION >= 11000
    dt_compute = to_cublas_compute_type(args.layout_c.dtype);
101
    cublas_check(cublasLtMatmulDescCreate(&matmul_desc, dt_compute, dt_c));
102 103 104 105
#else
    dt_compute = dt_c;
    cublas_check(cublasLtMatmulDescCreate(&matmul_desc, dt_compute));
#endif
106 107 108 109 110 111 112 113 114 115 116
    cublas_check(cublasLtMatmulDescSetAttribute(
            matmul_desc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pm, sizeof(pm)));

    cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32;
    cublasLtOrder_t order_COL4_4R2_8C = CUBLASLT_ORDER_COL4_4R2_8C;
    /**
     * \NOTE that cublas takes column-major matrices as inputs,
     * but megdnn takes row-major ones.
     * So we calculate C^t = B^t * A^t by cublas. Here the transpose symbol
     * implies row-major to column-major conversion
     */
117
    if (dt_c == CUDA_R_32I) {
118 119 120 121 122 123 124 125 126 127 128 129 130 131
        /**
         *  \NOTE: To use IMMA kernels, use computeType = CUDA_R_32I and
         *  CUBLASLT_ORDER_COL32 for matrices A,C,D and
         * CUBLASLT_ORDER_COL4_4R2_8C for matrix B.
         */
        int ldbtransform, ldatransform, ldctransform;
        size_t stride_b_trans, stride_a_trans, stride_c_trans;
        ldbtransform = 32 * n;
        ldatransform = 32 * round_up<int32_t>(m, 8);
        ldctransform = 32 * n;
        stride_b_trans = round_up<int32_t>(k, 32) / 32 * ldbtransform;
        stride_a_trans = round_up<int32_t>(k, 32) / 32 * ldatransform;
        stride_c_trans = round_up<int32_t>(m, 32) / 32 * ldctransform;
        trans_b = CUBLAS_OP_T;
M
Megvii Engine Team 已提交
132 133
        cublas_check(cublasLtMatmulDescSetAttribute(
                matmul_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(trans_b)));
134 135 136 137 138 139 140 141
        // origin layout
        cublas_check(cublasLtMatrixLayoutCreate(
                &layout_b, dt_b, n, k, args.layout_b.stride[batched ? 1 : 0]));
        cublas_check(cublasLtMatrixLayoutCreate(
                &layout_a, dt_a, k, m, args.layout_a.stride[batched ? 1 : 0]));
        cublas_check(cublasLtMatrixLayoutCreate(
                &layout_c, dt_c, n, m, args.layout_c.stride[batched ? 1 : 0]));
        // transformed layout
M
Megvii Engine Team 已提交
142 143 144 145 146 147
        cublas_check(
                cublasLtMatrixLayoutCreate(&layout_trans_b, dt_b, n, k, ldbtransform));
        cublas_check(
                cublasLtMatrixLayoutCreate(&layout_trans_a, dt_a, m, k, ldatransform));
        cublas_check(
                cublasLtMatrixLayoutCreate(&layout_trans_c, dt_c, n, m, ldctransform));
148 149 150 151
        cublas_check(cublasLtMatrixLayoutSetAttribute(
                layout_trans_b, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32,
                sizeof(order_COL32)));
        cublas_check(cublasLtMatrixLayoutSetAttribute(
M
Megvii Engine Team 已提交
152 153
                layout_trans_a, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL4_4R2_8C,
                sizeof(order_COL4_4R2_8C)));
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
        cublas_check(cublasLtMatrixLayoutSetAttribute(
                layout_trans_c, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32,
                sizeof(order_COL32)));
        if (batched) {
            cublas_check(cublasLtMatrixLayoutSetAttribute(
                    layout_trans_b, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch,
                    sizeof(batch)));
            cublas_check(cublasLtMatrixLayoutSetAttribute(
                    layout_trans_a, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch,
                    sizeof(batch)));
            cublas_check(cublasLtMatrixLayoutSetAttribute(
                    layout_trans_c, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch,
                    sizeof(batch)));
            cublas_check(cublasLtMatrixLayoutSetAttribute(
                    layout_trans_b, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
                    &stride_b_trans, sizeof(stride_b_trans)));
            cublas_check(cublasLtMatrixLayoutSetAttribute(
                    layout_trans_a, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
                    &stride_a_trans, sizeof(stride_a_trans)));
            cublas_check(cublasLtMatrixLayoutSetAttribute(
                    layout_trans_c, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
                    &stride_c_trans, sizeof(stride_c_trans)));
        }
        workspace_b = batch * cuda_dtype_size(dt_b) * stride_b_trans;
        workspace_a = batch * cuda_dtype_size(dt_a) * stride_a_trans;
        workspace_c = batch * cuda_dtype_size(dt_c) * stride_c_trans;
    } else {
        trans_b = args.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N;
        trans_a = args.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N;
M
Megvii Engine Team 已提交
183 184 185 186
        cublas_check(cublasLtMatmulDescSetAttribute(
                matmul_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_b, sizeof(trans_b)));
        cublas_check(cublasLtMatmulDescSetAttribute(
                matmul_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_a, sizeof(trans_a)));
187 188
        cublas_check(cublasLtMatrixLayoutCreate(
                &layout_b, dt_b, trans_b == CUBLAS_OP_N ? n : k,
M
Megvii Engine Team 已提交
189
                trans_b == CUBLAS_OP_N ? k : n, args.layout_b.stride[batched ? 1 : 0]));
190 191
        cublas_check(cublasLtMatrixLayoutCreate(
                &layout_a, dt_a, trans_a == CUBLAS_OP_N ? k : m,
M
Megvii Engine Team 已提交
192
                trans_a == CUBLAS_OP_N ? m : k, args.layout_a.stride[batched ? 1 : 0]));
193 194 195 196 197 198 199
        cublas_check(cublasLtMatrixLayoutCreate(
                &layout_c, dt_c, n, m, args.layout_c.stride[batched ? 1 : 0]));
    }
    size_t stride_b = args.layout_b.stride[0];
    size_t stride_a = args.layout_a.stride[0];
    size_t stride_c = args.layout_c.stride[0];
    cublas_check(cublasLtMatrixLayoutSetAttribute(
M
Megvii Engine Team 已提交
200
            layout_b, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
201
    cublas_check(cublasLtMatrixLayoutSetAttribute(
M
Megvii Engine Team 已提交
202
            layout_a, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
203
    cublas_check(cublasLtMatrixLayoutSetAttribute(
M
Megvii Engine Team 已提交
204
            layout_c, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
205 206 207 208 209 210 211 212 213 214 215 216 217
    cublas_check(cublasLtMatrixLayoutSetAttribute(
            layout_b, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b,
            sizeof(stride_b)));
    cublas_check(cublasLtMatrixLayoutSetAttribute(
            layout_a, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a,
            sizeof(stride_a)));
    cublas_check(cublasLtMatrixLayoutSetAttribute(
            layout_c, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c,
            sizeof(stride_c)));
}
bool CUBLASLTMatmulDesc::is_available(const SizeArgs& args, size_t ws_limit) {
    bool support;
    cublasLtMatmulAlgo_t algo;
218
    switch (dt_c) {
219 220 221 222
        case CUDA_R_16F:
            support = (dt_a == CUDA_R_16F);
            break;
        case CUDA_R_32I: {
M
Megvii Engine Team 已提交
223
            support = (dt_a == CUDA_R_8I) && (!args.transposeA && !args.transposeB);
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
            break;
        }
        case CUDA_R_32F:
            support = (dt_a == CUDA_R_16F || dt_a == CUDA_R_32F);
            break;
        case CUDA_R_64F: /* not support? */
        default:
            support = false;
            break;
    }
    support = support && dt_a == dt_b;
    support = support && get_algorithm_heuristic(args, ws_limit, algo);
    return support;
}
WorkspaceBundle CUBLASLTMatmulDesc::get_workspace_bundle(
        const SizeArgs& args, const cublasLtMatmulAlgo_t& algo) {
    size_t algo_workspace_size;
    auto&& handle = args.handle;
    auto&& cublasLt_handle = handle->cublasLt_handle();
    cublasStatus_t status;
    cublasLtMatmulHeuristicResult_t result{};
    status = cublasLtMatmulAlgoCheck(
            cublasLt_handle, matmul_desc,
247 248 249
            dt_c == CUDA_R_32I ? layout_trans_b : layout_b,
            dt_c == CUDA_R_32I ? layout_trans_a : layout_a,
            dt_c == CUDA_R_32I ? layout_trans_c : layout_c,
250
            dt_c == CUDA_R_32I ? layout_trans_c : layout_c, &algo, &result);
251 252 253 254 255
    // return empty WorkspaceBundle if cublasLtMatmulAlgoCheck() failed
    if (status != CUBLAS_STATUS_SUCCESS)
        return {nullptr, {}};
    algo_workspace_size = result.workspaceSize;
    return {nullptr,
256
            (dt_c == CUDA_R_32I)
M
Megvii Engine Team 已提交
257 258
                    ? SmallVector<
                              size_t>{algo_workspace_size, workspace_b, workspace_a, workspace_c}
259 260
                    : SmallVector<size_t>{algo_workspace_size}};
}
M
Megvii Engine Team 已提交
261 262
bool CUBLASLTMatmulDesc::get_algorithm_heuristic(
        const SizeArgs& args, size_t ws_limit, cublasLtMatmulAlgo_t& algo) {
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
    bool result;
    int return_algo_count;
    size_t algo_ws_limit;
    cublasStatus_t status;
    cublasLtMatmulPreference_t algo_pref;
    cublasLtMatmulHeuristicResult_t algo_result{};
    auto&& handle = concrete_handle(args.handle);
    auto&& cublasLt_handle = handle->cublasLt_handle();

    size_t temp = workspace_b + workspace_a + workspace_c;
    algo_ws_limit = (ws_limit > temp) ? (ws_limit - temp) : 0;

    /**
     *  \Note: algo_ws_limit must be zero if cublasLtGetVersion() <= 10100
     */
    // algo_ws_limit = 0;
279
    if (dt_c == CUDA_R_32I) {
280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
        //[FIXME]: cublasLt(Version 10020) produce wrong result when k in
        //[64*n+1 , 64*n+32] for small matrix

        //[TODO]: check if this bug is fixed in latter cublasLt.
        size_t k_pos = (is_batched ? 1 : 0) + (args.transposeA ? 0 : 1);
        size_t k = args.layout_a.shape[k_pos];
        bool flt = (k < 65 || ((k - 1) / 32) % 2 == 1);
        if (!flt)
            return false;
    }
    result = false;
    cublas_check(cublasLtMatmulPreferenceCreate(&algo_pref));
    cublas_check(cublasLtMatmulPreferenceSetAttribute(
            algo_pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &algo_ws_limit,
            sizeof(algo_ws_limit)));
295 296 297 298 299 300 301 302 303 304 305 306 307
#if CUDA_VERSION < 11000
    bool is_f32_config = args.layout_a.dtype == dtype::Float32() &&
                         args.layout_b.dtype == dtype::Float32() &&
                         args.layout_c.dtype == dtype::Float32();
    if (is_f32_config) {
        // disable HMMA tensor op matmul when inputs and output are all f32
        // tensors, to avoid the potential accuracy loss
        uint32_t math_mode = CUBLAS_DEFAULT_MATH;
        cublas_check(cublasLtMatmulPreferenceSetAttribute(
                algo_pref, CUBLASLT_MATMUL_PREF_MATH_MODE_MASK, &math_mode,
                sizeof(math_mode)));
    }
#endif
308 309
    status = cublasLtMatmulAlgoGetHeuristic(
            cublasLt_handle, matmul_desc,
310 311 312
            dt_c == CUDA_R_32I ? layout_trans_b : layout_b,
            dt_c == CUDA_R_32I ? layout_trans_a : layout_a,
            dt_c == CUDA_R_32I ? layout_trans_c : layout_c,
M
Megvii Engine Team 已提交
313 314
            dt_c == CUDA_R_32I ? layout_trans_c : layout_c, algo_pref, 1, &algo_result,
            &return_algo_count);
315 316 317 318 319 320 321 322 323 324 325 326 327
    if (status == CUBLAS_STATUS_SUCCESS && return_algo_count > 0 &&
        // perform cublasLtAlgoCheck() to make sure the algo is correct
        get_workspace_bundle(args, algo_result.algo).nr_workspace() > 0) {
        result = true;
        algo = algo_result.algo;
    }
    cublas_check(cublasLtMatmulPreferenceDestroy(algo_pref));
    return result;
}
}  // namespace cuda
}  // namespace megdnn
#endif
// vim: syntax=cpp.doxygen