From a0f4388946535070d8cde2f6dfbf17ef7bf61c4c Mon Sep 17 00:00:00 2001 From: limingshu <61349199+JamesLim-sy@users.noreply.github.com> Date: Mon, 5 Dec 2022 21:28:56 +0800 Subject: [PATCH] Transpose optimization for AlphaFold2 (#45230) * first commit * fix bugs according to ci * add some changes * change file name into function.cu.h * remove const_cast --- paddle/fluid/operators/fused/fmha_ref.h | 2 +- .../operators/fused/fused_gate_attention.h | 2 +- paddle/phi/kernels/funcs/dims_simplifier.h | 101 +++ ...e_functor.cu.h => transpose_function.cu.h} | 816 +++++++++++------- paddle/phi/kernels/funcs/transpose_functor.h | 216 ++--- paddle/phi/kernels/gpu/transpose_kernel.cu | 2 +- 6 files changed, 693 insertions(+), 446 deletions(-) rename paddle/phi/kernels/funcs/{transpose_functor.cu.h => transpose_function.cu.h} (59%) diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index fc5f9cf71d..11939a454b 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -21,7 +21,7 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/functors.h" -#include "paddle/phi/kernels/funcs/transpose_functor.cu.h" +#include "paddle/phi/kernels/funcs/transpose_function.cu.h" #include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" namespace paddle { diff --git a/paddle/fluid/operators/fused/fused_gate_attention.h b/paddle/fluid/operators/fused/fused_gate_attention.h index 1fba366ad2..d55d047009 100644 --- a/paddle/fluid/operators/fused/fused_gate_attention.h +++ b/paddle/fluid/operators/fused/fused_gate_attention.h @@ -18,7 +18,7 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/reduce_function.h" -#include "paddle/phi/kernels/funcs/transpose_functor.cu.h" +#include "paddle/phi/kernels/funcs/transpose_function.cu.h" #include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" namespace paddle { diff --git a/paddle/phi/kernels/funcs/dims_simplifier.h b/paddle/phi/kernels/funcs/dims_simplifier.h index 21f14bdba7..0ef0f6ac5b 100644 --- a/paddle/phi/kernels/funcs/dims_simplifier.h +++ b/paddle/phi/kernels/funcs/dims_simplifier.h @@ -243,5 +243,106 @@ struct BroadcastDimsSimplifier { } }; +// Simplify the input dims and permute dims if possible. +struct DimsSimplifier { + public: + explicit DimsSimplifier(const int rank, + const int64_t numel, + const std::vector &perm, + const std::vector &dims) + : perm_(rank), src_dims_(rank), count_(numel) { + SimplifyPermAndDims(rank, dims, perm); + perm_.resize(rank_); + src_dims_.resize(rank_); + dst_dims_.resize(rank_); + if (!is_seq_perm_) { + for (auto i = 0; i < rank_; ++i) { + dst_dims_[i] = src_dims_[perm_[i]]; + } + } else { + dst_dims_[0] = numel; + src_dims_[0] = numel; + } + } + + ~DimsSimplifier() = default; + + const int &GetRank() const { return rank_; } + const int64_t &GetCount() const { return count_; } + const std::vector &GetPerm() const { return perm_; } + const std::vector &GetSrcDims() const { return src_dims_; } + const std::vector &GetDstDims() const { return dst_dims_; } + + private: + int rank_{1}; + int64_t count_{0}; + bool is_seq_perm_{true}; + std::vector perm_; + std::vector src_dims_; + std::vector dst_dims_; + + void SimplifyPermAndDims(const int rank, + const std::vector &in_dims, + const std::vector &perm) { + int start_perm_idx = 0; + int valid_dim_idx = 0; + int valid_map[phi::DDim::kMaxRank]; + int64_t combined_dims[phi::DDim::kMaxRank]; + + // Merge consecutive dims to the fist one dim and + // leave original dim to be 1. Example below : + // perm: [2, 3, 0, 1], origin_dims : [4, 8, 2, 5] + // new_dims: [4, 8, 2, 5] -> [32, 1, 10, 1] + while (start_perm_idx < rank) { + const int start_dim_idx = perm[start_perm_idx]; + combined_dims[start_dim_idx] = in_dims[start_dim_idx]; + int end_perm_idx = start_perm_idx + 1; + + while (end_perm_idx < rank && + perm[end_perm_idx] == perm[end_perm_idx - 1] + 1) { + const int end_dim_idx = perm[end_perm_idx]; + combined_dims[start_dim_idx] *= in_dims[end_dim_idx]; + combined_dims[end_dim_idx] = 1; + end_perm_idx += 1; + } + start_perm_idx = end_perm_idx; + } + + // Reorder combined dims and marked useless dim as -1. + // for example, if combined dims is [32, 1, 10, 1], + // valid_map is [0, -1, 1, -1] and generate simplified + // dims as [32, 10] + for (auto i = 0; i < rank; ++i) { + const int dim_val = combined_dims[i]; + if (dim_val == 1) { + valid_map[i] = -1; + } else { + valid_map[i] = valid_dim_idx; + src_dims_[valid_dim_idx] = dim_val; + valid_dim_idx += 1; + } + } + + if (valid_dim_idx == 0) { + src_dims_[0] = 1; + perm_[0] = 0; + return; + } + + // Acquire simplified perm with help of combined dims + // and original perm, finally simplified perm is [1, 0] + int perm_idx = 0; + for (auto i = 0; i < rank; ++i) { + const int mapped = valid_map[perm[i]]; + if (mapped >= 0) { + perm_[perm_idx] = mapped; + is_seq_perm_ &= (mapped == perm_idx); + perm_idx += 1; + } + } + rank_ = is_seq_perm_ ? 1 : valid_dim_idx; + } +}; + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/transpose_functor.cu.h b/paddle/phi/kernels/funcs/transpose_function.cu.h similarity index 59% rename from paddle/phi/kernels/funcs/transpose_functor.cu.h rename to paddle/phi/kernels/funcs/transpose_function.cu.h index 8dae6ab60e..d4e36745a4 100644 --- a/paddle/phi/kernels/funcs/transpose_functor.cu.h +++ b/paddle/phi/kernels/funcs/transpose_function.cu.h @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_utils.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/autotune/auto_tune_base.h" +#include "paddle/phi/kernels/funcs/dims_simplifier.h" #include "paddle/phi/kernels/funcs/transpose_functor.h" #include "paddle/phi/kernels/primitive/datamover_primitives.h" @@ -191,7 +192,6 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input, IndexType output_origin_block_flat_index = FlatTensorIndex(block_tile_index_in_output, output_dims); - constexpr IndexType out_effective_thread_num = NumThreads / TileX * TileX; if (x < out_effective_thread_num) { @@ -652,7 +652,7 @@ struct SwapDim0And2InTranspose { inline void CombineTransposeDim3(const DDim& shape, const std::vector& perm, std::vector* new_perm, - DDim* new_dims) { + std::vector* new_dims) { PADDLE_ENFORCE_EQ(shape.size(), perm.size(), phi::errors::InvalidArgument( @@ -667,114 +667,111 @@ inline void CombineTransposeDim3(const DDim& shape, new_perm->resize(1); (*new_perm)[0] = perm[0]; dim_vec.push_back(shape[0]); - *new_dims = phi::make_ddim(dim_vec); - return; - } - std::vector new_dim_pos(shape.size(), -1); - std::vector combined_dims(shape.size(), 0); - int cur_head = perm[0]; - new_dim_pos[cur_head] = 0; - combined_dims[0] = shape[cur_head]; - int dim_idx = 0; - for (int perm_idx = 1; perm_idx < shape.size(); ++perm_idx) { - // combine consecutive dimensions. - if (cur_head + 1 == perm[perm_idx]) { - cur_head = perm[perm_idx]; - combined_dims[dim_idx] *= shape[cur_head]; - } else { - // Else start a new dimension. - cur_head = perm[perm_idx]; - dim_idx++; - new_dim_pos[cur_head] = dim_idx; - combined_dims[dim_idx] = shape[cur_head]; + } else { + int dim_idx = 0; + std::vector new_dim_pos(shape.size(), -1); + std::vector combined_dims(shape.size(), 0); + + int cur_head = perm[0]; + new_dim_pos[cur_head] = 0; + combined_dims[0] = shape[cur_head]; + for (int perm_idx = 1; perm_idx < shape.size(); ++perm_idx) { + // combine consecutive dimensions. + if (cur_head + 1 == perm[perm_idx]) { + cur_head = perm[perm_idx]; + combined_dims[dim_idx] *= shape[cur_head]; + } else { + // Else start a new dimension. + cur_head = perm[perm_idx]; + dim_idx++; + new_dim_pos[cur_head] = dim_idx; + combined_dims[dim_idx] = shape[cur_head]; + } } - } - - new_perm->resize(dim_idx + 1); - - dim_idx = 0; - for (int i = 0; i < new_dim_pos.size(); ++i) { - if (new_dim_pos[i] >= 0) { - int new_perm_idx = new_dim_pos[i]; - (*new_perm)[dim_idx] = new_perm_idx; - dim_vec.push_back(combined_dims[new_perm_idx]); - dim_idx++; + new_perm->resize(dim_idx + 1); + + dim_idx = 0; + for (int i = 0; i < new_dim_pos.size(); ++i) { + if (new_dim_pos[i] >= 0) { + int new_perm_idx = new_dim_pos[i]; + (*new_perm)[dim_idx] = new_perm_idx; + dim_vec.push_back(combined_dims[new_perm_idx]); + dim_idx++; + } } } - - *new_dims = phi::make_ddim(dim_vec); + *new_dims = dim_vec; } -template +template struct TransposeSimple { - static bool run(const phi::GPUContext& ctx, + static bool Impl(const phi::GPUContext& ctx, + const phi::DenseTensor& in, + const std::vector perm, + phi::DenseTensor* out, + const int64_t numel) { + if (numel >= std::numeric_limits::max()) { + return Run(ctx, in, perm, out); + } else { + return Run(ctx, in, perm, out); + } + } + + private: + template + static bool Run(const phi::GPUContext& ctx, const phi::DenseTensor& in, const std::vector perm, phi::DenseTensor* out) { // First reduce the dimensions of the input tensor if possible. + auto in_data = in.data(); + auto out_data = out->data(); std::vector new_perm; - DDim new_dims; + std::vector new_dims; CombineTransposeDim3(in.dims(), perm, &new_perm, &new_dims); + if (new_perm.size() < 2 || new_perm.size() > 3) return false; - // Only use tile copy GPU kernel when dimension is 2 or 3. - int dims = new_dims.size(); - std::vector new_dim_vec = phi::vectorize(new_dims); - if (dims < 2 || dims > 3) return false; - auto in_data = in.data(); - auto out_data = out->data(); // In most cases, dim will not greater than 3 after combine. - switch (dims) { - case 2: - if (new_perm[0] == 1 && new_perm[1] == 0) { - // Add the first dimension size as 1. - new_dim_vec.insert(new_dim_vec.begin(), 1); - SwapDim1And2InTranspose()( - ctx, in_data, new_dim_vec, out_data); - return true; - } - break; - case 3: - // In this case, suppose we can do coalescing read and write in tile. - if (new_perm == std::vector({0, 2, 1})) { - SwapDim1And2InTranspose()( - ctx, in_data, new_dim_vec, out_data); - return true; - } else if (new_perm == std::vector({2, 1, 0})) { - // Maybe can optimize later, find a way to do coalescing memory copy. - // But I think it depends on the data size. If span is not large, - // maybe - // can do coalescing. - SwapDim0And2InTranspose()( - ctx, in_data, new_dim_vec, out_data); - return true; - } else { - return false; - } - break; - default: - return false; + if (new_perm.size() == 2 && new_perm[1] == 0) { + // Add the first dimension size as 1. + new_dims.insert(new_dims.begin(), 1); + SwapDim1And2InTranspose()(ctx, in_data, new_dims, out_data); + return true; + } else if (new_perm == std::vector({0, 2, 1})) { + SwapDim1And2InTranspose()(ctx, in_data, new_dims, out_data); + return true; + } else if (new_perm == std::vector({2, 1, 0})) { + // Maybe can optimize later, find a way to do coalescing memory copy. + // But I think it depends on the data size. If span is not large, + // maybe can do coalescing. + SwapDim0And2InTranspose()(ctx, in_data, new_dims, out_data); + return true; + } else { + return false; } - return false; } }; -template +template class IdxHelper { public: IdxHelper() {} - explicit IdxHelper(const T* dims) { + explicit IdxHelper(const IndexT* dims) { for (int i = N - 1; i >= 0; --i) { stride_[i] = i < (N - 1) ? dims[i + 1] * stride_[i + 1] : 1; } } - __device__ inline T GetStride(int idx) const { return stride_[idx]; } + __device__ __forceinline__ IndexT GetStride(int idx) const { + return stride_[idx]; + } - __device__ inline void GetIndexFromOffset(T offset, T* index) const { - T remaining = offset; + __device__ __forceinline__ void GetIndexFromOffset(IndexT offset, + IndexT* index) const { + IndexT remaining = offset; #pragma unroll for (int i = 0; i < N - 1; ++i) { - const T idx = remaining / stride_[i]; + const IndexT idx = remaining / stride_[i]; remaining -= idx * stride_[i]; index[i] = idx; } @@ -782,11 +779,11 @@ class IdxHelper { } private: - T stride_[N]; + IndexT stride_[N]; }; template -class IdxHelper { +class IdxHelper { public: IdxHelper() {} explicit IdxHelper(const uint32_t* dims) { @@ -797,10 +794,12 @@ class IdxHelper { } } - __device__ inline uint32_t GetStride(int idx) const { return stride_[idx]; } + __device__ __forceinline__ uint32_t GetStride(int idx) const { + return stride_[idx]; + } - __device__ inline void GetIndexFromOffset(uint32_t offset, - uint32_t* index) const { + __device__ __forceinline__ void GetIndexFromOffset(uint32_t offset, + uint32_t* index) const { uint32_t remaining = offset; #pragma unroll for (int i = 0; i < N - 1; ++i) { @@ -817,18 +816,16 @@ class IdxHelper { }; // Transform index between memory offset and shape coodinate. -template +template class IdxAndOffsetHelper { public: IdxAndOffsetHelper() {} - ~IdxAndOffsetHelper() = default; - - explicit IdxAndOffsetHelper(const T* dims) { - index_helper = IdxHelper(dims); + explicit IdxAndOffsetHelper(const IndexT* dims) { + index_helper = IdxHelper(dims); } - __device__ inline T IndexToOffset(const T* index) const { - T offset = 0; + __device__ __forceinline__ IndexT IndexToOffset(const IndexT* index) const { + IndexT offset = 0; #pragma unroll for (int i = 0; i < N - 1; ++i) { offset += index[i] * index_helper.GetStride(i); @@ -837,15 +834,16 @@ class IdxAndOffsetHelper { return offset; } - __device__ inline void OffsetToIndex(T offset, T* index) const { + __device__ __forceinline__ void OffsetToIndex(IndexT offset, + IndexT* index) const { index_helper.GetIndexFromOffset(offset, index); } private: - IdxHelper index_helper; + IdxHelper index_helper; }; -template +template struct PermuteParams { public: IdxAndOffsetHelper src_index_helper; @@ -868,17 +866,17 @@ struct PermuteParams { // A special kernel for target case, both vectorized read and write supported. template -__global__ void VectorizedPermuteKernel(PermuteParams params, - const size_t count, +__global__ void VectorizedPermuteKernel(PermuteParams params, + const IndexT count, const T* __restrict__ src_data, T* dst_data) { using VecT = phi::AlignedVector; IndexT src_index[Rank]; IndexT dst_index[Rank]; - const VecT* __restrict__ src = + const VecT* __restrict__ vec_src = reinterpret_cast(src_data); - VecT* dst = reinterpret_cast(dst_data); + VecT* vec_dst = reinterpret_cast(dst_data); IndexT tid = blockIdx.x * blockDim.x + threadIdx.x; for (IndexT i = tid; i < count; i += blockDim.x * gridDim.x) { @@ -889,31 +887,23 @@ __global__ void VectorizedPermuteKernel(PermuteParams params, src_index[params.perm[j]] = dst_index[j]; } IndexT src_offset = params.src_index_helper.IndexToOffset(src_index); - dst[i] = src[src_offset]; + vec_dst[i] = vec_src[src_offset]; } } // A general kernel for normal case, only support vectorized write. template -__global__ void GeneralPermuteKernel(PermuteParams params, +__global__ void GeneralPermuteKernel(PermuteParams params, + const IndexT main_cnt, + const IndexT tail_cnt, + const IndexT offset, const T* __restrict__ src, - T* dst, - const size_t main_cnt, - const size_t tail_cnt, - const size_t offset) { + T* dst) { using VecT = phi::AlignedVector; VecT* vec_dst = reinterpret_cast(dst); - IndexT src_index[VecSize][Rank]; IndexT dst_index[VecSize][Rank]; - // Avoid read perm data both in 2 load process. - __shared__ int perm[Rank]; - if (threadIdx.x < Rank) { - perm[threadIdx.x] = params.perm[threadIdx.x]; - } - __syncthreads(); - // Vectorized load data. IndexT tid = blockIdx.x * blockDim.x + threadIdx.x; for (IndexT idx = tid; idx < main_cnt; idx += blockDim.x * gridDim.x) { @@ -926,7 +916,7 @@ __global__ void GeneralPermuteKernel(PermuteParams params, #pragma unroll for (int j = 0; j < Rank; ++j) { - src_index[i][perm[j]] = dst_index[i][j]; + src_index[i][params.perm[j]] = dst_index[i][j]; } IndexT src_offset = params.src_index_helper.IndexToOffset(src_index[i]); vec_data[i] = src[src_offset]; @@ -941,235 +931,441 @@ __global__ void GeneralPermuteKernel(PermuteParams params, #pragma unroll for (int j = 0; j < Rank; ++j) { - src_index[0][perm[j]] = dst_index[0][j]; + src_index[0][params.perm[j]] = dst_index[0][j]; } IndexT src_offset = params.src_index_helper.IndexToOffset(src_index[0]); dst[idx] = src[src_offset]; } } -// A Gerneral permute method that drectly find the dst data -// coordinate in the source data. -template -inline void LaunchPermuteKernel(const phi::GPUContext& ctx, - const IndexT count, - const PermuteType perm_type, - const std::vector& dims, - const std::vector& perm, - const T* src, - T* dst) { - size_t main_count = count / VecSize; - auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, main_count); - - if (perm_type == PermuteType::kGeneralPermute) { - size_t tail_count = count - main_count * VecSize; - size_t offset = count - tail_count; - auto params = PermuteParams(dims, perm); - - GeneralPermuteKernel - <<>>( - params, src, dst, main_count, tail_count, offset); - } else { - std::vector vec_dims(dims); - vec_dims[dims.size() - 1] /= VecSize; - auto params = PermuteParams(vec_dims, perm); - - VectorizedPermuteKernel - <<>>( - params, main_count, src, dst); +template +struct TransposeDataWriter { + __device__ __forceinline__ void operator()(T* dst_data, + const T* s_data, + const IndexT rows, + const IndexT cols, + const IndexT chs_stride, + const IndexT round_tile_cols, + const IndexT col_stride = 1) { + using OutVecT = phi::AlignedVector; + OutVecT* vec_dst = reinterpret_cast(dst_data); + + constexpr int kColTile = kTileSize * ReadSize; + constexpr int kColStride = kShareCol * ReadSize; + + const IndexT vec_rows = rows / WriteSize; + const IndexT col_in_mat = blockIdx.y * kTileSize + threadIdx.x; + + if (col_in_mat < /*dst_cols=*/vec_rows) { + const int cols_range = (blockIdx.x < round_tile_cols) + ? kTileSize + : (cols - round_tile_cols * kTileSize); + const int share_tile = threadIdx.x * (WriteSize * kColStride); + const IndexT write_offset = blockIdx.z * chs_stride + col_in_mat; +#pragma unroll + for (int tile_y = threadIdx.y; tile_y < cols_range; + tile_y += kBlockRows) { + OutVecT tmp_data[ReadSize]; +#pragma unroll + for (int i = 0; i < ReadSize; ++i) { + int tile_tail = tile_y * ReadSize + i; + int major_share_idx = share_tile + tile_tail; + IndexT row_in_mat = (blockIdx.x * kColTile + tile_tail) * col_stride; +#pragma unroll + for (int j = 0; j < WriteSize; ++j) { + tmp_data[i].val[j] = s_data[j * kColStride + major_share_idx]; + } + vec_dst[write_offset + row_in_mat * vec_rows] = tmp_data[i]; + } + } + } } -} +}; -template -inline void LaunchPermuteRankDispatch(const phi::GPUContext& ctx, - const IndexT count, - const PermuteType perm_type, - const std::vector& dims, - const std::vector& perm, - const T* src, - T* dst) { -#define CALL_DISPATCH_RANK(rank) \ - case rank: { \ - LaunchPermuteKernel( \ - ctx, count, perm_type, dims, perm, src, dst); \ - break; \ +template +struct TransposeDataWriter { + __device__ __forceinline__ void operator()(T* dst_data, + const T* s_data, + const IndexT rows, + const IndexT cols, + const IndexT chs_stride, + const IndexT round_tile_cols, + const IndexT col_stride = 1) { + const IndexT col_in_mat = blockIdx.y * kTileSize + threadIdx.x; + if (col_in_mat < /*dst_cols=*/rows) { + const int cols_range = (blockIdx.x < round_tile_cols) + ? kTileSize + : (cols - round_tile_cols * kTileSize); + const IndexT row_tile = blockIdx.x * kTileSize * ReadSize; + const IndexT write_offset = blockIdx.z * chs_stride + col_in_mat; + const int shared_tile = threadIdx.x * kShareCol * ReadSize; +#pragma unroll + for (int tile_y = threadIdx.y; tile_y < cols_range; + tile_y += kBlockRows) { + const int shared_major = shared_tile + tile_y * ReadSize; + const IndexT row_major = (row_tile + tile_y * ReadSize) * col_stride; +#pragma unroll + for (int i = 0; i < ReadSize; ++i) { + const IndexT row_in_mat = row_major + i * col_stride; + dst_data[write_offset + row_in_mat * rows] = s_data[shared_major + i]; + } + } + } } +}; - switch (dims.size()) { - CALL_DISPATCH_RANK(1); - CALL_DISPATCH_RANK(2); - CALL_DISPATCH_RANK(3); - CALL_DISPATCH_RANK(4); - CALL_DISPATCH_RANK(5); - CALL_DISPATCH_RANK(6); - CALL_DISPATCH_RANK(7); - CALL_DISPATCH_RANK(8); - CALL_DISPATCH_RANK(9); +template +struct TransposeDataReader { + __device__ __forceinline__ void operator()(const T* __restrict__ src, + T* s_shared, + const IndexT cols, + const IndexT rows, + const IndexT chs_stride, + const IndexT cols_thresh, + const IndexT round_tile_rows) { + using VecT = phi::AlignedVector; + const VecT* __restrict__ v_src = + reinterpret_cast(src); + VecT* v_shared = reinterpret_cast(s_shared); + + const IndexT col_in_mat = blockIdx.x * kTileSize + threadIdx.x; + if (col_in_mat < cols_thresh) { + const int row_range = (blockIdx.y < round_tile_rows) + ? RowTile + : (rows - RowTile * round_tile_rows); + const IndexT src_idx_major = blockIdx.z * chs_stride + col_in_mat; +#pragma unroll + for (int tile_y = threadIdx.y; tile_y < row_range; tile_y += kBlockRows) { + const IndexT row_in_mat = blockIdx.y * RowTile + tile_y; + v_shared[tile_y * kShareCol + threadIdx.x] = + v_src[row_in_mat * cols + src_idx_major]; + } + } + __syncthreads(); } -#undef CALL_DISPATCH_RANK -} +}; // Aim at transposing the last 2 dimensions. Reference from // https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/ -template +template +__global__ void SwapTransposeKernel(const T* __restrict__ src_data, + T* dst_data, + const IndexT round_tile_rows, + const IndexT round_tile_cols, + const IndexT cols, + const IndexT rows, + const IndexT chs /*=channel*/) { + constexpr int kRowTile = kTileSize * WriteSize; + __shared__ T s_data[kRowTile * kShareCol * ReadSize]; + + const IndexT chs_stride = chs * cols; + TransposeDataReader()( + src_data, s_data, chs_stride, rows, cols, cols, round_tile_rows); + TransposeDataWriter()( + dst_data, s_data, rows, cols, rows / WriteSize, round_tile_cols, chs); +} + +template __global__ void BatchTransposeKernel(const T* __restrict__ src_data, T* dst_data, - IndexT rows, - IndexT cols, - IndexT round_tile_rows, - IndexT round_tile_cols) { - using VecT = phi::AlignedVector; - constexpr int kShareCol = kTileSize + 1; - __shared__ VecT v_shared[kTileSize * kShareCol]; - T* s_shared = reinterpret_cast(v_shared); - - // Vectorized load data from src into shared memory. [rows, cols] - const VecT* __restrict__ vec_src = - reinterpret_cast(src_data); + const IndexT round_tile_rows, + const IndexT round_tile_cols, + const IndexT cols, + const IndexT rows) { + constexpr int kRowTile = kTileSize * WriteSize; + __shared__ T s_data[kRowTile * kShareCol * ReadSize]; + + const IndexT chs_stride = rows * cols; + TransposeDataReader()( + src_data, s_data, cols, rows, chs_stride, cols, round_tile_rows); + TransposeDataWriter()( + dst_data, + s_data, + rows, + cols, + chs_stride * ReadSize / WriteSize, + round_tile_cols); +} - IndexT col_in_matrix = blockIdx.x * kTileSize + threadIdx.x; - IndexT offset = blockIdx.z * rows * cols; +template +struct PermuteLauncher { + public: + void operator()(const phi::GPUContext& ctx, + const int& rank, + const IndexT& count, + const PermuteType& perm_type, + const std::vector& dims, + const std::vector& perm, + const T* src, + T* dst) { + dims_ = dims; + main_cnt_ = count / VecSize; +#define CALL_PERMUTE_DISPATCH_RANK(rank_) \ + case rank_: { \ + Run(ctx, perm, perm_type, count, src, dst); \ + break; \ + } - if (col_in_matrix < cols) { - int row_range = (blockIdx.y < round_tile_rows) - ? kTileSize - : (rows - kTileSize * round_tile_rows); -#pragma unroll - for (int tile_y = threadIdx.y; tile_y < row_range; tile_y += kBlockRows) { - IndexT row_in_matrix = tile_y + blockIdx.y * kTileSize; - v_shared[tile_y * kShareCol + threadIdx.x] = - vec_src[offset + row_in_matrix * cols + col_in_matrix]; + switch (rank) { + CALL_PERMUTE_DISPATCH_RANK(3); + CALL_PERMUTE_DISPATCH_RANK(4); + CALL_PERMUTE_DISPATCH_RANK(5); + CALL_PERMUTE_DISPATCH_RANK(6); + CALL_PERMUTE_DISPATCH_RANK(7); + CALL_PERMUTE_DISPATCH_RANK(8); + CALL_PERMUTE_DISPATCH_RANK(9); } +#undef CALL_PERMUTE_DISPATCH_RANK } - // Write data from shared memory into dst and - // dst_cols = rows, dst_rows = cols * Vecsize - col_in_matrix = blockIdx.y * kTileSize + threadIdx.x; - offset = offset * VecSize + col_in_matrix; - __syncthreads(); + private: + IndexT main_cnt_{0}; + std::vector dims_; + + template + void Run(const phi::GPUContext& ctx, + const std::vector& perm, + const PermuteType& perm_type, + const IndexT& count, + const T* src, + T* dst) { + auto cfg = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, main_cnt_); + if (perm_type == PermuteType::kVecPermute) { + dims_[Rank - 1] /= VecSize; + const auto params = PermuteParams(dims_, perm); + + VectorizedPermuteKernel + <<>>( + params, main_cnt_, src, dst); + } else { + IndexT tail_cnt = count - main_cnt_ * VecSize; + IndexT main_offset = count - tail_cnt; + const auto params = PermuteParams(dims_, perm); - if (col_in_matrix < /*dst_cols=*/rows) { - int col_range = (blockIdx.x < round_tile_cols) - ? kTileSize - : (cols - kTileSize * round_tile_cols); -#pragma unroll - for (IndexT tile_y = threadIdx.y; tile_y < col_range; - tile_y += kBlockRows) { -#pragma unroll - for (int i = 0; i < VecSize; ++i) { - IndexT row_in_matrix = (tile_y + blockIdx.x * kTileSize) * VecSize + i; - IndexT shared_idx = (tile_y + threadIdx.x * kShareCol) * VecSize + i; - dst_data[offset + row_in_matrix * rows] = s_shared[shared_idx]; + GeneralPermuteKernel + <<>>( + params, main_cnt_, tail_cnt, main_offset, src, dst); + } + } +}; + +template +struct TransposeLauncher { + public: + void operator()(const phi::GPUContext& ctx, + const int& rank, + const PermuteType& perm_type, + const std::vector& dims, + const IndexT& num_rows_tile, + const T* src, + T* dst) { + constexpr int ReadSize = sizeof(T) > sizeof(float) ? 1 : VecSize; + const IndexT cols = dims[rank - 1] / VecSize; + const IndexT n_cols_tile = GETTILESIZE(cols, kTileSize); + + if (perm_type == PermuteType::kGeneralTranspose) { + IndexT chs = (rank == 2) ? 1 : dims[0]; + IndexT rows = dims[rank - 2]; + IndexT n_rows_tile = + FindRowTiles(chs, rows, num_rows_tile, n_cols_tile, ctx.GetSMCount()); + dim3 blocks(n_cols_tile, n_rows_tile, chs); + dim3 threads(kTileSize, kBlockRows, 1); + + if (is_vec_write) { + BatchTransposeKernel + <<>>( + src, dst, n_rows_tile - 1, n_cols_tile - 1, cols, rows); + } else { + BatchTransposeKernel + <<>>( + src, dst, n_rows_tile - 1, n_cols_tile - 1, cols, rows); + } + } else { + IndexT rows = dims[0]; + IndexT chs = dims[rank - 2]; + IndexT n_rows_tile = + FindRowTiles(chs, rows, num_rows_tile, n_cols_tile, ctx.GetSMCount()); + dim3 blocks(n_cols_tile, n_rows_tile, chs); + dim3 threads(kTileSize, kBlockRows, 1); + + if (is_vec_write) { + SwapTransposeKernel + <<>>( + src, dst, n_rows_tile - 1, n_cols_tile - 1, cols, rows, chs); + } else { + SwapTransposeKernel + <<>>( + src, dst, n_rows_tile - 1, n_cols_tile - 1, cols, rows, chs); } } } -} -// With the byte limitation of shared_memory, the VecSize shall be -// restricted for the type whose byte-size is less than 4. -template 4 ? 1 : Size)> -inline void LaunchTransposeKernel(const phi::GPUContext& ctx, - const std::vector& dims, - const T* src, - T* dst) { - auto rank = dims.size(); - IndexT num_batches = (rank == 2) ? 1 : dims[0]; - IndexT rows = dims[rank - 2]; - IndexT cols = dims[rank - 1] / VecSize; - IndexT num_tile_rows = (rows + kTileSize - 1) / kTileSize; - IndexT num_tile_cols = (cols + kTileSize - 1) / kTileSize; - - dim3 blocks(num_tile_cols, num_tile_rows, num_batches); - dim3 threads(kTileSize, kBlockRows, 1); - - BatchTransposeKernel - <<>>( - src, dst, rows, cols, num_tile_rows - 1, num_tile_cols - 1); -} + private: + bool is_vec_write{false}; + inline IndexT FindRowTiles(const IndexT& chs, + const IndexT& rows, + const IndexT& num_rows_tile, + const IndexT& num_cols_tile, + const int& sm_count) { + constexpr int kVecRow = sizeof(float) / sizeof(T); + is_vec_write = + (sizeof(T) < sizeof(float)) ? ((rows % kVecRow) ? false : true) : false; + + int vec_write = 1; + if (is_vec_write) { + is_vec_write = (chs * num_cols_tile * num_rows_tile) > sm_count; + vec_write = is_vec_write ? kVecRow : 1; + } + IndexT n_rows_tile = is_vec_write + ? GETTILESIZE(rows, (kTileSize * vec_write)) + : num_rows_tile; + return n_rows_tile; + } +}; template -inline void LaunchWithDispatchVecSize(const phi::GPUContext& ctx, - const int vec_size, - const PermuteType perm_type, - const std::vector& dims, - const std::vector& perm, - const T* src, - T* dst, - IndexT count) { -#define CALL_DISPATCH_VEC_SIZE(vec_size) \ - case vec_size: { \ - if (perm_type == PermuteType::kTranspose) { \ - LaunchTransposeKernel(ctx, dims, src, dst); \ - } else { \ - LaunchPermuteRankDispatch( \ - ctx, count, perm_type, dims, perm, src, dst); \ - } \ - break; \ +struct PermuteDispatch { + public: + PermuteDispatch(const phi::GPUContext& ctx, + PermTypeClassifier* cls_ptr, + const std::vector& dims, + const std::vector& perm, + const IndexT count, + const T* src, + T* dst) + : dims_(dims), cls_(cls_ptr) { + rank_ = dims_.size(); + type_ = cls_->GetPermType(); + KernelTypeDispatch(ctx, count, perm, src, dst); } + ~PermuteDispatch() {} - switch (vec_size) { - CALL_DISPATCH_VEC_SIZE(1); - CALL_DISPATCH_VEC_SIZE(2); - CALL_DISPATCH_VEC_SIZE(4); - default: { - PADDLE_THROW(phi::errors::Unimplemented( - "Unsupported vectorized size: %d !", vec_size)); - break; + private: + int rank_{0}; + std::vector dims_; + PermTypeClassifier* cls_; + PermuteType type_{kGeneralPermute}; + + void KernelTypeDispatch(const phi::GPUContext& ctx, + const IndexT& count, + const std::vector& perm, + const T* src, + T* dst) { +#define TRANSPOSE_DISPATCH_VEC_SIZE(size) \ + case size: { \ + TransposeLauncher()( \ + ctx, rank_, type_, dims_, cls_->GetRowsTile(), src, dst); \ + break; \ + } + +#define PERMUTE_DISPATCH_VEC_SIZE(size) \ + case size: { \ + PermuteLauncher()( \ + ctx, rank_, count, type_, dims_, perm, src, dst); \ + break; \ + } + + switch (type_) { + case kSwapTranspose: + case kGeneralTranspose: + switch (cls_->GetVecSize()) { + TRANSPOSE_DISPATCH_VEC_SIZE(1); + TRANSPOSE_DISPATCH_VEC_SIZE(2); + TRANSPOSE_DISPATCH_VEC_SIZE(4); + } + break; + default: + switch (cls_->GetVecSize()) { + PERMUTE_DISPATCH_VEC_SIZE(1); + PERMUTE_DISPATCH_VEC_SIZE(2); + PERMUTE_DISPATCH_VEC_SIZE(4); + } + break; } +#define TRANSPOSE_DISPATCH_VEC_SIZE +#define PERMUTE_DISPATCH_VEC_SIZE } -#undef CALL_DISPATCH_VEC_SIZE -} +}; -template -inline void PermuteAndTranspose(const int rank, - const DeviceContext& ctx, +template +inline void PermuteAndTranspose(const phi::GPUContext& ctx, + const int& rank, const phi::DenseTensor& in, phi::DenseTensor* out, - const std::vector& perm) { - const int64_t numel = in.numel(); - auto classifier = - TranposeTypeClassifier(ctx.GetSMCount(), - rank, - numel, - perm, - phi::vectorize(in.dims()), - in.data(), - out->data()); - + const DimsSimplifier& simplifier) { + T* dst_data = out->data(); + const T* src_data = in.data(); + const auto count = simplifier.GetCount(); + auto classifier = PermTypeClassifier(ctx.GetSMCount(), + simplifier.GetRank(), + simplifier.GetPerm(), + simplifier.GetSrcDims(), + src_data, + dst_data); if (classifier.GetPermType() == PermuteType::kCopy) { // If perm is [0,1,2,3], then just operate a DtoD copy. - phi::backends::gpu::GpuMemcpyAsync(out->data(), - in.data(), - numel * sizeof(T), + phi::backends::gpu::GpuMemcpyAsync(dst_data, + src_data, + count * sizeof(T), phi::gpuMemcpyDeviceToDevice, ctx.stream()); } else { - if (numel < std::numeric_limits::max()) { - LaunchWithDispatchVecSize(ctx, - classifier.GetVecSize(), - classifier.GetPermType(), - classifier.GetSrcDims(), - classifier.GetPerm(), - in.data(), - out->data(), - static_cast(numel)); + if (count < std::numeric_limits::max()) { + PermuteDispatch(ctx, + &classifier, + simplifier.GetSrcDims(), + simplifier.GetPerm(), + static_cast(count), + src_data, + dst_data); } else { - int64_t cnt = static_cast(numel); - LaunchWithDispatchVecSize(ctx, - classifier.GetVecSize(), - classifier.GetPermType(), - classifier.GetSrcDims(), - classifier.GetPerm(), - in.data(), - out->data(), - static_cast(numel)); + PermuteDispatch(ctx, + &classifier, + simplifier.GetSrcDims(), + simplifier.GetPerm(), + static_cast(count), + src_data, + dst_data); } } } +template +inline void PermuteWithEigen(const phi::GPUContext& ctx, + const int& rank, + const phi::DenseTensor& in, + phi::DenseTensor* out, + const DimsSimplifier& simplifier) { + const bool not_same_dims = simplifier.GetRank() != rank; + if (not_same_dims) { + phi::DDim dst_dims = out->dims(); + phi::DenseTensor temp_in; + + temp_in.ShareBufferWith(in); + temp_in.Resize(phi::make_ddim(simplifier.GetSrcDims())); + out->Resize(phi::make_ddim(simplifier.GetDstDims())); + + TransCompute( + simplifier.GetRank(), ctx, temp_in, out, simplifier.GetPerm()); + out->Resize(dst_dims); + } else { + TransCompute( + simplifier.GetRank(), ctx, in, out, simplifier.GetPerm()); + } +} + template void TransposeGPUKernelDriver(const phi::GPUContext& ctx, const phi::DenseTensor& in, @@ -1177,30 +1373,26 @@ void TransposeGPUKernelDriver(const phi::GPUContext& ctx, phi::DenseTensor* out) { const int rank = perm.size(); int64_t numel = in.numel(); - bool ret{false}; - if (numel >= std::numeric_limits::max()) { - ret = TransposeSimple::run(ctx, in, perm, out); - } else { - ret = TransposeSimple::run(ctx, in, perm, out); - } + bool ret = TransposeSimple::Impl(ctx, in, perm, out, numel); if (!ret) { - auto* tuner = phi::autotune::MakeTransposeTuner( - funcs::TransCompute); - tuner->AddCallBack(PermuteAndTranspose); + auto simplifier = + DimsSimplifier(rank, numel, perm, phi::vectorize(in.dims())); + auto* tuner = phi::autotune::MakeTransposeTuner(PermuteWithEigen); + tuner->AddCallBack(PermuteAndTranspose); size_t key = phi::autotune::TransposeKey( - phi::vectorize(in.dims()), - perm, + simplifier.GetSrcDims(), + simplifier.GetPerm(), paddle::experimental::CppTypeToDataType::Type()); tuner->Run(ctx, phi::autotune::AlgorithmType::kTranspose, key, - rank, ctx, + rank, in, out, - perm); + simplifier); } } diff --git a/paddle/phi/kernels/funcs/transpose_functor.h b/paddle/phi/kernels/funcs/transpose_functor.h index d2a72efed0..c3904b9c1a 100644 --- a/paddle/phi/kernels/funcs/transpose_functor.h +++ b/paddle/phi/kernels/funcs/transpose_functor.h @@ -27,161 +27,115 @@ enum { kTransposeMKLDNNFP32 = 1, kTransposeMKLDNNINT8 = 2 }; enum PermuteType { kCopy = 1, - kTranspose = 2, - kVecPermute = 3, - kGeneralPermute = 4 + kSwapTranspose = 2, + kGeneralTranspose = 3, + kVecPermute = 4, + kGeneralPermute = 5 }; constexpr int kBlockRows = 16; constexpr int kTileSize = 32; +constexpr int kShareCol = (kTileSize + 1); + +#define GETTILESIZE(LEN_, ALIGN_) \ + ((LEN_ + (ALIGN_ - 1)) & ~(ALIGN_ - 1)) / ALIGN_ -// Simplify the input dims and permute dims if possible. template -class TranposeTypeClassifier { +struct PermTypeClassifier { public: - TranposeTypeClassifier(const int sm_count, - const size_t rank, - const int64_t numel, - const std::vector& perm, - const std::vector& dims, - const T* src, - T* dst) - : perm_(rank), src_dims(rank) { - SimplifyPermAndDims(rank, dims, perm); - if (rank_ > 1) { - vec_size_ = GetPermVecSize(sm_count, src, dst); - } - perm_.resize(rank_); - src_dims.resize(rank_); - dst_dims.resize(rank_); - - for (auto i = 0; i < rank_; ++i) { - dst_dims[i] = src_dims[perm_[i]]; - } - } - - int GetRank() const { return rank_; } - int GetVecSize() const { return vec_size_; } - PermuteType GetPermType() const { return type_; } - - std::vector GetPerm() const { return perm_; } - std::vector GetSrcDims() const { return src_dims; } - std::vector GetDstDims() const { return dst_dims; } - - private: - int rank_{1}; - int vec_size_{1}; - std::vector perm_; - std::vector src_dims; - std::vector dst_dims; - PermuteType type_{kCopy}; - - void SimplifyPermAndDims(const size_t rank, - const std::vector& in_dims, - const std::vector& perm) { - int64_t combined_dims[phi::DDim::kMaxRank]; - int valid_map[phi::DDim::kMaxRank]; - - // Merge consecutive dims to the fist one dim and - // leave original dim to be 1. Example below : - // perm: [2, 3, 0, 1], origin_dims : [4, 8, 2, 5] - // new_dims: [4, 8, 2, 5] -> [32, 1, 10, 1] - int start_perm_idx = 0; - while (start_perm_idx < rank) { - const int start_dim_idx = perm[start_perm_idx]; - combined_dims[start_dim_idx] = in_dims[start_dim_idx]; - int end_perm_idx = start_perm_idx + 1; - - while (end_perm_idx < rank && - perm[end_perm_idx] == perm[end_perm_idx - 1] + 1) { - const int end_dim_idx = perm[end_perm_idx]; - combined_dims[start_dim_idx] *= in_dims[end_dim_idx]; - combined_dims[end_dim_idx] = 1; - end_perm_idx += 1; + explicit PermTypeClassifier(const int sm_count, + const int rank, + const std::vector& perm, + const std::vector& dims, + const T* src, + T* dst) { + if (rank == 1) { + type_ = PermuteType::kCopy; + } else { + constexpr int64_t dim_limitation = 65536; + const int dst_vec_size = phi::GetVectorizedSize(dst); + + // While the last dim is fixed, there is chance for vectorized IO. + const int last_idx = rank - 1; + if (perm[last_idx] == last_idx) { + type_ = PermuteType::kVecPermute; + vec_size_ = GetDimVecSize(dst_vec_size, dims[last_idx], src, false); + return; } - start_perm_idx = end_perm_idx; - } - // Reorder combined dims and marked useless dim as -1. - // for example, if combined dims is [32, 1, 10, 1], - // valid_map is [0, -1, 1, -1] and generate simplified - // dims as [32, 10] - int valid_dim_idx = 0; - bool sequential_flag = false; - for (auto i = 0; i < rank; ++i) { - const int src_dim = combined_dims[i]; - if (src_dim == 1) { - valid_map[i] = -1; - } else { - sequential_flag = true; - valid_map[i] = valid_dim_idx; - src_dims[valid_dim_idx] = src_dim; - valid_dim_idx += 1; + // Permute at last 2 dims, namely transpose. + if ((rank == 2 && perm[1] == 0 && perm[0] == 1) || + (rank == 3 && perm[2] == 1 && perm[1] == 2)) { + int64_t channel = rank == 2 ? 1 : dims[0]; + // Currently, transpose kernel cannot cover the case that channel + // dimension is more than 65536 which is the limitation of dim3 setting. + // This special case will be covered by extended transpose kernel later. + if (channel < dim_limitation) { + type_ = PermuteType::kGeneralTranspose; + num_rows_tile_ = GETTILESIZE(dims[rank - 2], kTileSize); + int dim_vec_size = GetDimVecSize(dst_vec_size, dims[last_idx], src); + int tile_size = + channel * num_rows_tile_ * GETTILESIZE(dims[last_idx], kTileSize); + vec_size_ = tile_size < sm_count ? 1 : dim_vec_size; + } else { + type_ = PermuteType::kGeneralPermute; + } + return; } - } - if (valid_dim_idx == 0) { - src_dims[0] = 1; - perm_[0] = 0; - return; - } else if (valid_dim_idx == 1) { - type_ = PermuteType::kCopy; - } - - // Acquire simplified perm with help of combined dims - // and original perm, finally simplified perm is [1, 0] - int perm_idx = 0; - for (auto i = 0; i < rank; ++i) { - const int mapped = valid_map[perm[i]]; - if (mapped >= 0) { - perm_[perm_idx] = mapped; - perm_idx += 1; + // Permute at first dim and third dim. + if (rank == 3 && perm[2] == 0 && perm[1] == 1) { + // Currently, transpose kernel cannot cover the case that channel + // dimension is more than 65536 which is the limitation of dim3 setting. + // This special case will be covered by extended transpose kernel later. + if (dims[1] < dim_limitation) { + type_ = PermuteType::kSwapTranspose; + num_rows_tile_ = GETTILESIZE(dims[0], kTileSize); + + int dim_vec_size = GetDimVecSize(dst_vec_size, dims[last_idx], src); + int tile_size = + dims[1] * num_rows_tile_ * GETTILESIZE(dims[2], kTileSize); + vec_size_ = tile_size < sm_count ? 1 : dim_vec_size; + } else { + type_ = PermuteType::kGeneralPermute; + } + return; } + vec_size_ = dst_vec_size; } - rank_ = valid_dim_idx; } - int GetPermVecSize(const int sm_count, const T* src, T* dst) { - // For gerneal_permute kernel, there is good chance for - // vectorized write. - type_ = PermuteType::kGeneralPermute; - int vec_size = phi::GetVectorizedSize(dst); - - // While the last dim is fixed, there is good chance for - // both vectorized read and write. - if (perm_[rank_ - 1] == rank_ - 1) { - int tmp_size = std::min(vec_size, phi::GetVectorizedSize(src)); - tmp_size = GetDimVesSize(tmp_size, src_dims[rank_ - 1]); - if (tmp_size > 1) { - type_ = kVecPermute; - vec_size = tmp_size; - } - } + ~PermTypeClassifier() = default; - // Once only transpose at the last 2 dims, there is good - // chance for vectorized read. - if ((rank_ == 2 && perm_[1] == 0 && perm_[0] == 1) || - (rank_ == 3 && perm_[2] == 1 && perm_[1] == 2)) { - type_ = PermuteType::kTranspose; - int tmp_vec = std::min(vec_size, phi::GetVectorizedSize(src)); - // With bytes limitation of shared_memory, the VecSize shall be - // restricted for the type whose byte-size is less than 8 (double). - vec_size = - sizeof(T) > 8 ? 1 : GetDimVesSize(tmp_vec, src_dims[rank_ - 1]); - } - return vec_size; - } + int GetVecSize() const { return vec_size_; } + int GetRowsTile() const { return num_rows_tile_; } + PermuteType GetPermType() const { return type_; } + + private: + int vec_size_{1}; + int64_t num_rows_tile_{0}; + PermuteType type_{kGeneralPermute}; // To find if highest common divisor and make it as vec_size. - int GetDimVesSize(const int vec_size, const size_t target_dim) { + int GetDimVecSize(const int dst_vec_size, + const int64_t target_dim, + const T* src, + bool use_share_mem = true) { + const int vec_size = std::min(dst_vec_size, phi::GetVectorizedSize(src)); int dim_vec_size = 1; - for (auto size = vec_size; size > 0; size /= 2) { + for (int size = vec_size; size > 0; size /= 2) { if (target_dim % size == 0) { dim_vec_size = size; break; } } - return dim_vec_size; + + if (use_share_mem) { + // By bytes limitation of shared_memory. + return (sizeof(T) > sizeof(float) ? 1 : dim_vec_size); + } else { + return dim_vec_size; + } } }; diff --git a/paddle/phi/kernels/gpu/transpose_kernel.cu b/paddle/phi/kernels/gpu/transpose_kernel.cu index 4b7265e2f3..8ae9c6ed3d 100644 --- a/paddle/phi/kernels/gpu/transpose_kernel.cu +++ b/paddle/phi/kernels/gpu/transpose_kernel.cu @@ -21,7 +21,7 @@ #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/transpose_functor.cu.h" +#include "paddle/phi/kernels/funcs/transpose_function.cu.h" #include "paddle/phi/kernels/impl/transpose_grad_kernel_impl.h" namespace phi { -- GitLab