未验证 提交 d9a9e638 编写于 作者: Y Yuang Liu 提交者: GitHub

[alphafold] Transpose support large tensors where there numel is bigger than INT32_MAX (#45753)

上级 0ddcf30c
......@@ -82,9 +82,10 @@ struct Index3 : DeviceArray<int, 3, 0> {
};
// Flat index with real dimension
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int FlatTensorIndex(const Index3& index,
const Dim3& dims) {
int flat_index = index[0];
template <typename IDX_T = int>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IDX_T FlatTensorIndex(const Index3& index,
const Dim3& dims) {
IDX_T flat_index = index[0];
for (int i = 1; i < 3; i++) {
flat_index = flat_index * dims[i] + index[i];
}
......@@ -92,12 +93,13 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int FlatTensorIndex(const Index3& index,
}
// Convert index to tensor index with dimension.
template <typename IDX_T = int>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index3
ConvertTensorIndex(int index, const Dim3& dims) {
ConvertTensorIndex(IDX_T index, const Dim3& dims) {
Index3 tensor_index;
for (int i = 2; i >= 0; i--) {
int new_index = index / dims[i];
tensor_index[i] = index - dims[i] * new_index;
IDX_T new_index = index / dims[i];
tensor_index[i] = static_cast<int>(index - dims[i] * new_index);
index = new_index;
}
return tensor_index;
......
......@@ -79,7 +79,11 @@ constexpr bool CheckNonLongTileSize(int tile_long, int tile_short, int size_T) {
// Use SM to do data transfer, load a tile into SM then store out.
// All tile read and write are colascing, so can speedup memory copy
template <typename T, int NumThreads, int TileX, int TileY>
template <typename T,
int NumThreads,
int TileX,
int TileY,
typename IDX_T = int>
__global__ void TilingSwapDim1And2(const T* __restrict__ input,
Dim3 input_dims,
T* __restrict__ output) {
......@@ -116,7 +120,7 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
// Converts block idx to tile index, each block process a tile
Index3 input_block_tile_index =
ConvertTensorIndex(blockIdx.x, tile_aligned_input_dim);
framework::ConvertTensorIndex<IDX_T>(blockIdx.x, tile_aligned_input_dim);
// Compute real index align to tile:0, 32, 64...
Index3 block_tile_index_in_input = {
......@@ -126,11 +130,11 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
};
// Compute block flat index against input dims.
int input_origin_block_flat_index =
FlatTensorIndex(block_tile_index_in_input, input_dims);
IDX_T input_origin_block_flat_index =
framework::FlatTensorIndex<IDX_T>(block_tile_index_in_input, input_dims);
bool full_tile = true;
int tile_width = TileY;
IDX_T tile_width = TileY;
// Last row is not full.
if (input_block_tile_index[2] == tile_aligned_input_dim[2] - 1) {
......@@ -138,21 +142,21 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
full_tile &= false;
}
int tile_height = TileX;
IDX_T tile_height = TileX;
if (input_block_tile_index[1] == tile_aligned_input_dim[1] - 1) {
tile_height = input_dims[1] - (tile_aligned_input_dim[1] - 1) * TileX;
full_tile &= false;
}
constexpr int in_effective_thread_num = NumThreads / TileY * TileY;
constexpr IDX_T in_effective_thread_num = NumThreads / TileY * TileY;
if (x < in_effective_thread_num) {
// Read a tile from input using block.
int x_i = x / TileY;
int x_j = x % TileY;
int input_ind = input_origin_block_flat_index + x_i * input_dims[2] + x_j;
int input_inc = BlockReadRows * input_dims[2];
IDX_T input_ind = input_origin_block_flat_index + x_i * input_dims[2] + x_j;
IDX_T input_inc = BlockReadRows * input_dims[2];
if (full_tile) {
#pragma unroll
......@@ -163,7 +167,7 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
} else {
if (x_j < tile_width) {
#pragma unroll
for (int ind_i = x_i; ind_i < (tile_height); ind_i += BlockReadRows) {
for (IDX_T ind_i = x_i; ind_i < (tile_height); ind_i += BlockReadRows) {
tile_sm[ind_i][x_j] = input[input_ind];
input_ind += input_inc;
}
......@@ -186,17 +190,17 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
output_block_tile_index[2] * TileX,
};
int output_origin_block_flat_index =
FlatTensorIndex(block_tile_index_in_output, output_dims);
IDX_T output_origin_block_flat_index = framework::FlatTensorIndex<IDX_T>(
block_tile_index_in_output, output_dims);
constexpr int out_effective_thread_num = NumThreads / TileX * TileX;
constexpr IDX_T out_effective_thread_num = NumThreads / TileX * TileX;
if (x < out_effective_thread_num) {
int x_i = x / TileX;
int x_j = x % TileX;
int output_ind =
IDX_T output_ind =
output_origin_block_flat_index + x_i * output_dims[2] + x_j;
int output_inc = BlockWriteRows * output_dims[2];
IDX_T output_inc = BlockWriteRows * output_dims[2];
if (full_tile) {
#pragma unroll
......@@ -207,7 +211,7 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
} else {
if (x_j < tile_height) {
#pragma unroll
for (int ind_i = x_i; ind_i < (tile_width); ind_i += BlockWriteRows) {
for (IDX_T ind_i = x_i; ind_i < (tile_width); ind_i += BlockWriteRows) {
output[output_ind] = tile_sm[x_j][ind_i];
output_ind += output_inc;
}
......@@ -272,32 +276,36 @@ struct SystemElemType<16> {
using type = float4;
};
template <typename T, int tile_long, int tile_short>
template <typename T, int tile_long, int tile_short, typename IDX_T = int>
void LaunchNarrowDims2TransposeKernel(const phi::GPUContext& d,
int tile_size_i,
int tile_size_j,
int total_tiles_count,
IDX_T total_tiles_count,
const T* input,
const Dim3& input_dims,
T* output) {
constexpr int NumThreads = tile_long;
if (tile_size_i <= tile_long && tile_size_j <= tile_short) {
TilingSwapDim1And2<T, NumThreads, tile_long, tile_short>
TilingSwapDim1And2<T, NumThreads, tile_long, tile_short, IDX_T>
<<<total_tiles_count, NumThreads, 0, d.stream()>>>(
input, input_dims, output);
} else {
TilingSwapDim1And2<T, NumThreads, tile_short, tile_long>
TilingSwapDim1And2<T, NumThreads, tile_short, tile_long, IDX_T>
<<<total_tiles_count, NumThreads, 0, d.stream()>>>(
input, input_dims, output);
}
}
template <typename T, int tile_long, int tile_short, typename dummy = void>
template <typename T,
int tile_long,
int tile_short,
typename IDX_T = int,
typename dummy = void>
struct NarrowDims2TransposeDispatch {
static void DoTranspose(const phi::GPUContext& d,
int tile_size_i,
int tile_size_j,
int total_tiles_count,
IDX_T total_tiles_count,
const T* input,
const Dim3& input_dims,
T* output) {
......@@ -313,7 +321,7 @@ struct NarrowDims2TransposeDispatch {
std::min(tile_size_i, tile_size_j) <= tile_short;
if (request_satisfied) {
LaunchNarrowDims2TransposeKernel<T, tile_long, tile_short>(
LaunchNarrowDims2TransposeKernel<T, tile_long, tile_short, IDX_T>(
d,
tile_size_i,
tile_size_j,
......@@ -328,40 +336,41 @@ struct NarrowDims2TransposeDispatch {
std::max(tile_size_i, tile_size_j) > tile_long;
if (long_side_request_not_satisfied) {
NarrowDims2TransposeDispatch<T, tile_long * 2, tile_short>::DoTranspose(
d,
tile_size_i,
tile_size_j,
total_tiles_count,
input,
input_dims,
output);
NarrowDims2TransposeDispatch<T, tile_long * 2, tile_short, IDX_T>::
DoTranspose(d,
tile_size_i,
tile_size_j,
total_tiles_count,
input,
input_dims,
output);
} else {
NarrowDims2TransposeDispatch<T, tile_long, tile_short + 1>::DoTranspose(
d,
tile_size_i,
tile_size_j,
total_tiles_count,
input,
input_dims,
output);
NarrowDims2TransposeDispatch<T, tile_long, tile_short + 1, IDX_T>::
DoTranspose(d,
tile_size_i,
tile_size_j,
total_tiles_count,
input,
input_dims,
output);
}
}
};
// If Not long tile size, goto this function when compile.
template <typename T, int tile_long, int tile_short>
template <typename T, int tile_long, int tile_short, typename IDX_T>
struct NarrowDims2TransposeDispatch<
T,
tile_long,
tile_short,
IDX_T,
typename std::enable_if<CheckNonLongTileSize(
tile_long, tile_short, sizeof(T)),
void>::type> {
static void DoTranspose(const phi::GPUContext& d,
int tile_size_i,
int tile_size_j,
int total_tiles_count,
IDX_T total_tiles_count,
const T* input,
const Dim3& input_dims,
T* output) {
......@@ -377,7 +386,7 @@ struct NarrowDims2TransposeDispatch<
std::min(tile_size_i, tile_size_j) <= tile_short;
if (request_satisfied) {
LaunchNarrowDims2TransposeKernel<T, tile_long, tile_short>(
LaunchNarrowDims2TransposeKernel<T, tile_long, tile_short, IDX_T>(
d,
tile_size_i,
tile_size_j,
......@@ -388,29 +397,30 @@ struct NarrowDims2TransposeDispatch<
return;
}
NarrowDims2TransposeDispatch<T, tile_long, tile_short + 1>::DoTranspose(
d,
tile_size_i,
tile_size_j,
total_tiles_count,
input,
input_dims,
output);
NarrowDims2TransposeDispatch<T, tile_long, tile_short + 1, IDX_T>::
DoTranspose(d,
tile_size_i,
tile_size_j,
total_tiles_count,
input,
input_dims,
output);
}
};
// If long tile size, goto this function when compile.
template <typename T, int tile_long, int tile_short>
template <typename T, int tile_long, int tile_short, typename IDX_T>
struct NarrowDims2TransposeDispatch<
T,
tile_long,
tile_short,
IDX_T,
typename std::enable_if<CheckLongTileSize(tile_long, tile_short, sizeof(T)),
void>::type> {
static void DoTranspose(const phi::GPUContext& d,
int tile_size_i,
int tile_size_j,
int total_tiles_count,
IDX_T total_tiles_count,
const T* input,
const Dim3& input_dims,
T* output) {
......@@ -422,7 +432,7 @@ struct NarrowDims2TransposeDispatch<
" but received is:%d.",
tile_long));
LaunchNarrowDims2TransposeKernel<T, tile_long, tile_short>(
LaunchNarrowDims2TransposeKernel<T, tile_long, tile_short, IDX_T>(
d,
tile_size_i,
tile_size_j,
......@@ -433,7 +443,7 @@ struct NarrowDims2TransposeDispatch<
}
};
template <typename T, bool conjugate = false>
template <typename T, bool conjugate = false, typename IDX_T = int>
void SwapDim1And2InNarrow(const phi::GPUContext& d,
const T* input,
const Dim3& input_dims,
......@@ -504,13 +514,14 @@ void SwapDim1And2InNarrow(const phi::GPUContext& d,
framework::CeilOrFloor<int, true>(input_dims[2], select_tile_size_j),
};
int total_tiles_count =
input_dims_aligned[0] * input_dims_aligned[1] * input_dims_aligned[2];
IDX_T total_tiles_count = input_dims_aligned[0];
total_tiles_count *= input_dims_aligned[1];
total_tiles_count *= input_dims_aligned[2];
// Suppose T can be replaced by system builtin types
using ElemType = typename SystemElemType<sizeof(T)>::type;
NarrowDims2TransposeDispatch<ElemType, 32, 2>::DoTranspose(
NarrowDims2TransposeDispatch<ElemType, 32, 2, IDX_T>::DoTranspose(
d,
select_tile_size_i,
select_tile_size_j,
......@@ -522,8 +533,8 @@ void SwapDim1And2InNarrow(const phi::GPUContext& d,
// This is for case that cannot do coalescing read and write.
// Or input is too small to split into tiles.
template <typename T, int pos0, int pos1, int pos2>
__global__ void TransposeSimpleKernel(int nthreads,
template <typename T, int pos0, int pos1, int pos2, typename IDX_T = int>
__global__ void TransposeSimpleKernel(IDX_T nthreads,
const T* __restrict__ input,
Dim3 input_dims,
T* __restrict__ output) {
......@@ -532,22 +543,24 @@ __global__ void TransposeSimpleKernel(int nthreads,
output_dims[pos1] = input_dims[1];
output_dims[pos2] = input_dims[2];
CUDA_KERNEL_LOOP(output_index, nthreads) {
Index3 output_tensor_index = ConvertTensorIndex(output_index, output_dims);
CUDA_KERNEL_LOOP_TYPE(output_index, nthreads, IDX_T) {
Index3 output_tensor_index =
framework::ConvertTensorIndex<IDX_T>(output_index, output_dims);
Index3 input_tensor_index;
input_tensor_index[0] = output_tensor_index[pos0];
input_tensor_index[1] = output_tensor_index[pos1];
input_tensor_index[2] = output_tensor_index[pos2];
int input_index = FlatTensorIndex(input_tensor_index, input_dims);
IDX_T input_index =
framework::FlatTensorIndex<IDX_T>(input_tensor_index, input_dims);
output[output_index] = input[input_index];
}
}
// Here suppose convert all tensor to dim3, so just change dim1 and 2.
template <typename T>
template <typename T, typename IDX_T = int>
void SendSwapDim1And2InTranspose(const phi::GPUContext& d,
const T* input,
const Dim3& input_dims,
......@@ -572,10 +585,11 @@ void SendSwapDim1And2InTranspose(const phi::GPUContext& d,
framework::CeilOrFloor<int, true>(input_dims[2], kTileSize),
};
int total_tiles_count =
input_dims_aligned[0] * input_dims_aligned[1] * input_dims_aligned[2];
IDX_T total_tiles_count = input_dims_aligned[0];
total_tiles_count *= input_dims_aligned[1];
total_tiles_count *= input_dims_aligned[2];
TilingSwapDim1And2<T, kNumThreads, kTileSize, kTileSize>
TilingSwapDim1And2<T, kNumThreads, kTileSize, kTileSize, IDX_T>
<<<total_tiles_count, kNumThreads, 0, d.stream()>>>(
input, input_dims, output);
......@@ -583,18 +597,21 @@ void SendSwapDim1And2InTranspose(const phi::GPUContext& d,
// If input shape is like Rect, such as 2X100, use Narrow tile size.
// It makes things complicated, because need to find a tile can coverr
// input and also reach best coalescing.
SwapDim1And2InNarrow<T>(d, input, input_dims, output, kMinTileSize);
SwapDim1And2InNarrow<T, false, IDX_T>(
d, input, input_dims, output, kMinTileSize);
} else {
// If input shape is small, such as 8X8, just do simple copy
int total_elements = input_dims[0] * input_dims[1] * input_dims[2];
IDX_T total_elements = input_dims[0];
total_elements *= input_dims[1];
total_elements *= input_dims[2];
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(d, total_elements);
TransposeSimpleKernel<T, 0, 2, 1>
TransposeSimpleKernel<T, 0, 2, 1, IDX_T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>(
total_elements, input, input_dims, output);
}
}
template <typename T>
template <typename T, typename IDX_T = int>
struct SwapDim1And2InTranspose {
typedef phi::GPUContext Device;
void operator()(const Device& d,
......@@ -604,11 +621,11 @@ struct SwapDim1And2InTranspose {
Dim3 input_dims = {static_cast<int>(combined_dims[0]),
static_cast<int>(combined_dims[1]),
static_cast<int>(combined_dims[2])};
SendSwapDim1And2InTranspose<T>(d, in, input_dims, out);
SendSwapDim1And2InTranspose<T, IDX_T>(d, in, input_dims, out);
}
};
template <typename T>
template <typename T, typename IDX_T = int>
struct SwapDim0And2InTranspose {
typedef phi::GPUContext Device;
void operator()(const Device& d,
......@@ -619,10 +636,12 @@ struct SwapDim0And2InTranspose {
static_cast<int>(combined_dims[1]),
static_cast<int>(combined_dims[2])};
size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2];
IDX_T total_size = combined_dims[0];
total_size *= combined_dims[1];
total_size *= combined_dims[2];
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(d, total_size);
TransposeSimpleKernel<T, 2, 1, 0>
TransposeSimpleKernel<T, 2, 1, 0, IDX_T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>(
total_size, in, input_dims, out);
}
......@@ -652,7 +671,7 @@ inline void CombineTransposeDim3(const framework::DDim& shape,
return;
}
std::vector<int> new_dim_pos(shape.size(), -1);
std::vector<int> combined_dims(shape.size(), 0);
std::vector<int64_t> combined_dims(shape.size(), 0);
int cur_head = perm[0];
new_dim_pos[cur_head] = 0;
combined_dims[0] = shape[cur_head];
......@@ -686,7 +705,7 @@ inline void CombineTransposeDim3(const framework::DDim& shape,
*new_dims = phi::make_ddim(dim_vec);
}
template <typename T>
template <typename T, typename IDX_T = int>
struct TransposeSimple {
static bool run(const phi::GPUContext& ctx,
const Tensor& in,
......@@ -709,21 +728,24 @@ struct TransposeSimple {
if (new_perm[0] == 1 && new_perm[1] == 0) {
// Add the first dimension size as 1.
new_dim_vec.insert(new_dim_vec.begin(), 1);
SwapDim1And2InTranspose<T>()(ctx, in_data, new_dim_vec, out_data);
SwapDim1And2InTranspose<T, IDX_T>()(
ctx, in_data, new_dim_vec, out_data);
return true;
}
break;
case 3:
// In this case, suppose we can do coalescing read and write in tile.
if (new_perm == std::vector<int>({0, 2, 1})) {
SwapDim1And2InTranspose<T>()(ctx, in_data, new_dim_vec, out_data);
SwapDim1And2InTranspose<T, IDX_T>()(
ctx, in_data, new_dim_vec, out_data);
return true;
} else if (new_perm == std::vector<int>({2, 1, 0})) {
// Maybe can optimize later, find a way to do coalescing memory copy.
// But I think it depends on the data size. If span is not large,
// maybe
// can do coalescing.
SwapDim0And2InTranspose<T>()(ctx, in_data, new_dim_vec, out_data);
SwapDim0And2InTranspose<T, IDX_T>()(
ctx, in_data, new_dim_vec, out_data);
return true;
} else {
return false;
......@@ -1159,7 +1181,13 @@ void TransposeGPUKernelDriver(const phi::GPUContext& ctx,
const std::vector<int32_t>& perm,
Tensor* out) {
const int rank = perm.size();
auto ret = TransposeSimple<T>::run(ctx, in, perm, out);
int64_t numel = in.numel();
bool ret{false};
if (numel >= INT32_MAX) {
ret = TransposeSimple<T, int64_t>::run(ctx, in, perm, out);
} else {
ret = TransposeSimple<T>::run(ctx, in, perm, out);
}
if (!ret) {
auto* tuner =
phi::autotune::MakeTransposeTuner<T>(TransCompute<phi::GPUContext, T>);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册