未验证 提交 07b83f2e 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] refine set_constant_with_place by calling full kernel (#55089)

上级 b869e963
......@@ -34,6 +34,10 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function_impl.h"
#include "unsupported/Eigen/CXX11/Tensor"
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/core/kernel_factory.h"
#endif
namespace phi {
namespace funcs {
......@@ -171,16 +175,27 @@ void set_constant_with_place<phi::IPUPlace>(const phi::DeviceContext& context,
template <>
void set_constant_with_place<phi::CustomPlace>(
const phi::DeviceContext& context, phi::DenseTensor* tensor, float value) {
phi::DenseTensor tmp_tensor;
tmp_tensor.Resize(tensor->dims());
context.HostAlloc(&tmp_tensor, tensor->dtype());
phi::VisitDataType(tmp_tensor.dtype(),
TensorSetConstantCPU(&tmp_tensor, value));
phi::memory_utils::Copy(tensor->place(),
tensor->data(),
phi::CPUPlace(),
tmp_tensor.data(),
tensor->numel() * phi::SizeOf(tensor->dtype()));
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"full",
{paddle::experimental::ParseBackend(tensor->place()),
phi::DataLayout::ALL_LAYOUT,
paddle::experimental::ParseDataType(tensor->dtype())});
const auto& kernel = kernel_result.kernel;
using kernel_signature = void (*)(const phi::DeviceContext&,
const phi::IntArray&,
const phi::Scalar&,
DataType,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(context,
phi::IntArray(phi::vectorize(tensor->dims())),
phi::Scalar(value),
tensor->dtype(),
tensor);
#else
PADDLE_THROW(phi::errors::Unimplemented("CustomPlace is not supported"));
#endif
}
template <>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册