diff --git a/paddle/fluid/operators/transpose_op.cu.h b/paddle/fluid/operators/transpose_op.cu.h index 0e1906bedf7b8ae6fde064b449e6d194a3d5bcfc..eb9e8a7bed78450ffcb11ccc2f24774c31e06e0e 100644 --- a/paddle/fluid/operators/transpose_op.cu.h +++ b/paddle/fluid/operators/transpose_op.cu.h @@ -18,7 +18,6 @@ limitations under the License. */ #include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/fast_divmod.h" -#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/autotune/auto_tune_base.h" @@ -832,15 +831,6 @@ class IdxAndOffsetHelper { index_helper = IdxHelper(dims); } - template - explicit IdxAndOffsetHelper(const U* dims) { - T temp_dims[N]; - for (int i = 0; i < N; ++i) { - temp_dims[i] = static_cast(dims[i]); - } - index_helper = IdxHelper(temp_dims); - } - __device__ inline T IndexToOffset(const T* index) const { T offset = 0; #pragma unroll @@ -866,15 +856,17 @@ struct PermuteParams { IdxAndOffsetHelper dst_index_helper; int perm[Rank]{}; - explicit PermuteParams(const std::vector& dims, + explicit PermuteParams(const std::vector& dims, const std::vector& perm_) { - size_t dst_dims[Rank]; - for (size_t i = 0; i < Rank; ++i) { + IndexT dst_dims[Rank]; + IndexT src_dims[Rank]; + for (auto i = 0; i < Rank; ++i) { + src_dims[i] = dims[i]; dst_dims[i] = dims[perm_[i]]; perm[i] = perm_[i]; } dst_index_helper = IdxAndOffsetHelper(dst_dims); - src_index_helper = IdxAndOffsetHelper(dims.data()); + src_index_helper = IdxAndOffsetHelper(src_dims); } }; @@ -966,21 +958,26 @@ template inline void LaunchPermuteKernel(const phi::GPUContext& ctx, const IndexT count, const PermuteType perm_type, - const std::vector& dims, + const std::vector& dims, const std::vector& perm, const T* src, T* dst) { size_t main_count = count / VecSize; - auto params = PermuteParams(dims, perm); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, main_count); - if (perm_type == PermuteType::kNormalPermute) { + if (perm_type == PermuteType::kGeneralPermute) { size_t tail_count = count - main_count * VecSize; size_t offset = count - tail_count; + auto params = PermuteParams(dims, perm); + GeneralPermuteKernel <<>>( params, src, dst, main_count, tail_count, offset); } else { + std::vector vec_dims(dims); + vec_dims[dims.size() - 1] /= VecSize; + auto params = PermuteParams(vec_dims, perm); + VectorizedPermuteKernel <<>>( params, main_count, src, dst); @@ -991,7 +988,7 @@ template inline void LaunchPermuteRankDispatch(const phi::GPUContext& ctx, const IndexT count, const PermuteType perm_type, - const std::vector& dims, + const std::vector& dims, const std::vector& perm, const T* src, T* dst) { @@ -1016,70 +1013,76 @@ inline void LaunchPermuteRankDispatch(const phi::GPUContext& ctx, #undef CALL_DISPATCH_RANK } -// Aim at transposing the last 2 dimensions. Refer from +// Aim at transposing the last 2 dimensions. Reference from // https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/ template __global__ void BatchTransposeKernel(const T* __restrict__ src_data, T* dst_data, IndexT rows, - IndexT cols) { + IndexT cols, + IndexT round_tile_rows, + IndexT round_tile_cols) { using VecT = phi::AlignedVector; - - __shared__ VecT tile[kTileSize][kShareCol]; - T* single_tile = reinterpret_cast(tile); - - IndexT col_in_matrix = blockIdx.x * kTileSize + threadIdx.x; - IndexT offset = blockIdx.z * rows * cols; + constexpr int kShareCol = kTileSize + 1; + __shared__ VecT v_shared[kTileSize * kShareCol]; + T* s_shared = reinterpret_cast(v_shared); // Vectorized load data from src into shared memory. [rows, cols] - const VecT* __restrict__ src = + const VecT* __restrict__ vec_src = reinterpret_cast(src_data); - for (IndexT tile_y = threadIdx.y; tile_y < kTileSize; tile_y += kBlockRows) { - IndexT row_in_matrix = tile_y + blockIdx.y * kTileSize; + IndexT col_in_matrix = blockIdx.x * kTileSize + threadIdx.x; + IndexT offset = blockIdx.z * rows * cols; - if (col_in_matrix < cols && row_in_matrix < rows) { - tile[tile_y][threadIdx.x] = - src[offset + row_in_matrix * cols + col_in_matrix]; + if (col_in_matrix < cols) { + int row_range = (blockIdx.y < round_tile_rows) + ? kTileSize + : (rows - kTileSize * round_tile_rows); +#pragma unroll + for (int tile_y = threadIdx.y; tile_y < row_range; tile_y += kBlockRows) { + IndexT row_in_matrix = tile_y + blockIdx.y * kTileSize; + v_shared[tile_y * kShareCol + threadIdx.x] = + vec_src[offset + row_in_matrix * cols + col_in_matrix]; } } - // Singularized load data from shared memory into dst. - // and dst_cols = rows, dst_rows = cols, [cols * Vecsize, rows] + // Write data from shared memory into dst and + // dst_cols = rows, dst_rows = cols * Vecsize col_in_matrix = blockIdx.y * kTileSize + threadIdx.x; offset = offset * VecSize + col_in_matrix; - IndexT tile_x_idx = threadIdx.x * (kShareCol * VecSize); - __syncthreads(); - for (IndexT tile_y = threadIdx.y; tile_y < kTileSize; tile_y += kBlockRows) { - IndexT row_in_matrix = tile_y + blockIdx.x * kTileSize; - IndexT dst_idx = offset + row_in_matrix * VecSize * rows; - IndexT tile_idx = tile_x_idx + tile_y * VecSize; - if (col_in_matrix < /*dst_cols=*/rows && - row_in_matrix < /*dst_rows=*/cols) { + if (col_in_matrix < /*dst_cols=*/rows) { + int col_range = (blockIdx.x < round_tile_cols) + ? kTileSize + : (cols - kTileSize * round_tile_cols); #pragma unroll - for (auto i = 0; i < VecSize; ++i) { - dst_data[dst_idx + i * rows] = single_tile[tile_idx + i]; + for (IndexT tile_y = threadIdx.y; tile_y < col_range; + tile_y += kBlockRows) { +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + IndexT row_in_matrix = (tile_y + blockIdx.x * kTileSize) * VecSize + i; + IndexT shared_idx = (tile_y + threadIdx.x * kShareCol) * VecSize + i; + dst_data[offset + row_in_matrix * rows] = s_shared[shared_idx]; } } } } -// With the byte limitation of shared_memory, the VecSize shall be restricted -// for the type whose byte-size is less than 8. +// With the byte limitation of shared_memory, the VecSize shall be +// restricted for the type whose byte-size is less than 4. template 8 ? 1 : Size)> + int VecSize = (sizeof(T) > 4 ? 1 : Size)> inline void LaunchTransposeKernel(const phi::GPUContext& ctx, - const std::vector& dims, + const std::vector& dims, const T* src, T* dst) { auto rank = dims.size(); IndexT num_batches = (rank == 2) ? 1 : dims[0]; IndexT rows = dims[rank - 2]; - IndexT cols = dims[rank - 1]; + IndexT cols = dims[rank - 1] / VecSize; IndexT num_tile_rows = (rows + kTileSize - 1) / kTileSize; IndexT num_tile_cols = (cols + kTileSize - 1) / kTileSize; @@ -1087,14 +1090,15 @@ inline void LaunchTransposeKernel(const phi::GPUContext& ctx, dim3 threads(kTileSize, kBlockRows, 1); BatchTransposeKernel - <<>>(src, dst, rows, cols); + <<>>( + src, dst, rows, cols, num_tile_rows - 1, num_tile_cols - 1); } template inline void LaunchWithDispatchVecSize(const phi::GPUContext& ctx, const int vec_size, const PermuteType perm_type, - const std::vector& dims, + const std::vector& dims, const std::vector& perm, const T* src, T* dst, @@ -1123,60 +1127,50 @@ inline void LaunchWithDispatchVecSize(const phi::GPUContext& ctx, #undef CALL_DISPATCH_VEC_SIZE } -template -inline void LaunchWithDispatchIndex(const phi::GPUContext& ctx, - const size_t count, - const int vec_size, - const PermuteType perm_type, - const std::vector& dims, - const std::vector& perm, - const T* src, - T* dst) { - if (count < std::numeric_limits::max()) { - LaunchWithDispatchVecSize(ctx, - vec_size, - perm_type, - dims, - perm, - src, - dst, - static_cast(count)); - } else { - int64_t cnt = static_cast(count); - LaunchWithDispatchVecSize(ctx, - vec_size, - perm_type, - dims, - perm, - src, - dst, - static_cast(count)); - } -} - template -inline void SimplifyThenLaunch(const int rank, - const DeviceContext& ctx, - const phi::DenseTensor& in, - phi::DenseTensor* out, - const std::vector& perm) { - int sm_count = ctx.GetSMCount(); - auto src_dims = phi::vectorize(in.dims()); - auto simplifier = DimsSimplifier( - sm_count, rank, perm, src_dims, in.data(), out->data()); - - if (simplifier.GetPermType() == PermuteType::kCopy) { +inline void PermuteAndTranspose(const int rank, + const DeviceContext& ctx, + const phi::DenseTensor& in, + phi::DenseTensor* out, + const std::vector& perm) { + const int64_t numel = in.numel(); + auto classifier = + TranposeTypeClassifier(ctx.GetSMCount(), + rank, + numel, + perm, + phi::vectorize(in.dims()), + in.data(), + out->data()); + + if (classifier.GetPermType() == PermuteType::kCopy) { // If perm is [0,1,2,3], then just operate a DtoD copy. - phi::Copy(ctx, in, ctx.GetPlace(), false, out); + phi::backends::gpu::GpuMemcpyAsync(out->data(), + in.data(), + numel * sizeof(T), + phi::gpuMemcpyDeviceToDevice, + ctx.stream()); } else { - LaunchWithDispatchIndex(ctx, - simplifier.GetCount(), - simplifier.GetVecSize(), - simplifier.GetPermType(), - simplifier.GetDims(), - simplifier.GetPerm(), - in.data(), - out->data()); + if (numel < std::numeric_limits::max()) { + LaunchWithDispatchVecSize(ctx, + classifier.GetVecSize(), + classifier.GetPermType(), + classifier.GetSrcDims(), + classifier.GetPerm(), + in.data(), + out->data(), + static_cast(numel)); + } else { + int64_t cnt = static_cast(numel); + LaunchWithDispatchVecSize(ctx, + classifier.GetVecSize(), + classifier.GetPermType(), + classifier.GetSrcDims(), + classifier.GetPerm(), + in.data(), + out->data(), + static_cast(numel)); + } } } @@ -1196,7 +1190,7 @@ void TransposeGPUKernelDriver(const phi::GPUContext& ctx, if (!ret) { auto* tuner = phi::autotune::MakeTransposeTuner(TransCompute); - tuner->AddCallBack(SimplifyThenLaunch); + tuner->AddCallBack(PermuteAndTranspose); size_t key = phi::autotune::TransposeKey( phi::vectorize(in.dims()), diff --git a/paddle/fluid/operators/transpose_op.h b/paddle/fluid/operators/transpose_op.h index a533b17fc175dd610f89e956c3c5ddf493fe33d0..45495505e605995c7f87caa65177ed00baeb4f66 100644 --- a/paddle/fluid/operators/transpose_op.h +++ b/paddle/fluid/operators/transpose_op.h @@ -71,69 +71,72 @@ enum PermuteType { kCopy = 1, kTranspose = 2, kVecPermute = 3, - kNormalPermute = 4 + kGeneralPermute = 4 }; constexpr int kBlockRows = 16; constexpr int kTileSize = 32; -// To avoid bank conflict. -constexpr int kShareCol = kTileSize + 1; // Simplify the input dims and permute dims if possible. template -class DimsSimplifier { +class TranposeTypeClassifier { public: - explicit DimsSimplifier(const int sm_count, - const int rank, - const std::vector& perm, - const std::vector& dims, - const T* src, - T* dst) - : perm_(rank), dims_(rank) { + TranposeTypeClassifier(const int sm_count, + const size_t rank, + const int64_t numel, + const std::vector& perm, + const std::vector& dims, + const T* src, + T* dst) + : perm_(rank), src_dims(rank) { SimplifyPermAndDims(rank, dims, perm); - count_ = std::accumulate( - dims.begin(), dims.end(), size_t{1}, std::multiplies()); if (rank_ > 1) { vec_size_ = GetPermVecSize(sm_count, src, dst); - perm_.resize(rank_); - dims_.resize(rank_); + } + perm_.resize(rank_); + src_dims.resize(rank_); + dst_dims.resize(rank_); + + for (auto i = 0; i < rank_; ++i) { + dst_dims[i] = src_dims[perm_[i]]; } } - size_t GetCount() const { return count_; } + int GetRank() const { return rank_; } int GetVecSize() const { return vec_size_; } PermuteType GetPermType() const { return type_; } std::vector GetPerm() const { return perm_; } - std::vector GetDims() const { return dims_; } + std::vector GetSrcDims() const { return src_dims; } + std::vector GetDstDims() const { return dst_dims; } private: - size_t rank_{1}; - size_t count_{0}; + int rank_{1}; int vec_size_{1}; std::vector perm_; - std::vector dims_; + std::vector src_dims; + std::vector dst_dims; PermuteType type_{kCopy}; void SimplifyPermAndDims(const size_t rank, - const std::vector& in_dims, + const std::vector& in_dims, const std::vector& perm) { - size_t combined_dims[phi::DDim::kMaxRank]; + int64_t combined_dims[phi::DDim::kMaxRank]; int valid_map[phi::DDim::kMaxRank]; - // Merge consecutive dims to the fist one of this these dims, - // and leave the origin dim value to be 1. Example below : + // Merge consecutive dims to the fist one dim and + // leave original dim to be 1. Example below : // perm: [2, 3, 0, 1], origin_dims : [4, 8, 2, 5] // new_dims: [4, 8, 2, 5] -> [32, 1, 10, 1] - size_t start_perm_idx = 0; + int start_perm_idx = 0; while (start_perm_idx < rank) { - const size_t start_dim_idx = perm[start_perm_idx]; + const int start_dim_idx = perm[start_perm_idx]; combined_dims[start_dim_idx] = in_dims[start_dim_idx]; - size_t end_perm_idx = start_perm_idx + 1; + int end_perm_idx = start_perm_idx + 1; while (end_perm_idx < rank && perm[end_perm_idx] == perm[end_perm_idx - 1] + 1) { - const size_t end_dim_idx = perm[end_perm_idx]; + const int end_dim_idx = perm[end_perm_idx]; combined_dims[start_dim_idx] *= in_dims[end_dim_idx]; combined_dims[end_dim_idx] = 1; end_perm_idx += 1; @@ -145,22 +148,22 @@ class DimsSimplifier { // for example, if combined dims is [32, 1, 10, 1], // valid_map is [0, -1, 1, -1] and generate simplified // dims as [32, 10] - size_t valid_dim_idx = 0; + int valid_dim_idx = 0; bool sequential_flag = false; - for (size_t i = 0; i < rank; ++i) { + for (auto i = 0; i < rank; ++i) { const int src_dim = combined_dims[i]; if (src_dim == 1) { valid_map[i] = -1; } else { sequential_flag = true; valid_map[i] = valid_dim_idx; - dims_[valid_dim_idx] = src_dim; + src_dims[valid_dim_idx] = src_dim; valid_dim_idx += 1; } } if (valid_dim_idx == 0) { - dims_[0] = 1; + src_dims[0] = 1; perm_[0] = 0; return; } else if (valid_dim_idx == 1) { @@ -169,8 +172,8 @@ class DimsSimplifier { // Acquire simplified perm with help of combined dims // and original perm, finally simplified perm is [1, 0] - size_t perm_idx = 0; - for (size_t i = 0; i < rank; ++i) { + int perm_idx = 0; + for (auto i = 0; i < rank; ++i) { const int mapped = valid_map[perm[i]]; if (mapped >= 0) { perm_[perm_idx] = mapped; @@ -183,20 +186,17 @@ class DimsSimplifier { int GetPermVecSize(const int sm_count, const T* src, T* dst) { // For gerneal_permute kernel, there is good chance for // vectorized write. + type_ = PermuteType::kGeneralPermute; int vec_size = phi::GetVectorizedSize(dst); - type_ = PermuteType::kNormalPermute; // While the last dim is fixed, there is good chance for // both vectorized read and write. if (perm_[rank_ - 1] == rank_ - 1) { int tmp_size = std::min(vec_size, phi::GetVectorizedSize(src)); - tmp_size = GetDimVesSize(tmp_size, dims_[rank_ - 1]); + tmp_size = GetDimVesSize(tmp_size, src_dims[rank_ - 1]); if (tmp_size > 1) { type_ = kVecPermute; vec_size = tmp_size; - - // For stride calculation of src_data index. - dims_[rank_ - 1] /= vec_size; } } @@ -205,31 +205,11 @@ class DimsSimplifier { if ((rank_ == 2 && perm_[1] == 0 && perm_[0] == 1) || (rank_ == 3 && perm_[2] == 1 && perm_[1] == 2)) { type_ = PermuteType::kTranspose; - - // Compared with vectorized load or read, set config to let more - // sm work simultaneously affect more according to performance. - constexpr int threads = kTileSize * kTileSize; - int blocks = count_ / threads; - if (blocks < sm_count) { - vec_size = 1; - } else { - int tmp_vec = std::min(vec_size, phi::GetVectorizedSize(src)); - // With bytes limitation of shared_memory, the VecSize shall be - // restricted for the type whose byte-size is less than 8 (double). - int type_vec = - sizeof(T) > 8 ? 1 : GetDimVesSize(tmp_vec, dims_[rank_ - 1]); - for (int i = type_vec; i > 0; i /= 2) { - if (blocks / i >= sm_count) { - break; - } - // When blocks is smaller than sm_count, a test shown that decrease - // vec_size to make blocks close to sm_count would gain performance. - vec_size = i; - } - } - - dims_[rank_ - 1] /= vec_size; - count_ /= vec_size; + int tmp_vec = std::min(vec_size, phi::GetVectorizedSize(src)); + // With bytes limitation of shared_memory, the VecSize shall be + // restricted for the type whose byte-size is less than 8 (double). + vec_size = + sizeof(T) > 8 ? 1 : GetDimVesSize(tmp_vec, src_dims[rank_ - 1]); } return vec_size; } diff --git a/paddle/phi/kernels/autotune/auto_tune_base.h b/paddle/phi/kernels/autotune/auto_tune_base.h index ff97b2a1f48f4bf046fe0e8b4728a321d1a62336..d9f96ec2328f445442ffe6246d58214472ae3173 100644 --- a/paddle/phi/kernels/autotune/auto_tune_base.h +++ b/paddle/phi/kernels/autotune/auto_tune_base.h @@ -123,7 +123,7 @@ class AutoTuneBase { float RunAndMeasureKernel(const Context& ctx, const int idx, Args&&... args) { // Regard 1st run as warmup, judge the compare result by the time cost // of rest cycles. - constexpr int repeats = 3; + constexpr int repeats = 4; phi::GpuTimer timer; float time_cost = 0; const auto& stream = ctx.stream();