Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
fc4900c6
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
fc4900c6
编写于
5月 10, 2017
作者:
W
willzhang4a58
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove empty register
上级
1c9894a7
变更
23
隐藏空白更改
内联
并排
Showing
23 changed file
with
225 addition
and
276 deletion
+225
-276
oneflow/graph/boxing_task_node.cpp
oneflow/graph/boxing_task_node.cpp
+2
-2
oneflow/graph/chain_graph.cpp
oneflow/graph/chain_graph.cpp
+9
-0
oneflow/graph/chain_graph.h
oneflow/graph/chain_graph.h
+3
-0
oneflow/graph/comp_task_node.cpp
oneflow/graph/comp_task_node.cpp
+17
-15
oneflow/graph/copy_task_node.cpp
oneflow/graph/copy_task_node.cpp
+1
-1
oneflow/graph/exec_graph.cpp
oneflow/graph/exec_graph.cpp
+2
-2
oneflow/graph/exec_graph.proto
oneflow/graph/exec_graph.proto
+2
-2
oneflow/graph/logical_graph.cpp
oneflow/graph/logical_graph.cpp
+2
-2
oneflow/graph/task_graph_manager.cpp
oneflow/graph/task_graph_manager.cpp
+0
-85
oneflow/graph/task_graph_manager.h
oneflow/graph/task_graph_manager.h
+0
-35
oneflow/graph/task_node.cpp
oneflow/graph/task_node.cpp
+19
-15
oneflow/graph/task_node.h
oneflow/graph/task_node.h
+1
-1
oneflow/job/compiler.cpp
oneflow/job/compiler.cpp
+109
-19
oneflow/job/dlnet_conf.proto
oneflow/job/dlnet_conf.proto
+1
-1
oneflow/job/job_conf.proto
oneflow/job/job_conf.proto
+1
-0
oneflow/job/job_desc.cpp
oneflow/job/job_desc.cpp
+1
-1
oneflow/job/ofelf.proto
oneflow/job/ofelf.proto
+2
-3
oneflow/operator/operator.cpp
oneflow/operator/operator.cpp
+16
-16
oneflow/operator/operator.proto
oneflow/operator/operator.proto
+8
-8
oneflow/register/register_desc.cpp
oneflow/register/register_desc.cpp
+1
-1
oneflow/register/register_desc_manager.h
oneflow/register/register_desc_manager.h
+0
-39
oneflow/task/task.proto
oneflow/task/task.proto
+2
-2
prototxt/google_net.prototxt
prototxt/google_net.prototxt
+26
-26
未找到文件。
oneflow/graph/boxing_task_node.cpp
浏览文件 @
fc4900c6
...
...
@@ -38,11 +38,11 @@ void BoxingTaskNode::FwBuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
void
BoxingTaskNode
::
EnrollAllRegstAndBindRelatedEdge
()
{
for
(
TaskEdge
*
edge
:
out_edges
())
{
std
::
string
name
=
"boxing_out_"
+
edge
->
edge_id_str
();
auto
regst_desc
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
auto
regst_desc
=
of_make_unique
<
RegstDesc
>
();
BindProducedRegstAndOutEdge
(
regst_desc
.
get
(),
edge
);
EnrollProducedRegstDesc
(
name
,
std
::
move
(
regst_desc
));
}
auto
regst_desc
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
auto
regst_desc
=
of_make_unique
<
RegstDesc
>
();
EnrollProducedRegstDesc
(
"middle"
,
std
::
move
(
regst_desc
));
}
...
...
oneflow/graph/chain_graph.cpp
浏览文件 @
fc4900c6
...
...
@@ -235,6 +235,15 @@ std::string ChainNode::ConcatedOpsName() const {
}
}
bool
ChainNode
::
HasOpWithModelOrModelTmpBlob
()
const
{
for
(
std
::
shared_ptr
<
const
Operator
>
op
:
op_vec_
)
{
if
(
!
op
->
model_bns
().
empty
()
||
!
op
->
model_tmp_bns
().
empty
())
{
return
true
;
}
}
return
false
;
}
ChainGraph
::
ChainGraph
(
const
LogicalGraph
*
logical_gph
,
const
std
::
string
&
dot_filepath
)
{
LOG
(
INFO
)
<<
"Build ChainGraph..."
;
...
...
oneflow/graph/chain_graph.h
浏览文件 @
fc4900c6
...
...
@@ -55,6 +55,8 @@ class ChainNode final : public Node<ChainNode, ChainEdge> {
}
std
::
string
VisualStr
()
const
{
return
ConcatedOpsName
();
}
bool
HasOpWithModelOrModelTmpBlob
()
const
;
private:
std
::
vector
<
std
::
shared_ptr
<
const
Operator
>>
op_vec_
;
...
...
@@ -64,6 +66,7 @@ class ChainNode final : public Node<ChainNode, ChainEdge> {
};
class
ChainEdge
final
:
public
Edge
<
ChainNode
,
ChainEdge
>
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
ChainEdge
);
...
...
oneflow/graph/comp_task_node.cpp
浏览文件 @
fc4900c6
...
...
@@ -24,16 +24,16 @@ void CompTaskNode::DataFwBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
mut_exec_gph
().
UpdateSourceAndSink
();
// out regst
if
(
!
out_edges
().
empty
())
{
auto
out_regst
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
auto
out_regst
=
of_make_unique
<
RegstDesc
>
();
BindProducedRegstAndOutEdge
(
out_regst
.
get
(),
SoleOutEdge
());
EnrollProducedRegstDesc
(
"out"
,
std
::
move
(
out_regst
));
}
// the other produced regsts
auto
activation_regst
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
auto
data_tmp_regst
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
auto
model_tmp_regst
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
auto
model_regst
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
auto
log_regst
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
auto
activation_regst
=
of_make_unique
<
RegstDesc
>
();
auto
data_tmp_regst
=
of_make_unique
<
RegstDesc
>
();
auto
model_tmp_regst
=
of_make_unique
<
RegstDesc
>
();
auto
model_regst
=
of_make_unique
<
RegstDesc
>
();
auto
log_regst
=
of_make_unique
<
RegstDesc
>
();
// EnrollProducedRegstDesc
EnrollProducedRegstDesc
(
"activation"
,
std
::
move
(
activation_regst
));
EnrollProducedRegstDesc
(
"data_tmp"
,
std
::
move
(
data_tmp_regst
));
...
...
@@ -98,7 +98,7 @@ void CompTaskNode::MdLoadFwBuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
exec_node
->
mut_op
()
=
chain_node
()
->
SoleOp
();
mut_exec_gph
().
UpdateSourceAndSink
();
auto
model_regst
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
auto
model_regst
=
of_make_unique
<
RegstDesc
>
();
exec_node
->
BindBnInOpAndRegst
(
exec_node
->
op
()
->
SoleObn
(),
model_regst
.
get
());
BindProducedRegstAndOutEdge
(
model_regst
.
get
(),
SoleOutEdge
());
CompTaskNode
*
update_0
=
md_load_gph
->
parallel_id2updt_task
().
at
(
0
);
...
...
@@ -271,9 +271,9 @@ void CompTaskNode::BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
HashMap
<
ExecEdge
*
,
const
ExecEdge
*>
bp_edge2fw_edge
;
BpBuildExecGraph
(
fw_gph
,
&
fw_node2bp_node
,
&
bp_edge2fw_edge
);
// Produced registers
auto
in_diff_regst
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
auto
model_diff_regst
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
auto
activation_diff_regst
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
auto
in_diff_regst
=
of_make_unique
<
RegstDesc
>
();
auto
model_diff_regst
=
of_make_unique
<
RegstDesc
>
();
auto
activation_diff_regst
=
of_make_unique
<
RegstDesc
>
();
// Bind out edge
if
(
!
out_edges
().
empty
())
{
BindProducedRegstAndOutEdge
(
in_diff_regst
.
get
(),
SoleOutEdge
());
...
...
@@ -292,12 +292,14 @@ void CompTaskNode::BpInferShapeOfBlobsInProducedRegsts(TaskGraph*) {
RegstDesc
*
in_regst
=
GetRelatedRegst
(
GetFwNode
()
->
SoleInEdge
());
in_diff_regst
->
CopyShapeFrom
(
in_regst
);
// model_diff_regst
RegstDesc
*
model_diff_regst
=
GetProducedRegstDesc
(
"model_diff"
);
model_diff_regst
->
CopyShapeFrom
(
GetFwNode
()
->
exec_gph
().
RelatedModelRegst
());
if
(
RegstDesc
*
md_diff_regst
=
GetProducedRegstDesc
(
"model_diff"
))
{
md_diff_regst
->
CopyShapeFrom
(
GetFwNode
()
->
exec_gph
().
RelatedModelRegst
());
}
// activation_diff_regst
RegstDesc
*
activation_diff_regst
=
GetProducedRegstDesc
(
"activation_diff"
);
RegstDesc
*
activation_regst
=
GetFwNode
()
->
GetProducedRegstDesc
(
"activation"
);
activation_diff_regst
->
CopyShapeFrom
(
activation_regst
);
if
(
RegstDesc
*
acti_diff_regst
=
GetProducedRegstDesc
(
"activation_diff"
))
{
RegstDesc
*
acti_regst
=
GetFwNode
()
->
GetProducedRegstDesc
(
"activation"
);
acti_diff_regst
->
CopyShapeFrom
(
acti_regst
);
}
}
void
CompTaskNode
::
BpBuildExecGraph
(
...
...
oneflow/graph/copy_task_node.cpp
浏览文件 @
fc4900c6
...
...
@@ -6,7 +6,7 @@
namespace
oneflow
{
void
CopyTaskNode
::
BuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
){
auto
out_regst
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
auto
out_regst
=
of_make_unique
<
RegstDesc
>
();
BindProducedRegstAndOutEdge
(
out_regst
.
get
(),
SoleOutEdge
());
RegstDesc
*
in_regst
=
GetRelatedRegst
(
SoleInEdge
());
out_regst
->
CopyLbnFrom
(
in_regst
);
...
...
oneflow/graph/exec_graph.cpp
浏览文件 @
fc4900c6
...
...
@@ -14,7 +14,7 @@ void ExecNode::ToProto(ExecNodeProto* ret) const {
bn_regst
.
first
,
bn_regst
.
second
->
regst_desc_id
()});
}
for
(
ExecEdge
*
edge
:
in_edges
())
{
ret
->
add_predecessor_id
s
(
edge
->
src_node
()
->
node_id
());
ret
->
add_predecessor_id
(
edge
->
src_node
()
->
node_id
());
}
}
...
...
@@ -29,7 +29,7 @@ RegstDesc* ExecGraph::RelatedModelRegst() const {
void
ExecGraph
::
ToProto
(
ExecGraphProto
*
ret
)
const
{
for
(
const
std
::
unique_ptr
<
ExecNode
>&
node
:
nodes
())
{
node
->
ToProto
(
ret
->
add_exec_node
s
());
node
->
ToProto
(
ret
->
add_exec_node
());
}
}
...
...
oneflow/graph/exec_graph.proto
浏览文件 @
fc4900c6
...
...
@@ -5,9 +5,9 @@ message ExecNodeProto {
uint64
id
=
1
;
string
op_name
=
2
;
map
<
string
,
uint64
>
bn_in_op2regst_desc_id
=
3
;
repeated
uint64
predecessor_id
s
=
4
;
repeated
uint64
predecessor_id
=
4
;
}
message
ExecGraphProto
{
repeated
ExecNodeProto
exec_node
s
=
1
;
repeated
ExecNodeProto
exec_node
=
1
;
}
oneflow/graph/logical_graph.cpp
浏览文件 @
fc4900c6
...
...
@@ -23,8 +23,8 @@ void LogicalGraph::NaiveBuildGraphStruct(
HashMap
<
LogicalEdge
*
,
std
::
string
>*
edge2ibn
)
{
HashMap
<
std
::
string
,
LogicalNode
*>
lbn2producer
;
// Process Op
for
(
int
op_i
=
0
;
op_i
<
dl_net_conf
.
op_
conf_
size
();
++
op_i
)
{
const
OperatorConf
&
cur_op_conf
=
dl_net_conf
.
op
_conf
(
op_i
);
for
(
int
op_i
=
0
;
op_i
<
dl_net_conf
.
op_size
();
++
op_i
)
{
const
OperatorConf
&
cur_op_conf
=
dl_net_conf
.
op
(
op_i
);
// Construct cur node
LogicalNode
*
cur_node
=
NewNode
();
cur_node
->
mut_op
()
=
OpMgr
::
Singleton
().
ConstructOp
(
cur_op_conf
);
...
...
oneflow/graph/task_graph_manager.cpp
已删除
100644 → 0
浏览文件 @
1c9894a7
#include "graph/task_graph_manager.h"
namespace
oneflow
{
void
TaskGraphMgr
::
BuildGraphs
()
{
ordered_task_gphs_
.
clear
();
// data graph
LOG
(
INFO
)
<<
"Build DataTaskGraph..."
;
auto
data_task_gph
=
new
DataTaskGraph
(
"data"
,
JobDesc
::
Singleton
().
train_dlnet_conf
(),
JobDesc
::
Singleton
().
strategy
(),
true
);
ordered_task_gphs_
.
emplace_back
(
data_task_gph
);
// construct data_chain2sorted_bp_comp_tasks
HashMap
<
const
ChainNode
*
,
std
::
vector
<
CompTaskNode
*>>
data_chain2sorted_bp_comp_tasks
;
for
(
const
auto
&
node
:
data_task_gph
->
nodes
())
{
auto
bp_node
=
dynamic_cast
<
CompTaskNode
*>
(
node
.
get
());
if
(
bp_node
==
nullptr
||
bp_node
->
IsFwNode
())
{
continue
;
}
data_chain2sorted_bp_comp_tasks
[
bp_node
->
chain_node
()].
push_back
(
bp_node
);
}
for
(
auto
&
pair
:
data_chain2sorted_bp_comp_tasks
)
{
SortByParallelId
(
&
(
pair
.
second
));
}
// model graph
for
(
const
auto
&
pair
:
data_chain2sorted_bp_comp_tasks
)
{
std
::
string
chain_tag
=
pair
.
first
->
op_vec
().
front
()
->
op_name
();
str_replace
(
&
chain_tag
,
'/'
,
'_'
);
const
std
::
string
dot_path_prefix
=
DotDir
()
+
"/model/"
+
chain_tag
+
"_"
;
ParallelPolicy
policy
=
pair
.
first
->
parallel_desc
()
->
policy
();
// model update
LOG
(
INFO
)
<<
"Build MdUpdtTaskGraph... for "
<<
chain_tag
;
auto
updt_gph
=
new
MdUpdtTaskGraph
(
"md_updt_"
+
chain_tag
,
pair
.
first
,
pair
.
second
,
dot_path_prefix
+
"model_update_"
);
ChainNode
*
updt_chain
=
updt_gph
->
chain_gph
()
->
SoleSinkNode
();
auto
sorted_updt_tasks
=
updt_gph
->
SortedCompTasksInChain
(
updt_chain
);
HashMap
<
uint64_t
,
CompTaskNode
*>
parallel_id2updt_task
;
for
(
CompTaskNode
*
update_task
:
sorted_updt_tasks
)
{
CHECK
(
parallel_id2updt_task
.
emplace
(
update_task
->
parallel_id
(),
update_task
).
second
);
}
// model load save
LOG
(
INFO
)
<<
"Build MdLoadTaskGraph... for "
<<
chain_tag
;
auto
load_gph
=
new
MdLoadTaskGraph
(
"md_load_"
+
chain_tag
,
updt_chain
,
parallel_id2updt_task
,
policy
,
dot_path_prefix
+
"model_load_"
);
LOG
(
INFO
)
<<
"Build MdSaveTaskGraph... for "
<<
chain_tag
;
auto
save_gph
=
new
MdSaveTaskGraph
(
"md_save_"
+
chain_tag
,
updt_chain
,
parallel_id2updt_task
,
policy
,
dot_path_prefix
+
"model_save_"
);
ordered_task_gphs_
.
emplace_back
(
updt_gph
);
ordered_task_gphs_
.
emplace_back
(
load_gph
);
ordered_task_gphs_
.
emplace_back
(
save_gph
);
}
// all exec_graph 2 dot
for
(
const
auto
&
task_gph
:
ordered_task_gphs_
)
{
for
(
const
auto
&
task_node
:
task_gph
->
nodes
())
{
std
::
string
file_path
=
DotDir
()
+
"/exec/"
;
file_path
+=
task_node
->
node_id_str
()
+
".dot"
;
task_node
->
exec_gph
().
ToDotFile
(
file_path
);
}
}
}
void
TaskGraphMgr
::
InferShape4Regsts
()
{
for
(
auto
&
task_gph
:
ordered_task_gphs_
)
{
LOG
(
INFO
)
<<
"InferShape... for "
<<
task_gph
->
name
();
task_gph
->
InferShapeOfBlobsInProducedRegsts
();
}
}
void
TaskGraphMgr
::
AllTaskNodesToProto
(
PbRpf
<
TaskProto
>*
ret
)
{
ret
->
Clear
();
for
(
const
auto
&
task_gph
:
ordered_task_gphs_
)
{
for
(
const
auto
&
task_node
:
task_gph
->
nodes
())
{
task_node
->
ToProto
(
ret
->
Add
());
}
}
}
}
// namespace oneflow
oneflow/graph/task_graph_manager.h
已删除
100644 → 0
浏览文件 @
1c9894a7
#ifndef ONEFLOW_GRAPH_TASK_GRAPH_MANAGER_H_
#define ONEFLOW_GRAPH_TASK_GRAPH_MANAGER_H_
#include "job/job_desc.h"
#include "graph/data_task_graph.h"
#include "graph/model_load_task_graph.h"
#include "graph/model_save_task_graph.h"
#include "graph/model_update_task_graph.h"
namespace
oneflow
{
class
TaskGraphMgr
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
TaskGraphMgr
);
~
TaskGraphMgr
()
=
default
;
static
TaskGraphMgr
&
Singleton
()
{
static
TaskGraphMgr
obj
;
return
obj
;
}
void
BuildGraphs
();
void
InferShape4Regsts
();
void
AllTaskNodesToProto
(
PbRpf
<
TaskProto
>*
);
private:
TaskGraphMgr
()
=
default
;
std
::
vector
<
std
::
unique_ptr
<
TaskGraph
>>
ordered_task_gphs_
;
};
}
// namespace oneflow
#endif // ONEFLOW_GRAPH_TASK_GRAPH_MANAGER_H_
oneflow/graph/task_node.cpp
浏览文件 @
fc4900c6
...
...
@@ -41,7 +41,12 @@ std::unique_ptr<TaskNode> TaskNode::BuildAndConnectBpNode() {
}
RegstDesc
*
TaskNode
::
GetProducedRegstDesc
(
const
std
::
string
&
regst_desc_name
)
{
return
produced_regst_descs_
.
at
(
regst_desc_name
).
get
();
auto
it
=
produced_regst_descs_
.
find
(
regst_desc_name
);
if
(
it
==
produced_regst_descs_
.
end
())
{
return
nullptr
;
}
else
{
return
it
->
second
.
get
();
}
}
void
TaskNode
::
TakeOverRegstDesc
(
TaskNode
*
rhs
,
...
...
@@ -57,6 +62,18 @@ void TaskNode::TakeOverRegstDesc(TaskNode* rhs,
std
::
move
(
this_regst
)).
second
);
}
void
TaskNode
::
RemoveRegstsWithoutBlob
()
{
for
(
auto
it
=
produced_regst_descs_
.
begin
();
it
!=
produced_regst_descs_
.
end
();)
{
if
(
it
->
second
->
lbn2shape
().
empty
())
{
auto
cur_it
=
it
++
;
produced_regst_descs_
.
erase
(
cur_it
);
}
else
{
++
it
;
}
}
}
const
TaskEdge
*
TaskNode
::
GetOutEdge4ProducedRegst
(
RegstDesc
*
regst
)
const
{
return
produced_regst2out_edge_
.
at
(
regst
);
}
...
...
@@ -94,20 +111,7 @@ void TaskNode::ToProto(TaskProto* ret) const {
ret
->
set_is_forward
(
is_fw_node_
);
exec_gph_
.
ToProto
(
ret
->
mutable_exec_graph
());
for
(
const
auto
&
pair
:
produced_regst_descs_
)
{
ret
->
mutable_produced_regst_desc_ids
()
->
Add
(
pair
.
second
->
regst_desc_id
());
}
// subscribed_regsts
std
::
unordered_set
<
RegstDesc
*>
subscribed_regsts
;
for
(
const
auto
&
exec_node
:
exec_gph
().
nodes
())
{
for
(
const
auto
&
pair
:
exec_node
->
bn_in_op2regst
())
{
RegstDesc
*
related_regst
=
pair
.
second
;
if
(
related_regst
->
GetProducer
()
==
this
)
{
continue
;
}
subscribed_regsts
.
insert
(
related_regst
);
}
}
for
(
RegstDesc
*
regst
:
subscribed_regsts
)
{
ret
->
mutable_subscribed_regst_desc_ids
()
->
Add
(
regst
->
regst_desc_id
());
pair
.
second
->
ToProto
(
ret
->
mutable_produced_regst_desc
()
->
Add
());
}
}
...
...
oneflow/graph/task_node.h
浏览文件 @
fc4900c6
...
...
@@ -4,7 +4,6 @@
#include "task/task.pb.h"
#include "graph/stage_graph.h"
#include "graph/exec_graph.h"
#include "register/register_desc_manager.h"
namespace
oneflow
{
...
...
@@ -55,6 +54,7 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
//
RegstDesc
*
GetProducedRegstDesc
(
const
std
::
string
&
regst_desc_name
);
void
TakeOverRegstDesc
(
TaskNode
*
rhs
,
const
std
::
string
&
regst_desc_name
);
void
RemoveRegstsWithoutBlob
();
//
const
TaskEdge
*
GetOutEdge4ProducedRegst
(
RegstDesc
*
)
const
;
...
...
oneflow/job/compiler.cpp
浏览文件 @
fc4900c6
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "job/id_manager.h"
#include "graph/task_graph_manager.h"
#include "common/proto_io.h"
#include "graph/model_load_task_graph.h"
#include "graph/model_save_task_graph.h"
#include "graph/model_update_task_graph.h"
#include "graph/data_task_graph.h"
#include "job/job_conf.pb.h"
#include "job/ofelf.pb.h"
DEFINE_string
(
job_conf_filepath
,
""
,
""
);
DEFINE_string
(
elf_filepath
,
""
,
""
);
namespace
oneflow
{
class
Compiler
final
{
...
...
@@ -20,35 +21,124 @@ class Compiler final {
return
obj
;
}
void
Compile
(
const
JobConf
&
job_conf
)
{
JobDesc
::
Singleton
().
InitFromJobConf
(
job_conf
);
IDMgr
::
Singleton
().
InitFromResource
(
JobDesc
::
Singleton
().
resource
());
TaskGraphMgr
::
Singleton
().
BuildGraphs
();
JobDesc
::
Singleton
().
set_piece_size
(
50
);
// TODO: set appropriate piece_size
TaskGraphMgr
::
Singleton
().
InferShape4Regsts
();
// To Proto
OfElf
elf
;
TaskGraphMgr
::
Singleton
().
AllTaskNodesToProto
(
elf
.
mutable_tasks
());
RegstDescMgr
::
Singleton
().
AllRegstsToProto
(
elf
.
mutable_regst_descs
());
OpMgr
::
Singleton
().
AllOpToProto
(
elf
.
mutable_operators
());
JobDesc
::
Singleton
().
ToProto
(
elf
.
mutable_job_desc
());
PrintProtoToTextFile
(
elf
,
FLAGS_elf_filepath
);
}
void
Compile
(
const
JobConf
&
job_conf
,
const
std
::
string
&
elf_filepath
);
private:
Compiler
()
=
default
;
void
RunFunc4EachTaskNode
(
std
::
function
<
void
(
TaskNode
*
)
>
func
);
void
BuildGraphs
();
void
RemoveRegstsWithoutBlob
();
void
InferShape4Regsts
();
std
::
vector
<
std
::
unique_ptr
<
TaskGraph
>>
ordered_task_gphs_
;
};
void
Compiler
::
RunFunc4EachTaskNode
(
std
::
function
<
void
(
TaskNode
*
)
>
func
)
{
for
(
const
auto
&
task_gph
:
ordered_task_gphs_
)
{
for
(
const
auto
&
task_node
:
task_gph
->
nodes
())
{
func
(
task_node
.
get
());
}
}
}
// TODO: inference "piece_size" and "register_num for each register_desc"
void
Compiler
::
Compile
(
const
JobConf
&
job_conf
,
const
std
::
string
&
elf_filepath
)
{
JobDesc
::
Singleton
().
InitFromJobConf
(
job_conf
);
IDMgr
::
Singleton
().
InitFromResource
(
JobDesc
::
Singleton
().
resource
());
BuildGraphs
();
RunFunc4EachTaskNode
([](
TaskNode
*
node
)
{
node
->
RemoveRegstsWithoutBlob
();
});
InferShape4Regsts
();
OfElf
elf
;
RunFunc4EachTaskNode
([
&
elf
](
TaskNode
*
node
)
{
node
->
ToProto
(
elf
.
mutable_task
()
->
Add
());
});
OpMgr
::
Singleton
().
AllOpToProto
(
elf
.
mutable_op
());
JobDesc
::
Singleton
().
ToProto
(
elf
.
mutable_job_desc
());
PrintProtoToTextFile
(
elf
,
elf_filepath
);
}
void
Compiler
::
BuildGraphs
()
{
ordered_task_gphs_
.
clear
();
// data graph
LOG
(
INFO
)
<<
"Build DataTaskGraph..."
;
auto
data_task_gph
=
new
DataTaskGraph
(
"data"
,
JobDesc
::
Singleton
().
train_dlnet_conf
(),
JobDesc
::
Singleton
().
strategy
(),
true
);
ordered_task_gphs_
.
emplace_back
(
data_task_gph
);
// construct data_chain2sorted_bp_comp_tasks
HashMap
<
const
ChainNode
*
,
std
::
vector
<
CompTaskNode
*>>
data_chain2sorted_bp_comp_tasks
;
for
(
const
auto
&
node
:
data_task_gph
->
nodes
())
{
auto
bp_node
=
dynamic_cast
<
CompTaskNode
*>
(
node
.
get
());
if
(
bp_node
==
nullptr
||
bp_node
->
IsFwNode
())
{
continue
;
}
data_chain2sorted_bp_comp_tasks
[
bp_node
->
chain_node
()].
push_back
(
bp_node
);
}
for
(
auto
&
pair
:
data_chain2sorted_bp_comp_tasks
)
{
SortByParallelId
(
&
(
pair
.
second
));
}
// model graph
for
(
const
auto
&
pair
:
data_chain2sorted_bp_comp_tasks
)
{
if
(
pair
.
first
->
HasOpWithModelOrModelTmpBlob
()
==
false
)
{
continue
;
}
std
::
string
chain_tag
=
pair
.
first
->
op_vec
().
front
()
->
op_name
();
str_replace
(
&
chain_tag
,
'/'
,
'_'
);
const
std
::
string
dot_path_prefix
=
DotDir
()
+
"/model/"
+
chain_tag
+
"_"
;
ParallelPolicy
policy
=
pair
.
first
->
parallel_desc
()
->
policy
();
LOG
(
INFO
)
<<
"Build MdUpdtTaskGraph... for "
<<
chain_tag
;
auto
updt_gph
=
new
MdUpdtTaskGraph
(
"md_updt_"
+
chain_tag
,
pair
.
first
,
pair
.
second
,
dot_path_prefix
+
"model_update_"
);
ChainNode
*
updt_chain
=
updt_gph
->
chain_gph
()
->
SoleSinkNode
();
auto
sorted_updt_tasks
=
updt_gph
->
SortedCompTasksInChain
(
updt_chain
);
HashMap
<
uint64_t
,
CompTaskNode
*>
parallel_id2updt_task
;
for
(
CompTaskNode
*
update_task
:
sorted_updt_tasks
)
{
CHECK
(
parallel_id2updt_task
.
emplace
(
update_task
->
parallel_id
(),
update_task
).
second
);
}
LOG
(
INFO
)
<<
"Build MdLoadTaskGraph... for "
<<
chain_tag
;
auto
load_gph
=
new
MdLoadTaskGraph
(
"md_load_"
+
chain_tag
,
updt_chain
,
parallel_id2updt_task
,
policy
,
dot_path_prefix
+
"model_load_"
);
LOG
(
INFO
)
<<
"Build MdSaveTaskGraph... for "
<<
chain_tag
;
auto
save_gph
=
new
MdSaveTaskGraph
(
"md_save_"
+
chain_tag
,
updt_chain
,
parallel_id2updt_task
,
policy
,
dot_path_prefix
+
"model_save_"
);
ordered_task_gphs_
.
emplace_back
(
updt_gph
);
ordered_task_gphs_
.
emplace_back
(
load_gph
);
ordered_task_gphs_
.
emplace_back
(
save_gph
);
}
// all exec_graph 2 dot
RunFunc4EachTaskNode
([](
TaskNode
*
node
)
{
std
::
string
file_path
=
DotDir
()
+
"/exec/"
+
node
->
node_id_str
()
+
".dot"
;
node
->
exec_gph
().
ToDotFile
(
file_path
);
});
}
void
Compiler
::
InferShape4Regsts
()
{
for
(
auto
&
task_gph
:
ordered_task_gphs_
)
{
LOG
(
INFO
)
<<
"InferShape... for "
<<
task_gph
->
name
();
task_gph
->
InferShapeOfBlobsInProducedRegsts
();
}
}
}
// namespace oneflow
DEFINE_string
(
job_conf_filepath
,
""
,
""
);
DEFINE_string
(
elf_filepath
,
""
,
""
);
int
main
(
int
argc
,
char
**
argv
)
{
google
::
InitGoogleLogging
(
argv
[
0
]);
google
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
true
);
LOG
(
INFO
)
<<
"Compiler Starting Up..."
;
oneflow
::
JobConf
job_conf
;
oneflow
::
ParseProtoFromTextFile
(
FLAGS_job_conf_filepath
,
&
job_conf
);
oneflow
::
Compiler
::
Singleton
().
Compile
(
job_conf
);
oneflow
::
Compiler
::
Singleton
().
Compile
(
job_conf
,
FLAGS_elf_filepath
);
LOG
(
INFO
)
<<
"Compiler Shutting Down..."
;
return
0
;
}
oneflow/job/dlnet_conf.proto
浏览文件 @
fc4900c6
...
...
@@ -5,5 +5,5 @@ import "operator/op_conf.proto";
message
DLNetConf
{
string
name
=
1
;
repeated
OperatorConf
op
_conf
=
100
;
repeated
OperatorConf
op
=
100
;
}
oneflow/job/job_conf.proto
浏览文件 @
fc4900c6
...
...
@@ -8,4 +8,5 @@ message JobConf {
string
model_load_machine
=
4
;
string
model_save_machine
=
5
;
uint32
batch_size
=
6
;
uint32
piece_size
=
7
;
}
oneflow/job/job_desc.cpp
浏览文件 @
fc4900c6
...
...
@@ -11,7 +11,7 @@ void JobDesc::InitFromJobConf(const JobConf& conf) {
md_load_machine_
=
conf
.
model_load_machine
();
md_save_machine_
=
conf
.
model_save_machine
();
batch_size_
=
conf
.
batch_size
();
piece_size_
=
0
;
// TODO
piece_size_
=
conf
.
piece_size
();
}
void
JobDesc
::
InitFromProto
(
const
JobDescProto
&
proto
)
{
...
...
oneflow/job/ofelf.proto
浏览文件 @
fc4900c6
...
...
@@ -7,8 +7,7 @@ import "task/task.proto";
import
"job/job_desc.proto"
;
message
OfElf
{
repeated
TaskProto
tasks
=
1
;
repeated
RegstDescProto
regst_descs
=
2
;
repeated
OperatorProto
operators
=
3
;
repeated
TaskProto
task
=
1
;
repeated
OperatorProto
op
=
3
;
JobDescProto
job_desc
=
4
;
}
oneflow/operator/operator.cpp
浏览文件 @
fc4900c6
...
...
@@ -5,27 +5,27 @@ namespace oneflow {
void
Operator
::
InitFromProto
(
const
OperatorProto
&
op_proto
)
{
op_conf_
=
op_proto
.
op_conf
();
bn_in_op2lbn_
=
PbMap2HashMap
(
op_proto
.
bn_in_op2lbn
());
data_tmp_bns_
=
PbVec2StdVec
(
op_proto
.
data_tmp_bn
s
());
input_bns_
=
PbVec2StdVec
(
op_proto
.
input_bn
s
());
input_diff_bns_
=
PbVec2StdVec
(
op_proto
.
input_diff_bn
s
());
output_bns_
=
PbVec2StdVec
(
op_proto
.
output_bn
s
());
output_diff_bns_
=
PbVec2StdVec
(
op_proto
.
output_diff_bn
s
());
model_bns_
=
PbVec2StdVec
(
op_proto
.
model_bn
s
());
model_diff_bns_
=
PbVec2StdVec
(
op_proto
.
model_diff_bn
s
());
model_tmp_bns_
=
PbVec2StdVec
(
op_proto
.
model_tmp_bn
s
());
data_tmp_bns_
=
PbVec2StdVec
(
op_proto
.
data_tmp_bn
());
input_bns_
=
PbVec2StdVec
(
op_proto
.
input_bn
());
input_diff_bns_
=
PbVec2StdVec
(
op_proto
.
input_diff_bn
());
output_bns_
=
PbVec2StdVec
(
op_proto
.
output_bn
());
output_diff_bns_
=
PbVec2StdVec
(
op_proto
.
output_diff_bn
());
model_bns_
=
PbVec2StdVec
(
op_proto
.
model_bn
());
model_diff_bns_
=
PbVec2StdVec
(
op_proto
.
model_diff_bn
());
model_tmp_bns_
=
PbVec2StdVec
(
op_proto
.
model_tmp_bn
());
}
void
Operator
::
ToProto
(
OperatorProto
*
ret
)
const
{
*
(
ret
->
mutable_op_conf
())
=
op_conf_
;
*
(
ret
->
mutable_bn_in_op2lbn
())
=
HashMap2PbMap
(
bn_in_op2lbn_
);
*
(
ret
->
mutable_data_tmp_bn
s
())
=
StdVec2PbVec
(
data_tmp_bns_
);
*
(
ret
->
mutable_input_bn
s
())
=
StdVec2PbVec
(
input_bns_
);
*
(
ret
->
mutable_input_diff_bn
s
())
=
StdVec2PbVec
(
input_diff_bns_
);
*
(
ret
->
mutable_output_bn
s
())
=
StdVec2PbVec
(
output_bns_
);
*
(
ret
->
mutable_output_diff_bn
s
())
=
StdVec2PbVec
(
output_diff_bns_
);
*
(
ret
->
mutable_model_bn
s
())
=
StdVec2PbVec
(
model_bns_
);
*
(
ret
->
mutable_model_diff_bn
s
())
=
StdVec2PbVec
(
model_diff_bns_
);
*
(
ret
->
mutable_model_tmp_bn
s
())
=
StdVec2PbVec
(
model_tmp_bns_
);
*
(
ret
->
mutable_data_tmp_bn
())
=
StdVec2PbVec
(
data_tmp_bns_
);
*
(
ret
->
mutable_input_bn
())
=
StdVec2PbVec
(
input_bns_
);
*
(
ret
->
mutable_input_diff_bn
())
=
StdVec2PbVec
(
input_diff_bns_
);
*
(
ret
->
mutable_output_bn
())
=
StdVec2PbVec
(
output_bns_
);
*
(
ret
->
mutable_output_diff_bn
())
=
StdVec2PbVec
(
output_diff_bns_
);
*
(
ret
->
mutable_model_bn
())
=
StdVec2PbVec
(
model_bns_
);
*
(
ret
->
mutable_model_diff_bn
())
=
StdVec2PbVec
(
model_diff_bns_
);
*
(
ret
->
mutable_model_tmp_bn
())
=
StdVec2PbVec
(
model_tmp_bns_
);
}
const
std
::
string
&
Operator
::
Lbn4BnInOp
(
const
std
::
string
&
bn_in_op
)
const
{
...
...
oneflow/operator/operator.proto
浏览文件 @
fc4900c6
...
...
@@ -7,13 +7,13 @@ message OperatorProto {
OperatorConf
op_conf
=
1
;
map
<
string
,
string
>
bn_in_op2lbn
=
3
;
repeated
string
data_tmp_bn
s
=
4
;
repeated
string
input_bn
s
=
5
;
repeated
string
input_diff_bn
s
=
6
;
repeated
string
output_bn
s
=
7
;
repeated
string
output_diff_bn
s
=
8
;
repeated
string
data_tmp_bn
=
4
;
repeated
string
input_bn
=
5
;
repeated
string
input_diff_bn
=
6
;
repeated
string
output_bn
=
7
;
repeated
string
output_diff_bn
=
8
;
repeated
string
model_bn
s
=
9
;
repeated
string
model_diff_bn
s
=
10
;
repeated
string
model_tmp_bn
s
=
11
;
repeated
string
model_bn
=
9
;
repeated
string
model_diff_bn
=
10
;
repeated
string
model_tmp_bn
=
11
;
}
oneflow/register/register_desc.cpp
浏览文件 @
fc4900c6
...
...
@@ -7,7 +7,7 @@ namespace oneflow {
RegstDesc
::
RegstDesc
()
{
producer_
=
nullptr
;
register_num_
=
0
;
// TODO
register_num_
=
5
;
// TODO
}
void
RegstDesc
::
CopyLbnFrom
(
const
RegstDesc
*
rhs
)
{
...
...
oneflow/register/register_desc_manager.h
已删除
100644 → 0
浏览文件 @
1c9894a7
#ifndef ONEFLOW_REGISTER_REGISTER_DESC_MANAGER_H_
#define ONEFLOW_REGISTER_REGISTER_DESC_MANAGER_H_
#include "register/register_desc.h"
namespace
oneflow
{
class
RegstDescMgr
final
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
RegstDescMgr
);
~
RegstDescMgr
()
=
default
;
static
RegstDescMgr
&
Singleton
()
{
static
RegstDescMgr
obj
;
return
obj
;
}
std
::
unique_ptr
<
RegstDesc
>
CreateRegisterDesc
()
{
auto
ret
=
of_make_unique
<
RegstDesc
>
();
regst_descs_
.
push_back
(
ret
.
get
());
return
ret
;
}
void
AllRegstsToProto
(
PbRpf
<
RegstDescProto
>*
ret
)
{
ret
->
Clear
();
for
(
RegstDesc
*
regst
:
regst_descs_
)
{
regst
->
ToProto
(
ret
->
Add
());
}
}
private:
RegstDescMgr
()
=
default
;
std
::
list
<
RegstDesc
*>
regst_descs_
;
};
}
// namespace oneflow
#endif // ONEFLOW_REGISTER_REGISTER_DESC_MANAGER_H_
oneflow/task/task.proto
浏览文件 @
fc4900c6
...
...
@@ -2,6 +2,7 @@ syntax = "proto3";
package
oneflow
;
import
"graph/exec_graph.proto"
;
import
"register/register_desc.proto"
;
enum
TaskType
{
HostCompTask
=
0
;
...
...
@@ -19,8 +20,7 @@ message TaskProto {
uint64
thrd_local_id
=
4
;
bool
is_forward
=
5
;
ExecGraphProto
exec_graph
=
6
;
repeated
uint64
produced_regst_desc_ids
=
7
;
repeated
uint64
subscribed_regst_desc_ids
=
8
;
repeated
RegstDescProto
produced_regst_desc
=
7
;
// for CompTask
uint64
parallel_id
=
1000
;
}
prototxt/google_net.prototxt
浏览文件 @
fc4900c6
name: "GoogleNet"
op
_conf
{
op {
name: "mnist"
data_loader_conf {
feature: "feature"
...
...
@@ -13,7 +13,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "conv1"
convolution_conf {
in: "mnist/feature"
...
...
@@ -28,7 +28,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "pool1"
pooling_conf {
in: "conv1/out"
...
...
@@ -43,7 +43,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "conv2_1x1"
convolution_conf {
in: "pool1/out"
...
...
@@ -58,7 +58,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "conv2_1x3"
convolution_conf {
in: "pool1/out"
...
...
@@ -73,7 +73,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "conv2_1x5"
convolution_conf {
in: "pool1/out"
...
...
@@ -88,7 +88,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "conv2_3x3"
convolution_conf {
in: "conv2_1x3/out"
...
...
@@ -103,7 +103,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "conv2_5x5"
convolution_conf {
in: "conv2_1x5/out"
...
...
@@ -118,7 +118,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "concat2"
concat_conf {
in: "conv2_1x1/out"
...
...
@@ -129,7 +129,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "ip1"
innerproduct_conf {
in: "concat2/out"
...
...
@@ -138,7 +138,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "relu1"
relu_conf {
in: "ip1/out"
...
...
@@ -146,7 +146,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "ip2"
innerproduct_conf {
in: "relu1/out"
...
...
@@ -155,7 +155,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "softmax1"
softmax_conf {
in: "ip2/out"
...
...
@@ -163,7 +163,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "loss1"
multinomial_logistic_loss_conf {
prediction: "softmax1/out"
...
...
@@ -172,7 +172,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "pool2"
pooling_conf {
in: "concat2/out"
...
...
@@ -187,7 +187,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "conv3_1x1"
convolution_conf {
in: "pool2/out"
...
...
@@ -202,7 +202,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "conv3_1x3"
convolution_conf {
in: "pool2/out"
...
...
@@ -217,7 +217,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "conv3_1x5"
convolution_conf {
in: "pool2/out"
...
...
@@ -232,7 +232,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "conv3_3x3"
convolution_conf {
in: "conv3_1x3/out"
...
...
@@ -247,7 +247,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "conv3_5x5"
convolution_conf {
in: "conv3_1x5/out"
...
...
@@ -262,7 +262,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "concat3"
concat_conf {
in: "conv3_1x1/out"
...
...
@@ -273,7 +273,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "ip3"
innerproduct_conf {
in: "concat3/out"
...
...
@@ -282,7 +282,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "relu3"
relu_conf {
in: "ip3/out"
...
...
@@ -290,7 +290,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "ip4"
innerproduct_conf {
in: "relu3/out"
...
...
@@ -299,7 +299,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "softmax2"
softmax_conf {
in: "ip4/out"
...
...
@@ -307,7 +307,7 @@ op_conf {
}
}
op
_conf
{
op {
name: "loss2"
multinomial_logistic_loss_conf {
prediction: "softmax2/out"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录