未验证 提交 16562a9d 编写于 作者: J james 提交者: GitHub

use correct xpu stream for synchronization (#48470)

some legacy code still use xpu_wait() for stream sync -- it only syncs
default stream. this PR replaces them with dev_ctx.Wait() to ensure
that correct stream is always used
上级 7bf7e6e0
...@@ -151,7 +151,7 @@ class BinaryLogicalOpXPUKernel : public framework::OpKernel<T> { ...@@ -151,7 +151,7 @@ class BinaryLogicalOpXPUKernel : public framework::OpKernel<T> {
XpuLogicalType2Str(xpu_type))); XpuLogicalType2Str(xpu_type)));
if (need_broad_cast && dev_ctx.x_context()->xpu_stream != nullptr) { if (need_broad_cast && dev_ctx.x_context()->xpu_stream != nullptr) {
xpu_wait(); dev_ctx.Wait();
} }
} }
}; };
......
...@@ -72,7 +72,7 @@ void MemcpySyncD2H(void* dst, ...@@ -72,7 +72,7 @@ void MemcpySyncD2H(void* dst,
} }
// if src.device == dst.device and you need sync , after call this function, // if src.device == dst.device and you need sync , after call this function,
// need to call xpu_wait() // need to call dev_ctx.Wait()
void MemcpySyncD2D(void* dst, void MemcpySyncD2D(void* dst,
const platform::XPUPlace& dst_place, const platform::XPUPlace& dst_place,
const void* src, const void* src,
......
...@@ -160,7 +160,7 @@ void MemcpySyncD2H(void* dst, ...@@ -160,7 +160,7 @@ void MemcpySyncD2H(void* dst,
} }
// if src.device == dst.device and you need sync , after call this function, // if src.device == dst.device and you need sync , after call this function,
// need to call xpu_wait() // need to call dev_ctx.Wait()
void MemcpySyncD2D(void* dst, void MemcpySyncD2D(void* dst,
const phi::XPUPlace& dst_place, const phi::XPUPlace& dst_place,
const void* src, const void* src,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册