未验证 提交 f22b9666 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] fix set_constant (#52360)

上级 4e23af72
...@@ -184,7 +184,16 @@ void set_constant_with_place<phi::IPUPlace>(const phi::DeviceContext& context, ...@@ -184,7 +184,16 @@ void set_constant_with_place<phi::IPUPlace>(const phi::DeviceContext& context,
template <> template <>
void set_constant_with_place<phi::CustomPlace>( void set_constant_with_place<phi::CustomPlace>(
const phi::DeviceContext& context, phi::DenseTensor* tensor, float value) { const phi::DeviceContext& context, phi::DenseTensor* tensor, float value) {
PADDLE_THROW(phi::errors::Unimplemented("CustomPlace is not supported")); phi::DenseTensor tmp_tensor;
tmp_tensor.Resize(tensor->dims());
context.HostAlloc(&tmp_tensor, tensor->dtype());
phi::VisitDataType(tmp_tensor.dtype(),
TensorSetConstantCPU(&tmp_tensor, value));
phi::memory_utils::Copy(tensor->place(),
tensor->data(),
phi::CPUPlace(),
tmp_tensor.data(),
tensor->numel() * phi::SizeOf(tensor->dtype()));
} }
template <> template <>
...@@ -230,7 +239,7 @@ void set_constant(const phi::DeviceContext& context, ...@@ -230,7 +239,7 @@ void set_constant(const phi::DeviceContext& context,
TensorSetConstantWithPlace func(context, tensor, value); TensorSetConstantWithPlace func(context, tensor, value);
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
if (context.GetPlace().GetType() == phi::AllocationType::CUSTOM) { if (context.GetPlace().GetType() == phi::AllocationType::CUSTOM) {
func(phi::CPUPlace()); func(phi::CustomPlace());
return; return;
} }
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册