未验证 提交 07b83f2e 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] refine set_constant_with_place by calling full kernel (#55089)

上级 b869e963
...@@ -34,6 +34,10 @@ limitations under the License. */ ...@@ -34,6 +34,10 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function_impl.h" #include "paddle/phi/kernels/funcs/math_function_impl.h"
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/core/kernel_factory.h"
#endif
namespace phi { namespace phi {
namespace funcs { namespace funcs {
...@@ -171,16 +175,27 @@ void set_constant_with_place<phi::IPUPlace>(const phi::DeviceContext& context, ...@@ -171,16 +175,27 @@ void set_constant_with_place<phi::IPUPlace>(const phi::DeviceContext& context,
template <> template <>
void set_constant_with_place<phi::CustomPlace>( void set_constant_with_place<phi::CustomPlace>(
const phi::DeviceContext& context, phi::DenseTensor* tensor, float value) { const phi::DeviceContext& context, phi::DenseTensor* tensor, float value) {
phi::DenseTensor tmp_tensor; #ifdef PADDLE_WITH_CUSTOM_DEVICE
tmp_tensor.Resize(tensor->dims()); auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
context.HostAlloc(&tmp_tensor, tensor->dtype()); "full",
phi::VisitDataType(tmp_tensor.dtype(), {paddle::experimental::ParseBackend(tensor->place()),
TensorSetConstantCPU(&tmp_tensor, value)); phi::DataLayout::ALL_LAYOUT,
phi::memory_utils::Copy(tensor->place(), paddle::experimental::ParseDataType(tensor->dtype())});
tensor->data(), const auto& kernel = kernel_result.kernel;
phi::CPUPlace(), using kernel_signature = void (*)(const phi::DeviceContext&,
tmp_tensor.data(), const phi::IntArray&,
tensor->numel() * phi::SizeOf(tensor->dtype())); const phi::Scalar&,
DataType,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(context,
phi::IntArray(phi::vectorize(tensor->dims())),
phi::Scalar(value),
tensor->dtype(),
tensor);
#else
PADDLE_THROW(phi::errors::Unimplemented("CustomPlace is not supported"));
#endif
} }
template <> template <>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册