提交 c3621c61 编写于 作者: W willzhang4a58

path manager init

上级 553e19f4
......@@ -226,7 +226,7 @@ void DataMergeChains(
} // namespace
void ChainGraph::Init(std::shared_ptr<const LogicalGraph> logical_graph) {
void ChainGraph::Init(const LogicalGraph* logical_graph) {
// Build Chain
std::list<Chain> chain_list;
Logical2ChainItMap logical2chain_it;
......
......@@ -65,7 +65,7 @@ class ChainGraph final : public Graph<ChainNode, ChainEdge> {
ChainGraph() = default;
~ChainGraph() = default;
void Init(std::shared_ptr<const LogicalGraph> logical_graph);
void Init(const LogicalGraph* logical_graph);
private:
void CollectInputAndOutputLbns();
......
......@@ -45,12 +45,12 @@ void ConnectRelatedStages(
}
void StageGraph::Init(std::shared_ptr<const ChainGraph> chain_graph) {
chain_graph_ = chain_graph;
void StageGraph::Init(std::unique_ptr<const ChainGraph>&& chain_graph) {
chain_graph_ = std::move(chain_graph);
// Init Stages
std::unordered_map<const ChainNode*,
std::vector<StageNode*>> chain2stages;
for (const std::unique_ptr<ChainNode>& cur_chain : chain_graph->nodes()) {
for (const std::unique_ptr<ChainNode>& cur_chain : chain_graph_->nodes()) {
chain2stages[cur_chain.get()] = {};
for (MachineId machine_id : cur_chain->parallel_desc()->machines()) {
StageNode* stage_node = NewFinalNode();
......@@ -64,7 +64,7 @@ void StageGraph::Init(std::shared_ptr<const ChainGraph> chain_graph) {
(StageNode* src_node, StageNode* dst_node) {
Connect(src_node, this->NewFinalEdge(), dst_node);
};
for (const std::unique_ptr<ChainNode>& cur_chain : chain_graph->nodes()) {
for (const std::unique_ptr<ChainNode>& cur_chain : chain_graph_->nodes()) {
for (const ChainEdge* edge : cur_chain->out_edges()) {
const std::vector<StageNode*>& cur_stages =
chain2stages.at(cur_chain.get());
......
......@@ -48,11 +48,12 @@ class StageGraph final : public Graph<StageNode, StageEdge> {
StageGraph() = default;
~StageGraph() = default;
void Init(std::shared_ptr<const ChainGraph> chain_graph);
void Init(std::unique_ptr<const ChainGraph>&& chain_graph);
const ChainGraph* chain_graph() const { return chain_graph_.get(); }
private:
// We need to make sure the chain_node is alive
std::shared_ptr<const ChainGraph> chain_graph_;
std::unique_ptr<const ChainGraph> chain_graph_;
};
......
......@@ -16,20 +16,21 @@ void TaskGraph::Init(const DLNetConf& dl_net_conf,
const Strategy& strategy_conf,
const IDMap& id_map,
bool need_bp) {
auto logical_graph = std::make_shared<LogicalGraph>();
std::unique_ptr<LogicalGraph> logical_graph(new LogicalGraph);
logical_graph->Init(dl_net_conf, strategy_conf);
auto chain_graph = std::make_shared<ChainGraph>();
chain_graph->Init(logical_graph);
auto stage_graph = std::make_shared<StageGraph>();
stage_graph->Init(chain_graph);
BuildWithoutTransfm(stage_graph, id_map, need_bp);
std::unique_ptr<ChainGraph> chain_graph(new ChainGraph);
chain_graph->Init(logical_graph.get());
std::unique_ptr<StageGraph> stage_graph(new StageGraph);
stage_graph->Init(std::move(chain_graph));
BuildWithoutTransfm(std::move(stage_graph), id_map, need_bp);
BuildTransfm();
}
void TaskGraph::BuildWithoutTransfm(
std::shared_ptr<const StageGraph> stage_graph,
std::unique_ptr<const StageGraph>&& stage_graph,
const IDMap& id_map,
bool job_need_bp) {
stage_graph_ = std::move(stage_graph);
Stage2TaskNodesMap stage2task_nodes;
InitCompTaskNodes(*stage_graph, id_map, &stage2task_nodes);
InitBoxingTaskNodes(*stage_graph, id_map, &stage2task_nodes);
......
......@@ -20,8 +20,10 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
const IDMap& id_map,
bool need_bp);
const StageGraph* stage_graph() const { return stage_graph_.get(); }
private:
void BuildWithoutTransfm(std::shared_ptr<const StageGraph> stage_graph,
void BuildWithoutTransfm(std::unique_ptr<const StageGraph>&& stage_graph,
const IDMap& id_map,
bool need_bp);
void BuildTransfm();
......@@ -71,8 +73,7 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
void BackwardConnect(const std::vector<TaskNode*>& turning_node_vec);
void BuildBpStruct();
// We need to make sure the StageNode is alive
std::shared_ptr<const StageGraph> stage_graph_;
std::unique_ptr<const StageGraph> stage_graph_;
};
......
#ifndef ONEFLOW_PATH_CPS_DESC_H_
#define ONEFLOW_PATH_CPS_DESC_H_
// Cross Path Subcribe Descriptor
class CpsDesc {
public:
private:
};
#endif // ONEFLOW_PATH_CPS_DESC_H_
......@@ -11,6 +11,11 @@ class ModelLoadPath final : public Path {
ModelLoadPath() = default;
~ModelLoadPath() = default;
void Build(const ChainNode* chain_in_data_path,
std::function<void(const CpsDesc&)> add_cps_desc) {
LOG(FATAL) << "TODO";
}
private:
};
......
......@@ -11,6 +11,11 @@ class ModelSavePath final : public Path {
ModelSavePath() = default;
~ModelSavePath() = default;
void Build(const ChainNode* chain_in_data_path,
std::function<void(const CpsDesc&)> add_cps_desc) {
LOG(FATAL) << "TODO";
}
private:
};
......
......@@ -11,6 +11,11 @@ class ModelUpdatePath final : public Path {
ModelUpdatePath() = default;
~ModelUpdatePath() = default;
void Build(const ChainNode* chain_in_data_path,
std::function<void(const CpsDesc&)> add_cps_desc) {
LOG(FATAL) << "TODO";
}
private:
};
......
......@@ -2,6 +2,7 @@
#define ONEFLOW_PATH_PATH_H_
#include "graph/task_graph.h"
#include "path/cps_desc.h"
namespace oneflow {
......@@ -11,6 +12,13 @@ class Path {
Path() = default;
virtual ~Path() = default;
TaskGraph* task_graph() {
return task_graph_.get();
}
const ChainGraph* chain_graph() const {
return task_graph_->stage_graph()->chain_graph();
}
protected:
std::unique_ptr<TaskGraph>& mut_task_graph() { return task_graph_; }
......
......@@ -11,10 +11,29 @@ void PathManager::Init(const JobSysConf& job_sys_conf) {
data_path->Build(job_sys_conf.train_dlnet_conf(),
job_sys_conf.strategy(),
id_map,
true);
LOG(FATAL) << "TODO";
true); // TODO
paths_.insert(std::make_pair("data", std::move(data_path)));
// build model path
std::vector<CpsDesc> cps_desc_vec;
auto add_cps_desc = [&cps_desc_vec](const CpsDesc& cps_desc) {
cps_desc_vec.push_back(cps_desc);
};
for (const auto& chain : paths_.at("data")->chain_graph()->nodes()) {
std::unique_ptr<ModelUpdatePath> model_update_path(new ModelUpdatePath);
std::unique_ptr<ModelLoadPath> model_load_path(new ModelLoadPath);
std::unique_ptr<ModelSavePath> model_save_path(new ModelSavePath);
model_update_path->Build(chain.get(), add_cps_desc);
model_load_path->Build(chain.get(), add_cps_desc);
model_save_path->Build(chain.get(), add_cps_desc);
// TODO: name
paths_.insert(std::make_pair("", std::move(model_update_path)));
paths_.insert(std::make_pair("", std::move(model_load_path)));
paths_.insert(std::make_pair("", std::move(model_save_path)));
}
// processs cross path subscribe
for (const CpsDesc& cps_desc : cps_desc_vec) {
ProcessCps(cps_desc);
}
}
} // namespace oneflow
......@@ -18,6 +18,10 @@ class PathManager {
void Init(const JobSysConf& job_sys_conf);
private:
void ProcessCps(const CpsDesc& cps_desc) {
LOG(FATAL) << "TODO";
}
std::unordered_map<std::string, std::unique_ptr<Path>> paths_;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册