未验证 提交 a0f43889 编写于 作者: L limingshu 提交者: GitHub

Transpose optimization for AlphaFold2 (#45230)

* first commit

* fix bugs according to ci

* add some changes

* change file name into function.cu.h

* remove const_cast
上级 30f4ef7f
......@@ -21,7 +21,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/funcs/transpose_functor.cu.h"
#include "paddle/phi/kernels/funcs/transpose_function.cu.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace paddle {
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/funcs/transpose_functor.cu.h"
#include "paddle/phi/kernels/funcs/transpose_function.cu.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace paddle {
......
......@@ -243,5 +243,106 @@ struct BroadcastDimsSimplifier {
}
};
// Simplify the input dims and permute dims if possible.
struct DimsSimplifier {
public:
explicit DimsSimplifier(const int rank,
const int64_t numel,
const std::vector<int32_t> &perm,
const std::vector<int64_t> &dims)
: perm_(rank), src_dims_(rank), count_(numel) {
SimplifyPermAndDims(rank, dims, perm);
perm_.resize(rank_);
src_dims_.resize(rank_);
dst_dims_.resize(rank_);
if (!is_seq_perm_) {
for (auto i = 0; i < rank_; ++i) {
dst_dims_[i] = src_dims_[perm_[i]];
}
} else {
dst_dims_[0] = numel;
src_dims_[0] = numel;
}
}
~DimsSimplifier() = default;
const int &GetRank() const { return rank_; }
const int64_t &GetCount() const { return count_; }
const std::vector<int> &GetPerm() const { return perm_; }
const std::vector<int64_t> &GetSrcDims() const { return src_dims_; }
const std::vector<int64_t> &GetDstDims() const { return dst_dims_; }
private:
int rank_{1};
int64_t count_{0};
bool is_seq_perm_{true};
std::vector<int> perm_;
std::vector<int64_t> src_dims_;
std::vector<int64_t> dst_dims_;
void SimplifyPermAndDims(const int rank,
const std::vector<int64_t> &in_dims,
const std::vector<int32_t> &perm) {
int start_perm_idx = 0;
int valid_dim_idx = 0;
int valid_map[phi::DDim::kMaxRank];
int64_t combined_dims[phi::DDim::kMaxRank];
// Merge consecutive dims to the fist one dim and
// leave original dim to be 1. Example below :
// perm: [2, 3, 0, 1], origin_dims : [4, 8, 2, 5]
// new_dims: [4, 8, 2, 5] -> [32, 1, 10, 1]
while (start_perm_idx < rank) {
const int start_dim_idx = perm[start_perm_idx];
combined_dims[start_dim_idx] = in_dims[start_dim_idx];
int end_perm_idx = start_perm_idx + 1;
while (end_perm_idx < rank &&
perm[end_perm_idx] == perm[end_perm_idx - 1] + 1) {
const int end_dim_idx = perm[end_perm_idx];
combined_dims[start_dim_idx] *= in_dims[end_dim_idx];
combined_dims[end_dim_idx] = 1;
end_perm_idx += 1;
}
start_perm_idx = end_perm_idx;
}
// Reorder combined dims and marked useless dim as -1.
// for example, if combined dims is [32, 1, 10, 1],
// valid_map is [0, -1, 1, -1] and generate simplified
// dims as [32, 10]
for (auto i = 0; i < rank; ++i) {
const int dim_val = combined_dims[i];
if (dim_val == 1) {
valid_map[i] = -1;
} else {
valid_map[i] = valid_dim_idx;
src_dims_[valid_dim_idx] = dim_val;
valid_dim_idx += 1;
}
}
if (valid_dim_idx == 0) {
src_dims_[0] = 1;
perm_[0] = 0;
return;
}
// Acquire simplified perm with help of combined dims
// and original perm, finally simplified perm is [1, 0]
int perm_idx = 0;
for (auto i = 0; i < rank; ++i) {
const int mapped = valid_map[perm[i]];
if (mapped >= 0) {
perm_[perm_idx] = mapped;
is_seq_perm_ &= (mapped == perm_idx);
perm_idx += 1;
}
}
rank_ = is_seq_perm_ ? 1 : valid_dim_idx;
}
};
} // namespace funcs
} // namespace phi
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_utils.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/autotune/auto_tune_base.h"
#include "paddle/phi/kernels/funcs/dims_simplifier.h"
#include "paddle/phi/kernels/funcs/transpose_functor.h"
#include "paddle/phi/kernels/primitive/datamover_primitives.h"
......@@ -191,7 +192,6 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
IndexType output_origin_block_flat_index =
FlatTensorIndex<IndexType>(block_tile_index_in_output, output_dims);
constexpr IndexType out_effective_thread_num = NumThreads / TileX * TileX;
if (x < out_effective_thread_num) {
......@@ -652,7 +652,7 @@ struct SwapDim0And2InTranspose {
inline void CombineTransposeDim3(const DDim& shape,
const std::vector<int>& perm,
std::vector<int>* new_perm,
DDim* new_dims) {
std::vector<int>* new_dims) {
PADDLE_ENFORCE_EQ(shape.size(),
perm.size(),
phi::errors::InvalidArgument(
......@@ -667,114 +667,111 @@ inline void CombineTransposeDim3(const DDim& shape,
new_perm->resize(1);
(*new_perm)[0] = perm[0];
dim_vec.push_back(shape[0]);
*new_dims = phi::make_ddim(dim_vec);
return;
}
std::vector<int> new_dim_pos(shape.size(), -1);
std::vector<int64_t> combined_dims(shape.size(), 0);
int cur_head = perm[0];
new_dim_pos[cur_head] = 0;
combined_dims[0] = shape[cur_head];
int dim_idx = 0;
for (int perm_idx = 1; perm_idx < shape.size(); ++perm_idx) {
// combine consecutive dimensions.
if (cur_head + 1 == perm[perm_idx]) {
cur_head = perm[perm_idx];
combined_dims[dim_idx] *= shape[cur_head];
} else {
// Else start a new dimension.
cur_head = perm[perm_idx];
dim_idx++;
new_dim_pos[cur_head] = dim_idx;
combined_dims[dim_idx] = shape[cur_head];
} else {
int dim_idx = 0;
std::vector<int> new_dim_pos(shape.size(), -1);
std::vector<int> combined_dims(shape.size(), 0);
int cur_head = perm[0];
new_dim_pos[cur_head] = 0;
combined_dims[0] = shape[cur_head];
for (int perm_idx = 1; perm_idx < shape.size(); ++perm_idx) {
// combine consecutive dimensions.
if (cur_head + 1 == perm[perm_idx]) {
cur_head = perm[perm_idx];
combined_dims[dim_idx] *= shape[cur_head];
} else {
// Else start a new dimension.
cur_head = perm[perm_idx];
dim_idx++;
new_dim_pos[cur_head] = dim_idx;
combined_dims[dim_idx] = shape[cur_head];
}
}
}
new_perm->resize(dim_idx + 1);
dim_idx = 0;
for (int i = 0; i < new_dim_pos.size(); ++i) {
if (new_dim_pos[i] >= 0) {
int new_perm_idx = new_dim_pos[i];
(*new_perm)[dim_idx] = new_perm_idx;
dim_vec.push_back(combined_dims[new_perm_idx]);
dim_idx++;
new_perm->resize(dim_idx + 1);
dim_idx = 0;
for (int i = 0; i < new_dim_pos.size(); ++i) {
if (new_dim_pos[i] >= 0) {
int new_perm_idx = new_dim_pos[i];
(*new_perm)[dim_idx] = new_perm_idx;
dim_vec.push_back(combined_dims[new_perm_idx]);
dim_idx++;
}
}
}
*new_dims = phi::make_ddim(dim_vec);
*new_dims = dim_vec;
}
template <typename T, typename IndexType = int>
template <typename T>
struct TransposeSimple {
static bool run(const phi::GPUContext& ctx,
static bool Impl(const phi::GPUContext& ctx,
const phi::DenseTensor& in,
const std::vector<int32_t> perm,
phi::DenseTensor* out,
const int64_t numel) {
if (numel >= std::numeric_limits<int32_t>::max()) {
return Run<int64_t>(ctx, in, perm, out);
} else {
return Run<int32_t>(ctx, in, perm, out);
}
}
private:
template <typename IndexType = int32_t>
static bool Run(const phi::GPUContext& ctx,
const phi::DenseTensor& in,
const std::vector<int32_t> perm,
phi::DenseTensor* out) {
// First reduce the dimensions of the input tensor if possible.
auto in_data = in.data<T>();
auto out_data = out->data<T>();
std::vector<int> new_perm;
DDim new_dims;
std::vector<int> new_dims;
CombineTransposeDim3(in.dims(), perm, &new_perm, &new_dims);
if (new_perm.size() < 2 || new_perm.size() > 3) return false;
// Only use tile copy GPU kernel when dimension is 2 or 3.
int dims = new_dims.size();
std::vector<int> new_dim_vec = phi::vectorize<int>(new_dims);
if (dims < 2 || dims > 3) return false;
auto in_data = in.data<T>();
auto out_data = out->data<T>();
// In most cases, dim will not greater than 3 after combine.
switch (dims) {
case 2:
if (new_perm[0] == 1 && new_perm[1] == 0) {
// Add the first dimension size as 1.
new_dim_vec.insert(new_dim_vec.begin(), 1);
SwapDim1And2InTranspose<T, IndexType>()(
ctx, in_data, new_dim_vec, out_data);
return true;
}
break;
case 3:
// In this case, suppose we can do coalescing read and write in tile.
if (new_perm == std::vector<int>({0, 2, 1})) {
SwapDim1And2InTranspose<T, IndexType>()(
ctx, in_data, new_dim_vec, out_data);
return true;
} else if (new_perm == std::vector<int>({2, 1, 0})) {
// Maybe can optimize later, find a way to do coalescing memory copy.
// But I think it depends on the data size. If span is not large,
// maybe
// can do coalescing.
SwapDim0And2InTranspose<T, IndexType>()(
ctx, in_data, new_dim_vec, out_data);
return true;
} else {
return false;
}
break;
default:
return false;
if (new_perm.size() == 2 && new_perm[1] == 0) {
// Add the first dimension size as 1.
new_dims.insert(new_dims.begin(), 1);
SwapDim1And2InTranspose<T, IndexType>()(ctx, in_data, new_dims, out_data);
return true;
} else if (new_perm == std::vector<int>({0, 2, 1})) {
SwapDim1And2InTranspose<T, IndexType>()(ctx, in_data, new_dims, out_data);
return true;
} else if (new_perm == std::vector<int>({2, 1, 0})) {
// Maybe can optimize later, find a way to do coalescing memory copy.
// But I think it depends on the data size. If span is not large,
// maybe can do coalescing.
SwapDim0And2InTranspose<T, IndexType>()(ctx, in_data, new_dims, out_data);
return true;
} else {
return false;
}
return false;
}
};
template <int N, typename T>
template <typename IndexT, int N>
class IdxHelper {
public:
IdxHelper() {}
explicit IdxHelper(const T* dims) {
explicit IdxHelper(const IndexT* dims) {
for (int i = N - 1; i >= 0; --i) {
stride_[i] = i < (N - 1) ? dims[i + 1] * stride_[i + 1] : 1;
}
}
__device__ inline T GetStride(int idx) const { return stride_[idx]; }
__device__ __forceinline__ IndexT GetStride(int idx) const {
return stride_[idx];
}
__device__ inline void GetIndexFromOffset(T offset, T* index) const {
T remaining = offset;
__device__ __forceinline__ void GetIndexFromOffset(IndexT offset,
IndexT* index) const {
IndexT remaining = offset;
#pragma unroll
for (int i = 0; i < N - 1; ++i) {
const T idx = remaining / stride_[i];
const IndexT idx = remaining / stride_[i];
remaining -= idx * stride_[i];
index[i] = idx;
}
......@@ -782,11 +779,11 @@ class IdxHelper {
}
private:
T stride_[N];
IndexT stride_[N];
};
template <int N>
class IdxHelper<N, uint32_t> {
class IdxHelper<uint32_t, N> {
public:
IdxHelper() {}
explicit IdxHelper(const uint32_t* dims) {
......@@ -797,10 +794,12 @@ class IdxHelper<N, uint32_t> {
}
}
__device__ inline uint32_t GetStride(int idx) const { return stride_[idx]; }
__device__ __forceinline__ uint32_t GetStride(int idx) const {
return stride_[idx];
}
__device__ inline void GetIndexFromOffset(uint32_t offset,
uint32_t* index) const {
__device__ __forceinline__ void GetIndexFromOffset(uint32_t offset,
uint32_t* index) const {
uint32_t remaining = offset;
#pragma unroll
for (int i = 0; i < N - 1; ++i) {
......@@ -817,18 +816,16 @@ class IdxHelper<N, uint32_t> {
};
// Transform index between memory offset and shape coodinate.
template <typename T, int N>
template <typename IndexT, int N>
class IdxAndOffsetHelper {
public:
IdxAndOffsetHelper() {}
~IdxAndOffsetHelper() = default;
explicit IdxAndOffsetHelper(const T* dims) {
index_helper = IdxHelper<N, T>(dims);
explicit IdxAndOffsetHelper(const IndexT* dims) {
index_helper = IdxHelper<IndexT, N>(dims);
}
__device__ inline T IndexToOffset(const T* index) const {
T offset = 0;
__device__ __forceinline__ IndexT IndexToOffset(const IndexT* index) const {
IndexT offset = 0;
#pragma unroll
for (int i = 0; i < N - 1; ++i) {
offset += index[i] * index_helper.GetStride(i);
......@@ -837,15 +834,16 @@ class IdxAndOffsetHelper {
return offset;
}
__device__ inline void OffsetToIndex(T offset, T* index) const {
__device__ __forceinline__ void OffsetToIndex(IndexT offset,
IndexT* index) const {
index_helper.GetIndexFromOffset(offset, index);
}
private:
IdxHelper<N, T> index_helper;
IdxHelper<IndexT, N> index_helper;
};
template <size_t Rank, typename IndexT>
template <typename IndexT, int Rank>
struct PermuteParams {
public:
IdxAndOffsetHelper<IndexT, Rank> src_index_helper;
......@@ -868,17 +866,17 @@ struct PermuteParams {
// A special kernel for target case, both vectorized read and write supported.
template <typename T, typename IndexT, int VecSize, int Rank>
__global__ void VectorizedPermuteKernel(PermuteParams<Rank, IndexT> params,
const size_t count,
__global__ void VectorizedPermuteKernel(PermuteParams<IndexT, Rank> params,
const IndexT count,
const T* __restrict__ src_data,
T* dst_data) {
using VecT = phi::AlignedVector<T, VecSize>;
IndexT src_index[Rank];
IndexT dst_index[Rank];
const VecT* __restrict__ src =
const VecT* __restrict__ vec_src =
reinterpret_cast<const VecT* __restrict__>(src_data);
VecT* dst = reinterpret_cast<VecT*>(dst_data);
VecT* vec_dst = reinterpret_cast<VecT*>(dst_data);
IndexT tid = blockIdx.x * blockDim.x + threadIdx.x;
for (IndexT i = tid; i < count; i += blockDim.x * gridDim.x) {
......@@ -889,31 +887,23 @@ __global__ void VectorizedPermuteKernel(PermuteParams<Rank, IndexT> params,
src_index[params.perm[j]] = dst_index[j];
}
IndexT src_offset = params.src_index_helper.IndexToOffset(src_index);
dst[i] = src[src_offset];
vec_dst[i] = vec_src[src_offset];
}
}
// A general kernel for normal case, only support vectorized write.
template <typename T, typename IndexT, int VecSize, int Rank>
__global__ void GeneralPermuteKernel(PermuteParams<Rank, IndexT> params,
__global__ void GeneralPermuteKernel(PermuteParams<IndexT, Rank> params,
const IndexT main_cnt,
const IndexT tail_cnt,
const IndexT offset,
const T* __restrict__ src,
T* dst,
const size_t main_cnt,
const size_t tail_cnt,
const size_t offset) {
T* dst) {
using VecT = phi::AlignedVector<T, VecSize>;
VecT* vec_dst = reinterpret_cast<VecT*>(dst);
IndexT src_index[VecSize][Rank];
IndexT dst_index[VecSize][Rank];
// Avoid read perm data both in 2 load process.
__shared__ int perm[Rank];
if (threadIdx.x < Rank) {
perm[threadIdx.x] = params.perm[threadIdx.x];
}
__syncthreads();
// Vectorized load data.
IndexT tid = blockIdx.x * blockDim.x + threadIdx.x;
for (IndexT idx = tid; idx < main_cnt; idx += blockDim.x * gridDim.x) {
......@@ -926,7 +916,7 @@ __global__ void GeneralPermuteKernel(PermuteParams<Rank, IndexT> params,
#pragma unroll
for (int j = 0; j < Rank; ++j) {
src_index[i][perm[j]] = dst_index[i][j];
src_index[i][params.perm[j]] = dst_index[i][j];
}
IndexT src_offset = params.src_index_helper.IndexToOffset(src_index[i]);
vec_data[i] = src[src_offset];
......@@ -941,235 +931,441 @@ __global__ void GeneralPermuteKernel(PermuteParams<Rank, IndexT> params,
#pragma unroll
for (int j = 0; j < Rank; ++j) {
src_index[0][perm[j]] = dst_index[0][j];
src_index[0][params.perm[j]] = dst_index[0][j];
}
IndexT src_offset = params.src_index_helper.IndexToOffset(src_index[0]);
dst[idx] = src[src_offset];
}
}
// A Gerneral permute method that drectly find the dst data
// coordinate in the source data.
template <typename T, typename IndexT, int VecSize, int Rank>
inline void LaunchPermuteKernel(const phi::GPUContext& ctx,
const IndexT count,
const PermuteType perm_type,
const std::vector<int64_t>& dims,
const std::vector<int>& perm,
const T* src,
T* dst) {
size_t main_count = count / VecSize;
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, main_count);
if (perm_type == PermuteType::kGeneralPermute) {
size_t tail_count = count - main_count * VecSize;
size_t offset = count - tail_count;
auto params = PermuteParams<Rank, IndexT>(dims, perm);
GeneralPermuteKernel<T, IndexT, VecSize, Rank>
<<<config.GetGridSize(), config.GetBlockSize(), 0, ctx.stream()>>>(
params, src, dst, main_count, tail_count, offset);
} else {
std::vector<int64_t> vec_dims(dims);
vec_dims[dims.size() - 1] /= VecSize;
auto params = PermuteParams<Rank, IndexT>(vec_dims, perm);
VectorizedPermuteKernel<T, IndexT, VecSize, Rank>
<<<config.GetGridSize(), config.GetBlockSize(), 0, ctx.stream()>>>(
params, main_count, src, dst);
template <typename T, typename IndexT, int ReadSize, int WriteSize>
struct TransposeDataWriter {
__device__ __forceinline__ void operator()(T* dst_data,
const T* s_data,
const IndexT rows,
const IndexT cols,
const IndexT chs_stride,
const IndexT round_tile_cols,
const IndexT col_stride = 1) {
using OutVecT = phi::AlignedVector<T, WriteSize>;
OutVecT* vec_dst = reinterpret_cast<OutVecT*>(dst_data);
constexpr int kColTile = kTileSize * ReadSize;
constexpr int kColStride = kShareCol * ReadSize;
const IndexT vec_rows = rows / WriteSize;
const IndexT col_in_mat = blockIdx.y * kTileSize + threadIdx.x;
if (col_in_mat < /*dst_cols=*/vec_rows) {
const int cols_range = (blockIdx.x < round_tile_cols)
? kTileSize
: (cols - round_tile_cols * kTileSize);
const int share_tile = threadIdx.x * (WriteSize * kColStride);
const IndexT write_offset = blockIdx.z * chs_stride + col_in_mat;
#pragma unroll
for (int tile_y = threadIdx.y; tile_y < cols_range;
tile_y += kBlockRows) {
OutVecT tmp_data[ReadSize];
#pragma unroll
for (int i = 0; i < ReadSize; ++i) {
int tile_tail = tile_y * ReadSize + i;
int major_share_idx = share_tile + tile_tail;
IndexT row_in_mat = (blockIdx.x * kColTile + tile_tail) * col_stride;
#pragma unroll
for (int j = 0; j < WriteSize; ++j) {
tmp_data[i].val[j] = s_data[j * kColStride + major_share_idx];
}
vec_dst[write_offset + row_in_mat * vec_rows] = tmp_data[i];
}
}
}
}
}
};
template <typename T, typename IndexT, int VecSize>
inline void LaunchPermuteRankDispatch(const phi::GPUContext& ctx,
const IndexT count,
const PermuteType perm_type,
const std::vector<int64_t>& dims,
const std::vector<int>& perm,
const T* src,
T* dst) {
#define CALL_DISPATCH_RANK(rank) \
case rank: { \
LaunchPermuteKernel<T, IndexT, VecSize, rank>( \
ctx, count, perm_type, dims, perm, src, dst); \
break; \
template <typename T, typename IndexT, int ReadSize>
struct TransposeDataWriter<T, IndexT, ReadSize, 1> {
__device__ __forceinline__ void operator()(T* dst_data,
const T* s_data,
const IndexT rows,
const IndexT cols,
const IndexT chs_stride,
const IndexT round_tile_cols,
const IndexT col_stride = 1) {
const IndexT col_in_mat = blockIdx.y * kTileSize + threadIdx.x;
if (col_in_mat < /*dst_cols=*/rows) {
const int cols_range = (blockIdx.x < round_tile_cols)
? kTileSize
: (cols - round_tile_cols * kTileSize);
const IndexT row_tile = blockIdx.x * kTileSize * ReadSize;
const IndexT write_offset = blockIdx.z * chs_stride + col_in_mat;
const int shared_tile = threadIdx.x * kShareCol * ReadSize;
#pragma unroll
for (int tile_y = threadIdx.y; tile_y < cols_range;
tile_y += kBlockRows) {
const int shared_major = shared_tile + tile_y * ReadSize;
const IndexT row_major = (row_tile + tile_y * ReadSize) * col_stride;
#pragma unroll
for (int i = 0; i < ReadSize; ++i) {
const IndexT row_in_mat = row_major + i * col_stride;
dst_data[write_offset + row_in_mat * rows] = s_data[shared_major + i];
}
}
}
}
};
switch (dims.size()) {
CALL_DISPATCH_RANK(1);
CALL_DISPATCH_RANK(2);
CALL_DISPATCH_RANK(3);
CALL_DISPATCH_RANK(4);
CALL_DISPATCH_RANK(5);
CALL_DISPATCH_RANK(6);
CALL_DISPATCH_RANK(7);
CALL_DISPATCH_RANK(8);
CALL_DISPATCH_RANK(9);
template <typename T, typename IndexT, int VecSize, IndexT RowTile>
struct TransposeDataReader {
__device__ __forceinline__ void operator()(const T* __restrict__ src,
T* s_shared,
const IndexT cols,
const IndexT rows,
const IndexT chs_stride,
const IndexT cols_thresh,
const IndexT round_tile_rows) {
using VecT = phi::AlignedVector<T, VecSize>;
const VecT* __restrict__ v_src =
reinterpret_cast<const VecT* __restrict__>(src);
VecT* v_shared = reinterpret_cast<VecT*>(s_shared);
const IndexT col_in_mat = blockIdx.x * kTileSize + threadIdx.x;
if (col_in_mat < cols_thresh) {
const int row_range = (blockIdx.y < round_tile_rows)
? RowTile
: (rows - RowTile * round_tile_rows);
const IndexT src_idx_major = blockIdx.z * chs_stride + col_in_mat;
#pragma unroll
for (int tile_y = threadIdx.y; tile_y < row_range; tile_y += kBlockRows) {
const IndexT row_in_mat = blockIdx.y * RowTile + tile_y;
v_shared[tile_y * kShareCol + threadIdx.x] =
v_src[row_in_mat * cols + src_idx_major];
}
}
__syncthreads();
}
#undef CALL_DISPATCH_RANK
}
};
// Aim at transposing the last 2 dimensions. Reference from
// https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/
template <typename T, typename IndexT, int VecSize>
template <typename T,
typename IndexT,
bool IsVecWrite,
int ReadSize,
int WriteSize = (IsVecWrite && (sizeof(T) < sizeof(float)))
? sizeof(float) / sizeof(T)
: 1>
__global__ void SwapTransposeKernel(const T* __restrict__ src_data,
T* dst_data,
const IndexT round_tile_rows,
const IndexT round_tile_cols,
const IndexT cols,
const IndexT rows,
const IndexT chs /*=channel*/) {
constexpr int kRowTile = kTileSize * WriteSize;
__shared__ T s_data[kRowTile * kShareCol * ReadSize];
const IndexT chs_stride = chs * cols;
TransposeDataReader<T, IndexT, ReadSize, kRowTile>()(
src_data, s_data, chs_stride, rows, cols, cols, round_tile_rows);
TransposeDataWriter<T, IndexT, ReadSize, WriteSize>()(
dst_data, s_data, rows, cols, rows / WriteSize, round_tile_cols, chs);
}
template <typename T,
typename IndexT,
bool IsVecWrite,
int ReadSize,
int WriteSize = (IsVecWrite && (sizeof(T) < sizeof(float)))
? sizeof(float) / sizeof(T)
: 1>
__global__ void BatchTransposeKernel(const T* __restrict__ src_data,
T* dst_data,
IndexT rows,
IndexT cols,
IndexT round_tile_rows,
IndexT round_tile_cols) {
using VecT = phi::AlignedVector<T, VecSize>;
constexpr int kShareCol = kTileSize + 1;
__shared__ VecT v_shared[kTileSize * kShareCol];
T* s_shared = reinterpret_cast<T*>(v_shared);
// Vectorized load data from src into shared memory. [rows, cols]
const VecT* __restrict__ vec_src =
reinterpret_cast<const VecT* __restrict__>(src_data);
const IndexT round_tile_rows,
const IndexT round_tile_cols,
const IndexT cols,
const IndexT rows) {
constexpr int kRowTile = kTileSize * WriteSize;
__shared__ T s_data[kRowTile * kShareCol * ReadSize];
const IndexT chs_stride = rows * cols;
TransposeDataReader<T, IndexT, ReadSize, kRowTile>()(
src_data, s_data, cols, rows, chs_stride, cols, round_tile_rows);
TransposeDataWriter<T, IndexT, ReadSize, WriteSize>()(
dst_data,
s_data,
rows,
cols,
chs_stride * ReadSize / WriteSize,
round_tile_cols);
}
IndexT col_in_matrix = blockIdx.x * kTileSize + threadIdx.x;
IndexT offset = blockIdx.z * rows * cols;
template <typename T, typename IndexT, int VecSize>
struct PermuteLauncher {
public:
void operator()(const phi::GPUContext& ctx,
const int& rank,
const IndexT& count,
const PermuteType& perm_type,
const std::vector<int64_t>& dims,
const std::vector<int32_t>& perm,
const T* src,
T* dst) {
dims_ = dims;
main_cnt_ = count / VecSize;
#define CALL_PERMUTE_DISPATCH_RANK(rank_) \
case rank_: { \
Run<rank_>(ctx, perm, perm_type, count, src, dst); \
break; \
}
if (col_in_matrix < cols) {
int row_range = (blockIdx.y < round_tile_rows)
? kTileSize
: (rows - kTileSize * round_tile_rows);
#pragma unroll
for (int tile_y = threadIdx.y; tile_y < row_range; tile_y += kBlockRows) {
IndexT row_in_matrix = tile_y + blockIdx.y * kTileSize;
v_shared[tile_y * kShareCol + threadIdx.x] =
vec_src[offset + row_in_matrix * cols + col_in_matrix];
switch (rank) {
CALL_PERMUTE_DISPATCH_RANK(3);
CALL_PERMUTE_DISPATCH_RANK(4);
CALL_PERMUTE_DISPATCH_RANK(5);
CALL_PERMUTE_DISPATCH_RANK(6);
CALL_PERMUTE_DISPATCH_RANK(7);
CALL_PERMUTE_DISPATCH_RANK(8);
CALL_PERMUTE_DISPATCH_RANK(9);
}
#undef CALL_PERMUTE_DISPATCH_RANK
}
// Write data from shared memory into dst and
// dst_cols = rows, dst_rows = cols * Vecsize
col_in_matrix = blockIdx.y * kTileSize + threadIdx.x;
offset = offset * VecSize + col_in_matrix;
__syncthreads();
private:
IndexT main_cnt_{0};
std::vector<int64_t> dims_;
template <int Rank>
void Run(const phi::GPUContext& ctx,
const std::vector<int32_t>& perm,
const PermuteType& perm_type,
const IndexT& count,
const T* src,
T* dst) {
auto cfg = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, main_cnt_);
if (perm_type == PermuteType::kVecPermute) {
dims_[Rank - 1] /= VecSize;
const auto params = PermuteParams<IndexT, Rank>(dims_, perm);
VectorizedPermuteKernel<T, IndexT, VecSize, Rank>
<<<cfg.block_per_grid, cfg.thread_per_block, 0, ctx.stream()>>>(
params, main_cnt_, src, dst);
} else {
IndexT tail_cnt = count - main_cnt_ * VecSize;
IndexT main_offset = count - tail_cnt;
const auto params = PermuteParams<IndexT, Rank>(dims_, perm);
if (col_in_matrix < /*dst_cols=*/rows) {
int col_range = (blockIdx.x < round_tile_cols)
? kTileSize
: (cols - kTileSize * round_tile_cols);
#pragma unroll
for (IndexT tile_y = threadIdx.y; tile_y < col_range;
tile_y += kBlockRows) {
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
IndexT row_in_matrix = (tile_y + blockIdx.x * kTileSize) * VecSize + i;
IndexT shared_idx = (tile_y + threadIdx.x * kShareCol) * VecSize + i;
dst_data[offset + row_in_matrix * rows] = s_shared[shared_idx];
GeneralPermuteKernel<T, IndexT, VecSize, Rank>
<<<cfg.block_per_grid, cfg.thread_per_block, 0, ctx.stream()>>>(
params, main_cnt_, tail_cnt, main_offset, src, dst);
}
}
};
template <typename T, typename IndexT, int VecSize>
struct TransposeLauncher {
public:
void operator()(const phi::GPUContext& ctx,
const int& rank,
const PermuteType& perm_type,
const std::vector<int64_t>& dims,
const IndexT& num_rows_tile,
const T* src,
T* dst) {
constexpr int ReadSize = sizeof(T) > sizeof(float) ? 1 : VecSize;
const IndexT cols = dims[rank - 1] / VecSize;
const IndexT n_cols_tile = GETTILESIZE(cols, kTileSize);
if (perm_type == PermuteType::kGeneralTranspose) {
IndexT chs = (rank == 2) ? 1 : dims[0];
IndexT rows = dims[rank - 2];
IndexT n_rows_tile =
FindRowTiles(chs, rows, num_rows_tile, n_cols_tile, ctx.GetSMCount());
dim3 blocks(n_cols_tile, n_rows_tile, chs);
dim3 threads(kTileSize, kBlockRows, 1);
if (is_vec_write) {
BatchTransposeKernel<T, IndexT, true, ReadSize>
<<<blocks, threads, 0, ctx.stream()>>>(
src, dst, n_rows_tile - 1, n_cols_tile - 1, cols, rows);
} else {
BatchTransposeKernel<T, IndexT, false, ReadSize>
<<<blocks, threads, 0, ctx.stream()>>>(
src, dst, n_rows_tile - 1, n_cols_tile - 1, cols, rows);
}
} else {
IndexT rows = dims[0];
IndexT chs = dims[rank - 2];
IndexT n_rows_tile =
FindRowTiles(chs, rows, num_rows_tile, n_cols_tile, ctx.GetSMCount());
dim3 blocks(n_cols_tile, n_rows_tile, chs);
dim3 threads(kTileSize, kBlockRows, 1);
if (is_vec_write) {
SwapTransposeKernel<T, IndexT, true, ReadSize>
<<<blocks, threads, 0, ctx.stream()>>>(
src, dst, n_rows_tile - 1, n_cols_tile - 1, cols, rows, chs);
} else {
SwapTransposeKernel<T, IndexT, false, ReadSize>
<<<blocks, threads, 0, ctx.stream()>>>(
src, dst, n_rows_tile - 1, n_cols_tile - 1, cols, rows, chs);
}
}
}
}
// With the byte limitation of shared_memory, the VecSize shall be
// restricted for the type whose byte-size is less than 4.
template <typename T,
typename IndexT,
int Size,
int VecSize = (sizeof(T) > 4 ? 1 : Size)>
inline void LaunchTransposeKernel(const phi::GPUContext& ctx,
const std::vector<int64_t>& dims,
const T* src,
T* dst) {
auto rank = dims.size();
IndexT num_batches = (rank == 2) ? 1 : dims[0];
IndexT rows = dims[rank - 2];
IndexT cols = dims[rank - 1] / VecSize;
IndexT num_tile_rows = (rows + kTileSize - 1) / kTileSize;
IndexT num_tile_cols = (cols + kTileSize - 1) / kTileSize;
dim3 blocks(num_tile_cols, num_tile_rows, num_batches);
dim3 threads(kTileSize, kBlockRows, 1);
BatchTransposeKernel<T, IndexT, VecSize>
<<<blocks, threads, 0, ctx.stream()>>>(
src, dst, rows, cols, num_tile_rows - 1, num_tile_cols - 1);
}
private:
bool is_vec_write{false};
inline IndexT FindRowTiles(const IndexT& chs,
const IndexT& rows,
const IndexT& num_rows_tile,
const IndexT& num_cols_tile,
const int& sm_count) {
constexpr int kVecRow = sizeof(float) / sizeof(T);
is_vec_write =
(sizeof(T) < sizeof(float)) ? ((rows % kVecRow) ? false : true) : false;
int vec_write = 1;
if (is_vec_write) {
is_vec_write = (chs * num_cols_tile * num_rows_tile) > sm_count;
vec_write = is_vec_write ? kVecRow : 1;
}
IndexT n_rows_tile = is_vec_write
? GETTILESIZE(rows, (kTileSize * vec_write))
: num_rows_tile;
return n_rows_tile;
}
};
template <typename T, typename IndexT>
inline void LaunchWithDispatchVecSize(const phi::GPUContext& ctx,
const int vec_size,
const PermuteType perm_type,
const std::vector<int64_t>& dims,
const std::vector<int>& perm,
const T* src,
T* dst,
IndexT count) {
#define CALL_DISPATCH_VEC_SIZE(vec_size) \
case vec_size: { \
if (perm_type == PermuteType::kTranspose) { \
LaunchTransposeKernel<T, IndexT, vec_size>(ctx, dims, src, dst); \
} else { \
LaunchPermuteRankDispatch<T, IndexT, vec_size>( \
ctx, count, perm_type, dims, perm, src, dst); \
} \
break; \
struct PermuteDispatch {
public:
PermuteDispatch(const phi::GPUContext& ctx,
PermTypeClassifier<T>* cls_ptr,
const std::vector<int64_t>& dims,
const std::vector<int32_t>& perm,
const IndexT count,
const T* src,
T* dst)
: dims_(dims), cls_(cls_ptr) {
rank_ = dims_.size();
type_ = cls_->GetPermType();
KernelTypeDispatch(ctx, count, perm, src, dst);
}
~PermuteDispatch() {}
switch (vec_size) {
CALL_DISPATCH_VEC_SIZE(1);
CALL_DISPATCH_VEC_SIZE(2);
CALL_DISPATCH_VEC_SIZE(4);
default: {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported vectorized size: %d !", vec_size));
break;
private:
int rank_{0};
std::vector<int64_t> dims_;
PermTypeClassifier<T>* cls_;
PermuteType type_{kGeneralPermute};
void KernelTypeDispatch(const phi::GPUContext& ctx,
const IndexT& count,
const std::vector<int32_t>& perm,
const T* src,
T* dst) {
#define TRANSPOSE_DISPATCH_VEC_SIZE(size) \
case size: { \
TransposeLauncher<T, IndexT, size>()( \
ctx, rank_, type_, dims_, cls_->GetRowsTile(), src, dst); \
break; \
}
#define PERMUTE_DISPATCH_VEC_SIZE(size) \
case size: { \
PermuteLauncher<T, IndexT, size>()( \
ctx, rank_, count, type_, dims_, perm, src, dst); \
break; \
}
switch (type_) {
case kSwapTranspose:
case kGeneralTranspose:
switch (cls_->GetVecSize()) {
TRANSPOSE_DISPATCH_VEC_SIZE(1);
TRANSPOSE_DISPATCH_VEC_SIZE(2);
TRANSPOSE_DISPATCH_VEC_SIZE(4);
}
break;
default:
switch (cls_->GetVecSize()) {
PERMUTE_DISPATCH_VEC_SIZE(1);
PERMUTE_DISPATCH_VEC_SIZE(2);
PERMUTE_DISPATCH_VEC_SIZE(4);
}
break;
}
#define TRANSPOSE_DISPATCH_VEC_SIZE
#define PERMUTE_DISPATCH_VEC_SIZE
}
#undef CALL_DISPATCH_VEC_SIZE
}
};
template <typename DeviceContext, typename T>
inline void PermuteAndTranspose(const int rank,
const DeviceContext& ctx,
template <typename T>
inline void PermuteAndTranspose(const phi::GPUContext& ctx,
const int& rank,
const phi::DenseTensor& in,
phi::DenseTensor* out,
const std::vector<int32_t>& perm) {
const int64_t numel = in.numel();
auto classifier =
TranposeTypeClassifier<T>(ctx.GetSMCount(),
rank,
numel,
perm,
phi::vectorize<int64_t>(in.dims()),
in.data<T>(),
out->data<T>());
const DimsSimplifier& simplifier) {
T* dst_data = out->data<T>();
const T* src_data = in.data<T>();
const auto count = simplifier.GetCount();
auto classifier = PermTypeClassifier<T>(ctx.GetSMCount(),
simplifier.GetRank(),
simplifier.GetPerm(),
simplifier.GetSrcDims(),
src_data,
dst_data);
if (classifier.GetPermType() == PermuteType::kCopy) {
// If perm is [0,1,2,3], then just operate a DtoD copy.
phi::backends::gpu::GpuMemcpyAsync(out->data<T>(),
in.data<T>(),
numel * sizeof(T),
phi::backends::gpu::GpuMemcpyAsync(dst_data,
src_data,
count * sizeof(T),
phi::gpuMemcpyDeviceToDevice,
ctx.stream());
} else {
if (numel < std::numeric_limits<int>::max()) {
LaunchWithDispatchVecSize<T, int>(ctx,
classifier.GetVecSize(),
classifier.GetPermType(),
classifier.GetSrcDims(),
classifier.GetPerm(),
in.data<T>(),
out->data<T>(),
static_cast<int>(numel));
if (count < std::numeric_limits<uint32_t>::max()) {
PermuteDispatch<T, uint32_t>(ctx,
&classifier,
simplifier.GetSrcDims(),
simplifier.GetPerm(),
static_cast<uint32_t>(count),
src_data,
dst_data);
} else {
int64_t cnt = static_cast<int64_t>(numel);
LaunchWithDispatchVecSize<T, int64_t>(ctx,
classifier.GetVecSize(),
classifier.GetPermType(),
classifier.GetSrcDims(),
classifier.GetPerm(),
in.data<T>(),
out->data<T>(),
static_cast<int64_t>(numel));
PermuteDispatch<T, int64_t>(ctx,
&classifier,
simplifier.GetSrcDims(),
simplifier.GetPerm(),
static_cast<int64_t>(count),
src_data,
dst_data);
}
}
}
template <typename T>
inline void PermuteWithEigen(const phi::GPUContext& ctx,
const int& rank,
const phi::DenseTensor& in,
phi::DenseTensor* out,
const DimsSimplifier& simplifier) {
const bool not_same_dims = simplifier.GetRank() != rank;
if (not_same_dims) {
phi::DDim dst_dims = out->dims();
phi::DenseTensor temp_in;
temp_in.ShareBufferWith(in);
temp_in.Resize(phi::make_ddim(simplifier.GetSrcDims()));
out->Resize(phi::make_ddim(simplifier.GetDstDims()));
TransCompute<phi::GPUContext, T>(
simplifier.GetRank(), ctx, temp_in, out, simplifier.GetPerm());
out->Resize(dst_dims);
} else {
TransCompute<phi::GPUContext, T>(
simplifier.GetRank(), ctx, in, out, simplifier.GetPerm());
}
}
template <typename T>
void TransposeGPUKernelDriver(const phi::GPUContext& ctx,
const phi::DenseTensor& in,
......@@ -1177,30 +1373,26 @@ void TransposeGPUKernelDriver(const phi::GPUContext& ctx,
phi::DenseTensor* out) {
const int rank = perm.size();
int64_t numel = in.numel();
bool ret{false};
if (numel >= std::numeric_limits<int32_t>::max()) {
ret = TransposeSimple<T, int64_t>::run(ctx, in, perm, out);
} else {
ret = TransposeSimple<T>::run(ctx, in, perm, out);
}
bool ret = TransposeSimple<T>::Impl(ctx, in, perm, out, numel);
if (!ret) {
auto* tuner = phi::autotune::MakeTransposeTuner<T>(
funcs::TransCompute<phi::GPUContext, T>);
tuner->AddCallBack(PermuteAndTranspose<phi::GPUContext, T>);
auto simplifier =
DimsSimplifier(rank, numel, perm, phi::vectorize<int64_t>(in.dims()));
auto* tuner = phi::autotune::MakeTransposeTuner<T>(PermuteWithEigen<T>);
tuner->AddCallBack(PermuteAndTranspose<T>);
size_t key = phi::autotune::TransposeKey(
phi::vectorize(in.dims()),
perm,
simplifier.GetSrcDims(),
simplifier.GetPerm(),
paddle::experimental::CppTypeToDataType<T>::Type());
tuner->Run(ctx,
phi::autotune::AlgorithmType::kTranspose,
key,
rank,
ctx,
rank,
in,
out,
perm);
simplifier);
}
}
......
......@@ -27,161 +27,115 @@ enum { kTransposeMKLDNNFP32 = 1, kTransposeMKLDNNINT8 = 2 };
enum PermuteType {
kCopy = 1,
kTranspose = 2,
kVecPermute = 3,
kGeneralPermute = 4
kSwapTranspose = 2,
kGeneralTranspose = 3,
kVecPermute = 4,
kGeneralPermute = 5
};
constexpr int kBlockRows = 16;
constexpr int kTileSize = 32;
constexpr int kShareCol = (kTileSize + 1);
#define GETTILESIZE(LEN_, ALIGN_) \
((LEN_ + (ALIGN_ - 1)) & ~(ALIGN_ - 1)) / ALIGN_
// Simplify the input dims and permute dims if possible.
template <typename T>
class TranposeTypeClassifier {
struct PermTypeClassifier {
public:
TranposeTypeClassifier(const int sm_count,
const size_t rank,
const int64_t numel,
const std::vector<int32_t>& perm,
const std::vector<int64_t>& dims,
const T* src,
T* dst)
: perm_(rank), src_dims(rank) {
SimplifyPermAndDims(rank, dims, perm);
if (rank_ > 1) {
vec_size_ = GetPermVecSize(sm_count, src, dst);
}
perm_.resize(rank_);
src_dims.resize(rank_);
dst_dims.resize(rank_);
for (auto i = 0; i < rank_; ++i) {
dst_dims[i] = src_dims[perm_[i]];
}
}
int GetRank() const { return rank_; }
int GetVecSize() const { return vec_size_; }
PermuteType GetPermType() const { return type_; }
std::vector<int> GetPerm() const { return perm_; }
std::vector<int64_t> GetSrcDims() const { return src_dims; }
std::vector<int64_t> GetDstDims() const { return dst_dims; }
private:
int rank_{1};
int vec_size_{1};
std::vector<int> perm_;
std::vector<int64_t> src_dims;
std::vector<int64_t> dst_dims;
PermuteType type_{kCopy};
void SimplifyPermAndDims(const size_t rank,
const std::vector<int64_t>& in_dims,
const std::vector<int32_t>& perm) {
int64_t combined_dims[phi::DDim::kMaxRank];
int valid_map[phi::DDim::kMaxRank];
// Merge consecutive dims to the fist one dim and
// leave original dim to be 1. Example below :
// perm: [2, 3, 0, 1], origin_dims : [4, 8, 2, 5]
// new_dims: [4, 8, 2, 5] -> [32, 1, 10, 1]
int start_perm_idx = 0;
while (start_perm_idx < rank) {
const int start_dim_idx = perm[start_perm_idx];
combined_dims[start_dim_idx] = in_dims[start_dim_idx];
int end_perm_idx = start_perm_idx + 1;
while (end_perm_idx < rank &&
perm[end_perm_idx] == perm[end_perm_idx - 1] + 1) {
const int end_dim_idx = perm[end_perm_idx];
combined_dims[start_dim_idx] *= in_dims[end_dim_idx];
combined_dims[end_dim_idx] = 1;
end_perm_idx += 1;
explicit PermTypeClassifier(const int sm_count,
const int rank,
const std::vector<int32_t>& perm,
const std::vector<int64_t>& dims,
const T* src,
T* dst) {
if (rank == 1) {
type_ = PermuteType::kCopy;
} else {
constexpr int64_t dim_limitation = 65536;
const int dst_vec_size = phi::GetVectorizedSize<T>(dst);
// While the last dim is fixed, there is chance for vectorized IO.
const int last_idx = rank - 1;
if (perm[last_idx] == last_idx) {
type_ = PermuteType::kVecPermute;
vec_size_ = GetDimVecSize(dst_vec_size, dims[last_idx], src, false);
return;
}
start_perm_idx = end_perm_idx;
}
// Reorder combined dims and marked useless dim as -1.
// for example, if combined dims is [32, 1, 10, 1],
// valid_map is [0, -1, 1, -1] and generate simplified
// dims as [32, 10]
int valid_dim_idx = 0;
bool sequential_flag = false;
for (auto i = 0; i < rank; ++i) {
const int src_dim = combined_dims[i];
if (src_dim == 1) {
valid_map[i] = -1;
} else {
sequential_flag = true;
valid_map[i] = valid_dim_idx;
src_dims[valid_dim_idx] = src_dim;
valid_dim_idx += 1;
// Permute at last 2 dims, namely transpose.
if ((rank == 2 && perm[1] == 0 && perm[0] == 1) ||
(rank == 3 && perm[2] == 1 && perm[1] == 2)) {
int64_t channel = rank == 2 ? 1 : dims[0];
// Currently, transpose kernel cannot cover the case that channel
// dimension is more than 65536 which is the limitation of dim3 setting.
// This special case will be covered by extended transpose kernel later.
if (channel < dim_limitation) {
type_ = PermuteType::kGeneralTranspose;
num_rows_tile_ = GETTILESIZE(dims[rank - 2], kTileSize);
int dim_vec_size = GetDimVecSize(dst_vec_size, dims[last_idx], src);
int tile_size =
channel * num_rows_tile_ * GETTILESIZE(dims[last_idx], kTileSize);
vec_size_ = tile_size < sm_count ? 1 : dim_vec_size;
} else {
type_ = PermuteType::kGeneralPermute;
}
return;
}
}
if (valid_dim_idx == 0) {
src_dims[0] = 1;
perm_[0] = 0;
return;
} else if (valid_dim_idx == 1) {
type_ = PermuteType::kCopy;
}
// Acquire simplified perm with help of combined dims
// and original perm, finally simplified perm is [1, 0]
int perm_idx = 0;
for (auto i = 0; i < rank; ++i) {
const int mapped = valid_map[perm[i]];
if (mapped >= 0) {
perm_[perm_idx] = mapped;
perm_idx += 1;
// Permute at first dim and third dim.
if (rank == 3 && perm[2] == 0 && perm[1] == 1) {
// Currently, transpose kernel cannot cover the case that channel
// dimension is more than 65536 which is the limitation of dim3 setting.
// This special case will be covered by extended transpose kernel later.
if (dims[1] < dim_limitation) {
type_ = PermuteType::kSwapTranspose;
num_rows_tile_ = GETTILESIZE(dims[0], kTileSize);
int dim_vec_size = GetDimVecSize(dst_vec_size, dims[last_idx], src);
int tile_size =
dims[1] * num_rows_tile_ * GETTILESIZE(dims[2], kTileSize);
vec_size_ = tile_size < sm_count ? 1 : dim_vec_size;
} else {
type_ = PermuteType::kGeneralPermute;
}
return;
}
vec_size_ = dst_vec_size;
}
rank_ = valid_dim_idx;
}
int GetPermVecSize(const int sm_count, const T* src, T* dst) {
// For gerneal_permute kernel, there is good chance for
// vectorized write.
type_ = PermuteType::kGeneralPermute;
int vec_size = phi::GetVectorizedSize<T>(dst);
// While the last dim is fixed, there is good chance for
// both vectorized read and write.
if (perm_[rank_ - 1] == rank_ - 1) {
int tmp_size = std::min(vec_size, phi::GetVectorizedSize<T>(src));
tmp_size = GetDimVesSize(tmp_size, src_dims[rank_ - 1]);
if (tmp_size > 1) {
type_ = kVecPermute;
vec_size = tmp_size;
}
}
~PermTypeClassifier() = default;
// Once only transpose at the last 2 dims, there is good
// chance for vectorized read.
if ((rank_ == 2 && perm_[1] == 0 && perm_[0] == 1) ||
(rank_ == 3 && perm_[2] == 1 && perm_[1] == 2)) {
type_ = PermuteType::kTranspose;
int tmp_vec = std::min(vec_size, phi::GetVectorizedSize<T>(src));
// With bytes limitation of shared_memory, the VecSize shall be
// restricted for the type whose byte-size is less than 8 (double).
vec_size =
sizeof(T) > 8 ? 1 : GetDimVesSize(tmp_vec, src_dims[rank_ - 1]);
}
return vec_size;
}
int GetVecSize() const { return vec_size_; }
int GetRowsTile() const { return num_rows_tile_; }
PermuteType GetPermType() const { return type_; }
private:
int vec_size_{1};
int64_t num_rows_tile_{0};
PermuteType type_{kGeneralPermute};
// To find if highest common divisor and make it as vec_size.
int GetDimVesSize(const int vec_size, const size_t target_dim) {
int GetDimVecSize(const int dst_vec_size,
const int64_t target_dim,
const T* src,
bool use_share_mem = true) {
const int vec_size = std::min(dst_vec_size, phi::GetVectorizedSize<T>(src));
int dim_vec_size = 1;
for (auto size = vec_size; size > 0; size /= 2) {
for (int size = vec_size; size > 0; size /= 2) {
if (target_dim % size == 0) {
dim_vec_size = size;
break;
}
}
return dim_vec_size;
if (use_share_mem) {
// By bytes limitation of shared_memory.
return (sizeof(T) > sizeof(float) ? 1 : dim_vec_size);
} else {
return dim_vec_size;
}
}
};
......
......@@ -21,7 +21,7 @@
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/transpose_functor.cu.h"
#include "paddle/phi/kernels/funcs/transpose_function.cu.h"
#include "paddle/phi/kernels/impl/transpose_grad_kernel_impl.h"
namespace phi {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册