未验证 提交 7eb121df 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet_executor] Take task node from python side (#38083)

上级 f5b1fd7c
...@@ -33,10 +33,22 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) { ...@@ -33,10 +33,22 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
FleetExecutor::~FleetExecutor() { root_scope_->DropKids(); } FleetExecutor::~FleetExecutor() { root_scope_->DropKids(); }
void FleetExecutor::Init(const framework::ProgramDesc& program_desc, void FleetExecutor::Init(
framework::Scope* scope, const framework::ProgramDesc& program_desc, framework::Scope* scope,
const platform::Place& place) { const platform::Place& place, const std::vector<TaskNode*>& task_nodes,
runtime_graph_ = std::make_shared<RuntimeGraph>(program_desc, exe_desc_); const std::unordered_map<int64_t, int64_t>& task_id_to_rank) {
if (task_nodes.size() == 0) {
runtime_graph_ = std::make_shared<RuntimeGraph>(program_desc, exe_desc_);
} else {
runtime_graph_ = std::make_shared<RuntimeGraph>();
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
for (auto task_node : task_nodes) {
int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node);
}
runtime_graph_->SetInterceptorIdToRank(task_id_to_rank);
runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task);
}
root_scope_ = scope; root_scope_ = scope;
place_ = place; place_ = place;
PADDLE_ENFORCE_NOT_NULL(root_scope_, platform::errors::InvalidArgument( PADDLE_ENFORCE_NOT_NULL(root_scope_, platform::errors::InvalidArgument(
......
...@@ -30,6 +30,7 @@ namespace distributed { ...@@ -30,6 +30,7 @@ namespace distributed {
class RuntimeGraph; class RuntimeGraph;
class Carrier; class Carrier;
class MessageBus; class MessageBus;
class TaskNode;
class FleetExecutor final { class FleetExecutor final {
public: public:
...@@ -37,7 +38,9 @@ class FleetExecutor final { ...@@ -37,7 +38,9 @@ class FleetExecutor final {
explicit FleetExecutor(const std::string& exe_desc_str); explicit FleetExecutor(const std::string& exe_desc_str);
~FleetExecutor(); ~FleetExecutor();
void Init(const framework::ProgramDesc& program_desc, framework::Scope* scope, void Init(const framework::ProgramDesc& program_desc, framework::Scope* scope,
const platform::Place& place); const platform::Place& place,
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank);
void Run(); void Run();
private: private:
......
...@@ -44,6 +44,14 @@ class RuntimeGraph final { ...@@ -44,6 +44,14 @@ class RuntimeGraph final {
const std::unordered_map<int64_t, int64_t>& intercepter_id_to_rank() const { const std::unordered_map<int64_t, int64_t>& intercepter_id_to_rank() const {
return intercepter_id_to_rank_; return intercepter_id_to_rank_;
} }
void SetInterceptorIdToRank(
const std::unordered_map<int64_t, int64_t>& intercepter_id_to_rank) {
intercepter_id_to_rank_ = intercepter_id_to_rank;
}
void SetInterceptorIdToNode(
const std::unordered_map<int64_t, TaskNode*>& intercepter_id_to_node) {
intercepter_id_to_node_ = intercepter_id_to_node;
}
std::string DebugString() const; std::string DebugString() const;
private: private:
......
...@@ -1979,10 +1979,12 @@ class Executor(object): ...@@ -1979,10 +1979,12 @@ class Executor(object):
fleet_exe_desc.num_micro_batches = fleet_opt["num_micro_batches"] fleet_exe_desc.num_micro_batches = fleet_opt["num_micro_batches"]
num_of_gpu = fleet_exe_desc.dp_degree * fleet_exe_desc.mp_degree * fleet_exe_desc.pp_degree num_of_gpu = fleet_exe_desc.dp_degree * fleet_exe_desc.mp_degree * fleet_exe_desc.pp_degree
assert nrank == num_of_gpu, "The number of rank is not equal to the number of gpu." assert nrank == num_of_gpu, "The number of rank is not equal to the number of gpu."
task_id_to_rank = fleet_opt.get("task_id_to_rank", {})
tasks = fleet_opt.get("tasks", [])
fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString()) fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString())
place = core.Place() place = core.Place()
place.set_place(self.place) place.set_place(self.place)
fleet_exe.init(program.desc, scope, place) fleet_exe.init(program.desc, scope, place, tasks, task_id_to_rank)
return fleet_exe return fleet_exe
def _run_using_fleet_executor(self, def _run_using_fleet_executor(self,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册