cublas.cpp 5.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
#include "./algo.h"
#include "./helper.cuh"
#include "src/common/utils.cuh"
#include "src/cuda/handle.h"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;
using namespace batched_matrix_mul;

M
Megvii Engine Team 已提交
11
bool BatchedMatrixMulForwardImpl::AlgoCublas::is_available(const SizeArgs& args) const {
12 13 14
    auto dtype = args.layout_a.dtype;
    auto&& param = args.opr->param();
    auto&& handle = concrete_handle(args.opr->handle());
15 16 17 18 19 20 21 22 23
    // fix: cublasSgemmBatched with versions prior to 11.1 has some error when batch = 1
    // and matricA's width > 8191 .So temporarily drop this algo when
    // args.layout_a.shape[2] <= 8191 || args.layout_a.shape[0] != 1
    if (dtype == dtype::Float32()
#if CUBLAS_VERSION < 11200
        && (args.layout_a.shape[args.opr->param().transposeA ? 1 : 2] <= 8191 ||
            args.layout_a.shape[0] != 1)
#endif
    )
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
        return true;
    if (dtype != dtype::Float16())
        return false;
    else {
        auto&& cuda_cap = handle->device_prop();
        if (param.compute_mode == Param::ComputeMode::FLOAT32) {
#if CUDART_VERSION >= 9010
            return cuda_cap.major >= 5;
#else
            MEGDNN_MARK_USED_VAR(cuda_cap);
            return false;
#endif
        } else {
#if CUDART_VERSION >= 9000
            return cuda_cap.major >= 6;
#else
            MEGDNN_MARK_USED_VAR(cuda_cap);
            return false;
#endif
        }
    }
}
size_t BatchedMatrixMulForwardImpl::AlgoCublas::get_workspace_in_bytes(
        const SizeArgs& args) const {
    return args.layout_a.shape[0] * 3 * sizeof(uintptr_t);
}
void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const {
    auto param = args.opr->param();
    auto dtype = args.layout_a.dtype;
    auto handle = concrete_handle(args.opr->handle());
    auto cublas_handle = handle->cublas_handle();
    auto stream = cuda_stream(handle);
    auto batch = args.layout_a.shape[0];
    auto m = args.layout_c.shape[1], n = args.layout_c.shape[2];
    auto k = args.layout_a.shape[param.transposeA ? 1 : 2];
    auto workspace = args.workspace;

M
Megvii Engine Team 已提交
61 62 63 64 65 66
    uintptr_t* As = static_cast<uintptr_t*>(
            static_cast<void*>(workspace.raw_ptr + 0 * batch * sizeof(uintptr_t)));
    uintptr_t* Bs = static_cast<uintptr_t*>(
            static_cast<void*>(workspace.raw_ptr + 1 * batch * sizeof(uintptr_t)));
    uintptr_t* Cs = static_cast<uintptr_t*>(
            static_cast<void*>(workspace.raw_ptr + 2 * batch * sizeof(uintptr_t)));
67

M
Megvii Engine Team 已提交
68
    arange<uintptr_t>(
69
            As, reinterpret_cast<uintptr_t>(args.tensor_a.raw_ptr()),
M
Megvii Engine Team 已提交
70 71
            args.layout_a.stride[0] * dtype.size(), batch, stream);
    arange<uintptr_t>(
72
            Bs, reinterpret_cast<uintptr_t>(args.tensor_b.raw_ptr()),
M
Megvii Engine Team 已提交
73 74
            args.layout_b.stride[0] * dtype.size(), batch, stream);
    arange<uintptr_t>(
75
            Cs, reinterpret_cast<uintptr_t>(args.tensor_c.raw_ptr()),
M
Megvii Engine Team 已提交
76
            args.layout_c.stride[0] * dtype.size(), batch, stream);
77 78 79 80 81 82 83

    auto io32_c32 = [&]() {
        auto zero = handle->zero_device();
        auto one = handle->one_device();
        cublas_check(cublasSgemmBatched(
                cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
                param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one,
M
Megvii Engine Team 已提交
84 85 86
                reinterpret_cast<const dt_float32**>(Bs), args.layout_b.stride[1],
                reinterpret_cast<const dt_float32**>(As), args.layout_a.stride[1], zero,
                reinterpret_cast<dt_float32**>(Cs), args.layout_c.stride[1], batch));
87 88 89 90
    };

#if CUDART_VERSION >= 9010
    auto io16_c32 = [&]() {
M
Megvii Engine Team 已提交
91 92 93
#if CUDART_VERSION >= 11000
        cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
#else
94
        cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH));
M
Megvii Engine Team 已提交
95
#endif
96 97 98 99 100
        auto zero = handle->zero_device();
        auto one = handle->one_device();
        cublas_check(cublasGemmBatchedEx(
                cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
                param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one,
M
Megvii Engine Team 已提交
101 102 103 104
                reinterpret_cast<const void**>(Bs), CUDA_R_16F, args.layout_b.stride[1],
                reinterpret_cast<const void**>(As), CUDA_R_16F, args.layout_a.stride[1],
                zero, reinterpret_cast<void**>(Cs), CUDA_R_16F, args.layout_c.stride[1],
                batch, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));
105 106 107 108 109 110
        cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_DEFAULT_MATH));
    };
#endif

#if CUDART_VERSION >= 9000
    auto io16_c16 = [&]() {
M
Megvii Engine Team 已提交
111 112 113
#if CUDART_VERSION >= 11000
        cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
#else
114
        cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH));
M
Megvii Engine Team 已提交
115
#endif
116 117 118 119 120 121
        auto zero = handle->zero_device_h();
        auto one = handle->one_device_h();
        cublas_check(cublasHgemmBatched(
                cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
                param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one,
                reinterpret_cast<const __half**>(Bs), args.layout_b.stride[1],
M
Megvii Engine Team 已提交
122 123
                reinterpret_cast<const __half**>(As), args.layout_a.stride[1], zero,
                reinterpret_cast<__half**>(Cs), args.layout_c.stride[1], batch));
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
        cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_DEFAULT_MATH));
    };
#endif

    if (dtype == dtype::Float32()) {
        io32_c32();
    } else {
        if (param.compute_mode == Param::ComputeMode::FLOAT32) {
#if CUDART_VERSION >= 9010
            io16_c32();
#endif
        } else {
#if CUDART_VERSION >= 9000
            io16_c16();
#endif
        }
    }
}