未验证 提交 87e65a99 编写于 作者: L LiYuRio 提交者: GitHub

[Fleet_Executor] Passing runtime scope and place (#37603)

上级 0156669e
......@@ -17,6 +17,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace distributed {
......@@ -24,10 +25,16 @@ namespace distributed {
USE_INTERCEPTOR(Compute);
void Carrier::Init(
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node) {
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place) {
PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
"Carrier is already init."));
interceptor_id_to_node_ = interceptor_id_to_node;
minibatch_scope_ = minibatch_scope;
microbatch_scopes_ = microbatch_scopes;
place_ = place;
CreateInterceptors();
is_init_ = true;
}
......
......@@ -26,8 +26,13 @@
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
class Scope;
}
namespace distributed {
class TaskNode;
......@@ -42,7 +47,10 @@ class Carrier final {
}
void Init(
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node);
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place);
~Carrier();
......@@ -89,6 +97,9 @@ class Carrier final {
std::mutex running_mutex_;
std::condition_variable cond_var_;
std::vector<framework::Scope*> microbatch_scopes_;
framework::Scope* minibatch_scope_;
paddle::platform::Place place_;
};
} // namespace distributed
......
......@@ -19,6 +19,8 @@
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable_helper.h"
namespace paddle {
namespace distributed {
......@@ -33,8 +35,21 @@ FleetExecutor::~FleetExecutor() {
// Destroy Executor
}
void FleetExecutor::Init(const paddle::framework::ProgramDesc& program_desc) {
void FleetExecutor::Init(const framework::ProgramDesc& program_desc,
framework::Scope* scope,
const platform::Place& place) {
runtime_graph_ = std::make_unique<RuntimeGraph>(program_desc, exe_desc_);
root_scope_ = scope;
place_ = place;
PADDLE_ENFORCE_NOT_NULL(root_scope_, platform::errors::InvalidArgument(
"root_scope_ can not be nullptr"));
minibatch_scope_ = &root_scope_->NewScope();
int64_t num_micro_batches = exe_desc_.num_micro_batches();
microbatch_scopes_.resize(num_micro_batches);
for (int i = 0; i < num_micro_batches; ++i) {
microbatch_scopes_[i] = &minibatch_scope_->NewScope();
CopyParameters(i, program_desc);
}
VLOG(5) << runtime_graph_->DebugString();
InitCarrier();
InitMessageBus();
......@@ -43,7 +58,8 @@ void FleetExecutor::Init(const paddle::framework::ProgramDesc& program_desc) {
void FleetExecutor::InitCarrier() {
Carrier& carrier_instance = Carrier::Instance();
if (!carrier_instance.IsInit()) {
carrier_instance.Init(runtime_graph_->intercepter_id_to_node());
carrier_instance.Init(runtime_graph_->intercepter_id_to_node(),
minibatch_scope_, microbatch_scopes_, place_);
}
}
......@@ -97,8 +113,25 @@ void FleetExecutor::Run() {
carrier_instance.Start();
}
void FleetExecutor::Release() {
// Release
void FleetExecutor::Release() { root_scope_->DropKids(); }
void FleetExecutor::CopyParameters(int microbatch_id,
const framework::ProgramDesc& program) {
auto& global_block = program.Block(0);
for (auto& var : global_block.AllVars()) {
if (var->Persistable() && microbatch_id == 0) {
auto* ptr = root_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(5) << "Create persistable var: " << var->Name()
<< ", which pointer is " << ptr;
} else if (!var->Persistable()) {
auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name());
VLOG(5) << "Create variable " << var->Name() << " for microbatch "
<< microbatch_id << ", which pointer is " << ptr << ".";
InitializeVariable(ptr, var->GetType());
}
}
}
} // namespace distributed
......
......@@ -18,10 +18,12 @@
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
class ProgramDesc;
class Scope;
}
namespace distributed {
......@@ -34,16 +36,22 @@ class FleetExecutor final {
FleetExecutor() = delete;
explicit FleetExecutor(const std::string& exe_desc_str);
~FleetExecutor();
void Init(const paddle::framework::ProgramDesc& program_desc);
void Init(const framework::ProgramDesc& program_desc, framework::Scope* scope,
const platform::Place& place);
void Run();
void Release();
private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor);
FleetExecutorDesc exe_desc_;
std::unique_ptr<RuntimeGraph> runtime_graph_;
void InitMessageBus();
void InitCarrier();
void CopyParameters(int microbatch_id, const framework::ProgramDesc& program);
FleetExecutorDesc exe_desc_;
std::unique_ptr<RuntimeGraph> runtime_graph_;
framework::Scope* root_scope_;
framework::Scope* minibatch_scope_;
platform::Place place_;
std::vector<framework::Scope*> microbatch_scopes_;
};
} // namespace distributed
......
......@@ -21,7 +21,7 @@ enum MessageType {
STOP = 1; // STOP an Interceptor
DATA_IS_READY = 2; // upstream data is ready
DATE_IS_USELESS = 3; // downstream has used the data
ERROR = 4; // current Interceptor encounters error
ERR = 4; // current Interceptor encounters error
RESET = 5; // reset the status
}
......
......@@ -17,6 +17,8 @@
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/place.h"
namespace py = pybind11;
......@@ -30,7 +32,8 @@ void BindFleetExecutor(py::module* m) {
py::class_<FleetExecutor>(*m, "FleetExecutor")
.def(py::init<const std::string&>())
.def("init", &FleetExecutor::Init)
.def("run", &FleetExecutor::Run);
.def("run", &FleetExecutor::Run)
.def("release", &FleetExecutor::Release);
py::class_<TaskNode>(*m, "TaskNode")
.def(py::init<const framework::ProgramDesc&, int64_t, int64_t, int64_t>())
......
......@@ -1997,8 +1997,12 @@ class Executor(object):
num_of_gpu = fleet_exe_desc.dp_degree * fleet_exe_desc.mp_degree * fleet_exe_desc.pp_degree
assert nrank == num_of_gpu, "The number of rank is not equal to the number of gpu."
fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString())
fleet_exe.init(program._pipeline_opt["section_program"].desc)
place = core.Place()
place.set_place(self.place)
fleet_exe.init(program._pipeline_opt["section_program"].desc, scope,
place)
fleet_exe.run()
fleet_exe.release()
return None
def _run_pipeline(self,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册