From 2360406d3011aca2c6dd561a9dba3df96d061931 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Fri, 10 Dec 2021 13:34:16 +0800 Subject: [PATCH] [PTen]fix pten::Copy use error (#37982) * fix pten::Copy use error in redcue_impl * remove in_dtype args in reduce kernel * fix copy error * fix copy error --- .../hybird/cuda/reduce/reduce_cuda_impl.h | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h b/paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h index 94fab974ac..1f1b8ddd5f 100644 --- a/paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h +++ b/paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h @@ -769,6 +769,23 @@ static void LaunchReduceKernel(const Tx* x_data, } } +void TensorCopy(const DenseTensor& src, DenseTensor* dst) { + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + const paddle::platform::CUDADeviceContext* dev_ctx; + if (paddle::platform::is_gpu_place(dst->place()) || + paddle::platform::is_npu_place(dst->place())) { + dev_ctx = static_cast( + pool.Get(dst->place())); + + } else { + dev_ctx = static_cast( + pool.Get(src.place())); + } + + pten::Copy(*dev_ctx, src, false, dst); +} + template class ReduceOp> @@ -800,7 +817,7 @@ void TensorReduceFunctorImpl(const pten::DenseTensor& x, if (config.reduce_num == 1) { auto out_dims = y->dims(); if (x.dtype() == y->dtype()) { - pten::Copy(*dev_ctx, x, true, y); + TensorCopy(x, y); y->Resize(out_dims); } else { PD_VISIT_ALL_TYPES(y->dtype(), "CastKernelImpl", ([&] { -- GitLab