cublas.cpp 5.8 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/matrix_mul/cublas.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "./algos.h"

#include "src/cuda/handle.h"
#include "src/cuda/utils.h"

#include <cuda.h>

using namespace megdnn;
using namespace cuda;

#if CUDA_VERSION >= 8000
#define SE_CUDA_DATA_HALF CUDA_R_16F
#else
#define SE_CUDA_DATA_HALF CUBLAS_DATA_HALF
#endif

28 29 30 31
#if CUDA_VERSION < 11000
#define CUBLAS_COMPUTE_32I CUDA_R_32I
#endif

32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
bool MatrixMulForwardImpl::AlgoCuBlas::is_available(
        const SizeArgs& args) const {
    if (args.opr->param().format != param::MatrixMul::Format::DEFAULT)
        return false;
    if (args.layout_a.dtype == dtype::Float32() ||
        args.layout_a.dtype == dtype::Float16()) {
        return true;
    } else if (args.layout_a.dtype.enumv() == DTypeEnum::Int8 ||
               args.layout_a.dtype.enumv() == DTypeEnum::QuantizedS8) {
        /**
         * \note When passing in the strides which can not be divided by 4, the
         * cublas rontine cublasGemmEx will raise a Error
         * CUBLAS_STATUS_INVALID_VALUE. The error occured because the leading
         * dimension of matrix A or B is illegal.
         */
        return args.layout_a.stride[0] % 4 == 0 &&
               args.layout_b.stride[0] % 4 == 0 &&
49
               is_compute_capability_required(6, 1);
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
    }
    return false;
}

void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const {
    auto&& handle = concrete_handle(args.opr->handle());
    auto&& cublas_handle = handle->cublas_handle();
    auto&& param = args.opr->param();
    size_t m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1],
           k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1];

    auto sgemm = [&]() {
        auto zero = handle->zero_device();
        auto one = handle->one_device();
        cublas_check(cublasSgemm(
                cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
                param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one,
                args.tensor_b.ptr<dt_float32>(), args.tensor_b.layout.stride[0],
                args.tensor_a.ptr<dt_float32>(), args.tensor_a.layout.stride[0],
                zero, args.tensor_c.ptr<dt_float32>(),
                args.tensor_c.layout.stride[0]));
    };

    auto sgemm_ex = [&]() {
        auto zero = handle->zero_device();
        auto one = handle->one_device();
#if CUDART_VERSION >= 9000
        cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH));
#endif
        auto sgemm_ex_err = cublasSgemmEx(
                cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
                param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one,
                args.tensor_b.raw_ptr, SE_CUDA_DATA_HALF,
                args.tensor_b.layout.stride[0], args.tensor_a.raw_ptr,
                SE_CUDA_DATA_HALF, args.tensor_a.layout.stride[0], zero,
                args.tensor_c.raw_ptr, SE_CUDA_DATA_HALF,
                args.tensor_c.layout.stride[0]);
        cublas_check(sgemm_ex_err);
#if CUDART_VERSION >= 9000
        cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_DEFAULT_MATH));
#endif
    };

    auto hgemm = [&]() {
#if CUDART_VERSION >= 9000
        cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH));
#endif
        auto one_half = handle->one_device_h();
        auto zero_half = handle->zero_device_h();
        auto hgemm_ex_err = cublasHgemm(
                cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
                param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one_half,
                static_cast<const __half*>(args.tensor_b.raw_ptr),
                args.tensor_b.layout.stride[0],
                static_cast<const __half*>(args.tensor_a.raw_ptr),
                args.tensor_a.layout.stride[0], zero_half,
                static_cast<__half*>(args.tensor_c.raw_ptr),
                args.tensor_c.layout.stride[0]);
        cublas_check(hgemm_ex_err);
#if CUDART_VERSION >= 9000
        cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_DEFAULT_MATH));
#endif
    };

    auto igemm = [&]() {
        auto zero = handle->zero_device_i32();
        auto one = handle->one_device_i32();
        cublas_check(cublasGemmEx(
                cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
                param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one,
                args.tensor_b.raw_ptr, CUDA_R_8I,
                args.tensor_b.layout.stride[0], args.tensor_a.raw_ptr,
                CUDA_R_8I, args.tensor_a.layout.stride[0], zero,
                args.tensor_c.raw_ptr, CUDA_R_32I,
124
                args.tensor_c.layout.stride[0], CUBLAS_COMPUTE_32I, CUBLAS_GEMM_DFALT));
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
    };

    // Note that cublas takes column-major matrices as inputs,
    // but megdnn takes row-major ones.
    // So we calculate C^t = B^t * A^t by cublas. Here the transpose symbol
    // implies row-major to column-major conversion.
    if (args.tensor_a.layout.dtype == dtype::Float32()) {
        sgemm();
    } else if (args.tensor_a.layout.dtype == dtype::Float16()) {
        // use tensor core; note that CUBLAS_TENSOR_OP_MATH also causes
        // cublasSgemm to round to fp16, so we can not always enable it
        if (handle->device_prop().major >= 6 &&
            param.compute_mode == Param::ComputeMode::DEFAULT)
            hgemm();
        else
            sgemm_ex();
    } else if (args.can_be_treated_as_int8x8x32()) {
        igemm();
    } else {
        megdnn_throw("Unsupported data_type of matrix mul on cuda.");
    }
}

// vim: syntax=cpp.doxygen