diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index 3a823674d842c5a8e76d10d36b0e44dbeef90148..9bba2624bf2de293674be0733fda378cd5b03170 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -33,10 +33,22 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) { FleetExecutor::~FleetExecutor() { root_scope_->DropKids(); } -void FleetExecutor::Init(const framework::ProgramDesc& program_desc, - framework::Scope* scope, - const platform::Place& place) { - runtime_graph_ = std::make_shared(program_desc, exe_desc_); +void FleetExecutor::Init( + const framework::ProgramDesc& program_desc, framework::Scope* scope, + const platform::Place& place, const std::vector& task_nodes, + const std::unordered_map& task_id_to_rank) { + if (task_nodes.size() == 0) { + runtime_graph_ = std::make_shared(program_desc, exe_desc_); + } else { + runtime_graph_ = std::make_shared(); + std::unordered_map interceptor_id_to_task; + for (auto task_node : task_nodes) { + int64_t interceptor_id = task_node->task_id(); + interceptor_id_to_task.emplace(interceptor_id, task_node); + } + runtime_graph_->SetInterceptorIdToRank(task_id_to_rank); + runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task); + } root_scope_ = scope; place_ = place; PADDLE_ENFORCE_NOT_NULL(root_scope_, platform::errors::InvalidArgument( diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.h b/paddle/fluid/distributed/fleet_executor/fleet_executor.h index ac857fb6c38a2109d91d149598d2c1f74f615ddf..9fddeae63f6765f1e7bb403611e89c4a5a02d185 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.h +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.h @@ -30,6 +30,7 @@ namespace distributed { class RuntimeGraph; class Carrier; class MessageBus; +class TaskNode; class FleetExecutor final { public: @@ -37,7 +38,9 @@ class FleetExecutor final { explicit FleetExecutor(const std::string& exe_desc_str); ~FleetExecutor(); void Init(const framework::ProgramDesc& program_desc, framework::Scope* scope, - const platform::Place& place); + const platform::Place& place, + const std::vector& task_nodes, + const std::unordered_map& task_id_to_rank); void Run(); private: diff --git a/paddle/fluid/distributed/fleet_executor/runtime_graph.h b/paddle/fluid/distributed/fleet_executor/runtime_graph.h index 26b758767c07fcceb163897defca218fb648a985..9ffc9cc2cc137e8846957ba2d9feab5e8e7e4f43 100644 --- a/paddle/fluid/distributed/fleet_executor/runtime_graph.h +++ b/paddle/fluid/distributed/fleet_executor/runtime_graph.h @@ -44,6 +44,14 @@ class RuntimeGraph final { const std::unordered_map& intercepter_id_to_rank() const { return intercepter_id_to_rank_; } + void SetInterceptorIdToRank( + const std::unordered_map& intercepter_id_to_rank) { + intercepter_id_to_rank_ = intercepter_id_to_rank; + } + void SetInterceptorIdToNode( + const std::unordered_map& intercepter_id_to_node) { + intercepter_id_to_node_ = intercepter_id_to_node; + } std::string DebugString() const; private: diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index c50af065bc468e291945512600d3b4a4596482c7..eab00707f012ab7b58b4e543579396b55cb21f21 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1979,10 +1979,12 @@ class Executor(object): fleet_exe_desc.num_micro_batches = fleet_opt["num_micro_batches"] num_of_gpu = fleet_exe_desc.dp_degree * fleet_exe_desc.mp_degree * fleet_exe_desc.pp_degree assert nrank == num_of_gpu, "The number of rank is not equal to the number of gpu." + task_id_to_rank = fleet_opt.get("task_id_to_rank", {}) + tasks = fleet_opt.get("tasks", []) fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString()) place = core.Place() place.set_place(self.place) - fleet_exe.init(program.desc, scope, place) + fleet_exe.init(program.desc, scope, place, tasks, task_id_to_rank) return fleet_exe def _run_using_fleet_executor(self,