提交 fbbcedda 编写于 作者: Y Yu Yang

Fix bug

上级 7643c2cb
...@@ -108,14 +108,13 @@ struct OpHandle { ...@@ -108,14 +108,13 @@ struct OpHandle {
} }
virtual void Wait(platform::DeviceContext *waited_dev) { virtual void Wait(platform::DeviceContext *waited_dev) {
if (platform::is_cpu_place(waited_dev->GetPlace()) && events_.empty()) { if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) {
for (auto &dev_ctx : dev_ctx_) { for (auto &dev_ctx : dev_ctx_) {
dev_ctx.second->Wait(); dev_ctx.second->Wait();
} }
} else { } else {
auto stream = auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream(); static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
for (auto &ev : events_) { for (auto &ev : events_) {
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0)); PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册