From 8a9bef70b23f2a882a66f0a009340d25e2b6b0cc Mon Sep 17 00:00:00 2001 From: Netpunk <69072522+Patrick-Star125@users.noreply.github.com> Date: Wed, 30 Nov 2022 16:05:42 +0800 Subject: [PATCH] [PHI decoupling] migrate transpose_op.cu.h and gpu_utils.h to phi (#48286) * migrate transpose_op.cu.h and gpu_utils.h * format code style * fix some problems * format code * reset tranpose_op.cc * test commit * recover transpose_op.h * delete transpose_op.h * adjust header files order in transpose_op.cc --- paddle/fluid/operators/fused/fmha_ref.h | 12 +-- .../operators/fused/fused_gate_attention.h | 27 ++++--- .../operators/mkldnn/transpose_mkldnn_op.cc | 2 +- paddle/fluid/operators/transpose_op.cc | 4 +- paddle/fluid/operators/transpose_op_mlu.cc | 2 +- paddle/fluid/operators/unique_op.h | 6 +- .../backends/gpu}/gpu_utils.h | 22 +++--- .../kernels/funcs/transpose_functor.cu.h} | 76 +++++++++---------- .../kernels/funcs/transpose_functor.h} | 52 ++----------- paddle/phi/kernels/gpu/transpose_kernel.cu | 4 +- 10 files changed, 83 insertions(+), 124 deletions(-) rename paddle/{fluid/framework => phi/backends/gpu}/gpu_utils.h (88%) rename paddle/{fluid/operators/transpose_op.cu.h => phi/kernels/funcs/transpose_functor.cu.h} (95%) rename paddle/{fluid/operators/transpose_op.h => phi/kernels/funcs/transpose_functor.h} (79%) diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index 66176c9e75..fc5f9cf71d 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -16,12 +16,12 @@ limitations under the License. */ #include "paddle/fluid/operators/dropout_impl.cu.h" #include "paddle/fluid/operators/fused/fused_softmax_mask.cu.h" -#include "paddle/fluid/operators/transpose_op.cu.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/functors.h" +#include "paddle/phi/kernels/funcs/transpose_functor.cu.h" #include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" namespace paddle { @@ -98,7 +98,7 @@ class FMHARef { // transpose with perm [2, 0, 3, 1, 4], // output_shape: [3, bs, num_head, seq_len, head_dim] std::vector perm_1 = {2, 0, 3, 1, 4}; - TransposeGPUKernelDriver( + phi::funcs::TransposeGPUKernelDriver( dev_ctx_, qkv_input_tensor, perm_1, transpose_2_out_tensor); T* qkv_data = transpose_2_out_tensor->data(); T* qk_out_data = qk_out_tensor->data(); @@ -254,7 +254,7 @@ class FMHARef { // transpose: [0, 2, 1, 3] // output shape: [batch_size, seq_len, num_heads, head_dim] std::vector perm_3 = {0, 2, 1, 3}; - TransposeGPUKernelDriver( + phi::funcs::TransposeGPUKernelDriver( dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); } @@ -428,7 +428,7 @@ class FMHARef { // transpose: [0, 2, 1, 3] // output shape: [batch_size, seq_len, num_heads, head_dim] std::vector perm_3 = {0, 2, 1, 3}; - TransposeGPUKernelDriver( + phi::funcs::TransposeGPUKernelDriver( dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); } @@ -470,7 +470,7 @@ class FMHARef { // transpose bw std::vector perm_3 = {0, 2, 1, 3}; - TransposeGPUKernelDriver( + phi::funcs::TransposeGPUKernelDriver( dev_ctx_, fmha_out_grad_tensor, perm_3, qktv_out_grad_tensor); // recall batchedgemm(nn) fw: softmax_out_data(x) * v_ptr(y) = @@ -648,7 +648,7 @@ class FMHARef { // transpose bw std::vector perm_1 = {1, 3, 0, 2, 4}; - TransposeGPUKernelDriver( + phi::funcs::TransposeGPUKernelDriver( dev_ctx_, *transpose_2_out_grad_tensor, perm_1, qkv_input_grad_tensor); } diff --git a/paddle/fluid/operators/fused/fused_gate_attention.h b/paddle/fluid/operators/fused/fused_gate_attention.h index e50cc24d88..1fba366ad2 100644 --- a/paddle/fluid/operators/fused/fused_gate_attention.h +++ b/paddle/fluid/operators/fused/fused_gate_attention.h @@ -14,11 +14,11 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/operators/transpose_op.cu.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/reduce_function.h" +#include "paddle/phi/kernels/funcs/transpose_functor.cu.h" #include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" namespace paddle { @@ -626,9 +626,12 @@ class FMHAGateRef { phi::DenseTensor* k_transpose_out, phi::DenseTensor* v_transpose_out) { std::vector perm = {0, 1, 3, 2, 4}; - TransposeGPUKernelDriver(dev_ctx_, q_out, perm, q_transpose_out); - TransposeGPUKernelDriver(dev_ctx_, k_out, perm, k_transpose_out); - TransposeGPUKernelDriver(dev_ctx_, v_out, perm, v_transpose_out); + phi::funcs::TransposeGPUKernelDriver( + dev_ctx_, q_out, perm, q_transpose_out); + phi::funcs::TransposeGPUKernelDriver( + dev_ctx_, k_out, perm, k_transpose_out); + phi::funcs::TransposeGPUKernelDriver( + dev_ctx_, v_out, perm, v_transpose_out); } void ComputeQKVTransposeBackward(const phi::DenseTensor& q_transpose_out_grad, @@ -638,11 +641,11 @@ class FMHAGateRef { phi::DenseTensor* k_out_grad, phi::DenseTensor* v_out_grad) { std::vector perm = {0, 1, 3, 2, 4}; - TransposeGPUKernelDriver( + phi::funcs::TransposeGPUKernelDriver( dev_ctx_, q_transpose_out_grad, perm, q_out_grad); - TransposeGPUKernelDriver( + phi::funcs::TransposeGPUKernelDriver( dev_ctx_, k_transpose_out_grad, perm, k_out_grad); - TransposeGPUKernelDriver( + phi::funcs::TransposeGPUKernelDriver( dev_ctx_, v_transpose_out_grad, perm, v_out_grad); } @@ -651,14 +654,15 @@ class FMHAGateRef { void ComputeQKVTransposeForward(const phi::DenseTensor& qkv_out, phi::DenseTensor* qkv_transpose_out) { std::vector perm = {3, 0, 1, 4, 2, 5}; - TransposeGPUKernelDriver(dev_ctx_, qkv_out, perm, qkv_transpose_out); + phi::funcs::TransposeGPUKernelDriver( + dev_ctx_, qkv_out, perm, qkv_transpose_out); } void ComputeQKVTransposeBackward( const phi::DenseTensor& qkv_transpose_out_grad, phi::DenseTensor* qkv_out_grad) { std::vector perm = {1, 2, 4, 0, 3, 5}; - TransposeGPUKernelDriver( + phi::funcs::TransposeGPUKernelDriver( dev_ctx_, qkv_transpose_out_grad, perm, qkv_out_grad); } @@ -667,13 +671,14 @@ class FMHAGateRef { void ComputeQKTVTransposeForward(const phi::DenseTensor& qktv_out, phi::DenseTensor* fmha_out) { std::vector perm = {0, 1, 3, 2, 4}; - TransposeGPUKernelDriver(dev_ctx_, qktv_out, perm, fmha_out); + phi::funcs::TransposeGPUKernelDriver(dev_ctx_, qktv_out, perm, fmha_out); } void ComputeQKTVTransposeBackward(const phi::DenseTensor& fmha_out_grad, phi::DenseTensor* qktv_out_grad) { std::vector perm = {0, 1, 3, 2, 4}; - TransposeGPUKernelDriver(dev_ctx_, fmha_out_grad, perm, qktv_out_grad); + phi::funcs::TransposeGPUKernelDriver( + dev_ctx_, fmha_out_grad, perm, qktv_out_grad); } // qk_out = qk_out + nonbatched_bias + src_mask diff --git a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc index f7f7e5f6ad..2c5b269c39 100644 --- a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc @@ -15,8 +15,8 @@ #include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/malloc.h" -#include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/phi/kernels/funcs/transpose_functor.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index 9ee0196d8c..52a9955acc 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/transpose_op.h" - #include #include #include @@ -21,6 +19,8 @@ limitations under the License. */ #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/kernels/funcs/transpose_functor.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/transpose_op_mlu.cc b/paddle/fluid/operators/transpose_op_mlu.cc index 722ad3584f..0ef9fc247a 100644 --- a/paddle/fluid/operators/transpose_op_mlu.cc +++ b/paddle/fluid/operators/transpose_op_mlu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/mlu/mlu_baseop.h" -#include "paddle/fluid/operators/transpose_op.h" +#include "paddle/phi/kernels/funcs/transpose_functor.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/unique_op.h b/paddle/fluid/operators/unique_op.h index 45b1e3c435..d1e9afa03c 100644 --- a/paddle/fluid/operators/unique_op.h +++ b/paddle/fluid/operators/unique_op.h @@ -23,8 +23,8 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/concat_and_split.h" -#include "paddle/fluid/operators/transpose_op.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/transpose_functor.h" namespace paddle { namespace operators { @@ -251,7 +251,7 @@ static void UniqueDim(const framework::ExecutionContext& context, in_trans.Resize(in_trans_dims); in_trans.mutable_data(context.GetPlace()); auto& dev_ctx = context.template device_context(); - TransCompute( + phi::funcs::TransCompute( in.dims().size(), dev_ctx, in, &in_trans, permute); // reshape tensor: eg. [dim1, dim0, dim2] -> [dim1, dim0*dim2] framework::DDim in_trans_flat_dims = phi::flatten_to_2d(in_trans_dims, 1); @@ -315,7 +315,7 @@ static void UniqueDim(const framework::ExecutionContext& context, out->Resize(phi::make_ddim(out_trans_dims_vec)); out->mutable_data(context.GetPlace()); concat_functor(dev_ctx, input_unbind, 0, &out_trans); - TransCompute( + phi::funcs::TransCompute( out_trans.dims().size(), dev_ctx, out_trans, out, permute); if (return_inverse) { diff --git a/paddle/fluid/framework/gpu_utils.h b/paddle/phi/backends/gpu/gpu_utils.h similarity index 88% rename from paddle/fluid/framework/gpu_utils.h rename to paddle/phi/backends/gpu/gpu_utils.h index 68cbc309c2..ea97a086af 100644 --- a/paddle/fluid/framework/gpu_utils.h +++ b/paddle/phi/backends/gpu/gpu_utils.h @@ -18,11 +18,11 @@ #include -#include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/core/enforce.h" #include "unsupported/Eigen/CXX11/Tensor" -namespace paddle { -namespace framework { +namespace phi { +namespace funcs { template struct DeviceArray { @@ -110,16 +110,16 @@ IntType CeilOrFloor(IntType x, IntType deviser) { PADDLE_ENFORCE_GT( deviser, 0, - platform::errors::InvalidArgument("deviser should be greater than 0, " - "but received is:%d", - deviser)); + phi::errors::InvalidArgument("deviser should be greater than 0, " + "but received is:%d", + deviser)); PADDLE_ENFORCE_GT( x, 0, - platform::errors::InvalidArgument("input should be greater than 0, " - "but received is:%d", - x)); + phi::errors::InvalidArgument("input should be greater than 0, " + "but received is:%d", + x)); const IntType round_to_zero = x / deviser; const IntType inte_result = round_to_zero * deviser; @@ -140,5 +140,5 @@ IntType CeilOrFloor(IntType x, IntType deviser) { } } -} // namespace framework -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/transpose_op.cu.h b/paddle/phi/kernels/funcs/transpose_functor.cu.h similarity index 95% rename from paddle/fluid/operators/transpose_op.cu.h rename to paddle/phi/kernels/funcs/transpose_functor.cu.h index 4fc610c393..0d24fdebef 100644 --- a/paddle/fluid/operators/transpose_op.cu.h +++ b/paddle/phi/kernels/funcs/transpose_functor.cu.h @@ -14,20 +14,18 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/framework/gpu_utils.h" -#include "paddle/fluid/operators/transpose_op.h" -#include "paddle/fluid/platform/fast_divmod.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_utils.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/autotune/auto_tune_base.h" +#include "paddle/phi/kernels/funcs/transpose_functor.h" +#include "paddle/phi/kernels/primitive/datamover_primitives.h" -namespace paddle { -namespace operators { +namespace phi { +namespace funcs { using Tensor = phi::DenseTensor; -using Dim3 = framework::Dim3; -using Index3 = framework::Index3; struct EqualTo { constexpr bool operator()(int a, int b) const { return a == b; } @@ -118,8 +116,8 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input, }; // Converts block idx to tile index, each block process a tile - Index3 input_block_tile_index = framework::ConvertTensorIndex( - blockIdx.x, tile_aligned_input_dim); + Index3 input_block_tile_index = + ConvertTensorIndex(blockIdx.x, tile_aligned_input_dim); // Compute real index align to tile:0, 32, 64... Index3 block_tile_index_in_input = { @@ -130,8 +128,7 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input, // Compute block flat index against input dims. IndexType input_origin_block_flat_index = - framework::FlatTensorIndex(block_tile_index_in_input, - input_dims); + FlatTensorIndex(block_tile_index_in_input, input_dims); bool full_tile = true; IndexType tile_width = TileY; @@ -193,8 +190,7 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input, }; IndexType output_origin_block_flat_index = - framework::FlatTensorIndex(block_tile_index_in_output, - output_dims); + FlatTensorIndex(block_tile_index_in_output, output_dims); constexpr IndexType out_effective_thread_num = NumThreads / TileX * TileX; @@ -230,13 +226,13 @@ bool SelectProperTileSize(std::vector>* tiles) { PADDLE_ENFORCE_LE( TSIZE, 16, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The tile size should smaller than 16, but received is:%d.", TSIZE)); PADDLE_ENFORCE_EQ( (TSIZE & (TSIZE - 1)), 0, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Data types should be powers of 2, but reived size is:%d.", TSIZE)); const int kMaxLongSideLen = 1024; @@ -316,7 +312,7 @@ struct NarrowDims2TransposeDispatch { PADDLE_ENFORCE_EQ( (tile_long & (tile_long - 1)), 0, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The length of the longer side of the tile should be power of 2." " But received value is:%d.", tile_long)); @@ -381,7 +377,7 @@ struct NarrowDims2TransposeDispatch< PADDLE_ENFORCE_EQ( (tile_long & (tile_long - 1)), 0, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The length of the longer side of the tile should be power of 2." " But received value is:%d.", tile_long)); @@ -431,7 +427,7 @@ struct NarrowDims2TransposeDispatch< PADDLE_ENFORCE_EQ( (tile_long & (tile_long - 1)), 0, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The length of the longer side of the tile should be power of 2," " but received is:%d.", tile_long)); @@ -459,7 +455,7 @@ void SwapDim1And2InNarrow(const phi::GPUContext& d, PADDLE_ENFORCE_EQ( ret, true, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "SelectProperTileSize should return true, but return value is:%d.", ret)); @@ -475,12 +471,12 @@ void SwapDim1And2InNarrow(const phi::GPUContext& d, // to find least wasted threads, which means we need to find tile // can split input properly, in another words: num_wasted_threads=0. int num_wasted_threads = - input_long_edge - framework::CeilOrFloor( - input_long_edge, proposed_tile_long_edge) * - proposed_tile_long_edge; + input_long_edge - + CeilOrFloor(input_long_edge, proposed_tile_long_edge) * + proposed_tile_long_edge; - int num_full_tiles = framework::CeilOrFloor( - input_long_edge, proposed_tile_long_edge); + int num_full_tiles = + CeilOrFloor(input_long_edge, proposed_tile_long_edge); float cost = num_wasted_threads; @@ -514,8 +510,8 @@ void SwapDim1And2InNarrow(const phi::GPUContext& d, // Here finally get proper long X short tile size. Dim3 input_dims_aligned = { input_dims[0], - framework::CeilOrFloor(input_dims[1], select_tile_size_i), - framework::CeilOrFloor(input_dims[2], select_tile_size_j), + CeilOrFloor(input_dims[1], select_tile_size_i), + CeilOrFloor(input_dims[2], select_tile_size_j), }; IndexType total_tiles_count = input_dims_aligned[0]; @@ -549,7 +545,7 @@ __global__ void TransposeSimpleKernel(IndexType nthreads, CUDA_KERNEL_LOOP_TYPE(output_index, nthreads, IndexType) { Index3 output_tensor_index = - framework::ConvertTensorIndex(output_index, output_dims); + ConvertTensorIndex(output_index, output_dims); Index3 input_tensor_index; input_tensor_index[0] = output_tensor_index[pos0]; @@ -557,7 +553,7 @@ __global__ void TransposeSimpleKernel(IndexType nthreads, input_tensor_index[2] = output_tensor_index[pos2]; IndexType input_index = - framework::FlatTensorIndex(input_tensor_index, input_dims); + FlatTensorIndex(input_tensor_index, input_dims); output[output_index] = input[input_index]; } @@ -585,8 +581,8 @@ void SendSwapDim1And2InTranspose(const phi::GPUContext& d, Dim3 input_dims_aligned = { input_dims[0], - framework::CeilOrFloor(input_dims[1], kTileSize), - framework::CeilOrFloor(input_dims[2], kTileSize), + CeilOrFloor(input_dims[1], kTileSize), + CeilOrFloor(input_dims[2], kTileSize), }; IndexType total_tiles_count = input_dims_aligned[0]; @@ -653,13 +649,13 @@ struct SwapDim0And2InTranspose { // This function is to combine dimension. fox example: // (0, 1, 3, 2) --> (0, 2, 1) -inline void CombineTransposeDim3(const framework::DDim& shape, +inline void CombineTransposeDim3(const DDim& shape, const std::vector& perm, std::vector* new_perm, - framework::DDim* new_dims) { + DDim* new_dims) { PADDLE_ENFORCE_EQ(shape.size(), perm.size(), - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( " shape should have the save dim with perm, but" " received shape size is:%d, perm size is:%d.", shape.size(), @@ -717,7 +713,7 @@ struct TransposeSimple { phi::DenseTensor* out) { // First reduce the dimensions of the input tensor if possible. std::vector new_perm; - framework::DDim new_dims; + DDim new_dims; CombineTransposeDim3(in.dims(), perm, &new_perm, &new_dims); // Only use tile copy GPU kernel when dimension is 2 or 3. @@ -796,7 +792,7 @@ class IdxHelper { explicit IdxHelper(const uint32_t* dims) { for (int i = N - 1; i >= 0; --i) { uint32_t value = i < (N - 1) ? dims[i + 1] * stride_[i + 1] : 1; - divmoder_[i] = paddle::platform::FastDivMod(value); + divmoder_[i] = phi::kps::details::FastDivMod(value); stride_[i] = value; } } @@ -817,7 +813,7 @@ class IdxHelper { private: uint32_t stride_[N]; - paddle::platform::FastDivMod divmoder_[N]; + phi::kps::details::FastDivMod divmoder_[N]; }; // Transform index between memory offset and shape coodinate. @@ -1188,8 +1184,8 @@ void TransposeGPUKernelDriver(const phi::GPUContext& ctx, ret = TransposeSimple::run(ctx, in, perm, out); } if (!ret) { - auto* tuner = - phi::autotune::MakeTransposeTuner(TransCompute); + auto* tuner = phi::autotune::MakeTransposeTuner( + funcs::TransCompute); tuner->AddCallBack(PermuteAndTranspose); size_t key = phi::autotune::TransposeKey( @@ -1208,5 +1204,5 @@ void TransposeGPUKernelDriver(const phi::GPUContext& ctx, } } -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/transpose_op.h b/paddle/phi/kernels/funcs/transpose_functor.h similarity index 79% rename from paddle/fluid/operators/transpose_op.h rename to paddle/phi/kernels/funcs/transpose_functor.h index 45495505e6..d2a72efed0 100644 --- a/paddle/fluid/operators/transpose_op.h +++ b/paddle/phi/kernels/funcs/transpose_functor.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,57 +16,15 @@ limitations under the License. */ #include -#include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/math_function.h" -namespace paddle { -namespace operators { +namespace phi { +namespace funcs { enum { kTransposeMKLDNNFP32 = 1, kTransposeMKLDNNINT8 = 2 }; -template -inline void TransCompute(const int dim, - const DeviceContext& dev_ctx, - const phi::DenseTensor& in, - phi::DenseTensor* out, - const std::vector& axis) { - switch (dim) { - case 0: - phi::Copy(dev_ctx, in, dev_ctx.GetPlace(), false, out); - break; - case 1: - phi::funcs::Transpose trans1; - trans1(dev_ctx, in, out, axis); - break; - case 2: - phi::funcs::Transpose trans2; - trans2(dev_ctx, in, out, axis); - break; - case 3: - phi::funcs::Transpose trans3; - trans3(dev_ctx, in, out, axis); - break; - case 4: - phi::funcs::Transpose trans4; - trans4(dev_ctx, in, out, axis); - break; - case 5: - phi::funcs::Transpose trans5; - trans5(dev_ctx, in, out, axis); - break; - case 6: - phi::funcs::Transpose trans6; - trans6(dev_ctx, in, out, axis); - break; - default: - // for dim >= 7 situation - phi::funcs::TransposeNormal trans_normal; - trans_normal(dev_ctx, in, out, axis); - } -} - enum PermuteType { kCopy = 1, kTranspose = 2, @@ -227,5 +185,5 @@ class TranposeTypeClassifier { } }; -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/transpose_kernel.cu b/paddle/phi/kernels/gpu/transpose_kernel.cu index 36cf3fb8e3..4b7265e2f3 100644 --- a/paddle/phi/kernels/gpu/transpose_kernel.cu +++ b/paddle/phi/kernels/gpu/transpose_kernel.cu @@ -16,12 +16,12 @@ #include -#include "paddle/fluid/operators/transpose_op.cu.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/transpose_functor.cu.h" #include "paddle/phi/kernels/impl/transpose_grad_kernel_impl.h" namespace phi { @@ -38,7 +38,7 @@ void TransposeKernel(const Context& ctx, phi::Copy(ctx, x, ctx.GetPlace(), false, out); return; } - paddle::operators::TransposeGPUKernelDriver(ctx, x, axis, out); + phi::funcs::TransposeGPUKernelDriver(ctx, x, axis, out); } } // namespace phi -- GitLab