提交 e9db061e 编写于 作者: M Megvii Engine Team

fix(mgb): fix compiling error for cuda-11.1

GitOrigin-RevId: f63e71afa75160746f0d69c67282b3a18b544ed1
上级 cd02d7c8
......@@ -12,6 +12,7 @@
#include "src/common/utils.h"
#include "src/cuda/utils.h"
#if CUDA_VERSION >= 10010
namespace megdnn {
namespace cuda {
static cudaDataType_t to_cuda_dtype(DType tp) {
......@@ -31,6 +32,22 @@ static cudaDataType_t to_cuda_dtype(DType tp) {
"dtype must be float16/float32/int8/qs8/int32"));
}
}
static cublasComputeType_t to_cublas_compute_type(DType tp) {
switch (tp.enumv()) {
case DTypeEnum::Float16:
return CUBLAS_COMPUTE_16F;
case DTypeEnum::Float32:
return CUBLAS_COMPUTE_32F;
case DTypeEnum::Int32:
case DTypeEnum::QuantizedS32:
return CUBLAS_COMPUTE_32I;
default:
megdnn_throw(megdnn_mangle(
"dtype must be float16/float32/int32/Qs32"));
}
}
static const char* cuda_type_to_str(cudaDataType_t tp) {
switch (tp) {
case CUDA_R_16F:
......@@ -46,6 +63,7 @@ static const char* cuda_type_to_str(cudaDataType_t tp) {
megdnn_mangle("dtype must be float16/float32/int8/int32"));
}
}
static size_t cuda_dtype_size(cudaDataType_t dt) {
switch (dt) {
case CUDA_R_8I:
......@@ -60,6 +78,7 @@ static size_t cuda_dtype_size(cudaDataType_t dt) {
megdnn_mangle("dtype must be float16/float32/int8/int32"));
}
}
CUBLASLTMatmulDesc::~CUBLASLTMatmulDesc() {
if (matmul_desc)
cublas_check(cublasLtMatmulDescDestroy(matmul_desc));
......@@ -86,9 +105,10 @@ void CUBLASLTMatmulDesc::set(const SizeArgs& args, bool batched) {
uint32_t pm = CUBLAS_POINTER_MODE_DEVICE;
dt_b = to_cuda_dtype(args.layout_b.dtype);
dt_a = to_cuda_dtype(args.layout_a.dtype);
dt_compute = dt_c = to_cuda_dtype(args.layout_c.dtype);
dt_c = to_cuda_dtype(args.layout_c.dtype);
dt_compute = to_cublas_compute_type(args.layout_c.dtype);
megdnn_assert(dt_a == dt_b, "matrix A and B should have same precision");
cublas_check(cublasLtMatmulDescCreate(&matmul_desc, dt_compute));
cublas_check(cublasLtMatmulDescCreate(&matmul_desc, dt_compute, dt_c));
cublas_check(cublasLtMatmulDescSetAttribute(
matmul_desc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pm, sizeof(pm)));
......@@ -100,7 +120,7 @@ void CUBLASLTMatmulDesc::set(const SizeArgs& args, bool batched) {
* So we calculate C^t = B^t * A^t by cublas. Here the transpose symbol
* implies row-major to column-major conversion
*/
if (dt_compute == CUDA_R_32I) {
if (dt_c == CUDA_R_32I) {
/**
* \NOTE: To use IMMA kernels, use computeType = CUDA_R_32I and
* CUBLASLT_ORDER_COL32 for matrices A,C,D and
......@@ -209,7 +229,7 @@ void CUBLASLTMatmulDesc::set(const SizeArgs& args, bool batched) {
bool CUBLASLTMatmulDesc::is_available(const SizeArgs& args, size_t ws_limit) {
bool support;
cublasLtMatmulAlgo_t algo;
switch (dt_compute) {
switch (dt_c) {
case CUDA_R_16F:
support = (dt_a == CUDA_R_16F);
break;
......@@ -239,17 +259,17 @@ WorkspaceBundle CUBLASLTMatmulDesc::get_workspace_bundle(
cublasLtMatmulHeuristicResult_t result{};
status = cublasLtMatmulAlgoCheck(
cublasLt_handle, matmul_desc,
dt_compute == CUDA_R_32I ? layout_trans_b : layout_b,
dt_compute == CUDA_R_32I ? layout_trans_a : layout_a,
dt_compute == CUDA_R_32I ? layout_trans_c : layout_c,
dt_compute == CUDA_R_32I ? layout_trans_c : layout_c, &algo,
dt_c == CUDA_R_32I ? layout_trans_b : layout_b,
dt_c == CUDA_R_32I ? layout_trans_a : layout_a,
dt_c == CUDA_R_32I ? layout_trans_c : layout_c,
dt_c == CUDA_R_32I ? layout_trans_c : layout_c, &algo,
&result);
// return empty WorkspaceBundle if cublasLtMatmulAlgoCheck() failed
if (status != CUBLAS_STATUS_SUCCESS)
return {nullptr, {}};
algo_workspace_size = result.workspaceSize;
return {nullptr,
(dt_compute == CUDA_R_32I)
(dt_c == CUDA_R_32I)
? SmallVector<size_t>{algo_workspace_size, workspace_b,
workspace_a, workspace_c}
: SmallVector<size_t>{algo_workspace_size}};
......@@ -273,7 +293,7 @@ bool CUBLASLTMatmulDesc::get_algorithm_heuristic(const SizeArgs& args,
* \Note: algo_ws_limit must be zero if cublasLtGetVersion() <= 10100
*/
// algo_ws_limit = 0;
if (dt_compute == CUDA_R_32I) {
if (dt_c == CUDA_R_32I) {
//[FIXME]: cublasLt(Version 10020) produce wrong result when k in
//[64*n+1 , 64*n+32] for small matrix
......@@ -291,10 +311,10 @@ bool CUBLASLTMatmulDesc::get_algorithm_heuristic(const SizeArgs& args,
sizeof(algo_ws_limit)));
status = cublasLtMatmulAlgoGetHeuristic(
cublasLt_handle, matmul_desc,
dt_compute == CUDA_R_32I ? layout_trans_b : layout_b,
dt_compute == CUDA_R_32I ? layout_trans_a : layout_a,
dt_compute == CUDA_R_32I ? layout_trans_c : layout_c,
dt_compute == CUDA_R_32I ? layout_trans_c : layout_c, algo_pref, 1,
dt_c == CUDA_R_32I ? layout_trans_b : layout_b,
dt_c == CUDA_R_32I ? layout_trans_a : layout_a,
dt_c == CUDA_R_32I ? layout_trans_c : layout_c,
dt_c == CUDA_R_32I ? layout_trans_c : layout_c, algo_pref, 1,
&algo_result, &return_algo_count);
if (status == CUBLAS_STATUS_SUCCESS && return_algo_count > 0 &&
// perform cublasLtAlgoCheck() to make sure the algo is correct
......
......@@ -47,7 +47,8 @@ struct CUBLASLTMatmulDesc {
};
bool is_batched;
cublasLtMatmulDesc_t matmul_desc;
cudaDataType_t dt_a, dt_b, dt_c, dt_compute;
cudaDataType_t dt_a, dt_b, dt_c;
cublasComputeType_t dt_compute;
cublasLtMatrixLayout_t layout_a, layout_b, layout_c;
cublasLtMatrixLayout_t layout_trans_a, layout_trans_b, layout_trans_c;
size_t workspace_a, workspace_b, workspace_c;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册