diff --git a/paddle/phi/kernels/funcs/math_function.h b/paddle/phi/kernels/funcs/math_function.h index b42714e80db2f24a47c1978c71e71dac736ded56..5390d77c876f5fe37768d2a070670d0c34ff61bb 100644 --- a/paddle/phi/kernels/funcs/math_function.h +++ b/paddle/phi/kernels/funcs/math_function.h @@ -21,6 +21,11 @@ limitations under the License. */ #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/utils/data_type.h" +#ifdef PADDLE_WITH_XPU +#include "paddle/phi/backends/context_pool.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_header.h" +#endif namespace phi { namespace funcs { @@ -112,7 +117,8 @@ struct TensorSetConstantXPU { : tensor_(tensor), value_(value), place_(place) {} template void apply() const { - auto* begin = tensor_->mutable_data(place_); + auto* ctx = phi::DeviceContextPool::Instance().Get(place_); + auto begin = ctx->Alloc(tensor_); int numel = tensor_->numel(); std::unique_ptr data_cpu(new T[numel]); std::fill(data_cpu.get(), data_cpu.get() + numel, static_cast(value_)); @@ -126,6 +132,41 @@ struct TensorSetConstantXPU { U value_; phi::Place place_; }; + +template <> +struct TensorSetConstantXPU { + TensorSetConstantXPU(phi::DenseTensor* tensor, float value, phi::Place place) + : tensor_(tensor), value_(value), place_(place) {} + template + void apply() const { + auto* ctx = phi::DeviceContextPool::Instance().Get(place_); + auto begin = ctx->Alloc(tensor_); + int numel = tensor_->numel(); + if (((std::is_same::value) || + (std::is_same::value)) && + (place_ == phi::XPUPlace())) { + using XPUType = typename XPUTypeTrait::Type; + auto* dev_ctx = static_cast(ctx); + int r = xpu::constant(dev_ctx->x_context(), + reinterpret_cast(begin), + numel, + static_cast(value_)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + dev_ctx->Wait(); + } else { + std::unique_ptr data_cpu(new T[numel]); + std::fill(data_cpu.get(), data_cpu.get() + numel, static_cast(value_)); + memory_utils::Copy(place_, + begin, + phi::CPUPlace(), + static_cast(data_cpu.get()), + numel * sizeof(T)); + } + } + phi::DenseTensor* tensor_; + float value_; + phi::Place place_; +}; #endif template