#pragma once #include #include #include #include #include #include #include #include #include "StopWatch.h" #include "cublas_wrappers.h" template void check(T result, char const* const func, const char* const file, int const line) { if (result) { std::cout << (std::string("CUDA runtime error: ") + +file + ":" + std::to_string(line) + " \n"); } } #define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) template class GemmTest { public: GemmTest(int m, int n, int k, cublasOperation_t ta, cublasOperation_t tb, cublasHandle_t h) : M(m), N(n), K(k), transa(ta), transb(tb), handle(h) { check_cuda_error(cudaMalloc((void**)&A, sizeof(T) * M * K)); check_cuda_error(cudaMalloc((void**)&B, sizeof(T) * K * N)); check_cuda_error(cudaMalloc((void**)&C, sizeof(T) * M * N)); } ~GemmTest() { check_cuda_error(cudaFree(A)); check_cuda_error(cudaFree(B)); check_cuda_error(cudaFree(C)); } std::array TestAlgo(int loops) { float alpha = (T)1.0f; float beta = (T)0.0f; int algo_fw = Run(loops, [=](int algo) { cublas_gemm_ex(handle, CUBLAS_OP_T, CUBLAS_OP_N, N, M, K, &alpha, &beta, B, A, C, static_cast(algo)); }); int algo_bw1 = Run(loops, [=](int algo) { cublas_gemm_ex(handle, CUBLAS_OP_N, CUBLAS_OP_T, K, N, M, &alpha, &beta, A, C, B, static_cast(algo)); }); int algo_bw2 = Run(loops, [=](int algo) { cublas_gemm_ex(handle, CUBLAS_OP_N, CUBLAS_OP_N, K, M, N, &alpha, &beta, B, C, A, static_cast(algo)); }); return std::array({algo_fw, algo_bw1, algo_bw2}); } template int Run(int loops, Func f) { float fast_latency = (std::numeric_limits::max)(); int fast_algo = 0; for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; algo++) { int warm_up = 5; for (int i = 0; i < warm_up; ++i) f(algo); cudaDeviceSynchronize(); Stopwatch timer; timer.Restart(); for (int i = 0; i < loops; ++i) f(algo); cudaDeviceSynchronize(); timer.Stop(); float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops; printf("algo-%d: %.3fms\n", algo, avg_latency); if (avg_latency < fast_latency) { fast_latency = avg_latency; fast_algo = algo; } } printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency); return fast_algo; } private: int M, N, K; cublasHandle_t handle; cublasOperation_t transa, transb; T *A, *B, *C; }; template class StridedGemmTest { public: StridedGemmTest(int b, int m, int n, int k, cublasOperation_t ta, cublasOperation_t tb, cublasHandle_t h) : bsz(b), M(m), N(n), K(k), transa(ta), transb(tb), handle(h) { check_cuda_error(cudaMalloc((void**)&A, sizeof(T) * M * K * bsz)); check_cuda_error(cudaMalloc((void**)&B, sizeof(T) * K * N * bsz)); check_cuda_error(cudaMalloc((void**)&C, sizeof(T) * M * N * bsz)); } ~StridedGemmTest() { check_cuda_error(cudaFree(A)); check_cuda_error(cudaFree(B)); check_cuda_error(cudaFree(C)); } std::array TestAlgo(int loops) { float alpha = (T)1.0f; float beta = (T)0.0f; int algo_fw = Run(loops, [=](int algo) { int stride_a = M * K; int stride_b = N * K; int stride_c = M * N; cublas_strided_batched_gemm(handle, M, N, K, &alpha, &beta, A, B, C, transa, transb, stride_a, stride_b, stride_c, bsz, static_cast(algo)); }); int algo_bw1 = Run(loops, [=](int algo) { int mb = (transa == CUBLAS_OP_T ? K : M); int kb = (transa == CUBLAS_OP_T ? M : K); int stride_a = mb * N; int stride_b = N * kb; int stride_c = M * K; // B need to transpose. cublasOperation_t op_b = (transb == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); // Calculate d_A. cublas_strided_batched_gemm(handle, mb, kb, N, &alpha, &beta, (transa == CUBLAS_OP_T ? B : C), (transa == CUBLAS_OP_T ? C : B), A, CUBLAS_OP_N, op_b, stride_a, stride_b, stride_c, bsz, static_cast(algo)); }); int algo_bw2 = Run(loops, [=](int algo) { // A need to transpose. cublasOperation_t op_a = (transa == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); int stride_a = M * K; int stride_b = M * N; int stride_c = N * K; // Calculate d_B. cublas_strided_batched_gemm(handle, K, N, M, &alpha, &beta, A, C, B, op_a, CUBLAS_OP_N, stride_a, stride_b, stride_c, bsz, static_cast(algo)); }); return std::array({algo_fw, algo_bw1, algo_bw2}); } template int Run(int loops, Func f) { float fast_latency = (std::numeric_limits::max)(); int fast_algo = 0; for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; algo++) { int warm_up = 5; for (int i = 0; i < warm_up; ++i) f(algo); cudaDeviceSynchronize(); Stopwatch timer; timer.Restart(); for (int i = 0; i < loops; ++i) f(algo); cudaDeviceSynchronize(); timer.Stop(); float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops; printf("algo-%d: %.3fms\n", algo, avg_latency); if (avg_latency < fast_latency) { fast_latency = avg_latency; fast_algo = algo; } } printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency); return fast_algo; } private: int bsz, M, N, K; cublasHandle_t handle; cublasOperation_t transa, transb; T *A, *B, *C; };