diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index be911eb7eaced944cc1562f61fbce91062f2f6f7..d5ccf1297922f5dfb08993aa37200db194be9a71 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -9,7 +9,7 @@ SET(XPU_RT_LIB_NAME "libxpurt.so") if(NOT DEFINED XPU_BASE_URL) SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") - SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220411") + SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220425") else() SET(XPU_BASE_URL "${XPU_BASE_URL}") endif() diff --git a/paddle/fluid/platform/device/xpu/xpu_info.cc b/paddle/fluid/platform/device/xpu/xpu_info.cc index 6a58f7890f9fa932f21147adeb56f8bd49887d04..2e960c1c0dd9cbb6ecabfdcf98872f73c9c9fd61 100644 --- a/paddle/fluid/platform/device/xpu/xpu_info.cc +++ b/paddle/fluid/platform/device/xpu/xpu_info.cc @@ -54,7 +54,10 @@ std::vector GetXPUSelectedDevices() { void MemcpySyncH2D(void* dst, const void* src, size_t count, const platform::XPUPlace& dst_place) { - phi::backends::xpu::MemcpySyncH2D(dst, src, count, dst_place); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.GetByPlace(dst_place); + dev_ctx->Wait(); + phi::backends::xpu::MemcpySyncH2D(dst, src, count, dst_place, *dev_ctx); } void MemcpySyncD2H(void* dst, const void* src, size_t count, diff --git a/paddle/phi/backends/xpu/xpu_info.cc b/paddle/phi/backends/xpu/xpu_info.cc index d454fc0734c66aca37a55c53ec5a2d9206cfcc5b..4dba0ab94ff2014829260f1e7c078b54ce4dd117 100644 --- a/paddle/phi/backends/xpu/xpu_info.cc +++ b/paddle/phi/backends/xpu/xpu_info.cc @@ -140,8 +140,10 @@ std::vector GetXPUSelectedDevices() { void MemcpySyncH2D(void* dst, const void* src, size_t count, - const phi::XPUPlace& dst_place) { + const phi::XPUPlace& dst_place, + const phi::XPUContext& dev_ctx) { XPUDeviceGuard guard(dst_place.device); + dev_ctx.Wait(); PADDLE_ENFORCE_XPU_SUCCESS( xpu_memcpy(dst, src, count, XPUMemcpyKind::XPU_HOST_TO_DEVICE)); } diff --git a/paddle/phi/backends/xpu/xpu_info.h b/paddle/phi/backends/xpu/xpu_info.h index fa7d1b5c18a7d26c50d772717c32f188ba2b4bf4..b1056cdc4b14bf4abe80984563027f60eda1b283 100644 --- a/paddle/phi/backends/xpu/xpu_info.h +++ b/paddle/phi/backends/xpu/xpu_info.h @@ -49,7 +49,8 @@ std::vector GetXPUSelectedDevices(); void MemcpySyncH2D(void *dst, const void *src, size_t count, - const phi::XPUPlace &dst_place); + const phi::XPUPlace &dst_place, + const phi::XPUContext &dev_ctx); void MemcpySyncD2H(void *dst, const void *src, size_t count,