未验证 提交 e43b6f65 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet executor] Init fleet exe and prepare feed&fetch (#39032)

上级 f37a23a7
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include <glog/logging.h> #include <glog/logging.h>
#include "paddle/fluid/distributed/fleet_executor/dist_model.h" #include "paddle/fluid/distributed/fleet_executor/dist_model.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
...@@ -68,9 +70,15 @@ bool DistModel::Init() { ...@@ -68,9 +70,15 @@ bool DistModel::Init() {
program_.reset(config_.program_desc); program_.reset(config_.program_desc);
scope_.reset(config_.scope); scope_.reset(config_.scope);
} }
if (!PrepareFeedAndFetch()) {
return false;
}
if (!CommInit()) { if (!CommInit()) {
return false; return false;
} }
if (!PrepareFleetExe()) {
return false;
}
return true; return true;
} }
...@@ -298,6 +306,55 @@ bool DistModel::LoadParameters() { ...@@ -298,6 +306,55 @@ bool DistModel::LoadParameters() {
return true; return true;
} }
bool DistModel::PrepareFleetExe() {
task_node_.reset(new TaskNode(program_.get(), config_.local_rank));
if (config_.local_rank - config_.mp_degree >= 0) {
task_node_->AddUpstreamTask(config_.local_rank - config_.mp_degree);
}
if (config_.local_rank + config_.mp_degree < config_.nranks) {
task_node_->AddDownstreamTask(config_.local_rank + config_.mp_degree);
}
task_node_->SetType("Compute");
task_node_->Init();
executor_desc_ = FleetExecutorDesc();
executor_desc_.set_cur_rank(config_.local_rank);
std::unordered_map<int64_t, int64_t> id_to_rank;
for (int i = 0; i < config_.nranks; ++i) {
RankInfo *rank_info = executor_desc_.add_cluster_info();
rank_info->set_rank(i);
rank_info->set_ip_port(config_.trainer_endpoints[i]);
id_to_rank.insert({i, i});
}
fleet_exe.reset(new FleetExecutor(executor_desc_));
fleet_exe->Init("inference", *(program_.get()), scope_.get(), place_, 1,
{task_node_.get()}, id_to_rank);
return true;
}
bool DistModel::PrepareFeedAndFetch() {
for (auto *op : program_->Block(0).AllOps()) {
if (op->Type() == "feed") {
VLOG(3) << "feed op with feed var: " << op->Output("Out")[0];
int idx = BOOST_GET_CONST(int, op->GetAttr("col"));
if (feeds_.size() <= static_cast<size_t>(idx)) {
feeds_.resize(idx + 1);
}
feeds_[idx] = op;
feed_names_[op->Output("Out")[0]] = idx;
idx_to_feeds_[idx] = op->Output("Out")[0];
} else if (op->Type() == "fetch") {
VLOG(3) << "fetch op with fetch var: " << op->Input("X")[0];
int idx = BOOST_GET_CONST(int, op->GetAttr("col"));
if (fetches_.size() <= static_cast<size_t>(idx)) {
fetches_.resize(idx + 1);
}
fetches_[idx] = op;
id_to_fetches_[idx] = op->Input("X")[0];
}
}
return true;
}
void DistModel::Run(const std::vector<paddle::framework::Tensor> &input_data, void DistModel::Run(const std::vector<paddle::framework::Tensor> &input_data,
std::vector<paddle::framework::Tensor> *output_data) { std::vector<paddle::framework::Tensor> *output_data) {
/* TODO(fleet exe dev): implement this funct */ /* TODO(fleet exe dev): implement this funct */
......
...@@ -32,6 +32,9 @@ class BlockDesc; ...@@ -32,6 +32,9 @@ class BlockDesc;
namespace distributed { namespace distributed {
class TaskNode;
class FleetExecutor;
struct DistModelConfig { struct DistModelConfig {
std::string model_dir{}; std::string model_dir{};
framework::ProgramDesc* program_desc{nullptr}; framework::ProgramDesc* program_desc{nullptr};
...@@ -66,12 +69,21 @@ class DistModel { ...@@ -66,12 +69,21 @@ class DistModel {
bool LoadParameters(); bool LoadParameters();
bool PreparePlace(); bool PreparePlace();
bool CommInit(); bool CommInit();
bool PrepareFeedAndFetch();
bool PrepareFleetExe();
void InsertCommOp(std::string tmp_var_name, int nranks, int rank, void InsertCommOp(std::string tmp_var_name, int nranks, int rank,
const std::vector<std::string>& peer_endpoints, const std::vector<std::string>& peer_endpoints,
framework::BlockDesc* block, int ring_id); framework::BlockDesc* block, int ring_id);
std::vector<framework::OpDesc*> feeds_;
std::map<std::string, int64_t> feed_names_;
std::map<int64_t, std::string> idx_to_feeds_;
std::vector<framework::OpDesc*> fetches_;
std::map<int64_t, std::string> id_to_fetches_;
DistModelConfig config_; DistModelConfig config_;
FleetExecutorDesc executor_desc_; FleetExecutorDesc executor_desc_;
std::shared_ptr<FleetExecutor> fleet_exe;
std::shared_ptr<TaskNode> task_node_;
std::shared_ptr<framework::Scope> scope_; std::shared_ptr<framework::Scope> scope_;
paddle::platform::Place place_; paddle::platform::Place place_;
std::shared_ptr<framework::ProgramDesc> program_; std::shared_ptr<framework::ProgramDesc> program_;
......
...@@ -35,6 +35,13 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) { ...@@ -35,6 +35,13 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
InitMessageBus(); InitMessageBus();
} }
FleetExecutor::FleetExecutor(const FleetExecutorDesc& exe_desc)
: exe_desc_(exe_desc) {
// Message bus will be created and inited only once
GlobalVal<MessageBus>::Create();
InitMessageBus();
}
FleetExecutor::~FleetExecutor() { FleetExecutor::~FleetExecutor() {
for (const auto& carrier_id : carrier_ids_) { for (const auto& carrier_id : carrier_ids_) {
GlobalMap<std::string, Carrier>::Get(carrier_id)->Release(); GlobalMap<std::string, Carrier>::Get(carrier_id)->Release();
......
...@@ -36,6 +36,7 @@ class FleetExecutor final { ...@@ -36,6 +36,7 @@ class FleetExecutor final {
public: public:
FleetExecutor() = delete; FleetExecutor() = delete;
explicit FleetExecutor(const std::string& exe_desc_str); explicit FleetExecutor(const std::string& exe_desc_str);
explicit FleetExecutor(const FleetExecutorDesc& exe_desc);
~FleetExecutor(); ~FleetExecutor();
void Init(const std::string& carrier_id, void Init(const std::string& carrier_id,
const framework::ProgramDesc& program_desc, framework::Scope* scope, const framework::ProgramDesc& program_desc, framework::Scope* scope,
......
...@@ -38,6 +38,16 @@ TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank, ...@@ -38,6 +38,16 @@ TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank,
task_id_ = task_node_cnt++; task_id_ = task_node_cnt++;
} }
TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank)
: program_(program), rank_(rank), task_id_(rank) {
max_run_times_ = 1;
max_slot_nums_ = 1;
LOG(INFO)
<< "Constructing TaskNode for DistModelInf. The TaskNode's id is: "
<< rank
<< ". And the TaskNode's max_run_time and max_slot_num will be set to 1.";
}
void TaskNode::SetProgram(paddle::framework::ProgramDesc* program) { void TaskNode::SetProgram(paddle::framework::ProgramDesc* program) {
program_ = program; program_ = program;
} }
......
...@@ -42,6 +42,7 @@ class TaskNode final { ...@@ -42,6 +42,7 @@ class TaskNode final {
int64_t max_slot_nums); int64_t max_slot_nums);
TaskNode(paddle::framework::ProgramDesc* program, int64_t rank, TaskNode(paddle::framework::ProgramDesc* program, int64_t rank,
int64_t max_run_times, int64_t max_slot_nums); int64_t max_run_times, int64_t max_slot_nums);
TaskNode(paddle::framework::ProgramDesc* program, int64_t rank);
~TaskNode() = default; ~TaskNode() = default;
void SetProgram(paddle::framework::ProgramDesc* program); void SetProgram(paddle::framework::ProgramDesc* program);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册