diff --git a/paddle/fluid/framework/gpu_utils.h b/paddle/fluid/framework/gpu_utils.h index 7d937b1ded74233d0c26b336abb2f7f3c91e2756..4eaf5fb7d79ca0bc0e05513180f020abc32b9669 100644 --- a/paddle/fluid/framework/gpu_utils.h +++ b/paddle/fluid/framework/gpu_utils.h @@ -82,9 +82,10 @@ struct Index3 : DeviceArray { }; // Flat index with real dimension -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int FlatTensorIndex(const Index3& index, - const Dim3& dims) { - int flat_index = index[0]; +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IDX_T FlatTensorIndex(const Index3& index, + const Dim3& dims) { + IDX_T flat_index = index[0]; for (int i = 1; i < 3; i++) { flat_index = flat_index * dims[i] + index[i]; } @@ -92,12 +93,13 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int FlatTensorIndex(const Index3& index, } // Convert index to tensor index with dimension. +template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index3 -ConvertTensorIndex(int index, const Dim3& dims) { +ConvertTensorIndex(IDX_T index, const Dim3& dims) { Index3 tensor_index; for (int i = 2; i >= 0; i--) { - int new_index = index / dims[i]; - tensor_index[i] = index - dims[i] * new_index; + IDX_T new_index = index / dims[i]; + tensor_index[i] = static_cast(index - dims[i] * new_index); index = new_index; } return tensor_index; diff --git a/paddle/fluid/operators/transpose_op.cu.h b/paddle/fluid/operators/transpose_op.cu.h index 0ae020c0dfd3c08cbc60009f9b383e32f9f892e5..b8e41270ab3df9c554c49f5b62ad2c49fe811203 100644 --- a/paddle/fluid/operators/transpose_op.cu.h +++ b/paddle/fluid/operators/transpose_op.cu.h @@ -79,7 +79,11 @@ constexpr bool CheckNonLongTileSize(int tile_long, int tile_short, int size_T) { // Use SM to do data transfer, load a tile into SM then store out. // All tile read and write are colascing, so can speedup memory copy -template +template __global__ void TilingSwapDim1And2(const T* __restrict__ input, Dim3 input_dims, T* __restrict__ output) { @@ -116,7 +120,7 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input, // Converts block idx to tile index, each block process a tile Index3 input_block_tile_index = - ConvertTensorIndex(blockIdx.x, tile_aligned_input_dim); + framework::ConvertTensorIndex(blockIdx.x, tile_aligned_input_dim); // Compute real index align to tile:0, 32, 64... Index3 block_tile_index_in_input = { @@ -126,11 +130,11 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input, }; // Compute block flat index against input dims. - int input_origin_block_flat_index = - FlatTensorIndex(block_tile_index_in_input, input_dims); + IDX_T input_origin_block_flat_index = + framework::FlatTensorIndex(block_tile_index_in_input, input_dims); bool full_tile = true; - int tile_width = TileY; + IDX_T tile_width = TileY; // Last row is not full. if (input_block_tile_index[2] == tile_aligned_input_dim[2] - 1) { @@ -138,21 +142,21 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input, full_tile &= false; } - int tile_height = TileX; + IDX_T tile_height = TileX; if (input_block_tile_index[1] == tile_aligned_input_dim[1] - 1) { tile_height = input_dims[1] - (tile_aligned_input_dim[1] - 1) * TileX; full_tile &= false; } - constexpr int in_effective_thread_num = NumThreads / TileY * TileY; + constexpr IDX_T in_effective_thread_num = NumThreads / TileY * TileY; if (x < in_effective_thread_num) { // Read a tile from input using block. int x_i = x / TileY; int x_j = x % TileY; - int input_ind = input_origin_block_flat_index + x_i * input_dims[2] + x_j; - int input_inc = BlockReadRows * input_dims[2]; + IDX_T input_ind = input_origin_block_flat_index + x_i * input_dims[2] + x_j; + IDX_T input_inc = BlockReadRows * input_dims[2]; if (full_tile) { #pragma unroll @@ -163,7 +167,7 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input, } else { if (x_j < tile_width) { #pragma unroll - for (int ind_i = x_i; ind_i < (tile_height); ind_i += BlockReadRows) { + for (IDX_T ind_i = x_i; ind_i < (tile_height); ind_i += BlockReadRows) { tile_sm[ind_i][x_j] = input[input_ind]; input_ind += input_inc; } @@ -186,17 +190,17 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input, output_block_tile_index[2] * TileX, }; - int output_origin_block_flat_index = - FlatTensorIndex(block_tile_index_in_output, output_dims); + IDX_T output_origin_block_flat_index = framework::FlatTensorIndex( + block_tile_index_in_output, output_dims); - constexpr int out_effective_thread_num = NumThreads / TileX * TileX; + constexpr IDX_T out_effective_thread_num = NumThreads / TileX * TileX; if (x < out_effective_thread_num) { int x_i = x / TileX; int x_j = x % TileX; - int output_ind = + IDX_T output_ind = output_origin_block_flat_index + x_i * output_dims[2] + x_j; - int output_inc = BlockWriteRows * output_dims[2]; + IDX_T output_inc = BlockWriteRows * output_dims[2]; if (full_tile) { #pragma unroll @@ -207,7 +211,7 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input, } else { if (x_j < tile_height) { #pragma unroll - for (int ind_i = x_i; ind_i < (tile_width); ind_i += BlockWriteRows) { + for (IDX_T ind_i = x_i; ind_i < (tile_width); ind_i += BlockWriteRows) { output[output_ind] = tile_sm[x_j][ind_i]; output_ind += output_inc; } @@ -272,32 +276,36 @@ struct SystemElemType<16> { using type = float4; }; -template +template void LaunchNarrowDims2TransposeKernel(const phi::GPUContext& d, int tile_size_i, int tile_size_j, - int total_tiles_count, + IDX_T total_tiles_count, const T* input, const Dim3& input_dims, T* output) { constexpr int NumThreads = tile_long; if (tile_size_i <= tile_long && tile_size_j <= tile_short) { - TilingSwapDim1And2 + TilingSwapDim1And2 <<>>( input, input_dims, output); } else { - TilingSwapDim1And2 + TilingSwapDim1And2 <<>>( input, input_dims, output); } } -template +template struct NarrowDims2TransposeDispatch { static void DoTranspose(const phi::GPUContext& d, int tile_size_i, int tile_size_j, - int total_tiles_count, + IDX_T total_tiles_count, const T* input, const Dim3& input_dims, T* output) { @@ -313,7 +321,7 @@ struct NarrowDims2TransposeDispatch { std::min(tile_size_i, tile_size_j) <= tile_short; if (request_satisfied) { - LaunchNarrowDims2TransposeKernel( + LaunchNarrowDims2TransposeKernel( d, tile_size_i, tile_size_j, @@ -328,40 +336,41 @@ struct NarrowDims2TransposeDispatch { std::max(tile_size_i, tile_size_j) > tile_long; if (long_side_request_not_satisfied) { - NarrowDims2TransposeDispatch::DoTranspose( - d, - tile_size_i, - tile_size_j, - total_tiles_count, - input, - input_dims, - output); + NarrowDims2TransposeDispatch:: + DoTranspose(d, + tile_size_i, + tile_size_j, + total_tiles_count, + input, + input_dims, + output); } else { - NarrowDims2TransposeDispatch::DoTranspose( - d, - tile_size_i, - tile_size_j, - total_tiles_count, - input, - input_dims, - output); + NarrowDims2TransposeDispatch:: + DoTranspose(d, + tile_size_i, + tile_size_j, + total_tiles_count, + input, + input_dims, + output); } } }; // If Not long tile size, goto this function when compile. -template +template struct NarrowDims2TransposeDispatch< T, tile_long, tile_short, + IDX_T, typename std::enable_if::type> { static void DoTranspose(const phi::GPUContext& d, int tile_size_i, int tile_size_j, - int total_tiles_count, + IDX_T total_tiles_count, const T* input, const Dim3& input_dims, T* output) { @@ -377,7 +386,7 @@ struct NarrowDims2TransposeDispatch< std::min(tile_size_i, tile_size_j) <= tile_short; if (request_satisfied) { - LaunchNarrowDims2TransposeKernel( + LaunchNarrowDims2TransposeKernel( d, tile_size_i, tile_size_j, @@ -388,29 +397,30 @@ struct NarrowDims2TransposeDispatch< return; } - NarrowDims2TransposeDispatch::DoTranspose( - d, - tile_size_i, - tile_size_j, - total_tiles_count, - input, - input_dims, - output); + NarrowDims2TransposeDispatch:: + DoTranspose(d, + tile_size_i, + tile_size_j, + total_tiles_count, + input, + input_dims, + output); } }; // If long tile size, goto this function when compile. -template +template struct NarrowDims2TransposeDispatch< T, tile_long, tile_short, + IDX_T, typename std::enable_if::type> { static void DoTranspose(const phi::GPUContext& d, int tile_size_i, int tile_size_j, - int total_tiles_count, + IDX_T total_tiles_count, const T* input, const Dim3& input_dims, T* output) { @@ -422,7 +432,7 @@ struct NarrowDims2TransposeDispatch< " but received is:%d.", tile_long)); - LaunchNarrowDims2TransposeKernel( + LaunchNarrowDims2TransposeKernel( d, tile_size_i, tile_size_j, @@ -433,7 +443,7 @@ struct NarrowDims2TransposeDispatch< } }; -template +template void SwapDim1And2InNarrow(const phi::GPUContext& d, const T* input, const Dim3& input_dims, @@ -504,13 +514,14 @@ void SwapDim1And2InNarrow(const phi::GPUContext& d, framework::CeilOrFloor(input_dims[2], select_tile_size_j), }; - int total_tiles_count = - input_dims_aligned[0] * input_dims_aligned[1] * input_dims_aligned[2]; + IDX_T total_tiles_count = input_dims_aligned[0]; + total_tiles_count *= input_dims_aligned[1]; + total_tiles_count *= input_dims_aligned[2]; // Suppose T can be replaced by system builtin types using ElemType = typename SystemElemType::type; - NarrowDims2TransposeDispatch::DoTranspose( + NarrowDims2TransposeDispatch::DoTranspose( d, select_tile_size_i, select_tile_size_j, @@ -522,8 +533,8 @@ void SwapDim1And2InNarrow(const phi::GPUContext& d, // This is for case that cannot do coalescing read and write. // Or input is too small to split into tiles. -template -__global__ void TransposeSimpleKernel(int nthreads, +template +__global__ void TransposeSimpleKernel(IDX_T nthreads, const T* __restrict__ input, Dim3 input_dims, T* __restrict__ output) { @@ -532,22 +543,24 @@ __global__ void TransposeSimpleKernel(int nthreads, output_dims[pos1] = input_dims[1]; output_dims[pos2] = input_dims[2]; - CUDA_KERNEL_LOOP(output_index, nthreads) { - Index3 output_tensor_index = ConvertTensorIndex(output_index, output_dims); + CUDA_KERNEL_LOOP_TYPE(output_index, nthreads, IDX_T) { + Index3 output_tensor_index = + framework::ConvertTensorIndex(output_index, output_dims); Index3 input_tensor_index; input_tensor_index[0] = output_tensor_index[pos0]; input_tensor_index[1] = output_tensor_index[pos1]; input_tensor_index[2] = output_tensor_index[pos2]; - int input_index = FlatTensorIndex(input_tensor_index, input_dims); + IDX_T input_index = + framework::FlatTensorIndex(input_tensor_index, input_dims); output[output_index] = input[input_index]; } } // Here suppose convert all tensor to dim3, so just change dim1 and 2. -template +template void SendSwapDim1And2InTranspose(const phi::GPUContext& d, const T* input, const Dim3& input_dims, @@ -572,10 +585,11 @@ void SendSwapDim1And2InTranspose(const phi::GPUContext& d, framework::CeilOrFloor(input_dims[2], kTileSize), }; - int total_tiles_count = - input_dims_aligned[0] * input_dims_aligned[1] * input_dims_aligned[2]; + IDX_T total_tiles_count = input_dims_aligned[0]; + total_tiles_count *= input_dims_aligned[1]; + total_tiles_count *= input_dims_aligned[2]; - TilingSwapDim1And2 + TilingSwapDim1And2 <<>>( input, input_dims, output); @@ -583,18 +597,21 @@ void SendSwapDim1And2InTranspose(const phi::GPUContext& d, // If input shape is like Rect, such as 2X100, use Narrow tile size. // It makes things complicated, because need to find a tile can coverr // input and also reach best coalescing. - SwapDim1And2InNarrow(d, input, input_dims, output, kMinTileSize); + SwapDim1And2InNarrow( + d, input, input_dims, output, kMinTileSize); } else { // If input shape is small, such as 8X8, just do simple copy - int total_elements = input_dims[0] * input_dims[1] * input_dims[2]; + IDX_T total_elements = input_dims[0]; + total_elements *= input_dims[1]; + total_elements *= input_dims[2]; auto config = phi::backends::gpu::GetGpuLaunchConfig1D(d, total_elements); - TransposeSimpleKernel + TransposeSimpleKernel <<>>( total_elements, input, input_dims, output); } } -template +template struct SwapDim1And2InTranspose { typedef phi::GPUContext Device; void operator()(const Device& d, @@ -604,11 +621,11 @@ struct SwapDim1And2InTranspose { Dim3 input_dims = {static_cast(combined_dims[0]), static_cast(combined_dims[1]), static_cast(combined_dims[2])}; - SendSwapDim1And2InTranspose(d, in, input_dims, out); + SendSwapDim1And2InTranspose(d, in, input_dims, out); } }; -template +template struct SwapDim0And2InTranspose { typedef phi::GPUContext Device; void operator()(const Device& d, @@ -619,10 +636,12 @@ struct SwapDim0And2InTranspose { static_cast(combined_dims[1]), static_cast(combined_dims[2])}; - size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2]; + IDX_T total_size = combined_dims[0]; + total_size *= combined_dims[1]; + total_size *= combined_dims[2]; auto config = phi::backends::gpu::GetGpuLaunchConfig1D(d, total_size); - TransposeSimpleKernel + TransposeSimpleKernel <<>>( total_size, in, input_dims, out); } @@ -652,7 +671,7 @@ inline void CombineTransposeDim3(const framework::DDim& shape, return; } std::vector new_dim_pos(shape.size(), -1); - std::vector combined_dims(shape.size(), 0); + std::vector combined_dims(shape.size(), 0); int cur_head = perm[0]; new_dim_pos[cur_head] = 0; combined_dims[0] = shape[cur_head]; @@ -686,7 +705,7 @@ inline void CombineTransposeDim3(const framework::DDim& shape, *new_dims = phi::make_ddim(dim_vec); } -template +template struct TransposeSimple { static bool run(const phi::GPUContext& ctx, const Tensor& in, @@ -709,21 +728,24 @@ struct TransposeSimple { if (new_perm[0] == 1 && new_perm[1] == 0) { // Add the first dimension size as 1. new_dim_vec.insert(new_dim_vec.begin(), 1); - SwapDim1And2InTranspose()(ctx, in_data, new_dim_vec, out_data); + SwapDim1And2InTranspose()( + ctx, in_data, new_dim_vec, out_data); return true; } break; case 3: // In this case, suppose we can do coalescing read and write in tile. if (new_perm == std::vector({0, 2, 1})) { - SwapDim1And2InTranspose()(ctx, in_data, new_dim_vec, out_data); + SwapDim1And2InTranspose()( + ctx, in_data, new_dim_vec, out_data); return true; } else if (new_perm == std::vector({2, 1, 0})) { // Maybe can optimize later, find a way to do coalescing memory copy. // But I think it depends on the data size. If span is not large, // maybe // can do coalescing. - SwapDim0And2InTranspose()(ctx, in_data, new_dim_vec, out_data); + SwapDim0And2InTranspose()( + ctx, in_data, new_dim_vec, out_data); return true; } else { return false; @@ -1159,7 +1181,13 @@ void TransposeGPUKernelDriver(const phi::GPUContext& ctx, const std::vector& perm, Tensor* out) { const int rank = perm.size(); - auto ret = TransposeSimple::run(ctx, in, perm, out); + int64_t numel = in.numel(); + bool ret{false}; + if (numel >= INT32_MAX) { + ret = TransposeSimple::run(ctx, in, perm, out); + } else { + ret = TransposeSimple::run(ctx, in, perm, out); + } if (!ret) { auto* tuner = phi::autotune::MakeTransposeTuner(TransCompute);