From 5756d3e5df29e32837994ed8d579c22ebadadcd0 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Fri, 28 May 2021 12:40:29 +0800 Subject: [PATCH] modify to complex template types in reduce_sum OP and rewrite it's IdentityFunctor struct (#33164) --- .../fluid/operators/reduce_ops/cub_reduce.h | 13 +++++----- .../operators/reduce_ops/reduce_sum_op.cc | 17 ++++++------ .../operators/reduce_ops/reduce_sum_op.cu | 26 ++++++++++--------- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/operators/reduce_ops/cub_reduce.h b/paddle/fluid/operators/reduce_ops/cub_reduce.h index 29e46e091d0..9e1aed5dde4 100644 --- a/paddle/fluid/operators/reduce_ops/cub_reduce.h +++ b/paddle/fluid/operators/reduce_ops/cub_reduce.h @@ -366,33 +366,32 @@ void TensorReduce(const framework::Tensor& x, framework::Tensor* y, #undef CUB_BLOCK_DIM_CASE } -template +template class TransformOp> struct TensorReduceFunctor { const framework::Tensor& x; framework::Tensor* y; std::vector origin_reduce_dims; const double& init; const ReduceOp& reducer; - const TransformOp& transformer; gpuStream_t stream; TensorReduceFunctor(const framework::Tensor& x, framework::Tensor* y, std::vector origin_reduce_dims, const double& init, - const ReduceOp& reducer, const TransformOp& transformer, - gpuStream_t stream) + const ReduceOp& reducer, gpuStream_t stream) : x(x), y(y), origin_reduce_dims(origin_reduce_dims), init(init), reducer(reducer), - transformer(transformer), stream(stream) {} template void apply() const { const Ty& init_cast = static_cast(init); - TensorReduce( - x, y, origin_reduce_dims, init_cast, reducer, transformer, stream); + TensorReduce>( + x, y, origin_reduce_dims, init_cast, reducer, TransformOp(), + stream); } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc index a085e851eea..74e7db649d5 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc @@ -119,9 +119,9 @@ REGISTER_OP_CPU_KERNEL( ops::ReduceKernel, ops::ReduceKernel, + paddle::platform::complex, ops::SumFunctor>, ops::ReduceKernel, ops::SumFunctor>); @@ -130,10 +130,9 @@ using CPUReduceSumGradKernel = ops::ReduceSumGradKernel; -REGISTER_OP_CPU_KERNEL(reduce_sum_grad, CPUReduceSumGradKernel, - CPUReduceSumGradKernel, - CPUReduceSumGradKernel, - CPUReduceSumGradKernel, - CPUReduceSumGradKernel, - CPUReduceSumGradKernel, - CPUReduceSumGradKernel); +REGISTER_OP_CPU_KERNEL( + reduce_sum_grad, CPUReduceSumGradKernel, + CPUReduceSumGradKernel, CPUReduceSumGradKernel, + CPUReduceSumGradKernel, CPUReduceSumGradKernel, + CPUReduceSumGradKernel>, + CPUReduceSumGradKernel>); diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu index dbd020514b2..dd16ca4e393 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu @@ -18,11 +18,13 @@ namespace paddle { namespace operators { -template +template struct IdentityFunctor { HOSTDEVICE explicit inline IdentityFunctor() {} - HOSTDEVICE inline T operator()(const T& x) const { return x; } + HOSTDEVICE inline Ty operator()(const Tx& x) const { + return static_cast(x); + } }; template @@ -56,13 +58,13 @@ class ReduceSumKernel : public framework::OpKernel { if (out_dtype >= 0) { framework::VisitDataTypeSmall( static_cast(out_dtype), - TensorReduceFunctor>( + TensorReduceFunctor( *input, output, reduce_dims, static_cast(0.0), cub::Sum(), - IdentityFunctor(), stream)); + stream)); } else { - TensorReduce>( + TensorReduce>( *input, output, reduce_dims, static_cast(0), cub::Sum(), - IdentityFunctor(), stream); + IdentityFunctor(), stream); } } }; @@ -70,9 +72,9 @@ class ReduceSumKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(reduce_sum, ops::ReduceSumKernel, - ops::ReduceSumKernel, - ops::ReduceSumKernel, ops::ReduceSumKernel, - ops::ReduceSumKernel, - ops::ReduceSumKernel, - ops::ReduceSumKernel); +REGISTER_OP_CUDA_KERNEL( + reduce_sum, ops::ReduceSumKernel, ops::ReduceSumKernel, + ops::ReduceSumKernel, ops::ReduceSumKernel, + ops::ReduceSumKernel, + ops::ReduceSumKernel>, + ops::ReduceSumKernel>); -- GitLab