未验证 提交 d90e24ac 编写于 作者: Q QingshuChen 提交者: GitHub

update xpu depends (#42365)

* update xpu depends
*test=kunlun

* minor
*test=kunlun
Co-authored-by: Nroot <root@yq01-sys-hic-p40-0091.yq01.baidu.com>
上级 2006b817
...@@ -9,7 +9,7 @@ SET(XPU_RT_LIB_NAME "libxpurt.so") ...@@ -9,7 +9,7 @@ SET(XPU_RT_LIB_NAME "libxpurt.so")
if(NOT DEFINED XPU_BASE_URL) if(NOT DEFINED XPU_BASE_URL)
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220411") SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220425")
else() else()
SET(XPU_BASE_URL "${XPU_BASE_URL}") SET(XPU_BASE_URL "${XPU_BASE_URL}")
endif() endif()
......
...@@ -54,7 +54,10 @@ std::vector<int> GetXPUSelectedDevices() { ...@@ -54,7 +54,10 @@ std::vector<int> GetXPUSelectedDevices() {
void MemcpySyncH2D(void* dst, const void* src, size_t count, void MemcpySyncH2D(void* dst, const void* src, size_t count,
const platform::XPUPlace& dst_place) { const platform::XPUPlace& dst_place) {
phi::backends::xpu::MemcpySyncH2D(dst, src, count, dst_place); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.GetByPlace(dst_place);
dev_ctx->Wait();
phi::backends::xpu::MemcpySyncH2D(dst, src, count, dst_place, *dev_ctx);
} }
void MemcpySyncD2H(void* dst, const void* src, size_t count, void MemcpySyncD2H(void* dst, const void* src, size_t count,
......
...@@ -140,8 +140,10 @@ std::vector<int> GetXPUSelectedDevices() { ...@@ -140,8 +140,10 @@ std::vector<int> GetXPUSelectedDevices() {
void MemcpySyncH2D(void* dst, void MemcpySyncH2D(void* dst,
const void* src, const void* src,
size_t count, size_t count,
const phi::XPUPlace& dst_place) { const phi::XPUPlace& dst_place,
const phi::XPUContext& dev_ctx) {
XPUDeviceGuard guard(dst_place.device); XPUDeviceGuard guard(dst_place.device);
dev_ctx.Wait();
PADDLE_ENFORCE_XPU_SUCCESS( PADDLE_ENFORCE_XPU_SUCCESS(
xpu_memcpy(dst, src, count, XPUMemcpyKind::XPU_HOST_TO_DEVICE)); xpu_memcpy(dst, src, count, XPUMemcpyKind::XPU_HOST_TO_DEVICE));
} }
......
...@@ -49,7 +49,8 @@ std::vector<int> GetXPUSelectedDevices(); ...@@ -49,7 +49,8 @@ std::vector<int> GetXPUSelectedDevices();
void MemcpySyncH2D(void *dst, void MemcpySyncH2D(void *dst,
const void *src, const void *src,
size_t count, size_t count,
const phi::XPUPlace &dst_place); const phi::XPUPlace &dst_place,
const phi::XPUContext &dev_ctx);
void MemcpySyncD2H(void *dst, void MemcpySyncD2H(void *dst,
const void *src, const void *src,
size_t count, size_t count,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册