未验证 提交 7eb121df 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet_executor] Take task node from python side (#38083)

上级 f5b1fd7c
......@@ -33,10 +33,22 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
FleetExecutor::~FleetExecutor() { root_scope_->DropKids(); }
void FleetExecutor::Init(const framework::ProgramDesc& program_desc,
framework::Scope* scope,
const platform::Place& place) {
void FleetExecutor::Init(
const framework::ProgramDesc& program_desc, framework::Scope* scope,
const platform::Place& place, const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank) {
if (task_nodes.size() == 0) {
runtime_graph_ = std::make_shared<RuntimeGraph>(program_desc, exe_desc_);
} else {
runtime_graph_ = std::make_shared<RuntimeGraph>();
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
for (auto task_node : task_nodes) {
int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node);
}
runtime_graph_->SetInterceptorIdToRank(task_id_to_rank);
runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task);
}
root_scope_ = scope;
place_ = place;
PADDLE_ENFORCE_NOT_NULL(root_scope_, platform::errors::InvalidArgument(
......
......@@ -30,6 +30,7 @@ namespace distributed {
class RuntimeGraph;
class Carrier;
class MessageBus;
class TaskNode;
class FleetExecutor final {
public:
......@@ -37,7 +38,9 @@ class FleetExecutor final {
explicit FleetExecutor(const std::string& exe_desc_str);
~FleetExecutor();
void Init(const framework::ProgramDesc& program_desc, framework::Scope* scope,
const platform::Place& place);
const platform::Place& place,
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank);
void Run();
private:
......
......@@ -44,6 +44,14 @@ class RuntimeGraph final {
const std::unordered_map<int64_t, int64_t>& intercepter_id_to_rank() const {
return intercepter_id_to_rank_;
}
void SetInterceptorIdToRank(
const std::unordered_map<int64_t, int64_t>& intercepter_id_to_rank) {
intercepter_id_to_rank_ = intercepter_id_to_rank;
}
void SetInterceptorIdToNode(
const std::unordered_map<int64_t, TaskNode*>& intercepter_id_to_node) {
intercepter_id_to_node_ = intercepter_id_to_node;
}
std::string DebugString() const;
private:
......
......@@ -1979,10 +1979,12 @@ class Executor(object):
fleet_exe_desc.num_micro_batches = fleet_opt["num_micro_batches"]
num_of_gpu = fleet_exe_desc.dp_degree * fleet_exe_desc.mp_degree * fleet_exe_desc.pp_degree
assert nrank == num_of_gpu, "The number of rank is not equal to the number of gpu."
task_id_to_rank = fleet_opt.get("task_id_to_rank", {})
tasks = fleet_opt.get("tasks", [])
fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString())
place = core.Place()
place.set_place(self.place)
fleet_exe.init(program.desc, scope, place)
fleet_exe.init(program.desc, scope, place, tasks, task_id_to_rank)
return fleet_exe
def _run_using_fleet_executor(self,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册