提交 e9db061e 编写于 作者: M Megvii Engine Team

fix(mgb): fix compiling error for cuda-11.1

GitOrigin-RevId: f63e71afa75160746f0d69c67282b3a18b544ed1
上级 cd02d7c8
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/cuda/utils.h" #include "src/cuda/utils.h"
#if CUDA_VERSION >= 10010 #if CUDA_VERSION >= 10010
namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {
static cudaDataType_t to_cuda_dtype(DType tp) { static cudaDataType_t to_cuda_dtype(DType tp) {
...@@ -31,6 +32,22 @@ static cudaDataType_t to_cuda_dtype(DType tp) { ...@@ -31,6 +32,22 @@ static cudaDataType_t to_cuda_dtype(DType tp) {
"dtype must be float16/float32/int8/qs8/int32")); "dtype must be float16/float32/int8/qs8/int32"));
} }
} }
static cublasComputeType_t to_cublas_compute_type(DType tp) {
switch (tp.enumv()) {
case DTypeEnum::Float16:
return CUBLAS_COMPUTE_16F;
case DTypeEnum::Float32:
return CUBLAS_COMPUTE_32F;
case DTypeEnum::Int32:
case DTypeEnum::QuantizedS32:
return CUBLAS_COMPUTE_32I;
default:
megdnn_throw(megdnn_mangle(
"dtype must be float16/float32/int32/Qs32"));
}
}
static const char* cuda_type_to_str(cudaDataType_t tp) { static const char* cuda_type_to_str(cudaDataType_t tp) {
switch (tp) { switch (tp) {
case CUDA_R_16F: case CUDA_R_16F:
...@@ -46,6 +63,7 @@ static const char* cuda_type_to_str(cudaDataType_t tp) { ...@@ -46,6 +63,7 @@ static const char* cuda_type_to_str(cudaDataType_t tp) {
megdnn_mangle("dtype must be float16/float32/int8/int32")); megdnn_mangle("dtype must be float16/float32/int8/int32"));
} }
} }
static size_t cuda_dtype_size(cudaDataType_t dt) { static size_t cuda_dtype_size(cudaDataType_t dt) {
switch (dt) { switch (dt) {
case CUDA_R_8I: case CUDA_R_8I:
...@@ -60,6 +78,7 @@ static size_t cuda_dtype_size(cudaDataType_t dt) { ...@@ -60,6 +78,7 @@ static size_t cuda_dtype_size(cudaDataType_t dt) {
megdnn_mangle("dtype must be float16/float32/int8/int32")); megdnn_mangle("dtype must be float16/float32/int8/int32"));
} }
} }
CUBLASLTMatmulDesc::~CUBLASLTMatmulDesc() { CUBLASLTMatmulDesc::~CUBLASLTMatmulDesc() {
if (matmul_desc) if (matmul_desc)
cublas_check(cublasLtMatmulDescDestroy(matmul_desc)); cublas_check(cublasLtMatmulDescDestroy(matmul_desc));
...@@ -86,9 +105,10 @@ void CUBLASLTMatmulDesc::set(const SizeArgs& args, bool batched) { ...@@ -86,9 +105,10 @@ void CUBLASLTMatmulDesc::set(const SizeArgs& args, bool batched) {
uint32_t pm = CUBLAS_POINTER_MODE_DEVICE; uint32_t pm = CUBLAS_POINTER_MODE_DEVICE;
dt_b = to_cuda_dtype(args.layout_b.dtype); dt_b = to_cuda_dtype(args.layout_b.dtype);
dt_a = to_cuda_dtype(args.layout_a.dtype); dt_a = to_cuda_dtype(args.layout_a.dtype);
dt_compute = dt_c = to_cuda_dtype(args.layout_c.dtype); dt_c = to_cuda_dtype(args.layout_c.dtype);
dt_compute = to_cublas_compute_type(args.layout_c.dtype);
megdnn_assert(dt_a == dt_b, "matrix A and B should have same precision"); megdnn_assert(dt_a == dt_b, "matrix A and B should have same precision");
cublas_check(cublasLtMatmulDescCreate(&matmul_desc, dt_compute)); cublas_check(cublasLtMatmulDescCreate(&matmul_desc, dt_compute, dt_c));
cublas_check(cublasLtMatmulDescSetAttribute( cublas_check(cublasLtMatmulDescSetAttribute(
matmul_desc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pm, sizeof(pm))); matmul_desc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pm, sizeof(pm)));
...@@ -100,7 +120,7 @@ void CUBLASLTMatmulDesc::set(const SizeArgs& args, bool batched) { ...@@ -100,7 +120,7 @@ void CUBLASLTMatmulDesc::set(const SizeArgs& args, bool batched) {
* So we calculate C^t = B^t * A^t by cublas. Here the transpose symbol * So we calculate C^t = B^t * A^t by cublas. Here the transpose symbol
* implies row-major to column-major conversion * implies row-major to column-major conversion
*/ */
if (dt_compute == CUDA_R_32I) { if (dt_c == CUDA_R_32I) {
/** /**
* \NOTE: To use IMMA kernels, use computeType = CUDA_R_32I and * \NOTE: To use IMMA kernels, use computeType = CUDA_R_32I and
* CUBLASLT_ORDER_COL32 for matrices A,C,D and * CUBLASLT_ORDER_COL32 for matrices A,C,D and
...@@ -209,7 +229,7 @@ void CUBLASLTMatmulDesc::set(const SizeArgs& args, bool batched) { ...@@ -209,7 +229,7 @@ void CUBLASLTMatmulDesc::set(const SizeArgs& args, bool batched) {
bool CUBLASLTMatmulDesc::is_available(const SizeArgs& args, size_t ws_limit) { bool CUBLASLTMatmulDesc::is_available(const SizeArgs& args, size_t ws_limit) {
bool support; bool support;
cublasLtMatmulAlgo_t algo; cublasLtMatmulAlgo_t algo;
switch (dt_compute) { switch (dt_c) {
case CUDA_R_16F: case CUDA_R_16F:
support = (dt_a == CUDA_R_16F); support = (dt_a == CUDA_R_16F);
break; break;
...@@ -239,17 +259,17 @@ WorkspaceBundle CUBLASLTMatmulDesc::get_workspace_bundle( ...@@ -239,17 +259,17 @@ WorkspaceBundle CUBLASLTMatmulDesc::get_workspace_bundle(
cublasLtMatmulHeuristicResult_t result{}; cublasLtMatmulHeuristicResult_t result{};
status = cublasLtMatmulAlgoCheck( status = cublasLtMatmulAlgoCheck(
cublasLt_handle, matmul_desc, cublasLt_handle, matmul_desc,
dt_compute == CUDA_R_32I ? layout_trans_b : layout_b, dt_c == CUDA_R_32I ? layout_trans_b : layout_b,
dt_compute == CUDA_R_32I ? layout_trans_a : layout_a, dt_c == CUDA_R_32I ? layout_trans_a : layout_a,
dt_compute == CUDA_R_32I ? layout_trans_c : layout_c, dt_c == CUDA_R_32I ? layout_trans_c : layout_c,
dt_compute == CUDA_R_32I ? layout_trans_c : layout_c, &algo, dt_c == CUDA_R_32I ? layout_trans_c : layout_c, &algo,
&result); &result);
// return empty WorkspaceBundle if cublasLtMatmulAlgoCheck() failed // return empty WorkspaceBundle if cublasLtMatmulAlgoCheck() failed
if (status != CUBLAS_STATUS_SUCCESS) if (status != CUBLAS_STATUS_SUCCESS)
return {nullptr, {}}; return {nullptr, {}};
algo_workspace_size = result.workspaceSize; algo_workspace_size = result.workspaceSize;
return {nullptr, return {nullptr,
(dt_compute == CUDA_R_32I) (dt_c == CUDA_R_32I)
? SmallVector<size_t>{algo_workspace_size, workspace_b, ? SmallVector<size_t>{algo_workspace_size, workspace_b,
workspace_a, workspace_c} workspace_a, workspace_c}
: SmallVector<size_t>{algo_workspace_size}}; : SmallVector<size_t>{algo_workspace_size}};
...@@ -273,7 +293,7 @@ bool CUBLASLTMatmulDesc::get_algorithm_heuristic(const SizeArgs& args, ...@@ -273,7 +293,7 @@ bool CUBLASLTMatmulDesc::get_algorithm_heuristic(const SizeArgs& args,
* \Note: algo_ws_limit must be zero if cublasLtGetVersion() <= 10100 * \Note: algo_ws_limit must be zero if cublasLtGetVersion() <= 10100
*/ */
// algo_ws_limit = 0; // algo_ws_limit = 0;
if (dt_compute == CUDA_R_32I) { if (dt_c == CUDA_R_32I) {
//[FIXME]: cublasLt(Version 10020) produce wrong result when k in //[FIXME]: cublasLt(Version 10020) produce wrong result when k in
//[64*n+1 , 64*n+32] for small matrix //[64*n+1 , 64*n+32] for small matrix
...@@ -291,10 +311,10 @@ bool CUBLASLTMatmulDesc::get_algorithm_heuristic(const SizeArgs& args, ...@@ -291,10 +311,10 @@ bool CUBLASLTMatmulDesc::get_algorithm_heuristic(const SizeArgs& args,
sizeof(algo_ws_limit))); sizeof(algo_ws_limit)));
status = cublasLtMatmulAlgoGetHeuristic( status = cublasLtMatmulAlgoGetHeuristic(
cublasLt_handle, matmul_desc, cublasLt_handle, matmul_desc,
dt_compute == CUDA_R_32I ? layout_trans_b : layout_b, dt_c == CUDA_R_32I ? layout_trans_b : layout_b,
dt_compute == CUDA_R_32I ? layout_trans_a : layout_a, dt_c == CUDA_R_32I ? layout_trans_a : layout_a,
dt_compute == CUDA_R_32I ? layout_trans_c : layout_c, dt_c == CUDA_R_32I ? layout_trans_c : layout_c,
dt_compute == CUDA_R_32I ? layout_trans_c : layout_c, algo_pref, 1, dt_c == CUDA_R_32I ? layout_trans_c : layout_c, algo_pref, 1,
&algo_result, &return_algo_count); &algo_result, &return_algo_count);
if (status == CUBLAS_STATUS_SUCCESS && return_algo_count > 0 && if (status == CUBLAS_STATUS_SUCCESS && return_algo_count > 0 &&
// perform cublasLtAlgoCheck() to make sure the algo is correct // perform cublasLtAlgoCheck() to make sure the algo is correct
......
...@@ -47,7 +47,8 @@ struct CUBLASLTMatmulDesc { ...@@ -47,7 +47,8 @@ struct CUBLASLTMatmulDesc {
}; };
bool is_batched; bool is_batched;
cublasLtMatmulDesc_t matmul_desc; cublasLtMatmulDesc_t matmul_desc;
cudaDataType_t dt_a, dt_b, dt_c, dt_compute; cudaDataType_t dt_a, dt_b, dt_c;
cublasComputeType_t dt_compute;
cublasLtMatrixLayout_t layout_a, layout_b, layout_c; cublasLtMatrixLayout_t layout_a, layout_b, layout_c;
cublasLtMatrixLayout_t layout_trans_a, layout_trans_b, layout_trans_c; cublasLtMatrixLayout_t layout_trans_a, layout_trans_b, layout_trans_c;
size_t workspace_a, workspace_b, workspace_c; size_t workspace_a, workspace_b, workspace_c;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册