未验证 提交 809a6452 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet_executor] pass the env from carrier to interceptor (#37691)

上级 3f2a665a
...@@ -26,7 +26,7 @@ USE_INTERCEPTOR(Compute); ...@@ -26,7 +26,7 @@ USE_INTERCEPTOR(Compute);
void Carrier::Init( void Carrier::Init(
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node, const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* minibatch_scope, framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes, const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place) { const platform::Place& place) {
PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists( PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
...@@ -35,6 +35,8 @@ void Carrier::Init( ...@@ -35,6 +35,8 @@ void Carrier::Init(
minibatch_scope_ = minibatch_scope; minibatch_scope_ = minibatch_scope;
microbatch_scopes_ = microbatch_scopes; microbatch_scopes_ = microbatch_scopes;
place_ = place; place_ = place;
root_scope_ = root_scope;
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
CreateInterceptors(); CreateInterceptors();
is_init_ = true; is_init_ = true;
} }
...@@ -105,6 +107,7 @@ void Carrier::Start() { ...@@ -105,6 +107,7 @@ void Carrier::Start() {
} }
std::unique_lock<std::mutex> lock(running_mutex_); std::unique_lock<std::mutex> lock(running_mutex_);
cond_var_.wait(lock); cond_var_.wait(lock);
dev_ctx_->Wait();
} }
std::condition_variable& Carrier::GetCondVar() { return cond_var_; } std::condition_variable& Carrier::GetCondVar() { return cond_var_; }
...@@ -164,6 +167,10 @@ void Carrier::CreateInterceptors() { ...@@ -164,6 +167,10 @@ void Carrier::CreateInterceptors() {
// TODO(wangxi): use node_type to select different Interceptor // TODO(wangxi): use node_type to select different Interceptor
auto interceptor = auto interceptor =
std::make_unique<Interceptor>(interceptor_id, task_node); std::make_unique<Interceptor>(interceptor_id, task_node);
interceptor->SetPlace(place_);
interceptor->SetMiniBatchScope(minibatch_scope_);
interceptor->SetMicroBatchScope(microbatch_scopes_);
interceptor->SetRootScope(root_scope_);
SetInterceptor(interceptor_id, std::move(interceptor)); SetInterceptor(interceptor_id, std::move(interceptor));
VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id
<< "."; << ".";
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h" #include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
...@@ -48,7 +49,7 @@ class Carrier final { ...@@ -48,7 +49,7 @@ class Carrier final {
void Init( void Init(
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node, const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* minibatch_scope, framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes, const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place); const platform::Place& place);
...@@ -98,8 +99,10 @@ class Carrier final { ...@@ -98,8 +99,10 @@ class Carrier final {
std::mutex running_mutex_; std::mutex running_mutex_;
std::condition_variable cond_var_; std::condition_variable cond_var_;
std::vector<framework::Scope*> microbatch_scopes_; std::vector<framework::Scope*> microbatch_scopes_;
framework::Scope* root_scope_;
framework::Scope* minibatch_scope_; framework::Scope* minibatch_scope_;
paddle::platform::Place place_; paddle::platform::Place place_;
paddle::platform::DeviceContext* dev_ctx_ = nullptr;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -58,7 +58,7 @@ void FleetExecutor::Init(const framework::ProgramDesc& program_desc, ...@@ -58,7 +58,7 @@ void FleetExecutor::Init(const framework::ProgramDesc& program_desc,
void FleetExecutor::InitCarrier() { void FleetExecutor::InitCarrier() {
Carrier& carrier_instance = Carrier::Instance(); Carrier& carrier_instance = Carrier::Instance();
if (!carrier_instance.IsInit()) { if (!carrier_instance.IsInit()) {
carrier_instance.Init(runtime_graph_->intercepter_id_to_node(), carrier_instance.Init(runtime_graph_->intercepter_id_to_node(), root_scope_,
minibatch_scope_, microbatch_scopes_, place_); minibatch_scope_, microbatch_scopes_, place_);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册