diff --git a/paddle/fluid/operators/broadcast_tensors_op.cu b/paddle/fluid/operators/broadcast_tensors_op.cu index d670e1b333d411daa8e107356fdba62812a38bee..718e7ce3966217bd4b5162d8df7b5d567ac5bf7b 100644 --- a/paddle/fluid/operators/broadcast_tensors_op.cu +++ b/paddle/fluid/operators/broadcast_tensors_op.cu @@ -20,7 +20,7 @@ limitations under the License. */ #include #include -#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" namespace paddle { namespace operators { @@ -28,16 +28,6 @@ namespace operators { using framework::Tensor; using framework::DDim; -template -struct IdentityFunctor { - HOSTDEVICE explicit inline IdentityFunctor() {} - - template - HOSTDEVICE inline Tout operator()(const U& x) const { - return static_cast(x); - } -}; - template class CUDABroadcastTensorsGradOpKernel : public framework::OpKernel { public: @@ -99,9 +89,9 @@ class CUDABroadcastTensorsGradOpKernel : public framework::OpKernel { } else { // reduce_sum implementation on CUDA auto stream = context.cuda_device_context().stream(); - TensorReduce>( - *input_tensor, output_tensor, reduce_dims_vec, static_cast(0), - cub::Sum(), IdentityFunctor(), stream); + TensorReduceFunctorImpl>( + *input_tensor, output_tensor, kps::IdentityFunctor(), + reduce_dims_vec, stream); } } } diff --git a/paddle/fluid/operators/controlflow/compare_all_op.cu b/paddle/fluid/operators/controlflow/compare_all_op.cu index 8e8f3f01104f50b84d6404ad62a819d41f19e7d1..64a96ae9e8ee12bf2fecd2e3605ef3beb149ebf8 100644 --- a/paddle/fluid/operators/controlflow/compare_all_op.cu +++ b/paddle/fluid/operators/controlflow/compare_all_op.cu @@ -15,20 +15,16 @@ limitations under the License. */ #include #include "paddle/fluid/operators/controlflow/compare_all_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" -#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" namespace paddle { namespace operators { template -struct IdentityFunctor { - HOSTDEVICE explicit inline IdentityFunctor() {} - HOSTDEVICE inline T operator()(const T& x) const { return x; } -}; - struct BitwiseAdd { // Bitwise add operator, returns a + b - template + inline T initial() { return static_cast(true); } + __host__ __device__ __forceinline__ T operator()(const T& a, const T& b) const { return a & b; @@ -67,9 +63,9 @@ class CompareReduceOpKernel reduce_dims.resize(tmp.dims().size()); for (int i = 0; i < reduce_dims.size(); ++i) reduce_dims[i] = i; auto stream = context.cuda_device_context().stream(); - TensorReduce>( - tmp, z, reduce_dims, true, BitwiseAdd(), IdentityFunctor(), - stream); + TensorReduceFunctorImpl>( + tmp, z, kps::IdentityFunctor(), reduce_dims, stream); } } }; diff --git a/paddle/fluid/operators/fused/attn_bias_add.cu.h b/paddle/fluid/operators/fused/attn_bias_add.cu.h index 990ac8dbc81216e9d1a7d92891676482b7b9ab15..7666ea7aee23cb1351e9af243d6a9ecd870b8568 100644 --- a/paddle/fluid/operators/fused/attn_bias_add.cu.h +++ b/paddle/fluid/operators/fused/attn_bias_add.cu.h @@ -33,7 +33,7 @@ namespace cub = hipcub; #include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" -#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/platform/fast_divmod.h" namespace paddle { @@ -41,8 +41,6 @@ namespace operators { #define MAX_INPUT_NUM 2 -namespace kps = paddle::operators::kernel_primitives; - template using CudnnDataType = platform::CudnnDataType; template diff --git a/paddle/fluid/operators/kron_op.h b/paddle/fluid/operators/kron_op.h index ea2050fe8e61e7d36c40760e66eb6b3def8d3246..6039d8c624052ec4bde7704237548b98dbdaa276 100644 --- a/paddle/fluid/operators/kron_op.h +++ b/paddle/fluid/operators/kron_op.h @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/for_range.h" #if defined(__NVCC__) || defined(__HIPCC__) -#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "thrust/device_vector.h" #endif @@ -237,15 +237,6 @@ struct KronGradElemFunctor> { const int ndims_; }; -struct IdentityFunctor { - HOSTDEVICE explicit inline IdentityFunctor() {} - - template - HOSTDEVICE inline U operator()(const U& x) const { - return x; - } -}; - template struct KronGradOpFunctor { void operator()(const DeviceContext& dev_ctx, const framework::Tensor& dout, @@ -314,14 +305,12 @@ struct KronGradOpFunctor { #if defined(__NVCC__) || defined(__HIPCC__) auto stream = dev_ctx.stream(); // it is a cuda device_context if (dx) { - TensorReduce( - dout_x, dx, {1}, static_cast(0), cub::Sum(), IdentityFunctor(), - stream); + TensorReduceFunctorImpl>( + dout_x, dx, kps::IdentityFunctor(), {1}, stream); } if (dy) { - TensorReduce( - dout_y, dy, {1}, static_cast(0), cub::Sum(), IdentityFunctor(), - stream); + TensorReduceFunctorImpl>( + dout_y, dy, kps::IdentityFunctor(), {1}, stream); } #else auto* place = dev_ctx.eigen_device(); diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index ee95881caa9c53b07dd52e101fb924c7566ac6f6..fc0f1416cc13896d24406fe471504d7badad7a61 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -31,7 +31,7 @@ limitations under the License. */ #include "paddle/pten/include/linalg.h" #if defined(__NVCC__) || defined(__HIPCC__) -#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #endif namespace paddle { @@ -39,24 +39,14 @@ namespace operators { using framework::Tensor; -struct IdentityFunctor { - HOSTDEVICE explicit inline IdentityFunctor() {} - - template - HOSTDEVICE inline U operator()(const U& x) const { - return x; - } -}; - template void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output, const std::vector& reduce_dims, const paddle::framework::ExecutionContext& ctx) { #if defined(__NVCC__) || defined(__HIPCC__) auto stream = ctx.cuda_device_context().stream(); - TensorReduce(*input, output, reduce_dims, - static_cast(0), cub::Sum(), - IdentityFunctor(), stream); + TensorReduceFunctorImpl>( + *input, output, kps::IdentityFunctor(), reduce_dims, stream); #else ReduceKernelFunctor( input, output, reduce_dims, true, false, ctx) diff --git a/paddle/fluid/operators/reduce_ops/check_reduce_rank_test.cu b/paddle/fluid/operators/reduce_ops/check_reduce_rank_test.cu index 7efdff934239a0643da8ff57911492139bac3a9c..4464b9571255776d69589fd72bf99f589ca7aad4 100644 --- a/paddle/fluid/operators/reduce_ops/check_reduce_rank_test.cu +++ b/paddle/fluid/operators/reduce_ops/check_reduce_rank_test.cu @@ -13,11 +13,11 @@ // limitations under the License. #include "gtest/gtest.h" -#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" namespace paddle { namespace operators { -namespace detail { +namespace details { TEST(test_reduce_rank_check, all) { using EnforceNotMet = paddle::platform::EnforceNotMet; @@ -39,15 +39,15 @@ TEST(test_reduce_rank_check, all) { } if (is_valid) { - CheckReduceRankIsValid(reduce_rank, rank); + CheckReduceRank(reduce_rank, rank); } else { - ASSERT_THROW(CheckReduceRankIsValid(reduce_rank, rank), + ASSERT_THROW(CheckReduceRank(reduce_rank, rank), paddle::platform::EnforceNotMet); } } } } -} // namespace detail +} // namespace details } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/reduce_ops/cub_reduce.h b/paddle/fluid/operators/reduce_ops/cub_reduce.h deleted file mode 100644 index 0aab680e13dc1e570f39773cea6370a31bf1ccea..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/reduce_ops/cub_reduce.h +++ /dev/null @@ -1,468 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include - -#ifdef __NVCC__ -#include "cub/cub.cuh" // NOLINT -#endif - -#ifdef __HIPCC__ -#include -namespace cub = hipcub; -#endif - -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/operators/amp/fp16_type_traits.h" - -namespace paddle { -namespace operators { - -namespace detail { -template -struct Array { - public: - HOSTDEVICE inline Array() {} - - HOSTDEVICE inline T& operator[](size_t index) { return data_[index]; } - - HOSTDEVICE inline const T& operator[](size_t index) const { - return data_[index]; - } - - HOSTDEVICE constexpr inline size_t size() const { return ElementCount; } - - template - static inline Array From(const VectorLikeType& vec) { - PADDLE_ENFORCE_EQ(vec.size(), ElementCount, - platform::errors::InvalidArgument( - "Cub reduce Array: size not match. Received " - "vec.size() %d != ElementCount %d.", - vec.size(), ElementCount)); - size_t n = static_cast(vec.size()); - Array ret; - for (size_t i = 0; i < n; ++i) ret[i] = vec[i]; - return ret; - } - - private: - T data_[ElementCount]; -}; - -// reduce the 1d array to one element -template -__global__ void ReduceKernel1D(const Tx* x, Ty* y, ReduceOp reducer, - TransformOp transformer, MPType init, - int reduce_num) { - int thread_id = blockIdx.x * blockDim.x + threadIdx.x; - - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - MPType local_data = init; - for (int i = thread_id; i < reduce_num; i += gridDim.x * blockDim.x) { - local_data = static_cast( - reducer(local_data, static_cast(transformer(x[i])))); - } - __syncthreads(); - - local_data = BlockReduce(temp_storage).Reduce(local_data, reducer); - - if (threadIdx.x == 0) { - y[blockIdx.x] = static_cast(local_data); - } -} - -// reduce the last axis of 2d array -template -__global__ void ReduceKernel2D(const Tx* x, Ty* y, ReduceOp reducer, - TransformOp transformer, MPType init, - int reduce_num) { - __shared__ - typename cub::BlockReduce::TempStorage temp_storage; - int idx_x = blockIdx.x * reduce_num; - int idx_y = threadIdx.x; - MPType reduce_var = init; - for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += BlockDim) - reduce_var = - reducer(reduce_var, static_cast(transformer(x[idx_x + idx_y]))); - __syncthreads(); - - reduce_var = cub::BlockReduce(temp_storage) - .Reduce(reduce_var, reducer); - - if (threadIdx.x == 0) { - y[blockIdx.x] = static_cast(reduce_var); - } -} - -template -__global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer, - TransformOp transformer, MPType init, - int reduce_num, Array x_strides, - Array reduce_dim, - Array reduce_strides, - Array left_dim, - Array left_strides) { - __shared__ - typename cub::BlockReduce::TempStorage temp_storage; - Array sub_index; - int left_idx = blockIdx.x; - for (int i = 0; i < Rank - ReduceRank; ++i) { - sub_index[left_dim[i]] = left_idx / left_strides[i]; - left_idx %= left_strides[i]; - } - - int reduce_idx = threadIdx.x; - for (int j = 0; j < ReduceRank; ++j) { - sub_index[reduce_dim[j]] = reduce_idx / reduce_strides[j]; - reduce_idx %= reduce_strides[j]; - } - - int idx_x = 0; - for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]); - MPType reduce_var = static_cast(transformer(x[idx_x])); - - for (int i = threadIdx.x + BlockDim; i < reduce_num; i += BlockDim) { - int reduce_idx = i; - for (int j = 0; j < ReduceRank; ++j) { - sub_index[reduce_dim[j]] = reduce_idx / reduce_strides[j]; - reduce_idx %= reduce_strides[j]; - } - - int idx_x = 0; - for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]); - reduce_var = static_cast( - reducer(reduce_var, static_cast(transformer(x[idx_x])))); - } - __syncthreads(); - - reduce_var = cub::BlockReduce(temp_storage) - .Reduce(reduce_var, reducer); - - if (threadIdx.x == 0) { - y[blockIdx.x] = static_cast(reduce_var); - } -} - -static inline std::vector GetStrides(const std::vector& dims) { - int n = static_cast(dims.size()); - if (n == 0) return std::vector(); - std::vector strides(n); - strides.back() = 1; - for (int i = n - 2; i >= 0; --i) { - strides[i] = strides[i + 1] * dims[i + 1]; - } - return strides; -} - -static inline std::vector GetStrides(const std::vector& dims, - const std::vector& idx) { - int n = static_cast(idx.size()); - if (n == 0) return std::vector(); - std::vector strides(n); - strides.back() = 1; - for (int i = n - 2; i >= 0; --i) { - strides[i] = strides[i + 1] * dims[idx[i + 1]]; - } - return strides; -} - -#ifdef __HIPCC__ -constexpr int kMaxBlockDim = 256; -#else -constexpr int kMaxBlockDim = 512; -#endif - -static inline int GetDesiredBlockDim(int block_dim) { - return block_dim >= kMaxBlockDim - ? kMaxBlockDim - : (1 << static_cast(std::log2(block_dim))); -} - -static inline void CheckReduceRankIsValid(int reduce_rank, int rank) { - if (rank % 2 == 0) { - PADDLE_ENFORCE_EQ(reduce_rank, rank / 2, - platform::errors::InvalidArgument( - "ReduceOp: invalid reduce rank. When rank = %d, " - "reduce_rank must be %d, but got %d.", - rank, rank / 2, reduce_rank)); - } else { - auto lower_rank = (rank - 1) / 2; - auto upper_rank = (rank + 1) / 2; - PADDLE_ENFORCE_EQ( - reduce_rank == lower_rank || reduce_rank == upper_rank, true, - platform::errors::InvalidArgument( - "ReduceOp: invalid reduce rank. When rank = %d, reduce_rank " - "must be %d or %d, but got %d.", - rank, lower_rank, upper_rank, reduce_rank)); - } -} - -template -typename std::enable_if::value, - void>::type -LaunchCubReduceKernel(const Tx* x_data, Ty* y_data, - const platform::Place& place, const ReduceOp& reducer, - const TransformOp& transformer, const MPType& init, - int reduce_num, gpuStream_t stream) { - cub::TransformInputIterator trans_x(x_data, - transformer); - size_t temp_storage_bytes = 0; - cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data, - reduce_num, reducer, init, stream); - framework::Tensor tmp; - auto* temp_storage = tmp.mutable_data( - framework::make_ddim({static_cast(temp_storage_bytes)}), place); - cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data, - reduce_num, reducer, init, stream); -} - -template -typename std::enable_if::value, - void>::type -LaunchCubReduceKernel(const Tx* x_data, Ty* y_data, - const platform::Place& place, const ReduceOp& reducer, - const TransformOp& transformer, const MPType& init, - int reduce_num, gpuStream_t stream) { - int element_per_block = BlockDim * 10; - int block_per_grid = (reduce_num + element_per_block - 1) / element_per_block; - - framework::Tensor tmp; - auto* temp_storage = tmp.mutable_data( - framework::make_ddim( - {static_cast(block_per_grid * sizeof(MPType))}), - place); - - // each block reduce number to interim result - ReduceKernel1D<<>>( - x_data, temp_storage, reducer, transformer, init, reduce_num); - // reduce all number to final result - ReduceKernel1D<<<1, BlockDim, 0, stream>>>( - temp_storage, y_data, reducer, transformer, init, block_per_grid); -} - -template -static void TensorReduceImpl( - const Tx* x_data, Ty* y_data, const platform::Place& place, - const ReduceOp& reducer, const TransformOp& transformer, const Ty& init, - int left_num, int reduce_num, const std::vector& x_strides, - const std::vector& reduce_dim, const std::vector& reduce_strides, - const std::vector& left_dim, const std::vector& left_strides, - gpuStream_t stream) { - using MPType = typename details::MPTypeTrait::Type; - MPType init_mp = static_cast(init); - -#define CUB_RANK_CASE(i, ...) \ - case i: { \ - constexpr auto kRank = i; \ - switch (reduce_rank) { __VA_ARGS__; } \ - } break - -#define CUB_REDUCE_RANK_CASE(i, ...) \ - case i: { \ - constexpr auto kReduceRank = i; \ - ReduceKernel<<>>( \ - x_data, y_data, reducer, transformer, init_mp, reduce_num, \ - Array::From(x_strides), \ - Array::From(reduce_dim), \ - Array::From(reduce_strides), \ - Array::From(left_dim), \ - Array::From(left_strides)); \ - } break - - int rank = x_strides.size(); - int reduce_rank = reduce_strides.size(); - if (rank == reduce_rank) { - LaunchCubReduceKernel( - x_data, y_data, place, reducer, transformer, init_mp, reduce_num, - stream); - return; - } - if (rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) { - ReduceKernel2D<<>>( - x_data, y_data, reducer, transformer, init_mp, reduce_num); - return; - } - /* - if (rank == 3 && reduce_rank == 1 && reduce_dim[0] == 1) { - // TODO(liangdun): we can optimize 3d case which the 2nd axis is reduced. - // Currently, it is handled by code below, but inefficient - return; - } - */ - - /** - * Since we have combined the adjacent reduce dimensions inside TensorReduce, - * The reduce ranks and non-reduce ranks must be interleaving. That is to say, - * the rank of Tensor must be `1010...` or `0101...` where 1 represents that - * the dimension is about to be reduced. - * - * Therefore, - * If rank is odd, only need to switch-case (rank - 1)/2 and (rank + 1)/2. - * If rank is even, only need to switch-case rank/2. - * - * The total switch-case numbers reduce from 1+2+3+...+8=36 to (1+2)*4=12, - * it would speed up compiling and make the binary size lower. - */ - CheckReduceRankIsValid(reduce_rank, rank); - switch (rank) { - CUB_RANK_CASE(2, CUB_REDUCE_RANK_CASE(1);); - - CUB_RANK_CASE(3, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2);); - - CUB_RANK_CASE(4, CUB_REDUCE_RANK_CASE(2);); - - CUB_RANK_CASE(5, CUB_REDUCE_RANK_CASE(2); CUB_REDUCE_RANK_CASE(3);); - - CUB_RANK_CASE(6, CUB_REDUCE_RANK_CASE(3);); - - CUB_RANK_CASE(7, CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4);); - - CUB_RANK_CASE(8, CUB_REDUCE_RANK_CASE(4);); - - CUB_RANK_CASE(9, CUB_REDUCE_RANK_CASE(4); CUB_REDUCE_RANK_CASE(5);); - } - -#undef CUB_REDUCE_RANK_CASE -#undef CUB_RANK_CASE -} - -} // namespace detail - -template -void TensorReduce(const framework::Tensor& x, framework::Tensor* y, - std::vector origin_reduce_dims, const Ty& init, - const ReduceOp& reducer, const TransformOp& transformer, - gpuStream_t stream) { - auto x_dim = framework::vectorize(x.dims()); - std::vector new_x_dim, new_reduce_dims; - int is_reduced = 0; - for (auto e : origin_reduce_dims) { - auto pos = e >= 0 ? e : e + x_dim.size(); - is_reduced |= 1 << e; - } - for (int i = 0; i < x_dim.size(); i++) { - if ((i == 0) || (((is_reduced >> i) ^ (is_reduced >> (i - 1))) & 1)) { - new_x_dim.push_back(x_dim[i]); - if ((is_reduced >> i) & 1) - new_reduce_dims.push_back(new_x_dim.size() - 1); - } else { - new_x_dim[new_x_dim.size() - 1] *= x_dim[i]; - } - } - x_dim = new_x_dim; - origin_reduce_dims = new_reduce_dims; - int x_rank = static_cast(x_dim.size()); - std::set left_set, reduce_set; - for (int i = 0; i < x_rank; ++i) left_set.insert(i); - - for (auto e : origin_reduce_dims) { - left_set.erase(e); - reduce_set.insert(e); - } - - std::vector reduce_dim(reduce_set.begin(), reduce_set.end()); - std::vector left_dim(left_set.begin(), left_set.end()); - - std::vector x_strides = detail::GetStrides(x_dim); - std::vector reduce_strides = detail::GetStrides(x_dim, reduce_dim); - std::vector left_strides = detail::GetStrides(x_dim, left_dim); - int reduce_num = reduce_strides[0] * x_dim[reduce_dim[0]]; - int left_num = 1; - if (left_dim.size()) left_num = left_strides[0] * x_dim[left_dim[0]]; - - std::vector y_dim(left_dim.size()); - for (int i = 0; i < left_dim.size(); ++i) { - y_dim[i] = x_dim[left_dim[i]]; - } - auto x_data = x.data(); - auto y_data = y->mutable_data(x.place()); - if (reduce_num == 1) { - auto out_dims = y->dims(); - framework::TensorCopy(x, y->place(), y); - y->Resize(out_dims); - return; - } - -#define CUB_BLOCK_DIM_CASE(block_dim) \ - case block_dim: { \ - constexpr auto kBlockDim = block_dim; \ - detail::TensorReduceImpl( \ - x_data, y_data, x.place(), reducer, transformer, init, left_num, \ - reduce_num, x_strides, reduce_dim, reduce_strides, left_dim, \ - left_strides, stream); \ - } break - - switch (detail::GetDesiredBlockDim(reduce_num)) { - CUB_BLOCK_DIM_CASE(512); - CUB_BLOCK_DIM_CASE(256); - CUB_BLOCK_DIM_CASE(128); - CUB_BLOCK_DIM_CASE(64); - CUB_BLOCK_DIM_CASE(32); - CUB_BLOCK_DIM_CASE(16); - CUB_BLOCK_DIM_CASE(8); - CUB_BLOCK_DIM_CASE(4); - CUB_BLOCK_DIM_CASE(2); - } -#undef CUB_BLOCK_DIM_CASE -} - -template class TransformOp> -struct TensorReduceFunctor { - const framework::Tensor& x; - framework::Tensor* y; - std::vector origin_reduce_dims; - const double& init; - const ReduceOp& reducer; - gpuStream_t stream; - TensorReduceFunctor(const framework::Tensor& x, framework::Tensor* y, - std::vector origin_reduce_dims, const double& init, - const ReduceOp& reducer, gpuStream_t stream) - : x(x), - y(y), - origin_reduce_dims(origin_reduce_dims), - init(init), - reducer(reducer), - stream(stream) {} - - template - - void apply() const { - const Ty& init_cast = static_cast(init); - TensorReduce>(x, y, origin_reduce_dims, - init_cast, reducer, - TransformOp(), stream); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/reduce_ops/frobenius_norm_op.cu b/paddle/fluid/operators/reduce_ops/frobenius_norm_op.cu index 1ff645dfeb653c5fafa2ae2ca058e780a93a0764..b2cef09df94368d17171d5fb79fbc5e6ad332fe1 100644 --- a/paddle/fluid/operators/reduce_ops/frobenius_norm_op.cu +++ b/paddle/fluid/operators/reduce_ops/frobenius_norm_op.cu @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" #include "paddle/fluid/operators/reduce_ops/frobenius_norm_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" template using CUDAFrobeniusNormKernel = diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu index 419b8ce276526ba225782660b6c096284ae1d416..c629663b19ebd7f42f3a16e69bd4b46784ff67dd 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" template diff --git a/paddle/fluid/operators/solve_op.h b/paddle/fluid/operators/solve_op.h index ec72269f697e87b5cb957682312d0a1fa7a8d506..0acef78484cd32362a4f7f4de9402b6a4800f5ce 100644 --- a/paddle/fluid/operators/solve_op.h +++ b/paddle/fluid/operators/solve_op.h @@ -26,7 +26,7 @@ limitations under the License. */ #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" #include "paddle/fluid/operators/squeeze_op.h" #if defined(__NVCC__) || defined(__HIPCC__) -#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #endif #define MAX_RANK_SUPPORTED 6 @@ -39,24 +39,14 @@ using framework::To32BitIndex; constexpr int kMULMKLDNNINT8 = 1; -struct IdentityFunctor { - HOSTDEVICE explicit inline IdentityFunctor() {} - - template - HOSTDEVICE inline U operator()(const U& x) const { - return x; - } -}; - template void ReduceSumForSolve(const Tensor* input, Tensor* output, const std::vector& reduce_dims, bool keep_dim, const paddle::framework::ExecutionContext& ctx) { #if defined(__NVCC__) || defined(__HIPCC__) auto stream = ctx.cuda_device_context().stream(); - TensorReduce(*input, output, reduce_dims, - static_cast(0), cub::Sum(), - IdentityFunctor(), stream); + TensorReduceFunctorImpl>( + *input, output, kps::IdentityFunctor(), reduce_dims, stream); #else ReduceKernelFunctor( input, output, reduce_dims, keep_dim, false, ctx)