Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
50d606e7
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
50d606e7
编写于
4月 27, 2017
作者:
W
willzhang4a58
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add checkpoint
上级
8b45c6ea
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
39 addition
and
23 deletion
+39
-23
oneflow/graph/chain_graph.cpp
oneflow/graph/chain_graph.cpp
+3
-2
oneflow/graph/chain_graph.h
oneflow/graph/chain_graph.h
+2
-1
oneflow/graph/data_task_graph.h
oneflow/graph/data_task_graph.h
+3
-4
oneflow/graph/logical_graph.cpp
oneflow/graph/logical_graph.cpp
+3
-1
oneflow/graph/logical_graph.h
oneflow/graph/logical_graph.h
+2
-1
oneflow/graph/model_load_task_graph.cpp
oneflow/graph/model_load_task_graph.cpp
+2
-1
oneflow/graph/model_save_task_graph.cpp
oneflow/graph/model_save_task_graph.cpp
+2
-1
oneflow/graph/model_update_task_graph.cpp
oneflow/graph/model_update_task_graph.cpp
+2
-1
oneflow/graph/stage_graph.cpp
oneflow/graph/stage_graph.cpp
+3
-2
oneflow/graph/stage_graph.h
oneflow/graph/stage_graph.h
+2
-1
oneflow/graph/task_graph.cpp
oneflow/graph/task_graph.cpp
+9
-6
oneflow/graph/task_graph.h
oneflow/graph/task_graph.h
+5
-2
oneflow/graph/task_graph_manager.cpp
oneflow/graph/task_graph_manager.cpp
+1
-0
未找到文件。
oneflow/graph/chain_graph.cpp
浏览文件 @
50d606e7
...
...
@@ -231,7 +231,8 @@ std::string ChainNode::ConcatedOpsName() const {
return
ss
.
str
().
substr
(
2
);
}
ChainGraph
::
ChainGraph
(
const
LogicalGraph
*
logical_gph
)
{
ChainGraph
::
ChainGraph
(
const
LogicalGraph
*
logical_gph
,
const
std
::
string
&
dot_filepath
)
{
LOG
(
INFO
)
<<
"Build ChainGraph..."
;
// Build Chain
std
::
list
<
Chain
>
chain_list
;
...
...
@@ -275,7 +276,7 @@ ChainGraph::ChainGraph(const LogicalGraph* logical_gph) {
// Post processing
UpdateSourceAndSink
();
SetInOutLbn4AllChainNodeInDataTaskGraph
();
ToDotFile
(
LogDir
()
+
"/chain_graph.dot"
);
ToDotFile
(
dot_filepath
);
}
void
ChainGraph
::
SetInOutLbn4AllChainNodeInDataTaskGraph
()
{
...
...
oneflow/graph/chain_graph.h
浏览文件 @
50d606e7
...
...
@@ -81,7 +81,8 @@ class ChainGraph final : public Graph<ChainNode, ChainEdge> {
ChainGraph
()
=
default
;
~
ChainGraph
()
=
default
;
ChainGraph
(
const
LogicalGraph
*
logical_gph
);
ChainGraph
(
const
LogicalGraph
*
logical_gph
,
const
std
::
string
&
dot_filepath
);
private:
void
SetInOutLbn4AllChainNodeInDataTaskGraph
();
...
...
oneflow/graph/data_task_graph.h
浏览文件 @
50d606e7
...
...
@@ -14,10 +14,9 @@ class DataTaskGraph final : public TaskGraph {
DataTaskGraph
(
const
DLNetConf
&
dl_net_conf
,
const
Strategy
&
strategy_conf
,
bool
need_bp
)
{
LogicalGraph
logical_gph
(
dl_net_conf
,
strategy_conf
);
logical_gph
.
ToDotFile
(
LogDir
()
+
"/logical_graph.dot"
);
auto
chain_gph
=
of_make_unique
<
ChainGraph
>
(
&
logical_gph
);
BuildFromChainGph
(
std
::
move
(
chain_gph
),
need_bp
);
LogicalGraph
logical_gph
(
dl_net_conf
,
strategy_conf
,
LogDir
()
+
"/logical_graph.dot"
);
auto
chain_gph
=
of_make_unique
<
ChainGraph
>
(
&
logical_gph
,
LogDir
()
+
"/data_chain_graph.dot"
);
BuildFromChainGph
(
std
::
move
(
chain_gph
),
need_bp
,
LogDir
()
+
"/data_"
);
}
CompTaskNodeMemFunc
Func4FwBuildExecAndProducedRegsts
()
const
override
{
...
...
oneflow/graph/logical_graph.cpp
浏览文件 @
50d606e7
...
...
@@ -6,13 +6,15 @@
namespace
oneflow
{
LogicalGraph
::
LogicalGraph
(
const
DLNetConf
&
dl_net_conf
,
const
Strategy
&
strategy_conf
)
{
const
Strategy
&
strategy_conf
,
const
std
::
string
&
dot_filepath
)
{
LOG
(
INFO
)
<<
"Build LogicalGraph..."
;
HashMap
<
LogicalEdge
*
,
std
::
string
>
edge2lbn
;
HashMap
<
LogicalEdge
*
,
std
::
string
>
edge2ibn
;
NaiveBuildGraphStruct
(
dl_net_conf
,
&
edge2lbn
,
&
edge2ibn
);
FillNodeWithParallelDesc
(
strategy_conf
);
AddCloneNodes
(
edge2lbn
,
edge2ibn
);
ToDotFile
(
dot_filepath
);
}
void
LogicalGraph
::
NaiveBuildGraphStruct
(
...
...
oneflow/graph/logical_graph.h
浏览文件 @
50d606e7
...
...
@@ -58,7 +58,8 @@ class LogicalGraph final : public Graph<LogicalNode, LogicalEdge> {
~
LogicalGraph
()
=
default
;
LogicalGraph
(
const
DLNetConf
&
dl_net_conf
,
const
Strategy
&
strategy_conf
);
const
Strategy
&
strategy_conf
,
const
std
::
string
&
dot_filepath
);
private:
void
NaiveBuildGraphStruct
(
...
...
oneflow/graph/model_load_task_graph.cpp
浏览文件 @
50d606e7
...
...
@@ -40,7 +40,8 @@ void MdLoadTaskGraph::BuildTaskGraph(const ChainNode* update_chain) {
faker_chain
->
mut_input_lbns
()
=
{
RegstDesc
::
kAllLbn
};
Connect
(
load_chain
,
chain_gph
->
NewEdge
(),
faker_chain
);
chain_gph
->
UpdateSourceAndSink
();
BuildFromChainGph
(
std
::
move
(
chain_gph
),
false
);
chain_gph
->
ToDotFile
(
LogDir
()
+
"/model_load_chain_graph.dot"
);
BuildFromChainGph
(
std
::
move
(
chain_gph
),
false
,
LogDir
()
+
"/model_load_"
);
}
void
MdLoadTaskGraph
::
InitFaker2Mccoy
(
...
...
oneflow/graph/model_save_task_graph.cpp
浏览文件 @
50d606e7
...
...
@@ -41,7 +41,8 @@ void MdSaveTaskGraph::BuildTaskGraph(const ChainNode* update_chain) {
// Connect
Connect
(
faker_chain
,
chain_gph
->
NewEdge
(),
save_chain
);
chain_gph
->
UpdateSourceAndSink
();
BuildFromChainGph
(
std
::
move
(
chain_gph
),
false
);
chain_gph
->
ToDotFile
(
LogDir
()
+
"/model_save_chain_graph.dot"
);
BuildFromChainGph
(
std
::
move
(
chain_gph
),
false
,
LogDir
()
+
"/model_save_"
);
}
void
MdSaveTaskGraph
::
InitFaker2Mccoy
(
...
...
oneflow/graph/model_update_task_graph.cpp
浏览文件 @
50d606e7
...
...
@@ -37,7 +37,8 @@ void MdUpdtTaskGraph::BuildTaskGraph(const ChainNode* data_chain) {
Connect
(
faker_chain
,
chain_gph
->
NewEdge
(),
updt_chain
);
}
//
BuildFromChainGph
(
std
::
move
(
chain_gph
),
false
);
chain_gph
->
ToDotFile
(
LogDir
()
+
"/model_update_chain_graph.dot"
);
BuildFromChainGph
(
std
::
move
(
chain_gph
),
false
,
LogDir
()
+
"/model_update_"
);
}
void
MdUpdtTaskGraph
::
InitFaker2MccoyAndParallelId2UpdtMap
(
...
...
oneflow/graph/stage_graph.cpp
浏览文件 @
50d606e7
...
...
@@ -3,7 +3,8 @@
namespace
oneflow
{
StageGraph
::
StageGraph
(
std
::
unique_ptr
<
const
ChainGraph
>&&
chain_gph
)
{
StageGraph
::
StageGraph
(
std
::
unique_ptr
<
const
ChainGraph
>&&
chain_gph
,
const
std
::
string
&
dot_filepath
)
{
LOG
(
INFO
)
<<
"Build StageGraph..."
;
chain_gph_
=
std
::
move
(
chain_gph
);
HashMap
<
const
ChainNode
*
,
std
::
vector
<
StageNode
*>>
chain2stages
;
...
...
@@ -41,7 +42,7 @@ StageGraph::StageGraph(std::unique_ptr<const ChainGraph>&& chain_gph) {
}
// Post processing
UpdateSourceAndSink
();
ToDotFile
(
LogDir
()
+
"/stage_graph.dot"
);
ToDotFile
(
dot_filepath
);
}
}
// namespace oneflow
oneflow/graph/stage_graph.h
浏览文件 @
50d606e7
...
...
@@ -68,7 +68,8 @@ class StageGraph final : public Graph<StageNode, StageEdge> {
StageGraph
()
=
delete
;
~
StageGraph
()
=
default
;
StageGraph
(
std
::
unique_ptr
<
const
ChainGraph
>&&
chain_gph
);
StageGraph
(
std
::
unique_ptr
<
const
ChainGraph
>&&
chain_gph
,
const
std
::
string
&
dot_filepath
);
const
ChainGraph
*
chain_gph
()
const
{
return
chain_gph_
.
get
();
}
...
...
oneflow/graph/task_graph.cpp
浏览文件 @
50d606e7
...
...
@@ -36,21 +36,25 @@ std::vector<CompTaskNode*> TaskGraph::SortedCompTasksInChain(
void
TaskGraph
::
BuildFromChainGph
(
std
::
unique_ptr
<
ChainGraph
>&&
chain_gph
,
bool
need_bp
)
{
stage_gph_
.
reset
(
new
StageGraph
(
std
::
move
(
chain_gph
)));
BuildFromStageGph
(
need_bp
);
bool
need_bp
,
const
std
::
string
&
dot_filepath_prefix
)
{
stage_gph_
.
reset
(
new
StageGraph
(
std
::
move
(
chain_gph
),
dot_filepath_prefix
+
"stage_graph.dot"
));
BuildFromStageGph
(
need_bp
,
dot_filepath_prefix
);
}
void
TaskGraph
::
BuildFromStageGph
(
bool
need_bp
)
{
void
TaskGraph
::
BuildFromStageGph
(
bool
need_bp
,
const
std
::
string
&
dot_filepath_prefix
)
{
LOG
(
INFO
)
<<
"Build FwTaskGraph..."
;
Stage2TaskNodesMap
stage2task_nodes
;
InitCompTaskNodes
(
&
stage2task_nodes
);
InitBoxingTaskNodes
(
&
stage2task_nodes
);
ConnectBoxingTaskNodes
(
&
stage2task_nodes
);
UpdateSourceAndSink
();
ToDotFile
(
LogDir
()
+
"/
fw_task_graph.dot"
);
ToDotFile
(
dot_filepath_prefix
+
"
fw_task_graph.dot"
);
if
(
need_bp
)
{
BuildBpStruct
();
ToDotFile
(
dot_filepath_prefix
+
"bp_task_graph.dot"
);
}
}
...
...
@@ -234,7 +238,6 @@ void TaskGraph::BuildBpStruct() {
GenerateRelatedBpNodes
(
&
loss_node_vec
);
BackwardConnect
(
loss_node_vec
);
UpdateSourceAndSink
();
ToDotFile
(
LogDir
()
+
"/bp_task_graph.dot"
);
}
void
TaskGraph
::
GenerateRelatedBpNodes
(
...
...
oneflow/graph/task_graph.h
浏览文件 @
50d606e7
...
...
@@ -32,13 +32,16 @@ class TaskGraph : public Graph<TaskNode, TaskEdge> {
protected:
TaskGraph
()
=
default
;
void
BuildFromChainGph
(
std
::
unique_ptr
<
ChainGraph
>&&
chain_gph
,
bool
need_bp
);
void
BuildFromChainGph
(
std
::
unique_ptr
<
ChainGraph
>&&
chain_gph
,
bool
need_bp
,
const
std
::
string
&
dot_filepath_prefix
);
void
EnrollFakerMccoy
(
CompTaskNode
*
faker
,
CompTaskNode
*
mccoy
)
{
CHECK
(
faker2mccoy_
.
emplace
(
faker
,
mccoy
).
second
);
}
private:
void
BuildFromStageGph
(
bool
need_bp
);
void
BuildFromStageGph
(
bool
need_bp
,
const
std
::
string
&
dot_filepath_prefix
);
template
<
typename
TaskNodeType
>
TaskNodeType
*
NewTaskNode
()
{
...
...
oneflow/graph/task_graph_manager.cpp
浏览文件 @
50d606e7
...
...
@@ -10,6 +10,7 @@ void TaskGraphMgr::Init() {
JobDesc
::
Singleton
().
train_dlnet_conf
(),
JobDesc
::
Singleton
().
strategy
(),
true
);
LOG
(
FATAL
)
<<
"checkpoint"
;
task_gphs_
.
emplace_back
(
data_task_gph
);
// construct data_chain2sorted_bp_comp_tasks
HashMap
<
const
ChainNode
*
,
std
::
vector
<
CompTaskNode
*>>
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录