未验证 提交 6a9fac14 编写于 作者: N niuliling123 提交者: GitHub

modified reduce_sum_op and reduce_mean_op for higher_performance (#32885)

上级 bb01b120
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#include "paddle/fluid/operators/reduce_ops/reduce_all_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_all_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
// reduce_prod
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
reduce_all, reduce_all,
ops::ReduceCudaKernel<bool, paddle::operators::CustomLogicalAnd>); ops::ReduceCudaKernel<bool, paddle::operators::CustomLogicalAnd>);
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h"
// reduce_prod
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
reduce_any, reduce_any,
ops::ReduceCudaKernel<bool, paddle::operators::CustomLogicalOr>); ops::ReduceCudaKernel<bool, paddle::operators::CustomLogicalOr>);
...@@ -13,58 +13,11 @@ ...@@ -13,58 +13,11 @@
// limitations under the License. // limitations under the License.
#include <vector> #include <vector>
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" #include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_mean_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_mean_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
namespace paddle { REGISTER_OP_CUDA_KERNEL(
namespace operators { reduce_mean, ops::ReduceCudaKernel<bool, paddle::operators::CustomMean>,
ops::ReduceCudaKernel<float, paddle::operators::CustomMean>,
template <typename T> ops::ReduceCudaKernel<double, paddle::operators::CustomMean>);
struct DivideFunctor {
HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {}
HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; }
private:
T n_inv;
};
template <typename T>
class ReduceMeanKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
bool reduce_all = context.Attr<bool>("reduce_all");
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
auto dims = context.Attr<std::vector<int>>("dim");
bool keep_dim = context.Attr<bool>("keep_dim");
std::vector<int> reduce_dims;
if (reduce_all) {
reduce_dims.resize(input->dims().size());
for (int i = 0; i < reduce_dims.size(); ++i) reduce_dims[i] = i;
} else {
for (auto e : dims) {
reduce_dims.push_back(e >= 0 ? e : e + input->dims().size());
}
}
int reduce_num = 1;
for (int i = 0; i < reduce_dims.size(); ++i) {
reduce_num *= input->dims()[reduce_dims[i]];
}
auto stream = context.cuda_device_context().stream();
TensorReduce<T, T, cub::Sum, DivideFunctor<T>>(
*input, output, reduce_dims, static_cast<T>(0), cub::Sum(),
DivideFunctor<T>(reduce_num), stream);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(reduce_mean, ops::ReduceMeanKernel<bool>,
ops::ReduceMeanKernel<float>,
ops::ReduceMeanKernel<double>);
...@@ -33,6 +33,7 @@ namespace cub = hipcub; ...@@ -33,6 +33,7 @@ namespace cub = hipcub;
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/fast_divmod.h" #include "paddle/fluid/platform/fast_divmod.h"
...@@ -145,7 +146,6 @@ using Tensor = framework::Tensor; ...@@ -145,7 +146,6 @@ using Tensor = framework::Tensor;
constexpr int kMaxRank = framework::DDim::kMaxRank; constexpr int kMaxRank = framework::DDim::kMaxRank;
enum ReduceType { enum ReduceType {
kReduceAll = 0x00, // when reduce_rank == x_rank
kReduceLastDim = 0x01, // when reduce_dim[0] == x_dim.size() - 1; kReduceLastDim = 0x01, // when reduce_dim[0] == x_dim.size() - 1;
kReduceHigherDim = 0x02, // ReduceFirstDim or reduceSecondDim kReduceHigherDim = 0x02, // ReduceFirstDim or reduceSecondDim
kReduceAny = 0x03, // when reduce_dim.size() > 1 kReduceAny = 0x03, // when reduce_dim.size() > 1
...@@ -339,15 +339,11 @@ struct ReduceConfig { ...@@ -339,15 +339,11 @@ struct ReduceConfig {
void SetReduceType() { void SetReduceType() {
int rank = x_dim.size(); int rank = x_dim.size();
int reduce_rank = reduce_dim.size(); int reduce_rank = reduce_dim.size();
bool is_large_enough = (reduce_num > REDUCE_SPLIT_BOUNDARY / 2) || bool is_last_dim =
(left_num > REDUCE_SPLIT_BOUNDARY); (rank == 2) && (reduce_rank == 1) && (reduce_dim[0] == 1);
if (rank == reduce_rank || is_last_dim) {
if (rank == reduce_rank) {
reduce_type = static_cast<int>(ReduceType::kReduceAll);
} else if (rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) {
reduce_type = static_cast<int>(ReduceType::kReduceLastDim); reduce_type = static_cast<int>(ReduceType::kReduceLastDim);
} else if (reduce_rank == 1 && } else if (reduce_rank == 1) {
((rank == 2 && is_large_enough) || rank != 2)) {
// ReduceFirstDim and reduceSecondDim // ReduceFirstDim and reduceSecondDim
reduce_type = static_cast<int>(ReduceType::kReduceHigherDim); reduce_type = static_cast<int>(ReduceType::kReduceHigherDim);
} else { } else {
...@@ -577,14 +573,15 @@ static __device__ T BlockYReduce(T val, ReduceOp reducer) { ...@@ -577,14 +573,15 @@ static __device__ T BlockYReduce(T val, ReduceOp reducer) {
// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1 // eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
// if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx / 32 // if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx / 32
// else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32 // else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp> template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
typename TransformOp>
__device__ void ReduceHigherDim(const Tx* x, Ty* y, ReduceOp reducer, __device__ void ReduceHigherDim(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, Ty init, TransformOp transformer, MPType init,
int reduce_num, int left_num, int block_size) { int reduce_num, int left_num, int block_size) {
int idx = blockIdx.x * blockDim.x + threadIdx.x; int idx = blockIdx.x * blockDim.x + threadIdx.x;
int idy = blockIdx.y * block_size; int idy = blockIdx.y * block_size;
Ty reduce_var = init; MPType reduce_var = init;
if (idx < left_num) { if (idx < left_num) {
int loop = reduce_num - idy; int loop = reduce_num - idy;
...@@ -592,24 +589,24 @@ __device__ void ReduceHigherDim(const Tx* x, Ty* y, ReduceOp reducer, ...@@ -592,24 +589,24 @@ __device__ void ReduceHigherDim(const Tx* x, Ty* y, ReduceOp reducer,
for (int iy = 0; iy < loop; iy++) { for (int iy = 0; iy < loop; iy++) {
int id = (idy + iy) * left_num + idx + blockIdx.z * reduce_num * left_num; int id = (idy + iy) * left_num + idx + blockIdx.z * reduce_num * left_num;
reduce_var = reducer(reduce_var, static_cast<Ty>(transformer(x[id]))); reduce_var = reducer(reduce_var, static_cast<MPType>(transformer(x[id])));
} }
y[idx + blockIdx.y * left_num + blockIdx.z * gridDim.y * left_num] = y[idx + blockIdx.y * left_num + blockIdx.z * gridDim.y * left_num] =
reduce_var; static_cast<Ty>(reduce_var);
} }
} }
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or // when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this // when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
// function will be used // function will be used
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp, template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
typename ReduceIndexCal, typename LeftIndexCal> typename TransformOp>
__device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer, __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, Ty init, int reduce_num, TransformOp transformer, MPType init, int reduce_num,
int left_num, bool reduce_lastdim, int left_num, bool reduce_lastdim,
ReduceIndexCal reduce_index_calculator, const IndexCalculator& reduce_index_calculator,
LeftIndexCal left_index_calculator) { const IndexCalculator& left_index_calculator) {
int input_idx, left_idx, stride; int input_idx, left_idx, stride;
// the last dim gets involved in reduction // the last dim gets involved in reduction
if (reduce_lastdim) { if (reduce_lastdim) {
...@@ -622,9 +619,9 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer, ...@@ -622,9 +619,9 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
stride = gridDim.y * blockDim.y; stride = gridDim.y * blockDim.y;
} }
// calculate the offset, means the addr where each thread really start. // calculate the offset, means the addr where each thread really start.
int input_offset = left_index_calculator(left_idx); int input_offset = left_index_calculator.Get(left_idx);
const Tx* input = x + input_offset; const Tx* input = x + input_offset;
Ty reduce_var = init; MPType reduce_var = init;
// 1. reduce for each thread // 1. reduce for each thread
if (left_idx < left_num) { if (left_idx < left_num) {
...@@ -635,12 +632,13 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer, ...@@ -635,12 +632,13 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
#pragma unroll #pragma unroll
for (int i = 0; i < REDUCE_VEC_SIZE; ++i) { for (int i = 0; i < REDUCE_VEC_SIZE; ++i) {
int reduce_idx = input_idx + i * stride; int reduce_idx = input_idx + i * stride;
int idx_x = reduce_index_calculator(reduce_idx); int idx_x = reduce_index_calculator.Get(reduce_idx);
input_reg[i] = input[idx_x]; input_reg[i] = input[idx_x];
} }
#pragma unroll #pragma unroll
for (int i = 0; i < REDUCE_VEC_SIZE; ++i) { for (int i = 0; i < REDUCE_VEC_SIZE; ++i) {
reduce_var = reducer(reduce_var, transformer(input_reg[i])); reduce_var =
reducer(reduce_var, static_cast<MPType>(transformer(input_reg[i])));
} }
input_idx += REDUCE_VEC_SIZE * stride; input_idx += REDUCE_VEC_SIZE * stride;
} }
...@@ -653,7 +651,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer, ...@@ -653,7 +651,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
break; break;
} }
int reduce_idx = input_idx; int reduce_idx = input_idx;
int idx_x = reduce_index_calculator(reduce_idx); int idx_x = reduce_index_calculator.Get(reduce_idx);
input_reg[i] = input[idx_x]; input_reg[i] = input[idx_x];
input_idx += stride; input_idx += stride;
} }
...@@ -663,7 +661,8 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer, ...@@ -663,7 +661,8 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
if (input_idx >= reduce_num) { if (input_idx >= reduce_num) {
break; break;
} }
reduce_var = reducer(reduce_var, transformer(input_reg[i])); reduce_var =
reducer(reduce_var, static_cast<MPType>(transformer(input_reg[i])));
input_idx += stride; input_idx += stride;
} }
} }
...@@ -678,63 +677,56 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer, ...@@ -678,63 +677,56 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
// 3. reduce in block x // 3. reduce in block x
reduce_var = BlockXReduce(reduce_var, reducer); reduce_var = BlockXReduce(reduce_var, reducer);
if (left_idx < left_num && threadIdx.x == 0) { if (left_idx < left_num && threadIdx.x == 0) {
y[blockIdx.y * left_num + left_idx] = reduce_var; y[blockIdx.y * left_num + left_idx] = static_cast<Ty>(reduce_var);
} }
} else { } else {
if (left_idx < left_num && threadIdx.y == 0) { if (left_idx < left_num && threadIdx.y == 0) {
y[blockIdx.y * left_num + left_idx] = reduce_var; y[blockIdx.y * left_num + left_idx] = static_cast<Ty>(reduce_var);
} }
} }
} }
// module function designed for global function // module function designed for global function
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp> template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
typename TransformOp>
__device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer, __device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, Ty init, int reduce_num, TransformOp transformer, MPType init,
int left_num, int blocking_size, int reduce_type, int reduce_num, int left_num, int blocking_size,
bool reduce_lastdim, int reduce_type, bool reduce_lastdim,
const IndexCalculator& reduce_index_calculator, const IndexCalculator& reduce_index_calculator,
const IndexCalculator& left_index_calculator) { const IndexCalculator& left_index_calculator) {
if (reduce_type == ReduceType::kReduceLastDim) { if (reduce_type == ReduceType::kReduceLastDim ||
ReduceAny<Tx, Ty, ReduceOp, TransformOp>( reduce_type == ReduceType::kReduceAny) {
ReduceAny<Tx, Ty, MPType, ReduceOp, TransformOp>(
x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim, x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
[&](int idx) { return idx; }, reduce_index_calculator, left_index_calculator);
[&](int idx) { return idx * reduce_num; });
// reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1 // reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1
} else if (reduce_type == ReduceType::kReduceHigherDim) { } else if (reduce_type == ReduceType::kReduceHigherDim) {
ReduceHigherDim<Tx, Ty, ReduceOp, TransformOp>( ReduceHigherDim<Tx, Ty, MPType, ReduceOp, TransformOp>(
x, y, reducer, transformer, init, reduce_num, left_num, blocking_size); x, y, reducer, transformer, init, reduce_num, left_num, blocking_size);
// reduce_rank >= 2
} else {
ReduceAny<Tx, Ty, ReduceOp, TransformOp>(
x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
[&](int idx) { return reduce_index_calculator.Get(idx); },
[&](int idx) { return left_index_calculator.Get(idx); });
} }
} }
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp> template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
typename TransformOp>
__global__ void ReduceKernelFunction(const Tx* x, Ty* y, ReduceOp reducer, __global__ void ReduceKernelFunction(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, Ty init, TransformOp transformer, MPType init,
int reduce_num, int left_num, int reduce_num, int left_num,
int blocking_size, int reduce_type, int blocking_size, int reduce_type,
bool reduce_lastdim, bool reduce_lastdim,
IndexCalculator reduce_index_calculator, IndexCalculator reduce_index_calculator,
IndexCalculator left_index_calculator) { IndexCalculator left_index_calculator) {
ReduceModule<Tx, Ty, ReduceOp, TransformOp>( ReduceModule<Tx, Ty, MPType, ReduceOp, TransformOp>(
x, y, reducer, transformer, init, reduce_num, left_num, blocking_size, x, y, reducer, transformer, init, reduce_num, left_num, blocking_size,
reduce_type, reduce_lastdim, reduce_index_calculator, reduce_type, reduce_lastdim, reduce_index_calculator,
left_index_calculator); left_index_calculator);
} }
template <typename Tx, typename Ty, typename ReduceOp> template <typename Tx, typename Ty, typename MPType, typename ReduceOp>
static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
const ReduceOp& reducer, Ty init, const ReduceOp& reducer, MPType init,
gpuStream_t stream, ReduceConfig<Ty> config) { gpuStream_t stream, ReduceConfig<Ty> config) {
using TransformOp = typename ReduceOp::Transformer; using TransformOp = typename ReduceOp::Transformer;
int reduce_rank = config.reduce_strides.size(); int reduce_rank = config.reduce_strides.size();
int left_rank = config.left_strides.size(); int left_rank = config.left_strides.size();
auto reduce_index_calculator = IndexCalculator( auto reduce_index_calculator = IndexCalculator(
...@@ -742,7 +734,7 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, ...@@ -742,7 +734,7 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
auto left_index_calculator = IndexCalculator( auto left_index_calculator = IndexCalculator(
left_rank, config.left_dim, config.left_strides, config.x_strides); left_rank, config.left_dim, config.left_strides, config.x_strides);
ReduceKernelFunction<Tx, Ty, ReduceOp, ReduceKernelFunction<Tx, Ty, MPType, ReduceOp,
TransformOp><<<config.grid, config.block, 0, stream>>>( TransformOp><<<config.grid, config.block, 0, stream>>>(
x_data, config.output_data, reducer, TransformOp(config.reduce_num), init, x_data, config.output_data, reducer, TransformOp(config.reduce_num), init,
config.reduce_num, config.left_num, config.blocking_size, config.reduce_num, config.left_num, config.blocking_size,
...@@ -760,10 +752,11 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, ...@@ -760,10 +752,11 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
grid = dim3(config.grid.x, 1, config.grid.z); grid = dim3(config.grid.x, 1, config.grid.z);
} }
ReduceKernelFunction<Ty, Ty, ReduceOp, detail::IdentityFunctor< ReduceKernelFunction<
Ty>><<<grid, block, 0, stream>>>( Ty, Ty, MPType, ReduceOp,
detail::IdentityFunctor<Ty, MPType>><<<grid, block, 0, stream>>>(
config.output_data, y_data, reducer, config.output_data, y_data, reducer,
detail::IdentityFunctor<Ty>(config.grid.y), init, config.grid.y, detail::IdentityFunctor<Ty, MPType>(config.grid.y), init, config.grid.y,
config.left_num, config.grid.y, ReduceType::kReduceHigherDim, config.left_num, config.grid.y, ReduceType::kReduceHigherDim,
config.reduce_lastdim, reduce_index_calculator, left_index_calculator); config.reduce_lastdim, reduce_index_calculator, left_index_calculator);
} }
...@@ -794,11 +787,12 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, ...@@ -794,11 +787,12 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
} }
config.SetOutputData(y_data, x.place(), &tmp); config.SetOutputData(y_data, x.place(), &tmp);
bool use_cub_reduce = (config.left_num == 1) &&
using TransformOp = typename ReduceOp<Tx, Ty>::Transformer; (!std::is_same<Tx, paddle::platform::float16>::value);
auto reducer = ReduceOp<Tx, Ty>(); if (use_cub_reduce) {
// launch CUB::Reduce // launch CUB::Reduce
if (config.reduce_type == static_cast<int>(ReduceType::kReduceAll)) { using TransformOp = typename ReduceOp<Tx, Ty>::Transformer;
auto reducer = ReduceOp<Tx, Ty>();
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x( cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(
x_data, TransformOp(config.reduce_num)); x_data, TransformOp(config.reduce_num));
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
...@@ -816,7 +810,9 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, ...@@ -816,7 +810,9 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
return; return;
} }
LaunchReduceKernel<Tx, Ty, ReduceOp<Tx, Ty>>( using MPType = typename details::MPTypeTrait<Ty>::Type;
auto reducer = ReduceOp<Tx, MPType>();
LaunchReduceKernel<Tx, Ty, MPType, ReduceOp<Tx, MPType>>(
x_data, y_data, reducer, reducer.initial(), stream, config); x_data, y_data, reducer, reducer.initial(), stream, config);
} }
......
...@@ -11,72 +11,18 @@ ...@@ -11,72 +11,18 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
namespace paddle {
namespace operators {
template <typename Tout>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}
template <typename U>
HOSTDEVICE inline Tout operator()(const U& x) const {
return static_cast<Tout>(x);
}
};
template <typename T>
class ReduceSumKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
bool reduce_all = context.Attr<bool>("reduce_all");
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
auto out_dtype = context.Attr<int>("out_dtype");
auto dims = context.Attr<std::vector<int>>("dim");
bool keep_dim = context.Attr<bool>("keep_dim");
std::vector<int> reduce_dims;
if (reduce_all) {
reduce_dims.resize(input->dims().size());
for (int i = 0; i < reduce_dims.size(); ++i) reduce_dims[i] = i;
} else {
for (auto e : dims) {
reduce_dims.push_back(e >= 0 ? e : e + input->dims().size());
}
}
int reduce_num = 1;
for (int i = 0; i < reduce_dims.size(); ++i) {
reduce_num *= input->dims()[reduce_dims[i]];
}
auto stream = context.cuda_device_context().stream();
if (out_dtype >= 0) {
framework::VisitDataTypeSmall(
static_cast<framework::proto::VarType::Type>(out_dtype),
TensorReduceFunctor<T, cub::Sum, IdentityFunctor>(
*input, output, reduce_dims, static_cast<double>(0.0), cub::Sum(),
stream));
} else {
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
*input, output, reduce_dims, static_cast<T>(0), cub::Sum(),
IdentityFunctor<T>(), stream);
}
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
reduce_sum, ops::ReduceSumKernel<bool>, ops::ReduceSumKernel<float>, reduce_sum, ops::ReduceCudaKernel<bool, paddle::operators::CustomSum>,
ops::ReduceSumKernel<double>, ops::ReduceCudaKernel<float, paddle::operators::CustomSum>,
ops::ReduceSumKernel<paddle::platform::float16>, ops::ReduceSumKernel<int>, ops::ReduceCudaKernel<double, paddle::operators::CustomSum>,
ops::ReduceSumKernel<int64_t>, ops::ReduceCudaKernel<paddle::platform::float16,
ops::ReduceSumKernel<paddle::platform::complex<float>>, paddle::operators::CustomSum>,
ops::ReduceSumKernel<paddle::platform::complex<double>>); ops::ReduceCudaKernel<int, paddle::operators::CustomSum>,
ops::ReduceCudaKernel<int64_t, paddle::operators::CustomSum>,
ops::ReduceCudaKernel<paddle::platform::complex<float>,
paddle::operators::CustomSum>,
ops::ReduceCudaKernel<paddle::platform::complex<double>,
paddle::operators::CustomSum>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册