提交 50d606e7 编写于 作者: W willzhang4a58

add checkpoint

上级 8b45c6ea
......@@ -231,7 +231,8 @@ std::string ChainNode::ConcatedOpsName() const {
return ss.str().substr(2);
}
ChainGraph::ChainGraph(const LogicalGraph* logical_gph) {
ChainGraph::ChainGraph(const LogicalGraph* logical_gph,
const std::string& dot_filepath) {
LOG(INFO) << "Build ChainGraph...";
// Build Chain
std::list<Chain> chain_list;
......@@ -275,7 +276,7 @@ ChainGraph::ChainGraph(const LogicalGraph* logical_gph) {
// Post processing
UpdateSourceAndSink();
SetInOutLbn4AllChainNodeInDataTaskGraph();
ToDotFile(LogDir() + "/chain_graph.dot");
ToDotFile(dot_filepath);
}
void ChainGraph::SetInOutLbn4AllChainNodeInDataTaskGraph() {
......
......@@ -81,7 +81,8 @@ class ChainGraph final : public Graph<ChainNode, ChainEdge> {
ChainGraph() = default;
~ChainGraph() = default;
ChainGraph(const LogicalGraph* logical_gph);
ChainGraph(const LogicalGraph* logical_gph,
const std::string& dot_filepath);
private:
void SetInOutLbn4AllChainNodeInDataTaskGraph();
......
......@@ -14,10 +14,9 @@ class DataTaskGraph final : public TaskGraph {
DataTaskGraph(const DLNetConf& dl_net_conf,
const Strategy& strategy_conf,
bool need_bp) {
LogicalGraph logical_gph(dl_net_conf, strategy_conf);
logical_gph.ToDotFile(LogDir() + "/logical_graph.dot");
auto chain_gph = of_make_unique<ChainGraph> (&logical_gph);
BuildFromChainGph(std::move(chain_gph), need_bp);
LogicalGraph logical_gph(dl_net_conf, strategy_conf, LogDir() + "/logical_graph.dot");
auto chain_gph = of_make_unique<ChainGraph> (&logical_gph, LogDir() + "/data_chain_graph.dot");
BuildFromChainGph(std::move(chain_gph), need_bp, LogDir() + "/data_");
}
CompTaskNodeMemFunc Func4FwBuildExecAndProducedRegsts() const override {
......
......@@ -6,13 +6,15 @@
namespace oneflow {
LogicalGraph::LogicalGraph(const DLNetConf& dl_net_conf,
const Strategy& strategy_conf) {
const Strategy& strategy_conf,
const std::string& dot_filepath) {
LOG(INFO) << "Build LogicalGraph...";
HashMap<LogicalEdge*, std::string> edge2lbn;
HashMap<LogicalEdge*, std::string> edge2ibn;
NaiveBuildGraphStruct(dl_net_conf, &edge2lbn, &edge2ibn);
FillNodeWithParallelDesc(strategy_conf);
AddCloneNodes(edge2lbn, edge2ibn);
ToDotFile(dot_filepath);
}
void LogicalGraph::NaiveBuildGraphStruct(
......
......@@ -58,7 +58,8 @@ class LogicalGraph final : public Graph<LogicalNode, LogicalEdge> {
~LogicalGraph() = default;
LogicalGraph(const DLNetConf& dl_net_conf,
const Strategy& strategy_conf);
const Strategy& strategy_conf,
const std::string& dot_filepath);
private:
void NaiveBuildGraphStruct(
......
......@@ -40,7 +40,8 @@ void MdLoadTaskGraph::BuildTaskGraph(const ChainNode* update_chain) {
faker_chain->mut_input_lbns() = {RegstDesc::kAllLbn};
Connect(load_chain, chain_gph->NewEdge(), faker_chain);
chain_gph->UpdateSourceAndSink();
BuildFromChainGph(std::move(chain_gph), false);
chain_gph->ToDotFile(LogDir() + "/model_load_chain_graph.dot");
BuildFromChainGph(std::move(chain_gph), false, LogDir() + "/model_load_");
}
void MdLoadTaskGraph::InitFaker2Mccoy(
......
......@@ -41,7 +41,8 @@ void MdSaveTaskGraph::BuildTaskGraph(const ChainNode* update_chain) {
// Connect
Connect(faker_chain, chain_gph->NewEdge(), save_chain);
chain_gph->UpdateSourceAndSink();
BuildFromChainGph(std::move(chain_gph), false);
chain_gph->ToDotFile(LogDir() + "/model_save_chain_graph.dot");
BuildFromChainGph(std::move(chain_gph), false, LogDir() + "/model_save_");
}
void MdSaveTaskGraph::InitFaker2Mccoy(
......
......@@ -37,7 +37,8 @@ void MdUpdtTaskGraph::BuildTaskGraph(const ChainNode* data_chain) {
Connect(faker_chain, chain_gph->NewEdge(), updt_chain);
}
//
BuildFromChainGph(std::move(chain_gph), false);
chain_gph->ToDotFile(LogDir() + "/model_update_chain_graph.dot");
BuildFromChainGph(std::move(chain_gph), false, LogDir() + "/model_update_");
}
void MdUpdtTaskGraph::InitFaker2MccoyAndParallelId2UpdtMap(
......
......@@ -3,7 +3,8 @@
namespace oneflow {
StageGraph::StageGraph(std::unique_ptr<const ChainGraph>&& chain_gph) {
StageGraph::StageGraph(std::unique_ptr<const ChainGraph>&& chain_gph,
const std::string& dot_filepath) {
LOG(INFO) << "Build StageGraph...";
chain_gph_ = std::move(chain_gph);
HashMap<const ChainNode*, std::vector<StageNode*>> chain2stages;
......@@ -41,7 +42,7 @@ StageGraph::StageGraph(std::unique_ptr<const ChainGraph>&& chain_gph) {
}
// Post processing
UpdateSourceAndSink();
ToDotFile(LogDir() + "/stage_graph.dot");
ToDotFile(dot_filepath);
}
} // namespace oneflow
......@@ -68,7 +68,8 @@ class StageGraph final : public Graph<StageNode, StageEdge> {
StageGraph() = delete;
~StageGraph() = default;
StageGraph(std::unique_ptr<const ChainGraph>&& chain_gph);
StageGraph(std::unique_ptr<const ChainGraph>&& chain_gph,
const std::string& dot_filepath);
const ChainGraph* chain_gph() const { return chain_gph_.get(); }
......
......@@ -36,21 +36,25 @@ std::vector<CompTaskNode*> TaskGraph::SortedCompTasksInChain(
void TaskGraph::BuildFromChainGph(
std::unique_ptr<ChainGraph>&& chain_gph,
bool need_bp) {
stage_gph_.reset(new StageGraph(std::move(chain_gph)));
BuildFromStageGph(need_bp);
bool need_bp,
const std::string& dot_filepath_prefix) {
stage_gph_.reset(new StageGraph(std::move(chain_gph),
dot_filepath_prefix + "stage_graph.dot"));
BuildFromStageGph(need_bp, dot_filepath_prefix);
}
void TaskGraph::BuildFromStageGph(bool need_bp) {
void TaskGraph::BuildFromStageGph(bool need_bp,
const std::string& dot_filepath_prefix) {
LOG(INFO) << "Build FwTaskGraph...";
Stage2TaskNodesMap stage2task_nodes;
InitCompTaskNodes(&stage2task_nodes);
InitBoxingTaskNodes(&stage2task_nodes);
ConnectBoxingTaskNodes(&stage2task_nodes);
UpdateSourceAndSink();
ToDotFile(LogDir() + "/fw_task_graph.dot");
ToDotFile(dot_filepath_prefix + "fw_task_graph.dot");
if (need_bp) {
BuildBpStruct();
ToDotFile(dot_filepath_prefix + "bp_task_graph.dot");
}
}
......@@ -234,7 +238,6 @@ void TaskGraph::BuildBpStruct() {
GenerateRelatedBpNodes(&loss_node_vec);
BackwardConnect(loss_node_vec);
UpdateSourceAndSink();
ToDotFile(LogDir() + "/bp_task_graph.dot");
}
void TaskGraph::GenerateRelatedBpNodes(
......
......@@ -32,13 +32,16 @@ class TaskGraph : public Graph<TaskNode, TaskEdge> {
protected:
TaskGraph() = default;
void BuildFromChainGph(std::unique_ptr<ChainGraph>&& chain_gph, bool need_bp);
void BuildFromChainGph(std::unique_ptr<ChainGraph>&& chain_gph,
bool need_bp,
const std::string& dot_filepath_prefix);
void EnrollFakerMccoy(CompTaskNode* faker, CompTaskNode* mccoy) {
CHECK(faker2mccoy_.emplace(faker, mccoy).second);
}
private:
void BuildFromStageGph(bool need_bp);
void BuildFromStageGph(bool need_bp,
const std::string& dot_filepath_prefix);
template<typename TaskNodeType>
TaskNodeType* NewTaskNode() {
......
......@@ -10,6 +10,7 @@ void TaskGraphMgr::Init() {
JobDesc::Singleton().train_dlnet_conf(),
JobDesc::Singleton().strategy(),
true);
LOG(FATAL) << "checkpoint";
task_gphs_.emplace_back(data_task_gph);
// construct data_chain2sorted_bp_comp_tasks
HashMap<const ChainNode*, std::vector<CompTaskNode*>>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册