From 480b284c21ce2ef8fab8d0cc4cf7f87c1ad390e9 Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Tue, 22 Jun 2021 13:47:04 +0800 Subject: [PATCH] modified reduce_max, reduce_min, reduce_prod to higher_performance implementation. (#32974) --- .../operators/reduce_ops/reduce_functor_op.h | 84 +++- .../operators/reduce_ops/reduce_max_op.cu | 20 +- .../operators/reduce_ops/reduce_min_op.cu | 20 +- .../{reduce_op.cuh => reduce_op.cu.h} | 374 ++++++++++++------ .../operators/reduce_ops/reduce_prod_op.cu | 28 +- 5 files changed, 349 insertions(+), 177 deletions(-) rename paddle/fluid/operators/reduce_ops/{reduce_op.cuh => reduce_op.cu.h} (64%) diff --git a/paddle/fluid/operators/reduce_ops/reduce_functor_op.h b/paddle/fluid/operators/reduce_ops/reduce_functor_op.h index f4ea18edb2a..0f02be21cc9 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_functor_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_functor_op.h @@ -13,46 +13,98 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include -#include - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/amp/fp16_type_traits.h" -#include "paddle/fluid/platform/device_context.h" +#include +#include +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/platform/hostdevice.h" -#include "paddle/fluid/platform/macros.h" +#ifdef __HIPCC__ +#include +#endif namespace paddle { namespace operators { -template +template struct CustomMin { - __device__ __forceinline__ T operator()(const T &a, const T &b) const { + using Transformer = detail::IdentityFunctor; + + inline Ty initial() { + return static_cast(std::numeric_limits::max()); + } + + __device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const { return (b < a) ? b : a; } }; -template +template struct CustomMax { - __device__ __forceinline__ T operator()(const T &a, const T &b) const { + using Transformer = detail::IdentityFunctor; + + inline Ty initial() { + return static_cast(std::numeric_limits::lowest()); + } + + __device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const { return (b > a) ? b : a; } }; -template +// for cub::Reduce +template struct CustomSum { - __device__ __forceinline__ T operator()(const T &a, const T &b) const { + using Transformer = detail::IdentityFunctor; + + inline Ty initial() { return static_cast(0.0f); } + + __device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const { return b + a; } }; -template +template +struct CustomMean { + using Transformer = detail::DivideFunctor; + + inline Ty initial() { return static_cast(0.0f); } + + __device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const { + return b + a; + } +}; + +template struct CustomMul { - __device__ __forceinline__ T operator()(const T &a, const T &b) const { + using Transformer = detail::IdentityFunctor; + + inline Ty initial() { return static_cast(1.0f); } + + __device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const { return b * a; } }; +template +struct CustomLogicalOr { + using Transformer = detail::IdentityFunctor; + + inline Ty initial() { return static_cast(false); } + + __device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const { + return b || a; + } +}; + +template +struct CustomLogicalAnd { + using Transformer = detail::IdentityFunctor; + + inline Ty initial() { return static_cast(true); } + + __device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const { + return b && a; + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/reduce_ops/reduce_max_op.cu b/paddle/fluid/operators/reduce_ops/reduce_max_op.cu index 832112ede83..f214fcba199 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_max_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_max_op.cu @@ -11,15 +11,13 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.h" -#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h" - -REGISTER_OP_CUDA_KERNEL(reduce_max, - ops::ReduceKernel, - ops::ReduceKernel, - ops::ReduceKernel, - ops::ReduceKernel); +// reduce_max +REGISTER_OP_CUDA_KERNEL( + reduce_max, ops::ReduceCudaKernel, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_min_op.cu b/paddle/fluid/operators/reduce_ops/reduce_min_op.cu index 7b2706866f5..7806df284d8 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_min_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_min_op.cu @@ -11,15 +11,13 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.h" -#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h" - -REGISTER_OP_CUDA_KERNEL(reduce_min, - ops::ReduceKernel, - ops::ReduceKernel, - ops::ReduceKernel, - ops::ReduceKernel); +// reduce_min +REGISTER_OP_CUDA_KERNEL( + reduce_min, ops::ReduceCudaKernel, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.cuh b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h similarity index 64% rename from paddle/fluid/operators/reduce_ops/reduce_op.cuh rename to paddle/fluid/operators/reduce_ops/reduce_op.cu.h index 91d7fb7c843..5fad6efdb34 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.cuh +++ b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h @@ -30,32 +30,59 @@ namespace cub = hipcub; #endif #include "paddle/fluid/framework/array.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" +// Reduce split or not, Whether to use ReduceHigherDim +#define REDUCE_SPLIT_BOUNDARY 512 + namespace paddle { namespace operators { namespace detail { // Post processing function for sum, max, min, prod, any -template +template struct IdentityFunctor { - DEVICE explicit inline IdentityFunctor() {} + HOSTDEVICE explicit inline IdentityFunctor(int n) {} - DEVICE inline T operator()(const T& x) const { return x; } + HOSTDEVICE inline Ty operator()(const Tx& x) const { + return static_cast(x); + } }; // Post processing function for mean template struct DivideFunctor { - DEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {} + HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {} - DEVICE inline T operator()(const T& x) const { return x * n_inv; } + HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; } private: T n_inv; }; +static inline std::vector GetReduceDim(const std::vector& dims, + int dim_size, bool reduce_all) { + std::vector reduce_dims; + if (reduce_all) { + reduce_dims.resize(dim_size); + for (int i = 0; i < reduce_dims.size(); ++i) { + reduce_dims[i] = i; + } + } else { + for (auto e : dims) { + PADDLE_ENFORCE_LT(e, dim_size, + paddle::platform::errors::InvalidArgument( + "ReduceOp: invalid axis, when x_dims is %d, " + "axis[i] should less than x_dims, but got %d.", + dim_size, e)); + reduce_dims.push_back(e >= 0 ? e : e + dim_size); + } + } + return reduce_dims; +} + static inline int GetLastPow2(int n) { n |= (n >> 1); n |= (n >> 2); @@ -65,8 +92,9 @@ static inline int GetLastPow2(int n) { return std::max(1, n - (n >> 1)); } -static inline std::vector GetStrides(const std::vector& dims, - const std::vector& idx) { +// get strides of x_dim, reduce_dim and left_dim for reduceLastDim and reduceAny +static inline std::vector GetDimStrides(const std::vector& dims, + const std::vector& idx) { int n = static_cast(idx.size()); if (n == 0) return std::vector(); std::vector strides(n); @@ -78,18 +106,18 @@ static inline std::vector GetStrides(const std::vector& dims, } #ifdef __HIPCC__ -constexpr int kMaxBlockDim = 256; +constexpr int kMaxThread = 256; #else -constexpr int kMaxBlockDim = 512; +constexpr int kMaxThread = 128; #endif -static inline int GetDesiredBlockDim(int block_dim) { - return block_dim >= kMaxBlockDim - ? kMaxBlockDim - : (1 << static_cast(std::log2(block_dim))); +// get blockDim for reduceLastDim and reduceAny +static inline int GetBlockDim(int block_dim) { + return block_dim >= kMaxThread ? kMaxThread : GetLastPow2(block_dim); } -static inline void CheckReduceRankIsValid(int reduce_rank, int rank) { +// check reduce rand is valid +static inline void CheckReduceRank(int reduce_rank, int rank) { if (rank % 2 == 0) { PADDLE_ENFORCE_EQ(reduce_rank, rank / 2, platform::errors::InvalidArgument( @@ -108,8 +136,9 @@ static inline void CheckReduceRankIsValid(int reduce_rank, int rank) { } } +// convert dims from vector to array template -static inline paddle::framework::Array from( +static inline paddle::framework::Array VectorToArray( const VectorLikeType& vec) { PADDLE_ENFORCE_EQ(vec.size(), ElementCount, platform::errors::InvalidArgument( @@ -118,17 +147,21 @@ static inline paddle::framework::Array from( vec.size(), ElementCount)); size_t n = static_cast(vec.size()); paddle::framework::Array ret; - for (size_t i = 0; i < n; ++i) ret[i] = vec[i]; + for (size_t i = 0; i < n; ++i) { + ret[i] = vec[i]; + } return ret; } } // namespace detail +using Tensor = framework::Tensor; + enum ReduceType { - kReduceAll = 0x00, - kReduceLastDim = 0x01, + kReduceAll = 0x00, // when reduce_rank == x_rank + kReduceLastDim = 0x01, // when reduce_dim[0] == x_dim.size() - 1; kReduceHigherDim = 0x02, // ReduceFirstDim or reduceSecondDim - kReduceAny = 0x03, + kReduceAny = 0x03, // when reduce_dim.size() > 1 }; // reduce config @@ -141,21 +174,24 @@ struct ReduceConfig { void Run() { // step1: update the reduce_dim left_dim and x_dim SetReduceDim(); + // step2: get the strides of dim for reduceAny and reduceLastDim SetStrides(); + // step3: get the type of reduce SetReduceType(); + // step4: set the block and grid for launch kernel SetBlockDim(); } // when should_reduce_again is true, we need malloc temp space for temp data void SetOutputData(Ty* y_data, const platform::Place& place, - framework::Tensor& tmp) { + framework::Tensor* tmp) { if (should_reduce_again) { - output_data = tmp.mutable_data( + output_data = tmp->mutable_data( framework::make_ddim( - {static_cast(left_num * grid.y * sizeof(Ty))}), + {static_cast(left_num * grid.z * grid.y * sizeof(Ty))}), place); } else { output_data = y_data; @@ -168,50 +204,70 @@ struct ReduceConfig { // --SetReduceDim--> x_dim = [8,6], reduce_dim = [0], left_dim = [1] void SetReduceDim() { std::set reduce_set; - for (auto e : reduce_dims_origin) { auto pos = e >= 0 ? e : e + x_dim.size(); reduce_set.insert(pos); } + std::vector reduce_dim_temp(reduce_set.begin(), reduce_set.end()); std::sort(reduce_dim_temp.begin(), reduce_dim_temp.end()); - // get reduce_dim + + // update reduce_dim and x_dim + std::vector x_new_dim; + + reduce_dim.push_back(reduce_dim_temp[0]); + x_new_dim.push_back(x_dim[0]); + + int idx_reduce = 1; + int num = 0; + if (reduce_dim_temp.size() > 1) { - int num = 0; // for update axis - reduce_dim.push_back(reduce_dim_temp[0]); - for (int idx = 1; idx < reduce_dim_temp.size(); idx++) { - // update x_dim - if (reduce_dim_temp[idx] - reduce_dim_temp[idx - 1] == 1) { - x_dim[reduce_dim_temp[idx - 1]] *= x_dim[reduce_dim_temp[idx]]; - x_dim.erase(x_dim.begin() + reduce_dim_temp[idx]); - num++; + for (int i = 1; i < x_dim.size(); i++) { + if ((idx_reduce < reduce_dim_temp.size()) && + (i == reduce_dim_temp[idx_reduce])) { + int result = + reduce_dim_temp[idx_reduce] - reduce_dim[reduce_dim.size() - 1]; + bool is_equal = ((result - num) == 1); + if (is_equal) { + x_new_dim[x_new_dim.size() - 1] *= x_dim[i]; + num++; + } else { + reduce_dim.push_back(reduce_dim_temp[idx_reduce] - num); + x_new_dim.push_back(x_dim[i]); + } + idx_reduce++; } else { - reduce_dim.push_back(reduce_dim_temp[idx] - num); + x_new_dim.push_back(x_dim[i]); } } } else { - reduce_dim = reduce_dim_temp; + x_new_dim = x_dim; } - // update new_x_dim and new_reduce_dim - std::vector new_x_dim, new_reduce_dim_temp; + // update x_dim + x_dim = x_new_dim; + std::vector().swap(x_new_dim); + + std::vector reduce_dim_new; int is_reduced = 0; for (auto e : reduce_dim) { is_reduced |= 1 << e; } + std::vector().swap(reduce_dim); + for (int i = 0; i < x_dim.size(); i++) { if ((i == 0) || (((is_reduced >> i) ^ (is_reduced >> (i - 1))) & 1)) { - new_x_dim.push_back(x_dim[i]); + x_new_dim.push_back(x_dim[i]); if ((is_reduced >> i) & 1) - new_reduce_dim_temp.push_back(new_x_dim.size() - 1); + reduce_dim_new.push_back(x_new_dim.size() - 1); } else { - new_x_dim[new_x_dim.size() - 1] *= x_dim[i]; + x_new_dim[x_new_dim.size() - 1] *= x_dim[i]; } } - x_dim = new_x_dim; - reduce_dim = new_reduce_dim_temp; + x_dim = x_new_dim; + reduce_dim = reduce_dim_new; int x_rank = static_cast(x_dim.size()); std::set left_set; @@ -237,9 +293,9 @@ struct ReduceConfig { idx_dim.push_back(i); } - x_strides = detail::GetStrides(x_dim, idx_dim); - reduce_strides = detail::GetStrides(x_dim, reduce_dim); - left_strides = detail::GetStrides(x_dim, left_dim); + x_strides = detail::GetDimStrides(x_dim, idx_dim); + reduce_strides = detail::GetDimStrides(x_dim, reduce_dim); + left_strides = detail::GetDimStrides(x_dim, left_dim); reduce_num = reduce_strides[0] * x_dim[reduce_dim[0]]; left_num = 1; @@ -256,13 +312,17 @@ struct ReduceConfig { void SetReduceType() { int rank = x_dim.size(); int reduce_rank = reduce_dim.size(); + bool is_large_enough = (reduce_num > REDUCE_SPLIT_BOUNDARY / 2) || + (left_num > REDUCE_SPLIT_BOUNDARY); if (rank == reduce_rank) { reduce_type = static_cast(ReduceType::kReduceAll); } else if (rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) { reduce_type = static_cast(ReduceType::kReduceLastDim); - } else if (reduce_rank == 1) { + + } else if (reduce_rank == 1 && + ((rank == 2 && is_large_enough) || rank != 2)) { // ReduceFirstDim and reduceSecondDim reduce_type = static_cast(ReduceType::kReduceHigherDim); @@ -277,7 +337,7 @@ struct ReduceConfig { // for others: block(block_num, 1) , grid(left_num, 1) void SetBlockDim() { // init - int block_num = detail::GetDesiredBlockDim(reduce_num); + int block_num = detail::GetBlockDim(reduce_num); should_reduce_again = false; dim3 block_dim(block_num, 1); @@ -302,7 +362,7 @@ struct ReduceConfig { // init int num_block = (max_threads / left_num); - if (num_block > 1 && reduce_num >= 512) { + if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) { blocking_size = detail::GetLastPow2(reduce_num / num_block); if (blocking_size <= 1) { @@ -352,6 +412,9 @@ struct ReduceConfig { dim3 grid; }; +// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, this +// function will be used +// blockId.x -> left_num, threadId.x -> reduce_num template __device__ __forceinline__ void ReduceLastDim(const Tx* x, Ty* y, @@ -362,18 +425,25 @@ __device__ __forceinline__ void ReduceLastDim(const Tx* x, Ty* y, int idx_x = blockIdx.x * reduce_num; int idx_y = threadIdx.x; Ty reduce_var = init; - for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += BlockDim) - reduce_var = reducer(reduce_var, static_cast(x[idx_x + idx_y])); + for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += BlockDim) { + reduce_var = + reducer(reduce_var, static_cast(transformer(x[idx_x + idx_y]))); + } __syncthreads(); reduce_var = cub::BlockReduce(temp_storage).Reduce(reduce_var, reducer); if (threadIdx.x == 0) { - y[blockIdx.x] = transformer(reduce_var); + y[blockIdx.x] = reduce_var; } } +// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this +// function will be used +// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1 +// if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx / 32 +// else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32 template __device__ __forceinline__ void ReduceHigherDim(const Tx* x, Ty* y, ReduceOp reducer, @@ -383,25 +453,29 @@ __device__ __forceinline__ void ReduceHigherDim(const Tx* x, Ty* y, int idx = blockIdx.x * blockDim.x + threadIdx.x; int idy = blockIdx.y * block_size; - Ty temp = init; Ty reduce_var = init; if (idx < left_num) { int loop = reduce_num - idy; loop = loop > block_size ? block_size : loop; + for (int iy = 0; iy < loop; iy++) { int id = (idy + iy) * left_num + idx + blockIdx.z * reduce_num * left_num; - reduce_var = reducer(reduce_var, static_cast(x[id])); + reduce_var = reducer(reduce_var, static_cast(transformer(x[id]))); } + y[idx + blockIdx.y * left_num + blockIdx.z * gridDim.y * left_num] = - static_cast(transformer(reduce_var)); + reduce_var; } } +// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this +// function will be used +// blockId.x -> left_num, threadId.x -> reduce_num template __device__ __forceinline__ void ReduceAny( - const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init, + const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, int reduce_num, paddle::framework::Array x_strides, paddle::framework::Array reduce_dim, paddle::framework::Array reduce_strides, @@ -423,20 +497,26 @@ __device__ __forceinline__ void ReduceAny( } int idx_x = 0; - for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]); - Ty reduce_var = static_cast(x[idx_x]); + for (int k = 0; k < Rank; ++k) { + idx_x += (sub_index[k] * x_strides[k]); + } + Ty reduce_var = static_cast(transformer(x[idx_x])); for (int i = threadIdx.x + BlockDim; i < reduce_num; i += BlockDim) { int reduce_idx = i; + for (int j = 0; j < ReduceRank; ++j) { sub_index[reduce_dim[j]] = reduce_idx / reduce_strides[j]; reduce_idx %= reduce_strides[j]; } int idx_x = 0; - for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]); - reduce_var = - static_cast(reducer(reduce_var, static_cast(x[idx_x]))); + for (int k = 0; k < Rank; ++k) { + idx_x += (sub_index[k] * x_strides[k]); + } + + reduce_var = static_cast( + reducer(reduce_var, static_cast(transformer(x[idx_x])))); } __syncthreads(); @@ -444,10 +524,11 @@ __device__ __forceinline__ void ReduceAny( cub::BlockReduce(temp_storage).Reduce(reduce_var, reducer); if (threadIdx.x == 0) { - y[blockIdx.x] = transformer(reduce_var); + y[blockIdx.x] = reduce_var; } } +// module function designed for global function template __device__ __forceinline__ void ReduceModule( @@ -458,17 +539,20 @@ __device__ __forceinline__ void ReduceModule( paddle::framework::Array reduce_strides, paddle::framework::Array left_dim, paddle::framework::Array left_strides) { + // reduce_rank == 1 && reduce_dim[0] == x_dim.size() - 1 if (ReduceType == ReduceType::kReduceLastDim) { ReduceLastDim( x, y, reducer, transformer, init, reduce_num); + // reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1 } else if (ReduceType == ReduceType::kReduceHigherDim) { ReduceHigherDim( x, y, reducer, transformer, init, reduce_num, left_num, blocking_size); + // reduce_rank >= 2 } else { ReduceAny( - x, y, reducer, transformer, init, reduce_num, x_strides, reduce_dim, + x, y, reducer, transformer, reduce_num, x_strides, reduce_dim, reduce_strides, left_dim, left_strides); } } @@ -491,23 +575,22 @@ __global__ void ReduceKernelFunction( template -static void launchKernel(const Tx* x_data, Ty* y_data, - const platform::Place& place, const ReduceOp& reducer, - const TransformOp& transformer, const Ty& init, +static void LaunchKernel(const Tx* x_data, Ty* y_data, const ReduceOp& reducer, + const TransformOp& transformer, Ty init, gpuStream_t stream, ReduceConfig config) { -#define CUB_REDUCE_TYPE_CASE(type) \ - case type: { \ - constexpr auto kReduceType = type; \ - ReduceKernelFunction< \ - Tx, Ty, ReduceOp, TransformOp, BlockDim, kRank, kReduceRank, \ - kReduceType><<>>( \ - x_data, config.output_data, reducer, transformer, init, \ - config.reduce_num, config.left_num, config.blocking_size, \ - detail::from(config.x_strides), \ - detail::from(config.reduce_dim), \ - detail::from(config.reduce_strides), \ - detail::from(config.left_dim), \ - detail::from(config.left_strides)); \ +#define CUB_REDUCE_TYPE_CASE(type) \ + case type: { \ + constexpr auto kReduceType = type; \ + ReduceKernelFunction< \ + Tx, Ty, ReduceOp, TransformOp, BlockDim, kRank, kReduceRank, \ + kReduceType><<>>( \ + x_data, config.output_data, reducer, transformer, init, \ + config.reduce_num, config.left_num, config.blocking_size, \ + detail::VectorToArray(config.x_strides), \ + detail::VectorToArray(config.reduce_dim), \ + detail::VectorToArray(config.reduce_strides), \ + detail::VectorToArray(config.left_dim), \ + detail::VectorToArray(config.left_strides)); \ } break switch (config.reduce_type) { @@ -523,22 +606,22 @@ static void launchKernel(const Tx* x_data, Ty* y_data, ReduceKernelFunction< Ty, Ty, ReduceOp, detail::IdentityFunctor, 128, kRank, kReduceRank, ReduceType::kReduceHigherDim><<>>( - config.output_data, y_data, reducer, detail::IdentityFunctor(), - init, config.grid.y, config.left_num, config.grid.y, - detail::from(config.x_strides), - detail::from(config.reduce_dim), - detail::from(config.reduce_strides), - detail::from(config.left_dim), - detail::from(config.left_strides)); + config.output_data, y_data, reducer, + detail::IdentityFunctor(config.grid.y), init, config.grid.y, + config.left_num, config.grid.y, + detail::VectorToArray(config.x_strides), + detail::VectorToArray(config.reduce_dim), + detail::VectorToArray(config.reduce_strides), + detail::VectorToArray(config.left_dim), + detail::VectorToArray(config.left_strides)); } } template -static void launchReduceKernel(const Tx* x_data, Ty* y_data, - const platform::Place& place, +static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, const ReduceOp& reducer, - const TransformOp& transformer, const Ty& init, + const TransformOp& transformer, Ty init, gpuStream_t stream, ReduceConfig config) { int reduce_rank = config.reduce_strides.size(); int rank = config.x_strides.size(); @@ -552,28 +635,11 @@ static void launchReduceKernel(const Tx* x_data, Ty* y_data, #define CUB_REDUCE_RANK_CASE(i, ...) \ case i: { \ constexpr auto kReduceRank = i; \ - launchKernel( \ - x_data, y_data, place, reducer, transformer, init, stream, config); \ + LaunchKernel( \ + x_data, y_data, reducer, transformer, init, stream, config); \ } break - // launch CUB::Reduce - if (config.reduce_type == static_cast(ReduceType::kReduceAll)) { - cub::TransformInputIterator trans_x( - x_data, transformer); - size_t temp_storage_bytes = 0; - cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data, - config.reduce_num, reducer, init, stream); - framework::Tensor tmp; - auto* temp_storage = tmp.mutable_data( - framework::make_ddim({static_cast(temp_storage_bytes)}), - place); - cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data, - config.reduce_num, reducer, init, stream); - - return; - } - - detail::CheckReduceRankIsValid(reduce_rank, rank); + detail::CheckReduceRank(reduce_rank, rank); switch (rank) { CUB_RANK_CASE(2, CUB_REDUCE_RANK_CASE(1);); @@ -595,23 +661,25 @@ static void launchReduceKernel(const Tx* x_data, Ty* y_data, #undef CUB_REDUCE_RANK_CASE #undef CUB_RANK_CASE } -template -void TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y, - std::vector origin_reduce_dims, const Ty& init, - const ReduceOp& reducer, const TransformOp& transformer, - gpuStream_t stream) { + +template class ReduceOp> +void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, + std::vector origin_reduce_dims, + gpuStream_t stream) { auto x_dim = framework::vectorize(x.dims()); auto config = ReduceConfig(origin_reduce_dims, x_dim); - config.Run(); + config.Run(); // get the parameters of LaunchReduceKernel auto x_data = x.data(); auto y_data = y->mutable_data(x.place()); - framework::Tensor tmp; + // after config.run() // SetOutputData for ReduceHigherDim when should_reduce_again is true, // temp_output should be stored temp_data in output_data space or stored in // y_data; - config.SetOutputData(y_data, x.place(), tmp); + framework::Tensor tmp; + config.SetOutputData(y_data, x.place(), &tmp); if (config.reduce_num == 1) { auto out_dims = y->dims(); @@ -619,17 +687,36 @@ void TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y, y->Resize(out_dims); return; } + using TransformOp = typename ReduceOp::Transformer; + auto reducer = ReduceOp(); + // launch CUB::Reduce + if (config.reduce_type == static_cast(ReduceType::kReduceAll)) { + cub::TransformInputIterator trans_x( + x_data, TransformOp(config.reduce_num)); + size_t temp_storage_bytes = 0; + cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data, + config.reduce_num, reducer, reducer.initial(), + stream); + framework::Tensor tmp; + auto* temp_storage = tmp.mutable_data( + framework::make_ddim({static_cast(temp_storage_bytes)}), + x.place()); + cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data, + config.reduce_num, reducer, reducer.initial(), + stream); -#define CUB_BLOCK_DIM_CASE(block_dim) \ - case block_dim: { \ - constexpr auto kBlockDim = block_dim; \ - launchReduceKernel( \ - x_data, y_data, x.place(), reducer, transformer, init, stream, \ - config); \ + return; + } + +#define CUB_BLOCK_DIM_CASE(block_dim) \ + case block_dim: { \ + constexpr auto kBlockDim = block_dim; \ + LaunchReduceKernel, TransformOp>( \ + x_data, y_data, reducer, TransformOp(config.reduce_num), \ + reducer.initial(), stream, config); \ } break - switch (detail::GetDesiredBlockDim(config.reduce_num)) { - CUB_BLOCK_DIM_CASE(512); + switch (detail::GetBlockDim(config.reduce_num)) { CUB_BLOCK_DIM_CASE(256); CUB_BLOCK_DIM_CASE(128); CUB_BLOCK_DIM_CASE(64); @@ -642,5 +729,46 @@ void TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y, #undef CUB_BLOCK_DIM_CASE } +template class ReduceOp> +struct TensorReduceFunc { + const framework::Tensor& x; + framework::Tensor* y; + std::vector origin_reduce_dims; + gpuStream_t stream; + TensorReduceFunc(const framework::Tensor& x, framework::Tensor* y, + std::vector origin_reduce_dims, gpuStream_t stream) + : x(x), y(y), origin_reduce_dims(origin_reduce_dims), stream(stream) {} + + template + void apply() const { + TensorReduceFunctorImpl(x, y, origin_reduce_dims, stream); + } +}; + +template class ReduceOp> +class ReduceCudaKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + bool reduce_all = context.Attr("reduce_all"); + const Tensor* input = context.Input("X"); + Tensor* output = context.Output("Out"); + auto out_dtype = context.Attr("out_dtype"); + std::vector dims = context.Attr>("dim"); + + std::vector reduce_dims = + detail::GetReduceDim(dims, input->dims().size(), reduce_all); + + gpuStream_t stream = context.cuda_device_context().stream(); + if (out_dtype >= 0) { + framework::VisitDataTypeSmall( + static_cast(out_dtype), + TensorReduceFunc(*input, output, reduce_dims, stream)); + } else { + TensorReduceFunctorImpl(*input, output, reduce_dims, + stream); + } + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/reduce_ops/reduce_prod_op.cu b/paddle/fluid/operators/reduce_ops/reduce_prod_op.cu index 44e76c78b1f..4f259e415d2 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_prod_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_prod_op.cu @@ -12,26 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/reduce_ops/reduce_prod_op.h" +// reduce_prod #ifdef __HIPCC__ // Eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h:922 // do not support double in HIPCC platform (Eigen3 to be fixed) -REGISTER_OP_CUDA_KERNEL(reduce_prod, - ops::ReduceKernel, - ops::ReduceKernel, - ops::ReduceKernel); +REGISTER_OP_CUDA_KERNEL( + reduce_prod, ops::ReduceCudaKernel, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel); #else -REGISTER_OP_CUDA_KERNEL(reduce_prod, - ops::ReduceKernel, - ops::ReduceKernel, - ops::ReduceKernel, - ops::ReduceKernel); +REGISTER_OP_CUDA_KERNEL( + reduce_prod, ops::ReduceCudaKernel, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel, + ops::ReduceCudaKernel); #endif -- GitLab