diff --git a/paddle/fluid/operators/fused/attn_bias_add.cu.h b/paddle/fluid/operators/fused/attn_bias_add.cu.h index 7666ea7aee23cb1351e9af243d6a9ecd870b8568..f2f6b6bfe01d1ed279d37bfdb9236461a3c19911 100644 --- a/paddle/fluid/operators/fused/attn_bias_add.cu.h +++ b/paddle/fluid/operators/fused/attn_bias_add.cu.h @@ -191,9 +191,10 @@ void SetConfigForColumnReduce(const int max_threads, const int reduce_num, int num_block = (max_threads / left_num); if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) { - *blocking_size = details::GetLastPow2(reduce_num / num_block); + *blocking_size = + pten::kernels::details::GetLastPow2(reduce_num / num_block); if (*blocking_size <= 1) { - *blocking_size = details::GetLastPow2(sqrt(reduce_num)); + *blocking_size = pten::kernels::details::GetLastPow2(sqrt(reduce_num)); } else if (*blocking_size * 2 < reduce_num) { *blocking_size *= 2; } diff --git a/paddle/fluid/operators/margin_cross_entropy_op.cu b/paddle/fluid/operators/margin_cross_entropy_op.cu index 35035704b7e076d78594c27ab08bc49841e8697f..e4fb4150f841ba9f03ec1de990001a5bbc5f6c05 100644 --- a/paddle/fluid/operators/margin_cross_entropy_op.cu +++ b/paddle/fluid/operators/margin_cross_entropy_op.cu @@ -24,6 +24,7 @@ namespace cub = hipcub; #include "paddle/fluid/operators/margin_cross_entropy_op.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/softmax_impl.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include "paddle/fluid/string/string_helper.h" diff --git a/paddle/fluid/operators/reduce_ops/check_reduce_rank_test.cu b/paddle/fluid/operators/reduce_ops/check_reduce_rank_test.cu index 4464b9571255776d69589fd72bf99f589ca7aad4..63d42790205ab30380ae12e1c0bb42ccb9b04099 100644 --- a/paddle/fluid/operators/reduce_ops/check_reduce_rank_test.cu +++ b/paddle/fluid/operators/reduce_ops/check_reduce_rank_test.cu @@ -13,7 +13,7 @@ // limitations under the License. #include "gtest/gtest.h" -#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" +#include "paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h" namespace paddle { namespace operators { @@ -39,9 +39,9 @@ TEST(test_reduce_rank_check, all) { } if (is_valid) { - CheckReduceRank(reduce_rank, rank); + pten::kernels::details::CheckReduceRank(reduce_rank, rank); } else { - ASSERT_THROW(CheckReduceRank(reduce_rank, rank), + ASSERT_THROW(pten::kernels::details::CheckReduceRank(reduce_rank, rank), paddle::platform::EnforceNotMet); } } diff --git a/paddle/fluid/operators/reduce_ops/reduce_functor_op.h b/paddle/fluid/operators/reduce_ops/reduce_functor_op.h deleted file mode 100644 index 72d21d7074e88873fc9712f346011e7e5a1dc99c..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/reduce_ops/reduce_functor_op.h +++ /dev/null @@ -1,123 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" -#include "paddle/fluid/platform/hostdevice.h" -#ifdef __HIPCC__ -#include -#endif - -namespace paddle { -namespace operators { - -namespace kps = paddle::operators::kernel_primitives; - -template -struct CustomMin { - using Transformer = kps::IdentityFunctor; - - inline Ty initial() { - return static_cast(std::numeric_limits::max()); - } - - __device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const { - return (b < a) ? b : a; - } -}; - -template -struct CustomMax { - using Transformer = kps::IdentityFunctor; - - inline Ty initial() { - return static_cast(std::numeric_limits::lowest()); - } - - __device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const { - return (b > a) ? b : a; - } -}; - -// for cub::Reduce -template -struct CustomSum { - using Transformer = kps::IdentityFunctor; - - inline Ty initial() { return static_cast(0.0f); } - - __device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const { - return b + a; - } -}; - -template -struct CustomSub { - using Transformer = kps::InverseFunctor; - - inline Ty initial() { return static_cast(0.0f); } - - __device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const { - return b + a; - } -}; - -template -struct CustomMean { - using Transformer = kps::DivideFunctor; - - inline Ty initial() { return static_cast(0.0f); } - - __device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const { - return b + a; - } -}; - -template -struct CustomMul { - using Transformer = kps::IdentityFunctor; - - inline Ty initial() { return static_cast(1.0f); } - - __device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const { - return b * a; - } -}; - -template -struct CustomLogicalOr { - using Transformer = kps::IdentityFunctor; - - inline Ty initial() { return static_cast(false); } - - __device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const { - return b || a; - } -}; - -template -struct CustomLogicalAnd { - using Transformer = kps::IdentityFunctor; - - inline Ty initial() { return static_cast(true); } - - __device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const { - return b && a; - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h index 5a82176a9c9809372076eb1dbd5a6663cb0e0f71..e779da641b963a0339e127f82ee35a63c73418c0 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h @@ -29,952 +29,28 @@ namespace cub = hipcub; #endif -#include "paddle/fluid/framework/array.h" -#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/operators/amp/fp16_type_traits.h" -#include "paddle/fluid/operators/cast_op.h" -#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" -#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" -#include "paddle/fluid/platform/device/gpu/gpu_info.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/fast_divmod.h" -#include "paddle/fluid/string/string_helper.h" -// Reduce split or not, Whether to use ReduceHigherDim -#define REDUCE_SPLIT_BOUNDARY 512 -#define REDUCE_VEC_SIZE 4 - -namespace kps = paddle::operators::kernel_primitives; +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h" namespace paddle { namespace operators { -namespace details { - -static inline int GetLastPow2(int n) { - n |= (n >> 1); - n |= (n >> 2); - n |= (n >> 4); - n |= (n >> 8); - n |= (n >> 16); - return std::max(1, n - (n >> 1)); -} - -static inline int64_t AlignUp(int64_t a, int64_t b) { return (a + b - 1) / b; } - -// get strides of x_dim, reduce_dim and left_dim for reduceLastDim and reduceAny -static inline std::vector GetDimStrides(const std::vector& dims, - const std::vector& idx) { - int n = static_cast(idx.size()); - if (n == 0) return std::vector(); - std::vector strides(n); - strides.back() = 1; - for (int i = n - 2; i >= 0; --i) { - strides[i] = strides[i + 1] * dims[idx[i + 1]]; - } - return strides; -} - -// get blockDim for reduceLastDim and reduceAny -static inline int GetBlockDim(int block_dim) { - return block_dim >= kps::details::kReduceMaxThread - ? kps::details::kReduceMaxThread - : GetLastPow2(block_dim); -} - -// check reduce rand is valid -static inline void CheckReduceRank(int reduce_rank, int rank) { - if (rank % 2 == 0) { - PADDLE_ENFORCE_EQ(reduce_rank, rank / 2, - platform::errors::InvalidArgument( - "ReduceOp: invalid reduce rank. When rank = %d, " - "reduce_rank must be %d, but got %d.", - rank, rank / 2, reduce_rank)); - } else { - auto lower_rank = (rank - 1) / 2; - auto upper_rank = (rank + 1) / 2; - PADDLE_ENFORCE_EQ( - reduce_rank == lower_rank || reduce_rank == upper_rank, true, - platform::errors::InvalidArgument( - "ReduceOp: invalid reduce rank. When rank = %d, reduce_rank " - "must be %d or %d, but got %d.", - rank, lower_rank, upper_rank, reduce_rank)); - } -} - -// convert dims from vector to array -template -static inline paddle::framework::Array VectorToArray( - const VectorLikeType& vec) { - PADDLE_ENFORCE_LE(vec.size(), ElementCount, - platform::errors::InvalidArgument( - "Cub reduce Array: size not match. Received " - "vec.size() %d > ElementCount %d.", - vec.size(), ElementCount)); - size_t n = static_cast(vec.size()); - paddle::framework::Array ret; - for (size_t i = 0; i < n; ++i) { - ret[i] = vec[i]; - } - return ret; -} - -} // namespace details - -using Tensor = framework::Tensor; -constexpr int kMaxRank = framework::DDim::kMaxRank; - -enum ReduceType { - kReduceLastDim = 0x01, // when reduce_dim[0] == x_dim.size() - 1; - kReduceHigherDim = 0x02, // ReduceFirstDim or reduceSecondDim - kReduceAny = 0x03, // when reduce_dim.size() > 1 -}; - -struct IndexCalculator { - IndexCalculator(int dim, const std::vector& cal_dims, - const std::vector& cal_strides, - const std::vector& full_strides) - : dim(dim) { - dims = details::VectorToArray(cal_dims); - strides = details::VectorToArray(full_strides); - std::vector cal_divmoders; - // fast divmod - for (auto i : cal_strides) { - cal_divmoders.push_back(platform::FastDivMod(i)); - } - divmoders = - details::VectorToArray(cal_divmoders); - } - - __device__ inline int operator()(int offset) const { - int index = 0; -#pragma unroll - for (int i = 0; i < kMaxRank; ++i) { - if (i == dim) { - break; - } - auto divmod = divmoders[i].Divmod(offset); - index += (divmod.val[0] * strides[dims[i]]); - offset = divmod.val[1]; - } - return index; - } - - int dim; - framework::Array dims; - framework::Array strides; - framework::Array divmoders; -}; - -template -struct ReduceIndexMapping { - const kps::DimConfig dim; - HOSTDEVICE explicit ReduceIndexMapping(const kps::DimConfig& dims) - : dim(dims) {} - - __device__ __forceinline__ int BlockIdX() { -#ifdef PADDLE_WITH_XPU2 - if (ReduceLastDim) { - return (cluster_id() / dim.split_num_x % dim.split_num_y); - } else { - return cluster_id() % dim.split_num_x; - } -#else - return blockIdx.x; -#endif - } - - __device__ __forceinline__ int BlockIdY() { -#ifdef PADDLE_WITH_XPU2 - if (ReduceLastDim) { - return (cluster_id() % dim.split_num_x); - } else { - return (cluster_id() / dim.split_num_x % dim.split_num_y); - } -#else - return blockIdx.y; -#endif - } - - __device__ __forceinline__ int BlockDimX() { -#ifdef PADDLE_WITH_XPU2 - return dim.deal_size_x; -#else - return blockDim.x; -#endif - } - - __device__ __forceinline__ int BlockDimY() { -#ifdef PADDLE_WITH_XPU2 - return dim.deal_size_y; -#else - return blockDim.y; -#endif - } - - __device__ __forceinline__ int GridDimX() { -#ifdef PADDLE_WITH_XPU2 - if (ReduceLastDim) { - return dim.split_num_y; - } else { - return dim.split_num_x; - } -#else - return gridDim.x; -#endif - } - - __device__ __forceinline__ int GridDimY() { -#ifdef PADDLE_WITH_XPU2 - if (ReduceLastDim) { - return dim.split_num_x; - } else { - return dim.split_num_y; - } -#else - return gridDim.y; -#endif - } - - __device__ __forceinline__ int GetLoopSize() { -#ifdef PADDLE_WITH_XPU2 - if (ReduceLastDim) { - return dim.deal_size_y; - } else { - return dim.deal_size_x; - } -#else - return 1; -#endif - } -}; - -// when reduce_type == kReduceLastDim this struct will be used -// for higher performance -struct OneDimIndexCal { - explicit OneDimIndexCal(int num) : stride(num) {} - - __device__ inline int operator()(int index) const { return index * stride; } - int stride; -}; - -// reduce config -template -struct ReduceConfig { - ReduceConfig(const std::vector& origin_reduce_dims, - const std::vector& origin_x_dim) - : reduce_dims_origin(origin_reduce_dims), x_dim(origin_x_dim) {} - - // get the parameters of reduceKernel - void Run() { - // step1: update the reduce_dim left_dim and x_dim - SetReduceDim(); - - // step2: get the strides of dim for reduceAny and reduceLastDim - SetStrides(); - - // step3: get the type of reduce - SetReduceType(); - - // step4: set the block and grid for launch kernel - SetBlockDim(); - } - - // when should_reduce_again is true, we need malloc temp space for temp data - void SetOutputData(Ty* y_data, const platform::Place& place, - framework::Tensor* tmp) { - if (should_reduce_again) { - output_data = tmp->mutable_data( - framework::make_ddim( - {static_cast(left_num * grid.z * grid.y * sizeof(Ty))}), - place); - } else { - output_data = y_data; - } - } - - private: - // set reduce_dim, left_dim and update x_dim - // eg: x_dim = [2, 4, 6] origin_reduce_dims = [0, 1] - // --SetReduceDim--> x_dim = [8,6], reduce_dim = [0], left_dim = [1] - void SetReduceDim() { - std::set reduce_set; - for (auto e : reduce_dims_origin) { - auto pos = e >= 0 ? e : e + x_dim.size(); - reduce_set.insert(pos); - } - - std::vector reduce_dim_temp(reduce_set.begin(), reduce_set.end()); - std::sort(reduce_dim_temp.begin(), reduce_dim_temp.end()); - - // update reduce_dim and x_dim - std::vector x_new_dim; - - reduce_dim.push_back(reduce_dim_temp[0]); - x_new_dim.push_back(x_dim[0]); - - int idx_reduce = 1; - int num = 0; - - if (reduce_dim_temp.size() > 1) { - for (int i = 1; i < x_dim.size(); i++) { - if ((idx_reduce < reduce_dim_temp.size()) && - (i == reduce_dim_temp[idx_reduce])) { - int result = - reduce_dim_temp[idx_reduce] - reduce_dim[reduce_dim.size() - 1]; - bool is_equal = ((result - num) == 1); - if (is_equal) { - x_new_dim[x_new_dim.size() - 1] *= x_dim[i]; - num++; - } else { - reduce_dim.push_back(reduce_dim_temp[idx_reduce] - num); - x_new_dim.push_back(x_dim[i]); - } - idx_reduce++; - } else { - x_new_dim.push_back(x_dim[i]); - } - } - } else { - x_new_dim = x_dim; - } - - // update x_dim - x_dim = x_new_dim; - std::vector().swap(x_new_dim); - - std::vector reduce_dim_new; - int is_reduced = 0; - for (auto e : reduce_dim) { - is_reduced |= 1 << e; - } - - std::vector().swap(reduce_dim); - - for (int i = 0; i < x_dim.size(); i++) { - if ((i == 0) || (((is_reduced >> i) ^ (is_reduced >> (i - 1))) & 1)) { - x_new_dim.push_back(x_dim[i]); - if ((is_reduced >> i) & 1) - reduce_dim_new.push_back(x_new_dim.size() - 1); - } else { - x_new_dim[x_new_dim.size() - 1] *= x_dim[i]; - } - } - - x_dim = x_new_dim; - reduce_dim = reduce_dim_new; - - int x_rank = static_cast(x_dim.size()); - std::set left_set; - - for (int i = 0; i < x_rank; ++i) { - left_set.insert(i); - } - - for (auto e : reduce_dim) { - left_set.erase(e); - } - - left_dim.assign(left_set.begin(), left_set.end()); - - // if the last dim gets involved in reduction - reduce_last_dim = (reduce_dim.back() == x_dim.size() - 1); - } - - // set x_strides, reduce_strides, left_strides for reduceLastDim and reduceAny - // eg: x_dim = [8, 6], reduce_dim = [0], left_dim = [1] - // --SetStrides--> x_strides= [6,1], reduce_strides = [1], - // left_strides = [1] - void SetStrides() { - std::vector idx_dim; - for (int i = 0; i < x_dim.size(); i++) { - idx_dim.push_back(i); - } - - x_strides = details::GetDimStrides(x_dim, idx_dim); - reduce_strides = details::GetDimStrides(x_dim, reduce_dim); - left_strides = details::GetDimStrides(x_dim, left_dim); - reduce_num = reduce_strides[0] * x_dim[reduce_dim[0]]; - - left_num = 1; - if (left_dim.size()) { - left_num = left_strides[0] * x_dim[left_dim[0]]; - } - } - - // get the reduceType - // eg: x_dim = [8, 6] reduce_dim = [0] --> ReduceHigherDim -->reduceFirstDim - // x_dim = [8, 6] reduce_dim = [1] --> reduceLastDim - // x_dim = [8] reduce_dim = [0] --> reduceAll - // x_dim = [8, 6, 4, 2] reduce_dim = [0, 2] --> reduceAny - void SetReduceType() { - int rank = x_dim.size(); - int reduce_rank = reduce_dim.size(); - bool is_last_dim = - (rank == 2) && (reduce_rank == 1) && (reduce_dim[0] == 1); - if (rank == reduce_rank || is_last_dim) { - reduce_type = static_cast(ReduceType::kReduceLastDim); - } else if (reduce_rank == 1) { -// ReduceFirstDim and reduceSecondDim -#ifdef PADDLE_WITH_XPU2 - if (reduce_dim[0] == 0) { - reduce_type = static_cast(ReduceType::kReduceHigherDim); - } else { - reduce_type = static_cast(ReduceType::kReduceAny); - } -#else - reduce_type = static_cast(ReduceType::kReduceHigherDim); -#endif - } else { - reduce_type = static_cast(ReduceType::kReduceAny); - } - } - - void SetBlockDimForReduceAny(dim3* block_dim, dim3* grid_dim) { - constexpr int min_reduce_num_per_thread = 16; - constexpr int max_reduce_num_per_thread = 256; - constexpr int max_num_threads = kps::details::kReduceMaxThread; - - // set block size. - // 1. If reduce_last_dim == true, all the threads whose threadIdx.y are same - // will process the reduction for one output. - // The number of output for one block is blockDim.y; - // 2. If reduce_last_dim == false, different threadIdx.x will process - // different reduction and gets the output separately. If it is - // necessary, it should reduce in block y. - // The number of output for one block is blockDim.x; - int block_x, block_y; - int grid_num, reduce_num_per_thread; - if (reduce_last_dim) { - block_x = details::GetBlockDim(reduce_num); - block_y = details::GetBlockDim(left_num); - block_dim->x = block_x; - block_dim->y = - std::min(block_y, static_cast(max_num_threads / block_dim->x)); - grid_num = details::AlignUp(left_num, block_dim->y); - reduce_num_per_thread = details::AlignUp(reduce_num, block_dim->x); - } else { - block_x = details::GetBlockDim(left_num); - block_y = details::GetBlockDim(reduce_num); - block_dim->x = std::min(block_x, 32); - block_dim->y = - std::min(block_y, static_cast(max_num_threads / block_dim->x)); - block_dim->x = - std::min(block_x, static_cast(max_num_threads / block_dim->y)); - grid_num = details::AlignUp(left_num, block_dim->x); - reduce_num_per_thread = details::AlignUp(reduce_num, block_dim->y); - } - int device_id = platform::GetCurrentDeviceId(); - int max_mp = platform::GetGPUMultiProcessors(device_id); - int max_threads_per_mp = - platform::GetGPUMaxThreadsPerMultiProcessor(device_id); - int max_threads = max_threads_per_mp * max_mp; - int num_threads = block_dim->x * block_dim->y; - int max_num_blocks = max_threads / num_threads; - - // set grid size. - // Whether to set grid.y larger than 1, there are 3 following rules: - // 1. The number that each thread process should no less than - // min_reduce_num_per_threadbut no more than max_reduce_num_per_thread; - // 2. It should maximize the utilization of SM. - // So we choose the minimum between input_split_num_1 and input_split_num_3 - // to make each thread process as mush data as possible. Meanwhile, - // the number cannot be larger than max_reduce_num_per_thread, so we - // choose the maximum between the result above and input_split_num_2. - int input_split_num_1 = - details::AlignUp(reduce_num_per_thread, min_reduce_num_per_thread); - int input_split_num_2 = - details::AlignUp(reduce_num_per_thread, max_reduce_num_per_thread); - int input_split_num_3 = details::AlignUp(max_num_blocks, grid_num); - - grid_dim->x = grid_num; - grid_dim->y = std::max(std::min(input_split_num_1, input_split_num_3), - input_split_num_2); - // if grid.y > 1, we need launch reduce kernel again. - if (grid_dim->y > 1) { - should_reduce_again = true; - } - } - - // set block and grid for launch kernel - // for ReduceHigherDim: if block is enough -> splite reduce_num - // else init block(32, 1) grid(block_num, 1) - // for others: block(block_num, 1) , grid(left_num, 1) - void SetBlockDimForHigher(dim3* block_dim, dim3* grid_dim) { - int last_dim_num = x_dim.back(); - // update left_num - int grid_z = left_num / last_dim_num; - left_num = last_dim_num; - grid_dim->z = grid_z; - int device_id = platform::GetCurrentDeviceId(); - int max_mp = platform::GetGPUMultiProcessors(device_id); - int max_threads_per_mp = - platform::GetGPUMaxThreadsPerMultiProcessor(device_id); - int max_threads = max_threads_per_mp * max_mp; - // init - int num_block = (max_threads / left_num); - block_dim->x = details::GetBlockDim(left_num); - grid_dim->x = details::AlignUp(left_num, block_dim->x); - blocking_size = reduce_num; - - if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) { - blocking_size = details::GetLastPow2(reduce_num / num_block); - if (blocking_size <= 1) { - blocking_size = details::GetLastPow2(sqrt(reduce_num)); - } else if (blocking_size * 2 < reduce_num) { - blocking_size *= 2; - } - should_reduce_again = true; - grid_dim->y = details::AlignUp(reduce_num, blocking_size); - } - } - - void SetBlockDim() { - // init - int block_num = details::GetBlockDim(reduce_num); - should_reduce_again = false; - dim3 block_dim(block_num, 1, 1); - dim3 grid_dim(left_num, 1, 1); - blocking_size = reduce_num; -#ifdef PADDLE_WITH_XPU2 - if (reduce_last_dim) { - block_dim.x = 128; - block_dim.y = reduce_num; - grid_dim.x = 8; - grid_dim.y = 1; - } else { - block_dim.x = 128; - block_dim.y = left_num; - grid_dim.x = 8; - grid_dim.y = 1; - } -#else - if (reduce_type == ReduceType::kReduceHigherDim) { - SetBlockDimForHigher(&block_dim, &grid_dim); - } else { - SetBlockDimForReduceAny(&block_dim, &grid_dim); - } -#endif - - block = block_dim; - grid = grid_dim; - } - - public: - std::vector reduce_dims_origin; - std::vector reduce_dim; - std::vector x_dim; - std::vector left_dim; - std::vector x_strides; - std::vector left_strides; - std::vector reduce_strides; - - int reduce_type; - int reduce_num; - int left_num; - int blocking_size; - bool should_reduce_again; - bool reduce_last_dim; - - Ty* output_data; - - dim3 block; - dim3 grid; -}; - -// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or -// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this -// function will be used -template -__global__ void ReduceAnyKernel(const Tx* x, Ty* y, ReduceOp reducer, - TransformOp transformer, MPType init, - int reduce_num, int left_num, - bool reduce_last_dim, - const Calculator reduce_index_calculator, - const Calculator left_index_calculator, - const kps::DimConfig dim) { - int input_idx, left_idx, stride; - int block_size = 0; - bool need_store = true; - int loop_left = 0; - int tid = 0; - // the last dim gets involved in reduction - int store_offset = 0; - int stride_left = 0; - if (reduce_last_dim) { - auto block = ReduceIndexMapping(dim); - input_idx = block.BlockIdY() * block.BlockDimX(); - left_idx = block.BlockIdX() * block.BlockDimY() + THREAD_ID_Y; - stride = block.GridDimY() * block.BlockDimX(); - block_size = block.BlockDimX(); - need_store = (THREAD_ID_X == 0) && (left_idx < left_num); - store_offset = block.BlockIdY() * left_num + left_idx; - loop_left = min(block.GetLoopSize(), left_num - left_idx); - stride_left = 1; - tid = threadIdx.x; - } else { - auto block = ReduceIndexMapping(dim); - input_idx = block.BlockIdY() * block.BlockDimY(); - left_idx = block.BlockIdX() * block.BlockDimX() + THREAD_ID_X; - stride = block.GridDimY() * block.BlockDimY(); - block_size = block.BlockDimY(); - need_store = (THREAD_ID_Y == 0) && (left_idx < left_num); - loop_left = min(block.GetLoopSize(), left_num - left_idx); - stride_left = block.BlockDimX() * block.GridDimX(); - store_offset = block.BlockIdY() * left_num + left_idx; - tid = threadIdx.y; - } - // calculate the offset, means the addr where each thread really start. - // 1. reduce for each thread - MPType input_compute[REDUCE_VEC_SIZE]; - Tx input_reg[REDUCE_VEC_SIZE]; - for (int i = 0; i < loop_left; i += stride_left) { - int input_offset = left_index_calculator(left_idx + i); - const Tx* input = x + input_offset; - MPType reduce_var = init; - // load REDUCE_VEC_SIZE data once, and then compute - int bound = reduce_num - (REDUCE_VEC_SIZE - 1) * stride; - for (; input_idx + block_size < bound; - input_idx += REDUCE_VEC_SIZE * stride) { - kps::ReadDataReduce, false>( - &input_reg[0], input, input_idx, reduce_index_calculator, 1, - reduce_num, 1, stride, kps::IdentityFunctor(), reduce_last_dim); - kps::ElementwiseUnary( - &input_compute[0], &input_reg[0], transformer); - kps::Reduce( - &reduce_var, &input_compute[0], reducer, reduce_last_dim); - } - - kps::Init(&input_compute[0], init); - kps::ReadDataReduce( - &input_compute[0], input, input_idx, reduce_index_calculator, 1, - reduce_num - input_idx, 1, stride, transformer, reduce_last_dim); - kps::Reduce( - &reduce_var, &input_compute[0], reducer, reduce_last_dim); - - kps::Reduce( - &reduce_var, &reduce_var, reducer, reduce_last_dim); - if (need_store) { - y[store_offset + i] = static_cast(reduce_var); - } - } -} - -template -__global__ void ReduceHigherDimKernel(const Tx* x, Ty* y, ReduceOp reducer, - TransformOp transformer, MPType init, - int reduce_num, int left_num, - int blocking_size, - const kps::DimConfig dim) { - // when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this - // function will be used - auto block = ReduceIndexMapping(dim); - int idy = block.BlockIdY() * blocking_size; - int idx = block.BlockIdX() * block.BlockDimX(); - int idz = BLOCK_ID_Z * left_num; - int stride = dim.split_num_x * dim.deal_size_x; - int size = left_num - dim.rem_x; - int loop_size = min(reduce_num - idy, blocking_size); - int store_offset = block.BlockIdY() * left_num + idz * block.GridDimY(); - int block_offset = idy * left_num + idz * reduce_num; - const Tx* input = x + block_offset; - Tx reduce_input; - for (; idx < size; idx += stride) { - MPType reduce_var = init; - MPType reduce_compute = init; - for (int loop_idx = 0; loop_idx < loop_size; ++loop_idx) { - kps::ReadData(&reduce_input, - input + loop_idx * left_num + idx, - block.BlockDimX(), 1, 1, left_num); - kps::ElementwiseUnary( - &reduce_compute, &reduce_input, transformer); - kps::Reduce( - &reduce_var, &reduce_compute, reducer, false); - } - Ty result = static_cast(reduce_var); - kps::WriteData(y + store_offset + idx, &result, - block.BlockDimX()); - } - - if (idx < left_num) { - MPType reduce_var = init; - MPType reduce_compute = init; - for (int loop_idx = 0; loop_idx < loop_size; ++loop_idx) { - kps::ReadData(&reduce_input, - input + loop_idx * left_num + idx, - dim.rem_x, 1, 1, left_num); - kps::ElementwiseUnary( - &reduce_compute, &reduce_input, transformer); - kps::Reduce( - &reduce_var, &reduce_compute, reducer, false); - } - Ty result = static_cast(reduce_var); - kps::WriteData(y + store_offset + idx, &result, - dim.rem_x); - } -} - -template -static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, - const ReduceOp& reducer, - const TransformOp& transform, MPType init, - gpuStream_t stream, ReduceConfig config) { - if (config.reduce_type == kReduceLastDim) { - int stride_reduce = 1; - int stride_left = config.reduce_num; - // for higher performance - auto reduce_index_calculator = OneDimIndexCal(stride_reduce); - auto left_index_calculator = OneDimIndexCal(stride_left); - - kps::DimConfig dim = - kps::DimConfig(config.grid.x, config.grid.y, config.grid.z, - config.block.x, config.block.y, 0); - dim.SetRem(config.reduce_num % config.block.x, 0, 0); - -#ifdef PADDLE_WITH_XPU2 - ReduceAnyKernel<<<8, 128, stream>>>( - x_data, config.output_data, reducer, transform, init, config.reduce_num, - config.left_num, config.reduce_last_dim, reduce_index_calculator, - left_index_calculator, dim); -#else - ReduceAnyKernel<<>>( - x_data, config.output_data, reducer, transform, init, config.reduce_num, - config.left_num, config.reduce_last_dim, reduce_index_calculator, - left_index_calculator, dim); -#endif - - } else { - int reduce_rank = config.reduce_strides.size(); - int left_rank = config.left_strides.size(); - auto reduce_index_calculator = - IndexCalculator(reduce_rank, config.reduce_dim, config.reduce_strides, - config.x_strides); - auto left_index_calculator = IndexCalculator( - left_rank, config.left_dim, config.left_strides, config.x_strides); - - kps::DimConfig dim = - kps::DimConfig(config.grid.x, config.grid.y, config.grid.z, - config.block.x, config.block.y, 0); - dim.SetRem(config.reduce_num % config.block.x, 0, 0); - -#ifdef PADDLE_WITH_XPU2 - ReduceAnyKernel<<<8, 128, stream>>>( - x_data, config.output_data, reducer, transform, init, config.reduce_num, - config.left_num, config.reduce_last_dim, reduce_index_calculator, - left_index_calculator, dim); -#else - ReduceAnyKernel<<>>( - x_data, config.output_data, reducer, transform, init, config.reduce_num, - config.left_num, config.reduce_last_dim, reduce_index_calculator, - left_index_calculator, dim); -#endif - } - - if (config.should_reduce_again) { - dim3 block; - dim3 grid; - if (config.reduce_last_dim) { - block = dim3(32, 1, 1); - grid = dim3(details::AlignUp(config.left_num, 32), 1, 1); - } else { - block = dim3(config.block.x, 1, 1); - grid = dim3(config.grid.x, 1, config.grid.z); - } - - auto last_index = OneDimIndexCal(1); - auto first_index = OneDimIndexCal(config.left_num); - kps::DimConfig dim = - kps::DimConfig(grid.x, grid.y, grid.z, block.x, config.grid.y, 0); - dim.SetRem(config.left_num % block.x, 0, 0); -#ifdef PADDLE_WITH_XPU2 - ReduceHigherDimKernel><<<8, 128, stream>>>( - config.output_data, y_data, reducer, kps::IdentityFunctor(), - init, config.grid.y, config.left_num, config.grid.y, dim); -#else - ReduceHigherDimKernel< - Ty, Ty, MPType, ReduceOp, - kps::IdentityFunctor><<>>( - config.output_data, y_data, reducer, kps::IdentityFunctor(), - init, config.grid.y, config.left_num, config.grid.y, dim); -#endif - } -} - -template class ReduceOp, - typename TransformOp> -static typename std::enable_if::value, - void>::type -CubTensorReduceFunctorImpl(const Tx* x_data, Ty* y_data, - const TransformOp& transform, int reduce_num, - const platform::Place& place, gpuStream_t stream) { - auto reducer = ReduceOp(); - cub::TransformInputIterator trans_x(x_data, - transform); - size_t temp_storage_bytes = 0; - cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data, - reduce_num, reducer, reducer.initial(), stream); - framework::Tensor tmp; - auto* temp_storage = tmp.mutable_data( - framework::make_ddim({static_cast(temp_storage_bytes)}), place); - cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data, - reduce_num, reducer, reducer.initial(), stream); -} - -template class ReduceOp, - typename TransformOp> -static typename std::enable_if::value, - void>::type -CubTensorReduceFunctorImpl(const Tx* x_data, Ty* y_data, - const TransformOp& transform, int reduce_num, - const platform::Place& place, gpuStream_t stream) { - PADDLE_THROW(platform::errors::InvalidArgument( - "Tx should not be float16 when using cub::DeviceReduce::Reduce().")); -} - template class ReduceOp, typename TransformOp> void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, const TransformOp& transform, const std::vector& origin_reduce_dims, gpuStream_t stream) { - auto x_dim = framework::vectorize(x.dims()); - auto config = ReduceConfig(origin_reduce_dims, x_dim); - config.Run(); - int numel = x.numel(); - // after config.run() - // SetOutputData for ReduceHigherDim when should_reduce_again is true, - // temp_output should be stored temp_data in output_data space or stored in - // y_data; - framework::Tensor tmp; - auto x_data = x.data(); - auto y_data = y->mutable_data(x.place()); - - if (config.reduce_num == 1) { - auto out_dims = y->dims(); - if (x.type() == y->type()) { - framework::TensorCopy(x, y->place(), y); - y->Resize(out_dims); - } else { - auto* dev_ctx = static_cast( - paddle::platform::DeviceContextPool::Instance().Get(x.place())); - framework::VisitDataType( - static_cast(y->type()), - CastOpFunctor(&x, y, *dev_ctx)); - } - return; - } - - config.SetOutputData(y_data, x.place(), &tmp); - constexpr bool kIsTxFP16 = std::is_same::value; - bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16; - if (use_cub_reduce) { - CubTensorReduceFunctorImpl( - x_data, y_data, transform, config.reduce_num, x.place(), stream); - return; - } + y->mutable_data(x.place()); - using MPType = typename details::MPTypeTrait::Type; - auto reducer = ReduceOp(); - // launch ReduceHigherDimKernel - // when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this - // function will be used - // eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1 - // if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx / - // 32 - // else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32 - if (config.reduce_type == ReduceType::kReduceHigherDim) { - kps::DimConfig dim = - kps::DimConfig(config.grid.x, config.grid.y, config.grid.z, - config.block.x, config.blocking_size, 0); - dim.SetRem(config.left_num % config.block.x, - config.reduce_num % config.blocking_size, 0); + auto pt_x = paddle::experimental::MakePtenDenseTensor(x); + auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); -#ifdef PADDLE_WITH_XPU2 - ReduceHigherDimKernel, - TransformOp><<<8, 128, stream>>>( - x_data, config.output_data, reducer, transform, reducer.initial(), - config.reduce_num, config.left_num, config.blocking_size, dim); -#else - ReduceHigherDimKernel< - Tx, Ty, MPType, ReduceOp, - TransformOp><<>>( - x_data, config.output_data, reducer, transform, reducer.initial(), - config.reduce_num, config.left_num, config.blocking_size, dim); -#endif - - if (config.should_reduce_again) { - dim3 block = dim3(config.block.x, 1, 1); - dim3 grid = dim3(config.grid.x, 1, config.grid.z); - kps::DimConfig dim2 = - kps::DimConfig(grid.x, grid.y, grid.z, block.x, config.grid.y, 0); - dim2.SetRem(config.left_num % config.block.x, 0, 0); - -#ifdef PADDLE_WITH_XPU2 - ReduceHigherDimKernel< - Ty, Ty, MPType, ReduceOp, - kps::IdentityFunctor><<<8, 128, stream>>>( - config.output_data, y_data, reducer, - kps::IdentityFunctor(config.grid.y), reducer.initial(), - config.grid.y, config.left_num, config.grid.y, dim2); -#else - ReduceHigherDimKernel< - Ty, Ty, MPType, ReduceOp, - kps::IdentityFunctor><<>>( - config.output_data, y_data, reducer, - kps::IdentityFunctor(config.grid.y), reducer.initial(), - config.grid.y, config.left_num, config.grid.y, dim2); -#endif - } - return; - } - - // when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or - // when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this - // function will be used - LaunchReduceKernel, TransformOp>( - x_data, y_data, reducer, transform, reducer.initial(), stream, config); + pten::kernels::TensorReduceFunctorImpl( + *pt_x.get(), pt_y.get(), transform, origin_reduce_dims, stream); } -template class ReduceOp, - template class TransformOp> -struct TensorReduceFunc { - const framework::Tensor& x; - framework::Tensor* y; - std::vector origin_reduce_dims; - gpuStream_t stream; - int reduce_num; - TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y, - std::vector origin_reduce_dims, int num_reduce, - gpuStream_t stream) - : x(x), - y(y), - origin_reduce_dims(origin_reduce_dims), - reduce_num(num_reduce), - stream(stream) {} - - template - void apply() const { - using MPType = typename details::MPTypeTrait::Type; - TensorReduceFunctorImpl>( - x, y, TransformOp(reduce_num), origin_reduce_dims, stream); - } -}; - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index d3b938272e692b8f462f376cb28715e3c7edc3c7..bd09a7951aa2c9e7c94a91bfbd4df38fc286f28b 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -31,7 +31,7 @@ limitations under the License. */ #include "paddle/pten/kernels/hybird/general/reduce_impl.h" #if defined(__HIPCC__) || defined(__NVCC__) -#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" +#include "paddle/pten/kernels/hybird/cuda/reduce/reduce.h" #endif namespace paddle { @@ -700,24 +700,28 @@ class ReduceCudaKernel : public framework::OpKernel { auto out_dtype = context.Attr("out_dtype"); std::vector dims = context.Attr>("dim"); - std::vector reduce_dims = - GetReduceDim(dims, input->dims().size(), reduce_all); - int reduce_num = 1; - for (auto i : reduce_dims) { - reduce_num *= (input->dims())[i]; - } - gpuStream_t stream = context.cuda_device_context().stream(); + auto& dev_ctx = context.cuda_device_context(); + if (out_dtype >= 0) { - framework::VisitDataTypeSmall( - static_cast(out_dtype), - TensorReduceFunc( - *input, output, reduce_dims, reduce_num, stream)); + output->mutable_data( + dev_ctx.GetPlace(), + static_cast(out_dtype)); } else { - using MPType = typename details::MPTypeTrait::Type; - TensorReduceFunctorImpl>( - *input, output, TransformOp(reduce_num), reduce_dims, - stream); + output->mutable_data( + dev_ctx.GetPlace(), + static_cast(input->type())); } + + auto pt_x = paddle::experimental::MakePtenDenseTensor(*input); + auto pt_out = paddle::experimental::MakePtenDenseTensor(*output); + std::vector dims_int64{dims.begin(), dims.end()}; + + auto pt_out_dtype = pten::TransToPtenDataType( + static_cast(out_dtype)); + + pten::Reduce(dev_ctx, *pt_x.get(), reduce_all, + dims_int64, false, pt_out_dtype, + pt_out.get()); } }; #endif diff --git a/paddle/pten/api/ext/dispatch.h b/paddle/pten/api/ext/dispatch.h index 3b40a39af5300df546000b41d9228d89a66ba4a6..07d29ef3e140befe5f28fac92694457921b398a5 100644 --- a/paddle/pten/api/ext/dispatch.h +++ b/paddle/pten/api/ext/dispatch.h @@ -159,6 +159,73 @@ namespace paddle { } \ }() +///////// Floating and Complex and other type Dispatch Marco /////////// + +#define PD_DISPATCH_FLOATING_AND_COMPLEX_AND_1_TYPES( \ + SPECIFIED_TYPE, TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE( \ + NAME, \ + SPECIFIED_TYPE, \ + ::paddle::experimental::DataTypeToCppType::type, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::paddle::DataType::COMPLEX64, \ + ::paddle::complex64, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::paddle::DataType::COMPLEX128, \ + ::paddle::complex128, \ + __VA_ARGS__) \ + default: \ + PD_THROW("function " #NAME " is not implemented for data type `", \ + __dtype__, \ + "`"); \ + } \ + }() + +///////// Floating and Complex and 2 other type Dispatch Marco /////////// + +#define PD_DISPATCH_FLOATING_AND_COMPLEX_AND_2_TYPES( \ + SPECIFIED_TYPE1, SPECIFIED_TYPE2, TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE( \ + NAME, \ + SPECIFIED_TYPE1, \ + ::paddle::experimental::DataTypeToCppType::type, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, \ + SPECIFIED_TYPE2, \ + ::paddle::experimental::DataTypeToCppType::type, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::paddle::DataType::COMPLEX64, \ + ::paddle::complex64, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::paddle::DataType::COMPLEX128, \ + ::paddle::complex128, \ + __VA_ARGS__) \ + default: \ + PD_THROW("function " #NAME " is not implemented for data type `", \ + __dtype__, \ + "`"); \ + } \ + }() + ///////// Floating, Integral and Complex Dispatch Marco /////////// #define PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ diff --git a/paddle/pten/kernels/gpu/math.cu b/paddle/pten/kernels/gpu/math.cu index 4fc89a18da6a4559e8376d215ee80bc89f1b4263..e02403ac426f2c9544ac7300aec4099d0d010417 100644 --- a/paddle/pten/kernels/gpu/math.cu +++ b/paddle/pten/kernels/gpu/math.cu @@ -14,7 +14,6 @@ limitations under the License. */ #include "paddle/pten/kernels/gpu/math.h" -#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" #include "paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h" #include "paddle/pten/kernels/hybird/cuda/reduce/reduce.h" #include "paddle/pten/kernels/hybird/general/elementwise_functor.h" @@ -35,6 +34,8 @@ namespace cub = hipcub; #include "paddle/pten/core/convert_utils.h" #include "paddle/pten/core/kernel_registry.h" +namespace kps = paddle::operators::kernel_primitives; + namespace pten { /** @@ -64,7 +65,7 @@ void Mean(const GPUContext& dev_ctx, bool reduce_all, DenseTensor* out) { auto out_dtype = x.dtype(); - pten::Reduce( + pten::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); } @@ -85,7 +86,7 @@ void Sum(const GPUContext& dev_ctx, bool reduce_all, DataType out_dtype, DenseTensor* out) { - pten::Reduce( + pten::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); } @@ -95,7 +96,8 @@ using float16 = paddle::platform::float16; using complex64 = ::paddle::platform::complex; using complex128 = ::paddle::platform::complex; -PT_REGISTER_KERNEL(mean, GPU, ALL_LAYOUT, pten::Mean, float, double, bool) {} +PT_REGISTER_KERNEL( + mean, GPU, ALL_LAYOUT, pten::Mean, float, double, bool, float16) {} PT_REGISTER_KERNEL(add, GPU, ALL_LAYOUT, diff --git a/paddle/pten/kernels/hybird/cuda/reduce/reduce.h b/paddle/pten/kernels/hybird/cuda/reduce/reduce.h index 793e8505ec606b875d4bb2594d90c3755f521881..2281cd5ef78ea57de2c9c7c65a5ca455d67cc54c 100644 --- a/paddle/pten/kernels/hybird/cuda/reduce/reduce.h +++ b/paddle/pten/kernels/hybird/cuda/reduce/reduce.h @@ -17,38 +17,16 @@ // CUDA and HIP use same api #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#include "paddle/pten/api/ext/dispatch.h" #include "paddle/pten/backends/gpu/gpu_context.h" #include "paddle/pten/common/scalar.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h" - namespace pten { -static inline std::vector GetReduceDim( - const std::vector& dims, int dim_size, bool reduce_all) { - std::vector reduce_dims; - if (reduce_all) { - reduce_dims.resize(dim_size); - int reduce_size = reduce_dims.size(); - for (int i = 0; i < reduce_size; ++i) { - reduce_dims[i] = i; - } - } else { - for (auto e : dims) { - PADDLE_ENFORCE_LT(e, - dim_size, - paddle::platform::errors::InvalidArgument( - "ReduceOp: invalid axis, when x_dims is %d, " - "axis[i] should less than x_dims, but got %d.", - dim_size, - e)); - reduce_dims.push_back(e >= 0 ? e : e + dim_size); - } - } - return reduce_dims; -} - -template class ReduceFunctor> +template class ReduceOp, + template class TransformOp> void Reduce(const GPUContext& dev_ctx, const DenseTensor& x, bool reduce_all, @@ -56,20 +34,35 @@ void Reduce(const GPUContext& dev_ctx, bool keep_dim, DataType out_dtype, DenseTensor* out) { - std::vector reduce_dims = - GetReduceDim(dims, x.dims().size(), reduce_all); + std::vector reduce_dims = + pten::kernels::details::GetReduceDim(dims, x.dims().size(), reduce_all); + + int reduce_num = 1; + for (auto i : reduce_dims) { + reduce_num *= (x.dims())[i]; + } gpuStream_t stream = dev_ctx.stream(); if (out_dtype != pten::DataType::UNDEFINED && out_dtype != x.dtype()) { - PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES( - out_dtype, "TensorReduceFunctorImpl", ([&] { - pten::detail::TensorReduceFunctorImpl( - x, out, reduce_dims, stream); + PD_DISPATCH_FLOATING_AND_COMPLEX_AND_2_TYPES( + pten::DataType::INT32, + pten::DataType::INT64, + out_dtype, + "TensorReduceFunctorImpl", + ([&] { + using MPType = typename kps::details::MPTypeTrait::Type; + pten::kernels::TensorReduceFunctorImpl>( + x, out, TransformOp(reduce_num), reduce_dims, stream); })); } else { - pten::detail::TensorReduceFunctorImpl( - x, out, reduce_dims, stream); + using MPType = typename kps::details::MPTypeTrait::Type; + pten::kernels:: + TensorReduceFunctorImpl>( + x, out, TransformOp(reduce_num), reduce_dims, stream); } } diff --git a/paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h b/paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h index bdb883c1df8714fbf18d775a44985b11a0349637..8c2213ca9b3ce994f975bfd8d98c2b877535260e 100644 --- a/paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h +++ b/paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h @@ -31,17 +31,16 @@ namespace cub = hipcub; #include "paddle/fluid/framework/array.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" -#include "paddle/fluid/operators/cast_op.h" #include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/fast_divmod.h" +#include "paddle/fluid/string/string_helper.h" -#include "paddle/fluid/operators/kernel_primitives/compute_primitives.h" #include "paddle/pten/api/ext/dispatch.h" -#include "paddle/pten/api/include/tensor.h" +#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/kernels/gpu/utils.h" #include "paddle/pten/kernels/hybird/math/cast_func.h" @@ -49,11 +48,11 @@ namespace cub = hipcub; #define REDUCE_SPLIT_BOUNDARY 512 #define REDUCE_VEC_SIZE 4 -namespace pten { -namespace detail { - namespace kps = paddle::operators::kernel_primitives; +namespace pten { +namespace kernels { + namespace details { static inline int GetLastPow2(int n) { @@ -68,11 +67,11 @@ static inline int GetLastPow2(int n) { static inline int64_t AlignUp(int64_t a, int64_t b) { return (a + b - 1) / b; } // get strides of x_dim, reduce_dim and left_dim for reduceLastDim and reduceAny -static inline std::vector GetDimStrides( - const std::vector& dims, const std::vector& idx) { +static inline std::vector GetDimStrides(const std::vector& dims, + const std::vector& idx) { int n = static_cast(idx.size()); - if (n == 0) return std::vector(); - std::vector strides(n); + if (n == 0) return std::vector(); + std::vector strides(n); strides.back() = 1; for (int i = n - 2; i >= 0; --i) { strides[i] = strides[i + 1] * dims[idx[i + 1]]; @@ -133,9 +132,34 @@ static inline paddle::framework::Array VectorToArray( return ret; } +static inline std::vector GetReduceDim(const std::vector& dims, + int dim_size, + bool reduce_all) { + std::vector reduce_dims; + if (reduce_all) { + reduce_dims.resize(dim_size); + int reduce_size = reduce_dims.size(); + for (int i = 0; i < reduce_size; ++i) { + reduce_dims[i] = i; + } + } else { + for (auto e : dims) { + PADDLE_ENFORCE_LT(e, + dim_size, + paddle::platform::errors::InvalidArgument( + "ReduceOp: invalid axis, when x_dims is %d, " + "axis[i] should less than x_dims, but got %d.", + dim_size, + e)); + reduce_dims.push_back(e >= 0 ? e : e + dim_size); + } + } + return reduce_dims; +} + } // namespace details -constexpr int kMaxRank = pten::DDim::kMaxRank; +constexpr int kMaxRank = paddle::framework::DDim::kMaxRank; enum ReduceType { kReduceLastDim = 0x01, // when reduce_dim[0] == x_dim.size() - 1; @@ -145,9 +169,9 @@ enum ReduceType { struct IndexCalculator { IndexCalculator(int dim, - const std::vector& cal_dims, - const std::vector& cal_strides, - const std::vector& full_strides) + const std::vector& cal_dims, + const std::vector& cal_strides, + const std::vector& full_strides) : dim(dim) { dims = details::VectorToArray(cal_dims); strides = details::VectorToArray(full_strides); @@ -275,8 +299,8 @@ struct OneDimIndexCal { // reduce config template struct ReduceConfig { - ReduceConfig(const std::vector& origin_reduce_dims, - const std::vector& origin_x_dim) + ReduceConfig(const std::vector& origin_reduce_dims, + const std::vector& origin_x_dim) : reduce_dims_origin(origin_reduce_dims), x_dim(origin_x_dim) {} // get the parameters of reduceKernel @@ -312,17 +336,17 @@ struct ReduceConfig { // eg: x_dim = [2, 4, 6] origin_reduce_dims = [0, 1] // --SetReduceDim--> x_dim = [8,6], reduce_dim = [0], left_dim = [1] void SetReduceDim() { - std::set reduce_set; + std::set reduce_set; for (auto e : reduce_dims_origin) { auto pos = e >= 0 ? e : e + x_dim.size(); reduce_set.insert(pos); } - std::vector reduce_dim_temp(reduce_set.begin(), reduce_set.end()); + std::vector reduce_dim_temp(reduce_set.begin(), reduce_set.end()); std::sort(reduce_dim_temp.begin(), reduce_dim_temp.end()); // update reduce_dim and x_dim - std::vector x_new_dim; + std::vector x_new_dim; reduce_dim.push_back(reduce_dim_temp[0]); x_new_dim.push_back(x_dim[0]); @@ -355,15 +379,15 @@ struct ReduceConfig { // update x_dim x_dim = x_new_dim; - std::vector().swap(x_new_dim); + std::vector().swap(x_new_dim); - std::vector reduce_dim_new; + std::vector reduce_dim_new; int is_reduced = 0; for (auto e : reduce_dim) { is_reduced |= 1 << e; } - std::vector().swap(reduce_dim); + std::vector().swap(reduce_dim); for (int i = 0; i < x_dim.size(); i++) { if ((i == 0) || (((is_reduced >> i) ^ (is_reduced >> (i - 1))) & 1)) { @@ -400,7 +424,7 @@ struct ReduceConfig { // --SetStrides--> x_strides= [6,1], reduce_strides = [1], // left_strides = [1] void SetStrides() { - std::vector idx_dim; + std::vector idx_dim; for (int i = 0; i < x_dim.size(); i++) { idx_dim.push_back(i); } @@ -575,13 +599,13 @@ struct ReduceConfig { } public: - std::vector reduce_dims_origin; - std::vector reduce_dim; - std::vector x_dim; - std::vector left_dim; - std::vector x_strides; - std::vector left_strides; - std::vector reduce_strides; + std::vector reduce_dims_origin; + std::vector reduce_dim; + std::vector x_dim; + std::vector left_dim; + std::vector x_strides; + std::vector left_strides; + std::vector reduce_strides; int reduce_type; int reduce_num; @@ -596,15 +620,223 @@ struct ReduceConfig { dim3 grid; }; -template +// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or +// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this +// function will be used +template +__global__ void ReduceAnyKernel(const Tx* x, + Ty* y, + ReduceOp reducer, + TransformOp transformer, + MPType init, + int reduce_num, + int left_num, + bool reduce_last_dim, + const Calculator reduce_index_calculator, + const Calculator left_index_calculator, + const kps::DimConfig dim) { + int input_idx, left_idx, stride; + int block_size = 0; + bool need_store = true; + int loop_left = 0; + int tid = 0; + // the last dim gets involved in reduction + int store_offset = 0; + int stride_left = 0; + if (reduce_last_dim) { + auto block = ReduceIndexMapping(dim); + input_idx = block.BlockIdY() * block.BlockDimX(); + left_idx = block.BlockIdX() * block.BlockDimY() + THREAD_ID_Y; + stride = block.GridDimY() * block.BlockDimX(); + block_size = block.BlockDimX(); + need_store = (THREAD_ID_X == 0) && (left_idx < left_num); + store_offset = block.BlockIdY() * left_num + left_idx; + loop_left = min(block.GetLoopSize(), left_num - left_idx); + stride_left = 1; + tid = threadIdx.x; + } else { + auto block = ReduceIndexMapping(dim); + input_idx = block.BlockIdY() * block.BlockDimY(); + left_idx = block.BlockIdX() * block.BlockDimX() + THREAD_ID_X; + stride = block.GridDimY() * block.BlockDimY(); + block_size = block.BlockDimY(); + need_store = (THREAD_ID_Y == 0) && (left_idx < left_num); + loop_left = min(block.GetLoopSize(), left_num - left_idx); + stride_left = block.BlockDimX() * block.GridDimX(); + store_offset = block.BlockIdY() * left_num + left_idx; + tid = threadIdx.y; + } + // calculate the offset, means the addr where each thread really start. + // 1. reduce for each thread + MPType input_compute[REDUCE_VEC_SIZE]; + Tx input_reg[REDUCE_VEC_SIZE]; + for (int i = 0; i < loop_left; i += stride_left) { + int input_offset = left_index_calculator(left_idx + i); + const Tx* input = x + input_offset; + MPType reduce_var = init; + // load REDUCE_VEC_SIZE data once, and then compute + int bound = reduce_num - (REDUCE_VEC_SIZE - 1) * stride; + for (; input_idx + block_size < bound; + input_idx += REDUCE_VEC_SIZE * stride) { + kps::ReadDataReduce, + false>(&input_reg[0], + input, + input_idx, + reduce_index_calculator, + 1, + reduce_num, + 1, + stride, + kps::IdentityFunctor(), + reduce_last_dim); + kps::ElementwiseUnary( + &input_compute[0], &input_reg[0], transformer); + kps::Reduce( + &reduce_var, &input_compute[0], reducer, reduce_last_dim); + } + + kps::Init(&input_compute[0], init); + kps::ReadDataReduce(&input_compute[0], + input, + input_idx, + reduce_index_calculator, + 1, + reduce_num - input_idx, + 1, + stride, + transformer, + reduce_last_dim); + kps::Reduce( + &reduce_var, &input_compute[0], reducer, reduce_last_dim); + + kps::Reduce( + &reduce_var, &reduce_var, reducer, reduce_last_dim); + if (need_store) { + y[store_offset + i] = static_cast(reduce_var); + } + } +} + +template +__global__ void ReduceHigherDimKernel(const Tx* x, + Ty* y, + ReduceOp reducer, + TransformOp transformer, + MPType init, + int reduce_num, + int left_num, + int blocking_size, + const kps::DimConfig dim) { + // when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this + // function will be used + auto block = ReduceIndexMapping(dim); + int idy = block.BlockIdY() * blocking_size; + int idx = block.BlockIdX() * block.BlockDimX(); + int idz = BLOCK_ID_Z * left_num; + int stride = dim.split_num_x * dim.deal_size_x; + int size = left_num - dim.rem_x; + int loop_size = min(reduce_num - idy, blocking_size); + int store_offset = block.BlockIdY() * left_num + idz * block.GridDimY(); + int block_offset = idy * left_num + idz * reduce_num; + const Tx* input = x + block_offset; + Tx reduce_input; + for (; idx < size; idx += stride) { + MPType reduce_var = init; + MPType reduce_compute = init; + for (int loop_idx = 0; loop_idx < loop_size; ++loop_idx) { + kps::ReadData(&reduce_input, + input + loop_idx * left_num + idx, + block.BlockDimX(), + 1, + 1, + left_num); + kps::ElementwiseUnary( + &reduce_compute, &reduce_input, transformer); + kps::Reduce( + &reduce_var, &reduce_compute, reducer, false); + } + Ty result = static_cast(reduce_var); + kps::WriteData( + y + store_offset + idx, &result, block.BlockDimX()); + } + + if (idx < left_num) { + MPType reduce_var = init; + MPType reduce_compute = init; + for (int loop_idx = 0; loop_idx < loop_size; ++loop_idx) { + kps::ReadData(&reduce_input, + input + loop_idx * left_num + idx, + dim.rem_x, + 1, + 1, + left_num); + kps::ElementwiseUnary( + &reduce_compute, &reduce_input, transformer); + kps::Reduce( + &reduce_var, &reduce_compute, reducer, false); + } + Ty result = static_cast(reduce_var); + kps::WriteData( + y + store_offset + idx, &result, dim.rem_x); + } +} + +template static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, const ReduceOp& reducer, + const TransformOp& transform, MPType init, gpuStream_t stream, ReduceConfig config) { - using TransformOp = typename ReduceOp::Transformer; - if (config.reduce_type == kReduceLastDim) { int stride_reduce = 1; int stride_left = config.reduce_num; @@ -621,35 +853,33 @@ static void LaunchReduceKernel(const Tx* x_data, dim.SetRem(config.reduce_num % config.block.x, 0, 0); #ifdef PADDLE_WITH_XPU2 - paddle::operators::ReduceAnyKernel<<<8, 128, stream>>>( - x_data, - config.output_data, - reducer, - TransformOp(config.reduce_num), - init, - config.reduce_num, - config.left_num, - config.reduce_last_dim, - reduce_index_calculator, - left_index_calculator, - dim); + ReduceAnyKernel<<<8, 128, stream>>>(x_data, + config.output_data, + reducer, + transform, + init, + config.reduce_num, + config.left_num, + config.reduce_last_dim, + reduce_index_calculator, + left_index_calculator, + dim); #else - paddle::operators::ReduceAnyKernel< - Tx, - Ty, - MPType, - ReduceOp, - TransformOp, - OneDimIndexCal><<>>( + ReduceAnyKernel<<>>( x_data, config.output_data, reducer, - TransformOp(config.reduce_num), + transform, init, config.reduce_num, config.left_num, @@ -678,16 +908,16 @@ static void LaunchReduceKernel(const Tx* x_data, dim.SetRem(config.reduce_num % config.block.x, 0, 0); #ifdef PADDLE_WITH_XPU2 - paddle::operators::ReduceAnyKernel<<<8, 128, stream>>>( + ReduceAnyKernel<<<8, 128, stream>>>( x_data, config.output_data, reducer, - TransformOp(config.reduce_num), + transform, init, config.reduce_num, config.left_num, @@ -696,17 +926,16 @@ static void LaunchReduceKernel(const Tx* x_data, left_index_calculator, dim); #else - paddle::operators::ReduceAnyKernel< - Tx, - Ty, - MPType, - ReduceOp, - TransformOp, - IndexCalculator><<>>( + ReduceAnyKernel<<>>( x_data, config.output_data, reducer, - TransformOp(config.reduce_num), + transform, init, config.reduce_num, config.left_num, @@ -734,23 +963,22 @@ static void LaunchReduceKernel(const Tx* x_data, kps::DimConfig(grid.x, grid.y, grid.z, block.x, config.grid.y, 0); dim.SetRem(config.left_num % block.x, 0, 0); #ifdef PADDLE_WITH_XPU2 - paddle::operators::ReduceHigherDimKernel< - Ty, - Ty, - MPType, - ReduceOp, - kps::IdentityFunctor><<<8, 128, stream>>>( + ReduceHigherDimKernel><<<8, 128, stream>>>( config.output_data, y_data, reducer, - kps::IdentityFunctor(config.grid.y), + kps::IdentityFunctor(), init, config.grid.y, config.left_num, config.grid.y, dim); #else - paddle::operators::ReduceHigherDimKernel< + ReduceHigherDimKernel< Ty, Ty, MPType, @@ -759,7 +987,7 @@ static void LaunchReduceKernel(const Tx* x_data, config.output_data, y_data, reducer, - kps::IdentityFunctor(config.grid.y), + kps::IdentityFunctor(), init, config.grid.y, config.left_num, @@ -769,7 +997,68 @@ static void LaunchReduceKernel(const Tx* x_data, } } -static void AsyncCopy(const DenseTensor& src, DenseTensor* dst) { +template class ReduceOp, + typename TransformOp> +static + typename std::enable_if::value, + void>::type + CubTensorReduceFunctorImpl(const Tx* x_data, + Ty* y_data, + const TransformOp& transform, + int reduce_num, + const paddle::platform::Place& place, + gpuStream_t stream) { + auto reducer = ReduceOp(); + cub::TransformInputIterator trans_x(x_data, + transform); + size_t temp_storage_bytes = 0; + cub::DeviceReduce::Reduce(nullptr, + temp_storage_bytes, + trans_x, + y_data, + reduce_num, + reducer, + reducer.initial(), + stream); + + pten::DenseTensor tmp = pten::DenseTensor( + pten::make_intrusive(place), + pten::DenseTensorMeta(pten::DataType::UINT8, + paddle::framework::make_ddim( + {static_cast(temp_storage_bytes)}))); + + auto* temp_storage = tmp.mutable_data(); + + cub::DeviceReduce::Reduce(temp_storage, + temp_storage_bytes, + trans_x, + y_data, + reduce_num, + reducer, + reducer.initial(), + stream); +} + +template class ReduceOp, + typename TransformOp> +static + typename std::enable_if::value, + void>::type + CubTensorReduceFunctorImpl(const Tx* x_data, + Ty* y_data, + const TransformOp& transform, + int reduce_num, + const paddle::platform::Place& place, + gpuStream_t stream) { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Tx should not be float16 when using cub::DeviceReduce::Reduce().")); +} + +static void AsyncCopy(const pten::DenseTensor& src, pten::DenseTensor* dst) { paddle::platform::DeviceContextPool& pool = paddle::platform::DeviceContextPool::Instance(); const paddle::platform::CUDADeviceContext* dev_ctx; @@ -788,21 +1077,25 @@ static void AsyncCopy(const DenseTensor& src, DenseTensor* dst) { template class ReduceOp> + template class ReduceOp, + typename TransformOp> void TensorReduceFunctorImpl(const pten::DenseTensor& x, pten::DenseTensor* y, - std::vector origin_reduce_dims, + const TransformOp& transform, + const std::vector& origin_reduce_dims, gpuStream_t stream) { // Allocate memory y->mutable_data(); - auto x_dim = paddle::framework::vectorize(x.dims()); + + auto x_dim = paddle::framework::vectorize(x.dims()); auto config = ReduceConfig(origin_reduce_dims, x_dim); config.Run(); - int64_t numel = x.numel(); + int numel = x.numel(); // after config.run() // SetOutputData for ReduceHigherDim when should_reduce_again is true, // temp_output should be stored temp_data in output_data space or stored in // y_data; + pten::DDim tmp_ddim; pten::DenseTensor tmp = pten::DenseTensor( pten::make_intrusive(y->place()), @@ -819,56 +1112,27 @@ void TensorReduceFunctorImpl(const pten::DenseTensor& x, AsyncCopy(x, y); y->Resize(out_dims); } else { - PD_VISIT_ALL_TYPES(y->dtype(), "CastKernelImpl", ([&] { - pten::math::CastKernelImpl( - *dev_ctx, x, y); - })); + PD_VISIT_ALL_TYPES( + y->dtype(), "CastKernelImpl", ([&] { + pten::math::CastKernelImpl(*dev_ctx, x, y); + })); } return; } config.SetOutputData(y_data, x.place(), &tmp); - bool use_cub_reduce = (config.reduce_num == numel) && - (!std::is_same::value); + constexpr bool kIsTxFP16 = std::is_same::value; + bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16; if (use_cub_reduce) { - // launch CUB::Reduce - using TransformOp = typename ReduceOp::Transformer; - auto reducer = ReduceOp(); - cub::TransformInputIterator trans_x( - x_data, TransformOp(config.reduce_num)); - size_t temp_storage_bytes = 0; - cub::DeviceReduce::Reduce(nullptr, - temp_storage_bytes, - trans_x, - y_data, - config.reduce_num, - reducer, - reducer.initial(), - stream); - // framework::Tensor tmp; - pten::DenseTensor tmp = pten::DenseTensor( - pten::make_intrusive(x.place()), - pten::DenseTensorMeta(pten::DataType::UINT8, - paddle::framework::make_ddim( - {static_cast(temp_storage_bytes)}), - x.layout())); - auto* temp_storage = tmp.mutable_data(); - cub::DeviceReduce::Reduce(temp_storage, - temp_storage_bytes, - trans_x, - y_data, - config.reduce_num, - reducer, - reducer.initial(), - stream); - + CubTensorReduceFunctorImpl( + x_data, y_data, transform, config.reduce_num, x.place(), stream); return; } - using MPType = - typename paddle::operators::kernel_primitives::details::MPTypeTrait< - Ty>::Type; - auto reducer = ReduceOp(); + using MPType = typename kps::details::MPTypeTrait::Type; + auto reducer = ReduceOp(); // launch ReduceHigherDimKernel // when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this // function will be used @@ -877,7 +1141,6 @@ void TensorReduceFunctorImpl(const pten::DenseTensor& x, // 32 // else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32 if (config.reduce_type == ReduceType::kReduceHigherDim) { - using TransformOp = typename ReduceOp::Transformer; kps::DimConfig dim = kps::DimConfig(config.grid.x, config.grid.y, config.grid.z, @@ -889,31 +1152,30 @@ void TensorReduceFunctorImpl(const pten::DenseTensor& x, 0); #ifdef PADDLE_WITH_XPU2 - paddle::operators::ReduceHigherDimKernel, - TransformOp><<<8, 128, stream>>>( - x_data, - config.output_data, - reducer, - TransformOp(config.reduce_num), - reducer.initial(), - config.reduce_num, - config.left_num, - config.blocking_size, - dim); + ReduceHigherDimKernel, + TransformOp><<<8, 128, stream>>>(x_data, + config.output_data, + reducer, + transform, + reducer.initial(), + config.reduce_num, + config.left_num, + config.blocking_size, + dim); #else - paddle::operators::ReduceHigherDimKernel< + ReduceHigherDimKernel< Tx, Ty, MPType, - ReduceOp, + ReduceOp, TransformOp><<>>( x_data, config.output_data, reducer, - TransformOp(config.reduce_num), + transform, reducer.initial(), config.reduce_num, config.left_num, @@ -929,11 +1191,11 @@ void TensorReduceFunctorImpl(const pten::DenseTensor& x, dim2.SetRem(config.left_num % config.block.x, 0, 0); #ifdef PADDLE_WITH_XPU2 - paddle::operators::ReduceHigherDimKernel< + ReduceHigherDimKernel< Ty, Ty, MPType, - ReduceOp, + ReduceOp, kps::IdentityFunctor><<<8, 128, stream>>>( config.output_data, y_data, @@ -945,11 +1207,11 @@ void TensorReduceFunctorImpl(const pten::DenseTensor& x, config.grid.y, dim2); #else - paddle::operators::ReduceHigherDimKernel< + ReduceHigherDimKernel< Ty, Ty, MPType, - ReduceOp, + ReduceOp, kps::IdentityFunctor><<>>( config.output_data, y_data, @@ -968,9 +1230,9 @@ void TensorReduceFunctorImpl(const pten::DenseTensor& x, // when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or // when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this // function will be used - LaunchReduceKernel>( - x_data, y_data, reducer, reducer.initial(), stream, config); + LaunchReduceKernel, TransformOp>( + x_data, y_data, reducer, transform, reducer.initial(), stream, config); } -} // namespace detail +} // namespace kernels } // namespace pten