diff --git a/dnn/src/cuda/matrix_mul/naive.cpp b/dnn/src/cuda/matrix_mul/naive.cpp index 577596405806e207bd3c2db6ea09534bc29039d5..a9145e5928d4a914e121bb5b03f3237e16353a5b 100644 --- a/dnn/src/cuda/matrix_mul/naive.cpp +++ b/dnn/src/cuda/matrix_mul/naive.cpp @@ -17,8 +17,21 @@ using namespace megdnn; using namespace cuda; +#include "midout.h" +MIDOUT_DECL(megdnn_naive_matmul) + bool MatrixMulForwardImpl::AlgoNaive::is_available(const SizeArgs& args) const { - return args.can_be_treated_as_int8x8x32(); + if (args.can_be_treated_as_int8x8x32()) + return true; + auto&& layout_a = args.layout_a; + auto&& layout_b = args.layout_b; + auto&& layout_c = args.layout_c; + return layout_a.dtype.enumv() == layout_b.dtype.enumv() && + (layout_a.dtype.enumv() == DTypeEnum::Float32 || + layout_a.dtype.enumv() == DTypeEnum::Float16) && + (layout_c.dtype.enumv() == DTypeEnum::Float32 || + layout_c.dtype.enumv() == DTypeEnum::Float16) && + args.opr->param().format == param::MatrixMul::Format::DEFAULT; } void MatrixMulForwardImpl::AlgoNaive::exec(const ExecArgs& args) const { auto&& param = args.opr->param(); @@ -28,13 +41,45 @@ void MatrixMulForwardImpl::AlgoNaive::exec(const ExecArgs& args) const { LDB = args.tensor_b.layout.stride[0], LDC = args.tensor_c.layout.stride[0]; - int8_t* A = args.tensor_a.compatible_ptr(); - int8_t* B = args.tensor_b.compatible_ptr(); - int32_t* C = args.tensor_c.compatible_ptr(); - auto&& handle = concrete_handle(args.opr->handle()); - exec_gemm_int8_naive(A, B, C, m, n, k, LDA, LDB, LDC, param.transposeA, - param.transposeB, cuda_stream(handle)); + + using ComputeMode = Param::ComputeMode; +#define DISPATCH_CMODE(in_dt, out_dt, in_ct, out_ct, comp_ct, cmode) \ + MIDOUT_BEGIN(megdnn_naive_matmul, midout_iv(#in_dt #out_dt #in_ct, \ + #out_ct, #comp_ct, #cmode)) { \ + do { \ + using namespace dtype; \ + if (args.tensor_a.layout.dtype.enumv() == \ + DTypeTrait::enumv && \ + args.tensor_c.layout.dtype.enumv() == \ + DTypeTrait::enumv && \ + param.compute_mode == cmode) { \ + in_ct* A = args.tensor_a.compatible_ptr(); \ + in_ct* B = args.tensor_b.compatible_ptr(); \ + out_ct* C = args.tensor_c.compatible_ptr(); \ + exec_gemm_naive( \ + A, B, C, m, n, k, LDA, LDB, LDC, param.transposeA, \ + param.transposeB, cuda_stream(handle)); \ + return; \ + } \ + } while (0); \ + } \ + MIDOUT_END(); +#define DISPATCH(in_dt, out_dt, in_ct, out_ct, comp_ct) \ + DISPATCH_CMODE(in_dt, out_dt, in_ct, out_ct, comp_ct, ComputeMode::DEFAULT) + + DISPATCH(Float32, Float32, dt_float32, dt_float32, dt_float32); + DISPATCH(Float16, Float16, dt_float16, dt_float16, dt_float16); + DISPATCH(Int8, Int32, dt_int8, dt_int32, dt_int32); + DISPATCH(QuantizedS8, QuantizedS32, dt_int8, dt_int32, dt_int32); + DNN_INC_FLOAT16(DISPATCH_CMODE(Float16, Float16, dt_float16, dt_float16, + dt_float32, ComputeMode::FLOAT32)); +#undef DISPATCH_CMODE +#undef DISPATCH + megdnn_throw(ssprintf( + "unsupported Matmul(%s, %s) -> %s with cmode = %d", + args.layout_a.dtype.name(), args.layout_b.dtype.name(), + args.layout_c.dtype.name(), static_cast(param.compute_mode))); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/matrix_mul/naive.cu b/dnn/src/cuda/matrix_mul/naive.cu index 49d8f6eec6c59ad6ffa20c59405756eceebd4c87..6a892a6a5523bf907c053f2aa22c9cdd887ae5ae 100644 --- a/dnn/src/cuda/matrix_mul/naive.cu +++ b/dnn/src/cuda/matrix_mul/naive.cu @@ -14,16 +14,18 @@ #include "src/cuda/utils.cuh" namespace { -__global__ void do_exec(const int8_t* A, const int8_t* B, int32_t* C, size_t M, + +template +__global__ void do_exec(const AType* A, const BType* B, CType* C, size_t M, size_t N, size_t K, size_t LDA, size_t LDB, size_t LDC, bool transA, bool transB) { size_t m = blockIdx.x; for (; m < M; m += gridDim.x) { size_t n = threadIdx.x; for (; n < N; n += blockDim.x) { - int32_t res = 0; + CompType res = static_cast(0); for (size_t k = 0; k < K; ++k) { - int8_t av = transA ? A[k * LDA + m] : A[m * LDA + k], + AType av = transA ? A[k * LDA + m] : A[m * LDA + k], bv = transB ? B[n * LDB + k] : B[k * LDB + n]; res += av * bv; } @@ -36,14 +38,29 @@ __global__ void do_exec(const int8_t* A, const int8_t* B, int32_t* C, size_t M, namespace megdnn { namespace cuda { -void exec_gemm_int8_naive(const int8_t* A, const int8_t* B, int32_t* C, - size_t M, size_t N, size_t K, size_t LDA, size_t LDB, - size_t LDC, bool transA, bool transB, - cudaStream_t stream) { - do_exec<<<128, 128, 0, stream>>>(A, B, C, M, N, K, LDA, LDB, LDC, transA, - transB); +template +void exec_gemm_naive(const AType* A, const BType* B, CType* C, size_t M, + size_t N, size_t K, size_t LDA, size_t LDB, size_t LDC, + bool transA, bool transB, cudaStream_t stream) { + do_exec<<<128, 128, 0, stream>>>( + A, B, C, M, N, K, LDA, LDB, LDC, transA, transB); } +#define INST(in_ct, out_ct, comp_ct) \ + template void exec_gemm_naive( \ + const in_ct* A, const in_ct* B, out_ct* C, size_t M, size_t N, \ + size_t K, size_t LDA, size_t LDB, size_t LDC, bool transA, \ + bool transB, cudaStream_t stream); + +INST(megdnn::dt_float32, megdnn::dt_float32, megdnn::dt_float32) +INST(megdnn::dt_float16, megdnn::dt_float16, megdnn::dt_float16) +INST(megdnn::dt_int8, megdnn::dt_int32, megdnn::dt_int32) +INST(megdnn::dt_float16, megdnn::dt_float16, megdnn::dt_float32) + +#undef cb +#undef INST + } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/matrix_mul/naive.cuh b/dnn/src/cuda/matrix_mul/naive.cuh index 01d43d8b6598eb646195b0f907dfdcb8ba0db3a4..615d6befce6395a7a9107dcb4a20c5dfeb100ada 100644 --- a/dnn/src/cuda/matrix_mul/naive.cuh +++ b/dnn/src/cuda/matrix_mul/naive.cuh @@ -15,8 +15,9 @@ namespace megdnn { namespace cuda { -void exec_gemm_int8_naive(const int8_t* A, const int8_t* B, int32_t* C, - size_t m, size_t n, size_t k, size_t ldA, size_t ldB, +template +void exec_gemm_naive(const AType* A, const BType* B, CType* C, size_t m, + size_t n, size_t k, size_t ldA, size_t ldB, size_t ldC, bool transA, bool transB, cudaStream_t stream); } // namespace cuda diff --git a/dnn/test/cuda/matrix_mul.cpp b/dnn/test/cuda/matrix_mul.cpp index 9ba0b3cdb762b9ccf5f0d2c3562ba01ad14ca072..4cc9cf811b1937ce70519bbe5a39225072bc7e03 100644 --- a/dnn/test/cuda/matrix_mul.cpp +++ b/dnn/test/cuda/matrix_mul.cpp @@ -185,6 +185,46 @@ TEST_F(CUDA, MATRIX_MUL_INT8x8x32_NAIVE) { } } +TEST_F(CUDA, MATRIX_MUL_FLOAT_NAIVE) { + Checker checker(handle_cuda()); + checker.set_before_exec_callback(AlgoChecker("NAIVE")); + using Param = MatrixMul::Param; + size_t m = 12, n = 16, k = 20; + + std::vector dtype_array; + dtype_array.push_back(dtype::Float32()); + dtype_array.push_back(dtype::Float16()); + + for (DType dtype : dtype_array) { + for (unsigned mask = 0; mask < 4; ++mask) { + Param param; + param.transposeA = mask & 1; + param.transposeB = mask & 2; + DType stype = dtype; + TensorShape A, B; + if (param.transposeA) + A = TensorShape{k, m}; + else + A = TensorShape{m, k}; + if (param.transposeB) + B = TensorShape{n, k}; + else + B = TensorShape{k, n}; + if (dtype == dtype::Float16()) { + param.compute_mode = param::MatrixMul::ComputeMode::FLOAT32; + } + checker.set_param(param) + .set_dtype(0, stype) + .set_dtype(1, stype) + .set_dtype(2, dtype) + .set_epsilon(dtype == dtype::Float16() + ? 5e-2 + : 5e-3) + .execs({A, B, {}}); + } + } +} + TEST_F(CUDA, MATRIX_MUL) { if (cuda::current_device_prop().major < 6) { printf("Skip CUDA.MATRIX_MUL test as current device doesn't support\n");