未验证 提交 9b70c556 编写于 作者: Y Yuang Liu 提交者: GitHub

rename the template type name for tranpose (#45834)

上级 420d186a
...@@ -82,10 +82,10 @@ struct Index3 : DeviceArray<int, 3, 0> { ...@@ -82,10 +82,10 @@ struct Index3 : DeviceArray<int, 3, 0> {
}; };
// Flat index with real dimension // Flat index with real dimension
template <typename IDX_T = int> template <typename IndexType = int>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IDX_T FlatTensorIndex(const Index3& index, EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IndexType
const Dim3& dims) { FlatTensorIndex(const Index3& index, const Dim3& dims) {
IDX_T flat_index = index[0]; IndexType flat_index = index[0];
for (int i = 1; i < 3; i++) { for (int i = 1; i < 3; i++) {
flat_index = flat_index * dims[i] + index[i]; flat_index = flat_index * dims[i] + index[i];
} }
...@@ -93,12 +93,12 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IDX_T FlatTensorIndex(const Index3& index, ...@@ -93,12 +93,12 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IDX_T FlatTensorIndex(const Index3& index,
} }
// Convert index to tensor index with dimension. // Convert index to tensor index with dimension.
template <typename IDX_T = int> template <typename IndexType = int>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index3 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index3
ConvertTensorIndex(IDX_T index, const Dim3& dims) { ConvertTensorIndex(IndexType index, const Dim3& dims) {
Index3 tensor_index; Index3 tensor_index;
for (int i = 2; i >= 0; i--) { for (int i = 2; i >= 0; i--) {
IDX_T new_index = index / dims[i]; IndexType new_index = index / dims[i];
tensor_index[i] = static_cast<int>(index - dims[i] * new_index); tensor_index[i] = static_cast<int>(index - dims[i] * new_index);
index = new_index; index = new_index;
} }
......
...@@ -83,7 +83,7 @@ template <typename T, ...@@ -83,7 +83,7 @@ template <typename T,
int NumThreads, int NumThreads,
int TileX, int TileX,
int TileY, int TileY,
typename IDX_T = int> typename IndexType = int>
__global__ void TilingSwapDim1And2(const T* __restrict__ input, __global__ void TilingSwapDim1And2(const T* __restrict__ input,
Dim3 input_dims, Dim3 input_dims,
T* __restrict__ output) { T* __restrict__ output) {
...@@ -119,8 +119,8 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input, ...@@ -119,8 +119,8 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
}; };
// Converts block idx to tile index, each block process a tile // Converts block idx to tile index, each block process a tile
Index3 input_block_tile_index = Index3 input_block_tile_index = framework::ConvertTensorIndex<IndexType>(
framework::ConvertTensorIndex<IDX_T>(blockIdx.x, tile_aligned_input_dim); blockIdx.x, tile_aligned_input_dim);
// Compute real index align to tile:0, 32, 64... // Compute real index align to tile:0, 32, 64...
Index3 block_tile_index_in_input = { Index3 block_tile_index_in_input = {
...@@ -130,11 +130,12 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input, ...@@ -130,11 +130,12 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
}; };
// Compute block flat index against input dims. // Compute block flat index against input dims.
IDX_T input_origin_block_flat_index = IndexType input_origin_block_flat_index =
framework::FlatTensorIndex<IDX_T>(block_tile_index_in_input, input_dims); framework::FlatTensorIndex<IndexType>(block_tile_index_in_input,
input_dims);
bool full_tile = true; bool full_tile = true;
IDX_T tile_width = TileY; IndexType tile_width = TileY;
// Last row is not full. // Last row is not full.
if (input_block_tile_index[2] == tile_aligned_input_dim[2] - 1) { if (input_block_tile_index[2] == tile_aligned_input_dim[2] - 1) {
...@@ -142,21 +143,22 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input, ...@@ -142,21 +143,22 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
full_tile &= false; full_tile &= false;
} }
IDX_T tile_height = TileX; IndexType tile_height = TileX;
if (input_block_tile_index[1] == tile_aligned_input_dim[1] - 1) { if (input_block_tile_index[1] == tile_aligned_input_dim[1] - 1) {
tile_height = input_dims[1] - (tile_aligned_input_dim[1] - 1) * TileX; tile_height = input_dims[1] - (tile_aligned_input_dim[1] - 1) * TileX;
full_tile &= false; full_tile &= false;
} }
constexpr IDX_T in_effective_thread_num = NumThreads / TileY * TileY; constexpr IndexType in_effective_thread_num = NumThreads / TileY * TileY;
if (x < in_effective_thread_num) { if (x < in_effective_thread_num) {
// Read a tile from input using block. // Read a tile from input using block.
int x_i = x / TileY; int x_i = x / TileY;
int x_j = x % TileY; int x_j = x % TileY;
IDX_T input_ind = input_origin_block_flat_index + x_i * input_dims[2] + x_j; IndexType input_ind =
IDX_T input_inc = BlockReadRows * input_dims[2]; input_origin_block_flat_index + x_i * input_dims[2] + x_j;
IndexType input_inc = BlockReadRows * input_dims[2];
if (full_tile) { if (full_tile) {
#pragma unroll #pragma unroll
...@@ -167,7 +169,8 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input, ...@@ -167,7 +169,8 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
} else { } else {
if (x_j < tile_width) { if (x_j < tile_width) {
#pragma unroll #pragma unroll
for (IDX_T ind_i = x_i; ind_i < (tile_height); ind_i += BlockReadRows) { for (IndexType ind_i = x_i; ind_i < (tile_height);
ind_i += BlockReadRows) {
tile_sm[ind_i][x_j] = input[input_ind]; tile_sm[ind_i][x_j] = input[input_ind];
input_ind += input_inc; input_ind += input_inc;
} }
...@@ -190,17 +193,18 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input, ...@@ -190,17 +193,18 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
output_block_tile_index[2] * TileX, output_block_tile_index[2] * TileX,
}; };
IDX_T output_origin_block_flat_index = framework::FlatTensorIndex<IDX_T>( IndexType output_origin_block_flat_index =
block_tile_index_in_output, output_dims); framework::FlatTensorIndex<IndexType>(block_tile_index_in_output,
output_dims);
constexpr IDX_T out_effective_thread_num = NumThreads / TileX * TileX; constexpr IndexType out_effective_thread_num = NumThreads / TileX * TileX;
if (x < out_effective_thread_num) { if (x < out_effective_thread_num) {
int x_i = x / TileX; int x_i = x / TileX;
int x_j = x % TileX; int x_j = x % TileX;
IDX_T output_ind = IndexType output_ind =
output_origin_block_flat_index + x_i * output_dims[2] + x_j; output_origin_block_flat_index + x_i * output_dims[2] + x_j;
IDX_T output_inc = BlockWriteRows * output_dims[2]; IndexType output_inc = BlockWriteRows * output_dims[2];
if (full_tile) { if (full_tile) {
#pragma unroll #pragma unroll
...@@ -211,7 +215,8 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input, ...@@ -211,7 +215,8 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
} else { } else {
if (x_j < tile_height) { if (x_j < tile_height) {
#pragma unroll #pragma unroll
for (IDX_T ind_i = x_i; ind_i < (tile_width); ind_i += BlockWriteRows) { for (IndexType ind_i = x_i; ind_i < (tile_width);
ind_i += BlockWriteRows) {
output[output_ind] = tile_sm[x_j][ind_i]; output[output_ind] = tile_sm[x_j][ind_i];
output_ind += output_inc; output_ind += output_inc;
} }
...@@ -276,21 +281,21 @@ struct SystemElemType<16> { ...@@ -276,21 +281,21 @@ struct SystemElemType<16> {
using type = float4; using type = float4;
}; };
template <typename T, int tile_long, int tile_short, typename IDX_T = int> template <typename T, int tile_long, int tile_short, typename IndexType = int>
void LaunchNarrowDims2TransposeKernel(const phi::GPUContext& d, void LaunchNarrowDims2TransposeKernel(const phi::GPUContext& d,
int tile_size_i, int tile_size_i,
int tile_size_j, int tile_size_j,
IDX_T total_tiles_count, IndexType total_tiles_count,
const T* input, const T* input,
const Dim3& input_dims, const Dim3& input_dims,
T* output) { T* output) {
constexpr int NumThreads = tile_long; constexpr int NumThreads = tile_long;
if (tile_size_i <= tile_long && tile_size_j <= tile_short) { if (tile_size_i <= tile_long && tile_size_j <= tile_short) {
TilingSwapDim1And2<T, NumThreads, tile_long, tile_short, IDX_T> TilingSwapDim1And2<T, NumThreads, tile_long, tile_short, IndexType>
<<<total_tiles_count, NumThreads, 0, d.stream()>>>( <<<total_tiles_count, NumThreads, 0, d.stream()>>>(
input, input_dims, output); input, input_dims, output);
} else { } else {
TilingSwapDim1And2<T, NumThreads, tile_short, tile_long, IDX_T> TilingSwapDim1And2<T, NumThreads, tile_short, tile_long, IndexType>
<<<total_tiles_count, NumThreads, 0, d.stream()>>>( <<<total_tiles_count, NumThreads, 0, d.stream()>>>(
input, input_dims, output); input, input_dims, output);
} }
...@@ -299,13 +304,13 @@ void LaunchNarrowDims2TransposeKernel(const phi::GPUContext& d, ...@@ -299,13 +304,13 @@ void LaunchNarrowDims2TransposeKernel(const phi::GPUContext& d,
template <typename T, template <typename T,
int tile_long, int tile_long,
int tile_short, int tile_short,
typename IDX_T = int, typename IndexType = int,
typename dummy = void> typename dummy = void>
struct NarrowDims2TransposeDispatch { struct NarrowDims2TransposeDispatch {
static void DoTranspose(const phi::GPUContext& d, static void DoTranspose(const phi::GPUContext& d,
int tile_size_i, int tile_size_i,
int tile_size_j, int tile_size_j,
IDX_T total_tiles_count, IndexType total_tiles_count,
const T* input, const T* input,
const Dim3& input_dims, const Dim3& input_dims,
T* output) { T* output) {
...@@ -321,7 +326,7 @@ struct NarrowDims2TransposeDispatch { ...@@ -321,7 +326,7 @@ struct NarrowDims2TransposeDispatch {
std::min(tile_size_i, tile_size_j) <= tile_short; std::min(tile_size_i, tile_size_j) <= tile_short;
if (request_satisfied) { if (request_satisfied) {
LaunchNarrowDims2TransposeKernel<T, tile_long, tile_short, IDX_T>( LaunchNarrowDims2TransposeKernel<T, tile_long, tile_short, IndexType>(
d, d,
tile_size_i, tile_size_i,
tile_size_j, tile_size_j,
...@@ -336,7 +341,7 @@ struct NarrowDims2TransposeDispatch { ...@@ -336,7 +341,7 @@ struct NarrowDims2TransposeDispatch {
std::max(tile_size_i, tile_size_j) > tile_long; std::max(tile_size_i, tile_size_j) > tile_long;
if (long_side_request_not_satisfied) { if (long_side_request_not_satisfied) {
NarrowDims2TransposeDispatch<T, tile_long * 2, tile_short, IDX_T>:: NarrowDims2TransposeDispatch<T, tile_long * 2, tile_short, IndexType>::
DoTranspose(d, DoTranspose(d,
tile_size_i, tile_size_i,
tile_size_j, tile_size_j,
...@@ -345,7 +350,7 @@ struct NarrowDims2TransposeDispatch { ...@@ -345,7 +350,7 @@ struct NarrowDims2TransposeDispatch {
input_dims, input_dims,
output); output);
} else { } else {
NarrowDims2TransposeDispatch<T, tile_long, tile_short + 1, IDX_T>:: NarrowDims2TransposeDispatch<T, tile_long, tile_short + 1, IndexType>::
DoTranspose(d, DoTranspose(d,
tile_size_i, tile_size_i,
tile_size_j, tile_size_j,
...@@ -358,19 +363,19 @@ struct NarrowDims2TransposeDispatch { ...@@ -358,19 +363,19 @@ struct NarrowDims2TransposeDispatch {
}; };
// If Not long tile size, goto this function when compile. // If Not long tile size, goto this function when compile.
template <typename T, int tile_long, int tile_short, typename IDX_T> template <typename T, int tile_long, int tile_short, typename IndexType>
struct NarrowDims2TransposeDispatch< struct NarrowDims2TransposeDispatch<
T, T,
tile_long, tile_long,
tile_short, tile_short,
IDX_T, IndexType,
typename std::enable_if<CheckNonLongTileSize( typename std::enable_if<CheckNonLongTileSize(
tile_long, tile_short, sizeof(T)), tile_long, tile_short, sizeof(T)),
void>::type> { void>::type> {
static void DoTranspose(const phi::GPUContext& d, static void DoTranspose(const phi::GPUContext& d,
int tile_size_i, int tile_size_i,
int tile_size_j, int tile_size_j,
IDX_T total_tiles_count, IndexType total_tiles_count,
const T* input, const T* input,
const Dim3& input_dims, const Dim3& input_dims,
T* output) { T* output) {
...@@ -386,7 +391,7 @@ struct NarrowDims2TransposeDispatch< ...@@ -386,7 +391,7 @@ struct NarrowDims2TransposeDispatch<
std::min(tile_size_i, tile_size_j) <= tile_short; std::min(tile_size_i, tile_size_j) <= tile_short;
if (request_satisfied) { if (request_satisfied) {
LaunchNarrowDims2TransposeKernel<T, tile_long, tile_short, IDX_T>( LaunchNarrowDims2TransposeKernel<T, tile_long, tile_short, IndexType>(
d, d,
tile_size_i, tile_size_i,
tile_size_j, tile_size_j,
...@@ -397,7 +402,7 @@ struct NarrowDims2TransposeDispatch< ...@@ -397,7 +402,7 @@ struct NarrowDims2TransposeDispatch<
return; return;
} }
NarrowDims2TransposeDispatch<T, tile_long, tile_short + 1, IDX_T>:: NarrowDims2TransposeDispatch<T, tile_long, tile_short + 1, IndexType>::
DoTranspose(d, DoTranspose(d,
tile_size_i, tile_size_i,
tile_size_j, tile_size_j,
...@@ -409,18 +414,18 @@ struct NarrowDims2TransposeDispatch< ...@@ -409,18 +414,18 @@ struct NarrowDims2TransposeDispatch<
}; };
// If long tile size, goto this function when compile. // If long tile size, goto this function when compile.
template <typename T, int tile_long, int tile_short, typename IDX_T> template <typename T, int tile_long, int tile_short, typename IndexType>
struct NarrowDims2TransposeDispatch< struct NarrowDims2TransposeDispatch<
T, T,
tile_long, tile_long,
tile_short, tile_short,
IDX_T, IndexType,
typename std::enable_if<CheckLongTileSize(tile_long, tile_short, sizeof(T)), typename std::enable_if<CheckLongTileSize(tile_long, tile_short, sizeof(T)),
void>::type> { void>::type> {
static void DoTranspose(const phi::GPUContext& d, static void DoTranspose(const phi::GPUContext& d,
int tile_size_i, int tile_size_i,
int tile_size_j, int tile_size_j,
IDX_T total_tiles_count, IndexType total_tiles_count,
const T* input, const T* input,
const Dim3& input_dims, const Dim3& input_dims,
T* output) { T* output) {
...@@ -432,7 +437,7 @@ struct NarrowDims2TransposeDispatch< ...@@ -432,7 +437,7 @@ struct NarrowDims2TransposeDispatch<
" but received is:%d.", " but received is:%d.",
tile_long)); tile_long));
LaunchNarrowDims2TransposeKernel<T, tile_long, tile_short, IDX_T>( LaunchNarrowDims2TransposeKernel<T, tile_long, tile_short, IndexType>(
d, d,
tile_size_i, tile_size_i,
tile_size_j, tile_size_j,
...@@ -443,7 +448,7 @@ struct NarrowDims2TransposeDispatch< ...@@ -443,7 +448,7 @@ struct NarrowDims2TransposeDispatch<
} }
}; };
template <typename T, bool conjugate = false, typename IDX_T = int> template <typename T, bool conjugate = false, typename IndexType = int>
void SwapDim1And2InNarrow(const phi::GPUContext& d, void SwapDim1And2InNarrow(const phi::GPUContext& d,
const T* input, const T* input,
const Dim3& input_dims, const Dim3& input_dims,
...@@ -514,14 +519,14 @@ void SwapDim1And2InNarrow(const phi::GPUContext& d, ...@@ -514,14 +519,14 @@ void SwapDim1And2InNarrow(const phi::GPUContext& d,
framework::CeilOrFloor<int, true>(input_dims[2], select_tile_size_j), framework::CeilOrFloor<int, true>(input_dims[2], select_tile_size_j),
}; };
IDX_T total_tiles_count = input_dims_aligned[0]; IndexType total_tiles_count = input_dims_aligned[0];
total_tiles_count *= input_dims_aligned[1]; total_tiles_count *= input_dims_aligned[1];
total_tiles_count *= input_dims_aligned[2]; total_tiles_count *= input_dims_aligned[2];
// Suppose T can be replaced by system builtin types // Suppose T can be replaced by system builtin types
using ElemType = typename SystemElemType<sizeof(T)>::type; using ElemType = typename SystemElemType<sizeof(T)>::type;
NarrowDims2TransposeDispatch<ElemType, 32, 2, IDX_T>::DoTranspose( NarrowDims2TransposeDispatch<ElemType, 32, 2, IndexType>::DoTranspose(
d, d,
select_tile_size_i, select_tile_size_i,
select_tile_size_j, select_tile_size_j,
...@@ -533,8 +538,8 @@ void SwapDim1And2InNarrow(const phi::GPUContext& d, ...@@ -533,8 +538,8 @@ void SwapDim1And2InNarrow(const phi::GPUContext& d,
// This is for case that cannot do coalescing read and write. // This is for case that cannot do coalescing read and write.
// Or input is too small to split into tiles. // Or input is too small to split into tiles.
template <typename T, int pos0, int pos1, int pos2, typename IDX_T = int> template <typename T, int pos0, int pos1, int pos2, typename IndexType = int>
__global__ void TransposeSimpleKernel(IDX_T nthreads, __global__ void TransposeSimpleKernel(IndexType nthreads,
const T* __restrict__ input, const T* __restrict__ input,
Dim3 input_dims, Dim3 input_dims,
T* __restrict__ output) { T* __restrict__ output) {
...@@ -543,24 +548,24 @@ __global__ void TransposeSimpleKernel(IDX_T nthreads, ...@@ -543,24 +548,24 @@ __global__ void TransposeSimpleKernel(IDX_T nthreads,
output_dims[pos1] = input_dims[1]; output_dims[pos1] = input_dims[1];
output_dims[pos2] = input_dims[2]; output_dims[pos2] = input_dims[2];
CUDA_KERNEL_LOOP_TYPE(output_index, nthreads, IDX_T) { CUDA_KERNEL_LOOP_TYPE(output_index, nthreads, IndexType) {
Index3 output_tensor_index = Index3 output_tensor_index =
framework::ConvertTensorIndex<IDX_T>(output_index, output_dims); framework::ConvertTensorIndex<IndexType>(output_index, output_dims);
Index3 input_tensor_index; Index3 input_tensor_index;
input_tensor_index[0] = output_tensor_index[pos0]; input_tensor_index[0] = output_tensor_index[pos0];
input_tensor_index[1] = output_tensor_index[pos1]; input_tensor_index[1] = output_tensor_index[pos1];
input_tensor_index[2] = output_tensor_index[pos2]; input_tensor_index[2] = output_tensor_index[pos2];
IDX_T input_index = IndexType input_index =
framework::FlatTensorIndex<IDX_T>(input_tensor_index, input_dims); framework::FlatTensorIndex<IndexType>(input_tensor_index, input_dims);
output[output_index] = input[input_index]; output[output_index] = input[input_index];
} }
} }
// Here suppose convert all tensor to dim3, so just change dim1 and 2. // Here suppose convert all tensor to dim3, so just change dim1 and 2.
template <typename T, typename IDX_T = int> template <typename T, typename IndexType = int>
void SendSwapDim1And2InTranspose(const phi::GPUContext& d, void SendSwapDim1And2InTranspose(const phi::GPUContext& d,
const T* input, const T* input,
const Dim3& input_dims, const Dim3& input_dims,
...@@ -585,11 +590,11 @@ void SendSwapDim1And2InTranspose(const phi::GPUContext& d, ...@@ -585,11 +590,11 @@ void SendSwapDim1And2InTranspose(const phi::GPUContext& d,
framework::CeilOrFloor<int, true>(input_dims[2], kTileSize), framework::CeilOrFloor<int, true>(input_dims[2], kTileSize),
}; };
IDX_T total_tiles_count = input_dims_aligned[0]; IndexType total_tiles_count = input_dims_aligned[0];
total_tiles_count *= input_dims_aligned[1]; total_tiles_count *= input_dims_aligned[1];
total_tiles_count *= input_dims_aligned[2]; total_tiles_count *= input_dims_aligned[2];
TilingSwapDim1And2<T, kNumThreads, kTileSize, kTileSize, IDX_T> TilingSwapDim1And2<T, kNumThreads, kTileSize, kTileSize, IndexType>
<<<total_tiles_count, kNumThreads, 0, d.stream()>>>( <<<total_tiles_count, kNumThreads, 0, d.stream()>>>(
input, input_dims, output); input, input_dims, output);
...@@ -597,21 +602,21 @@ void SendSwapDim1And2InTranspose(const phi::GPUContext& d, ...@@ -597,21 +602,21 @@ void SendSwapDim1And2InTranspose(const phi::GPUContext& d,
// If input shape is like Rect, such as 2X100, use Narrow tile size. // If input shape is like Rect, such as 2X100, use Narrow tile size.
// It makes things complicated, because need to find a tile can coverr // It makes things complicated, because need to find a tile can coverr
// input and also reach best coalescing. // input and also reach best coalescing.
SwapDim1And2InNarrow<T, false, IDX_T>( SwapDim1And2InNarrow<T, false, IndexType>(
d, input, input_dims, output, kMinTileSize); d, input, input_dims, output, kMinTileSize);
} else { } else {
// If input shape is small, such as 8X8, just do simple copy // If input shape is small, such as 8X8, just do simple copy
IDX_T total_elements = input_dims[0]; IndexType total_elements = input_dims[0];
total_elements *= input_dims[1]; total_elements *= input_dims[1];
total_elements *= input_dims[2]; total_elements *= input_dims[2];
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(d, total_elements); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(d, total_elements);
TransposeSimpleKernel<T, 0, 2, 1, IDX_T> TransposeSimpleKernel<T, 0, 2, 1, IndexType>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>( <<<config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>(
total_elements, input, input_dims, output); total_elements, input, input_dims, output);
} }
} }
template <typename T, typename IDX_T = int> template <typename T, typename IndexType = int>
struct SwapDim1And2InTranspose { struct SwapDim1And2InTranspose {
typedef phi::GPUContext Device; typedef phi::GPUContext Device;
void operator()(const Device& d, void operator()(const Device& d,
...@@ -621,11 +626,11 @@ struct SwapDim1And2InTranspose { ...@@ -621,11 +626,11 @@ struct SwapDim1And2InTranspose {
Dim3 input_dims = {static_cast<int>(combined_dims[0]), Dim3 input_dims = {static_cast<int>(combined_dims[0]),
static_cast<int>(combined_dims[1]), static_cast<int>(combined_dims[1]),
static_cast<int>(combined_dims[2])}; static_cast<int>(combined_dims[2])};
SendSwapDim1And2InTranspose<T, IDX_T>(d, in, input_dims, out); SendSwapDim1And2InTranspose<T, IndexType>(d, in, input_dims, out);
} }
}; };
template <typename T, typename IDX_T = int> template <typename T, typename IndexType = int>
struct SwapDim0And2InTranspose { struct SwapDim0And2InTranspose {
typedef phi::GPUContext Device; typedef phi::GPUContext Device;
void operator()(const Device& d, void operator()(const Device& d,
...@@ -636,12 +641,12 @@ struct SwapDim0And2InTranspose { ...@@ -636,12 +641,12 @@ struct SwapDim0And2InTranspose {
static_cast<int>(combined_dims[1]), static_cast<int>(combined_dims[1]),
static_cast<int>(combined_dims[2])}; static_cast<int>(combined_dims[2])};
IDX_T total_size = combined_dims[0]; IndexType total_size = combined_dims[0];
total_size *= combined_dims[1]; total_size *= combined_dims[1];
total_size *= combined_dims[2]; total_size *= combined_dims[2];
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(d, total_size); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(d, total_size);
TransposeSimpleKernel<T, 2, 1, 0, IDX_T> TransposeSimpleKernel<T, 2, 1, 0, IndexType>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>( <<<config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>(
total_size, in, input_dims, out); total_size, in, input_dims, out);
} }
...@@ -705,7 +710,7 @@ inline void CombineTransposeDim3(const framework::DDim& shape, ...@@ -705,7 +710,7 @@ inline void CombineTransposeDim3(const framework::DDim& shape,
*new_dims = phi::make_ddim(dim_vec); *new_dims = phi::make_ddim(dim_vec);
} }
template <typename T, typename IDX_T = int> template <typename T, typename IndexType = int>
struct TransposeSimple { struct TransposeSimple {
static bool run(const phi::GPUContext& ctx, static bool run(const phi::GPUContext& ctx,
const Tensor& in, const Tensor& in,
...@@ -728,7 +733,7 @@ struct TransposeSimple { ...@@ -728,7 +733,7 @@ struct TransposeSimple {
if (new_perm[0] == 1 && new_perm[1] == 0) { if (new_perm[0] == 1 && new_perm[1] == 0) {
// Add the first dimension size as 1. // Add the first dimension size as 1.
new_dim_vec.insert(new_dim_vec.begin(), 1); new_dim_vec.insert(new_dim_vec.begin(), 1);
SwapDim1And2InTranspose<T, IDX_T>()( SwapDim1And2InTranspose<T, IndexType>()(
ctx, in_data, new_dim_vec, out_data); ctx, in_data, new_dim_vec, out_data);
return true; return true;
} }
...@@ -736,7 +741,7 @@ struct TransposeSimple { ...@@ -736,7 +741,7 @@ struct TransposeSimple {
case 3: case 3:
// In this case, suppose we can do coalescing read and write in tile. // In this case, suppose we can do coalescing read and write in tile.
if (new_perm == std::vector<int>({0, 2, 1})) { if (new_perm == std::vector<int>({0, 2, 1})) {
SwapDim1And2InTranspose<T, IDX_T>()( SwapDim1And2InTranspose<T, IndexType>()(
ctx, in_data, new_dim_vec, out_data); ctx, in_data, new_dim_vec, out_data);
return true; return true;
} else if (new_perm == std::vector<int>({2, 1, 0})) { } else if (new_perm == std::vector<int>({2, 1, 0})) {
...@@ -744,7 +749,7 @@ struct TransposeSimple { ...@@ -744,7 +749,7 @@ struct TransposeSimple {
// But I think it depends on the data size. If span is not large, // But I think it depends on the data size. If span is not large,
// maybe // maybe
// can do coalescing. // can do coalescing.
SwapDim0And2InTranspose<T, IDX_T>()( SwapDim0And2InTranspose<T, IndexType>()(
ctx, in_data, new_dim_vec, out_data); ctx, in_data, new_dim_vec, out_data);
return true; return true;
} else { } else {
...@@ -1183,7 +1188,7 @@ void TransposeGPUKernelDriver(const phi::GPUContext& ctx, ...@@ -1183,7 +1188,7 @@ void TransposeGPUKernelDriver(const phi::GPUContext& ctx,
const int rank = perm.size(); const int rank = perm.size();
int64_t numel = in.numel(); int64_t numel = in.numel();
bool ret{false}; bool ret{false};
if (numel >= INT32_MAX) { if (numel >= std::numeric_limits<int32_t>::max()) {
ret = TransposeSimple<T, int64_t>::run(ctx, in, perm, out); ret = TransposeSimple<T, int64_t>::run(ctx, in, perm, out);
} else { } else {
ret = TransposeSimple<T>::run(ctx, in, perm, out); ret = TransposeSimple<T>::run(ctx, in, perm, out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册