“5c49efc004518fe3500e63e0c5ef625a62d9ceae”上不存在“...doc_cn/api/v2/git@gitcode.net:paddlepaddle/Paddle.git”
未验证 提交 a0f43889 编写于 作者: L limingshu 提交者: GitHub

Transpose optimization for AlphaFold2 (#45230)

* first commit

* fix bugs according to ci

* add some changes

* change file name into function.cu.h

* remove const_cast
上级 30f4ef7f
...@@ -21,7 +21,7 @@ limitations under the License. */ ...@@ -21,7 +21,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/functors.h" #include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/funcs/transpose_functor.cu.h" #include "paddle/phi/kernels/funcs/transpose_function.cu.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" #include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace paddle { namespace paddle {
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/funcs/transpose_functor.cu.h" #include "paddle/phi/kernels/funcs/transpose_function.cu.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" #include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace paddle { namespace paddle {
......
...@@ -243,5 +243,106 @@ struct BroadcastDimsSimplifier { ...@@ -243,5 +243,106 @@ struct BroadcastDimsSimplifier {
} }
}; };
// Simplify the input dims and permute dims if possible.
struct DimsSimplifier {
public:
explicit DimsSimplifier(const int rank,
const int64_t numel,
const std::vector<int32_t> &perm,
const std::vector<int64_t> &dims)
: perm_(rank), src_dims_(rank), count_(numel) {
SimplifyPermAndDims(rank, dims, perm);
perm_.resize(rank_);
src_dims_.resize(rank_);
dst_dims_.resize(rank_);
if (!is_seq_perm_) {
for (auto i = 0; i < rank_; ++i) {
dst_dims_[i] = src_dims_[perm_[i]];
}
} else {
dst_dims_[0] = numel;
src_dims_[0] = numel;
}
}
~DimsSimplifier() = default;
const int &GetRank() const { return rank_; }
const int64_t &GetCount() const { return count_; }
const std::vector<int> &GetPerm() const { return perm_; }
const std::vector<int64_t> &GetSrcDims() const { return src_dims_; }
const std::vector<int64_t> &GetDstDims() const { return dst_dims_; }
private:
int rank_{1};
int64_t count_{0};
bool is_seq_perm_{true};
std::vector<int> perm_;
std::vector<int64_t> src_dims_;
std::vector<int64_t> dst_dims_;
void SimplifyPermAndDims(const int rank,
const std::vector<int64_t> &in_dims,
const std::vector<int32_t> &perm) {
int start_perm_idx = 0;
int valid_dim_idx = 0;
int valid_map[phi::DDim::kMaxRank];
int64_t combined_dims[phi::DDim::kMaxRank];
// Merge consecutive dims to the fist one dim and
// leave original dim to be 1. Example below :
// perm: [2, 3, 0, 1], origin_dims : [4, 8, 2, 5]
// new_dims: [4, 8, 2, 5] -> [32, 1, 10, 1]
while (start_perm_idx < rank) {
const int start_dim_idx = perm[start_perm_idx];
combined_dims[start_dim_idx] = in_dims[start_dim_idx];
int end_perm_idx = start_perm_idx + 1;
while (end_perm_idx < rank &&
perm[end_perm_idx] == perm[end_perm_idx - 1] + 1) {
const int end_dim_idx = perm[end_perm_idx];
combined_dims[start_dim_idx] *= in_dims[end_dim_idx];
combined_dims[end_dim_idx] = 1;
end_perm_idx += 1;
}
start_perm_idx = end_perm_idx;
}
// Reorder combined dims and marked useless dim as -1.
// for example, if combined dims is [32, 1, 10, 1],
// valid_map is [0, -1, 1, -1] and generate simplified
// dims as [32, 10]
for (auto i = 0; i < rank; ++i) {
const int dim_val = combined_dims[i];
if (dim_val == 1) {
valid_map[i] = -1;
} else {
valid_map[i] = valid_dim_idx;
src_dims_[valid_dim_idx] = dim_val;
valid_dim_idx += 1;
}
}
if (valid_dim_idx == 0) {
src_dims_[0] = 1;
perm_[0] = 0;
return;
}
// Acquire simplified perm with help of combined dims
// and original perm, finally simplified perm is [1, 0]
int perm_idx = 0;
for (auto i = 0; i < rank; ++i) {
const int mapped = valid_map[perm[i]];
if (mapped >= 0) {
perm_[perm_idx] = mapped;
is_seq_perm_ &= (mapped == perm_idx);
perm_idx += 1;
}
}
rank_ = is_seq_perm_ ? 1 : valid_dim_idx;
}
};
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_utils.h" #include "paddle/phi/backends/gpu/gpu_utils.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/autotune/auto_tune_base.h" #include "paddle/phi/kernels/autotune/auto_tune_base.h"
#include "paddle/phi/kernels/funcs/dims_simplifier.h"
#include "paddle/phi/kernels/funcs/transpose_functor.h" #include "paddle/phi/kernels/funcs/transpose_functor.h"
#include "paddle/phi/kernels/primitive/datamover_primitives.h" #include "paddle/phi/kernels/primitive/datamover_primitives.h"
...@@ -191,7 +192,6 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input, ...@@ -191,7 +192,6 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
IndexType output_origin_block_flat_index = IndexType output_origin_block_flat_index =
FlatTensorIndex<IndexType>(block_tile_index_in_output, output_dims); FlatTensorIndex<IndexType>(block_tile_index_in_output, output_dims);
constexpr IndexType out_effective_thread_num = NumThreads / TileX * TileX; constexpr IndexType out_effective_thread_num = NumThreads / TileX * TileX;
if (x < out_effective_thread_num) { if (x < out_effective_thread_num) {
...@@ -652,7 +652,7 @@ struct SwapDim0And2InTranspose { ...@@ -652,7 +652,7 @@ struct SwapDim0And2InTranspose {
inline void CombineTransposeDim3(const DDim& shape, inline void CombineTransposeDim3(const DDim& shape,
const std::vector<int>& perm, const std::vector<int>& perm,
std::vector<int>* new_perm, std::vector<int>* new_perm,
DDim* new_dims) { std::vector<int>* new_dims) {
PADDLE_ENFORCE_EQ(shape.size(), PADDLE_ENFORCE_EQ(shape.size(),
perm.size(), perm.size(),
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
...@@ -667,114 +667,111 @@ inline void CombineTransposeDim3(const DDim& shape, ...@@ -667,114 +667,111 @@ inline void CombineTransposeDim3(const DDim& shape,
new_perm->resize(1); new_perm->resize(1);
(*new_perm)[0] = perm[0]; (*new_perm)[0] = perm[0];
dim_vec.push_back(shape[0]); dim_vec.push_back(shape[0]);
*new_dims = phi::make_ddim(dim_vec); } else {
return; int dim_idx = 0;
} std::vector<int> new_dim_pos(shape.size(), -1);
std::vector<int> new_dim_pos(shape.size(), -1); std::vector<int> combined_dims(shape.size(), 0);
std::vector<int64_t> combined_dims(shape.size(), 0);
int cur_head = perm[0]; int cur_head = perm[0];
new_dim_pos[cur_head] = 0; new_dim_pos[cur_head] = 0;
combined_dims[0] = shape[cur_head]; combined_dims[0] = shape[cur_head];
int dim_idx = 0; for (int perm_idx = 1; perm_idx < shape.size(); ++perm_idx) {
for (int perm_idx = 1; perm_idx < shape.size(); ++perm_idx) { // combine consecutive dimensions.
// combine consecutive dimensions. if (cur_head + 1 == perm[perm_idx]) {
if (cur_head + 1 == perm[perm_idx]) { cur_head = perm[perm_idx];
cur_head = perm[perm_idx]; combined_dims[dim_idx] *= shape[cur_head];
combined_dims[dim_idx] *= shape[cur_head]; } else {
} else { // Else start a new dimension.
// Else start a new dimension. cur_head = perm[perm_idx];
cur_head = perm[perm_idx]; dim_idx++;
dim_idx++; new_dim_pos[cur_head] = dim_idx;
new_dim_pos[cur_head] = dim_idx; combined_dims[dim_idx] = shape[cur_head];
combined_dims[dim_idx] = shape[cur_head]; }
} }
} new_perm->resize(dim_idx + 1);
new_perm->resize(dim_idx + 1); dim_idx = 0;
for (int i = 0; i < new_dim_pos.size(); ++i) {
dim_idx = 0; if (new_dim_pos[i] >= 0) {
for (int i = 0; i < new_dim_pos.size(); ++i) { int new_perm_idx = new_dim_pos[i];
if (new_dim_pos[i] >= 0) { (*new_perm)[dim_idx] = new_perm_idx;
int new_perm_idx = new_dim_pos[i]; dim_vec.push_back(combined_dims[new_perm_idx]);
(*new_perm)[dim_idx] = new_perm_idx; dim_idx++;
dim_vec.push_back(combined_dims[new_perm_idx]); }
dim_idx++;
} }
} }
*new_dims = dim_vec;
*new_dims = phi::make_ddim(dim_vec);
} }
template <typename T, typename IndexType = int> template <typename T>
struct TransposeSimple { struct TransposeSimple {
static bool run(const phi::GPUContext& ctx, static bool Impl(const phi::GPUContext& ctx,
const phi::DenseTensor& in,
const std::vector<int32_t> perm,
phi::DenseTensor* out,
const int64_t numel) {
if (numel >= std::numeric_limits<int32_t>::max()) {
return Run<int64_t>(ctx, in, perm, out);
} else {
return Run<int32_t>(ctx, in, perm, out);
}
}
private:
template <typename IndexType = int32_t>
static bool Run(const phi::GPUContext& ctx,
const phi::DenseTensor& in, const phi::DenseTensor& in,
const std::vector<int32_t> perm, const std::vector<int32_t> perm,
phi::DenseTensor* out) { phi::DenseTensor* out) {
// First reduce the dimensions of the input tensor if possible. // First reduce the dimensions of the input tensor if possible.
auto in_data = in.data<T>();
auto out_data = out->data<T>();
std::vector<int> new_perm; std::vector<int> new_perm;
DDim new_dims; std::vector<int> new_dims;
CombineTransposeDim3(in.dims(), perm, &new_perm, &new_dims); CombineTransposeDim3(in.dims(), perm, &new_perm, &new_dims);
if (new_perm.size() < 2 || new_perm.size() > 3) return false;
// Only use tile copy GPU kernel when dimension is 2 or 3.
int dims = new_dims.size();
std::vector<int> new_dim_vec = phi::vectorize<int>(new_dims);
if (dims < 2 || dims > 3) return false;
auto in_data = in.data<T>();
auto out_data = out->data<T>();
// In most cases, dim will not greater than 3 after combine. // In most cases, dim will not greater than 3 after combine.
switch (dims) { if (new_perm.size() == 2 && new_perm[1] == 0) {
case 2: // Add the first dimension size as 1.
if (new_perm[0] == 1 && new_perm[1] == 0) { new_dims.insert(new_dims.begin(), 1);
// Add the first dimension size as 1. SwapDim1And2InTranspose<T, IndexType>()(ctx, in_data, new_dims, out_data);
new_dim_vec.insert(new_dim_vec.begin(), 1); return true;
SwapDim1And2InTranspose<T, IndexType>()( } else if (new_perm == std::vector<int>({0, 2, 1})) {
ctx, in_data, new_dim_vec, out_data); SwapDim1And2InTranspose<T, IndexType>()(ctx, in_data, new_dims, out_data);
return true; return true;
} } else if (new_perm == std::vector<int>({2, 1, 0})) {
break; // Maybe can optimize later, find a way to do coalescing memory copy.
case 3: // But I think it depends on the data size. If span is not large,
// In this case, suppose we can do coalescing read and write in tile. // maybe can do coalescing.
if (new_perm == std::vector<int>({0, 2, 1})) { SwapDim0And2InTranspose<T, IndexType>()(ctx, in_data, new_dims, out_data);
SwapDim1And2InTranspose<T, IndexType>()( return true;
ctx, in_data, new_dim_vec, out_data); } else {
return true; return false;
} else if (new_perm == std::vector<int>({2, 1, 0})) {
// Maybe can optimize later, find a way to do coalescing memory copy.
// But I think it depends on the data size. If span is not large,
// maybe
// can do coalescing.
SwapDim0And2InTranspose<T, IndexType>()(
ctx, in_data, new_dim_vec, out_data);
return true;
} else {
return false;
}
break;
default:
return false;
} }
return false;
} }
}; };
template <int N, typename T> template <typename IndexT, int N>
class IdxHelper { class IdxHelper {
public: public:
IdxHelper() {} IdxHelper() {}
explicit IdxHelper(const T* dims) { explicit IdxHelper(const IndexT* dims) {
for (int i = N - 1; i >= 0; --i) { for (int i = N - 1; i >= 0; --i) {
stride_[i] = i < (N - 1) ? dims[i + 1] * stride_[i + 1] : 1; stride_[i] = i < (N - 1) ? dims[i + 1] * stride_[i + 1] : 1;
} }
} }
__device__ inline T GetStride(int idx) const { return stride_[idx]; } __device__ __forceinline__ IndexT GetStride(int idx) const {
return stride_[idx];
}
__device__ inline void GetIndexFromOffset(T offset, T* index) const { __device__ __forceinline__ void GetIndexFromOffset(IndexT offset,
T remaining = offset; IndexT* index) const {
IndexT remaining = offset;
#pragma unroll #pragma unroll
for (int i = 0; i < N - 1; ++i) { for (int i = 0; i < N - 1; ++i) {
const T idx = remaining / stride_[i]; const IndexT idx = remaining / stride_[i];
remaining -= idx * stride_[i]; remaining -= idx * stride_[i];
index[i] = idx; index[i] = idx;
} }
...@@ -782,11 +779,11 @@ class IdxHelper { ...@@ -782,11 +779,11 @@ class IdxHelper {
} }
private: private:
T stride_[N]; IndexT stride_[N];
}; };
template <int N> template <int N>
class IdxHelper<N, uint32_t> { class IdxHelper<uint32_t, N> {
public: public:
IdxHelper() {} IdxHelper() {}
explicit IdxHelper(const uint32_t* dims) { explicit IdxHelper(const uint32_t* dims) {
...@@ -797,10 +794,12 @@ class IdxHelper<N, uint32_t> { ...@@ -797,10 +794,12 @@ class IdxHelper<N, uint32_t> {
} }
} }
__device__ inline uint32_t GetStride(int idx) const { return stride_[idx]; } __device__ __forceinline__ uint32_t GetStride(int idx) const {
return stride_[idx];
}
__device__ inline void GetIndexFromOffset(uint32_t offset, __device__ __forceinline__ void GetIndexFromOffset(uint32_t offset,
uint32_t* index) const { uint32_t* index) const {
uint32_t remaining = offset; uint32_t remaining = offset;
#pragma unroll #pragma unroll
for (int i = 0; i < N - 1; ++i) { for (int i = 0; i < N - 1; ++i) {
...@@ -817,18 +816,16 @@ class IdxHelper<N, uint32_t> { ...@@ -817,18 +816,16 @@ class IdxHelper<N, uint32_t> {
}; };
// Transform index between memory offset and shape coodinate. // Transform index between memory offset and shape coodinate.
template <typename T, int N> template <typename IndexT, int N>
class IdxAndOffsetHelper { class IdxAndOffsetHelper {
public: public:
IdxAndOffsetHelper() {} IdxAndOffsetHelper() {}
~IdxAndOffsetHelper() = default; explicit IdxAndOffsetHelper(const IndexT* dims) {
index_helper = IdxHelper<IndexT, N>(dims);
explicit IdxAndOffsetHelper(const T* dims) {
index_helper = IdxHelper<N, T>(dims);
} }
__device__ inline T IndexToOffset(const T* index) const { __device__ __forceinline__ IndexT IndexToOffset(const IndexT* index) const {
T offset = 0; IndexT offset = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < N - 1; ++i) { for (int i = 0; i < N - 1; ++i) {
offset += index[i] * index_helper.GetStride(i); offset += index[i] * index_helper.GetStride(i);
...@@ -837,15 +834,16 @@ class IdxAndOffsetHelper { ...@@ -837,15 +834,16 @@ class IdxAndOffsetHelper {
return offset; return offset;
} }
__device__ inline void OffsetToIndex(T offset, T* index) const { __device__ __forceinline__ void OffsetToIndex(IndexT offset,
IndexT* index) const {
index_helper.GetIndexFromOffset(offset, index); index_helper.GetIndexFromOffset(offset, index);
} }
private: private:
IdxHelper<N, T> index_helper; IdxHelper<IndexT, N> index_helper;
}; };
template <size_t Rank, typename IndexT> template <typename IndexT, int Rank>
struct PermuteParams { struct PermuteParams {
public: public:
IdxAndOffsetHelper<IndexT, Rank> src_index_helper; IdxAndOffsetHelper<IndexT, Rank> src_index_helper;
...@@ -868,17 +866,17 @@ struct PermuteParams { ...@@ -868,17 +866,17 @@ struct PermuteParams {
// A special kernel for target case, both vectorized read and write supported. // A special kernel for target case, both vectorized read and write supported.
template <typename T, typename IndexT, int VecSize, int Rank> template <typename T, typename IndexT, int VecSize, int Rank>
__global__ void VectorizedPermuteKernel(PermuteParams<Rank, IndexT> params, __global__ void VectorizedPermuteKernel(PermuteParams<IndexT, Rank> params,
const size_t count, const IndexT count,
const T* __restrict__ src_data, const T* __restrict__ src_data,
T* dst_data) { T* dst_data) {
using VecT = phi::AlignedVector<T, VecSize>; using VecT = phi::AlignedVector<T, VecSize>;
IndexT src_index[Rank]; IndexT src_index[Rank];
IndexT dst_index[Rank]; IndexT dst_index[Rank];
const VecT* __restrict__ src = const VecT* __restrict__ vec_src =
reinterpret_cast<const VecT* __restrict__>(src_data); reinterpret_cast<const VecT* __restrict__>(src_data);
VecT* dst = reinterpret_cast<VecT*>(dst_data); VecT* vec_dst = reinterpret_cast<VecT*>(dst_data);
IndexT tid = blockIdx.x * blockDim.x + threadIdx.x; IndexT tid = blockIdx.x * blockDim.x + threadIdx.x;
for (IndexT i = tid; i < count; i += blockDim.x * gridDim.x) { for (IndexT i = tid; i < count; i += blockDim.x * gridDim.x) {
...@@ -889,31 +887,23 @@ __global__ void VectorizedPermuteKernel(PermuteParams<Rank, IndexT> params, ...@@ -889,31 +887,23 @@ __global__ void VectorizedPermuteKernel(PermuteParams<Rank, IndexT> params,
src_index[params.perm[j]] = dst_index[j]; src_index[params.perm[j]] = dst_index[j];
} }
IndexT src_offset = params.src_index_helper.IndexToOffset(src_index); IndexT src_offset = params.src_index_helper.IndexToOffset(src_index);
dst[i] = src[src_offset]; vec_dst[i] = vec_src[src_offset];
} }
} }
// A general kernel for normal case, only support vectorized write. // A general kernel for normal case, only support vectorized write.
template <typename T, typename IndexT, int VecSize, int Rank> template <typename T, typename IndexT, int VecSize, int Rank>
__global__ void GeneralPermuteKernel(PermuteParams<Rank, IndexT> params, __global__ void GeneralPermuteKernel(PermuteParams<IndexT, Rank> params,
const IndexT main_cnt,
const IndexT tail_cnt,
const IndexT offset,
const T* __restrict__ src, const T* __restrict__ src,
T* dst, T* dst) {
const size_t main_cnt,
const size_t tail_cnt,
const size_t offset) {
using VecT = phi::AlignedVector<T, VecSize>; using VecT = phi::AlignedVector<T, VecSize>;
VecT* vec_dst = reinterpret_cast<VecT*>(dst); VecT* vec_dst = reinterpret_cast<VecT*>(dst);
IndexT src_index[VecSize][Rank]; IndexT src_index[VecSize][Rank];
IndexT dst_index[VecSize][Rank]; IndexT dst_index[VecSize][Rank];
// Avoid read perm data both in 2 load process.
__shared__ int perm[Rank];
if (threadIdx.x < Rank) {
perm[threadIdx.x] = params.perm[threadIdx.x];
}
__syncthreads();
// Vectorized load data. // Vectorized load data.
IndexT tid = blockIdx.x * blockDim.x + threadIdx.x; IndexT tid = blockIdx.x * blockDim.x + threadIdx.x;
for (IndexT idx = tid; idx < main_cnt; idx += blockDim.x * gridDim.x) { for (IndexT idx = tid; idx < main_cnt; idx += blockDim.x * gridDim.x) {
...@@ -926,7 +916,7 @@ __global__ void GeneralPermuteKernel(PermuteParams<Rank, IndexT> params, ...@@ -926,7 +916,7 @@ __global__ void GeneralPermuteKernel(PermuteParams<Rank, IndexT> params,
#pragma unroll #pragma unroll
for (int j = 0; j < Rank; ++j) { for (int j = 0; j < Rank; ++j) {
src_index[i][perm[j]] = dst_index[i][j]; src_index[i][params.perm[j]] = dst_index[i][j];
} }
IndexT src_offset = params.src_index_helper.IndexToOffset(src_index[i]); IndexT src_offset = params.src_index_helper.IndexToOffset(src_index[i]);
vec_data[i] = src[src_offset]; vec_data[i] = src[src_offset];
...@@ -941,235 +931,441 @@ __global__ void GeneralPermuteKernel(PermuteParams<Rank, IndexT> params, ...@@ -941,235 +931,441 @@ __global__ void GeneralPermuteKernel(PermuteParams<Rank, IndexT> params,
#pragma unroll #pragma unroll
for (int j = 0; j < Rank; ++j) { for (int j = 0; j < Rank; ++j) {
src_index[0][perm[j]] = dst_index[0][j]; src_index[0][params.perm[j]] = dst_index[0][j];
} }
IndexT src_offset = params.src_index_helper.IndexToOffset(src_index[0]); IndexT src_offset = params.src_index_helper.IndexToOffset(src_index[0]);
dst[idx] = src[src_offset]; dst[idx] = src[src_offset];
} }
} }
// A Gerneral permute method that drectly find the dst data template <typename T, typename IndexT, int ReadSize, int WriteSize>
// coordinate in the source data. struct TransposeDataWriter {
template <typename T, typename IndexT, int VecSize, int Rank> __device__ __forceinline__ void operator()(T* dst_data,
inline void LaunchPermuteKernel(const phi::GPUContext& ctx, const T* s_data,
const IndexT count, const IndexT rows,
const PermuteType perm_type, const IndexT cols,
const std::vector<int64_t>& dims, const IndexT chs_stride,
const std::vector<int>& perm, const IndexT round_tile_cols,
const T* src, const IndexT col_stride = 1) {
T* dst) { using OutVecT = phi::AlignedVector<T, WriteSize>;
size_t main_count = count / VecSize; OutVecT* vec_dst = reinterpret_cast<OutVecT*>(dst_data);
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, main_count);
constexpr int kColTile = kTileSize * ReadSize;
if (perm_type == PermuteType::kGeneralPermute) { constexpr int kColStride = kShareCol * ReadSize;
size_t tail_count = count - main_count * VecSize;
size_t offset = count - tail_count; const IndexT vec_rows = rows / WriteSize;
auto params = PermuteParams<Rank, IndexT>(dims, perm); const IndexT col_in_mat = blockIdx.y * kTileSize + threadIdx.x;
GeneralPermuteKernel<T, IndexT, VecSize, Rank> if (col_in_mat < /*dst_cols=*/vec_rows) {
<<<config.GetGridSize(), config.GetBlockSize(), 0, ctx.stream()>>>( const int cols_range = (blockIdx.x < round_tile_cols)
params, src, dst, main_count, tail_count, offset); ? kTileSize
} else { : (cols - round_tile_cols * kTileSize);
std::vector<int64_t> vec_dims(dims); const int share_tile = threadIdx.x * (WriteSize * kColStride);
vec_dims[dims.size() - 1] /= VecSize; const IndexT write_offset = blockIdx.z * chs_stride + col_in_mat;
auto params = PermuteParams<Rank, IndexT>(vec_dims, perm); #pragma unroll
for (int tile_y = threadIdx.y; tile_y < cols_range;
VectorizedPermuteKernel<T, IndexT, VecSize, Rank> tile_y += kBlockRows) {
<<<config.GetGridSize(), config.GetBlockSize(), 0, ctx.stream()>>>( OutVecT tmp_data[ReadSize];
params, main_count, src, dst); #pragma unroll
for (int i = 0; i < ReadSize; ++i) {
int tile_tail = tile_y * ReadSize + i;
int major_share_idx = share_tile + tile_tail;
IndexT row_in_mat = (blockIdx.x * kColTile + tile_tail) * col_stride;
#pragma unroll
for (int j = 0; j < WriteSize; ++j) {
tmp_data[i].val[j] = s_data[j * kColStride + major_share_idx];
}
vec_dst[write_offset + row_in_mat * vec_rows] = tmp_data[i];
}
}
}
} }
} };
template <typename T, typename IndexT, int VecSize> template <typename T, typename IndexT, int ReadSize>
inline void LaunchPermuteRankDispatch(const phi::GPUContext& ctx, struct TransposeDataWriter<T, IndexT, ReadSize, 1> {
const IndexT count, __device__ __forceinline__ void operator()(T* dst_data,
const PermuteType perm_type, const T* s_data,
const std::vector<int64_t>& dims, const IndexT rows,
const std::vector<int>& perm, const IndexT cols,
const T* src, const IndexT chs_stride,
T* dst) { const IndexT round_tile_cols,
#define CALL_DISPATCH_RANK(rank) \ const IndexT col_stride = 1) {
case rank: { \ const IndexT col_in_mat = blockIdx.y * kTileSize + threadIdx.x;
LaunchPermuteKernel<T, IndexT, VecSize, rank>( \ if (col_in_mat < /*dst_cols=*/rows) {
ctx, count, perm_type, dims, perm, src, dst); \ const int cols_range = (blockIdx.x < round_tile_cols)
break; \ ? kTileSize
: (cols - round_tile_cols * kTileSize);
const IndexT row_tile = blockIdx.x * kTileSize * ReadSize;
const IndexT write_offset = blockIdx.z * chs_stride + col_in_mat;
const int shared_tile = threadIdx.x * kShareCol * ReadSize;
#pragma unroll
for (int tile_y = threadIdx.y; tile_y < cols_range;
tile_y += kBlockRows) {
const int shared_major = shared_tile + tile_y * ReadSize;
const IndexT row_major = (row_tile + tile_y * ReadSize) * col_stride;
#pragma unroll
for (int i = 0; i < ReadSize; ++i) {
const IndexT row_in_mat = row_major + i * col_stride;
dst_data[write_offset + row_in_mat * rows] = s_data[shared_major + i];
}
}
}
} }
};
switch (dims.size()) { template <typename T, typename IndexT, int VecSize, IndexT RowTile>
CALL_DISPATCH_RANK(1); struct TransposeDataReader {
CALL_DISPATCH_RANK(2); __device__ __forceinline__ void operator()(const T* __restrict__ src,
CALL_DISPATCH_RANK(3); T* s_shared,
CALL_DISPATCH_RANK(4); const IndexT cols,
CALL_DISPATCH_RANK(5); const IndexT rows,
CALL_DISPATCH_RANK(6); const IndexT chs_stride,
CALL_DISPATCH_RANK(7); const IndexT cols_thresh,
CALL_DISPATCH_RANK(8); const IndexT round_tile_rows) {
CALL_DISPATCH_RANK(9); using VecT = phi::AlignedVector<T, VecSize>;
const VecT* __restrict__ v_src =
reinterpret_cast<const VecT* __restrict__>(src);
VecT* v_shared = reinterpret_cast<VecT*>(s_shared);
const IndexT col_in_mat = blockIdx.x * kTileSize + threadIdx.x;
if (col_in_mat < cols_thresh) {
const int row_range = (blockIdx.y < round_tile_rows)
? RowTile
: (rows - RowTile * round_tile_rows);
const IndexT src_idx_major = blockIdx.z * chs_stride + col_in_mat;
#pragma unroll
for (int tile_y = threadIdx.y; tile_y < row_range; tile_y += kBlockRows) {
const IndexT row_in_mat = blockIdx.y * RowTile + tile_y;
v_shared[tile_y * kShareCol + threadIdx.x] =
v_src[row_in_mat * cols + src_idx_major];
}
}
__syncthreads();
} }
#undef CALL_DISPATCH_RANK };
}
// Aim at transposing the last 2 dimensions. Reference from // Aim at transposing the last 2 dimensions. Reference from
// https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/ // https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/
template <typename T, typename IndexT, int VecSize> template <typename T,
typename IndexT,
bool IsVecWrite,
int ReadSize,
int WriteSize = (IsVecWrite && (sizeof(T) < sizeof(float)))
? sizeof(float) / sizeof(T)
: 1>
__global__ void SwapTransposeKernel(const T* __restrict__ src_data,
T* dst_data,
const IndexT round_tile_rows,
const IndexT round_tile_cols,
const IndexT cols,
const IndexT rows,
const IndexT chs /*=channel*/) {
constexpr int kRowTile = kTileSize * WriteSize;
__shared__ T s_data[kRowTile * kShareCol * ReadSize];
const IndexT chs_stride = chs * cols;
TransposeDataReader<T, IndexT, ReadSize, kRowTile>()(
src_data, s_data, chs_stride, rows, cols, cols, round_tile_rows);
TransposeDataWriter<T, IndexT, ReadSize, WriteSize>()(
dst_data, s_data, rows, cols, rows / WriteSize, round_tile_cols, chs);
}
template <typename T,
typename IndexT,
bool IsVecWrite,
int ReadSize,
int WriteSize = (IsVecWrite && (sizeof(T) < sizeof(float)))
? sizeof(float) / sizeof(T)
: 1>
__global__ void BatchTransposeKernel(const T* __restrict__ src_data, __global__ void BatchTransposeKernel(const T* __restrict__ src_data,
T* dst_data, T* dst_data,
IndexT rows, const IndexT round_tile_rows,
IndexT cols, const IndexT round_tile_cols,
IndexT round_tile_rows, const IndexT cols,
IndexT round_tile_cols) { const IndexT rows) {
using VecT = phi::AlignedVector<T, VecSize>; constexpr int kRowTile = kTileSize * WriteSize;
constexpr int kShareCol = kTileSize + 1; __shared__ T s_data[kRowTile * kShareCol * ReadSize];
__shared__ VecT v_shared[kTileSize * kShareCol];
T* s_shared = reinterpret_cast<T*>(v_shared); const IndexT chs_stride = rows * cols;
TransposeDataReader<T, IndexT, ReadSize, kRowTile>()(
// Vectorized load data from src into shared memory. [rows, cols] src_data, s_data, cols, rows, chs_stride, cols, round_tile_rows);
const VecT* __restrict__ vec_src = TransposeDataWriter<T, IndexT, ReadSize, WriteSize>()(
reinterpret_cast<const VecT* __restrict__>(src_data); dst_data,
s_data,
rows,
cols,
chs_stride * ReadSize / WriteSize,
round_tile_cols);
}
IndexT col_in_matrix = blockIdx.x * kTileSize + threadIdx.x; template <typename T, typename IndexT, int VecSize>
IndexT offset = blockIdx.z * rows * cols; struct PermuteLauncher {
public:
void operator()(const phi::GPUContext& ctx,
const int& rank,
const IndexT& count,
const PermuteType& perm_type,
const std::vector<int64_t>& dims,
const std::vector<int32_t>& perm,
const T* src,
T* dst) {
dims_ = dims;
main_cnt_ = count / VecSize;
#define CALL_PERMUTE_DISPATCH_RANK(rank_) \
case rank_: { \
Run<rank_>(ctx, perm, perm_type, count, src, dst); \
break; \
}
if (col_in_matrix < cols) { switch (rank) {
int row_range = (blockIdx.y < round_tile_rows) CALL_PERMUTE_DISPATCH_RANK(3);
? kTileSize CALL_PERMUTE_DISPATCH_RANK(4);
: (rows - kTileSize * round_tile_rows); CALL_PERMUTE_DISPATCH_RANK(5);
#pragma unroll CALL_PERMUTE_DISPATCH_RANK(6);
for (int tile_y = threadIdx.y; tile_y < row_range; tile_y += kBlockRows) { CALL_PERMUTE_DISPATCH_RANK(7);
IndexT row_in_matrix = tile_y + blockIdx.y * kTileSize; CALL_PERMUTE_DISPATCH_RANK(8);
v_shared[tile_y * kShareCol + threadIdx.x] = CALL_PERMUTE_DISPATCH_RANK(9);
vec_src[offset + row_in_matrix * cols + col_in_matrix];
} }
#undef CALL_PERMUTE_DISPATCH_RANK
} }
// Write data from shared memory into dst and private:
// dst_cols = rows, dst_rows = cols * Vecsize IndexT main_cnt_{0};
col_in_matrix = blockIdx.y * kTileSize + threadIdx.x; std::vector<int64_t> dims_;
offset = offset * VecSize + col_in_matrix;
__syncthreads(); template <int Rank>
void Run(const phi::GPUContext& ctx,
const std::vector<int32_t>& perm,
const PermuteType& perm_type,
const IndexT& count,
const T* src,
T* dst) {
auto cfg = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, main_cnt_);
if (perm_type == PermuteType::kVecPermute) {
dims_[Rank - 1] /= VecSize;
const auto params = PermuteParams<IndexT, Rank>(dims_, perm);
VectorizedPermuteKernel<T, IndexT, VecSize, Rank>
<<<cfg.block_per_grid, cfg.thread_per_block, 0, ctx.stream()>>>(
params, main_cnt_, src, dst);
} else {
IndexT tail_cnt = count - main_cnt_ * VecSize;
IndexT main_offset = count - tail_cnt;
const auto params = PermuteParams<IndexT, Rank>(dims_, perm);
if (col_in_matrix < /*dst_cols=*/rows) { GeneralPermuteKernel<T, IndexT, VecSize, Rank>
int col_range = (blockIdx.x < round_tile_cols) <<<cfg.block_per_grid, cfg.thread_per_block, 0, ctx.stream()>>>(
? kTileSize params, main_cnt_, tail_cnt, main_offset, src, dst);
: (cols - kTileSize * round_tile_cols); }
#pragma unroll }
for (IndexT tile_y = threadIdx.y; tile_y < col_range; };
tile_y += kBlockRows) {
#pragma unroll template <typename T, typename IndexT, int VecSize>
for (int i = 0; i < VecSize; ++i) { struct TransposeLauncher {
IndexT row_in_matrix = (tile_y + blockIdx.x * kTileSize) * VecSize + i; public:
IndexT shared_idx = (tile_y + threadIdx.x * kShareCol) * VecSize + i; void operator()(const phi::GPUContext& ctx,
dst_data[offset + row_in_matrix * rows] = s_shared[shared_idx]; const int& rank,
const PermuteType& perm_type,
const std::vector<int64_t>& dims,
const IndexT& num_rows_tile,
const T* src,
T* dst) {
constexpr int ReadSize = sizeof(T) > sizeof(float) ? 1 : VecSize;
const IndexT cols = dims[rank - 1] / VecSize;
const IndexT n_cols_tile = GETTILESIZE(cols, kTileSize);
if (perm_type == PermuteType::kGeneralTranspose) {
IndexT chs = (rank == 2) ? 1 : dims[0];
IndexT rows = dims[rank - 2];
IndexT n_rows_tile =
FindRowTiles(chs, rows, num_rows_tile, n_cols_tile, ctx.GetSMCount());
dim3 blocks(n_cols_tile, n_rows_tile, chs);
dim3 threads(kTileSize, kBlockRows, 1);
if (is_vec_write) {
BatchTransposeKernel<T, IndexT, true, ReadSize>
<<<blocks, threads, 0, ctx.stream()>>>(
src, dst, n_rows_tile - 1, n_cols_tile - 1, cols, rows);
} else {
BatchTransposeKernel<T, IndexT, false, ReadSize>
<<<blocks, threads, 0, ctx.stream()>>>(
src, dst, n_rows_tile - 1, n_cols_tile - 1, cols, rows);
}
} else {
IndexT rows = dims[0];
IndexT chs = dims[rank - 2];
IndexT n_rows_tile =
FindRowTiles(chs, rows, num_rows_tile, n_cols_tile, ctx.GetSMCount());
dim3 blocks(n_cols_tile, n_rows_tile, chs);
dim3 threads(kTileSize, kBlockRows, 1);
if (is_vec_write) {
SwapTransposeKernel<T, IndexT, true, ReadSize>
<<<blocks, threads, 0, ctx.stream()>>>(
src, dst, n_rows_tile - 1, n_cols_tile - 1, cols, rows, chs);
} else {
SwapTransposeKernel<T, IndexT, false, ReadSize>
<<<blocks, threads, 0, ctx.stream()>>>(
src, dst, n_rows_tile - 1, n_cols_tile - 1, cols, rows, chs);
} }
} }
} }
}
// With the byte limitation of shared_memory, the VecSize shall be private:
// restricted for the type whose byte-size is less than 4. bool is_vec_write{false};
template <typename T, inline IndexT FindRowTiles(const IndexT& chs,
typename IndexT, const IndexT& rows,
int Size, const IndexT& num_rows_tile,
int VecSize = (sizeof(T) > 4 ? 1 : Size)> const IndexT& num_cols_tile,
inline void LaunchTransposeKernel(const phi::GPUContext& ctx, const int& sm_count) {
const std::vector<int64_t>& dims, constexpr int kVecRow = sizeof(float) / sizeof(T);
const T* src, is_vec_write =
T* dst) { (sizeof(T) < sizeof(float)) ? ((rows % kVecRow) ? false : true) : false;
auto rank = dims.size();
IndexT num_batches = (rank == 2) ? 1 : dims[0]; int vec_write = 1;
IndexT rows = dims[rank - 2]; if (is_vec_write) {
IndexT cols = dims[rank - 1] / VecSize; is_vec_write = (chs * num_cols_tile * num_rows_tile) > sm_count;
IndexT num_tile_rows = (rows + kTileSize - 1) / kTileSize; vec_write = is_vec_write ? kVecRow : 1;
IndexT num_tile_cols = (cols + kTileSize - 1) / kTileSize; }
IndexT n_rows_tile = is_vec_write
dim3 blocks(num_tile_cols, num_tile_rows, num_batches); ? GETTILESIZE(rows, (kTileSize * vec_write))
dim3 threads(kTileSize, kBlockRows, 1); : num_rows_tile;
return n_rows_tile;
BatchTransposeKernel<T, IndexT, VecSize> }
<<<blocks, threads, 0, ctx.stream()>>>( };
src, dst, rows, cols, num_tile_rows - 1, num_tile_cols - 1);
}
template <typename T, typename IndexT> template <typename T, typename IndexT>
inline void LaunchWithDispatchVecSize(const phi::GPUContext& ctx, struct PermuteDispatch {
const int vec_size, public:
const PermuteType perm_type, PermuteDispatch(const phi::GPUContext& ctx,
const std::vector<int64_t>& dims, PermTypeClassifier<T>* cls_ptr,
const std::vector<int>& perm, const std::vector<int64_t>& dims,
const T* src, const std::vector<int32_t>& perm,
T* dst, const IndexT count,
IndexT count) { const T* src,
#define CALL_DISPATCH_VEC_SIZE(vec_size) \ T* dst)
case vec_size: { \ : dims_(dims), cls_(cls_ptr) {
if (perm_type == PermuteType::kTranspose) { \ rank_ = dims_.size();
LaunchTransposeKernel<T, IndexT, vec_size>(ctx, dims, src, dst); \ type_ = cls_->GetPermType();
} else { \ KernelTypeDispatch(ctx, count, perm, src, dst);
LaunchPermuteRankDispatch<T, IndexT, vec_size>( \
ctx, count, perm_type, dims, perm, src, dst); \
} \
break; \
} }
~PermuteDispatch() {}
switch (vec_size) { private:
CALL_DISPATCH_VEC_SIZE(1); int rank_{0};
CALL_DISPATCH_VEC_SIZE(2); std::vector<int64_t> dims_;
CALL_DISPATCH_VEC_SIZE(4); PermTypeClassifier<T>* cls_;
default: { PermuteType type_{kGeneralPermute};
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported vectorized size: %d !", vec_size)); void KernelTypeDispatch(const phi::GPUContext& ctx,
break; const IndexT& count,
const std::vector<int32_t>& perm,
const T* src,
T* dst) {
#define TRANSPOSE_DISPATCH_VEC_SIZE(size) \
case size: { \
TransposeLauncher<T, IndexT, size>()( \
ctx, rank_, type_, dims_, cls_->GetRowsTile(), src, dst); \
break; \
}
#define PERMUTE_DISPATCH_VEC_SIZE(size) \
case size: { \
PermuteLauncher<T, IndexT, size>()( \
ctx, rank_, count, type_, dims_, perm, src, dst); \
break; \
}
switch (type_) {
case kSwapTranspose:
case kGeneralTranspose:
switch (cls_->GetVecSize()) {
TRANSPOSE_DISPATCH_VEC_SIZE(1);
TRANSPOSE_DISPATCH_VEC_SIZE(2);
TRANSPOSE_DISPATCH_VEC_SIZE(4);
}
break;
default:
switch (cls_->GetVecSize()) {
PERMUTE_DISPATCH_VEC_SIZE(1);
PERMUTE_DISPATCH_VEC_SIZE(2);
PERMUTE_DISPATCH_VEC_SIZE(4);
}
break;
} }
#define TRANSPOSE_DISPATCH_VEC_SIZE
#define PERMUTE_DISPATCH_VEC_SIZE
} }
#undef CALL_DISPATCH_VEC_SIZE };
}
template <typename DeviceContext, typename T> template <typename T>
inline void PermuteAndTranspose(const int rank, inline void PermuteAndTranspose(const phi::GPUContext& ctx,
const DeviceContext& ctx, const int& rank,
const phi::DenseTensor& in, const phi::DenseTensor& in,
phi::DenseTensor* out, phi::DenseTensor* out,
const std::vector<int32_t>& perm) { const DimsSimplifier& simplifier) {
const int64_t numel = in.numel(); T* dst_data = out->data<T>();
auto classifier = const T* src_data = in.data<T>();
TranposeTypeClassifier<T>(ctx.GetSMCount(), const auto count = simplifier.GetCount();
rank, auto classifier = PermTypeClassifier<T>(ctx.GetSMCount(),
numel, simplifier.GetRank(),
perm, simplifier.GetPerm(),
phi::vectorize<int64_t>(in.dims()), simplifier.GetSrcDims(),
in.data<T>(), src_data,
out->data<T>()); dst_data);
if (classifier.GetPermType() == PermuteType::kCopy) { if (classifier.GetPermType() == PermuteType::kCopy) {
// If perm is [0,1,2,3], then just operate a DtoD copy. // If perm is [0,1,2,3], then just operate a DtoD copy.
phi::backends::gpu::GpuMemcpyAsync(out->data<T>(), phi::backends::gpu::GpuMemcpyAsync(dst_data,
in.data<T>(), src_data,
numel * sizeof(T), count * sizeof(T),
phi::gpuMemcpyDeviceToDevice, phi::gpuMemcpyDeviceToDevice,
ctx.stream()); ctx.stream());
} else { } else {
if (numel < std::numeric_limits<int>::max()) { if (count < std::numeric_limits<uint32_t>::max()) {
LaunchWithDispatchVecSize<T, int>(ctx, PermuteDispatch<T, uint32_t>(ctx,
classifier.GetVecSize(), &classifier,
classifier.GetPermType(), simplifier.GetSrcDims(),
classifier.GetSrcDims(), simplifier.GetPerm(),
classifier.GetPerm(), static_cast<uint32_t>(count),
in.data<T>(), src_data,
out->data<T>(), dst_data);
static_cast<int>(numel));
} else { } else {
int64_t cnt = static_cast<int64_t>(numel); PermuteDispatch<T, int64_t>(ctx,
LaunchWithDispatchVecSize<T, int64_t>(ctx, &classifier,
classifier.GetVecSize(), simplifier.GetSrcDims(),
classifier.GetPermType(), simplifier.GetPerm(),
classifier.GetSrcDims(), static_cast<int64_t>(count),
classifier.GetPerm(), src_data,
in.data<T>(), dst_data);
out->data<T>(),
static_cast<int64_t>(numel));
} }
} }
} }
template <typename T>
inline void PermuteWithEigen(const phi::GPUContext& ctx,
const int& rank,
const phi::DenseTensor& in,
phi::DenseTensor* out,
const DimsSimplifier& simplifier) {
const bool not_same_dims = simplifier.GetRank() != rank;
if (not_same_dims) {
phi::DDim dst_dims = out->dims();
phi::DenseTensor temp_in;
temp_in.ShareBufferWith(in);
temp_in.Resize(phi::make_ddim(simplifier.GetSrcDims()));
out->Resize(phi::make_ddim(simplifier.GetDstDims()));
TransCompute<phi::GPUContext, T>(
simplifier.GetRank(), ctx, temp_in, out, simplifier.GetPerm());
out->Resize(dst_dims);
} else {
TransCompute<phi::GPUContext, T>(
simplifier.GetRank(), ctx, in, out, simplifier.GetPerm());
}
}
template <typename T> template <typename T>
void TransposeGPUKernelDriver(const phi::GPUContext& ctx, void TransposeGPUKernelDriver(const phi::GPUContext& ctx,
const phi::DenseTensor& in, const phi::DenseTensor& in,
...@@ -1177,30 +1373,26 @@ void TransposeGPUKernelDriver(const phi::GPUContext& ctx, ...@@ -1177,30 +1373,26 @@ void TransposeGPUKernelDriver(const phi::GPUContext& ctx,
phi::DenseTensor* out) { phi::DenseTensor* out) {
const int rank = perm.size(); const int rank = perm.size();
int64_t numel = in.numel(); int64_t numel = in.numel();
bool ret{false}; bool ret = TransposeSimple<T>::Impl(ctx, in, perm, out, numel);
if (numel >= std::numeric_limits<int32_t>::max()) {
ret = TransposeSimple<T, int64_t>::run(ctx, in, perm, out);
} else {
ret = TransposeSimple<T>::run(ctx, in, perm, out);
}
if (!ret) { if (!ret) {
auto* tuner = phi::autotune::MakeTransposeTuner<T>( auto simplifier =
funcs::TransCompute<phi::GPUContext, T>); DimsSimplifier(rank, numel, perm, phi::vectorize<int64_t>(in.dims()));
tuner->AddCallBack(PermuteAndTranspose<phi::GPUContext, T>); auto* tuner = phi::autotune::MakeTransposeTuner<T>(PermuteWithEigen<T>);
tuner->AddCallBack(PermuteAndTranspose<T>);
size_t key = phi::autotune::TransposeKey( size_t key = phi::autotune::TransposeKey(
phi::vectorize(in.dims()), simplifier.GetSrcDims(),
perm, simplifier.GetPerm(),
paddle::experimental::CppTypeToDataType<T>::Type()); paddle::experimental::CppTypeToDataType<T>::Type());
tuner->Run(ctx, tuner->Run(ctx,
phi::autotune::AlgorithmType::kTranspose, phi::autotune::AlgorithmType::kTranspose,
key, key,
rank,
ctx, ctx,
rank,
in, in,
out, out,
perm); simplifier);
} }
} }
......
...@@ -27,161 +27,115 @@ enum { kTransposeMKLDNNFP32 = 1, kTransposeMKLDNNINT8 = 2 }; ...@@ -27,161 +27,115 @@ enum { kTransposeMKLDNNFP32 = 1, kTransposeMKLDNNINT8 = 2 };
enum PermuteType { enum PermuteType {
kCopy = 1, kCopy = 1,
kTranspose = 2, kSwapTranspose = 2,
kVecPermute = 3, kGeneralTranspose = 3,
kGeneralPermute = 4 kVecPermute = 4,
kGeneralPermute = 5
}; };
constexpr int kBlockRows = 16; constexpr int kBlockRows = 16;
constexpr int kTileSize = 32; constexpr int kTileSize = 32;
constexpr int kShareCol = (kTileSize + 1);
#define GETTILESIZE(LEN_, ALIGN_) \
((LEN_ + (ALIGN_ - 1)) & ~(ALIGN_ - 1)) / ALIGN_
// Simplify the input dims and permute dims if possible.
template <typename T> template <typename T>
class TranposeTypeClassifier { struct PermTypeClassifier {
public: public:
TranposeTypeClassifier(const int sm_count, explicit PermTypeClassifier(const int sm_count,
const size_t rank, const int rank,
const int64_t numel, const std::vector<int32_t>& perm,
const std::vector<int32_t>& perm, const std::vector<int64_t>& dims,
const std::vector<int64_t>& dims, const T* src,
const T* src, T* dst) {
T* dst) if (rank == 1) {
: perm_(rank), src_dims(rank) { type_ = PermuteType::kCopy;
SimplifyPermAndDims(rank, dims, perm); } else {
if (rank_ > 1) { constexpr int64_t dim_limitation = 65536;
vec_size_ = GetPermVecSize(sm_count, src, dst); const int dst_vec_size = phi::GetVectorizedSize<T>(dst);
}
perm_.resize(rank_); // While the last dim is fixed, there is chance for vectorized IO.
src_dims.resize(rank_); const int last_idx = rank - 1;
dst_dims.resize(rank_); if (perm[last_idx] == last_idx) {
type_ = PermuteType::kVecPermute;
for (auto i = 0; i < rank_; ++i) { vec_size_ = GetDimVecSize(dst_vec_size, dims[last_idx], src, false);
dst_dims[i] = src_dims[perm_[i]]; return;
}
}
int GetRank() const { return rank_; }
int GetVecSize() const { return vec_size_; }
PermuteType GetPermType() const { return type_; }
std::vector<int> GetPerm() const { return perm_; }
std::vector<int64_t> GetSrcDims() const { return src_dims; }
std::vector<int64_t> GetDstDims() const { return dst_dims; }
private:
int rank_{1};
int vec_size_{1};
std::vector<int> perm_;
std::vector<int64_t> src_dims;
std::vector<int64_t> dst_dims;
PermuteType type_{kCopy};
void SimplifyPermAndDims(const size_t rank,
const std::vector<int64_t>& in_dims,
const std::vector<int32_t>& perm) {
int64_t combined_dims[phi::DDim::kMaxRank];
int valid_map[phi::DDim::kMaxRank];
// Merge consecutive dims to the fist one dim and
// leave original dim to be 1. Example below :
// perm: [2, 3, 0, 1], origin_dims : [4, 8, 2, 5]
// new_dims: [4, 8, 2, 5] -> [32, 1, 10, 1]
int start_perm_idx = 0;
while (start_perm_idx < rank) {
const int start_dim_idx = perm[start_perm_idx];
combined_dims[start_dim_idx] = in_dims[start_dim_idx];
int end_perm_idx = start_perm_idx + 1;
while (end_perm_idx < rank &&
perm[end_perm_idx] == perm[end_perm_idx - 1] + 1) {
const int end_dim_idx = perm[end_perm_idx];
combined_dims[start_dim_idx] *= in_dims[end_dim_idx];
combined_dims[end_dim_idx] = 1;
end_perm_idx += 1;
} }
start_perm_idx = end_perm_idx;
}
// Reorder combined dims and marked useless dim as -1. // Permute at last 2 dims, namely transpose.
// for example, if combined dims is [32, 1, 10, 1], if ((rank == 2 && perm[1] == 0 && perm[0] == 1) ||
// valid_map is [0, -1, 1, -1] and generate simplified (rank == 3 && perm[2] == 1 && perm[1] == 2)) {
// dims as [32, 10] int64_t channel = rank == 2 ? 1 : dims[0];
int valid_dim_idx = 0; // Currently, transpose kernel cannot cover the case that channel
bool sequential_flag = false; // dimension is more than 65536 which is the limitation of dim3 setting.
for (auto i = 0; i < rank; ++i) { // This special case will be covered by extended transpose kernel later.
const int src_dim = combined_dims[i]; if (channel < dim_limitation) {
if (src_dim == 1) { type_ = PermuteType::kGeneralTranspose;
valid_map[i] = -1; num_rows_tile_ = GETTILESIZE(dims[rank - 2], kTileSize);
} else { int dim_vec_size = GetDimVecSize(dst_vec_size, dims[last_idx], src);
sequential_flag = true; int tile_size =
valid_map[i] = valid_dim_idx; channel * num_rows_tile_ * GETTILESIZE(dims[last_idx], kTileSize);
src_dims[valid_dim_idx] = src_dim; vec_size_ = tile_size < sm_count ? 1 : dim_vec_size;
valid_dim_idx += 1; } else {
type_ = PermuteType::kGeneralPermute;
}
return;
} }
}
if (valid_dim_idx == 0) { // Permute at first dim and third dim.
src_dims[0] = 1; if (rank == 3 && perm[2] == 0 && perm[1] == 1) {
perm_[0] = 0; // Currently, transpose kernel cannot cover the case that channel
return; // dimension is more than 65536 which is the limitation of dim3 setting.
} else if (valid_dim_idx == 1) { // This special case will be covered by extended transpose kernel later.
type_ = PermuteType::kCopy; if (dims[1] < dim_limitation) {
} type_ = PermuteType::kSwapTranspose;
num_rows_tile_ = GETTILESIZE(dims[0], kTileSize);
// Acquire simplified perm with help of combined dims
// and original perm, finally simplified perm is [1, 0] int dim_vec_size = GetDimVecSize(dst_vec_size, dims[last_idx], src);
int perm_idx = 0; int tile_size =
for (auto i = 0; i < rank; ++i) { dims[1] * num_rows_tile_ * GETTILESIZE(dims[2], kTileSize);
const int mapped = valid_map[perm[i]]; vec_size_ = tile_size < sm_count ? 1 : dim_vec_size;
if (mapped >= 0) { } else {
perm_[perm_idx] = mapped; type_ = PermuteType::kGeneralPermute;
perm_idx += 1; }
return;
} }
vec_size_ = dst_vec_size;
} }
rank_ = valid_dim_idx;
} }
int GetPermVecSize(const int sm_count, const T* src, T* dst) { ~PermTypeClassifier() = default;
// For gerneal_permute kernel, there is good chance for
// vectorized write.
type_ = PermuteType::kGeneralPermute;
int vec_size = phi::GetVectorizedSize<T>(dst);
// While the last dim is fixed, there is good chance for
// both vectorized read and write.
if (perm_[rank_ - 1] == rank_ - 1) {
int tmp_size = std::min(vec_size, phi::GetVectorizedSize<T>(src));
tmp_size = GetDimVesSize(tmp_size, src_dims[rank_ - 1]);
if (tmp_size > 1) {
type_ = kVecPermute;
vec_size = tmp_size;
}
}
// Once only transpose at the last 2 dims, there is good int GetVecSize() const { return vec_size_; }
// chance for vectorized read. int GetRowsTile() const { return num_rows_tile_; }
if ((rank_ == 2 && perm_[1] == 0 && perm_[0] == 1) || PermuteType GetPermType() const { return type_; }
(rank_ == 3 && perm_[2] == 1 && perm_[1] == 2)) {
type_ = PermuteType::kTranspose; private:
int tmp_vec = std::min(vec_size, phi::GetVectorizedSize<T>(src)); int vec_size_{1};
// With bytes limitation of shared_memory, the VecSize shall be int64_t num_rows_tile_{0};
// restricted for the type whose byte-size is less than 8 (double). PermuteType type_{kGeneralPermute};
vec_size =
sizeof(T) > 8 ? 1 : GetDimVesSize(tmp_vec, src_dims[rank_ - 1]);
}
return vec_size;
}
// To find if highest common divisor and make it as vec_size. // To find if highest common divisor and make it as vec_size.
int GetDimVesSize(const int vec_size, const size_t target_dim) { int GetDimVecSize(const int dst_vec_size,
const int64_t target_dim,
const T* src,
bool use_share_mem = true) {
const int vec_size = std::min(dst_vec_size, phi::GetVectorizedSize<T>(src));
int dim_vec_size = 1; int dim_vec_size = 1;
for (auto size = vec_size; size > 0; size /= 2) { for (int size = vec_size; size > 0; size /= 2) {
if (target_dim % size == 0) { if (target_dim % size == 0) {
dim_vec_size = size; dim_vec_size = size;
break; break;
} }
} }
return dim_vec_size;
if (use_share_mem) {
// By bytes limitation of shared_memory.
return (sizeof(T) > sizeof(float) ? 1 : dim_vec_size);
} else {
return dim_vec_size;
}
} }
}; };
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/transpose_functor.cu.h" #include "paddle/phi/kernels/funcs/transpose_function.cu.h"
#include "paddle/phi/kernels/impl/transpose_grad_kernel_impl.h" #include "paddle/phi/kernels/impl/transpose_grad_kernel_impl.h"
namespace phi { namespace phi {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册