From 7eb121df47163e06b50bcd7ac81b0d0c05fd1283 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Tue, 14 Dec 2021 12:03:37 +0800 Subject: [PATCH] [fleet_executor] Take task node from python side (#38083) --- .../fleet_executor/fleet_executor.cc | 20 +++++++++++++++---- .../fleet_executor/fleet_executor.h | 5 ++++- .../fleet_executor/runtime_graph.h | 8 ++++++++ python/paddle/fluid/executor.py | 4 +++- 4 files changed, 31 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index 3a823674d8..9bba2624bf 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -33,10 +33,22 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) { FleetExecutor::~FleetExecutor() { root_scope_->DropKids(); } -void FleetExecutor::Init(const framework::ProgramDesc& program_desc, - framework::Scope* scope, - const platform::Place& place) { - runtime_graph_ = std::make_shared(program_desc, exe_desc_); +void FleetExecutor::Init( + const framework::ProgramDesc& program_desc, framework::Scope* scope, + const platform::Place& place, const std::vector& task_nodes, + const std::unordered_map& task_id_to_rank) { + if (task_nodes.size() == 0) { + runtime_graph_ = std::make_shared(program_desc, exe_desc_); + } else { + runtime_graph_ = std::make_shared(); + std::unordered_map interceptor_id_to_task; + for (auto task_node : task_nodes) { + int64_t interceptor_id = task_node->task_id(); + interceptor_id_to_task.emplace(interceptor_id, task_node); + } + runtime_graph_->SetInterceptorIdToRank(task_id_to_rank); + runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task); + } root_scope_ = scope; place_ = place; PADDLE_ENFORCE_NOT_NULL(root_scope_, platform::errors::InvalidArgument( diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.h b/paddle/fluid/distributed/fleet_executor/fleet_executor.h index ac857fb6c3..9fddeae63f 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.h +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.h @@ -30,6 +30,7 @@ namespace distributed { class RuntimeGraph; class Carrier; class MessageBus; +class TaskNode; class FleetExecutor final { public: @@ -37,7 +38,9 @@ class FleetExecutor final { explicit FleetExecutor(const std::string& exe_desc_str); ~FleetExecutor(); void Init(const framework::ProgramDesc& program_desc, framework::Scope* scope, - const platform::Place& place); + const platform::Place& place, + const std::vector& task_nodes, + const std::unordered_map& task_id_to_rank); void Run(); private: diff --git a/paddle/fluid/distributed/fleet_executor/runtime_graph.h b/paddle/fluid/distributed/fleet_executor/runtime_graph.h index 26b758767c..9ffc9cc2cc 100644 --- a/paddle/fluid/distributed/fleet_executor/runtime_graph.h +++ b/paddle/fluid/distributed/fleet_executor/runtime_graph.h @@ -44,6 +44,14 @@ class RuntimeGraph final { const std::unordered_map& intercepter_id_to_rank() const { return intercepter_id_to_rank_; } + void SetInterceptorIdToRank( + const std::unordered_map& intercepter_id_to_rank) { + intercepter_id_to_rank_ = intercepter_id_to_rank; + } + void SetInterceptorIdToNode( + const std::unordered_map& intercepter_id_to_node) { + intercepter_id_to_node_ = intercepter_id_to_node; + } std::string DebugString() const; private: diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index c50af065bc..eab00707f0 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1979,10 +1979,12 @@ class Executor(object): fleet_exe_desc.num_micro_batches = fleet_opt["num_micro_batches"] num_of_gpu = fleet_exe_desc.dp_degree * fleet_exe_desc.mp_degree * fleet_exe_desc.pp_degree assert nrank == num_of_gpu, "The number of rank is not equal to the number of gpu." + task_id_to_rank = fleet_opt.get("task_id_to_rank", {}) + tasks = fleet_opt.get("tasks", []) fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString()) place = core.Place() place.set_place(self.place) - fleet_exe.init(program.desc, scope, place) + fleet_exe.init(program.desc, scope, place, tasks, task_id_to_rank) return fleet_exe def _run_using_fleet_executor(self, -- GitLab