diff --git a/paddle/phi/kernels/funcs/math_function.cc b/paddle/phi/kernels/funcs/math_function.cc index b339a0ee5ac0a06c4803db77712b363ebcfd91ea..10d18cc958ae365ee3f16a978732961542409349 100644 --- a/paddle/phi/kernels/funcs/math_function.cc +++ b/paddle/phi/kernels/funcs/math_function.cc @@ -34,6 +34,10 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_function_impl.h" #include "unsupported/Eigen/CXX11/Tensor" +#ifdef PADDLE_WITH_CUSTOM_DEVICE +#include "paddle/phi/api/lib/kernel_dispatch.h" +#include "paddle/phi/core/kernel_factory.h" +#endif namespace phi { namespace funcs { @@ -171,16 +175,27 @@ void set_constant_with_place(const phi::DeviceContext& context, template <> void set_constant_with_place( const phi::DeviceContext& context, phi::DenseTensor* tensor, float value) { - phi::DenseTensor tmp_tensor; - tmp_tensor.Resize(tensor->dims()); - context.HostAlloc(&tmp_tensor, tensor->dtype()); - phi::VisitDataType(tmp_tensor.dtype(), - TensorSetConstantCPU(&tmp_tensor, value)); - phi::memory_utils::Copy(tensor->place(), - tensor->data(), - phi::CPUPlace(), - tmp_tensor.data(), - tensor->numel() * phi::SizeOf(tensor->dtype())); +#ifdef PADDLE_WITH_CUSTOM_DEVICE + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( + "full", + {paddle::experimental::ParseBackend(tensor->place()), + phi::DataLayout::ALL_LAYOUT, + paddle::experimental::ParseDataType(tensor->dtype())}); + const auto& kernel = kernel_result.kernel; + using kernel_signature = void (*)(const phi::DeviceContext&, + const phi::IntArray&, + const phi::Scalar&, + DataType, + phi::DenseTensor*); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + (*kernel_fn)(context, + phi::IntArray(phi::vectorize(tensor->dims())), + phi::Scalar(value), + tensor->dtype(), + tensor); +#else + PADDLE_THROW(phi::errors::Unimplemented("CustomPlace is not supported")); +#endif } template <>