未验证 提交 1acb845a 编写于 作者: TaoTao Li's avatar TaoTao Li 提交者: GitHub

fix distributed comm context (#52787)

上级 e7652a37
...@@ -148,8 +148,7 @@ DeviceContext* StreamAnalyzer::ParseDeviceContext( ...@@ -148,8 +148,7 @@ DeviceContext* StreamAnalyzer::ParseDeviceContext(
const int stream_priority = op_func_node.stream_priority_; const int stream_priority = op_func_node.stream_priority_;
ContextManager& ctx_manager = ContextManager::Instance(); ContextManager& ctx_manager = ContextManager::Instance();
auto dev_ctx = ctx_manager.Get(op_type, place_, stream_priority).get().get(); DeviceContext* dev_ctx = nullptr;
SetDeviceCommContext(op.get(), dev_ctx);
// only gpu/npu need update. xpu not need, because xpu memcpy op kernel is // only gpu/npu need update. xpu not need, because xpu memcpy op kernel is
// synchronous. // synchronous.
...@@ -158,22 +157,30 @@ DeviceContext* StreamAnalyzer::ParseDeviceContext( ...@@ -158,22 +157,30 @@ DeviceContext* StreamAnalyzer::ParseDeviceContext(
VLOG(6) << "Parse DeviceContext for " << op_type VLOG(6) << "Parse DeviceContext for " << op_type
<< ", execution stream = " << execution_stream; << ", execution stream = " << execution_stream;
if (execution_stream != kDefaultStream) { if (execution_stream != kDefaultStream) {
return ctx_manager dev_ctx = ctx_manager
.Get(std::string(kCustomStream) + "-" + execution_stream, .Get(std::string(kCustomStream) + "-" + execution_stream,
place_, place_,
stream_priority) stream_priority)
.get() .get()
.get(); .get();
SetDeviceCommContext(op.get(), dev_ctx);
return dev_ctx;
} }
if (op_type == interpreter::kMemcpyD2H) { if (op_type == interpreter::kMemcpyD2H) {
return ctx_manager.Get(std::string(kD2HStream), place_, stream_priority) dev_ctx =
.get() ctx_manager.Get(std::string(kD2HStream), place_, stream_priority)
.get(); .get()
.get();
SetDeviceCommContext(op.get(), dev_ctx);
return dev_ctx;
} else if (op_type == interpreter::kMemcpyH2D) { } else if (op_type == interpreter::kMemcpyH2D) {
return ctx_manager.Get(std::string(kH2DStream), place_, stream_priority) dev_ctx =
.get() ctx_manager.Get(std::string(kH2DStream), place_, stream_priority)
.get(); .get()
.get();
SetDeviceCommContext(op.get(), dev_ctx);
return dev_ctx;
} }
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
...@@ -195,6 +202,7 @@ DeviceContext* StreamAnalyzer::ParseDeviceContext( ...@@ -195,6 +202,7 @@ DeviceContext* StreamAnalyzer::ParseDeviceContext(
#endif #endif
} }
SetDeviceCommContext(op.get(), op_func_node.dev_ctx_);
return op_func_node.dev_ctx_; return op_func_node.dev_ctx_;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册