cublas_lt.cpp 6.9 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/batched_matrix_mul/cublas_lt.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#include "./algo.h"
#include "src/cuda/handle.h"
#include "src/cuda/matrix_mul/cublasLt_wrapper.h"
M
Megvii Engine Team 已提交
14
#include "src/cuda/utils.h"
15 16 17 18 19 20 21 22 23 24 25

using namespace megdnn;
using namespace cuda;

#if CUDA_VERSION >= 10010
static inline CUBLASLTMatmulDesc::SizeArgs from_local_size_args(
        const BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs& args) {
    auto&& param = args.opr->param();
    auto&& handle = concrete_handle(args.opr->handle());
    bool transA = param.transposeA;
    bool transB = param.transposeB;
M
Megvii Engine Team 已提交
26
    return {handle, transA, transB, args.layout_a, args.layout_b, args.layout_c};
27
}
28

29 30 31 32 33 34
bool BatchedMatrixMulForwardImpl::AlgoCublasLt::is_available(
        const SizeArgs& args) const {
    auto cublasLt_args = from_local_size_args(args);
    auto&& dev_prop = current_device_prop();
    bool is_dev_support = dev_prop.major >= 7;
    bool res = is_dev_support && CUBLASLTMatmulDesc(cublasLt_args, true)
M
Megvii Engine Team 已提交
35
                                         .is_available(cublasLt_args, INT_MAX);
36 37
    return res;
}
38

39 40 41 42 43 44 45 46
size_t BatchedMatrixMulForwardImpl::AlgoCublasLt::get_workspace_in_bytes(
        const SizeArgs& args) const {
    auto cublasLt_args = from_local_size_args(args);
    cublasLtMatmulAlgo_t algo;
    CUBLASLTMatmulDesc desc(cublasLt_args, true);
    desc.get_algorithm_heuristic(cublasLt_args, INT_MAX, algo);
    return desc.get_workspace_bundle(cublasLt_args, algo).total_size_in_bytes();
}
47

M
Megvii Engine Team 已提交
48
void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec(const ExecArgs& args) const {
49 50 51 52 53 54 55 56 57 58 59
    auto cublasLt_args = from_local_size_args(args);
    cublasLtMatmulAlgo_t algo;
    CUBLASLTMatmulDesc desc(cublasLt_args, true);
    desc.get_algorithm_heuristic(cublasLt_args, INT_MAX, algo);
    auto ws_bundle = desc.get_workspace_bundle(cublasLt_args, algo);
    auto&& handle = concrete_handle(args.opr->handle());
    auto&& stream = handle->stream();
    auto&& cublasLt_handle = handle->cublasLt_handle();
    auto batched_hgemm = [&]() {
        auto zero_half = handle->zero_device_h();
        auto one_half = handle->one_device_h();
M
Megvii Engine Team 已提交
60 61 62
        megdnn_assert(
                ws_bundle.nr_workspace() == 1,
                "workspace bundle size should be 1(ws_algo)");
63 64
        cublas_check(cublasLtMatmul(
                cublasLt_handle, desc.matmul_desc, one_half,
M
Megvii Engine Team 已提交
65 66 67
                static_cast<const __half*>(args.tensor_b.raw_ptr), desc.layout_b,
                static_cast<const __half*>(args.tensor_a.raw_ptr), desc.layout_a,
                zero_half, static_cast<const __half*>(args.tensor_c.raw_ptr),
68
                desc.layout_c, static_cast<__half*>(args.tensor_c.raw_ptr),
M
Megvii Engine Team 已提交
69
                desc.layout_c, &algo, ws_bundle.get(0), ws_bundle.get_size(0), stream));
70 71 72 73
    };
    auto batched_sgemm = [&]() {
        auto zero = handle->zero_device();
        auto one = handle->one_device();
M
Megvii Engine Team 已提交
74 75 76 77 78 79
        auto dev_b = (desc.dt_b == CUDA_R_16F)
                           ? static_cast<void*>(args.tensor_b.ptr<dt_float16>())
                           : static_cast<void*>(args.tensor_b.ptr<dt_float32>());
        auto dev_a = (desc.dt_a == CUDA_R_16F)
                           ? static_cast<void*>(args.tensor_a.ptr<dt_float16>())
                           : static_cast<void*>(args.tensor_a.ptr<dt_float32>());
80
        auto dev_c = static_cast<void*>(args.tensor_c.raw_ptr);
M
Megvii Engine Team 已提交
81 82 83 84 85 86 87
        megdnn_assert(
                ws_bundle.nr_workspace() == 1,
                "workspace bundle size should be 1(ws_algo)");
        cublas_check(cublasLtMatmul(
                cublasLt_handle, desc.matmul_desc, one, dev_b, desc.layout_b, dev_a,
                desc.layout_a, zero, dev_c, desc.layout_c, dev_c, desc.layout_c, &algo,
                ws_bundle.get(0), ws_bundle.get_size(0), stream));
88
    };
M
Megvii Engine Team 已提交
89

90 91 92 93 94 95 96 97 98 99 100 101
    auto batched_igemm = [&]() {
        auto zero = handle->zero_device();
        auto one = handle->one_device();
        megdnn_assert(
                ws_bundle.nr_workspace() == 4,
                "workspace bundle size should be 4(ws_algo, ws_a, ws_b, ws_c)");
        void* ws_b = ws_bundle.get(1);
        void* ws_a = ws_bundle.get(2);
        void* ws_c = ws_bundle.get(3);
        int32_t pm = CUBLAS_POINTER_MODE_DEVICE;
        cublasOperation_t trans_a = CUBLAS_OP_T, trans_c = CUBLAS_OP_N;
        cublasLtMatrixTransformDesc_t transform_desc = nullptr;
M
Megvii Engine Team 已提交
102
        cublas_check(cublasLtMatrixTransformDescCreate(&transform_desc, CUDA_R_32F));
103
        cublas_check(cublasLtMatrixTransformDescSetAttribute(
M
Megvii Engine Team 已提交
104 105
                transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE, &pm,
                sizeof(pm)));
106 107
        cublas_check(cublasLtMatrixTransform(
                cublasLt_handle, transform_desc, one, args.tensor_b.raw_ptr,
M
Megvii Engine Team 已提交
108 109
                desc.layout_b, zero, nullptr, nullptr, ws_b, desc.layout_trans_b,
                stream));
110 111 112 113 114
        cublas_check(cublasLtMatrixTransformDescSetAttribute(
                transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &trans_a,
                sizeof(trans_a)));
        cublas_check(cublasLtMatrixTransform(
                cublasLt_handle, transform_desc, one, args.tensor_a.raw_ptr,
M
Megvii Engine Team 已提交
115 116
                desc.layout_a, zero, nullptr, nullptr, ws_a, desc.layout_trans_a,
                stream));
117
        cublas_check(cublasLtMatmul(
M
Megvii Engine Team 已提交
118 119 120 121
                cublasLt_handle, desc.matmul_desc, one, ws_b, desc.layout_trans_b, ws_a,
                desc.layout_trans_a, zero, ws_c, desc.layout_trans_c, ws_c,
                desc.layout_trans_c, &algo, ws_bundle.get(0), ws_bundle.get_size(0),
                stream));
122 123 124 125
        cublas_check(cublasLtMatrixTransformDescSetAttribute(
                transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &trans_c,
                sizeof(trans_c)));
        cublas_check(cublasLtMatrixTransform(
M
Megvii Engine Team 已提交
126 127
                cublasLt_handle, transform_desc, one, ws_c, desc.layout_trans_c, zero,
                nullptr, nullptr, args.tensor_c.raw_ptr, desc.layout_c, stream));
128 129 130 131
        cublas_check(cublasLtMatrixTransformDescDestroy(transform_desc));
    };

    ws_bundle.set(args.workspace.raw_ptr);
132 133 134 135 136 137 138 139
#if CUDA_VERSION >= 11000
    if (desc.dt_compute == CUBLAS_COMPUTE_32I) {
        batched_igemm();
    } else if (desc.dt_compute == CUBLAS_COMPUTE_16F) {
        batched_hgemm();
    } else if (desc.dt_compute == CUBLAS_COMPUTE_32F) {
        batched_sgemm();
    } else {
M
Megvii Engine Team 已提交
140
        megdnn_throw("compute_type must be int32/float16/float32");
141 142
    }
#else
143 144 145 146 147 148 149
    if (desc.dt_compute == CUDA_R_32I) {
        batched_igemm();
    } else if (desc.dt_compute == CUDA_R_16F) {
        batched_hgemm();
    } else if (desc.dt_compute == CUDA_R_32F) {
        batched_sgemm();
    } else {
M
Megvii Engine Team 已提交
150
        megdnn_throw("compute_type must be int32/float16/float32");
151
    }
152
#endif
153 154
}
#endif