From 809a64522292304efcc92d034162d801374f59b8 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Tue, 30 Nov 2021 14:08:30 +0800 Subject: [PATCH] [fleet_executor] pass the env from carrier to interceptor (#37691) --- paddle/fluid/distributed/fleet_executor/carrier.cc | 9 ++++++++- paddle/fluid/distributed/fleet_executor/carrier.h | 5 ++++- .../fluid/distributed/fleet_executor/fleet_executor.cc | 2 +- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 6b28089141..108a21b92f 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -26,7 +26,7 @@ USE_INTERCEPTOR(Compute); void Carrier::Init( const std::unordered_map& interceptor_id_to_node, - framework::Scope* minibatch_scope, + framework::Scope* root_scope, framework::Scope* minibatch_scope, const std::vector& microbatch_scopes, const platform::Place& place) { PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists( @@ -35,6 +35,8 @@ void Carrier::Init( minibatch_scope_ = minibatch_scope; microbatch_scopes_ = microbatch_scopes; place_ = place; + root_scope_ = root_scope; + dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_); CreateInterceptors(); is_init_ = true; } @@ -105,6 +107,7 @@ void Carrier::Start() { } std::unique_lock lock(running_mutex_); cond_var_.wait(lock); + dev_ctx_->Wait(); } std::condition_variable& Carrier::GetCondVar() { return cond_var_; } @@ -164,6 +167,10 @@ void Carrier::CreateInterceptors() { // TODO(wangxi): use node_type to select different Interceptor auto interceptor = std::make_unique(interceptor_id, task_node); + interceptor->SetPlace(place_); + interceptor->SetMiniBatchScope(minibatch_scope_); + interceptor->SetMicroBatchScope(microbatch_scopes_); + interceptor->SetRootScope(root_scope_); SetInterceptor(interceptor_id, std::move(interceptor)); VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id << "."; diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index ee6d3158bf..c4c6a41846 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -23,6 +23,7 @@ #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h" +#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/macros.h" @@ -48,7 +49,7 @@ class Carrier final { void Init( const std::unordered_map& interceptor_id_to_node, - framework::Scope* minibatch_scope, + framework::Scope* root_scope, framework::Scope* minibatch_scope, const std::vector& microbatch_scopes, const platform::Place& place); @@ -98,8 +99,10 @@ class Carrier final { std::mutex running_mutex_; std::condition_variable cond_var_; std::vector microbatch_scopes_; + framework::Scope* root_scope_; framework::Scope* minibatch_scope_; paddle::platform::Place place_; + paddle::platform::DeviceContext* dev_ctx_ = nullptr; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index 2483b4a545..ec60ec5fd5 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -58,7 +58,7 @@ void FleetExecutor::Init(const framework::ProgramDesc& program_desc, void FleetExecutor::InitCarrier() { Carrier& carrier_instance = Carrier::Instance(); if (!carrier_instance.IsInit()) { - carrier_instance.Init(runtime_graph_->intercepter_id_to_node(), + carrier_instance.Init(runtime_graph_->intercepter_id_to_node(), root_scope_, minibatch_scope_, microbatch_scopes_, place_); } } -- GitLab