diff --git a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc index 077107dca68f332e866c1ba6aca3e5d8b8e7bc28..aa8c84879139ac09688f878868285328f58436b4 100644 --- a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc @@ -16,7 +16,6 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/platform/mkldnn_reuse.h" -#include "paddle/phi/kernels/funcs/transpose_functor.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index 52a9955acc18d74cdb06b66a0bf0e00a9db3bd82..e81c619db4348e38c4d63ebf489cc8cfa05e032e 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -20,7 +20,6 @@ limitations under the License. */ #include "paddle/fluid/platform/mkldnn_helper.h" #endif #include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/transpose_functor.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/transpose_op_mlu.cc b/paddle/fluid/operators/transpose_op_mlu.cc index 0ef9fc247ab0126ac2c465c6be37b0489db92f83..ba9997cf0f77a9fa0bef13721fbb5e08af366939 100644 --- a/paddle/fluid/operators/transpose_op_mlu.cc +++ b/paddle/fluid/operators/transpose_op_mlu.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/mlu/mlu_baseop.h" -#include "paddle/phi/kernels/funcs/transpose_functor.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/unique_op.h b/paddle/fluid/operators/unique_op.h index d1e9afa03ccee634906621349636426eedb89411..4d9b39d2dd262e131899a5282097582d2fd8265a 100644 --- a/paddle/fluid/operators/unique_op.h +++ b/paddle/fluid/operators/unique_op.h @@ -24,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/funcs/transpose_functor.h" namespace paddle { namespace operators { diff --git a/paddle/phi/kernels/autotune/auto_tune_base.h b/paddle/phi/kernels/autotune/auto_tune_base.h index 2a10314f844be77a36178b5bfd680d94c44f9a4a..4916890c461e7f13cb3cf3002039ccc7537685aa 100644 --- a/paddle/phi/kernels/autotune/auto_tune_base.h +++ b/paddle/phi/kernels/autotune/auto_tune_base.h @@ -123,7 +123,7 @@ class AutoTuneBase { float RunAndMeasureKernel(const Context& ctx, const int idx, Args&&... args) { // Regard 1st run as warmup, judge the compare result by the time cost // of rest cycles. - constexpr int repeats = 4; + constexpr int repeats = 6; phi::GpuTimer timer; float time_cost = 0; const auto& stream = ctx.stream(); diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index d2c30c8fa361146717674f4e80f86462aa4459a2..49020337e08d8c5033338fdf0174de90ee98d1bf 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -29,12 +29,23 @@ namespace funcs { #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) -template -int GetVecsize(const std::vector &ins, - std::vector *outs) { - int in_vec_size = 4; - int out_vec_size = 4; - if (outs->size() > 1) { +enum BroadcastLoadType { kMixed = 1, kBroadcast = 2, kElementwise = 3 }; + +template +struct LoaderTypeClassifier { + public: + int64_t numel{0}; + int vec_size{1}; + int broadcast_num{0}; + bool all_elementwise{true}; + phi::Array use_broadcast; + phi::Array ins_data; + + LoaderTypeClassifier() {} + LoaderTypeClassifier(const std::vector &ins, + std::vector *outs) { + int out_vec_size = + std::min(4, phi::GetVectorizedSize((*outs)[0]->data())); for (auto i = 1; i < outs->size(); ++i) { PADDLE_ENFORCE_EQ( (*outs)[i]->dims(), @@ -46,25 +57,33 @@ int GetVecsize(const std::vector &ins, out_vec_size = std::min( phi::GetVectorizedSize((*outs)[i]->data()), out_vec_size); } - } else { - out_vec_size = phi::GetVectorizedSize((*outs)[0]->data()); - } - for (auto *in : ins) { - auto temp_size = phi::GetVectorizedSize(in->data()); - in_vec_size = in->dims() == (*outs)[0]->dims() - ? std::min(temp_size, in_vec_size) - : in_vec_size; + numel = (*outs)[0]->numel(); + for (int i = 0; i < Arity; ++i) { + auto in_data = ins[i]->data(); + ins_data[i] = (const _ptr_ InT *)(in_data); + + bool is_same_dim = ins[i]->numel() == numel; + if (is_same_dim) { + use_broadcast[i] = false; + auto temp_size = phi::GetVectorizedSize(in_data); + in_vec_size = std::min(temp_size, in_vec_size); + } else { + use_broadcast[i] = true; + broadcast_num++; + } + all_elementwise &= is_same_dim; + } + vec_size = std::min(out_vec_size, in_vec_size); } - return std::min(out_vec_size, in_vec_size); -} + + private: + int in_vec_size{4}; +}; #ifndef PADDLE_WITH_XPU_KP -template +// Common broadcast/elementwise Loader. +template struct BroadcastDataLoader { __device__ __forceinline__ void operator()( T args[Arity][VecSize], @@ -88,8 +107,63 @@ struct BroadcastDataLoader { } }; +// Scalar elementwise Loader with consideration of IsBoundary. +template +struct BroadcastDataLoader { + __device__ __forceinline__ void operator()( + T args[Arity][VecSize], + const phi::Array &ins, + const phi::Array &configs, + const phi::Array &use_broadcast, + const int block_offset, + const int num, + const uint32_t numel) { + int thread_offset = threadIdx.x * VecSize + block_offset; +#pragma unroll + for (int i = 0; i < Arity; ++i) { +#pragma unroll + for (int idx = 0; idx < VecSize; ++idx) { + args[i][idx] = static_cast(1); + int index = thread_offset + idx; + if (index < numel) { + args[i][idx] = ins[i][index]; + } + } + } + } +}; + +// Vectorized elementwise Loader without consideration of IsBoundary. +template +struct BroadcastDataLoader { + __device__ __forceinline__ void operator()( + T args[Arity][VecSize], + const phi::Array &ins, + const phi::Array &configs, + const phi::Array &use_broadcast, + const int block_offset, + const int num, + const uint32_t numel) { + using VecType = phi::kps::details::VectorType; + VecType vec_temp[Arity]; + + int thread_offset = threadIdx.x + blockIdx.x * blockDim.x; +#pragma unroll + for (int i = 0; i < Arity; ++i) { + const VecType *__restrict__ vec_input = + reinterpret_cast(ins[i]); + vec_temp[i] = vec_input[thread_offset]; +#pragma unroll + for (int idx = 0; idx < VecSize; ++idx) { + args[i][idx] = vec_temp[i].val[idx]; + } + } + } +}; + +// Common broadcast data loader. template -struct BroadcastDataLoader { +struct BroadcastDataLoader { __device__ __forceinline__ void operator()( T args[Arity][VecSize], const phi::Array &ins, @@ -146,7 +220,7 @@ template + int LoadType> __device__ void VectorizedBroadcastKernelImpl( const phi::Array &ins, phi::Array<_ptr_ OutT *, NumOuts> outs, @@ -172,7 +246,7 @@ __device__ void VectorizedBroadcastKernelImpl( } } #else - BroadcastDataLoader()( + BroadcastDataLoader()( args, ins, configs, use_broadcast, block_offset, num, numel); #endif @@ -196,7 +270,7 @@ template + int LoadType> __global__ void VectorizedBroadcastKernel( phi::Array ins, phi::Array<_ptr_ OutT *, NumOuts> outs, @@ -218,15 +292,15 @@ __global__ void VectorizedBroadcastKernel( NumOuts, VecSize, false, - IsAllBroadcast>(ins, - outs, - use_broadcast, - numel, - configs, - BLOCK_NUM_X * read_lens, - block_offset, - read_lens, - func); + LoadType>(ins, + outs, + use_broadcast, + numel, + configs, + BLOCK_NUM_X * read_lens, + block_offset, + read_lens, + func); } int num = numel - block_offset; if (num > 0) { @@ -237,15 +311,15 @@ __global__ void VectorizedBroadcastKernel( NumOuts, VecSize, true, - IsAllBroadcast>(ins, - outs, - use_broadcast, - numel, - configs, - num, - block_offset, - read_lens, - func); + LoadType>(ins, + outs, + use_broadcast, + numel, + configs, + num, + block_offset, + read_lens, + func); } #else int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; @@ -257,15 +331,15 @@ __global__ void VectorizedBroadcastKernel( NumOuts, VecSize, false, - IsAllBroadcast>(ins, - outs, - use_broadcast, - numel, - configs, - BLOCK_NUM_X * VecSize, - block_offset, - read_lens, - func); + LoadType>(ins, + outs, + use_broadcast, + numel, + configs, + BLOCK_NUM_X * VecSize, + block_offset, + read_lens, + func); } else { VectorizedBroadcastKernelImpl(ins, - outs, - use_broadcast, - numel, - configs, - tail_tid, - block_offset, - read_lens, - func); + LoadType>(ins, + outs, + use_broadcast, + numel, + configs, + tail_tid, + block_offset, + read_lens, + func); } #endif } template @@ -297,29 +371,16 @@ void LaunchBroadcastKernel( const KPDevice &ctx, const std::vector &ins, std::vector *outs, - Functor func, - const phi::Array &configs) { - int broadcast_num = 0; - int numel = (*outs)[0]->numel(); - phi::Array use_broadcast; - phi::Array ins_data; + Func func, + const phi::Array &configs, + const LoaderTypeClassifier &loader_classifier) { phi::Array<_ptr_ OutT *, NumOuts> outs_data; - for (int i = 0; i < NumOuts; ++i) { outs_data[i] = (_ptr_ OutT *)(ctx.Alloc((*outs)[i])); } - for (int i = 0; i < Arity; ++i) { - if (ins[i]->numel() != numel) { - broadcast_num++; - use_broadcast[i] = true; - } else { - use_broadcast[i] = false; - } - ins_data[i] = (const _ptr_ InT *)(ins[i]->data()); - } - #ifdef PADDLE_WITH_XPU_KP + int numel = (*outs)[0]->numel(); const int threads = 64; const int blocks = 8; int read_lens = configs[0].buf_len; @@ -327,10 +388,10 @@ void LaunchBroadcastKernel( int main_offset = (numel / (read_lens * threads)) * read_lens * threads; int tail_tid = numel % (read_lens * threads); - VectorizedBroadcastKernel - <<>>(ins_data, + VectorizedBroadcastKernel + <<>>(loader_classifier.ins_data, outs_data, - use_broadcast, + loader_classifier.use_broadcast, numel, configs, main_offset, @@ -338,49 +399,54 @@ void LaunchBroadcastKernel( read_lens, func); #else + const auto &numel = loader_classifier.numel; auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize); - int read_lens = VecSize; auto stream = ctx.stream(); - auto threads = gpu_config.thread_per_block; + auto threads = gpu_config.GetBlockSize(); auto blocks = gpu_config.block_per_grid; - int main_offset = (numel / (read_lens * gpu_config.GetBlockSize())) * - read_lens * gpu_config.GetBlockSize(); - int tail_tid = numel % (read_lens * gpu_config.GetBlockSize()); + int main_offset = (numel / (VecSize * threads)) * VecSize * threads; + int tail_tid = numel % (VecSize * threads); - if (broadcast_num > (Arity >> 1)) { - VectorizedBroadcastKernel 1)> - <<>>(ins_data, + kElementwise> + <<>>(loader_classifier.ins_data, + outs_data, + loader_classifier.use_broadcast, + numel, + configs, + main_offset, + tail_tid, + VecSize, + func); + } else if (loader_classifier.broadcast_num > (Arity >> 1)) { + constexpr BroadcastLoadType type_ = (Arity > 1) ? kBroadcast : kMixed; + VectorizedBroadcastKernel + <<>>(loader_classifier.ins_data, outs_data, - use_broadcast, + loader_classifier.use_broadcast, numel, configs, main_offset, tail_tid, - read_lens, + VecSize, func); } else { - VectorizedBroadcastKernel - <<>>(ins_data, + VectorizedBroadcastKernel + <<>>(loader_classifier.ins_data, outs_data, - use_broadcast, + loader_classifier.use_broadcast, numel, configs, main_offset, tail_tid, - read_lens, + VecSize, func); } #endif @@ -777,21 +843,6 @@ struct LaunchBroadcastKernelWithInt64IndexHelper -static std::string ReversedVectorToString(const std::vector &reversed_v) { - std::stringstream ss; - bool is_last = true; - for (int i = reversed_v.size() - 1; i >= 0; --i) { - if (is_last) { - ss << reversed_v[i]; - is_last = false; - } else { - ss << ", " << reversed_v[i]; - } - } - return ss.str(); -} - template numel() >= std::numeric_limits::max(); if (use_int64_index_kernel) { - int vec_size = GetVecsize(ins, outs); - switch (vec_size) { + auto loader_classifier = LoaderTypeClassifier(ins, outs); + switch (loader_classifier.vec_size) { case VecSizeL: { LaunchBroadcastKernelWithInt64IndexHelperdims(), axis); - if (VLOG_IS_ON(6)) { - for (size_t i = 0; i < ins.size(); ++i) { - VLOG(6) << "input i=" << i << ": origin_dims={" << ins[i]->dims() - << "}, simplied_dims={" - << ReversedVectorToString(dims_simplifier.in_dims[i]) - << "}"; - } - VLOG(6) << "output: origin_dims={" << (*outs)[0]->dims() - << "}, simplied_dims={" - << ReversedVectorToString(dims_simplifier.out_dims) << "}"; - } - phi::Array configs; - -// get vec_size #ifdef PADDLE_WITH_XPU_KP PADDLE_ENFORCE_EQ( ins.size(), 2, phi::errors::InvalidArgument( "XPU only support inputs is 2, but received %d", ins.size())); + + auto loader_classifier = LoaderTypeClassifier(); + const auto dims_simplifier = + BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis); + if (VLOG_IS_ON(6)) { + DimsSimplifiedLogger::Log( + ins, outs, dims_simplifier, "XPU Broadcast"); + } configs[0] = kps::details::BroadcastConfig(dims_simplifier.out_dims, dims_simplifier.in_dims[0], dims_simplifier.in_dims[1], @@ -926,39 +968,46 @@ void BroadcastKernelForDifferentVecSize( bool is_optimize = configs[0].cmp_type != type; int vec_size = is_optimize ? VecSizeL : VecSizeM; #else - for (int i = 0; i < kArity; ++i) { - // get the broadcast config, - // if data shape is[m, n], then you should set data_dim = {n, m} - // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3} - // if (ins[i]->numel() != (*outs)[0]->numel()) { - if (ins[i]->numel()) { - configs[i] = kps::details::BroadcastConfig(dims_simplifier.out_dims, - dims_simplifier.in_dims[i], - dims_simplifier.rank); + auto loader_classifier = LoaderTypeClassifier(ins, outs); + if (!loader_classifier.all_elementwise) { + const auto dims_simplifier = + BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis); + + if (VLOG_IS_ON(6)) { + DimsSimplifiedLogger::Log( + ins, outs, dims_simplifier, "GPU Broadcast"); + } + for (int i = 0; i < kArity; ++i) { + // if data shape is[m, n], then you should set data_dim = {n, m} + // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3} + // if (ins[i]->numel() != (*outs)[0]->numel()) { + if (ins[i]->numel()) { + configs[i] = kps::details::BroadcastConfig(dims_simplifier.out_dims, + dims_simplifier.in_dims[i], + dims_simplifier.rank); + } } } - int vec_size = GetVecsize(ins, outs); #endif - - switch (vec_size) { + switch (loader_classifier.vec_size) { case VecSizeL: { LaunchBroadcastKernel( - ctx, ins, outs, func, configs); + ctx, ins, outs, func, configs, loader_classifier); break; } case VecSizeM: { LaunchBroadcastKernel( - ctx, ins, outs, func, configs); + ctx, ins, outs, func, configs, loader_classifier); break; } case VecSizeS: { LaunchBroadcastKernel( - ctx, ins, outs, func, configs); + ctx, ins, outs, func, configs, loader_classifier); break; } default: { PADDLE_THROW(phi::errors::Unimplemented( - "Unsupported vectorized size: %d!", vec_size)); + "Unsupported vectorized size: %d!", loader_classifier.vec_size)); break; } } @@ -1037,7 +1086,6 @@ void DefaultElementwiseOperator(const DeviceContext &dev_ctx, dev_ctx, x, y, axis, InverseFunctor(), z); } } - #endif } // namespace funcs diff --git a/paddle/phi/kernels/funcs/dims_simplifier.h b/paddle/phi/kernels/funcs/dims_simplifier.h index 0ef0f6ac5b8f6ff7f2f5b3ebc69974e65226cb7a..3912357546944734dc7101122cda298d23a93b17 100644 --- a/paddle/phi/kernels/funcs/dims_simplifier.h +++ b/paddle/phi/kernels/funcs/dims_simplifier.h @@ -34,18 +34,6 @@ struct BroadcastDimsSimplifier { BroadcastDimsSimplifier(const std::vector &ins, const phi::DDim &dims, int axis) { - if (!NeedBroadcast(ins, dims)) { - int64_t numel = phi::product(dims); - rank = 1; - N = ins.size(); - out_dims = DimVector{numel}; - in_dims.resize(N); - for (int64_t i = 0; i < N; ++i) { - in_dims[i] = DimVector{numel}; - } - return; - } - N = std::max(static_cast(ins.size()), 2); in_dims.resize(N); rank = dims.size(); @@ -112,18 +100,6 @@ struct BroadcastDimsSimplifier { } private: - bool NeedBroadcast(const std::vector &ins, - const phi::DDim &dims) { - bool no_broadcast_flag = true; - for (auto *in : ins) { - no_broadcast_flag &= ins[0]->dims() == in->dims(); - } - if (ins.size() > 0) { - no_broadcast_flag &= dims == ins[0]->dims(); - } - return !no_broadcast_flag; - } - // To compensate the lackage of input_tensors' dimension with axis. void ExtendInputDimensions(int N, int axis) { for (auto &in_dim : in_dims) { @@ -244,18 +220,18 @@ struct BroadcastDimsSimplifier { }; // Simplify the input dims and permute dims if possible. -struct DimsSimplifier { +struct PermuteDimsSimplifier { public: - explicit DimsSimplifier(const int rank, - const int64_t numel, - const std::vector &perm, - const std::vector &dims) + PermuteDimsSimplifier(const int rank, + const int64_t numel, + const std::vector &perm, + const std::vector &dims) : perm_(rank), src_dims_(rank), count_(numel) { SimplifyPermAndDims(rank, dims, perm); perm_.resize(rank_); src_dims_.resize(rank_); dst_dims_.resize(rank_); - if (!is_seq_perm_) { + if (!is_sequential_perm_) { for (auto i = 0; i < rank_; ++i) { dst_dims_[i] = src_dims_[perm_[i]]; } @@ -265,7 +241,7 @@ struct DimsSimplifier { } } - ~DimsSimplifier() = default; + ~PermuteDimsSimplifier() = default; const int &GetRank() const { return rank_; } const int64_t &GetCount() const { return count_; } @@ -276,8 +252,8 @@ struct DimsSimplifier { private: int rank_{1}; int64_t count_{0}; - bool is_seq_perm_{true}; std::vector perm_; + bool is_sequential_perm_{true}; std::vector src_dims_; std::vector dst_dims_; @@ -336,11 +312,44 @@ struct DimsSimplifier { const int mapped = valid_map[perm[i]]; if (mapped >= 0) { perm_[perm_idx] = mapped; - is_seq_perm_ &= (mapped == perm_idx); + is_sequential_perm_ &= (mapped == perm_idx); perm_idx += 1; } } - rank_ = is_seq_perm_ ? 1 : valid_dim_idx; + rank_ = is_sequential_perm_ ? 1 : valid_dim_idx; + } +}; + +template +struct DimsSimplifiedLogger { + public: + static void Log(const std::vector &ins, + std::vector *outs, + const BroadcastDimsSimplifier &dims_simplifier, + const std::string &op_name) { + VLOG(6) << op_name << "`s dims after simplification is below :"; + for (size_t i = 0; i < ins.size(); ++i) { + VLOG(6) << "input i=" << i << ": origin_dims={" << ins[i]->dims() + << "}, simplied_dims={" + << ReversedVectorToString(dims_simplifier.in_dims[i]) << "}"; + } + VLOG(6) << "output: origin_dims={" << (*outs)[0]->dims() + << "}, simplied_dims={" + << ReversedVectorToString(dims_simplifier.out_dims) << "}"; + } + + static std::string ReversedVectorToString(const std::vector &reversed_v) { + std::stringstream ss; + bool is_last = true; + for (int i = reversed_v.size() - 1; i >= 0; --i) { + if (is_last) { + ss << reversed_v[i]; + is_last = false; + } else { + ss << ", " << reversed_v[i]; + } + } + return ss.str(); } }; diff --git a/paddle/phi/kernels/funcs/transpose_function.cu.h b/paddle/phi/kernels/funcs/transpose_function.cu.h index d4e36745a47d4bf2d43ff3ebd4ae150a8a333b4c..9f746349a67a3969e8f52d57980c33efc2f26265 100644 --- a/paddle/phi/kernels/funcs/transpose_function.cu.h +++ b/paddle/phi/kernels/funcs/transpose_function.cu.h @@ -19,8 +19,9 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_utils.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/autotune/auto_tune_base.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/dims_simplifier.h" -#include "paddle/phi/kernels/funcs/transpose_functor.h" +#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/primitive/datamover_primitives.h" namespace phi { @@ -705,24 +706,24 @@ inline void CombineTransposeDim3(const DDim& shape, template struct TransposeSimple { - static bool Impl(const phi::GPUContext& ctx, - const phi::DenseTensor& in, - const std::vector perm, - phi::DenseTensor* out, - const int64_t numel) { + static bool Run(const phi::GPUContext& ctx, + const phi::DenseTensor& in, + const std::vector& perm, + phi::DenseTensor* out, + const int64_t numel) { if (numel >= std::numeric_limits::max()) { - return Run(ctx, in, perm, out); + return RunImpl(ctx, in, perm, out); } else { - return Run(ctx, in, perm, out); + return RunImpl(ctx, in, perm, out); } } private: template - static bool Run(const phi::GPUContext& ctx, - const phi::DenseTensor& in, - const std::vector perm, - phi::DenseTensor* out) { + static bool RunImpl(const phi::GPUContext& ctx, + const phi::DenseTensor& in, + const std::vector& perm, + phi::DenseTensor* out) { // First reduce the dimensions of the input tensor if possible. auto in_data = in.data(); auto out_data = out->data(); @@ -752,13 +753,128 @@ struct TransposeSimple { } }; -template +enum PermuteType { + kCopy = 1, + kSwapTranspose = 2, + kGeneralTranspose = 3, + kVecPermute = 4, + kGeneralPermute = 5 +}; + +constexpr int kBlockRows = 16; +constexpr int kTileSize = 32; +constexpr int kShareCol = (kTileSize + 1); + +#define GET_TILE_SIZE(LEN_, ALIGN_) \ + ((LEN_ + (ALIGN_ - 1)) & ~(ALIGN_ - 1)) / ALIGN_ + +template +struct PermTypeClassifier { + public: + PermTypeClassifier(const int sm_count, + const int rank, + const std::vector& perm, + const std::vector& dims, + const T* src, + T* dst) { + if (rank == 1) { + type_ = PermuteType::kCopy; + } else { + // Limitation of the setting in one dimension of cuda grid. + constexpr int64_t dim_limitation = 65536; + int dst_vec_size = phi::GetVectorizedSize(dst); + + // While the last dim is fixed, there is chance for vectorized IO. + const int last_idx = rank - 1; + if (perm[last_idx] == last_idx) { + type_ = PermuteType::kVecPermute; + vec_size_ = GetDimVecSize(dst_vec_size, dims[last_idx], src, false); + return; + } + + // Permute at last 2 dims, namely transpose. + if ((rank == 2 && perm[1] == 0) || + (rank == 3 && perm[2] == 1 && perm[1] == 2)) { + int64_t channel = rank == 2 ? 1 : dims[0]; + // Currently, transpose kernel cannot cover the case that channel + // dimension is more than 65536 which is the limitation of dim3 setting. + // This special case will be covered by extended transpose kernel later. + if (channel < dim_limitation) { + type_ = PermuteType::kGeneralTranspose; + num_rows_tile_ = GET_TILE_SIZE(dims[rank - 2], kTileSize); + int dim_vec_size = GetDimVecSize(dst_vec_size, dims[last_idx], src); + int tile_size = channel * num_rows_tile_ * + GET_TILE_SIZE(dims[last_idx], kTileSize); + vec_size_ = tile_size < sm_count ? 1 : dim_vec_size; + } else { + type_ = PermuteType::kGeneralPermute; + } + return; + } + + // Permute at first dim and third dim. + if (rank == 3 && perm[2] == 0 && perm[1] == 1) { + // Currently, transpose kernel cannot cover the case that channel + // dimension is more than 65536 which is the limitation of dim3 setting. + // This special case will be covered by extended transpose kernel later. + if (dims[1] < dim_limitation) { + type_ = PermuteType::kSwapTranspose; + num_rows_tile_ = GET_TILE_SIZE(dims[0], kTileSize); + + int dim_vec_size = GetDimVecSize(dst_vec_size, dims[last_idx], src); + int tile_size = + dims[1] * num_rows_tile_ * GET_TILE_SIZE(dims[2], kTileSize); + vec_size_ = tile_size < sm_count ? 1 : dim_vec_size; + } else { + type_ = PermuteType::kGeneralPermute; + } + return; + } + vec_size_ = dst_vec_size; + } + } + + ~PermTypeClassifier() = default; + + int GetVecSize() const { return vec_size_; } + int GetRowsTile() const { return num_rows_tile_; } + PermuteType GetPermType() const { return type_; } + + private: + int vec_size_{1}; + int64_t num_rows_tile_{0}; + PermuteType type_{kGeneralPermute}; + + // To find if highest common divisor and make it as vec_size. + int GetDimVecSize(const int dst_vec_size, + const int64_t target_dim, + const T* src, + bool use_share_mem = true) { + int vec_size = std::min(dst_vec_size, phi::GetVectorizedSize(src)); + int dim_vec_size = 1; + for (int size = vec_size; size > 0; size /= 2) { + if (target_dim % size == 0) { + dim_vec_size = size; + break; + } + } + + if (use_share_mem) { + // By bytes limitation of shared_memory. + return (sizeof(T) > sizeof(float) ? 1 : dim_vec_size); + } else { + return dim_vec_size; + } + } +}; + +template class IdxHelper { public: IdxHelper() {} explicit IdxHelper(const IndexT* dims) { - for (int i = N - 1; i >= 0; --i) { - stride_[i] = i < (N - 1) ? dims[i + 1] * stride_[i + 1] : 1; + for (int i = Rank - 1; i >= 0; --i) { + stride_[i] = i < (Rank - 1) ? dims[i + 1] * stride_[i + 1] : 1; } } @@ -770,25 +886,25 @@ class IdxHelper { IndexT* index) const { IndexT remaining = offset; #pragma unroll - for (int i = 0; i < N - 1; ++i) { + for (int i = 0; i < Rank - 1; ++i) { const IndexT idx = remaining / stride_[i]; remaining -= idx * stride_[i]; index[i] = idx; } - index[N - 1] = remaining; + index[Rank - 1] = remaining; } private: - IndexT stride_[N]; + IndexT stride_[Rank]; }; -template -class IdxHelper { +template +class IdxHelper { public: IdxHelper() {} explicit IdxHelper(const uint32_t* dims) { - for (int i = N - 1; i >= 0; --i) { - uint32_t value = i < (N - 1) ? dims[i + 1] * stride_[i + 1] : 1; + for (int i = Rank - 1; i >= 0; --i) { + uint32_t value = i < (Rank - 1) ? dims[i + 1] * stride_[i + 1] : 1; divmoder_[i] = phi::kps::details::FastDivMod(value); stride_[i] = value; } @@ -802,35 +918,35 @@ class IdxHelper { uint32_t* index) const { uint32_t remaining = offset; #pragma unroll - for (int i = 0; i < N - 1; ++i) { + for (int i = 0; i < Rank - 1; ++i) { uint32_t idx = divmoder_[i].Div(remaining); index[i] = idx; remaining -= idx * stride_[i]; } - index[N - 1] = remaining; + index[Rank - 1] = remaining; } private: - uint32_t stride_[N]; - phi::kps::details::FastDivMod divmoder_[N]; + uint32_t stride_[Rank]; + phi::kps::details::FastDivMod divmoder_[Rank]; }; // Transform index between memory offset and shape coodinate. -template +template class IdxAndOffsetHelper { public: IdxAndOffsetHelper() {} explicit IdxAndOffsetHelper(const IndexT* dims) { - index_helper = IdxHelper(dims); + index_helper = IdxHelper(dims); } __device__ __forceinline__ IndexT IndexToOffset(const IndexT* index) const { IndexT offset = 0; #pragma unroll - for (int i = 0; i < N - 1; ++i) { + for (int i = 0; i < Rank - 1; ++i) { offset += index[i] * index_helper.GetStride(i); } - offset += index[N - 1]; + offset += index[Rank - 1]; return offset; } @@ -840,7 +956,7 @@ class IdxAndOffsetHelper { } private: - IdxHelper index_helper; + IdxHelper index_helper; }; template @@ -1173,7 +1289,7 @@ struct TransposeLauncher { T* dst) { constexpr int ReadSize = sizeof(T) > sizeof(float) ? 1 : VecSize; const IndexT cols = dims[rank - 1] / VecSize; - const IndexT n_cols_tile = GETTILESIZE(cols, kTileSize); + const IndexT n_cols_tile = GET_TILE_SIZE(cols, kTileSize); if (perm_type == PermuteType::kGeneralTranspose) { IndexT chs = (rank == 2) ? 1 : dims[0]; @@ -1229,82 +1345,65 @@ struct TransposeLauncher { vec_write = is_vec_write ? kVecRow : 1; } IndexT n_rows_tile = is_vec_write - ? GETTILESIZE(rows, (kTileSize * vec_write)) + ? GET_TILE_SIZE(rows, (kTileSize * vec_write)) : num_rows_tile; return n_rows_tile; } }; template -struct PermuteDispatch { - public: - PermuteDispatch(const phi::GPUContext& ctx, - PermTypeClassifier* cls_ptr, - const std::vector& dims, - const std::vector& perm, - const IndexT count, - const T* src, - T* dst) - : dims_(dims), cls_(cls_ptr) { - rank_ = dims_.size(); - type_ = cls_->GetPermType(); - KernelTypeDispatch(ctx, count, perm, src, dst); - } - ~PermuteDispatch() {} - - private: - int rank_{0}; - std::vector dims_; - PermTypeClassifier* cls_; - PermuteType type_{kGeneralPermute}; +inline void PermuteDispatch(const phi::GPUContext& ctx, + const IndexT& count, + PermTypeClassifier* cls_ptr, + const std::vector& dims, + const std::vector& perm, + const T* src, + T* dst) { + int rank = dims.size(); + PermuteType type = cls_ptr->GetPermType(); - void KernelTypeDispatch(const phi::GPUContext& ctx, - const IndexT& count, - const std::vector& perm, - const T* src, - T* dst) { #define TRANSPOSE_DISPATCH_VEC_SIZE(size) \ case size: { \ TransposeLauncher()( \ - ctx, rank_, type_, dims_, cls_->GetRowsTile(), src, dst); \ + ctx, rank, type, dims, cls_ptr->GetRowsTile(), src, dst); \ break; \ } -#define PERMUTE_DISPATCH_VEC_SIZE(size) \ - case size: { \ - PermuteLauncher()( \ - ctx, rank_, count, type_, dims_, perm, src, dst); \ - break; \ +#define PERMUTE_DISPATCH_VEC_SIZE(size) \ + case size: { \ + PermuteLauncher()( \ + ctx, rank, count, type, dims, perm, src, dst); \ + break; \ } - switch (type_) { - case kSwapTranspose: - case kGeneralTranspose: - switch (cls_->GetVecSize()) { - TRANSPOSE_DISPATCH_VEC_SIZE(1); - TRANSPOSE_DISPATCH_VEC_SIZE(2); - TRANSPOSE_DISPATCH_VEC_SIZE(4); - } - break; - default: - switch (cls_->GetVecSize()) { - PERMUTE_DISPATCH_VEC_SIZE(1); - PERMUTE_DISPATCH_VEC_SIZE(2); - PERMUTE_DISPATCH_VEC_SIZE(4); - } - break; - } + switch (type) { + case kSwapTranspose: + case kGeneralTranspose: + switch (cls_ptr->GetVecSize()) { + TRANSPOSE_DISPATCH_VEC_SIZE(1); + TRANSPOSE_DISPATCH_VEC_SIZE(2); + TRANSPOSE_DISPATCH_VEC_SIZE(4); + } + break; + default: + switch (cls_ptr->GetVecSize()) { + PERMUTE_DISPATCH_VEC_SIZE(1); + PERMUTE_DISPATCH_VEC_SIZE(2); + PERMUTE_DISPATCH_VEC_SIZE(4); + } + break; + } #define TRANSPOSE_DISPATCH_VEC_SIZE #define PERMUTE_DISPATCH_VEC_SIZE - } -}; +} template -inline void PermuteAndTranspose(const phi::GPUContext& ctx, - const int& rank, - const phi::DenseTensor& in, - phi::DenseTensor* out, - const DimsSimplifier& simplifier) { +inline void PermuteAndTranspose( + const phi::GPUContext& ctx, + const int& rank, + const phi::DenseTensor& in, + phi::DenseTensor* out, + const phi::funcs::PermuteDimsSimplifier& simplifier) { T* dst_data = out->data(); const T* src_data = in.data(); const auto count = simplifier.GetCount(); @@ -1324,18 +1423,18 @@ inline void PermuteAndTranspose(const phi::GPUContext& ctx, } else { if (count < std::numeric_limits::max()) { PermuteDispatch(ctx, + static_cast(count), &classifier, simplifier.GetSrcDims(), simplifier.GetPerm(), - static_cast(count), src_data, dst_data); } else { PermuteDispatch(ctx, + static_cast(count), &classifier, simplifier.GetSrcDims(), simplifier.GetPerm(), - static_cast(count), src_data, dst_data); } @@ -1343,12 +1442,13 @@ inline void PermuteAndTranspose(const phi::GPUContext& ctx, } template -inline void PermuteWithEigen(const phi::GPUContext& ctx, - const int& rank, - const phi::DenseTensor& in, - phi::DenseTensor* out, - const DimsSimplifier& simplifier) { - const bool not_same_dims = simplifier.GetRank() != rank; +inline void PermuteWithEigen( + const phi::GPUContext& ctx, + const int& rank, + const phi::DenseTensor& in, + phi::DenseTensor* out, + const phi::funcs::PermuteDimsSimplifier& simplifier) { + bool not_same_dims = simplifier.GetRank() != rank; if (not_same_dims) { phi::DDim dst_dims = out->dims(); phi::DenseTensor temp_in; @@ -1373,10 +1473,10 @@ void TransposeGPUKernelDriver(const phi::GPUContext& ctx, phi::DenseTensor* out) { const int rank = perm.size(); int64_t numel = in.numel(); - bool ret = TransposeSimple::Impl(ctx, in, perm, out, numel); + bool ret = TransposeSimple::Run(ctx, in, perm, out, numel); if (!ret) { - auto simplifier = - DimsSimplifier(rank, numel, perm, phi::vectorize(in.dims())); + auto simplifier = phi::funcs::PermuteDimsSimplifier( + rank, numel, perm, phi::vectorize(in.dims())); auto* tuner = phi::autotune::MakeTransposeTuner(PermuteWithEigen); tuner->AddCallBack(PermuteAndTranspose); diff --git a/paddle/phi/kernels/funcs/transpose_functor.h b/paddle/phi/kernels/funcs/transpose_functor.h deleted file mode 100644 index c3904b9c1ade6831d639898a3615529493d44f51..0000000000000000000000000000000000000000 --- a/paddle/phi/kernels/funcs/transpose_functor.h +++ /dev/null @@ -1,143 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include - -#include "paddle/phi/core/tensor_utils.h" -#include "paddle/phi/kernels/funcs/aligned_vector.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace phi { -namespace funcs { - -enum { kTransposeMKLDNNFP32 = 1, kTransposeMKLDNNINT8 = 2 }; - -enum PermuteType { - kCopy = 1, - kSwapTranspose = 2, - kGeneralTranspose = 3, - kVecPermute = 4, - kGeneralPermute = 5 -}; - -constexpr int kBlockRows = 16; -constexpr int kTileSize = 32; -constexpr int kShareCol = (kTileSize + 1); - -#define GETTILESIZE(LEN_, ALIGN_) \ - ((LEN_ + (ALIGN_ - 1)) & ~(ALIGN_ - 1)) / ALIGN_ - -template -struct PermTypeClassifier { - public: - explicit PermTypeClassifier(const int sm_count, - const int rank, - const std::vector& perm, - const std::vector& dims, - const T* src, - T* dst) { - if (rank == 1) { - type_ = PermuteType::kCopy; - } else { - constexpr int64_t dim_limitation = 65536; - const int dst_vec_size = phi::GetVectorizedSize(dst); - - // While the last dim is fixed, there is chance for vectorized IO. - const int last_idx = rank - 1; - if (perm[last_idx] == last_idx) { - type_ = PermuteType::kVecPermute; - vec_size_ = GetDimVecSize(dst_vec_size, dims[last_idx], src, false); - return; - } - - // Permute at last 2 dims, namely transpose. - if ((rank == 2 && perm[1] == 0 && perm[0] == 1) || - (rank == 3 && perm[2] == 1 && perm[1] == 2)) { - int64_t channel = rank == 2 ? 1 : dims[0]; - // Currently, transpose kernel cannot cover the case that channel - // dimension is more than 65536 which is the limitation of dim3 setting. - // This special case will be covered by extended transpose kernel later. - if (channel < dim_limitation) { - type_ = PermuteType::kGeneralTranspose; - num_rows_tile_ = GETTILESIZE(dims[rank - 2], kTileSize); - int dim_vec_size = GetDimVecSize(dst_vec_size, dims[last_idx], src); - int tile_size = - channel * num_rows_tile_ * GETTILESIZE(dims[last_idx], kTileSize); - vec_size_ = tile_size < sm_count ? 1 : dim_vec_size; - } else { - type_ = PermuteType::kGeneralPermute; - } - return; - } - - // Permute at first dim and third dim. - if (rank == 3 && perm[2] == 0 && perm[1] == 1) { - // Currently, transpose kernel cannot cover the case that channel - // dimension is more than 65536 which is the limitation of dim3 setting. - // This special case will be covered by extended transpose kernel later. - if (dims[1] < dim_limitation) { - type_ = PermuteType::kSwapTranspose; - num_rows_tile_ = GETTILESIZE(dims[0], kTileSize); - - int dim_vec_size = GetDimVecSize(dst_vec_size, dims[last_idx], src); - int tile_size = - dims[1] * num_rows_tile_ * GETTILESIZE(dims[2], kTileSize); - vec_size_ = tile_size < sm_count ? 1 : dim_vec_size; - } else { - type_ = PermuteType::kGeneralPermute; - } - return; - } - vec_size_ = dst_vec_size; - } - } - - ~PermTypeClassifier() = default; - - int GetVecSize() const { return vec_size_; } - int GetRowsTile() const { return num_rows_tile_; } - PermuteType GetPermType() const { return type_; } - - private: - int vec_size_{1}; - int64_t num_rows_tile_{0}; - PermuteType type_{kGeneralPermute}; - - // To find if highest common divisor and make it as vec_size. - int GetDimVecSize(const int dst_vec_size, - const int64_t target_dim, - const T* src, - bool use_share_mem = true) { - const int vec_size = std::min(dst_vec_size, phi::GetVectorizedSize(src)); - int dim_vec_size = 1; - for (int size = vec_size; size > 0; size /= 2) { - if (target_dim % size == 0) { - dim_vec_size = size; - break; - } - } - - if (use_share_mem) { - // By bytes limitation of shared_memory. - return (sizeof(T) > sizeof(float) ? 1 : dim_vec_size); - } else { - return dim_vec_size; - } - } -}; - -} // namespace funcs -} // namespace phi