未验证 提交 1acb845a 编写于 作者: TaoTao Li's avatar TaoTao Li 提交者: GitHub

fix distributed comm context (#52787)

上级 e7652a37
......@@ -148,8 +148,7 @@ DeviceContext* StreamAnalyzer::ParseDeviceContext(
const int stream_priority = op_func_node.stream_priority_;
ContextManager& ctx_manager = ContextManager::Instance();
auto dev_ctx = ctx_manager.Get(op_type, place_, stream_priority).get().get();
SetDeviceCommContext(op.get(), dev_ctx);
DeviceContext* dev_ctx = nullptr;
// only gpu/npu need update. xpu not need, because xpu memcpy op kernel is
// synchronous.
......@@ -158,22 +157,30 @@ DeviceContext* StreamAnalyzer::ParseDeviceContext(
VLOG(6) << "Parse DeviceContext for " << op_type
<< ", execution stream = " << execution_stream;
if (execution_stream != kDefaultStream) {
return ctx_manager
.Get(std::string(kCustomStream) + "-" + execution_stream,
place_,
stream_priority)
.get()
.get();
dev_ctx = ctx_manager
.Get(std::string(kCustomStream) + "-" + execution_stream,
place_,
stream_priority)
.get()
.get();
SetDeviceCommContext(op.get(), dev_ctx);
return dev_ctx;
}
if (op_type == interpreter::kMemcpyD2H) {
return ctx_manager.Get(std::string(kD2HStream), place_, stream_priority)
.get()
.get();
dev_ctx =
ctx_manager.Get(std::string(kD2HStream), place_, stream_priority)
.get()
.get();
SetDeviceCommContext(op.get(), dev_ctx);
return dev_ctx;
} else if (op_type == interpreter::kMemcpyH2D) {
return ctx_manager.Get(std::string(kH2DStream), place_, stream_priority)
.get()
.get();
dev_ctx =
ctx_manager.Get(std::string(kH2DStream), place_, stream_priority)
.get()
.get();
SetDeviceCommContext(op.get(), dev_ctx);
return dev_ctx;
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
......@@ -195,6 +202,7 @@ DeviceContext* StreamAnalyzer::ParseDeviceContext(
#endif
}
SetDeviceCommContext(op.get(), op_func_node.dev_ctx_);
return op_func_node.dev_ctx_;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册