未验证 提交 e43b6f65 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet executor] Init fleet exe and prepare feed&fetch (#39032)

上级 f37a23a7
......@@ -15,6 +15,8 @@
#include <glog/logging.h>
#include "paddle/fluid/distributed/fleet_executor/dist_model.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/op_proto_maker.h"
......@@ -68,9 +70,15 @@ bool DistModel::Init() {
program_.reset(config_.program_desc);
scope_.reset(config_.scope);
}
if (!PrepareFeedAndFetch()) {
return false;
}
if (!CommInit()) {
return false;
}
if (!PrepareFleetExe()) {
return false;
}
return true;
}
......@@ -298,6 +306,55 @@ bool DistModel::LoadParameters() {
return true;
}
bool DistModel::PrepareFleetExe() {
task_node_.reset(new TaskNode(program_.get(), config_.local_rank));
if (config_.local_rank - config_.mp_degree >= 0) {
task_node_->AddUpstreamTask(config_.local_rank - config_.mp_degree);
}
if (config_.local_rank + config_.mp_degree < config_.nranks) {
task_node_->AddDownstreamTask(config_.local_rank + config_.mp_degree);
}
task_node_->SetType("Compute");
task_node_->Init();
executor_desc_ = FleetExecutorDesc();
executor_desc_.set_cur_rank(config_.local_rank);
std::unordered_map<int64_t, int64_t> id_to_rank;
for (int i = 0; i < config_.nranks; ++i) {
RankInfo *rank_info = executor_desc_.add_cluster_info();
rank_info->set_rank(i);
rank_info->set_ip_port(config_.trainer_endpoints[i]);
id_to_rank.insert({i, i});
}
fleet_exe.reset(new FleetExecutor(executor_desc_));
fleet_exe->Init("inference", *(program_.get()), scope_.get(), place_, 1,
{task_node_.get()}, id_to_rank);
return true;
}
bool DistModel::PrepareFeedAndFetch() {
for (auto *op : program_->Block(0).AllOps()) {
if (op->Type() == "feed") {
VLOG(3) << "feed op with feed var: " << op->Output("Out")[0];
int idx = BOOST_GET_CONST(int, op->GetAttr("col"));
if (feeds_.size() <= static_cast<size_t>(idx)) {
feeds_.resize(idx + 1);
}
feeds_[idx] = op;
feed_names_[op->Output("Out")[0]] = idx;
idx_to_feeds_[idx] = op->Output("Out")[0];
} else if (op->Type() == "fetch") {
VLOG(3) << "fetch op with fetch var: " << op->Input("X")[0];
int idx = BOOST_GET_CONST(int, op->GetAttr("col"));
if (fetches_.size() <= static_cast<size_t>(idx)) {
fetches_.resize(idx + 1);
}
fetches_[idx] = op;
id_to_fetches_[idx] = op->Input("X")[0];
}
}
return true;
}
void DistModel::Run(const std::vector<paddle::framework::Tensor> &input_data,
std::vector<paddle::framework::Tensor> *output_data) {
/* TODO(fleet exe dev): implement this funct */
......
......@@ -32,6 +32,9 @@ class BlockDesc;
namespace distributed {
class TaskNode;
class FleetExecutor;
struct DistModelConfig {
std::string model_dir{};
framework::ProgramDesc* program_desc{nullptr};
......@@ -66,12 +69,21 @@ class DistModel {
bool LoadParameters();
bool PreparePlace();
bool CommInit();
bool PrepareFeedAndFetch();
bool PrepareFleetExe();
void InsertCommOp(std::string tmp_var_name, int nranks, int rank,
const std::vector<std::string>& peer_endpoints,
framework::BlockDesc* block, int ring_id);
std::vector<framework::OpDesc*> feeds_;
std::map<std::string, int64_t> feed_names_;
std::map<int64_t, std::string> idx_to_feeds_;
std::vector<framework::OpDesc*> fetches_;
std::map<int64_t, std::string> id_to_fetches_;
DistModelConfig config_;
FleetExecutorDesc executor_desc_;
std::shared_ptr<FleetExecutor> fleet_exe;
std::shared_ptr<TaskNode> task_node_;
std::shared_ptr<framework::Scope> scope_;
paddle::platform::Place place_;
std::shared_ptr<framework::ProgramDesc> program_;
......
......@@ -35,6 +35,13 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
InitMessageBus();
}
FleetExecutor::FleetExecutor(const FleetExecutorDesc& exe_desc)
: exe_desc_(exe_desc) {
// Message bus will be created and inited only once
GlobalVal<MessageBus>::Create();
InitMessageBus();
}
FleetExecutor::~FleetExecutor() {
for (const auto& carrier_id : carrier_ids_) {
GlobalMap<std::string, Carrier>::Get(carrier_id)->Release();
......
......@@ -36,6 +36,7 @@ class FleetExecutor final {
public:
FleetExecutor() = delete;
explicit FleetExecutor(const std::string& exe_desc_str);
explicit FleetExecutor(const FleetExecutorDesc& exe_desc);
~FleetExecutor();
void Init(const std::string& carrier_id,
const framework::ProgramDesc& program_desc, framework::Scope* scope,
......
......@@ -38,6 +38,16 @@ TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank,
task_id_ = task_node_cnt++;
}
TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank)
: program_(program), rank_(rank), task_id_(rank) {
max_run_times_ = 1;
max_slot_nums_ = 1;
LOG(INFO)
<< "Constructing TaskNode for DistModelInf. The TaskNode's id is: "
<< rank
<< ". And the TaskNode's max_run_time and max_slot_num will be set to 1.";
}
void TaskNode::SetProgram(paddle::framework::ProgramDesc* program) {
program_ = program;
}
......
......@@ -42,6 +42,7 @@ class TaskNode final {
int64_t max_slot_nums);
TaskNode(paddle::framework::ProgramDesc* program, int64_t rank,
int64_t max_run_times, int64_t max_slot_nums);
TaskNode(paddle::framework::ProgramDesc* program, int64_t rank);
~TaskNode() = default;
void SetProgram(paddle::framework::ProgramDesc* program);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册