From 07b83f2ec5b1bbe43f55672385eaa462ac31be24 Mon Sep 17 00:00:00 2001 From: ronnywang Date: Tue, 4 Jul 2023 10:40:41 +0800 Subject: [PATCH] [CustomDevice] refine set_constant_with_place by calling full kernel (#55089) --- paddle/phi/kernels/funcs/math_function.cc | 35 ++++++++++++++++------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/paddle/phi/kernels/funcs/math_function.cc b/paddle/phi/kernels/funcs/math_function.cc index b339a0ee5ac..10d18cc958a 100644 --- a/paddle/phi/kernels/funcs/math_function.cc +++ b/paddle/phi/kernels/funcs/math_function.cc @@ -34,6 +34,10 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_function_impl.h" #include "unsupported/Eigen/CXX11/Tensor" +#ifdef PADDLE_WITH_CUSTOM_DEVICE +#include "paddle/phi/api/lib/kernel_dispatch.h" +#include "paddle/phi/core/kernel_factory.h" +#endif namespace phi { namespace funcs { @@ -171,16 +175,27 @@ void set_constant_with_place(const phi::DeviceContext& context, template <> void set_constant_with_place( const phi::DeviceContext& context, phi::DenseTensor* tensor, float value) { - phi::DenseTensor tmp_tensor; - tmp_tensor.Resize(tensor->dims()); - context.HostAlloc(&tmp_tensor, tensor->dtype()); - phi::VisitDataType(tmp_tensor.dtype(), - TensorSetConstantCPU(&tmp_tensor, value)); - phi::memory_utils::Copy(tensor->place(), - tensor->data(), - phi::CPUPlace(), - tmp_tensor.data(), - tensor->numel() * phi::SizeOf(tensor->dtype())); +#ifdef PADDLE_WITH_CUSTOM_DEVICE + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( + "full", + {paddle::experimental::ParseBackend(tensor->place()), + phi::DataLayout::ALL_LAYOUT, + paddle::experimental::ParseDataType(tensor->dtype())}); + const auto& kernel = kernel_result.kernel; + using kernel_signature = void (*)(const phi::DeviceContext&, + const phi::IntArray&, + const phi::Scalar&, + DataType, + phi::DenseTensor*); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + (*kernel_fn)(context, + phi::IntArray(phi::vectorize(tensor->dims())), + phi::Scalar(value), + tensor->dtype(), + tensor); +#else + PADDLE_THROW(phi::errors::Unimplemented("CustomPlace is not supported")); +#endif } template <> -- GitLab