未验证 提交 db5204ec 编写于 作者: R Ruibin Cheung 提交者: GitHub

[Fix Bug] fix get_new_shape and get_new_data_from_tensor not support fallback...

[Fix Bug] fix get_new_shape and get_new_data_from_tensor not support fallback to CPU on custom device (#52002)
上级 b811043a
...@@ -94,7 +94,14 @@ inline std::vector<int> get_new_shape( ...@@ -94,7 +94,14 @@ inline std::vector<int> get_new_shape(
"The shape of dimension tensor should be [1] or []," "The shape of dimension tensor should be [1] or [],"
"but received d%.", "but received d%.",
tensor->dims())); tensor->dims()));
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (tensor->place().GetType() == phi::AllocationType::CUSTOM) {
DenseTensor temp;
phi::Copy(*dev_ctx, *tensor, phi::CPUPlace(), true, &temp);
vec_new_shape.push_back(static_cast<int32_t>(*temp.data<int32_t>()));
continue;
}
#endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
if (tensor->place().GetType() == phi::AllocationType::XPU) { if (tensor->place().GetType() == phi::AllocationType::XPU) {
DenseTensor temp; DenseTensor temp;
...@@ -128,6 +135,13 @@ inline std::vector<T> get_new_data_from_tensor( ...@@ -128,6 +135,13 @@ inline std::vector<T> get_new_data_from_tensor(
*dev_ctx, *new_data_tensor, phi::CPUPlace(), true, &cpu_starts_tensor); *dev_ctx, *new_data_tensor, phi::CPUPlace(), true, &cpu_starts_tensor);
new_data = cpu_starts_tensor.data<T>(); new_data = cpu_starts_tensor.data<T>();
} }
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (new_data_tensor->place().GetType() == phi::AllocationType::CUSTOM) {
phi::Copy(
*dev_ctx, *new_data_tensor, phi::CPUPlace(), true, &cpu_starts_tensor);
new_data = cpu_starts_tensor.data<T>();
}
#endif
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
if (new_data_tensor->place().GetType() == phi::AllocationType::NPU) { if (new_data_tensor->place().GetType() == phi::AllocationType::NPU) {
phi::Copy( phi::Copy(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册