未验证 提交 d90e24ac 编写于 作者: Q QingshuChen 提交者: GitHub

update xpu depends (#42365)

* update xpu depends
*test=kunlun

* minor
*test=kunlun
Co-authored-by: Nroot <root@yq01-sys-hic-p40-0091.yq01.baidu.com>
上级 2006b817
......@@ -9,7 +9,7 @@ SET(XPU_RT_LIB_NAME "libxpurt.so")
if(NOT DEFINED XPU_BASE_URL)
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220411")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220425")
else()
SET(XPU_BASE_URL "${XPU_BASE_URL}")
endif()
......
......@@ -54,7 +54,10 @@ std::vector<int> GetXPUSelectedDevices() {
void MemcpySyncH2D(void* dst, const void* src, size_t count,
const platform::XPUPlace& dst_place) {
phi::backends::xpu::MemcpySyncH2D(dst, src, count, dst_place);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.GetByPlace(dst_place);
dev_ctx->Wait();
phi::backends::xpu::MemcpySyncH2D(dst, src, count, dst_place, *dev_ctx);
}
void MemcpySyncD2H(void* dst, const void* src, size_t count,
......
......@@ -140,8 +140,10 @@ std::vector<int> GetXPUSelectedDevices() {
void MemcpySyncH2D(void* dst,
const void* src,
size_t count,
const phi::XPUPlace& dst_place) {
const phi::XPUPlace& dst_place,
const phi::XPUContext& dev_ctx) {
XPUDeviceGuard guard(dst_place.device);
dev_ctx.Wait();
PADDLE_ENFORCE_XPU_SUCCESS(
xpu_memcpy(dst, src, count, XPUMemcpyKind::XPU_HOST_TO_DEVICE));
}
......
......@@ -49,7 +49,8 @@ std::vector<int> GetXPUSelectedDevices();
void MemcpySyncH2D(void *dst,
const void *src,
size_t count,
const phi::XPUPlace &dst_place);
const phi::XPUPlace &dst_place,
const phi::XPUContext &dev_ctx);
void MemcpySyncD2H(void *dst,
const void *src,
size_t count,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册