From d4217fc6b23190f23cef48deae6fd16317d9c118 Mon Sep 17 00:00:00 2001 From: limingshu <61349199+JamesLim-sy@users.noreply.github.com> Date: Sun, 26 Feb 2023 10:28:47 +0800 Subject: [PATCH] Matmul performance optimization with cuBlasLt (#46431) * implement of matmul using cublasLt instead of cublas * Update matmul_kernel_impl_via_blasLt.h --------- Co-authored-by: zhangbopd <1299246947@qq.com> Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> Co-authored-by: Liu Yiqun --- paddle/phi/kernels/autotune/auto_tune_base.h | 46 +- paddle/phi/kernels/autotune/cache.h | 26 +- paddle/phi/kernels/autotune/cache_base.h | 54 ++ .../phi/kernels/funcs/blas/blaslt_impl.cu.h | 385 +++++++++++++ paddle/phi/kernels/impl/matmul_kernel_impl.h | 514 +++++++++++++++++- 5 files changed, 992 insertions(+), 33 deletions(-) create mode 100644 paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h diff --git a/paddle/phi/kernels/autotune/auto_tune_base.h b/paddle/phi/kernels/autotune/auto_tune_base.h index e01fdd7eb5c..908b6560e28 100644 --- a/paddle/phi/kernels/autotune/auto_tune_base.h +++ b/paddle/phi/kernels/autotune/auto_tune_base.h @@ -141,7 +141,43 @@ class AutoTuneBase { } }; -// To init the auto_tuner object. +template +class MatmulAutoTuner + : public AutoTuneBase> { + public: + static MatmulAutoTuner* Instance( + ReturnType (*func)(Args...)) { + static std::once_flag matmul_init_flag; + static std::unique_ptr> instance; + std::call_once(matmul_init_flag, [&] { + auto obj = MakeCallback(func); + instance.reset(new MatmulAutoTuner); + instance->AddCallBack(func); + }); + return instance.get(); + } + + template + void Run(const Context& ctx, const size_t key, Args... args) { + this->is_init_ = true; + this->CheckKernelSize(); + auto& cache = AutoTuneCache::Instance().GetMatmul(); + if (cache.Find(key)) { + auto best_idx = cache.Get(key); + this->kernels_[best_idx].Run(args...); + } else { + bool use_autotune = AutoTuneStatus::Instance().UseAutoTune(); + if (use_autotune) { + auto best_idx = this->PickBestKernel(ctx, args...); + cache.Set(key, best_idx); + } else { + this->kernels_[0].Run(args...); + } + } + } +}; + +// Define the auto_tuner inital object. #define DEFINE_AUTOTUNER_COMMON_OBJ(name) \ template \ class name##AutoTuner \ @@ -161,7 +197,7 @@ class AutoTuneBase { } \ }; -// To init auto_tuner inital function. +// Define the auto_tuner inital function. #define DEFINE_AUTOTUNER_FN(name) \ template \ static name##AutoTuner* Make##name##Tuner( \ @@ -169,10 +205,12 @@ class AutoTuneBase { return name##AutoTuner::Instance(func); \ } -#define DEFINE_AUTOTUNER(name) \ - DEFINE_AUTOTUNER_COMMON_OBJ(name) DEFINE_AUTOTUNER_FN(name) +#define DEFINE_AUTOTUNER(name) \ + DEFINE_AUTOTUNER_COMMON_OBJ(name) \ + DEFINE_AUTOTUNER_FN(name) DEFINE_AUTOTUNER(Transpose) +DEFINE_AUTOTUNER_FN(Matmul) #undef DEFINE_AUTOTUNER_COMMON_OBJECT #undef DEFINE_AUTOTUNER_FN diff --git a/paddle/phi/kernels/autotune/cache.h b/paddle/phi/kernels/autotune/cache.h index e58b1ff00c9..711c8a063f7 100644 --- a/paddle/phi/kernels/autotune/cache.h +++ b/paddle/phi/kernels/autotune/cache.h @@ -44,28 +44,33 @@ enum class AlgorithmType { kConvBackwardData = 2, kConvBackwardFilter = 3, kTranspose = 4, -#ifdef PADDLE_WITH_CUDNN_FRONTEND - kConvForwardV8 = 5, - kConvBackwardDataV8 = 6, - kConvBackwardFilterV8 = 7, - kAlgorithmCount = 8 + kMatmul = 5, +#if !defined(PADDLE_WITH_CUDNN_FRONTEND) + kAlgorithmCount = 6 #else - kAlgorithmCount = 5 + kConvForwardV8 = 6, + kConvBackwardDataV8 = 7, + kConvBackwardFilterV8 = 8, + kAlgorithmCount = 9 #endif }; // AlgorithmsConfigKey -> AlgorithmsID -// (todo. hong) use cudnnConvolutionFwdAlgo_t -using AlgorithmsCacheMap = AlgorithmsCache; // AlgorithmType -> AlgorithmsCache +using AlgorithmsCacheMap = AlgorithmsCache; using AlgorithmsTypeMap = std::unordered_map; + +// (todo. hong) use cudnnConvolutionFwdAlgo_t using ConvAlgorithmsCacheMap = ConvAlgorithmsCache; using ConvAlgorithmsTypeMap = std::unordered_map; + +using MatmulAlgorithmsCacheMap = MatmulAlgorithmsCache; #ifdef PADDLE_WITH_CUDNN_FRONTEND using CudnnV8AlgorithmsTypeMap = std::unordered_map; #endif + class AutoTuneCache { public: static AutoTuneCache& Instance() { @@ -77,6 +82,8 @@ class AutoTuneCache { return auto_tune_map_[static_cast(algo_type)]; } + MatmulAlgorithmsCacheMap& GetMatmul() { return matmul_auto_tune_map_; } + ConvAlgorithmsCacheMap& GetConv(const AlgorithmType& algo_type) { return conv_auto_tune_map_[static_cast(algo_type)]; } @@ -87,8 +94,6 @@ class AutoTuneCache { } #endif - AlgorithmsCacheMap& GetTranspose() { return Get(AlgorithmType::kTranspose); } - void Clean() { for (auto& v : auto_tune_map_) { v.second.Clean(); @@ -162,6 +167,7 @@ class AutoTuneCache { AlgorithmsTypeMap auto_tune_map_; ConvAlgorithmsTypeMap conv_auto_tune_map_; + MatmulAlgorithmsCacheMap matmul_auto_tune_map_; #ifdef PADDLE_WITH_CUDNN_FRONTEND CudnnV8AlgorithmsTypeMap cudnn_v8_auto_tune_map_; #endif diff --git a/paddle/phi/kernels/autotune/cache_base.h b/paddle/phi/kernels/autotune/cache_base.h index 7fc1c3c3c83..e165cd76df7 100644 --- a/paddle/phi/kernels/autotune/cache_base.h +++ b/paddle/phi/kernels/autotune/cache_base.h @@ -60,6 +60,31 @@ size_t GenKey(Args&&... args) { return seed; } +struct MatmulHashValueType { + uint64_t data[8]; +}; + +struct MatmulCacheKey { + public: + MatmulCacheKey() {} + MatmulCacheKey(const std::vector& x_dims, + const std::vector& y_dims, + const bool trans_x, + const bool trans_y, + phi::DataType dtype) { + key = GenKey(x_dims, + y_dims, + static_cast(trans_x), + static_cast(trans_y), + static_cast(dtype)); + } + size_t GetKey() const { return key; } + size_t GetSubKey(int64_t idx) const { return GenKey(key, idx); } + + private: + size_t key; +}; + struct ConvCacheKey { ConvCacheKey() {} ConvCacheKey(const std::vector& arg_x_dims, @@ -213,5 +238,34 @@ class ConvAlgorithmsCache : public AlgorithmsCache +class MatmulAlgorithmsCache : public AlgorithmsCache { + public: + MatmulAlgorithmsCache() : AlgorithmsCache() {} + + bool FindSubKey(const KeyT& sub_key) { + std::lock_guard lock(*(this->cache_mutex_)); + bool ret = (sub_hash_.find(sub_key) != sub_hash_.end()) ? true : false; + return ret; + } + + void SetSubKey(const KeyT& sub_key, const MatmulHashValueType* algo) { + std::lock_guard lock(*(this->cache_mutex_)); + sub_hash_[sub_key] = *algo; + } + + MatmulHashValueType* GetSubKey(const KeyT& sub_key) { + std::lock_guard lock(*(this->cache_mutex_)); + PADDLE_ENFORCE_NE( + sub_hash_.find(sub_key), + sub_hash_.end(), + phi::errors::PreconditionNotMet("The key does not exist.")); + return &(sub_hash_[sub_key]); + } + + private: + std::unordered_map sub_hash_; +}; + } // namespace autotune } // namespace phi diff --git a/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h new file mode 100644 index 00000000000..d3a2ead2847 --- /dev/null +++ b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h @@ -0,0 +1,385 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 + +#include +#include "cuda.h" // NOLINT +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/kernels/autotune/cache.h" +#include "paddle/phi/kernels/autotune/gpu_timer.h" + +namespace phi { +namespace funcs { + +enum MatmulImplType { kImplWithCublas = 1, kImplWithCublasLt = 2 }; + +template +cudaDataType_t ConvertToCudaDataType() { + if (std::is_same::value) { + return CUDA_R_32F; + } else if (std::is_same::value) { + return CUDA_R_64F; + } else if (std::is_same::value) { + return CUDA_R_16F; + } else if (std::is_same::value) { + return CUDA_R_16BF; + } +} + +template +cublasComputeType_t GetCudaComputeType() { + if (std::is_same::value) { + return CUBLAS_COMPUTE_64F; + } else { + return CUBLAS_COMPUTE_32F; + } +} + +struct MatmulDescriptor { + public: + cublasLtMatmulDesc_t op_desc{nullptr}; + cublasLtMatrixLayout_t x_desc{nullptr}; + cublasLtMatrixLayout_t y_desc{nullptr}; + cublasLtMatrixLayout_t out_desc{nullptr}; + + template + void Create(const int M, + const int N, + const int K, + const bool trans_x, + const bool trans_y, + const int batch_size = 1, + int64_t stride_x = 0, + int64_t stride_y = 0, + int64_t stride_out = 0) { + using MT = typename phi::dtype::MPTypeTrait::Type; + + cudaDataType_t mat_type = ConvertToCudaDataType(); + cudaDataType_t scale_type = ConvertToCudaDataType(); + cublasComputeType_t compute_type = GetCudaComputeType(); + + // Create operation desciriptor; see cublasLtMatmulDescAttributes_t for + // details about defaults; just need to set the transforms for A and B + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulDescCreate(&op_desc, compute_type, scale_type)); + cublasOperation_t cublas_trans_x = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t cublas_trans_y = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N; + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulDescSetAttribute(op_desc, + CUBLASLT_MATMUL_DESC_TRANSB, + &cublas_trans_x, + sizeof(cublas_trans_x))); + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulDescSetAttribute(op_desc, + CUBLASLT_MATMUL_DESC_TRANSA, + &cublas_trans_y, + sizeof(cublas_trans_y))); + + // Create matrix descriptors + CreateMatrixLayout(&x_desc, mat_type, M, K, trans_x); + CreateMatrixLayout(&y_desc, mat_type, K, N, trans_y); + CreateMatrixLayout(&out_desc, mat_type, M, N, false); + + // Config batch size and stride. + if (batch_size > 1) { + SetBatchAndStride(x_desc, batch_size, stride_x); + SetBatchAndStride(y_desc, batch_size, stride_y); + SetBatchAndStride(out_desc, batch_size, stride_out); + } + } + + void CreateMatrixLayout(cublasLtMatrixLayout_t* desc, + cudaDataType type, + uint64_t rows, + uint64_t cols, + bool trans) { + if (trans) { + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatrixLayoutCreate(desc, type, rows, cols, rows)); + } else { + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatrixLayoutCreate(desc, type, cols, rows, cols)); + } + } + + void SetBatchAndStride(cublasLtMatrixLayout_t desc, + int batch_size, + int64_t stride) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutSetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batch_size, + sizeof(batch_size))); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutSetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &stride, + sizeof(stride))); + } + + void Release() { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutDestroy(y_desc)); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutDestroy(x_desc)); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatrixLayoutDestroy(out_desc)); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescDestroy(op_desc)); + + op_desc = nullptr; + x_desc = nullptr; + y_desc = nullptr; + out_desc = nullptr; + } +}; + +template +struct MatmulWithCublasLt { + public: + using MT = typename phi::dtype::MPTypeTrait::Type; + + static void Run(const phi::GPUContext& ctx, + const T* x_data, + const T* y_data, + T* out_data, + const int M, + const int N, + const int K, + const bool trans_x, + const bool trans_y, + phi::autotune::MatmulCacheKey* matmul_key = nullptr) { + MatmulDescriptor desc; + desc.Create(M, N, K, trans_x, trans_y); + RunImpl(ctx, desc, x_data, y_data, out_data, matmul_key); + desc.Release(); + } + + static void RunWithBatch( + const phi::GPUContext& ctx, + const T* x_data, + const T* y_data, + T* out_data, + const int M, + const int N, + const int K, + bool trans_x, + bool trans_y, + int batch_size, + int64_t stride_x, + int64_t stride_y, + int64_t stride_out, + phi::autotune::MatmulCacheKey* matmul_key = nullptr) { + MatmulDescriptor desc; + desc.Create( + M, N, K, trans_x, trans_y, batch_size, stride_x, stride_y, stride_out); + RunImpl(ctx, desc, x_data, y_data, out_data, matmul_key); + desc.Release(); + } + + static void RunWithBatch( + const phi::GPUContext& ctx, + const T** x_data, + const T** y_data, + T** out_data, + const int M, + const int N, + const int K, + bool trans_x, + bool trans_y, + int batch_size, + phi::autotune::MatmulCacheKey* matmul_key = nullptr) { + for (int i = 0; i < batch_size; ++i) { + Run(ctx, + x_data[i], + y_data[i], + out_data[i], + M, + N, + K, + trans_x, + trans_y, + matmul_key); + } + } + + private: + static phi::Allocator::AllocationPtr GetWorkspace(const phi::GPUContext& ctx, + size_t workspace_size) { + return paddle::memory::Alloc( + ctx.GetPlace(), + workspace_size, + phi::Stream(reinterpret_cast(ctx.stream()))); + } + + static void RunImpl(const phi::GPUContext& ctx, + const MatmulDescriptor& desc, + const T* x_ptr, + const T* y_ptr, + T* out_ptr, + phi::autotune::MatmulCacheKey* matmul_key = nullptr) { + MT alpha = static_cast(1); + MT beta = static_cast(0); + + cublasLtHandle_t cublaslt_handle = ctx.cublaslt_handle(); + cublasLtMatmulAlgo_t* best_algo = nullptr; + + size_t workspace_size = static_cast(4) * 1024 * 1024; + phi::Allocator::AllocationPtr workspace = GetWorkspace(ctx, workspace_size); + + if (matmul_key != nullptr) { + auto& cache = phi::autotune::AutoTuneCache::Instance().GetMatmul(); + size_t sub_key = matmul_key->GetSubKey( + static_cast(MatmulImplType::kImplWithCublasLt)); + if (cache.FindSubKey(sub_key)) { + best_algo = + reinterpret_cast(cache.GetSubKey(sub_key)); + } else if (phi::autotune::AutoTuneStatus::Instance().UseAutoTune()) { + cublasLtMatmulAlgo_t test_algo; + SearchBestAlgo(ctx, + cublaslt_handle, + desc.op_desc, + desc.y_desc, + desc.x_desc, + desc.out_desc, + static_cast(&alpha), + static_cast(&beta), + y_ptr, + x_ptr, + out_ptr, + workspace->ptr(), + workspace_size, + &(test_algo)); + cache.SetSubKey( + sub_key, + reinterpret_cast(&test_algo)); + best_algo = &test_algo; + } + } + + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmul( + cublaslt_handle, + desc.op_desc, + static_cast(&alpha), + y_ptr, + desc.y_desc, + x_ptr, + desc.x_desc, + static_cast(&beta), + out_ptr, + desc.out_desc, + out_ptr, + desc.out_desc, + reinterpret_cast(best_algo), + workspace->ptr(), + workspace_size, + ctx.stream())); + } + + static void SearchBestAlgo(const phi::GPUContext& ctx, + const cublasLtHandle_t& lt_handle, + const cublasLtMatmulDesc_t& op_desc, + const cublasLtMatrixLayout_t& y_desc, + const cublasLtMatrixLayout_t& x_desc, + const cublasLtMatrixLayout_t& out_desc, + const void* alpha, + const void* beta, + const void* y_data, + const void* x_data, + void* out_data, + void* workspace_ptr, + size_t workspace_size, + cublasLtMatmulAlgo_t* best_algo) { + const auto& stream = ctx.stream(); + int returned_results = 0; + constexpr int requested_algo_count = 10; + cublasLtMatmulPreference_t preference; + + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulPreferenceCreate(&preference)); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulPreferenceSetAttribute( + preference, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size, + sizeof(workspace_size))); + + std::vector heuristic_results( + requested_algo_count); + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulAlgoGetHeuristic(lt_handle, + op_desc, + y_desc, + x_desc, + out_desc, + out_desc, + preference, + requested_algo_count, + heuristic_results.data(), + &returned_results)); + PADDLE_ENFORCE_GT(returned_results, + 0, + phi::errors::Unavailable("No GEMM algorithm avaliable.")); + + phi::GpuTimer timer; + int best_algo_idx = -1; + constexpr int repeats = 6; + float min_time_cost = std::numeric_limits::max(); + for (int algo_idx = 0; algo_idx < returned_results; ++algo_idx) { + ctx.Wait(); + float cur_time = 0.f; + for (int i = 0; i < repeats; ++i) { + timer.Start(stream); + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmul(lt_handle, + op_desc, + alpha, + y_data, + y_desc, + x_data, + x_desc, + beta, + out_data, + out_desc, + out_data, + out_desc, + &(heuristic_results[algo_idx].algo), + workspace_ptr, + workspace_size, + stream)); + timer.Stop(stream); + auto time = timer.ElapsedTime(); + if (i > 0) { + cur_time += time; + } + } + float time_cnt = (cur_time / (repeats - 1)); + VLOG(4) << "Time cost in MatmulWithCublaslt algo[" << algo_idx << "]" + << "is : " << time_cnt << " s"; + + if (cur_time < min_time_cost) { + best_algo_idx = algo_idx; + min_time_cost = cur_time; + } + } + VLOG(4) << "Best_algo_idx in MatmulWithCublaslt is : " << best_algo_idx; + + *best_algo = heuristic_results[best_algo_idx].algo; + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulPreferenceDestroy(preference)); + } +}; + +} // namespace funcs +} // namespace phi + +#endif diff --git a/paddle/phi/kernels/impl/matmul_kernel_impl.h b/paddle/phi/kernels/impl/matmul_kernel_impl.h index 83855f7296b..4cd8b5eceaf 100644 --- a/paddle/phi/kernels/impl/matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_kernel_impl.h @@ -15,8 +15,13 @@ limitations under the License. */ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/autotune/cache_base.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/complex_functors.h" +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 +#include "paddle/phi/kernels/autotune/auto_tune_base.h" +#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" +#endif namespace phi { @@ -84,16 +89,19 @@ static void IndexIncreaseFromDims(const int ndim, } } +// The general implementation with blas. template -void MatMulFunction(const Context& dev_ctx, - const DenseTensor& X, - const DenseTensor& Y, - const std::vector& x_dims, - const std::vector& y_dims, - DenseTensor* Out, - bool trans_x, - bool trans_y, - bool flag = false) { +void MatMulFunctionImplWithBlas( + const Context& dev_ctx, + const DenseTensor& X, + const DenseTensor& Y, + const std::vector& x_dims, + const std::vector& y_dims, + DenseTensor* Out, + bool trans_x, + bool trans_y, + bool flag = false, + phi::autotune::MatmulCacheKey* matmul_key = nullptr) { const int x_ndim = x_dims.size(); const int y_ndim = y_dims.size(); @@ -471,22 +479,487 @@ void MatMulFunction(const Context& dev_ctx, } } +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 +// This is almost a copy from MatMulFunctionImplWithBlas, +// compare cublas with cublasLt kernels when Matmul autotune is on +template +void MatMulFunctionImplWithCublasLt( + const Context& dev_ctx, + const DenseTensor& X, + const DenseTensor& Y, + const std::vector& x_dims, + const std::vector& y_dims, + DenseTensor* Out, + bool trans_x, + bool trans_y, + bool flag = false, + phi::autotune::MatmulCacheKey* matmul_key = nullptr) { + const int x_ndim = x_dims.size(); + const int y_ndim = y_dims.size(); + const T* x_data = X.data(); + const T* y_data = Y.data(); + using blaslt = phi::funcs::MatmulWithCublasLt; + + if (x_ndim == 1 && y_ndim == 1) { + const int M = X.numel(); + const int N = Y.numel(); + PADDLE_ENFORCE_EQ( + M, + N, + phi::errors::InvalidArgument( + "X's numbers must be equal to Y's numbers," + "when X/Y's dims =1. But received X has [%d] elements," + "received Y has [%d] elements", + M, + N)); + + // MatMul's case 0 => vector * vector + Out->Resize({1}); + dev_ctx.template Alloc(Out); + VLOG(3) << "MatMul with blaslt case 1"; + blaslt::Run(dev_ctx, + y_data, + x_data, + dev_ctx.template Alloc(Out), + 1, + 1, + M, + false, + true, + matmul_key); + return; + } + + if (x_ndim == 1) { + const int N = X.numel(); + if (trans_y) { + PADDLE_ENFORCE_EQ( + y_dims[y_ndim - 1], + N, + phi::errors::InvalidArgument("Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 1, + N, + y_ndim - 1, + y_dims[y_ndim - 1])); + } else { + PADDLE_ENFORCE_EQ( + y_dims[y_ndim - 2], + N, + phi::errors::InvalidArgument("Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 2, + N, + y_ndim - 2, + y_dims[y_ndim - 2])); + } + std::vector out_dims(y_ndim - 1); + if (trans_y) { + std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin()); + } else { + std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin()); + out_dims.back() = y_dims.back(); + } + Out->ResizeAndAllocate(phi::make_ddim(out_dims)); + dev_ctx.template Alloc(Out); + if (trans_y) { + const int M = Y.numel() / N; + VLOG(3) << "MatMul with blaslt 2"; + blaslt::Run(dev_ctx, + y_data, + x_data, + dev_ctx.template Alloc(Out), + M, + 1, + N, + false, + false, + matmul_key); + } else { + const int M = y_dims[y_ndim - 1]; + const int batch_size = Y.numel() / (M * N); + if (batch_size == 1) { + VLOG(3) << "MatMul with blaslt 3"; + blaslt::Run(dev_ctx, + y_data, + x_data, + dev_ctx.template Alloc(Out), + M, + 1, + N, + true, + false, + matmul_key); + } else { + VLOG(3) << "MatMul with blaslt 4"; + blaslt::RunWithBatch(dev_ctx, + y_data, + x_data, + dev_ctx.template Alloc(Out), + M, + 1, + N, + true, + false, + batch_size, + M * N, + 0, + M, + matmul_key); + } + } + return; + } + + if (y_ndim == 1) { + const int N = Y.numel(); + if (trans_x) { + PADDLE_ENFORCE_EQ( + x_dims[x_ndim - 2], + N, + phi::errors::InvalidArgument("Input(X) has error dim." + "X'dims[%d] must be equal to %d" + "But received X'dims[%d] is %d", + x_ndim - 2, + N, + x_ndim - 2, + x_dims[x_ndim - 2])); + } else { + PADDLE_ENFORCE_EQ( + x_dims[x_ndim - 1], + N, + phi::errors::InvalidArgument("Input(X) has error dim." + "X'dims[%d] must be equal to %d" + "But received X'dims[%d] is %d", + x_ndim - 1, + N, + x_ndim - 1, + x_dims[x_ndim - 1])); + } + std::vector out_dims(x_ndim - 1); + if (trans_x) { + std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin()); + out_dims.back() = x_dims.back(); + } else { + std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin()); + } + Out->ResizeAndAllocate(phi::make_ddim(out_dims)); + dev_ctx.template Alloc(Out); + + if (trans_x) { + const int M = x_dims[x_ndim - 1]; + const int batch_size = X.numel() / (M * N); + if (batch_size == 1) { + VLOG(3) << "MatMul with blaslt 5"; + blaslt::Run(dev_ctx, + x_data, + y_data, + dev_ctx.template Alloc(Out), + M, + 1, + N, + true, + false, + matmul_key); + } else { + VLOG(3) << "MatMul with blaslt 6"; + blaslt::RunWithBatch(dev_ctx, + x_data, + y_data, + dev_ctx.template Alloc(Out), + M, + 1, + N, + true, + false, + batch_size, + M * N, + 0, + M, + matmul_key); + } + } else { + const int M = X.numel() / N; + VLOG(3) << "MatMul with blaslt 7"; + blaslt::Run(dev_ctx, + x_data, + y_data, + dev_ctx.template Alloc(Out), + M, + 1, + N, + false, + false, + matmul_key); + } + return; + } + + const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; + const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; + if (trans_y) { + PADDLE_ENFORCE_EQ( + y_dims[y_ndim - 1], + K, + phi::errors::InvalidArgument("Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 1, + K, + y_ndim - 1, + y_dims[y_ndim - 1])); + } else { + PADDLE_ENFORCE_EQ( + y_dims[y_ndim - 2], + K, + phi::errors::InvalidArgument("Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 2, + K, + y_ndim - 2, + y_dims[y_ndim - 2])); + } + const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; + const int ndim = (std::max)(x_ndim, y_ndim); + std::vector x_broadcast_dims(ndim); + std::vector y_broadcast_dims(ndim); + std::vector out_broadcast_dims(ndim); + + GetBroadcastFromDims(x_ndim - 2, + x_dims.data(), + y_ndim - 2, + y_dims.data(), + x_broadcast_dims.data(), + y_broadcast_dims.data(), + out_broadcast_dims.data()); + out_broadcast_dims[ndim - 2] = M; + out_broadcast_dims[ndim - 1] = N; + + Out->ResizeAndAllocate(phi::make_ddim(out_broadcast_dims)); + dev_ctx.template Alloc(Out); + + const int batch_dim = ndim - 2; + // broadcast message + const bool is_broadcast_dims = + !std::equal(x_broadcast_dims.cbegin(), + x_broadcast_dims.cbegin() + batch_dim, + y_broadcast_dims.cbegin()); + + const std::int64_t x_batch_size = + std::accumulate(x_broadcast_dims.cbegin(), + x_broadcast_dims.cbegin() + batch_dim, + 1LL, + std::multiplies()); + const std::int64_t y_batch_size = + std::accumulate(y_broadcast_dims.cbegin(), + y_broadcast_dims.cbegin() + batch_dim, + 1LL, + std::multiplies()); + const std::int64_t out_batch_size = + std::accumulate(out_broadcast_dims.cbegin(), + out_broadcast_dims.cbegin() + batch_dim, + 1LL, + std::multiplies()); + if (out_batch_size == 0) return; + if (x_batch_size == 1 && y_batch_size == 1) { + VLOG(3) << "MatMul with blaslt 8"; + blaslt::Run(dev_ctx, + x_data, + y_data, + dev_ctx.template Alloc(Out), + M, + N, + K, + trans_x, + trans_y, + matmul_key); + } else if (x_batch_size == 1) { + if (M == 1 && trans_y) { + VLOG(3) << "MatMul with blaslt 9"; + blaslt::Run(dev_ctx, + y_data, + x_data, + dev_ctx.template Alloc(Out), + y_batch_size * N, + 1, + K, + false, + false, + matmul_key); + } else { + VLOG(3) << "MatMul with blaslt 10"; + blaslt::RunWithBatch(dev_ctx, + x_data, + y_data, + dev_ctx.template Alloc(Out), + M, + N, + K, + trans_x, + trans_y, + out_batch_size, + 0, + K * N, + M * N, + matmul_key); + } + } else if (y_batch_size == 1) { + if (!trans_x) { + VLOG(3) << "MatMul with blaslt 11"; + blaslt::Run(dev_ctx, + x_data, + y_data, + dev_ctx.template Alloc(Out), + x_batch_size * M, + N, + K, + false, + trans_y, + matmul_key); + } else { + VLOG(3) << "MatMul with blaslt 12"; + blaslt::RunWithBatch(dev_ctx, + x_data, + y_data, + dev_ctx.template Alloc(Out), + M, + N, + K, + true, + trans_y, + out_batch_size, + M * K, + 0, + M * N, + matmul_key); + } + } else if (!is_broadcast_dims) { + VLOG(3) << "MatMul with blaslt 13"; + blaslt::RunWithBatch(dev_ctx, + x_data, + y_data, + dev_ctx.template Alloc(Out), + M, + N, + K, + trans_x, + trans_y, + out_batch_size, + M * K, + K * N, + M * N, + matmul_key); + } else { + // in the case, can't use stridedgemm + std::vector x_ptr(out_batch_size); + std::vector y_ptr(out_batch_size); + std::vector out_ptr(out_batch_size); + std::vector index(batch_dim, 0); + for (std::int64_t i = 0; i < out_batch_size; ++i) { + // using the index to get offset + const std::int64_t x_index = + GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data()); + const std::int64_t y_index = + GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data()); + + x_ptr[i] = x_data + x_index * M * K; + y_ptr[i] = y_data + y_index * K * N; + out_ptr[i] = dev_ctx.template Alloc(Out) + i * M * N; + IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data()); + } + VLOG(3) << "MatMul with blaslt 14"; + blaslt::RunWithBatch(dev_ctx, + x_ptr.data(), + y_ptr.data(), + out_ptr.data(), + M, + N, + K, + trans_x, + trans_y, + out_batch_size, + matmul_key); + } +} +#endif + +template +struct MatMulDispatcher { + void operator()(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const std::vector& x_dims, + const std::vector& y_dims, + DenseTensor* out, + bool trans_x, + bool trans_y, + bool flag = false) { + MatMulFunctionImplWithBlas( + ctx, x, y, x_dims, y_dims, out, trans_x, trans_y, flag); + } +}; + +#ifdef PADDLE_WITH_CUDA +template +struct MatMulDispatcher { + void operator()(const phi::GPUContext& ctx, + const DenseTensor& x, + const DenseTensor& y, + const std::vector& x_dims, + const std::vector& y_dims, + DenseTensor* out, + bool trans_x, + bool trans_y, + bool flag = false) { +#if CUDA_VERSION >= 11060 + auto* tuner = phi::autotune::MakeMatmulTuner( + MatMulFunctionImplWithBlas); + tuner->AddCallBack(MatMulFunctionImplWithCublasLt); + phi::autotune::MatmulCacheKey matmul_cache( + x_dims, + y_dims, + trans_x, + trans_y, + paddle::experimental::CppTypeToDataType::Type()); + tuner->Run(ctx, + matmul_cache.GetKey(), + ctx, + x, + y, + x_dims, + y_dims, + out, + trans_x, + trans_y, + flag, + &matmul_cache); +#else + MatMulFunctionImplWithBlas( + ctx, x, y, x_dims, y_dims, out, trans_x, trans_y, flag); +#endif + } +}; +#endif // PADDLE_WITH_CUDA + template -void MatMulFunction(const Context& dev_ctx, - const DenseTensor& X, - const DenseTensor& Y, - DenseTensor* Out, +void MatMulFunction(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const std::vector& x_dims, + const std::vector& y_dims, + DenseTensor* out, bool trans_x, bool trans_y, bool flag = false) { - const std::vector x_dims = vectorize(X.dims()); - const std::vector y_dims = vectorize(Y.dims()); - MatMulFunction( - dev_ctx, X, Y, x_dims, y_dims, Out, trans_x, trans_y, flag); + MatMulDispatcher()( + ctx, x, y, x_dims, y_dims, out, trans_x, trans_y, flag); } template -void MatmulKernel(const Context& dev_ctx, +void MatmulKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& y, bool transpose_x, @@ -502,7 +975,10 @@ void MatmulKernel(const Context& dev_ctx, 0, phi::errors::InvalidArgument("The Input(Y) dims size must not be equal 0," " but reviced dims size is 0. ")); - MatMulFunction(dev_ctx, x, y, out, transpose_x, transpose_y); + const std::vector x_dims = vectorize(x.dims()); + const std::vector y_dims = vectorize(y.dims()); + MatMulFunction( + ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y); } template -- GitLab