cublas_lt.cpp 5.9 KB
Newer Older
1 2 3
#include "./algos.h"
#include "src/cuda/handle.h"
#include "src/cuda/matrix_mul/cublasLt_wrapper.h"
M
Megvii Engine Team 已提交
4
#include "src/cuda/utils.h"
5 6 7 8
#if CUDA_VERSION >= 10010
using namespace megdnn;
using namespace cuda;

M
Megvii Engine Team 已提交
9
bool MatrixMulForwardImpl::AlgoCuBlasLt::is_available(const SizeArgs& args) const {
10 11
    if (args.opr->param().format != param::MatrixMul::Format::DEFAULT)
        return false;
12 13
    if (args.layout_a.dtype.enumv() == DTypeEnum::Quantized4Asymm ||
        args.layout_a.dtype.enumv() == DTypeEnum::BFloat16)
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
        return false;
    CUBLASLTMatmulDesc::SizeArgs ltArgs(args);
    return CUBLASLTMatmulDesc(ltArgs).is_available(ltArgs, INT_MAX);
}
size_t MatrixMulForwardImpl::AlgoCuBlasLt::get_workspace_in_bytes(
        const SizeArgs& args) const {
    CUBLASLTMatmulDesc::SizeArgs ltArgs(args);
    cublasLtMatmulAlgo_t algo;
    CUBLASLTMatmulDesc desc(ltArgs);
    desc.get_algorithm_heuristic(ltArgs, INT_MAX, algo);
    return desc.get_workspace_bundle(ltArgs, algo).total_size_in_bytes();
}
void MatrixMulForwardImpl::AlgoCuBlasLt::exec(const ExecArgs& args) const {
    CUBLASLTMatmulDesc::SizeArgs ltArgs(args);
    cublasLtMatmulAlgo_t algo;
    CUBLASLTMatmulDesc desc(ltArgs);
    auto&& handle = ltArgs.handle;
    auto&& stream = handle->stream();
    auto&& cublasLt_handle = handle->cublasLt_handle();
    desc.get_algorithm_heuristic(ltArgs, INT_MAX, algo);
    auto&& ws_bundle = desc.get_workspace_bundle(ltArgs, algo);
    ws_bundle.set(args.workspace.raw_ptr);

    auto sgemm = [&]() {
        auto zero = handle->zero_device();
        auto one = handle->one_device();
M
Megvii Engine Team 已提交
40 41 42 43 44 45 46 47 48 49
        megdnn_assert(
                ws_bundle.nr_workspace() == 1,
                "workspace bundle size should be 1(ws_algo)");
        cublas_check(cublasLtMatmul(
                cublasLt_handle, desc.matmul_desc, one,
                static_cast<void*>(args.tensor_b.ptr<dt_float32>()), desc.layout_b,
                static_cast<void*>(args.tensor_a.ptr<dt_float32>()), desc.layout_a,
                zero, static_cast<void*>(args.tensor_c.ptr<dt_float32>()),
                desc.layout_c, static_cast<void*>(args.tensor_c.ptr<dt_float32>()),
                desc.layout_c, &algo, ws_bundle.get(0), ws_bundle.get_size(0), stream));
50 51 52 53
    };
    auto hgemm = [&]() {
        auto zero_half = handle->zero_device_h();
        auto one_half = handle->one_device_h();
M
Megvii Engine Team 已提交
54 55 56 57 58
        megdnn_assert(
                ws_bundle.nr_workspace() == 1,
                "workspace bundle size should be 1(ws_algo)");
        cublas_check(cublasLtMatmul(
                cublasLt_handle, desc.matmul_desc, one_half,
59 60 61 62
                static_cast<const __half*>(args.tensor_b.raw_ptr()), desc.layout_b,
                static_cast<const __half*>(args.tensor_a.raw_ptr()), desc.layout_a,
                zero_half, static_cast<const __half*>(args.tensor_c.raw_ptr()),
                desc.layout_c, static_cast<__half*>(args.tensor_c.raw_ptr()),
M
Megvii Engine Team 已提交
63
                desc.layout_c, &algo, ws_bundle.get(0), ws_bundle.get_size(0), stream));
64 65 66 67
    };
    auto igemm = [&]() {
        auto zero = handle->zero_device();
        auto one = handle->one_device();
M
Megvii Engine Team 已提交
68 69 70 71 72 73 74 75
        megdnn_assert(
                ws_bundle.nr_workspace() == 4,
                "workspace bundle size should be 4(ws_algo, ws_a, ws_b, ws_c)");
        void* ws_b = ws_bundle.get(1);
        void* ws_a = ws_bundle.get(2);
        void* ws_c = ws_bundle.get(3);
        int32_t pm = CUBLAS_POINTER_MODE_DEVICE;
        cublasOperation_t trans_a = CUBLAS_OP_T, trans_c = CUBLAS_OP_N;
76 77
        cublasLtMatrixTransformDesc_t transform_desc = nullptr;
        cublas_check(cublasLtMatrixTransformDescCreate(&transform_desc, CUDA_R_32F));
M
Megvii Engine Team 已提交
78 79 80 81
        cublas_check(cublasLtMatrixTransformDescSetAttribute(
                transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE, &pm,
                sizeof(pm)));
        cublas_check(cublasLtMatrixTransform(
82
                cublasLt_handle, transform_desc, one, args.tensor_b.raw_ptr(),
M
Megvii Engine Team 已提交
83 84 85 86 87 88
                desc.layout_b, zero, nullptr, nullptr, ws_b, desc.layout_trans_b,
                stream));
        cublas_check(cublasLtMatrixTransformDescSetAttribute(
                transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &trans_a,
                sizeof(trans_a)));
        cublas_check(cublasLtMatrixTransform(
89
                cublasLt_handle, transform_desc, one, args.tensor_a.raw_ptr(),
M
Megvii Engine Team 已提交
90 91 92 93 94 95 96 97 98 99 100 101
                desc.layout_a, zero, nullptr, nullptr, ws_a, desc.layout_trans_a,
                stream));
        cublas_check(cublasLtMatmul(
                cublasLt_handle, desc.matmul_desc, one, ws_b, desc.layout_trans_b, ws_a,
                desc.layout_trans_a, zero, ws_c, desc.layout_trans_c, ws_c,
                desc.layout_trans_c, &algo, ws_bundle.get(0), ws_bundle.get_size(0),
                stream));
        cublas_check(cublasLtMatrixTransformDescSetAttribute(
                transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &trans_c,
                sizeof(trans_c)));
        cublas_check(cublasLtMatrixTransform(
                cublasLt_handle, transform_desc, one, ws_c, desc.layout_trans_c, zero,
102
                nullptr, nullptr, args.tensor_c.raw_ptr(), desc.layout_c, stream));
103 104
        cublas_check(cublasLtMatrixTransformDescDestroy(transform_desc));
    };
105 106 107 108 109
#if CUDA_VERSION >= 11000
    switch (desc.dt_compute) {
        case CUBLAS_COMPUTE_16F:
            hgemm();
            break;
M
Megvii Engine Team 已提交
110
        case CUBLAS_COMPUTE_32F_FAST_TF32:
111 112 113 114 115 116
            sgemm();
            break;
        case CUBLAS_COMPUTE_32I:
            igemm();
            break;
        default:
M
Megvii Engine Team 已提交
117
            megdnn_throw("compute type must be float16/float32/int32");
118 119 120
    }
#else
    switch (desc.dt_compute) {
121 122 123 124 125 126 127 128 129 130
        case CUDA_R_16F:
            hgemm();
            break;
        case CUDA_R_32F:
            sgemm();
            break;
        case CUDA_R_32I:
            igemm();
            break;
        default:
M
Megvii Engine Team 已提交
131
            megdnn_throw("compute type must be float16/float32/int32");
132
    }
133
#endif
134 135 136
}
#endif
// vim: syntax=cpp.doxygen