From d00b7d83a7d4ea36d3eca4d8bc54d97d8f0690a9 Mon Sep 17 00:00:00 2001 From: Ruibiao Chen Date: Wed, 19 Oct 2022 12:07:58 +0800 Subject: [PATCH] Support stream overlap for c_allreduce_sum (#47030) * Support stream overlap for c_allreduce_sum * Test CI * Add notes * Add SingleStreamGuard for BuildOpFuncList --- paddle/fluid/framework/ir/graph_helper.cc | 35 ++++- .../framework/new_executor/CMakeLists.txt | 45 +------ .../new_executor/interpreter/CMakeLists.txt | 42 +++++- .../{ => interpreter}/data_transfer.cc | 2 +- .../{ => interpreter}/data_transfer.h | 0 .../interpreter/dependency_builder.cc | 32 +---- .../{ => interpreter}/event_manager.cc | 2 +- .../{ => interpreter}/event_manager.h | 0 .../interpreter_util.cc} | 59 ++++++++- .../interpreter_util.h} | 2 + .../framework/new_executor/interpretercore.cc | 4 +- .../framework/new_executor/interpretercore.h | 4 +- .../new_executor/standalone_executor.cc | 2 +- .../framework/new_executor/stream_analyzer.cc | 124 +++++++++++------- .../framework/new_executor/stream_analyzer.h | 4 +- 15 files changed, 211 insertions(+), 146 deletions(-) rename paddle/fluid/framework/new_executor/{ => interpreter}/data_transfer.cc (99%) rename paddle/fluid/framework/new_executor/{ => interpreter}/data_transfer.h (100%) rename paddle/fluid/framework/new_executor/{ => interpreter}/event_manager.cc (97%) rename paddle/fluid/framework/new_executor/{ => interpreter}/event_manager.h (100%) rename paddle/fluid/framework/new_executor/{interpretercore_util.cc => interpreter/interpreter_util.cc} (95%) rename paddle/fluid/framework/new_executor/{interpretercore_util.h => interpreter/interpreter_util.h} (98%) diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 3db9814374..f5f28219ec 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -511,8 +511,10 @@ void ReplaceAllReduceOp(const Node &node, std::vector *ops) { #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) bool is_fused = (node.Name() == "fused_all_reduce"); + details::OpHandleBase &op_handle = const_cast(&node)->Wrapper(); + auto &in_var_handles = op_handle.Inputs(); // Even if PADDLE_WITH_NCCL is defined, if the program runs on CPU, // nccl_ctxs_ in NCCLOpHandleBase will be nullptr, and calling the @@ -537,7 +539,6 @@ void ReplaceAllReduceOp(const Node &node, << all_reduce_var_name; // get inputs of check_memory_continue - auto in_var_handles = op_handle.Inputs(); std::vector in_names; for (const auto &in : in_var_handles) { if (dynamic_cast(in) != nullptr) { @@ -555,7 +556,7 @@ void ReplaceAllReduceOp(const Node &node, fuse_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), (static_cast(OpRole::kBackward))); } else { - all_reduce_var_name = op_handle.Inputs()[0]->Name(); + all_reduce_var_name = in_var_handles[0]->Name(); } // add c_allreduce_sum OP @@ -568,7 +569,7 @@ void ReplaceAllReduceOp(const Node &node, int ring_id = platform::NCCLCommContext::Instance().GetRingId( dynamic_cast(&op_handle)->GetComm()); all_reduce_op_desc.SetAttr("ring_id", ring_id); - all_reduce_op_desc.SetAttr("use_calc_stream", true); + all_reduce_op_desc.SetAttr("use_calc_stream", false); all_reduce_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), (static_cast(OpRole::kBackward))); @@ -586,6 +587,34 @@ void ReplaceAllReduceOp(const Node &node, ->GradMergeCondName(); all_reduce_op_desc.SetInput("Cond", {cond_name}); } + + // Add dependency for FusedAllReduce. + // For the following example: + // ### fused_grad = FusedAllReduce(grad0, grad1, grad2, ...) + // ### v0 = op0(grad0) + // ### v1 = op1(grad1) + // It is converted to: + // ### fused_grad = check_memory_continue(grad0, grad1, grad2, ...) + // ### fused_grad = c_sum_allreduce(fused_grad) + // ### v0 = op0(grad0) + // ### v1 = op1(grad1) + // We should add the following dependency to ensure that op0 and op1 both run + // afer c_sum_allreduce: + // ### grad0 = depend(grad0, fused_grad) + // ### grad1 = depend(grad1, fused_grad) + if (is_fused) { + for (const auto &in : in_var_handles) { + if (dynamic_cast(in) != nullptr) { + continue; + } + ops->emplace_back(); + OpDesc &depend_op_desc = ops->back(); + depend_op_desc.SetType("depend"); + depend_op_desc.SetInput("X", {in->Name()}); + depend_op_desc.SetInput("Dep", {all_reduce_var_name}); + depend_op_desc.SetOutput("Out", {in->Name()}); + } + } #else PADDLE_THROW( platform::errors::Unimplemented("ReplaceAllReduceOp is only implemented " diff --git a/paddle/fluid/framework/new_executor/CMakeLists.txt b/paddle/fluid/framework/new_executor/CMakeLists.txt index 2e6f273490..ca9bf5f86d 100644 --- a/paddle/fluid/framework/new_executor/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/CMakeLists.txt @@ -2,48 +2,11 @@ add_subdirectory(garbage_collector) add_subdirectory(interpreter) add_subdirectory(workqueue) -set(STANDALONE_EXECUTOR_SRCS - data_transfer.cc - new_executor_defs.cc - interpretercore_util.cc - event_manager.cc - stream_analyzer.cc - interpretercore.cc - standalone_executor.cc) +set(STANDALONE_EXECUTOR_SRCS interpretercore.cc new_executor_defs.cc + stream_analyzer.cc standalone_executor.cc) -set(STANDALONE_EXECUTOR_DEPS - dependency_builder - device_context - execution_config - op_registry - scope - framework_proto - data_feed_proto - ops_extra_info - heter_service_proto - trainer_desc_proto - glog - lod_rank_table - fs - shell - fleet_wrapper - heter_wrapper - ps_gpu_wrapper - box_wrapper - lodtensor_printer - feed_fetch_method - graph_to_program_pass - variable_helper - timer - monitor - nan_inf_utils - enforce - scope - glog - workqueue - interpretercore_garbage_collector - ${DEVICE_EVENT_LIBS} - glog) +set(STANDALONE_EXECUTOR_DEPS interpreter interpretercore_garbage_collector + workqueue) cc_library( standalone_executor diff --git a/paddle/fluid/framework/new_executor/interpreter/CMakeLists.txt b/paddle/fluid/framework/new_executor/interpreter/CMakeLists.txt index dc4b2e6407..019e97e689 100644 --- a/paddle/fluid/framework/new_executor/interpreter/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/interpreter/CMakeLists.txt @@ -1,9 +1,37 @@ -cc_library( - dependency_builder - SRCS dependency_builder.cc - DEPS operator) +set(INTERPRETER_SRCS data_transfer.cc dependency_builder.cc event_manager.cc + execution_config.cc interpreter_util.cc) + +set(INTERPRETER_DEPS + device_context + op_registry + scope + framework_proto + data_feed_proto + ops_extra_info + heter_service_proto + trainer_desc_proto + glog + lod_rank_table + fs + shell + fleet_wrapper + heter_wrapper + ps_gpu_wrapper + box_wrapper + lodtensor_printer + feed_fetch_method + graph_to_program_pass + variable_helper + timer + monitor + nan_inf_utils + enforce + scope + glog + ${DEVICE_EVENT_LIBS} + glog) cc_library( - execution_config - SRCS execution_config.cc - DEPS phi_backends) + interpreter + SRCS ${INTERPRETER_SRCS} + DEPS ${INTERPRETER_DEPS}) diff --git a/paddle/fluid/framework/new_executor/data_transfer.cc b/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc similarity index 99% rename from paddle/fluid/framework/new_executor/data_transfer.cc rename to paddle/fluid/framework/new_executor/interpreter/data_transfer.cc index 24ec52ca3b..3c6a0740a5 100644 --- a/paddle/fluid/framework/new_executor/data_transfer.cc +++ b/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/new_executor/data_transfer.h" +#include "paddle/fluid/framework/new_executor/interpreter/data_transfer.h" #include "paddle/fluid/framework/convert_utils.h" #include "paddle/phi/core/kernel_context.h" diff --git a/paddle/fluid/framework/new_executor/data_transfer.h b/paddle/fluid/framework/new_executor/interpreter/data_transfer.h similarity index 100% rename from paddle/fluid/framework/new_executor/data_transfer.h rename to paddle/fluid/framework/new_executor/interpreter/data_transfer.h diff --git a/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc b/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc index 49e675a473..3b2a2aed7f 100644 --- a/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc +++ b/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/new_executor/interpreter/dependency_builder.h" #include +#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" namespace paddle { namespace framework { @@ -27,22 +28,6 @@ size_t CountDownstreamMap(const std::map>& downstream_map) { } return count; } - -bool IsCommunicationOp(const std::string& op_name) { - const std::set special_comm_op_set = { - "send", - "recv", - "send_v2", - "recv_v2", - }; - const std::string communication_op_prefix = "c_"; - if (op_name.find(communication_op_prefix) != std::string::npos || - special_comm_op_set.count(op_name)) { - return true; - } - return false; -} - const std::string StringizeDownstreamMap( const std::map>& downstream_map) { std::ostringstream oss; @@ -187,21 +172,6 @@ void DependencyBuilder::AddDependencyForCoalesceTensorOp() { } void DependencyBuilder::AddDependencyForCommunicationOp() { - auto IsCommunicationOp = [](std::string op) -> bool { - const std::set special_comm_op_set = { - "send", - "recv", - "send_v2", - "recv_v2", - }; - const std::string communication_op_prefix = "c_"; - if (op.find(communication_op_prefix) != std::string::npos || - special_comm_op_set.count(op)) { - return true; - } - return false; - }; - int dependence_op_idx = -1; for (size_t op_idx = 0; op_idx < op_num_; ++op_idx) { if (IsCommunicationOp(instructions_->at(op_idx).OpBase()->Type())) { diff --git a/paddle/fluid/framework/new_executor/event_manager.cc b/paddle/fluid/framework/new_executor/interpreter/event_manager.cc similarity index 97% rename from paddle/fluid/framework/new_executor/event_manager.cc rename to paddle/fluid/framework/new_executor/interpreter/event_manager.cc index 7135e83705..70a365d795 100644 --- a/paddle/fluid/framework/new_executor/event_manager.cc +++ b/paddle/fluid/framework/new_executor/interpreter/event_manager.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/new_executor/event_manager.h" +#include "paddle/fluid/framework/new_executor/interpreter/event_manager.h" #include "paddle/fluid/platform/profiler/event_tracing.h" diff --git a/paddle/fluid/framework/new_executor/event_manager.h b/paddle/fluid/framework/new_executor/interpreter/event_manager.h similarity index 100% rename from paddle/fluid/framework/new_executor/event_manager.h rename to paddle/fluid/framework/new_executor/interpreter/event_manager.h diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc similarity index 95% rename from paddle/fluid/framework/new_executor/interpretercore_util.cc rename to paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index f41cda93bf..120d00e427 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/new_executor/interpretercore_util.h" +#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include #include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/executor_gc_helper.h" -#include "paddle/fluid/framework/new_executor/data_transfer.h" +#include "paddle/fluid/framework/new_executor/interpreter/data_transfer.h" #include "paddle/fluid/memory/stats.h" #include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h" #include "paddle/fluid/operators/controlflow/recurrent_op_helper.h" @@ -50,6 +50,38 @@ namespace interpreter { using VariableIdMap = std::map>; +// NOTE(Ruibiao): SingleStreamGuard make some multi-strem op (i.e., +// c_allreduce_sum) run in single stream. It is dedicated to BuildOpFuncList +// which run kernel without stream synchronization. +class SingleStreamGuard { + public: + explicit SingleStreamGuard(std::shared_ptr& op) : op_(op) { + if (op_->Type() == "c_allreduce_sum" && + op_->Attr("use_calc_stream") == false) { + VLOG(6) << "Set c_allredce_sum's attr use_calc_stream to true"; + op_->SetAttr("use_calc_stream", true); + is_changed = true; + } + } + + ~SingleStreamGuard() { + if (!is_changed) { + return; + } + + if (op_->Type() == "c_allreduce_sum") { + op_->SetAttr("use_calc_stream", false); + VLOG(6) << "Set c_allredce_sum's attr use_calc_stream to false"; + } + } + + DISABLE_COPY_AND_ASSIGN(SingleStreamGuard); + + private: + bool is_changed{false}; + std::shared_ptr op_; +}; + const std::vector ConstructWorkQueueOptions( size_t host_num_threads, size_t device_num_threads, EventsWaiter* waiter) { std::vector group_options; @@ -471,6 +503,9 @@ void BuildOpFuncList(const platform::Place& place, op_func_node.operator_base_ = ops[i]; op_func_node.input_index = ins_name2id; op_func_node.output_index = outs_name2id; + + SingleStreamGuard single_stream_guard(ops[i]); + VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope); #ifdef PADDLE_WITH_ASCEND_CL @@ -514,16 +549,13 @@ void BuildOpFuncList(const platform::Place& place, auto& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); - VLOG(4) << "get dev_ctx"; auto exec_ctx = ExecutionContext( *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context); - VLOG(4) << "get exec_ctx"; auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(exec_ctx); - VLOG(4) << "get expected_kernel_key"; + VLOG(4) << "expected_kernel_key : " << expected_kernel_key; // change device by the device_guard() ApplyDeviceGuard(op, place, &expected_kernel_key); - VLOG(4) << "expected_kernel_key : " << expected_kernel_key; // step 2. select op kernel auto run_phi_kernel = false; @@ -722,6 +754,21 @@ void AddFetch(const std::vector& fetch_names, } } +bool IsCommunicationOp(const std::string& op_name) { + const std::set special_comm_op_set = { + "send", + "recv", + "send_v2", + "recv_v2", + }; + const std::string communication_op_prefix = "c_"; + if (op_name.find(communication_op_prefix) != std::string::npos || + special_comm_op_set.count(op_name)) { + return true; + } + return false; +} + } // namespace interpreter } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.h b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h similarity index 98% rename from paddle/fluid/framework/new_executor/interpretercore_util.h rename to paddle/fluid/framework/new_executor/interpreter/interpreter_util.h index 3e96262407..7bd82a5dd5 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.h +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h @@ -81,6 +81,8 @@ void BuildOpFuncList(const platform::Place& place, void AddFetch(const std::vector& fetch_names, framework::BlockDesc* block); +bool IsCommunicationOp(const std::string& op_name); + } // namespace interpreter } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index cdb5fcc58a..2bc2652688 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -18,7 +18,7 @@ #include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/details/share_tensor_buffer_functor.h" -#include "paddle/fluid/framework/new_executor/interpretercore_util.h" +#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/os_info.h" @@ -101,7 +101,7 @@ inline void SetDeviceId(const platform::Place& place) { } } -// TODO(Ruibia): Pass skip_gc_vars, used_for_jit, and other config messages by +// TODO(Ruibiao): Pass skip_gc_vars, used_for_jit, and other config messages by // constructing an interpreter::ExecutionConfig InterpreterCore::InterpreterCore(const platform::Place& place, const BlockDesc& block, diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index 8e63c970e1..530c8e9d04 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -20,11 +20,11 @@ #include #include "paddle/fluid/framework/details/exception_holder.h" -#include "paddle/fluid/framework/new_executor/event_manager.h" #include "paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.h" #include "paddle/fluid/framework/new_executor/interpreter/dependency_builder.h" +#include "paddle/fluid/framework/new_executor/interpreter/event_manager.h" #include "paddle/fluid/framework/new_executor/interpreter/execution_config.h" -#include "paddle/fluid/framework/new_executor/interpretercore_util.h" +#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/fluid/framework/new_executor/new_executor_defs.h" #include "paddle/fluid/framework/new_executor/profiler.h" #include "paddle/fluid/framework/new_executor/stream_analyzer.h" diff --git a/paddle/fluid/framework/new_executor/standalone_executor.cc b/paddle/fluid/framework/new_executor/standalone_executor.cc index a2c8ecfac5..2fe686b808 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/new_executor/standalone_executor.h" -#include "paddle/fluid/framework/new_executor/interpretercore_util.h" +#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/fluid/platform/profiler/event_tracing.h" namespace paddle { diff --git a/paddle/fluid/framework/new_executor/stream_analyzer.cc b/paddle/fluid/framework/new_executor/stream_analyzer.cc index 3025f01747..09c54a6480 100644 --- a/paddle/fluid/framework/new_executor/stream_analyzer.cc +++ b/paddle/fluid/framework/new_executor/stream_analyzer.cc @@ -17,44 +17,50 @@ #include #include +#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" +#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device_context.h" namespace paddle { namespace framework { -namespace { -std::map>>* - d2h_ctxs = nullptr; -std::map>>* - h2d_ctxs = nullptr; -std::mutex ctx_mtx; -} // namespace - -StreamAnalyzer::StreamAnalyzer(const platform::Place& place) : place_(place) { - if (platform::is_gpu_place(place) || platform::is_npu_place(place) || - platform::is_custom_place(place)) { - std::lock_guard lk(ctx_mtx); - if (d2h_ctxs == nullptr) { - d2h_ctxs = new std::map< - Place, - std::shared_future>>(); - h2d_ctxs = new std::map< - Place, - std::shared_future>>(); - } - if (d2h_ctxs->find(place) == d2h_ctxs->end()) { - platform::EmplaceDeviceContexts( - d2h_ctxs, - {place}, - /*disable_setting_default_stream_for_allocator=*/true); + +// stream types +constexpr const char* kD2HStream = "D2HStream"; +constexpr const char* kH2DStream = "H2DStream"; + +class ContextManager { + public: + using DeviceContextMap = + std::map>>; + + static ContextManager& Instance() { + static ContextManager* ctx_manager = new ContextManager; + return *ctx_manager; + } + + std::shared_future> Get( + const std::string& type, const platform::Place& place) { + std::lock_guard lk(ctx_mtx_); + VLOG(6) << "Get dev_ctx for " << type << " - " << place; + + DeviceContextMap& ctxs = ctx_pool_[type]; + if (ctxs.find(place) == ctxs.end()) { platform::EmplaceDeviceContexts( - h2d_ctxs, + &ctxs, {place}, /*disable_setting_default_stream_for_allocator=*/true); } - d2h_ctx_ = (*d2h_ctxs)[place]; - h2d_ctx_ = (*h2d_ctxs)[place]; + return ctxs[place]; } -} + + private: + ContextManager() {} + DISABLE_COPY_AND_ASSIGN(ContextManager); + + std::mutex ctx_mtx_; + std::unordered_map ctx_pool_; +}; /* * Parse the var_ids that need to be associated with an event. @@ -88,23 +94,26 @@ std::vector StreamAnalyzer::GetNeedEventVarIds( return false; }; + bool is_comm = interpreter::IsCommunicationOp(cur_instr.OpBase()->Type()) || + interpreter::IsCommunicationOp(next_instr.OpBase()->Type()); std::vector need_event_var_ids; for (auto& item : next_instr.Inputs()) { for (auto var_id : item.second) { if (unique_var_ids.count(var_id) > 0) { - if (next_instr.NoDataTransformVars().count(var_id)) { - VLOG(4) << "Skip inserting event at variable " << item.first - << " of operator " << next_instr.OpBase()->Type() - << " since it is NoDataTransform"; - continue; + if (!is_comm) { + if (next_instr.NoDataTransformVars().count(var_id)) { + VLOG(4) << "Skip inserting event at variable " << item.first + << " of operator " << next_instr.OpBase()->Type() + << " since it is NoDataTransform"; + continue; + } + if (is_no_need_buffer(item.first)) { + VLOG(4) << "Skip inserting event at variable " << item.first + << " of operator " << next_instr.OpBase()->Type() + << " since it is NoNeedBufferVar"; + continue; + } } - if (is_no_need_buffer(item.first)) { - VLOG(4) << "Skip inserting event at variable " << item.first - << " of operator " << next_instr.OpBase()->Type() - << " since it is NoNeedBufferVar"; - continue; - } - need_event_var_ids.push_back(var_id); } } @@ -175,21 +184,40 @@ void StreamAnalyzer::Schedule(const std::vector& downstream_ops, platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( const OpFuncNode& op_func_node) { - auto& op_type = op_func_node.operator_base_->Type(); - auto* dev_ctx = op_func_node.dev_ctx_; + auto& op = op_func_node.operator_base_; + auto& op_type = op->Type(); + ContextManager& ctx_manager = ContextManager::Instance(); + // only gpu/npu need update. xpu not need, because xpu memcpy op kernel is // synchronous. if (platform::is_gpu_place(place_) || platform::is_npu_place(place_) || platform::is_custom_place(place_)) { if (op_type == interpreter::kMemcpyD2H) { - VLOG(3) << "Get dev_ctx from d2h_context_pool_"; - dev_ctx = d2h_ctx_.get().get(); + return ctx_manager.Get(std::string(kD2HStream), place_).get().get(); } else if (op_type == interpreter::kMemcpyH2D) { - VLOG(3) << "Get dev_ctx from h2d_context_pool_"; - dev_ctx = h2d_ctx_.get().get(); + return ctx_manager.Get(std::string(kH2DStream), place_).get().get(); + } + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + // NOTE(Ruibiao): Here supports multi-stream overlap for c_allreduce_sum + // with use_cal_stream==false by returning a device context getting from the + // global NCCLCommContext instance. Because when use_calc_stream==false, in + // OP kernel, the NCCL communication will be launched to the stream directly + // getting from the global NCCLCommContext instance rather than the + // DeviceContext passed from executor (see CAllReduceOpCUDAKernel in + // c_allreduce_op.h). Now it is just a temporary solution for ONLY + // c_allreduce_sum which is used in ResNet50 distributed training. + if (op_type == "c_allreduce_sum" && + op->Attr("use_calc_stream") == false) { + int ring_id = op->Attr("ring_id"); + return platform::NCCLCommContext::Instance() + .Get(ring_id, place_) + ->dev_context(); } +#endif } - return dev_ctx; + + return op_func_node.dev_ctx_; } /* diff --git a/paddle/fluid/framework/new_executor/stream_analyzer.h b/paddle/fluid/framework/new_executor/stream_analyzer.h index 4be8ffe6bb..3cdcfc68c8 100644 --- a/paddle/fluid/framework/new_executor/stream_analyzer.h +++ b/paddle/fluid/framework/new_executor/stream_analyzer.h @@ -29,7 +29,7 @@ class StreamAnalyzer { using Place = platform::Place; using DeviceContext = platform::DeviceContext; - explicit StreamAnalyzer(const Place& place); + explicit StreamAnalyzer(const Place& place) : place_(place) {} ~StreamAnalyzer() {} @@ -54,8 +54,6 @@ class StreamAnalyzer { platform::DeviceType GetWaiterType(const Instruction& instr); const Place place_; - std::shared_future> d2h_ctx_; - std::shared_future> h2d_ctx_; std::map> var_id2event_; }; -- GitLab