diff --git a/paddle/fluid/operators/reduce_ops/reduce_all_op.cu b/paddle/fluid/operators/reduce_ops/reduce_all_op.cu index 89f3345fcbe42deb572700cb12827d79cb22d3d3..99a5caaad6ab802facaec6a3b5c4c5e2384945d4 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_all_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_all_op.cu @@ -13,7 +13,9 @@ // limitations under the License. #include "paddle/fluid/operators/reduce_ops/reduce_all_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" +// reduce_prod REGISTER_OP_CUDA_KERNEL( - reduce_all, ops::BoolReduceKernel); + reduce_all, + ops::ReduceCudaKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_any_op.cu b/paddle/fluid/operators/reduce_ops/reduce_any_op.cu index c0f94098a351ea9042e44b8550b305bb0f9d74c6..c7eafa2ac8760a3edde56a9f2411c6faaac454f1 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_any_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_any_op.cu @@ -13,7 +13,10 @@ // limitations under the License. #include "paddle/fluid/operators/reduce_ops/reduce_any_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.h" +// reduce_prod REGISTER_OP_CUDA_KERNEL( - reduce_any, ops::BoolReduceKernel); + reduce_any, + ops::ReduceCudaKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h index 5fad6efdb34961f99215fc93ec7947ed3e452889..45279a224ac8dc6895f35427496de37ec6279198 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h @@ -62,27 +62,6 @@ struct DivideFunctor { T n_inv; }; -static inline std::vector GetReduceDim(const std::vector& dims, - int dim_size, bool reduce_all) { - std::vector reduce_dims; - if (reduce_all) { - reduce_dims.resize(dim_size); - for (int i = 0; i < reduce_dims.size(); ++i) { - reduce_dims[i] = i; - } - } else { - for (auto e : dims) { - PADDLE_ENFORCE_LT(e, dim_size, - paddle::platform::errors::InvalidArgument( - "ReduceOp: invalid axis, when x_dims is %d, " - "axis[i] should less than x_dims, but got %d.", - dim_size, e)); - reduce_dims.push_back(e >= 0 ? e : e + dim_size); - } - } - return reduce_dims; -} - static inline int GetLastPow2(int n) { n |= (n >> 1); n |= (n >> 2); @@ -167,8 +146,9 @@ enum ReduceType { // reduce config template struct ReduceConfig { - ReduceConfig(std::vector origin_reduce_dims, std::vector x_dim) - : reduce_dims_origin(origin_reduce_dims), x_dim(x_dim) {} + ReduceConfig(const std::vector& origin_reduce_dims, + const std::vector& origin_x_dim) + : reduce_dims_origin(origin_reduce_dims), x_dim(origin_x_dim) {} // get the parameters of reduceKernel void Run() { @@ -530,22 +510,22 @@ __device__ __forceinline__ void ReduceAny( // module function designed for global function template + int BlockDim, int Rank, int ReduceRank> __device__ __forceinline__ void ReduceModule( const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init, - int reduce_num, int left_num, int blocking_size, + int reduce_num, int left_num, int blocking_size, int reduce_type, paddle::framework::Array x_strides, paddle::framework::Array reduce_dim, paddle::framework::Array reduce_strides, paddle::framework::Array left_dim, paddle::framework::Array left_strides) { // reduce_rank == 1 && reduce_dim[0] == x_dim.size() - 1 - if (ReduceType == ReduceType::kReduceLastDim) { + if (reduce_type == ReduceType::kReduceLastDim) { ReduceLastDim( x, y, reducer, transformer, init, reduce_num); // reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1 - } else if (ReduceType == ReduceType::kReduceHigherDim) { + } else if (reduce_type == ReduceType::kReduceHigherDim) { ReduceHigherDim( x, y, reducer, transformer, init, reduce_num, left_num, blocking_size); @@ -558,57 +538,47 @@ __device__ __forceinline__ void ReduceModule( } template + int BlockDim, int Rank, int ReduceRank> __global__ void ReduceKernelFunction( const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init, - int reduce_num, int left_num, int block_size, + int reduce_num, int left_num, int block_size, int reduce_type, paddle::framework::Array x_strides, paddle::framework::Array reduce_dim, paddle::framework::Array reduce_strides, paddle::framework::Array left_dim, paddle::framework::Array left_strides) { - ReduceModule(x, y, reducer, transformer, init, reduce_num, - left_num, block_size, x_strides, reduce_dim, - reduce_strides, left_dim, left_strides); + ReduceModule( + x, y, reducer, transformer, init, reduce_num, left_num, block_size, + reduce_type, x_strides, reduce_dim, reduce_strides, left_dim, + left_strides); } -template -static void LaunchKernel(const Tx* x_data, Ty* y_data, const ReduceOp& reducer, - const TransformOp& transformer, Ty init, - gpuStream_t stream, ReduceConfig config) { -#define CUB_REDUCE_TYPE_CASE(type) \ - case type: { \ - constexpr auto kReduceType = type; \ - ReduceKernelFunction< \ - Tx, Ty, ReduceOp, TransformOp, BlockDim, kRank, kReduceRank, \ - kReduceType><<>>( \ - x_data, config.output_data, reducer, transformer, init, \ - config.reduce_num, config.left_num, config.blocking_size, \ - detail::VectorToArray(config.x_strides), \ - detail::VectorToArray(config.reduce_dim), \ - detail::VectorToArray(config.reduce_strides), \ - detail::VectorToArray(config.left_dim), \ - detail::VectorToArray(config.left_strides)); \ - } break - - switch (config.reduce_type) { - CUB_REDUCE_TYPE_CASE(1); // reduceLastDim - CUB_REDUCE_TYPE_CASE(2); // ReduceHigherDim - CUB_REDUCE_TYPE_CASE(3); // reduceAny - } +template +static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, + const ReduceOp& reducer, Ty init, + gpuStream_t stream, ReduceConfig config) { + using TransformOp = typename ReduceOp::Transformer; + + ReduceKernelFunction<<>>( + x_data, config.output_data, reducer, TransformOp(config.reduce_num), init, + config.reduce_num, config.left_num, config.blocking_size, + config.reduce_type, detail::VectorToArray(config.x_strides), + detail::VectorToArray(config.reduce_dim), + detail::VectorToArray(config.reduce_strides), + detail::VectorToArray(config.left_dim), + detail::VectorToArray(config.left_strides)); if (config.should_reduce_again) { dim3 block(config.block.x, 1, 1); dim3 grid(config.grid.x, 1, config.grid.z); - ReduceKernelFunction< - Ty, Ty, ReduceOp, detail::IdentityFunctor, 128, kRank, kReduceRank, - ReduceType::kReduceHigherDim><<>>( + ReduceKernelFunction, 128, + kRank, kReduceRank><<>>( config.output_data, y_data, reducer, detail::IdentityFunctor(config.grid.y), init, config.grid.y, - config.left_num, config.grid.y, + config.left_num, config.grid.y, ReduceType::kReduceHigherDim, detail::VectorToArray(config.x_strides), detail::VectorToArray(config.reduce_dim), detail::VectorToArray(config.reduce_strides), @@ -617,12 +587,10 @@ static void LaunchKernel(const Tx* x_data, Ty* y_data, const ReduceOp& reducer, } } -template -static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, - const ReduceOp& reducer, - const TransformOp& transformer, Ty init, - gpuStream_t stream, ReduceConfig config) { +template +static void ReduceKernelImpl(const Tx* x_data, Ty* y_data, + const ReduceOp& reducer, Ty init, + gpuStream_t stream, ReduceConfig config) { int reduce_rank = config.reduce_strides.size(); int rank = config.x_strides.size(); @@ -632,11 +600,11 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, switch (reduce_rank) { __VA_ARGS__; } \ } break -#define CUB_REDUCE_RANK_CASE(i, ...) \ - case i: { \ - constexpr auto kReduceRank = i; \ - LaunchKernel( \ - x_data, y_data, reducer, transformer, init, stream, config); \ +#define CUB_REDUCE_RANK_CASE(i, ...) \ + case i: { \ + constexpr auto kReduceRank = i; \ + LaunchReduceKernel( \ + x_data, y_data, reducer, init, stream, config); \ } break detail::CheckReduceRank(reduce_rank, rank); @@ -671,15 +639,13 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, auto config = ReduceConfig(origin_reduce_dims, x_dim); config.Run(); // get the parameters of LaunchReduceKernel - auto x_data = x.data(); - auto y_data = y->mutable_data(x.place()); - // after config.run() // SetOutputData for ReduceHigherDim when should_reduce_again is true, // temp_output should be stored temp_data in output_data space or stored in // y_data; framework::Tensor tmp; - config.SetOutputData(y_data, x.place(), &tmp); + auto x_data = x.data(); + auto y_data = y->mutable_data(x.place()); if (config.reduce_num == 1) { auto out_dims = y->dims(); @@ -687,6 +653,9 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, y->Resize(out_dims); return; } + + config.SetOutputData(y_data, x.place(), &tmp); + using TransformOp = typename ReduceOp::Transformer; auto reducer = ReduceOp(); // launch CUB::Reduce @@ -708,12 +677,11 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, return; } -#define CUB_BLOCK_DIM_CASE(block_dim) \ - case block_dim: { \ - constexpr auto kBlockDim = block_dim; \ - LaunchReduceKernel, TransformOp>( \ - x_data, y_data, reducer, TransformOp(config.reduce_num), \ - reducer.initial(), stream, config); \ +#define CUB_BLOCK_DIM_CASE(block_dim) \ + case block_dim: { \ + constexpr auto kBlockDim = block_dim; \ + ReduceKernelImpl>( \ + x_data, y_data, reducer, reducer.initial(), stream, config); \ } break switch (detail::GetBlockDim(config.reduce_num)) { @@ -745,30 +713,5 @@ struct TensorReduceFunc { } }; -template class ReduceOp> -class ReduceCudaKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - bool reduce_all = context.Attr("reduce_all"); - const Tensor* input = context.Input("X"); - Tensor* output = context.Output("Out"); - auto out_dtype = context.Attr("out_dtype"); - std::vector dims = context.Attr>("dim"); - - std::vector reduce_dims = - detail::GetReduceDim(dims, input->dims().size(), reduce_all); - - gpuStream_t stream = context.cuda_device_context().stream(); - if (out_dtype >= 0) { - framework::VisitDataTypeSmall( - static_cast(out_dtype), - TensorReduceFunc(*input, output, reduce_dims, stream)); - } else { - TensorReduceFunctorImpl(*input, output, reduce_dims, - stream); - } - } -}; - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 390c4d9709a60f1400273062d5da52155e100853..368fedececf53336edc7b67f932408d74994d760 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -23,6 +23,9 @@ limitations under the License. */ #include "paddle/fluid/operators/cast_op.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/reduce_ops/reduce_op_function.h" +#if defined(__HIPCC__) || defined(__NVCC__) +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" +#endif namespace paddle { namespace operators { @@ -60,6 +63,27 @@ inline void GetShuffledDim(const DDim& src_dims, DDim* dst_dims, } } +static inline std::vector GetReduceDim(const std::vector& dims, + int dim_size, bool reduce_all) { + std::vector reduce_dims; + if (reduce_all) { + reduce_dims.resize(dim_size); + int reduce_size = reduce_dims.size(); + for (int i = 0; i < reduce_size; ++i) { + reduce_dims[i] = i; + } + } else { + for (auto e : dims) { + PADDLE_ENFORCE_LT(e, dim_size, + paddle::platform::errors::InvalidArgument( + "ReduceOp: invalid axis, when x_dims is %d, " + "axis[i] should less than x_dims, but got %d.", + dim_size, e)); + reduce_dims.push_back(e >= 0 ? e : e + dim_size); + } + } + return reduce_dims; +} template void GetShuffledInput(const framework::ExecutionContext& context, const Tensor* input, Tensor* shuffled_input, @@ -308,6 +332,7 @@ class BoolReduceKernel : public framework::OpKernel { } } }; + template class ReduceGradKernel : public framework::OpKernel { @@ -636,6 +661,33 @@ If reduce_all is true, just reduce along all dimensions and output a scalar. virtual std::string GetOpType() const = 0; }; +#if defined(__HIPCC__) || defined(__NVCC__) +template class ReduceOp> +class ReduceCudaKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + bool reduce_all = context.Attr("reduce_all"); + const Tensor* input = context.Input("X"); + Tensor* output = context.Output("Out"); + auto out_dtype = context.Attr("out_dtype"); + std::vector dims = context.Attr>("dim"); + + std::vector reduce_dims = + GetReduceDim(dims, input->dims().size(), reduce_all); + + gpuStream_t stream = context.cuda_device_context().stream(); + if (out_dtype >= 0) { + framework::VisitDataTypeSmall( + static_cast(out_dtype), + TensorReduceFunc(*input, output, reduce_dims, stream)); + } else { + TensorReduceFunctorImpl(*input, output, reduce_dims, + stream); + } + } +}; +#endif + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/reduce_ops/reduce_prod_op.cu b/paddle/fluid/operators/reduce_ops/reduce_prod_op.cu index 4f259e415d2220aa5a0598a4f18ca5fbfd7cf85b..317a6e1d93c2e8981bd7a54b6e4d64ccd53b9928 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_prod_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_prod_op.cu @@ -16,18 +16,8 @@ #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/reduce_ops/reduce_prod_op.h" -// reduce_prod -#ifdef __HIPCC__ -// Eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h:922 -// do not support double in HIPCC platform (Eigen3 to be fixed) -REGISTER_OP_CUDA_KERNEL( - reduce_prod, ops::ReduceCudaKernel, - ops::ReduceCudaKernel, - ops::ReduceCudaKernel); -#else REGISTER_OP_CUDA_KERNEL( reduce_prod, ops::ReduceCudaKernel, ops::ReduceCudaKernel, ops::ReduceCudaKernel, ops::ReduceCudaKernel); -#endif