diff --git a/paddle/pten/kernels/gpu/reduce.h b/paddle/pten/kernels/gpu/reduce.h index 0704b76a2f069d724bdd04c3f1571c7efddb5770..5a736ef0e6e72988c4ed756b57cd7d4bb41c15e5 100644 --- a/paddle/pten/kernels/gpu/reduce.h +++ b/paddle/pten/kernels/gpu/reduce.h @@ -45,8 +45,7 @@ namespace cub = hipcub; #include "paddle/pten/api/ext/dispatch.h" #include "paddle/pten/backends/gpu/gpu_context.h" #include "paddle/pten/core/dense_tensor.h" -#include "paddle/pten/kernels/cast_kernel.h" -#include "paddle/pten/kernels/copy_kernel.h" +#include "paddle/pten/kernels/gpu/elementwise.h" // Reduce split or not, Whether to use ReduceHigherDim #define REDUCE_SPLIT_BOUNDARY 512 @@ -1062,23 +1061,6 @@ static "Tx should not be float16 when using cub::DeviceReduce::Reduce().")); } -static void AsyncCopy(const pten::DenseTensor& src, pten::DenseTensor* dst) { - paddle::platform::DeviceContextPool& pool = - paddle::platform::DeviceContextPool::Instance(); - const paddle::platform::CUDADeviceContext* dev_ctx; - if (paddle::platform::is_gpu_place(dst->place()) || - paddle::platform::is_npu_place(dst->place())) { - dev_ctx = static_cast( - pool.Get(dst->place())); - - } else { - dev_ctx = static_cast( - pool.Get(src.place())); - } - - pten::Copy(*dev_ctx, src, false, dst); -} - template class ReduceOp, @@ -1111,13 +1093,10 @@ void TensorReduceFunctorImpl(const pten::DenseTensor& x, auto* dev_ctx = static_cast( paddle::platform::DeviceContextPool::Instance().Get(x.place())); if (config.reduce_num == 1) { - auto out_dims = y->dims(); - if (x.dtype() == y->dtype()) { - AsyncCopy(x, y); - y->Resize(out_dims); - } else { - pten::CastKernel(*dev_ctx, x, y->dtype(), y); - } + std::vector inputs = {&x}; + std::vector outputs = {y}; + pten::LaunchSameDimsElementwiseCudaKernel( + *dev_ctx, inputs, &outputs, transform); return; }