From 54970444ce9baf3154a3775c81280ca97d59c83d Mon Sep 17 00:00:00 2001 From: zhaoyuchen2018 <45989343+zhaoyuchen2018@users.noreply.github.com> Date: Tue, 11 Feb 2020 12:01:31 +0800 Subject: [PATCH] Improve transpose performance with tile sm copy, test=develop (#22311) * Refine code, fix select tile error,test=develop * Refine element type and some comments, test=develop * Refine comments and gpu utils, test=develop * Remove some useless condition * Refine floor and ceil, test=develop * refine for loop. test=develop Signed-off-by: zhaoyuchen --- paddle/fluid/framework/gpu_utils.h | 137 ++++ paddle/fluid/operators/transpose_op.cu | 743 ++++++++++++++++++ paddle/fluid/operators/transpose_op.cu.cc | 45 -- .../tests/unittests/test_transpose_op.py | 18 + 4 files changed, 898 insertions(+), 45 deletions(-) create mode 100644 paddle/fluid/framework/gpu_utils.h create mode 100644 paddle/fluid/operators/transpose_op.cu delete mode 100644 paddle/fluid/operators/transpose_op.cu.cc diff --git a/paddle/fluid/framework/gpu_utils.h b/paddle/fluid/framework/gpu_utils.h new file mode 100644 index 0000000000..37c9852a1a --- /dev/null +++ b/paddle/fluid/framework/gpu_utils.h @@ -0,0 +1,137 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#define EIGEN_USE_GPU + +#include +#include "paddle/fluid/platform/enforce.h" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace paddle { +namespace framework { + +template +struct DeviceArray { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& operator[](int index) const { + return data[index]; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T& operator[](int index) { + return data[index]; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DeviceArray() { + for (int i = 0; i < Size; i++) { + data[i] = DefaultValue; + } + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DeviceArray(T a0) { + data[0] = a0; + for (int i = 1; i < Size; i++) { + data[i] = DefaultValue; + } + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DeviceArray(T a0, T a1) { + data[0] = a0; + data[1] = a1; + for (int i = 2; i < Size; i++) { + data[i] = DefaultValue; + } + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DeviceArray(T a0, T a1, T a2) { + data[0] = a0; + data[1] = a1; + data[2] = a2; + for (int i = 3; i < Size; i++) { + data[i] = DefaultValue; + } + } + EIGEN_STRONG_INLINE DeviceArray(const std::array& sa) { + for (int i = 0; i < Size; i++) { + data[i] = sa[i]; + } + } + T data[Size]; +}; + +struct Dim3 : DeviceArray { + typedef DeviceArray Base; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dim3() : Base() {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dim3(int a0, int a1, int a2) + : Base(a0, a1, a2) {} + EIGEN_STRONG_INLINE Dim3(const std::array& array) : Base(array) {} +}; + +struct Index3 : DeviceArray { + typedef DeviceArray Base; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index3() : Base() {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index3(int a0, int a1, int a2) + : Base(a0, a1, a2) {} +}; + +// Flat index with real dimension +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int FlatTensorIndex(const Index3& index, + const Dim3& dims) { + int flat_index = index[0]; + for (int i = 1; i < 3; i++) { + flat_index = flat_index * dims[i] + index[i]; + } + return flat_index; +} + +// Convert index to tensor index with dimension. +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index3 +ConvertTensorIndex(int index, const Dim3& dims) { + Index3 tensor_index; + for (int i = 2; i >= 0; i--) { + int new_index = index / dims[i]; + tensor_index[i] = index - dims[i] * new_index; + index = new_index; + } + return tensor_index; +} + +template +IntType CeilOrFloor(IntType x, IntType deviser) { + PADDLE_ENFORCE_GT(deviser, 0, platform::errors::InvalidArgument( + "deviser should be greater than 0, " + "but received is:%d", + deviser)); + + PADDLE_ENFORCE_GT( + x, 0, platform::errors::InvalidArgument("input should be greater than 0, " + "but received is:%d", + x)); + + const IntType round_to_zero = x / deviser; + const IntType inte_result = round_to_zero * deviser; + + if (ceil) { + const bool do_adjustment = + (round_to_zero >= 0) && (deviser > 0 && x > inte_result); + const IntType adjustment = static_cast(do_adjustment); + const IntType ceil_val = round_to_zero + adjustment; + return ceil_val; + } else { + const bool do_adjustment = + (round_to_zero <= 0) && (deviser > 0 && x < inte_result); + + const IntType adjustment = static_cast(do_adjustment); + const IntType floor_val = round_to_zero - adjustment; + return floor_val; + } +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/operators/transpose_op.cu b/paddle/fluid/operators/transpose_op.cu new file mode 100644 index 0000000000..3152c902b0 --- /dev/null +++ b/paddle/fluid/operators/transpose_op.cu @@ -0,0 +1,743 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/framework/gpu_utils.h" +#include "paddle/fluid/operators/transpose_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/gpu_launch_param_config.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using Dim3 = framework::Dim3; +using Index3 = framework::Index3; + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +struct EqualTo { + constexpr bool operator()(int a, int b) const { return a == b; } +}; + +struct GreaterThan { + constexpr bool operator()(int a, int b) const { return a > b; } +}; + +// Value can be decided in compile time. +template +constexpr bool CheckProperTileSize(int tile_long, int tile_short, int size_T, + FUN op) { + return (size_T == 16 && ((tile_long == INT_32 && op(tile_short, 4)) || + (tile_long == 2 * INT_32 && op(tile_short, 4)) || + (tile_long == 4 * INT_32 && op(tile_short, 4)) || + (tile_long == 8 * INT_32 && op(tile_short, 2)))) || + (size_T == 8 && ((tile_long == INT_32 && op(tile_short, 15)) || + (tile_long == 2 * INT_32 && op(tile_short, 15)) || + (tile_long == 4 * INT_32 && op(tile_short, 8)) || + (tile_long == 8 * INT_32 && op(tile_short, 4)) || + (tile_long == 16 * INT_32 && op(tile_short, 2)))) || + ((size_T == 4 || size_T == 2 || size_T == 1) && + ((tile_long == INT_32 && op(tile_short, 15)) || + (tile_long == 2 * INT_32 && op(tile_short, 15)) || + (tile_long == 4 * INT_32 && op(tile_short, 8)) || + (tile_long == 8 * INT_32 && op(tile_short, 4)) || + (tile_long == 16 * INT_32 && op(tile_short, 2)) || + (tile_long == 16 * INT_32 && op(tile_short, 2)))); +} + +constexpr bool CheckLongTileSize(int tile_long, int tile_short, int size_T) { + return CheckProperTileSize(tile_long, tile_short, size_T, EqualTo()); +} + +constexpr bool CheckOutsideTileSize(int tile_long, int tile_short, int size_T) { + return CheckProperTileSize(tile_long, tile_short, size_T, GreaterThan()); +} + +constexpr bool CheckNonLongTileSize(int tile_long, int tile_short, int size_T) { + return !CheckOutsideTileSize(tile_long, tile_short, size_T) && + (CheckOutsideTileSize(tile_long * 2, tile_short, size_T) || + CheckOutsideTileSize(tile_long, tile_short + 1, size_T)) && + !CheckLongTileSize(tile_long, tile_short, size_T); +} + +// Use SM to do data transfer, load a tile into SM then store out. +// All tile read and write are colascing, so can speedup memory copy +template +__global__ void TilingSwapDim1And2(const T* __restrict__ input, Dim3 input_dims, + T* __restrict__ output) { + assert(blockDim.x == NumThreads); + assert(blockDim.y == 1); + assert(blockDim.z == 1); + assert(gridDim.y == 1); + assert(gridDim.z == 1); + + constexpr int BlockReadRows = NumThreads / TileY; + constexpr int BlockWriteRows = NumThreads / TileX; + + // One extra line in the inner dimension to avoid share memory bank conflict. + __shared__ __align__( + alignof(T)) char share_mem_ptr[TileX * (TileY + 1) * sizeof(T)]; + typedef T(*ShareMemory)[TileY + 1]; + + ShareMemory tile_sm = reinterpret_cast(share_mem_ptr); + + int x = threadIdx.x; + + Dim3 output_dims = { + input_dims[0], input_dims[2], input_dims[1], + }; + + // Align dim to Tiles + Dim3 tile_aligned_input_dim = { + input_dims[0], (input_dims[1] + TileX - 1) / TileX, + (input_dims[2] + TileY - 1) / TileY, + }; + + // Converts block idx to tile index, each block process a tile + Index3 input_block_tile_index = + ConvertTensorIndex(blockIdx.x, tile_aligned_input_dim); + + // Compute real index align to tile:0, 32, 64... + Index3 block_tile_index_in_input = { + input_block_tile_index[0], input_block_tile_index[1] * TileX, + input_block_tile_index[2] * TileY, + }; + + // Compute block flat index against input dims. + int input_origin_block_flat_index = + FlatTensorIndex(block_tile_index_in_input, input_dims); + + bool full_tile = true; + int tile_width = TileY; + + // Last row is not full. + if (input_block_tile_index[2] == tile_aligned_input_dim[2] - 1) { + tile_width = input_dims[2] - (tile_aligned_input_dim[2] - 1) * TileY; + full_tile &= false; + } + + int tile_height = TileX; + + if (input_block_tile_index[1] == tile_aligned_input_dim[1] - 1) { + tile_height = input_dims[1] - (tile_aligned_input_dim[1] - 1) * TileX; + full_tile &= false; + } + + constexpr int in_effective_thread_num = NumThreads / TileY * TileY; + + if (x < in_effective_thread_num) { + // Read a tile from input using block. + int x_i = x / TileY; + int x_j = x % TileY; + int input_ind = input_origin_block_flat_index + x_i * input_dims[2] + x_j; + int input_inc = BlockReadRows * input_dims[2]; + + if (full_tile) { +#pragma unroll + for (int ind_i = x_i; ind_i < (TileX); ind_i += BlockReadRows) { + tile_sm[ind_i][x_j] = input[input_ind]; + input_ind += input_inc; + } + } else { + if (x_j < tile_width) { +#pragma unroll + for (int ind_i = x_i; ind_i < (tile_height); ind_i += BlockReadRows) { + tile_sm[ind_i][x_j] = input[input_ind]; + input_ind += input_inc; + } + } + } + } + + __syncthreads(); + + // Store sm value back to out + Index3 output_block_tile_index = { + input_block_tile_index[0], input_block_tile_index[2], + input_block_tile_index[1], + }; + + Index3 block_tile_index_in_output = { + output_block_tile_index[0], output_block_tile_index[1] * TileY, + output_block_tile_index[2] * TileX, + }; + + int output_origin_block_flat_index = + FlatTensorIndex(block_tile_index_in_output, output_dims); + + constexpr int out_effective_thread_num = NumThreads / TileX * TileX; + + if (x < out_effective_thread_num) { + int x_i = x / TileX; + int x_j = x % TileX; + int output_ind = + output_origin_block_flat_index + x_i * output_dims[2] + x_j; + int output_inc = BlockWriteRows * output_dims[2]; + + if (full_tile) { +#pragma unroll + for (int ind_i = x_i; ind_i < (TileY); ind_i += BlockWriteRows) { + output[output_ind] = tile_sm[x_j][ind_i]; + output_ind += output_inc; + } + } else { + if (x_j < tile_height) { +#pragma unroll + for (int ind_i = x_i; ind_i < (tile_width); ind_i += BlockWriteRows) { + output[output_ind] = tile_sm[x_j][ind_i]; + output_ind += output_inc; + } + } + } + } +} + +// This function will find combination of long_side X short_side in backups +template +bool SelectProperTileSize(std::vector>* tiles) { + PADDLE_ENFORCE_LE( + TSIZE, 16, + platform::errors::InvalidArgument( + "The tile size should smaller than 16, but received is:%d.", TSIZE)); + + PADDLE_ENFORCE_EQ( + (TSIZE & (TSIZE - 1)), 0, + platform::errors::InvalidArgument( + "Data types should be powers of 2, but reived size is:%d.", TSIZE)); + + const int kMaxLongSideLen = 1024; + const int kMaxShortSideLen = 15; + + for (int long_side = 32; long_side <= kMaxLongSideLen; long_side *= 2) { + for (int short_side = 2; short_side <= kMaxShortSideLen; short_side += 1) { + if (CheckLongTileSize(long_side, short_side, TSIZE)) { + tiles->push_back(std::make_pair(long_side, short_side)); + + if (short_side == 2) return true; + + break; + } + } + } + return false; +} + +// Use system built in type +template +struct SystemElemType; +template <> +struct SystemElemType<1> { + using type = uint8_t; +}; +template <> +struct SystemElemType<2> { + using type = uint16_t; +}; +template <> +struct SystemElemType<4> { + using type = uint32_t; +}; +template <> +struct SystemElemType<8> { + using type = uint64_t; +}; +template <> +struct SystemElemType<16> { + using type = float4; +}; + +template +void LaunchNarrowDims2TransposeKernel(const platform::CUDADeviceContext& d, + int tile_size_i, int tile_size_j, + int total_tiles_count, const T* input, + const Dim3& input_dims, T* output) { + constexpr int NumThreads = tile_long; + if (tile_size_i <= tile_long && tile_size_j <= tile_short) { + TilingSwapDim1And2< + T, NumThreads, tile_long, + tile_short><<>>( + input, input_dims, output); + } else { + TilingSwapDim1And2< + T, NumThreads, tile_short, + tile_long><<>>( + input, input_dims, output); + } +} + +template +struct NarrowDims2TransposeDispatch { + static void DoTranspose(const platform::CUDADeviceContext& d, int tile_size_i, + int tile_size_j, int total_tiles_count, + const T* input, const Dim3& input_dims, T* output) { + PADDLE_ENFORCE_EQ( + (tile_long & (tile_long - 1)), 0, + platform::errors::InvalidArgument( + "The length of the longer side of the tile should be power of 2." + " But received value is:%d.", + tile_long)); + + bool request_satisfied = std::max(tile_size_i, tile_size_j) <= tile_long && + std::min(tile_size_i, tile_size_j) <= tile_short; + + if (request_satisfied) { + LaunchNarrowDims2TransposeKernel( + d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, + output); + return; + } + + const bool long_side_request_not_satisfied = + std::max(tile_size_i, tile_size_j) > tile_long; + + if (long_side_request_not_satisfied) { + NarrowDims2TransposeDispatch::DoTranspose( + d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, + output); + } else { + NarrowDims2TransposeDispatch::DoTranspose( + d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, + output); + } + } +}; + +// If Not long tile size, goto this function when compile. +template +struct NarrowDims2TransposeDispatch< + T, tile_long, tile_short, + typename std::enable_if< + CheckNonLongTileSize(tile_long, tile_short, sizeof(T)), void>::type> { + static void DoTranspose(const platform::CUDADeviceContext& d, int tile_size_i, + int tile_size_j, int total_tiles_count, + const T* input, const Dim3& input_dims, T* output) { + PADDLE_ENFORCE_EQ( + (tile_long & (tile_long - 1)), 0, + platform::errors::InvalidArgument( + "The length of the longer side of the tile should be power of 2." + " But received value is:%d.", + tile_long)); + + bool request_satisfied = std::max(tile_size_i, tile_size_j) <= tile_long && + std::min(tile_size_i, tile_size_j) <= tile_short; + + if (request_satisfied) { + LaunchNarrowDims2TransposeKernel( + d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, + output); + return; + } + + NarrowDims2TransposeDispatch::DoTranspose( + d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, + output); + } +}; + +// If long tile size, goto this function when compile. +template +struct NarrowDims2TransposeDispatch< + T, tile_long, tile_short, + typename std::enable_if::type> { + static void DoTranspose(const platform::CUDADeviceContext& d, int tile_size_i, + int tile_size_j, int total_tiles_count, + const T* input, const Dim3& input_dims, T* output) { + PADDLE_ENFORCE_EQ( + (tile_long & (tile_long - 1)), 0, + platform::errors::InvalidArgument( + "The length of the longer side of the tile should be power of 2," + " but received is:%d.", + tile_long)); + + LaunchNarrowDims2TransposeKernel( + d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, + output); + } +}; + +template +void SwapDim1And2InNarrow(const platform::CUDADeviceContext& d, const T* input, + const Dim3& input_dims, T* output, + const int kMinTileSize) { + // First get available tile sizes for the data type requested as backups + std::vector> tile_sele; + auto ret = SelectProperTileSize(&tile_sele); + PADDLE_ENFORCE_EQ( + ret, true, + platform::errors::InvalidArgument( + "SelectProperTileSize should return true, but return value is:%d.", + ret)); + + int tile_long_edge = 0; + int tile_short_edge = 0; + float lowest_cost = std::numeric_limits::max(); + int input_long_edge = std::max(input_dims[1], input_dims[2]); + + // Find the tile size that best suit in inputs. + for (auto tile_size_pair : tile_sele) { + int proposed_tile_long_edge = tile_size_pair.first; + // data may not aligned to tile, so some threads wasted, we need + // to find least wasted threads, which means we need to find tile + // can split input properly, in another words: num_wasted_threads=0. + int num_wasted_threads = input_long_edge - + framework::CeilOrFloor( + input_long_edge, proposed_tile_long_edge) * + proposed_tile_long_edge; + + int num_full_tiles = framework::CeilOrFloor( + input_long_edge, proposed_tile_long_edge); + + float cost = num_wasted_threads; + + if (cost <= lowest_cost) { + tile_long_edge = proposed_tile_long_edge; + tile_short_edge = tile_size_pair.second; + lowest_cost = cost; + } + // break as we already find best tile size. + if (cost == 0) break; + } + + // The tile size we select should be match with input dim, long side to long + // short side to short. + // First set long side as i if dim1 > Tile min size, then set dim2 as j. + int select_tile_size_i = + input_dims[1] >= kMinTileSize ? tile_long_edge : input_dims[1]; + int select_tile_size_j = + input_dims[1] >= kMinTileSize ? input_dims[2] : tile_long_edge; + + // Check if i is long edge, if not set i as short. + select_tile_size_i = select_tile_size_i == tile_long_edge + ? tile_long_edge + : std::min(select_tile_size_i, tile_short_edge); + + // Check if j is long edge, if not set j as short. + select_tile_size_j = select_tile_size_j == tile_long_edge + ? tile_long_edge + : std::min(select_tile_size_j, tile_short_edge); + + // Here finally get proper long X short tile size. + Dim3 input_dims_aligned = { + input_dims[0], + framework::CeilOrFloor(input_dims[1], select_tile_size_i), + framework::CeilOrFloor(input_dims[2], select_tile_size_j), + }; + + int total_tiles_count = + input_dims_aligned[0] * input_dims_aligned[1] * input_dims_aligned[2]; + + // Suppose T can be replaced by system builtin types + using ElemType = typename SystemElemType::type; + + NarrowDims2TransposeDispatch::DoTranspose( + d, select_tile_size_i, select_tile_size_j, total_tiles_count, + reinterpret_cast(input), input_dims, + reinterpret_cast(output)); +} + +// This is for case that cannot do coalescing read and write. +// Or input is too small to split into tiles. +template +__global__ void TransposeSimpleKernel(int nthreads, const T* __restrict__ input, + Dim3 input_dims, T* __restrict__ output) { + Dim3 output_dims; + output_dims[pos0] = input_dims[0]; + output_dims[pos1] = input_dims[1]; + output_dims[pos2] = input_dims[2]; + + CUDA_1D_KERNEL_LOOP(output_index, nthreads) { + Index3 output_tensor_index = ConvertTensorIndex(output_index, output_dims); + + Index3 input_tensor_index; + input_tensor_index[0] = output_tensor_index[pos0]; + input_tensor_index[1] = output_tensor_index[pos1]; + input_tensor_index[2] = output_tensor_index[pos2]; + + int input_index = FlatTensorIndex(input_tensor_index, input_dims); + + output[output_index] = input[input_index]; + } +} + +// Here suppose convert all tensor to dim3, so just change dim1 and 2. +template +void SendSwapDim1And2InTranspose(const platform::CUDADeviceContext& d, + const T* input, const Dim3& input_dims, + T* output) { + // Suppose tile size > 16 + static const int kMinTileSize = 16; + static const int kMinNarrowTileSize = 96; + + bool large_tile = + input_dims[1] >= kMinTileSize && input_dims[2] >= kMinTileSize; + bool narrow_tile = input_dims[1] >= kMinNarrowTileSize || + input_dims[2] >= kMinNarrowTileSize; + if (large_tile) { + // If input is large square, such as 32X32, use SM to do copy. + // suppose 32 X 32 gives best performance, and 8 warp in block. + constexpr int kTileSize = 32; + constexpr int kNumThreads = 256; + + Dim3 input_dims_aligned = { + input_dims[0], + framework::CeilOrFloor(input_dims[1], kTileSize), + framework::CeilOrFloor(input_dims[2], kTileSize), + }; + + int total_tiles_count = + input_dims_aligned[0] * input_dims_aligned[1] * input_dims_aligned[2]; + + TilingSwapDim1And2< + T, kNumThreads, kTileSize, + kTileSize><<>>( + input, input_dims, output); + + } else if (narrow_tile) { + // If input shape is like Rect, such as 2X100, use Narrow tile size. + // It makes things complicated, because need to find a tile can coverr + // input and also reach best coalescing. + SwapDim1And2InNarrow(d, input, input_dims, output, kMinTileSize); + } else { + // If input shape is small, such as 8X8, just do simple copy + int total_elements = input_dims[0] * input_dims[1] * input_dims[2]; + auto config = GetGpuLaunchConfig1D(d, total_elements); + TransposeSimpleKernel<<< + config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>( + total_elements, input, input_dims, output); + } +} + +template +struct SwapDim1And2InTranspose { + typedef platform::CUDADeviceContext Device; + void operator()(const Device& d, const T* in, + const std::vector& combined_dims, T* out) { + Dim3 input_dims = {static_cast(combined_dims[0]), + static_cast(combined_dims[1]), + static_cast(combined_dims[2])}; + SendSwapDim1And2InTranspose(d, in, input_dims, out); + } +}; + +template +struct SwapDim0And2InTranspose { + typedef platform::CUDADeviceContext Device; + void operator()(const Device& d, const T* in, + const std::vector& combined_dims, T* out) { + Dim3 input_dims = {static_cast(combined_dims[0]), + static_cast(combined_dims[1]), + static_cast(combined_dims[2])}; + + size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2]; + auto config = GetGpuLaunchConfig1D(d, total_size); + + TransposeSimpleKernel<<< + config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>( + total_size, in, input_dims, out); + } +}; + +// This function is to combine dimension. fox example: +// (0, 1, 3, 2) --> (0, 2, 1) +inline void CombineTransposeDim3(const framework::DDim& shape, + const std::vector& perm, + std::vector* new_perm, + framework::DDim* new_dims) { + PADDLE_ENFORCE_EQ(shape.size(), perm.size(), + platform::errors::InvalidArgument( + " shape should have the save dim with perm, but" + " received shape size is:%d, perm size is:%d.", + shape.size(), perm.size())); + + std::vector dim_vec; + if (shape.size() == 1) { + // If input dimension is already 1, no need to combine dim. + new_perm->resize(1); + (*new_perm)[0] = perm[0]; + dim_vec.push_back(shape[0]); + *new_dims = framework::make_ddim(dim_vec); + return; + } + std::vector new_dim_pos(shape.size(), -1); + std::vector combined_dims(shape.size(), 0); + int cur_head = perm[0]; + new_dim_pos[cur_head] = 0; + combined_dims[0] = shape[cur_head]; + int dim_idx = 0; + for (int perm_idx = 1; perm_idx < shape.size(); ++perm_idx) { + // combine consecutive dimensions. + if (cur_head + 1 == perm[perm_idx]) { + cur_head = perm[perm_idx]; + combined_dims[dim_idx] *= shape[cur_head]; + } else { + // Else start a new dimension. + cur_head = perm[perm_idx]; + dim_idx++; + new_dim_pos[cur_head] = dim_idx; + combined_dims[dim_idx] = shape[cur_head]; + } + } + + new_perm->resize(dim_idx + 1); + + dim_idx = 0; + for (int i = 0; i < new_dim_pos.size(); ++i) { + if (new_dim_pos[i] >= 0) { + int new_perm_idx = new_dim_pos[i]; + (*new_perm)[dim_idx] = new_perm_idx; + dim_vec.push_back(combined_dims[new_perm_idx]); + dim_idx++; + } + } + + *new_dims = framework::make_ddim(dim_vec); +} + +template +struct TransposeSimple { + static bool run(const platform::CUDADeviceContext& ctx, const Tensor& in, + const std::vector perm, Tensor* out) { + // First reduce the dimensions of the input tensor if possible. + std::vector new_perm; + framework::DDim new_dims; + CombineTransposeDim3(in.dims(), perm, &new_perm, &new_dims); + + // Only use tile copy GPU kernel when dimension is 2 or 3. + int dims = new_dims.size(); + std::vector new_dim_vec = framework::vectorize(new_dims); + if (dims < 2 || dims > 3) return false; + auto in_data = in.data(); + auto out_data = out->data(); + // In most cases, dim will not greater than 3 after combine. + switch (dims) { + case 2: + if (new_perm[0] == 1 && new_perm[1] == 0) { + // Add the first dimension size as 1. + new_dim_vec.insert(new_dim_vec.begin(), 1); + SwapDim1And2InTranspose()(ctx, in_data, new_dim_vec, out_data); + return true; + } + break; + case 3: + // In this case, suppose we can do coalescing read and write in tile. + if (new_perm == std::vector({0, 2, 1})) { + SwapDim1And2InTranspose()(ctx, in_data, new_dim_vec, out_data); + return true; + } else if (new_perm == std::vector({2, 1, 0})) { + // Maybe can optimize later, find a way to do coalescing memory copy. + // But I think it depends on the data size. If span is not large, + // maybe + // can do coalescing. + SwapDim0And2InTranspose()(ctx, in_data, new_dim_vec, out_data); + return true; + } else { + return false; + } + break; + default: + return false; + } + return false; + } +}; + +template +class TransposeGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + + std::vector axis = context.Attr>("axis"); + int ndims = axis.size(); + const auto& dev_ctx = context.template device_context(); + auto ret = TransposeSimple::run(dev_ctx, *x, axis, out); + if (!ret) { + TransCompute(ndims, dev_ctx, *x, out, axis); + } + } +}; +template +class TransposeGradGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* out_grad = + context.Input(framework::GradVarName("Out")); + auto* x_grad = + context.Output(framework::GradVarName("X")); + if (!x_grad) return; + + x_grad->mutable_data(context.GetPlace()); + std::vector axis = context.Attr>("axis"); + std::vector reversed_axis(axis); + + for (size_t i = 0; i < axis.size(); i++) { + reversed_axis[axis[i]] = i; + } + + int ndims = axis.size(); + const auto& dev_ctx = context.template device_context(); + auto ret = + TransposeSimple::run(dev_ctx, *out_grad, reversed_axis, x_grad); + if (!ret) { + TransCompute(ndims, dev_ctx, *out_grad, x_grad, + reversed_axis); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + transpose, + ops::TransposeGPUKernel, + ops::TransposeGPUKernel, + ops::TransposeGPUKernel); +REGISTER_OP_CUDA_KERNEL( + transpose_grad, + ops::TransposeGradGPUKernel, + ops::TransposeGradGPUKernel, + ops::TransposeGradGPUKernel); + +REGISTER_OP_CUDA_KERNEL( + transpose2, + ops::TransposeGPUKernel, + ops::TransposeGPUKernel, + ops::TransposeGPUKernel, + ops::TransposeGPUKernel, + ops::TransposeGPUKernel); +REGISTER_OP_CUDA_KERNEL( + transpose2_grad, + ops::TransposeGradGPUKernel, + ops::TransposeGradGPUKernel, + ops::TransposeGradGPUKernel, + ops::TransposeGradGPUKernel, + ops::TransposeGradGPUKernel); diff --git a/paddle/fluid/operators/transpose_op.cu.cc b/paddle/fluid/operators/transpose_op.cu.cc deleted file mode 100644 index debf9bce55..0000000000 --- a/paddle/fluid/operators/transpose_op.cu.cc +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/transpose_op.h" - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL( - transpose, ops::TransposeKernel, - ops::TransposeKernel, - ops::TransposeKernel); -REGISTER_OP_CUDA_KERNEL( - transpose_grad, - ops::TransposeGradKernel, - ops::TransposeGradKernel, - ops::TransposeGradKernel); - -REGISTER_OP_CUDA_KERNEL( - transpose2, - ops::TransposeKernel, - ops::TransposeKernel, - ops::TransposeKernel, - ops::TransposeKernel, - ops::TransposeKernel); -REGISTER_OP_CUDA_KERNEL( - transpose2_grad, - ops::TransposeGradKernel, - ops::TransposeGradKernel, - ops::TransposeGradKernel, - ops::TransposeGradKernel, - ops::TransposeGradKernel); diff --git a/python/paddle/fluid/tests/unittests/test_transpose_op.py b/python/paddle/fluid/tests/unittests/test_transpose_op.py index 273d474b3b..c8e3a1f2c2 100644 --- a/python/paddle/fluid/tests/unittests/test_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_transpose_op.py @@ -80,6 +80,24 @@ class TestCase4(TestTransposeOp): self.axis = (4, 2, 3, 1, 0, 5) +class TestCase5(TestTransposeOp): + def initTestCase(self): + self.shape = (2, 16, 96) + self.axis = (0, 2, 1) + + +class TestCase6(TestTransposeOp): + def initTestCase(self): + self.shape = (2, 10, 12, 16) + self.axis = (3, 1, 2, 0) + + +class TestCase7(TestTransposeOp): + def initTestCase(self): + self.shape = (2, 10, 2, 16) + self.axis = (0, 1, 3, 2) + + class TestTransposeOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): -- GitLab