未验证 提交 ec7fe888 编写于 作者: L limingshu 提交者: GitHub

Fix bugs in tranpose kernel (#47212)

* first commit

* transpose_kernel_optimization

* first complishment of transpose op

* second commit

* refine code logics of tranpose_kernel

* refine transpose kernel

* first commit

* fix DtoD copy bugs for hip

* refine code according to the PR advice

* change dim to int64_t type.

* fix some type error
上级 399047d7
......@@ -18,7 +18,6 @@ limitations under the License. */
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/autotune/auto_tune_base.h"
......@@ -832,15 +831,6 @@ class IdxAndOffsetHelper {
index_helper = IdxHelper<N, T>(dims);
}
template <typename U>
explicit IdxAndOffsetHelper(const U* dims) {
T temp_dims[N];
for (int i = 0; i < N; ++i) {
temp_dims[i] = static_cast<T>(dims[i]);
}
index_helper = IdxHelper<N, T>(temp_dims);
}
__device__ inline T IndexToOffset(const T* index) const {
T offset = 0;
#pragma unroll
......@@ -866,15 +856,17 @@ struct PermuteParams {
IdxAndOffsetHelper<IndexT, Rank> dst_index_helper;
int perm[Rank]{};
explicit PermuteParams(const std::vector<size_t>& dims,
explicit PermuteParams(const std::vector<int64_t>& dims,
const std::vector<int>& perm_) {
size_t dst_dims[Rank];
for (size_t i = 0; i < Rank; ++i) {
IndexT dst_dims[Rank];
IndexT src_dims[Rank];
for (auto i = 0; i < Rank; ++i) {
src_dims[i] = dims[i];
dst_dims[i] = dims[perm_[i]];
perm[i] = perm_[i];
}
dst_index_helper = IdxAndOffsetHelper<IndexT, Rank>(dst_dims);
src_index_helper = IdxAndOffsetHelper<IndexT, Rank>(dims.data());
src_index_helper = IdxAndOffsetHelper<IndexT, Rank>(src_dims);
}
};
......@@ -966,21 +958,26 @@ template <typename T, typename IndexT, int VecSize, int Rank>
inline void LaunchPermuteKernel(const phi::GPUContext& ctx,
const IndexT count,
const PermuteType perm_type,
const std::vector<size_t>& dims,
const std::vector<int64_t>& dims,
const std::vector<int>& perm,
const T* src,
T* dst) {
size_t main_count = count / VecSize;
auto params = PermuteParams<Rank, IndexT>(dims, perm);
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, main_count);
if (perm_type == PermuteType::kNormalPermute) {
if (perm_type == PermuteType::kGeneralPermute) {
size_t tail_count = count - main_count * VecSize;
size_t offset = count - tail_count;
auto params = PermuteParams<Rank, IndexT>(dims, perm);
GeneralPermuteKernel<T, IndexT, VecSize, Rank>
<<<config.GetGridSize(), config.GetBlockSize(), 0, ctx.stream()>>>(
params, src, dst, main_count, tail_count, offset);
} else {
std::vector<int64_t> vec_dims(dims);
vec_dims[dims.size() - 1] /= VecSize;
auto params = PermuteParams<Rank, IndexT>(vec_dims, perm);
VectorizedPermuteKernel<T, IndexT, VecSize, Rank>
<<<config.GetGridSize(), config.GetBlockSize(), 0, ctx.stream()>>>(
params, main_count, src, dst);
......@@ -991,7 +988,7 @@ template <typename T, typename IndexT, int VecSize>
inline void LaunchPermuteRankDispatch(const phi::GPUContext& ctx,
const IndexT count,
const PermuteType perm_type,
const std::vector<size_t>& dims,
const std::vector<int64_t>& dims,
const std::vector<int>& perm,
const T* src,
T* dst) {
......@@ -1016,70 +1013,76 @@ inline void LaunchPermuteRankDispatch(const phi::GPUContext& ctx,
#undef CALL_DISPATCH_RANK
}
// Aim at transposing the last 2 dimensions. Refer from
// Aim at transposing the last 2 dimensions. Reference from
// https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/
template <typename T, typename IndexT, int VecSize>
__global__ void BatchTransposeKernel(const T* __restrict__ src_data,
T* dst_data,
IndexT rows,
IndexT cols) {
IndexT cols,
IndexT round_tile_rows,
IndexT round_tile_cols) {
using VecT = phi::AlignedVector<T, VecSize>;
__shared__ VecT tile[kTileSize][kShareCol];
T* single_tile = reinterpret_cast<T*>(tile);
IndexT col_in_matrix = blockIdx.x * kTileSize + threadIdx.x;
IndexT offset = blockIdx.z * rows * cols;
constexpr int kShareCol = kTileSize + 1;
__shared__ VecT v_shared[kTileSize * kShareCol];
T* s_shared = reinterpret_cast<T*>(v_shared);
// Vectorized load data from src into shared memory. [rows, cols]
const VecT* __restrict__ src =
const VecT* __restrict__ vec_src =
reinterpret_cast<const VecT* __restrict__>(src_data);
for (IndexT tile_y = threadIdx.y; tile_y < kTileSize; tile_y += kBlockRows) {
IndexT row_in_matrix = tile_y + blockIdx.y * kTileSize;
IndexT col_in_matrix = blockIdx.x * kTileSize + threadIdx.x;
IndexT offset = blockIdx.z * rows * cols;
if (col_in_matrix < cols && row_in_matrix < rows) {
tile[tile_y][threadIdx.x] =
src[offset + row_in_matrix * cols + col_in_matrix];
if (col_in_matrix < cols) {
int row_range = (blockIdx.y < round_tile_rows)
? kTileSize
: (rows - kTileSize * round_tile_rows);
#pragma unroll
for (int tile_y = threadIdx.y; tile_y < row_range; tile_y += kBlockRows) {
IndexT row_in_matrix = tile_y + blockIdx.y * kTileSize;
v_shared[tile_y * kShareCol + threadIdx.x] =
vec_src[offset + row_in_matrix * cols + col_in_matrix];
}
}
// Singularized load data from shared memory into dst.
// and dst_cols = rows, dst_rows = cols, [cols * Vecsize, rows]
// Write data from shared memory into dst and
// dst_cols = rows, dst_rows = cols * Vecsize
col_in_matrix = blockIdx.y * kTileSize + threadIdx.x;
offset = offset * VecSize + col_in_matrix;
IndexT tile_x_idx = threadIdx.x * (kShareCol * VecSize);
__syncthreads();
for (IndexT tile_y = threadIdx.y; tile_y < kTileSize; tile_y += kBlockRows) {
IndexT row_in_matrix = tile_y + blockIdx.x * kTileSize;
IndexT dst_idx = offset + row_in_matrix * VecSize * rows;
IndexT tile_idx = tile_x_idx + tile_y * VecSize;
if (col_in_matrix < /*dst_cols=*/rows &&
row_in_matrix < /*dst_rows=*/cols) {
if (col_in_matrix < /*dst_cols=*/rows) {
int col_range = (blockIdx.x < round_tile_cols)
? kTileSize
: (cols - kTileSize * round_tile_cols);
#pragma unroll
for (auto i = 0; i < VecSize; ++i) {
dst_data[dst_idx + i * rows] = single_tile[tile_idx + i];
for (IndexT tile_y = threadIdx.y; tile_y < col_range;
tile_y += kBlockRows) {
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
IndexT row_in_matrix = (tile_y + blockIdx.x * kTileSize) * VecSize + i;
IndexT shared_idx = (tile_y + threadIdx.x * kShareCol) * VecSize + i;
dst_data[offset + row_in_matrix * rows] = s_shared[shared_idx];
}
}
}
}
// With the byte limitation of shared_memory, the VecSize shall be restricted
// for the type whose byte-size is less than 8.
// With the byte limitation of shared_memory, the VecSize shall be
// restricted for the type whose byte-size is less than 4.
template <typename T,
typename IndexT,
int Size,
int VecSize = (sizeof(T) > 8 ? 1 : Size)>
int VecSize = (sizeof(T) > 4 ? 1 : Size)>
inline void LaunchTransposeKernel(const phi::GPUContext& ctx,
const std::vector<size_t>& dims,
const std::vector<int64_t>& dims,
const T* src,
T* dst) {
auto rank = dims.size();
IndexT num_batches = (rank == 2) ? 1 : dims[0];
IndexT rows = dims[rank - 2];
IndexT cols = dims[rank - 1];
IndexT cols = dims[rank - 1] / VecSize;
IndexT num_tile_rows = (rows + kTileSize - 1) / kTileSize;
IndexT num_tile_cols = (cols + kTileSize - 1) / kTileSize;
......@@ -1087,14 +1090,15 @@ inline void LaunchTransposeKernel(const phi::GPUContext& ctx,
dim3 threads(kTileSize, kBlockRows, 1);
BatchTransposeKernel<T, IndexT, VecSize>
<<<blocks, threads, 0, ctx.stream()>>>(src, dst, rows, cols);
<<<blocks, threads, 0, ctx.stream()>>>(
src, dst, rows, cols, num_tile_rows - 1, num_tile_cols - 1);
}
template <typename T, typename IndexT>
inline void LaunchWithDispatchVecSize(const phi::GPUContext& ctx,
const int vec_size,
const PermuteType perm_type,
const std::vector<size_t>& dims,
const std::vector<int64_t>& dims,
const std::vector<int>& perm,
const T* src,
T* dst,
......@@ -1123,60 +1127,50 @@ inline void LaunchWithDispatchVecSize(const phi::GPUContext& ctx,
#undef CALL_DISPATCH_VEC_SIZE
}
template <typename T>
inline void LaunchWithDispatchIndex(const phi::GPUContext& ctx,
const size_t count,
const int vec_size,
const PermuteType perm_type,
const std::vector<size_t>& dims,
const std::vector<int>& perm,
const T* src,
T* dst) {
if (count < std::numeric_limits<uint32_t>::max()) {
LaunchWithDispatchVecSize<T, uint32_t>(ctx,
vec_size,
perm_type,
dims,
perm,
src,
dst,
static_cast<uint32_t>(count));
} else {
int64_t cnt = static_cast<int64_t>(count);
LaunchWithDispatchVecSize<T, int64_t>(ctx,
vec_size,
perm_type,
dims,
perm,
src,
dst,
static_cast<int64_t>(count));
}
}
template <typename DeviceContext, typename T>
inline void SimplifyThenLaunch(const int rank,
const DeviceContext& ctx,
const phi::DenseTensor& in,
phi::DenseTensor* out,
const std::vector<int32_t>& perm) {
int sm_count = ctx.GetSMCount();
auto src_dims = phi::vectorize<size_t>(in.dims());
auto simplifier = DimsSimplifier<T>(
sm_count, rank, perm, src_dims, in.data<T>(), out->data<T>());
if (simplifier.GetPermType() == PermuteType::kCopy) {
inline void PermuteAndTranspose(const int rank,
const DeviceContext& ctx,
const phi::DenseTensor& in,
phi::DenseTensor* out,
const std::vector<int32_t>& perm) {
const int64_t numel = in.numel();
auto classifier =
TranposeTypeClassifier<T>(ctx.GetSMCount(),
rank,
numel,
perm,
phi::vectorize<int64_t>(in.dims()),
in.data<T>(),
out->data<T>());
if (classifier.GetPermType() == PermuteType::kCopy) {
// If perm is [0,1,2,3], then just operate a DtoD copy.
phi::Copy(ctx, in, ctx.GetPlace(), false, out);
phi::backends::gpu::GpuMemcpyAsync(out->data<T>(),
in.data<T>(),
numel * sizeof(T),
phi::gpuMemcpyDeviceToDevice,
ctx.stream());
} else {
LaunchWithDispatchIndex<T>(ctx,
simplifier.GetCount(),
simplifier.GetVecSize(),
simplifier.GetPermType(),
simplifier.GetDims(),
simplifier.GetPerm(),
in.data<T>(),
out->data<T>());
if (numel < std::numeric_limits<int>::max()) {
LaunchWithDispatchVecSize<T, int>(ctx,
classifier.GetVecSize(),
classifier.GetPermType(),
classifier.GetSrcDims(),
classifier.GetPerm(),
in.data<T>(),
out->data<T>(),
static_cast<int>(numel));
} else {
int64_t cnt = static_cast<int64_t>(numel);
LaunchWithDispatchVecSize<T, int64_t>(ctx,
classifier.GetVecSize(),
classifier.GetPermType(),
classifier.GetSrcDims(),
classifier.GetPerm(),
in.data<T>(),
out->data<T>(),
static_cast<int64_t>(numel));
}
}
}
......@@ -1196,7 +1190,7 @@ void TransposeGPUKernelDriver(const phi::GPUContext& ctx,
if (!ret) {
auto* tuner =
phi::autotune::MakeTransposeTuner<T>(TransCompute<phi::GPUContext, T>);
tuner->AddCallBack(SimplifyThenLaunch<phi::GPUContext, T>);
tuner->AddCallBack(PermuteAndTranspose<phi::GPUContext, T>);
size_t key = phi::autotune::TransposeKey(
phi::vectorize(in.dims()),
......
......@@ -71,69 +71,72 @@ enum PermuteType {
kCopy = 1,
kTranspose = 2,
kVecPermute = 3,
kNormalPermute = 4
kGeneralPermute = 4
};
constexpr int kBlockRows = 16;
constexpr int kTileSize = 32;
// To avoid bank conflict.
constexpr int kShareCol = kTileSize + 1;
// Simplify the input dims and permute dims if possible.
template <typename T>
class DimsSimplifier {
class TranposeTypeClassifier {
public:
explicit DimsSimplifier(const int sm_count,
const int rank,
const std::vector<int32_t>& perm,
const std::vector<size_t>& dims,
const T* src,
T* dst)
: perm_(rank), dims_(rank) {
TranposeTypeClassifier(const int sm_count,
const size_t rank,
const int64_t numel,
const std::vector<int32_t>& perm,
const std::vector<int64_t>& dims,
const T* src,
T* dst)
: perm_(rank), src_dims(rank) {
SimplifyPermAndDims(rank, dims, perm);
count_ = std::accumulate(
dims.begin(), dims.end(), size_t{1}, std::multiplies<size_t>());
if (rank_ > 1) {
vec_size_ = GetPermVecSize(sm_count, src, dst);
perm_.resize(rank_);
dims_.resize(rank_);
}
perm_.resize(rank_);
src_dims.resize(rank_);
dst_dims.resize(rank_);
for (auto i = 0; i < rank_; ++i) {
dst_dims[i] = src_dims[perm_[i]];
}
}
size_t GetCount() const { return count_; }
int GetRank() const { return rank_; }
int GetVecSize() const { return vec_size_; }
PermuteType GetPermType() const { return type_; }
std::vector<int> GetPerm() const { return perm_; }
std::vector<size_t> GetDims() const { return dims_; }
std::vector<int64_t> GetSrcDims() const { return src_dims; }
std::vector<int64_t> GetDstDims() const { return dst_dims; }
private:
size_t rank_{1};
size_t count_{0};
int rank_{1};
int vec_size_{1};
std::vector<int> perm_;
std::vector<size_t> dims_;
std::vector<int64_t> src_dims;
std::vector<int64_t> dst_dims;
PermuteType type_{kCopy};
void SimplifyPermAndDims(const size_t rank,
const std::vector<size_t>& in_dims,
const std::vector<int64_t>& in_dims,
const std::vector<int32_t>& perm) {
size_t combined_dims[phi::DDim::kMaxRank];
int64_t combined_dims[phi::DDim::kMaxRank];
int valid_map[phi::DDim::kMaxRank];
// Merge consecutive dims to the fist one of this these dims,
// and leave the origin dim value to be 1. Example below :
// Merge consecutive dims to the fist one dim and
// leave original dim to be 1. Example below :
// perm: [2, 3, 0, 1], origin_dims : [4, 8, 2, 5]
// new_dims: [4, 8, 2, 5] -> [32, 1, 10, 1]
size_t start_perm_idx = 0;
int start_perm_idx = 0;
while (start_perm_idx < rank) {
const size_t start_dim_idx = perm[start_perm_idx];
const int start_dim_idx = perm[start_perm_idx];
combined_dims[start_dim_idx] = in_dims[start_dim_idx];
size_t end_perm_idx = start_perm_idx + 1;
int end_perm_idx = start_perm_idx + 1;
while (end_perm_idx < rank &&
perm[end_perm_idx] == perm[end_perm_idx - 1] + 1) {
const size_t end_dim_idx = perm[end_perm_idx];
const int end_dim_idx = perm[end_perm_idx];
combined_dims[start_dim_idx] *= in_dims[end_dim_idx];
combined_dims[end_dim_idx] = 1;
end_perm_idx += 1;
......@@ -145,22 +148,22 @@ class DimsSimplifier {
// for example, if combined dims is [32, 1, 10, 1],
// valid_map is [0, -1, 1, -1] and generate simplified
// dims as [32, 10]
size_t valid_dim_idx = 0;
int valid_dim_idx = 0;
bool sequential_flag = false;
for (size_t i = 0; i < rank; ++i) {
for (auto i = 0; i < rank; ++i) {
const int src_dim = combined_dims[i];
if (src_dim == 1) {
valid_map[i] = -1;
} else {
sequential_flag = true;
valid_map[i] = valid_dim_idx;
dims_[valid_dim_idx] = src_dim;
src_dims[valid_dim_idx] = src_dim;
valid_dim_idx += 1;
}
}
if (valid_dim_idx == 0) {
dims_[0] = 1;
src_dims[0] = 1;
perm_[0] = 0;
return;
} else if (valid_dim_idx == 1) {
......@@ -169,8 +172,8 @@ class DimsSimplifier {
// Acquire simplified perm with help of combined dims
// and original perm, finally simplified perm is [1, 0]
size_t perm_idx = 0;
for (size_t i = 0; i < rank; ++i) {
int perm_idx = 0;
for (auto i = 0; i < rank; ++i) {
const int mapped = valid_map[perm[i]];
if (mapped >= 0) {
perm_[perm_idx] = mapped;
......@@ -183,20 +186,17 @@ class DimsSimplifier {
int GetPermVecSize(const int sm_count, const T* src, T* dst) {
// For gerneal_permute kernel, there is good chance for
// vectorized write.
type_ = PermuteType::kGeneralPermute;
int vec_size = phi::GetVectorizedSize<T>(dst);
type_ = PermuteType::kNormalPermute;
// While the last dim is fixed, there is good chance for
// both vectorized read and write.
if (perm_[rank_ - 1] == rank_ - 1) {
int tmp_size = std::min(vec_size, phi::GetVectorizedSize<T>(src));
tmp_size = GetDimVesSize(tmp_size, dims_[rank_ - 1]);
tmp_size = GetDimVesSize(tmp_size, src_dims[rank_ - 1]);
if (tmp_size > 1) {
type_ = kVecPermute;
vec_size = tmp_size;
// For stride calculation of src_data index.
dims_[rank_ - 1] /= vec_size;
}
}
......@@ -205,31 +205,11 @@ class DimsSimplifier {
if ((rank_ == 2 && perm_[1] == 0 && perm_[0] == 1) ||
(rank_ == 3 && perm_[2] == 1 && perm_[1] == 2)) {
type_ = PermuteType::kTranspose;
// Compared with vectorized load or read, set config to let more
// sm work simultaneously affect more according to performance.
constexpr int threads = kTileSize * kTileSize;
int blocks = count_ / threads;
if (blocks < sm_count) {
vec_size = 1;
} else {
int tmp_vec = std::min(vec_size, phi::GetVectorizedSize<T>(src));
// With bytes limitation of shared_memory, the VecSize shall be
// restricted for the type whose byte-size is less than 8 (double).
int type_vec =
sizeof(T) > 8 ? 1 : GetDimVesSize(tmp_vec, dims_[rank_ - 1]);
for (int i = type_vec; i > 0; i /= 2) {
if (blocks / i >= sm_count) {
break;
}
// When blocks is smaller than sm_count, a test shown that decrease
// vec_size to make blocks close to sm_count would gain performance.
vec_size = i;
}
}
dims_[rank_ - 1] /= vec_size;
count_ /= vec_size;
int tmp_vec = std::min(vec_size, phi::GetVectorizedSize<T>(src));
// With bytes limitation of shared_memory, the VecSize shall be
// restricted for the type whose byte-size is less than 8 (double).
vec_size =
sizeof(T) > 8 ? 1 : GetDimVesSize(tmp_vec, src_dims[rank_ - 1]);
}
return vec_size;
}
......
......@@ -123,7 +123,7 @@ class AutoTuneBase {
float RunAndMeasureKernel(const Context& ctx, const int idx, Args&&... args) {
// Regard 1st run as warmup, judge the compare result by the time cost
// of rest cycles.
constexpr int repeats = 3;
constexpr int repeats = 4;
phi::GpuTimer timer;
float time_cost = 0;
const auto& stream = ctx.stream();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册