未验证 提交 8a9dc5dc 编写于 作者: T TTerror 提交者: GitHub

add get xpu version api (#34594)

上级 cabfb4a7
......@@ -35,7 +35,7 @@ ELSE ()
ENDIF()
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210729")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210804")
SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
......
......@@ -1632,7 +1632,13 @@ All parameter, weight, gradient are variables in Paddle.
.def("__repr__", string::to_string<const platform::XPUPlace &>)
.def("__str__", string::to_string<const platform::XPUPlace &>);
#ifdef PADDLE_WITH_XPU
py::enum_<platform::XPUVersion>(m, "XPUVersion", py::arithmetic())
.value("XPU1", platform::XPUVersion::XPU1)
.value("XPU2", platform::XPUVersion::XPU2)
.export_values();
m.def("get_xpu_device_count", platform::GetXPUDeviceCount);
m.def("get_xpu_device_version",
[](int device_id) { return platform::get_xpu_version(device_id); });
#endif
py::class_<paddle::platform::CPUPlace>(m, "CPUPlace", R"DOC(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册