未验证 提交 809a6452 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet_executor] pass the env from carrier to interceptor (#37691)

上级 3f2a665a
......@@ -26,7 +26,7 @@ USE_INTERCEPTOR(Compute);
void Carrier::Init(
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* minibatch_scope,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place) {
PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
......@@ -35,6 +35,8 @@ void Carrier::Init(
minibatch_scope_ = minibatch_scope;
microbatch_scopes_ = microbatch_scopes;
place_ = place;
root_scope_ = root_scope;
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
CreateInterceptors();
is_init_ = true;
}
......@@ -105,6 +107,7 @@ void Carrier::Start() {
}
std::unique_lock<std::mutex> lock(running_mutex_);
cond_var_.wait(lock);
dev_ctx_->Wait();
}
std::condition_variable& Carrier::GetCondVar() { return cond_var_; }
......@@ -164,6 +167,10 @@ void Carrier::CreateInterceptors() {
// TODO(wangxi): use node_type to select different Interceptor
auto interceptor =
std::make_unique<Interceptor>(interceptor_id, task_node);
interceptor->SetPlace(place_);
interceptor->SetMiniBatchScope(minibatch_scope_);
interceptor->SetMicroBatchScope(microbatch_scopes_);
interceptor->SetRootScope(root_scope_);
SetInterceptor(interceptor_id, std::move(interceptor));
VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id
<< ".";
......
......@@ -23,6 +23,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h"
......@@ -48,7 +49,7 @@ class Carrier final {
void Init(
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* minibatch_scope,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place);
......@@ -98,8 +99,10 @@ class Carrier final {
std::mutex running_mutex_;
std::condition_variable cond_var_;
std::vector<framework::Scope*> microbatch_scopes_;
framework::Scope* root_scope_;
framework::Scope* minibatch_scope_;
paddle::platform::Place place_;
paddle::platform::DeviceContext* dev_ctx_ = nullptr;
};
} // namespace distributed
......
......@@ -58,7 +58,7 @@ void FleetExecutor::Init(const framework::ProgramDesc& program_desc,
void FleetExecutor::InitCarrier() {
Carrier& carrier_instance = Carrier::Instance();
if (!carrier_instance.IsInit()) {
carrier_instance.Init(runtime_graph_->intercepter_id_to_node(),
carrier_instance.Init(runtime_graph_->intercepter_id_to_node(), root_scope_,
minibatch_scope_, microbatch_scopes_, place_);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册