cublas.cpp 5.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#include "./algos.h"

#include "src/cuda/handle.h"
#include "src/cuda/utils.h"

#include <cuda.h>

using namespace megdnn;
using namespace cuda;

#if CUDA_VERSION >= 8000
#define SE_CUDA_DATA_HALF CUDA_R_16F
#else
#define SE_CUDA_DATA_HALF CUBLAS_DATA_HALF
#endif

17 18 19 20
#if CUDA_VERSION < 11000
#define CUBLAS_COMPUTE_32I CUDA_R_32I
#endif

M
Megvii Engine Team 已提交
21
bool MatrixMulForwardImpl::AlgoCuBlas::is_available(const SizeArgs& args) const {
22 23 24 25 26
    if (args.opr->param().format != param::MatrixMul::Format::DEFAULT)
        return false;
    if (args.layout_a.dtype == dtype::Float32() ||
        args.layout_a.dtype == dtype::Float16()) {
        return true;
M
Megvii Engine Team 已提交
27 28 29
    } else if (
            args.layout_a.dtype.enumv() == DTypeEnum::Int8 ||
            args.layout_a.dtype.enumv() == DTypeEnum::QuantizedS8) {
30 31 32 33 34 35
        /**
         * \note When passing in the strides which can not be divided by 4, the
         * cublas rontine cublasGemmEx will raise a Error
         * CUBLAS_STATUS_INVALID_VALUE. The error occured because the leading
         * dimension of matrix A or B is illegal.
         */
M
Megvii Engine Team 已提交
36
        return args.layout_a.stride[0] % 4 == 0 && args.layout_b.stride[0] % 4 == 0 &&
37
               is_compute_capability_required(6, 1);
38 39 40 41 42 43 44 45 46 47 48 49 50 51
    }
    return false;
}

void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const {
    auto&& handle = concrete_handle(args.opr->handle());
    auto&& cublas_handle = handle->cublas_handle();
    auto&& param = args.opr->param();
    size_t m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1],
           k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1];

    auto sgemm = [&]() {
        auto zero = handle->zero_device();
        auto one = handle->one_device();
M
Megvii Engine Team 已提交
52 53 54
#if CUDART_VERSION >= 11000
        cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
#endif
55 56 57 58
        cublas_check(cublasSgemm(
                cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
                param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one,
                args.tensor_b.ptr<dt_float32>(), args.tensor_b.layout.stride[0],
M
Megvii Engine Team 已提交
59 60
                args.tensor_a.ptr<dt_float32>(), args.tensor_a.layout.stride[0], zero,
                args.tensor_c.ptr<dt_float32>(), args.tensor_c.layout.stride[0]));
M
Megvii Engine Team 已提交
61 62 63
#if CUDART_VERSION >= 11000
        cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_DEFAULT_MATH));
#endif
64 65 66 67 68
    };

    auto sgemm_ex = [&]() {
        auto zero = handle->zero_device();
        auto one = handle->one_device();
M
Megvii Engine Team 已提交
69 70 71
#if CUDART_VERSION >= 11000
        cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
#elif CUDART_VERSION >= 9000
72 73 74 75 76
        cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH));
#endif
        auto sgemm_ex_err = cublasSgemmEx(
                cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
                param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one,
77 78
                args.tensor_b.raw_ptr(), SE_CUDA_DATA_HALF,
                args.tensor_b.layout.stride[0], args.tensor_a.raw_ptr(),
79
                SE_CUDA_DATA_HALF, args.tensor_a.layout.stride[0], zero,
80
                args.tensor_c.raw_ptr(), SE_CUDA_DATA_HALF,
81 82 83 84 85 86 87 88
                args.tensor_c.layout.stride[0]);
        cublas_check(sgemm_ex_err);
#if CUDART_VERSION >= 9000
        cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_DEFAULT_MATH));
#endif
    };

    auto hgemm = [&]() {
M
Megvii Engine Team 已提交
89 90 91
#if CUDART_VERSION >= 11000
        cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
#elif CUDART_VERSION >= 9000
92 93 94 95 96 97 98
        cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH));
#endif
        auto one_half = handle->one_device_h();
        auto zero_half = handle->zero_device_h();
        auto hgemm_ex_err = cublasHgemm(
                cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
                param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one_half,
99
                static_cast<const __half*>(args.tensor_b.raw_ptr()),
100
                args.tensor_b.layout.stride[0],
101
                static_cast<const __half*>(args.tensor_a.raw_ptr()),
102
                args.tensor_a.layout.stride[0], zero_half,
103
                static_cast<__half*>(args.tensor_c.raw_ptr()),
104 105 106 107 108 109 110 111 112 113 114 115 116
                args.tensor_c.layout.stride[0]);
        cublas_check(hgemm_ex_err);
#if CUDART_VERSION >= 9000
        cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_DEFAULT_MATH));
#endif
    };

    auto igemm = [&]() {
        auto zero = handle->zero_device_i32();
        auto one = handle->one_device_i32();
        cublas_check(cublasGemmEx(
                cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
                param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one,
117 118 119 120
                args.tensor_b.raw_ptr(), CUDA_R_8I, args.tensor_b.layout.stride[0],
                args.tensor_a.raw_ptr(), CUDA_R_8I, args.tensor_a.layout.stride[0],
                zero, args.tensor_c.raw_ptr(), CUDA_R_32I,
                args.tensor_c.layout.stride[0], CUBLAS_COMPUTE_32I, CUBLAS_GEMM_DFALT));
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
    };

    // Note that cublas takes column-major matrices as inputs,
    // but megdnn takes row-major ones.
    // So we calculate C^t = B^t * A^t by cublas. Here the transpose symbol
    // implies row-major to column-major conversion.
    if (args.tensor_a.layout.dtype == dtype::Float32()) {
        sgemm();
    } else if (args.tensor_a.layout.dtype == dtype::Float16()) {
        // use tensor core; note that CUBLAS_TENSOR_OP_MATH also causes
        // cublasSgemm to round to fp16, so we can not always enable it
        if (handle->device_prop().major >= 6 &&
            param.compute_mode == Param::ComputeMode::DEFAULT)
            hgemm();
        else
            sgemm_ex();
    } else if (args.can_be_treated_as_int8x8x32()) {
        igemm();
    } else {
        megdnn_throw("Unsupported data_type of matrix mul on cuda.");
    }
}

// vim: syntax=cpp.doxygen