未验证 提交 f22b9666 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] fix set_constant (#52360)

上级 4e23af72
......@@ -184,7 +184,16 @@ void set_constant_with_place<phi::IPUPlace>(const phi::DeviceContext& context,
template <>
void set_constant_with_place<phi::CustomPlace>(
const phi::DeviceContext& context, phi::DenseTensor* tensor, float value) {
PADDLE_THROW(phi::errors::Unimplemented("CustomPlace is not supported"));
phi::DenseTensor tmp_tensor;
tmp_tensor.Resize(tensor->dims());
context.HostAlloc(&tmp_tensor, tensor->dtype());
phi::VisitDataType(tmp_tensor.dtype(),
TensorSetConstantCPU(&tmp_tensor, value));
phi::memory_utils::Copy(tensor->place(),
tensor->data(),
phi::CPUPlace(),
tmp_tensor.data(),
tensor->numel() * phi::SizeOf(tensor->dtype()));
}
template <>
......@@ -230,7 +239,7 @@ void set_constant(const phi::DeviceContext& context,
TensorSetConstantWithPlace func(context, tensor, value);
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (context.GetPlace().GetType() == phi::AllocationType::CUSTOM) {
func(phi::CPUPlace());
func(phi::CustomPlace());
return;
}
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册