diff --git a/oneflow/core/primitive/common/permute.h b/oneflow/core/primitive/common/permute.h index 5dba46223a646d6a0bd162cd3f67b650eeb0de62..ccd526bfec4babce7a901916ad57dd98509e5591 100644 --- a/oneflow/core/primitive/common/permute.h +++ b/oneflow/core/primitive/common/permute.h @@ -125,12 +125,10 @@ struct PermuteKernelParams { constexpr size_t kMaxMovementSize = 16; constexpr size_t kMaxNumDims = 8; -template -void LaunchKernel(StreamContext* stream_ctx, PermuteKernelParams params); - -template -void LaunchKernel(StreamContext* stream_ctx, const int64_t* src_dims, const void* src, - const int* permutation, void* dst, size_t count) { +template +PermuteKernelParams MakePermuteParams(const int64_t* src_dims, const void* src, + const int* permutation, void* dst, + size_t count) { PermuteKernelParams params; params.src_index_helper = NdIndexOffsetHelper(src_dims); int64_t dst_dims[num_dims]; @@ -140,9 +138,13 @@ void LaunchKernel(StreamContext* stream_ctx, const int64_t* src_dims, const void params.src = src; params.dst = dst; params.count = static_cast(count); - LaunchKernel(stream_ctx, params); + return params; } +template +void LaunchKernel(StreamContext* stream_ctx, const int64_t* src_dims, const void* src, + const int* permutation, void* dst, size_t count); + template void DispatchIndexType(StreamContext* stream_ctx, const int64_t* src_dims, const void* src, const int* permutation, void* dst) { diff --git a/oneflow/core/primitive/cpu/permute.cpp b/oneflow/core/primitive/cpu/permute.cpp index 09f7aa032434c206f319bfc8bda69c3847cf3ce0..b5f6ea89606d604aa73fea2d5f45f8d6ad3fdeb4 100644 --- a/oneflow/core/primitive/cpu/permute.cpp +++ b/oneflow/core/primitive/cpu/permute.cpp @@ -42,16 +42,19 @@ void PermuteKernel(PermuteKernelParams params) { } template -void LaunchKernel(StreamContext* stream_ctx, PermuteKernelParams params) { +void LaunchKernel(StreamContext* stream_ctx, const int64_t* src_dims, const void* src, + const int* permutation, void* dst, size_t count) { + PermuteKernelParams params = + MakePermuteParams(src_dims, src, permutation, dst, count); PermuteKernel(params); } - class PermuteImpl : public Permute { public: OF_DISALLOW_COPY_AND_MOVE(PermuteImpl); PermuteImpl() = default; ~PermuteImpl() override = default; + using Permute::Launch; void Launch(StreamContext* stream_ctx, DataType data_type, size_t num_dims, const int64_t* src_dims, const void* src, const int* permutation, void* dst) override { diff --git a/oneflow/core/primitive/cuda/permute.cu b/oneflow/core/primitive/cuda/permute.cu index fc72c5573e353357cafbd843cdbbb1adaa15e552..29df2c721b9eaf57129811bd03957a14cd624785 100644 --- a/oneflow/core/primitive/cuda/permute.cu +++ b/oneflow/core/primitive/cuda/permute.cu @@ -26,6 +26,10 @@ namespace permute_internal { namespace { +constexpr int32_t kMov4TileSize = 32; +constexpr int32_t kMov2TileSize = 64; +constexpr int32_t kBlockRows = 8; + template __global__ void PermuteKernel(PermuteKernelParams params) { using T = typename std::aligned_storage::type; @@ -44,12 +48,236 @@ __global__ void PermuteKernel(PermuteKernelParams params) { } } +// (B, X, Y) -> (B, Y, X) +// refer from https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/ +template +__global__ void BatchTransposeKernel(const void* src_ptr, void* dst_ptr, IndexType H, IndexType W, + IndexType num_tile_rows, IndexType num_tile_cols, + int32_t block_nums) { + using T = typename std::aligned_storage::type; + __shared__ T tile[tile_size][tile_size + 1]; // To avoid bank conflict. + + const T* src = reinterpret_cast(src_ptr); + T* dst = reinterpret_cast(dst_ptr); + + IndexType batch_num_tile = num_tile_rows * num_tile_cols; + for (int i = blockIdx.x, step = gridDim.x; i < block_nums; i += step) { + const IndexType batch_index = i / batch_num_tile; // the index of batch. + const IndexType flatten_index = + i - batch_index * batch_num_tile; // equal to i % (num_tile_rows*num_tile_cols). the + // flatten index of tile in a batch. + + const IndexType row_index = flatten_index / num_tile_cols; // the row index of tile in a batch. + const IndexType col_index = + flatten_index + - row_index + * num_tile_cols; // equal to k % num_tile_cols. the col index of tile in a batch. + const IndexType offset = batch_index * H * W; + IndexType x = col_index * tile_size + threadIdx.x; + IndexType y = row_index * tile_size + threadIdx.y; + if (x < W) { + IndexType y_range = + ((tile_size - threadIdx.y) < (H - y)) ? (tile_size - threadIdx.y) : (H - y); +#pragma unroll + // each thread process 4 elements. + // `i < y_range` equals to: `threadIdx.y + i < tile_size && y + i < H`. + for (int i = 0; i < y_range; i += kBlockRows) { + tile[threadIdx.y + i][threadIdx.x] = src[offset + (y + i) * W + x]; + } + } + __syncthreads(); + x = row_index * tile_size + threadIdx.x; + y = col_index * tile_size + threadIdx.y; + if (x < H) { + IndexType x_range = + ((tile_size - threadIdx.y) < (W - y)) ? (tile_size - threadIdx.y) : (W - y); +#pragma unroll + // `i < x_range` equals to: `threadIdx.y + i < tile_size && y + i < W`. + for (int i = 0; i < x_range; i += kBlockRows) { + dst[offset + (y + i) * H + x] = tile[threadIdx.x][threadIdx.y + i]; + } + } + __syncthreads(); + } +} + +/* +Here is a Movementsie=2 version of Batch Transpose. +When the H W can be divided by 2. we can read data use movementsize=4, and write back as +movementsize=2. +*/ +template +__global__ void BatchTransposeMovement2Kernel(const void* src_ptr, void* dst_ptr, IndexType rows, + IndexType cols, IndexType num_tile_rows, + IndexType num_tile_cols, int32_t block_nums) { + static_assert(tile_size % 2 == 0); + using T_MOV2 = typename std::aligned_storage<2, 2>::type; + using T_MOV4 = typename std::aligned_storage<4, 4>::type; + + const T_MOV4* src = reinterpret_cast(src_ptr); + T_MOV4* dst = reinterpret_cast(dst_ptr); + + // Use union structure to process Load and Store. + __shared__ union { + T_MOV2 tile_m2[tile_size][tile_size + 2]; // half [64][66] + T_MOV4 tile_m4[tile_size][tile_size / 2 + 1]; // half2 [64][33] + } tile_mem; + + IndexType batch_num_tile = num_tile_rows * num_tile_cols; + for (int i = blockIdx.x, step = gridDim.x; i < block_nums; i += step) { + const IndexType batch_index = i / batch_num_tile; // the index of batch. + const IndexType flatten_index = + i - batch_index * batch_num_tile; // equal to i%(num_tile_rows*num_tile_cols). the flatten + // index of tile in a batch. + + const IndexType row_index = flatten_index / num_tile_cols; // the row index of tile in a batch. + const IndexType col_index = + flatten_index + - row_index + * num_tile_cols; // equal to k % num_tile_cols. the col index of tile in a batch. + const IndexType offset = batch_index * rows * cols; + IndexType x = + col_index * tile_size + threadIdx.x * 2; // cause each thread process a half2 element, we + // need to multiply 2 for threadIdx.x. + IndexType y = row_index * tile_size + threadIdx.y; + if (x < cols) { + // each thread process 4 elements. + IndexType y_range = + ((tile_size - threadIdx.y) < (rows - y)) ? (tile_size - threadIdx.y) : (rows - y); +#pragma unroll + // `i < y_range` equals to: `threadIdx.y + i < tile_size && y + i < rows`. + for (int i = 0; i < y_range; i += kBlockRows) { + // each thread load a half2. + tile_mem.tile_m4[threadIdx.y + i][threadIdx.x] = src[(offset + (y + i) * cols + x) / 2]; + } + } + __syncthreads(); + x = row_index * tile_size + threadIdx.x * 2; // cause each thread process a half2 element, we + // need to multiply 2 for threadIdx.x. + y = col_index * tile_size + threadIdx.y; + if (x < rows) { + IndexType x_range = + ((tile_size - threadIdx.y) < (cols - y)) ? (tile_size - threadIdx.y) : (cols - y); +#pragma unroll + // `i < x_range` equals to: `threadIdx.y + i < tile_size && y + i < cols`. + for (int i = 0; i < x_range; i += kBlockRows) { + /* + When write back as column, it cannot be stored as half2 directly. + So we split as 2 half elements, and write back separately. + */ + union { + T_MOV4 m4; + T_MOV2 m2[2]; + } tmp_storage; + tmp_storage.m2[0] = tile_mem.tile_m2[threadIdx.x * 2][threadIdx.y + i]; + tmp_storage.m2[1] = tile_mem.tile_m2[threadIdx.x * 2 + 1][threadIdx.y + i]; + dst[(offset + (y + i) * rows + x) / 2] = tmp_storage.m4; + } + } + __syncthreads(); + } +} + +template +void LaunchBatchTransposeKernel(cudaStream_t& cuda_stream, + const PermuteKernelParams& params, + const IndexType& num_batches, const IndexType& rows, + const IndexType& cols) { + IndexType num_tile_rows = (rows + tile_size - 1) / tile_size; + IndexType num_tile_cols = (cols + tile_size - 1) / tile_size; + + const int32_t block_nums = num_batches * num_tile_rows * num_tile_cols; + int32_t checked_block_nums = std::min(block_nums, kCudaMaxBlocksNum); + if (tile_size == kMov2TileSize) { + const int32_t half2_thread = tile_size / 2; // cause each thread process two half elements. + BatchTransposeMovement2Kernel + <<>>( + params.src, params.dst, rows, cols, num_tile_rows, num_tile_cols, + block_nums); // Set threads num as 32x8 cause each threads + // process 4 elements to 32x32 share memory. + } else { + BatchTransposeKernel + <<>>( + params.src, params.dst, rows, cols, num_tile_rows, num_tile_cols, block_nums); + } +} + +template +bool CheckIfGreaterEqualThanTileSize(const IndexType& rows, const IndexType& cols) { + if (rows < tile_size || cols < tile_size) { return false; } + return true; +} + +template +bool CheckLaunchBatchTranspose(const int* permutation, const IndexType& num_batches, + const IndexType& rows, const IndexType& cols) { + if (CheckIfGreaterEqualThanTileSize(rows, cols)) { + if (num_batches == 1) { + // 2d tensor case: (0, 1) -> (1, 0) + return true; + } else if (num_dims == 3 && permutation[2] == 1 && permutation[1] == 2) { + // 3d tensor case: (0, 1, 2) -> (0, 2, 1) + return true; + } else { + return false; + } + } + return false; +} + +template +bool CheckUseMov2(const IndexType& rows, const IndexType& cols, const void* src, void* dst) { + auto src_ptr = reinterpret_cast(src); + auto dst_ptr = reinterpret_cast(dst); + return (movement_size == 2) && (rows % 2 == 0) && (cols % 2 == 0) && (src_ptr % 4 == 0) + && (dst_ptr % 4 == 0); + ; +} + +template +void InferBatchTransposeShape(const int64_t* src_dims, IndexType* num_batches, IndexType* rows, + IndexType* cols) { + if (num_dims == 2) { + *num_batches = 1; + *rows = src_dims[0]; + *cols = src_dims[1]; + } else { + *num_batches = src_dims[0]; + *rows = src_dims[1]; + *cols = src_dims[2]; + } +} + template -void LaunchKernel(StreamContext* stream_ctx, PermuteKernelParams params) { +void LaunchKernel(StreamContext* stream_ctx, const int64_t* src_dims, const void* src, + const int* permutation, void* dst, size_t count) { + PermuteKernelParams params = + MakePermuteParams(src_dims, src, permutation, dst, count); cudaStream_t cuda_stream = CHECK_NOTNULL(dynamic_cast(stream_ctx))->cuda_stream(); - PermuteKernel - <<>>(params); + + if (num_dims == 2 || num_dims == 3) { + IndexType num_batches; + IndexType rows; + IndexType cols; + InferBatchTransposeShape(src_dims, &num_batches, &rows, &cols); + if (CheckLaunchBatchTranspose(params.permutation, num_batches, rows, + cols)) { + if (CheckUseMov2(rows, cols, src, dst)) { + LaunchBatchTransposeKernel(cuda_stream, params, + num_batches, rows, cols); + } else { + LaunchBatchTransposeKernel( + cuda_stream, params, num_batches, rows, cols); + } + } else { + PermuteKernel + <<>>(params); + } + } else { + PermuteKernel + <<>>(params); + } } class PermuteImpl : public Permute { @@ -58,6 +286,7 @@ class PermuteImpl : public Permute { PermuteImpl() = default; ~PermuteImpl() override = default; + using Permute::Launch; void Launch(StreamContext* stream_ctx, DataType data_type, size_t num_dims, const int64_t* src_dims, const void* src, const int* permutation, void* dst) override { diff --git a/oneflow/user/kernels/transpose_kernel.cpp b/oneflow/user/kernels/transpose_kernel.cpp index e9dc5e4e40c7661615032842698980810d81ba39..4b3743f2e1743e0f93c0adde7adc6479654efa1b 100644 --- a/oneflow/user/kernels/transpose_kernel.cpp +++ b/oneflow/user/kernels/transpose_kernel.cpp @@ -13,58 +13,60 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "oneflow/core/common/protobuf.h" #include "oneflow/core/framework/framework.h" -#include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" - +#include "oneflow/core/primitive/include/permute.h" namespace oneflow { namespace user_op { -template +template +std::unique_ptr NewPermutePrimitive(Context* ctx) { + const int64_t num_dims = ctx->TensorDesc4ArgNameAndIndex("output", 0)->shape().NumAxes(); + return primitive::NewPrimitive(ctx->device_type(), num_dims); +} + class TransposeKernel final : public OpKernel, public user_op::CudaGraphSupport { public: + OF_DISALLOW_COPY_AND_MOVE(TransposeKernel); TransposeKernel() = default; ~TransposeKernel() override = default; private: void Compute(KernelComputeContext* ctx) const override { + auto primitive = NewPermutePrimitive(ctx); + CHECK(primitive); + const Tensor* tensor_in = ctx->Tensor4ArgNameAndIndex("input", 0); Tensor* tensor_out = ctx->Tensor4ArgNameAndIndex("output", 0); const auto& perm = ctx->Attr>("perm"); const ShapeView& in_shape = tensor_in->shape(); - const ShapeView& out_shape = tensor_out->shape(); - NewKernelUtil::Transpose(ctx->device_ctx(), in_shape.NumAxes(), in_shape, - out_shape, perm, in_shape.elem_cnt(), - tensor_in->dptr(), tensor_out->mut_dptr()); + DataType dtype = tensor_out->data_type(); + size_t num_dims = tensor_in->shape().NumAxes(); + const int64_t* src_dims = in_shape.ptr(); + + int64_t elem_cnt = tensor_out->shape().elem_cnt(); + if (elem_cnt != 0) { + primitive->Launch(ctx->stream_ctx(), dtype, num_dims, src_dims, tensor_in->dptr(), + perm.data(), tensor_out->mut_dptr()); + } else { + // For 0-d Tensor + return; + } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; -#define REGISTER_TRANSPOSE_KERNEL(device, dtype) \ - REGISTER_USER_KERNEL("transpose") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ - & (user_op::HobDataType("input", 0) == GetDataType::value) \ - & (user_op::HobDataType("output", 0) == GetDataType::value)); - -REGISTER_TRANSPOSE_KERNEL(DeviceType::kCPU, uint8_t) -REGISTER_TRANSPOSE_KERNEL(DeviceType::kCPU, int8_t) -REGISTER_TRANSPOSE_KERNEL(DeviceType::kCPU, int32_t) -REGISTER_TRANSPOSE_KERNEL(DeviceType::kCPU, int64_t) -REGISTER_TRANSPOSE_KERNEL(DeviceType::kCPU, float) -REGISTER_TRANSPOSE_KERNEL(DeviceType::kCPU, double) +hob::HobContextGetter PermutePrimitiveExists() { + return user_op::HobCtxGetter("PermutePrimitiveExists", + [](const user_op::KernelRegContext& ctx) { + return NewPermutePrimitive(&ctx).operator bool(); + }); +} -#ifdef WITH_CUDA -REGISTER_TRANSPOSE_KERNEL(DeviceType::kGPU, uint8_t) -REGISTER_TRANSPOSE_KERNEL(DeviceType::kGPU, int8_t) -REGISTER_TRANSPOSE_KERNEL(DeviceType::kGPU, int32_t) -REGISTER_TRANSPOSE_KERNEL(DeviceType::kGPU, int64_t) -REGISTER_TRANSPOSE_KERNEL(DeviceType::kGPU, float) -REGISTER_TRANSPOSE_KERNEL(DeviceType::kGPU, double) -REGISTER_TRANSPOSE_KERNEL(DeviceType::kGPU, float16) -#endif +REGISTER_USER_KERNEL("transpose") + .SetCreateFn() + .SetIsMatchedHob(PermutePrimitiveExists() == true); } // namespace user_op } // namespace oneflow