未验证 提交 a68709d8 编写于 作者: H houj04 提交者: GitHub

add NPU support for zero_copy_tensor. (#34629)

* add NPU support for zero_copy_tensor.

* revert unnesessary codes.

* revert unnesessary codes.
上级 7a38b769
...@@ -65,10 +65,13 @@ T *Tensor::mutable_data(PlaceType place) { ...@@ -65,10 +65,13 @@ T *Tensor::mutable_data(PlaceType place) {
case static_cast<int>(PlaceType::kXPU): { case static_cast<int>(PlaceType::kXPU): {
return tensor->mutable_data<T>(paddle::platform::XPUPlace(device_)); return tensor->mutable_data<T>(paddle::platform::XPUPlace(device_));
} }
case static_cast<int>(PlaceType::kNPU): {
return tensor->mutable_data<T>(paddle::platform::NPUPlace(device_));
}
default: default:
PADDLE_THROW(paddle::platform::errors::Unavailable( PADDLE_THROW(paddle::platform::errors::Unavailable(
"Only CPU / CUDA / XPU places is supported. The place `%d` is not " "Only CPU / CUDA / XPU / NPU places is supported. The place `%d` is "
"supported.", "not supported.",
static_cast<int>(place))); static_cast<int>(place)));
break; break;
} }
...@@ -86,6 +89,8 @@ T *Tensor::data(PlaceType *place, int *size) const { ...@@ -86,6 +89,8 @@ T *Tensor::data(PlaceType *place, int *size) const {
*place = PlaceType::kGPU; *place = PlaceType::kGPU;
} else if (paddle::platform::is_xpu_place(tensor->place())) { } else if (paddle::platform::is_xpu_place(tensor->place())) {
*place = PlaceType::kXPU; *place = PlaceType::kXPU;
} else if (paddle::platform::is_npu_place(tensor->place())) {
*place = PlaceType::kNPU;
} else { } else {
*place = PlaceType::kUNK; *place = PlaceType::kUNK;
} }
......
...@@ -133,6 +133,10 @@ TEST(Tensor, FillRandomDataAndCheck) { ...@@ -133,6 +133,10 @@ TEST(Tensor, FillRandomDataAndCheck) {
ASSERT_TRUE(FillRandomDataAndCheck(PlaceType::kGPU)); ASSERT_TRUE(FillRandomDataAndCheck(PlaceType::kGPU));
ASSERT_TRUE(SetPlaceAndCheck(PlaceType::kGPU)); ASSERT_TRUE(SetPlaceAndCheck(PlaceType::kGPU));
#endif #endif
#ifdef PADDLE_WITH_ASCEND_CL
ASSERT_TRUE(FillRandomDataAndCheck(PlaceType::kNPU));
ASSERT_TRUE(SetPlaceAndCheck(PlaceType::kNPU));
#endif
} }
} // namespace paddle_infer } // namespace paddle_infer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册