diff --git a/paddle/fluid/distributed/fleet_executor/dist_model.cc b/paddle/fluid/distributed/fleet_executor/dist_model.cc index 0fdd38ac7a329ee3ba5befa47e305e6c7cf8ad36..310c809de717287bb4c695240349d6fc5b7d9f60 100644 --- a/paddle/fluid/distributed/fleet_executor/dist_model.cc +++ b/paddle/fluid/distributed/fleet_executor/dist_model.cc @@ -15,6 +15,8 @@ #include #include "paddle/fluid/distributed/fleet_executor/dist_model.h" +#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" +#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/framework/op_proto_maker.h" @@ -68,9 +70,15 @@ bool DistModel::Init() { program_.reset(config_.program_desc); scope_.reset(config_.scope); } + if (!PrepareFeedAndFetch()) { + return false; + } if (!CommInit()) { return false; } + if (!PrepareFleetExe()) { + return false; + } return true; } @@ -298,6 +306,55 @@ bool DistModel::LoadParameters() { return true; } +bool DistModel::PrepareFleetExe() { + task_node_.reset(new TaskNode(program_.get(), config_.local_rank)); + if (config_.local_rank - config_.mp_degree >= 0) { + task_node_->AddUpstreamTask(config_.local_rank - config_.mp_degree); + } + if (config_.local_rank + config_.mp_degree < config_.nranks) { + task_node_->AddDownstreamTask(config_.local_rank + config_.mp_degree); + } + task_node_->SetType("Compute"); + task_node_->Init(); + executor_desc_ = FleetExecutorDesc(); + executor_desc_.set_cur_rank(config_.local_rank); + std::unordered_map id_to_rank; + for (int i = 0; i < config_.nranks; ++i) { + RankInfo *rank_info = executor_desc_.add_cluster_info(); + rank_info->set_rank(i); + rank_info->set_ip_port(config_.trainer_endpoints[i]); + id_to_rank.insert({i, i}); + } + fleet_exe.reset(new FleetExecutor(executor_desc_)); + fleet_exe->Init("inference", *(program_.get()), scope_.get(), place_, 1, + {task_node_.get()}, id_to_rank); + return true; +} + +bool DistModel::PrepareFeedAndFetch() { + for (auto *op : program_->Block(0).AllOps()) { + if (op->Type() == "feed") { + VLOG(3) << "feed op with feed var: " << op->Output("Out")[0]; + int idx = BOOST_GET_CONST(int, op->GetAttr("col")); + if (feeds_.size() <= static_cast(idx)) { + feeds_.resize(idx + 1); + } + feeds_[idx] = op; + feed_names_[op->Output("Out")[0]] = idx; + idx_to_feeds_[idx] = op->Output("Out")[0]; + } else if (op->Type() == "fetch") { + VLOG(3) << "fetch op with fetch var: " << op->Input("X")[0]; + int idx = BOOST_GET_CONST(int, op->GetAttr("col")); + if (fetches_.size() <= static_cast(idx)) { + fetches_.resize(idx + 1); + } + fetches_[idx] = op; + id_to_fetches_[idx] = op->Input("X")[0]; + } + } + return true; +} + void DistModel::Run(const std::vector &input_data, std::vector *output_data) { /* TODO(fleet exe dev): implement this funct */ diff --git a/paddle/fluid/distributed/fleet_executor/dist_model.h b/paddle/fluid/distributed/fleet_executor/dist_model.h index 6ab0d2cd4dedd95963af323d153f98ac23bcbb7e..d6dc554f158d340c19391fc12b32cfd19cdae66e 100644 --- a/paddle/fluid/distributed/fleet_executor/dist_model.h +++ b/paddle/fluid/distributed/fleet_executor/dist_model.h @@ -32,6 +32,9 @@ class BlockDesc; namespace distributed { +class TaskNode; +class FleetExecutor; + struct DistModelConfig { std::string model_dir{}; framework::ProgramDesc* program_desc{nullptr}; @@ -66,12 +69,21 @@ class DistModel { bool LoadParameters(); bool PreparePlace(); bool CommInit(); + bool PrepareFeedAndFetch(); + bool PrepareFleetExe(); void InsertCommOp(std::string tmp_var_name, int nranks, int rank, const std::vector& peer_endpoints, framework::BlockDesc* block, int ring_id); + std::vector feeds_; + std::map feed_names_; + std::map idx_to_feeds_; + std::vector fetches_; + std::map id_to_fetches_; DistModelConfig config_; FleetExecutorDesc executor_desc_; + std::shared_ptr fleet_exe; + std::shared_ptr task_node_; std::shared_ptr scope_; paddle::platform::Place place_; std::shared_ptr program_; diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index 19c44fa521b1bffbe4baab27a4262e2c3bded6f1..457549a27b4b7ed6305b107cfd319ecae026a53b 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -35,6 +35,13 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) { InitMessageBus(); } +FleetExecutor::FleetExecutor(const FleetExecutorDesc& exe_desc) + : exe_desc_(exe_desc) { + // Message bus will be created and inited only once + GlobalVal::Create(); + InitMessageBus(); +} + FleetExecutor::~FleetExecutor() { for (const auto& carrier_id : carrier_ids_) { GlobalMap::Get(carrier_id)->Release(); diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.h b/paddle/fluid/distributed/fleet_executor/fleet_executor.h index b2af3e4e457c7573389d491a313a63672b71b627..fa65309127bec50869c52d2f3c85477910ccb37b 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.h +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.h @@ -36,6 +36,7 @@ class FleetExecutor final { public: FleetExecutor() = delete; explicit FleetExecutor(const std::string& exe_desc_str); + explicit FleetExecutor(const FleetExecutorDesc& exe_desc); ~FleetExecutor(); void Init(const std::string& carrier_id, const framework::ProgramDesc& program_desc, framework::Scope* scope, diff --git a/paddle/fluid/distributed/fleet_executor/task_node.cc b/paddle/fluid/distributed/fleet_executor/task_node.cc index 656cfc431cd7034ac5eb3bfe22caa9fb4bbfefcf..6de7038b3231f2fb302dd970273c565c5a718b73 100644 --- a/paddle/fluid/distributed/fleet_executor/task_node.cc +++ b/paddle/fluid/distributed/fleet_executor/task_node.cc @@ -38,6 +38,16 @@ TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank, task_id_ = task_node_cnt++; } +TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank) + : program_(program), rank_(rank), task_id_(rank) { + max_run_times_ = 1; + max_slot_nums_ = 1; + LOG(INFO) + << "Constructing TaskNode for DistModelInf. The TaskNode's id is: " + << rank + << ". And the TaskNode's max_run_time and max_slot_num will be set to 1."; +} + void TaskNode::SetProgram(paddle::framework::ProgramDesc* program) { program_ = program; } diff --git a/paddle/fluid/distributed/fleet_executor/task_node.h b/paddle/fluid/distributed/fleet_executor/task_node.h index b9c1361dc9909c89cd4b4869ca6140b458906e0c..b655d140d37a5bdf547a278eec3355ef4638539f 100644 --- a/paddle/fluid/distributed/fleet_executor/task_node.h +++ b/paddle/fluid/distributed/fleet_executor/task_node.h @@ -42,6 +42,7 @@ class TaskNode final { int64_t max_slot_nums); TaskNode(paddle::framework::ProgramDesc* program, int64_t rank, int64_t max_run_times, int64_t max_slot_nums); + TaskNode(paddle::framework::ProgramDesc* program, int64_t rank); ~TaskNode() = default; void SetProgram(paddle::framework::ProgramDesc* program);