From 71a63f0a9be78d371a648a7cc97456857cadf718 Mon Sep 17 00:00:00 2001 From: limingshu <61349199+JamesLim-sy@users.noreply.github.com> Date: Tue, 7 Jun 2022 11:28:49 +0800 Subject: [PATCH] Transpose optimization with assitant of Chengdu Supercomputing Center and auto_tune operation (#42704) --- paddle/fluid/operators/transpose_op.cu.h | 432 +++++++++++++++++- paddle/fluid/operators/transpose_op.h | 178 ++++++++ paddle/fluid/platform/fast_divmod.h | 2 +- paddle/phi/kernels/autotune/auto_tune_base.h | 114 +++-- paddle/phi/kernels/autotune/auto_tune_test.cu | 22 +- paddle/phi/kernels/autotune/cache.h | 5 +- 6 files changed, 709 insertions(+), 44 deletions(-) diff --git a/paddle/fluid/operators/transpose_op.cu.h b/paddle/fluid/operators/transpose_op.cu.h index 40a967b11f7..f9d91fec4c3 100644 --- a/paddle/fluid/operators/transpose_op.cu.h +++ b/paddle/fluid/operators/transpose_op.cu.h @@ -17,8 +17,12 @@ limitations under the License. */ #include "paddle/fluid/framework/gpu_utils.h" #include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/fluid/platform/fast_divmod.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/kernels/autotune/auto_tune_base.h" +#include "paddle/phi/kernels/autotune/cache.h" +#include "paddle/phi/kernels/copy_kernel.h" namespace paddle { namespace operators { @@ -656,13 +660,437 @@ struct TransposeSimple { } }; +template +class IdxHelper { + public: + IdxHelper() {} + explicit IdxHelper(const T* dims) { + for (int i = N - 1; i >= 0; --i) { + stride_[i] = i < (N - 1) ? dims[i + 1] * stride_[i + 1] : 1; + } + } + + __device__ inline T GetStride(int idx) const { return stride_[idx]; } + + __device__ inline void GetIndexFromOffset(T offset, T* index) const { + T remaining = offset; +#pragma unroll + for (int i = 0; i < N - 1; ++i) { + const T idx = remaining / stride_[i]; + remaining -= idx * stride_[i]; + index[i] = idx; + } + index[N - 1] = remaining; + } + + private: + T stride_[N]; +}; + +template +class IdxHelper { + public: + IdxHelper() {} + explicit IdxHelper(const uint32_t* dims) { + for (int i = N - 1; i >= 0; --i) { + uint32_t value = i < (N - 1) ? dims[i + 1] * stride_[i + 1] : 1; + divmoder_[i] = paddle::platform::FastDivMod(value); + stride_[i] = value; + } + } + + __device__ inline uint32_t GetStride(int idx) const { return stride_[idx]; } + + __device__ inline void GetIndexFromOffset(uint32_t offset, + uint32_t* index) const { + uint32_t remaining = offset; +#pragma unroll + for (int i = 0; i < N - 1; ++i) { + uint32_t idx = divmoder_[i].Div(remaining); + index[i] = idx; + remaining -= idx * stride_[i]; + } + index[N - 1] = remaining; + } + + private: + uint32_t stride_[N]; + paddle::platform::FastDivMod divmoder_[N]; +}; + +// Transform index between memory offset and shape coodinate. +template +class IdxAndOffsetHelper { + public: + IdxAndOffsetHelper() {} + ~IdxAndOffsetHelper() = default; + + explicit IdxAndOffsetHelper(const T* dims) { + index_helper = IdxHelper(dims); + } + + template + explicit IdxAndOffsetHelper(const U* dims) { + T temp_dims[N]; + for (int i = 0; i < N; ++i) { + temp_dims[i] = static_cast(dims[i]); + } + index_helper = IdxHelper(temp_dims); + } + + __device__ inline T IndexToOffset(const T* index) const { + T offset = 0; +#pragma unroll + for (int i = 0; i < N - 1; ++i) { + offset += index[i] * index_helper.GetStride(i); + } + offset += index[N - 1]; + return offset; + } + + __device__ inline void OffsetToIndex(T offset, T* index) const { + index_helper.GetIndexFromOffset(offset, index); + } + + private: + IdxHelper index_helper; +}; + +template +struct PermuteParams { + public: + IdxAndOffsetHelper src_index_helper; + IdxAndOffsetHelper dst_index_helper; + int perm[Rank]{}; + + explicit PermuteParams(const std::vector& dims, + const std::vector& perm_) { + size_t dst_dims[Rank]; + for (size_t i = 0; i < Rank; ++i) { + dst_dims[i] = dims[perm_[i]]; + perm[i] = perm_[i]; + } + dst_index_helper = IdxAndOffsetHelper(dst_dims); + src_index_helper = IdxAndOffsetHelper(dims.data()); + } +}; + +// A special kernel for target case, both vectorized read and write supported. +template +__global__ void VectorizedPermuteKernel(PermuteParams params, + const size_t count, + const T* __restrict__ src_data, + T* dst_data) { + using VecT = phi::AlignedVector; + IndexT src_index[Rank]; + IndexT dst_index[Rank]; + + const VecT* __restrict__ src = + reinterpret_cast(src_data); + VecT* dst = reinterpret_cast(dst_data); + + IndexT tid = blockIdx.x * blockDim.x + threadIdx.x; + for (IndexT i = tid; i < count; i += blockDim.x * gridDim.x) { + params.dst_index_helper.OffsetToIndex(i, dst_index); + +#pragma unroll + for (int j = 0; j < Rank; ++j) { + src_index[params.perm[j]] = dst_index[j]; + } + IndexT src_offset = params.src_index_helper.IndexToOffset(src_index); + dst[i] = src[src_offset]; + } +} + +// A general kernel for normal case, only support vectorized write. +template +__global__ void GeneralPermuteKernel(PermuteParams params, + const T* __restrict__ src, T* dst, + const size_t main_cnt, + const size_t tail_cnt, + const size_t offset) { + using VecT = phi::AlignedVector; + VecT* vec_dst = reinterpret_cast(dst); + + IndexT src_index[VecSize][Rank]; + IndexT dst_index[VecSize][Rank]; + + // Avoid read perm data both in 2 load process. + __shared__ int perm[Rank]; + if (threadIdx.x < Rank) { + perm[threadIdx.x] = params.perm[threadIdx.x]; + } + __syncthreads(); + + // Vectorized load data. + IndexT tid = blockIdx.x * blockDim.x + threadIdx.x; + for (IndexT idx = tid; idx < main_cnt; idx += blockDim.x * gridDim.x) { + VecT vec_data; + IndexT vec_idx = idx * VecSize; + +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + params.dst_index_helper.OffsetToIndex(vec_idx + i, dst_index[i]); + +#pragma unroll + for (int j = 0; j < Rank; ++j) { + src_index[i][perm[j]] = dst_index[i][j]; + } + IndexT src_offset = params.src_index_helper.IndexToOffset(src_index[i]); + vec_data[i] = src[src_offset]; + } + vec_dst[idx] = vec_data; + } + + // Singularized load data. + if (tid < tail_cnt) { + IndexT idx = tid + offset; + params.dst_index_helper.OffsetToIndex(idx, dst_index[0]); + +#pragma unroll + for (int j = 0; j < Rank; ++j) { + src_index[0][perm[j]] = dst_index[0][j]; + } + IndexT src_offset = params.src_index_helper.IndexToOffset(src_index[0]); + dst[idx] = src[src_offset]; + } +} + +// A Gerneral permute method that drectly find the dst data +// coordinate in the source data. +template +inline void LaunchPermuteKernel(const phi::GPUContext& ctx, const IndexT count, + const PermuteType perm_type, + const std::vector& dims, + const std::vector& perm, const T* src, + T* dst) { + size_t main_count = count / VecSize; + auto params = PermuteParams(dims, perm); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, main_count); + + if (perm_type == PermuteType::kNormalPermute) { + size_t tail_count = count - main_count * VecSize; + size_t offset = count - tail_count; + GeneralPermuteKernel< + T, IndexT, VecSize, + Rank><<>>( + params, src, dst, main_count, tail_count, offset); + } else { + VectorizedPermuteKernel< + T, IndexT, VecSize, + Rank><<>>( + params, main_count, src, dst); + } +} + +template +inline void LaunchPermuteRankDispatch(const phi::GPUContext& ctx, + const IndexT count, + const PermuteType perm_type, + const std::vector& dims, + const std::vector& perm, + const T* src, T* dst) { +#define CALL_DISPATCH_RANK(rank) \ + case rank: { \ + LaunchPermuteKernel(ctx, count, perm_type, dims, \ + perm, src, dst); \ + break; \ + } + + switch (dims.size()) { + CALL_DISPATCH_RANK(1); + CALL_DISPATCH_RANK(2); + CALL_DISPATCH_RANK(3); + CALL_DISPATCH_RANK(4); + CALL_DISPATCH_RANK(5); + CALL_DISPATCH_RANK(6); + CALL_DISPATCH_RANK(7); + CALL_DISPATCH_RANK(8); + CALL_DISPATCH_RANK(9); + } +#undef CALL_DISPATCH_RANK +} + +// Aim at transposing the last 2 dimensions. Refer from +// https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/ +template +__global__ void BatchTransposeKernel(const T* __restrict__ src_data, + T* dst_data, IndexT rows, IndexT cols) { + using VecT = phi::AlignedVector; + + __shared__ VecT tile[kTileSize][kShareCol]; + T* single_tile = reinterpret_cast(tile); + + IndexT col_in_matrix = blockIdx.x * kTileSize + threadIdx.x; + IndexT offset = blockIdx.z * rows * cols; + + // Vectorized load data from src into shared memory. [rows, cols] + const VecT* __restrict__ src = + reinterpret_cast(src_data); + + for (IndexT tile_y = threadIdx.y; tile_y < kTileSize; tile_y += kBlockRows) { + IndexT row_in_matrix = tile_y + blockIdx.y * kTileSize; + + if (col_in_matrix < cols && row_in_matrix < rows) { + tile[tile_y][threadIdx.x] = + src[offset + row_in_matrix * cols + col_in_matrix]; + } + } + + // Singularized load data from shared memory into dst. + // and dst_cols = rows, dst_rows = cols, [cols * Vecsize, rows] + col_in_matrix = blockIdx.y * kTileSize + threadIdx.x; + offset = offset * VecSize + col_in_matrix; + IndexT tile_x_idx = threadIdx.x * (kShareCol * VecSize); + + __syncthreads(); + + for (IndexT tile_y = threadIdx.y; tile_y < kTileSize; tile_y += kBlockRows) { + IndexT row_in_matrix = tile_y + blockIdx.x * kTileSize; + IndexT dst_idx = offset + row_in_matrix * VecSize * rows; + IndexT tile_idx = tile_x_idx + tile_y * VecSize; + if (col_in_matrix < /*dst_cols=*/rows && + row_in_matrix < /*dst_rows=*/cols) { +#pragma unroll + for (auto i = 0; i < VecSize; ++i) { + dst_data[dst_idx + i * rows] = single_tile[tile_idx + i]; + } + } + } +} + +// With the byte limitation of shared_memory, the VecSize shall be restricted +// for the type whose byte-size is less than 8. +template 8 ? 1 : Size)> +inline void LaunchTransposeKernel(const phi::GPUContext& ctx, + const std::vector& dims, const T* src, + T* dst) { + auto rank = dims.size(); + IndexT num_batches = (rank == 2) ? 1 : dims[0]; + IndexT rows = dims[rank - 2]; + IndexT cols = dims[rank - 1]; + IndexT num_tile_rows = (rows + kTileSize - 1) / kTileSize; + IndexT num_tile_cols = (cols + kTileSize - 1) / kTileSize; + + dim3 blocks(num_tile_cols, num_tile_rows, num_batches); + dim3 threads(kTileSize, kBlockRows, 1); + + BatchTransposeKernel<<>>( + src, dst, rows, cols); +} + +template +inline void LaunchWithDispatchVecSize(const phi::GPUContext& ctx, + const int vec_size, + const PermuteType perm_type, + const std::vector& dims, + const std::vector& perm, + const T* src, T* dst, IndexT count) { +#define CALL_DISPATCH_VEC_SIZE(vec_size) \ + case vec_size: { \ + if (perm_type == PermuteType::kTranspose) { \ + LaunchTransposeKernel(ctx, dims, src, dst); \ + } else { \ + LaunchPermuteRankDispatch(ctx, count, perm_type, \ + dims, perm, src, dst); \ + } \ + break; \ + } + + switch (vec_size) { + CALL_DISPATCH_VEC_SIZE(1); + CALL_DISPATCH_VEC_SIZE(2); + CALL_DISPATCH_VEC_SIZE(4); + default: { + PADDLE_THROW(phi::errors::Unimplemented( + "Unsupported vectorized size: %d !", vec_size)); + break; + } + } +#undef CALL_DISPATCH_VEC_SIZE +} + +template +inline void LaunchWithDispatchIndex(const phi::GPUContext& ctx, + const size_t count, const int vec_size, + const PermuteType perm_type, + const std::vector& dims, + const std::vector& perm, const T* src, + T* dst) { + if (count < std::numeric_limits::max()) { + LaunchWithDispatchVecSize(ctx, vec_size, perm_type, dims, perm, + src, dst, + static_cast(count)); + } else { + int64_t cnt = static_cast(count); + LaunchWithDispatchVecSize(ctx, vec_size, perm_type, dims, perm, + src, dst, + static_cast(count)); + } +} + +template +inline void SimplifyThenLaunch(const int rank, const DeviceContext& ctx, + const Tensor& in, Tensor* out, + const std::vector& perm) { + int sm_count = ctx.GetSMCount(); + auto src_dims = phi::vectorize(in.dims()); + auto simplifier = DimsSimplifier(sm_count, rank, perm, src_dims, + in.data(), out->data()); + + if (simplifier.GetPermType() == PermuteType::kCopy) { + // If perm is [0,1,2,3], then just operate a DtoD copy. + phi::Copy(ctx, in, ctx.GetPlace(), false, out); + } else { + LaunchWithDispatchIndex( + ctx, simplifier.GetCount(), simplifier.GetVecSize(), + simplifier.GetPermType(), simplifier.GetDims(), simplifier.GetPerm(), + in.data(), out->data()); + } +} + +template +size_t GetTransposeKey(const int rank, const Tensor& in, + const std::vector& perm) { + auto in_shape = phi::vectorize(in.dims()); + return phi::autotune::GetKey( + in_shape, perm, rank, paddle::experimental::CppTypeToDataType::Type()); +} + template -void TransposeGPUKernelDriver(const phi::GPUContext& dev_ctx, const int ndims, +void TransposeGPUKernelDriver(const phi::GPUContext& dev_ctx, const int rank, const Tensor& in, const std::vector& perm, Tensor* out) { + PADDLE_ENFORCE_LT( + rank, phi::DDim::kMaxRank, + platform::errors::OutOfRange( + "The maximum dimension rank of " + "tensor is expected to be less than %d, but here is %d.", + phi::DDim::kMaxRank, rank)); + auto ret = TransposeSimple::run(dev_ctx, in, perm, out); if (!ret) { - TransCompute(ndims, dev_ctx, in, out, perm); + auto* tuner = phi::autotune::MakeTransposeTuner( + SimplifyThenLaunch); + if (!tuner->IsInit()) { + tuner->AddCallBack( + phi::autotune::MakeCallback(TransCompute)); + tuner->Finalize(); + } + + auto key = GetTransposeKey(rank, in, perm); + auto& cache = phi::autotune::AutoTuneCache::Instance().GetTranspose(); + if (cache.Find(key)) { + auto index = cache.Get(key); + tuner->RunBestKernel(index, rank, dev_ctx, in, out, perm); + } else { + // All avaliable kernels have ran while picking the best kernel, so + // there may be no need for another RunBestKernel. + auto index = tuner->PickBestKernel(dev_ctx, rank, dev_ctx, in, out, perm); + cache.Set(key, index); + } } } diff --git a/paddle/fluid/operators/transpose_op.h b/paddle/fluid/operators/transpose_op.h index 891aa312f69..ca57687ea5f 100644 --- a/paddle/fluid/operators/transpose_op.h +++ b/paddle/fluid/operators/transpose_op.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { @@ -60,5 +61,182 @@ inline void TransCompute(const int dim, const DeviceContext& dev_ctx, } } +enum PermuteType { + kCopy = 1, + kTranspose = 2, + kVecPermute = 3, + kNormalPermute = 4 +}; + +constexpr int kBlockRows = 16; +constexpr int kTileSize = 32; +// To avoid bank conflict. +constexpr int kShareCol = kTileSize + 1; + +// Simplify the input dims and permute dims if possible. +template +class DimsSimplifier { + public: + explicit DimsSimplifier(const int sm_count, const int rank, + const std::vector& perm, + const std::vector& dims, const T* src, T* dst) + : perm_(rank), dims_(rank) { + SimplifyPermAndDims(rank, dims, perm); + count_ = std::accumulate(dims.begin(), dims.end(), size_t{1}, + std::multiplies()); + if (rank_ > 1) { + vec_size_ = GetPermVecSize(sm_count, src, dst); + perm_.resize(rank_); + dims_.resize(rank_); + } + } + + size_t GetCount() const { return count_; } + int GetVecSize() const { return vec_size_; } + PermuteType GetPermType() const { return type_; } + + std::vector GetPerm() const { return perm_; } + std::vector GetDims() const { return dims_; } + + private: + size_t rank_{1}; + size_t count_{0}; + int vec_size_{1}; + std::vector perm_; + std::vector dims_; + PermuteType type_{kCopy}; + + void SimplifyPermAndDims(const size_t rank, + const std::vector& in_dims, + const std::vector& perm) { + size_t combined_dims[phi::DDim::kMaxRank]; + int valid_map[phi::DDim::kMaxRank]; + + // Merge consecutive dims to the fist one of this these dims, + // and leave the origin dim value to be 1. Example below : + // perm: [2, 3, 0, 1], origin_dims : [4, 8, 2, 5] + // new_dims: [4, 8, 2, 5] -> [32, 1, 10, 1] + size_t start_perm_idx = 0; + while (start_perm_idx < rank) { + const size_t start_dim_idx = perm[start_perm_idx]; + combined_dims[start_dim_idx] = in_dims[start_dim_idx]; + size_t end_perm_idx = start_perm_idx + 1; + + while (end_perm_idx < rank && + perm[end_perm_idx] == perm[end_perm_idx - 1] + 1) { + const size_t end_dim_idx = perm[end_perm_idx]; + combined_dims[start_dim_idx] *= in_dims[end_dim_idx]; + combined_dims[end_dim_idx] = 1; + end_perm_idx += 1; + } + start_perm_idx = end_perm_idx; + } + + // Reorder combined dims and marked useless dim as -1. + // for example, if combined dims is [32, 1, 10, 1], + // valid_map is [0, -1, 1, -1] and generate simplified + // dims as [32, 10] + size_t valid_dim_idx = 0; + bool sequential_flag = false; + for (size_t i = 0; i < rank; ++i) { + const int src_dim = combined_dims[i]; + if (src_dim == 1) { + valid_map[i] = -1; + } else { + sequential_flag = true; + valid_map[i] = valid_dim_idx; + dims_[valid_dim_idx] = src_dim; + valid_dim_idx += 1; + } + } + + if (valid_dim_idx == 0) { + dims_[0] = 1; + perm_[0] = 0; + return; + } else if (valid_dim_idx == 1) { + type_ = PermuteType::kCopy; + } + + // Acquire simplified perm with help of combined dims + // and original perm, finally simplified perm is [1, 0] + size_t perm_idx = 0; + for (size_t i = 0; i < rank; ++i) { + const int mapped = valid_map[perm[i]]; + if (mapped >= 0) { + perm_[perm_idx] = mapped; + perm_idx += 1; + } + } + rank_ = valid_dim_idx; + } + + int GetPermVecSize(const int sm_count, const T* src, T* dst) { + // For gerneal_permute kernel, there is good chance for + // vectorized write. + int vec_size = phi::GetVectorizedSize(dst); + type_ = PermuteType::kNormalPermute; + + // While the last dim is fixed, there is good chance for + // both vectorized read and write. + if (perm_[rank_ - 1] == rank_ - 1) { + int tmp_size = std::min(vec_size, phi::GetVectorizedSize(src)); + tmp_size = GetDimVesSize(tmp_size, dims_[rank_ - 1]); + if (tmp_size > 1) { + type_ = kVecPermute; + vec_size = tmp_size; + + // For stride calculation of src_data index. + dims_[rank_ - 1] /= vec_size; + } + } + + // Once only transpose at the last 2 dims, there is good + // chance for vectorized read. + if ((rank_ == 2 && perm_[1] == 0 && perm_[0] == 1) || + (rank_ == 3 && perm_[2] == 1 && perm_[1] == 2)) { + type_ = PermuteType::kTranspose; + + // Compared with vectorized load or read, set config to let more + // sm work simultaneously affect more according to performance. + constexpr int threads = kTileSize * kTileSize; + int blocks = count_ / threads; + if (blocks < sm_count) { + vec_size = 1; + } else { + int tmp_vec = std::min(vec_size, phi::GetVectorizedSize(src)); + // With bytes limitation of shared_memory, the VecSize shall be + // restricted for the type whose byte-size is less than 8 (double). + int type_vec = + sizeof(T) > 8 ? 1 : GetDimVesSize(tmp_vec, dims_[rank_ - 1]); + for (int i = type_vec; i > 0; i /= 2) { + if (blocks / i >= sm_count) { + break; + } + // When blocks is smaller than sm_count, a test shown that decrease + // vec_size to make blocks close to sm_count would gain performance. + vec_size = i; + } + } + + dims_[rank_ - 1] /= vec_size; + count_ /= vec_size; + } + return vec_size; + } + + // To find if highest common divisor and make it as vec_size. + int GetDimVesSize(const int vec_size, const size_t target_dim) { + int dim_vec_size = 1; + for (auto size = vec_size; size > 0; size /= 2) { + if (target_dim % size == 0) { + dim_vec_size = size; + break; + } + } + return dim_vec_size; + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/platform/fast_divmod.h b/paddle/fluid/platform/fast_divmod.h index f2a150c3012..892c5b29aae 100644 --- a/paddle/fluid/platform/fast_divmod.h +++ b/paddle/fluid/platform/fast_divmod.h @@ -59,8 +59,8 @@ struct FastDivMod { return result; } - int32_t divisor; int32_t shift_val; + uint32_t divisor; uint32_t multiplier; }; diff --git a/paddle/phi/kernels/autotune/auto_tune_base.h b/paddle/phi/kernels/autotune/auto_tune_base.h index e18b854cf34..95afa7f697b 100644 --- a/paddle/phi/kernels/autotune/auto_tune_base.h +++ b/paddle/phi/kernels/autotune/auto_tune_base.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include "glog/logging.h" @@ -23,7 +24,7 @@ namespace phi { namespace autotune { -template +template class KernelCallback { public: using ReturnT = RetureType; @@ -33,71 +34,126 @@ class KernelCallback { explicit KernelCallback(FuncType func_) : func(func_) {} virtual ~KernelCallback() {} - RetureType Call(Args... args) { return func(args...); } + RetureType Run(Args... args) { return func(args...); } private: FuncType func; }; -template -static KernelCallback MakeCallback( +template +static KernelCallback MakeCallback( RetureType (*cb)(Args...)) { - return KernelCallback(cb); + return KernelCallback(cb); } -template +template class AutoTuneBase { public: AutoTuneBase() {} virtual ~AutoTuneBase() {} - explicit AutoTuneBase(KernelType kernel) : default_kernel_(kernel) { + explicit AutoTuneBase(KernelType kernel) { kernels_.push_back(kernel); } + + template + void AddCallBack(Type kernel) { + static_assert(std::is_same::value, + "Type must be the same"); kernels_.push_back(kernel); } - template - void AddCallBack(T kernel) { - static_assert(std::is_same::value, "Type must be the same"); - kernels_.push_back(kernel); + template + void RunBestKernel(const int idx, Args&&... args) { + kernels_[idx].Run(args...); + } + + template + void RunDefaultKernel(Args&&... args) { + kernels_[0].Run(args...); } template - KernelType PickBestKernel(const Context& ctx, Args&&... args) { + int PickBestKernel(const Context& ctx, Args&&... args) { PADDLE_ENFORCE_GT( kernels_.size(), 0, paddle::platform::errors::InvalidArgument( "kernel num must be greater than 0, now is %d", kernels_.size())); - int idx = 0; - phi::GpuTimer timer; + int best_idx = 0; float min_time = std::numeric_limits::max(); + // Time cost test estabulished in default stream. for (int i = 0; i < kernels_.size(); ++i) { - ctx.Wait(); - timer.Start(0); - kernels_[i].Call(args...); - timer.Stop(0); - auto time = timer.ElapsedTime(); - VLOG(3) << "kernel[" << i << "]: time cost is " << time; - + auto time = RunAndMeasureKernel(ctx, i, args...); if (time < min_time) { min_time = time; - idx = i; + best_idx = i; } } - VLOG(3) << "best kernel idx is " << idx; - return kernels_[idx]; + VLOG(3) << "best kernel idx is " << best_idx; + return best_idx; } + bool IsInit() { return is_init_; } + void Finalize() { is_init_ = true; } + private: - KernelType default_kernel_; + bool is_init_{false}; std::vector kernels_; + + template + float RunAndMeasureKernel(const Context& ctx, const int idx, Args&&... args) { + phi::GpuTimer timer; + float time_cost = 0; + const auto& stream = ctx.stream(); + + // Treat 1st run as warm up. Judge the result with + // the sum of 2nd and 3rd run. + constexpr int repeats = 3; + + ctx.Wait(); + for (int i = 0; i < repeats; ++i) { + timer.Start(stream); + kernels_[idx].Run(args...); + timer.Stop(stream); + auto time = timer.ElapsedTime(); + if (i > 0) { + time_cost += time; + } + VLOG(3) << "kernel[" << idx << "][" << i << "th time cost is " << time; + } + return time_cost; + } }; -template -static AutoTuneBase> MakeAutoTuner( +template +static AutoTuneBase> MakeAutoTuner( RetureType (*func)(Args...)) { - auto obj = MakeCallback(func); - return AutoTuneBase(obj); + auto obj = MakeCallback(func); + return AutoTuneBase(obj); +} + +template +class TransposeAutoTuner : public AutoTuneBase { + public: + static AutoTuneBase* Instance(KernelType kernel) { + static std::unique_ptr> instance_; + std::call_once(init_flag_, [&] { + instance_.reset(new AutoTuneBase(kernel)); + }); + return instance_.get(); + } + + private: + static std::once_flag init_flag_; +}; + +template +std::once_flag TransposeAutoTuner::init_flag_; + +template +static AutoTuneBase>* + MakeTransposeTuner(RetureType (*func)(Args...)) { + auto obj = MakeCallback(func); + return TransposeAutoTuner::Instance(obj); } } // namespace autotune diff --git a/paddle/phi/kernels/autotune/auto_tune_test.cu b/paddle/phi/kernels/autotune/auto_tune_test.cu index c3918b8ebe5..8701a0572fc 100644 --- a/paddle/phi/kernels/autotune/auto_tune_test.cu +++ b/paddle/phi/kernels/autotune/auto_tune_test.cu @@ -74,7 +74,7 @@ float Algo(const phi::GPUContext& ctx, } TEST(AutoTune, sum) { - int64_t N = 1 << 22; + int64_t N = 1 << 20; size_t blocks = 512; size_t threads = 256; size_t size = sizeof(float) * N; @@ -119,35 +119,35 @@ TEST(AutoTune, sum) { // 1. Test call_back. VLOG(3) << ">>> [CallBack]: Test case."; - auto callback1 = tune::MakeCallback(Algo<4>); - auto callback2 = tune::MakeCallback(Algo<2>); - auto callback3 = tune::MakeCallback(Algo<1>); + auto callback1 = tune::MakeCallback(Algo<4>); + auto callback2 = tune::MakeCallback(Algo<2>); + auto callback3 = tune::MakeCallback(Algo<1>); std::vector callbacks{callback1, callback2, callback3}; for (int i = 0; i < callbacks.size(); ++i) { dev_ctx->Wait(); phi::GpuTimer timer; timer.Start(0); - callbacks[i].Call(*dev_ctx, *d_in1.get(), d_in2.get(), N, threads, blocks); + callbacks[i].Run(*dev_ctx, *d_in1.get(), d_in2.get(), N, threads, blocks); timer.Stop(0); VLOG(3) << "kernel[" << i << "]: time cost is " << timer.ElapsedTime(); } // 2. Test call_back tune. VLOG(3) << ">>> [AutoTune]: Test case."; - auto tuner = tune::MakeAutoTuner(Algo<4>); - tuner.AddCallBack(tune::MakeCallback(Algo<2>)); - tuner.AddCallBack(tune::MakeCallback(Algo<1>)); + auto tuner = tune::MakeAutoTuner(Algo<4>); + tuner.AddCallBack(tune::MakeCallback(Algo<2>)); + tuner.AddCallBack(tune::MakeCallback(Algo<1>)); /* The 1st ctx works for ctx.Wait(), the 2nd is just the param of call_back. */ - auto best_call_back = tuner.PickBestKernel( + auto best_index = tuner.PickBestKernel( *dev_ctx, *dev_ctx, *d_in1.get(), d_in2.get(), N, threads, blocks); - best_call_back.Call(*dev_ctx, *d_in1.get(), d_in2.get(), N, threads, blocks); dev_ctx->Wait(); phi::GpuTimer timer; timer.Start(0); - best_call_back.Call(*dev_ctx, *d_in1.get(), d_in2.get(), N, threads, blocks); + tuner.RunBestKernel( + best_index, *dev_ctx, *d_in1.get(), d_in2.get(), N, threads, blocks); timer.Stop(0); VLOG(3) << "Best CallBackKernel time cost is " << timer.ElapsedTime(); #endif diff --git a/paddle/phi/kernels/autotune/cache.h b/paddle/phi/kernels/autotune/cache.h index 9d7f57e96e3..8de0695ede4 100644 --- a/paddle/phi/kernels/autotune/cache.h +++ b/paddle/phi/kernels/autotune/cache.h @@ -134,7 +134,8 @@ enum class AlgorithmType { kConvForward = 1, kConvBackwardData = 2, kConvBackwardFilter = 3, - kAlgorithmCount = 4 + kTranspose = 4, + kAlgorithmCount = 5 }; // AlgorithmsConfigKey -> AlgorithmsID @@ -165,6 +166,8 @@ class AutoTuneCache { return Get(AlgorithmType::kConvBackwardFilter); } + AlgorithmsCacheMap& GetTranspose() { return Get(AlgorithmType::kTranspose); } + void Clean() { for (auto& v : auto_tune_map_) { v.second.Clean(); -- GitLab