Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
3589c342
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
3589c342
编写于
4月 29, 2017
作者:
W
willzhang4a58
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine task_node::build_exec
上级
b55817ac
变更
22
隐藏空白更改
内联
并排
Showing
22 changed file
with
417 addition
and
294 deletion
+417
-294
oneflow/graph/boxing_task_node.cpp
oneflow/graph/boxing_task_node.cpp
+33
-19
oneflow/graph/boxing_task_node.h
oneflow/graph/boxing_task_node.h
+5
-3
oneflow/graph/comm_net_task_node.cpp
oneflow/graph/comm_net_task_node.cpp
+22
-8
oneflow/graph/comm_net_task_node.h
oneflow/graph/comm_net_task_node.h
+8
-4
oneflow/graph/comp_task_node.cpp
oneflow/graph/comp_task_node.cpp
+128
-73
oneflow/graph/comp_task_node.h
oneflow/graph/comp_task_node.h
+27
-22
oneflow/graph/copy_hd_task_node.cpp
oneflow/graph/copy_hd_task_node.cpp
+39
-54
oneflow/graph/copy_hd_task_node.h
oneflow/graph/copy_hd_task_node.h
+9
-7
oneflow/graph/exec_graph.h
oneflow/graph/exec_graph.h
+11
-3
oneflow/graph/task_graph.cpp
oneflow/graph/task_graph.cpp
+8
-2
oneflow/graph/task_graph.h
oneflow/graph/task_graph.h
+8
-5
oneflow/graph/task_node.cpp
oneflow/graph/task_node.cpp
+11
-3
oneflow/graph/task_node.h
oneflow/graph/task_node.h
+7
-4
oneflow/operator/comm_net_op.h
oneflow/operator/comm_net_op.h
+1
-1
oneflow/operator/copy_hd_op.cpp
oneflow/operator/copy_hd_op.cpp
+19
-0
oneflow/operator/copy_hd_op.h
oneflow/operator/copy_hd_op.h
+32
-0
oneflow/operator/copy_op.cpp
oneflow/operator/copy_op.cpp
+0
-48
oneflow/operator/op_conf.proto
oneflow/operator/op_conf.proto
+11
-9
oneflow/operator/operator.proto
oneflow/operator/operator.proto
+1
-6
oneflow/register/register_desc.cpp
oneflow/register/register_desc.cpp
+10
-2
oneflow/register/register_desc.h
oneflow/register/register_desc.h
+4
-21
oneflow/register/register_desc_manager.h
oneflow/register/register_desc_manager.h
+23
-0
未找到文件。
oneflow/graph/boxing_task_node.cpp
浏览文件 @
3589c342
...
@@ -8,28 +8,28 @@ namespace oneflow {
...
@@ -8,28 +8,28 @@ namespace oneflow {
namespace
{
namespace
{
void
FwCompleteBoxOpConfDataData
(
BoxingOpConf
*
conf
)
{
void
FwCompleteBoxOpConfDataData
(
BoxingOpConf
*
conf
)
{
conf
->
mutable_concat_box
()
->
set_
axis
(
0
);
conf
->
mutable_concat_box
()
->
set_
type
(
BoxConcatConf
::
kData
);
conf
->
mutable_
split_box
()
->
set_axis
(
0
);
conf
->
mutable_
data_split_box
(
);
}
}
void
FwCompleteBoxOpConfDataModel
(
BoxingOpConf
*
conf
)
{
void
FwCompleteBoxOpConfDataModel
(
BoxingOpConf
*
conf
)
{
conf
->
mutable_concat_box
()
->
set_
axis
(
0
);
conf
->
mutable_concat_box
()
->
set_
type
(
BoxConcatConf
::
kData
);
conf
->
mutable_clone_box
();
conf
->
mutable_clone_box
();
}
}
void
FwCompleteBoxOpConfModelData
(
BoxingOpConf
*
conf
)
{
void
FwCompleteBoxOpConfModelData
(
BoxingOpConf
*
conf
)
{
conf
->
mutable_concat_box
()
->
set_
axis
(
1
);
conf
->
mutable_concat_box
()
->
set_
type
(
BoxConcatConf
::
kModel
);
conf
->
mutable_
split_box
()
->
set_axis
(
0
);
conf
->
mutable_
data_split_box
(
);
}
}
void
FwCompleteBoxOpConfModelModel
(
BoxingOpConf
*
conf
)
{
void
FwCompleteBoxOpConfModelModel
(
BoxingOpConf
*
conf
)
{
conf
->
mutable_concat_box
()
->
set_
axis
(
1
);
conf
->
mutable_concat_box
()
->
set_
type
(
BoxConcatConf
::
kModel
);
conf
->
mutable_clone_box
();
conf
->
mutable_clone_box
();
}
}
}
// namespace
}
// namespace
void
BoxingTaskNode
::
FwBuildExecAnd
Produced
Regsts
(
TaskGraph
*
gph
)
{
void
BoxingTaskNode
::
FwBuildExecAnd
EnrollLbn2
Regsts
(
TaskGraph
*
gph
)
{
EnrollAllRegstAndBindRelatedEdge
();
EnrollAllRegstAndBindRelatedEdge
();
FwVirtualBuild
();
FwVirtualBuild
();
}
}
...
@@ -37,11 +37,11 @@ void BoxingTaskNode::FwBuildExecAndProducedRegsts(TaskGraph* gph) {
...
@@ -37,11 +37,11 @@ void BoxingTaskNode::FwBuildExecAndProducedRegsts(TaskGraph* gph) {
void
BoxingTaskNode
::
EnrollAllRegstAndBindRelatedEdge
()
{
void
BoxingTaskNode
::
EnrollAllRegstAndBindRelatedEdge
()
{
for
(
TaskEdge
*
edge
:
out_edges
())
{
for
(
TaskEdge
*
edge
:
out_edges
())
{
std
::
string
name
=
"boxing_out_"
+
edge
->
edge_id_str
();
std
::
string
name
=
"boxing_out_"
+
edge
->
edge_id_str
();
auto
regst_desc
=
of_make_unique
<
DisContigRegstDesc
>
();
auto
regst_desc
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
BindProducedRegstAndOutEdge
(
regst_desc
.
get
(),
edge
);
BindProducedRegstAndOutEdge
(
regst_desc
.
get
(),
edge
);
EnrollProducedRegstDesc
(
name
,
std
::
move
(
regst_desc
));
EnrollProducedRegstDesc
(
name
,
std
::
move
(
regst_desc
));
}
}
auto
regst_desc
=
of_make_unique
<
DisContigRegstDesc
>
();
auto
regst_desc
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
EnrollProducedRegstDesc
(
"middle"
,
std
::
move
(
regst_desc
));
EnrollProducedRegstDesc
(
"middle"
,
std
::
move
(
regst_desc
));
}
}
...
@@ -133,7 +133,7 @@ void BoxingTaskNode::FwBuildChainSortedEdgesPair(
...
@@ -133,7 +133,7 @@ void BoxingTaskNode::FwBuildChainSortedEdgesPair(
const
std
::
string
&
ibn
=
node
->
op
()
->
input_bns
().
at
(
i
);
const
std
::
string
&
ibn
=
node
->
op
()
->
input_bns
().
at
(
i
);
std
::
string
lbn
=
node
->
op
()
->
ibn2lbn
(
ibn
);
std
::
string
lbn
=
node
->
op
()
->
ibn2lbn
(
ibn
);
Shape
*
ptr
=
in_regst
->
GetMutShapePtr
(
lbn
);
Shape
*
ptr
=
in_regst
->
GetMutShapePtr
(
lbn
);
node
->
op
()
->
Set
ShapePtr
(
ibn
,
ptr
);
node
->
BindBnInOpAnd
ShapePtr
(
ibn
,
ptr
);
node
->
BindBnInOpAndRegst
(
ibn
,
in_regst
);
node
->
BindBnInOpAndRegst
(
ibn
,
in_regst
);
}
}
// obn
// obn
...
@@ -142,16 +142,22 @@ void BoxingTaskNode::FwBuildChainSortedEdgesPair(
...
@@ -142,16 +142,22 @@ void BoxingTaskNode::FwBuildChainSortedEdgesPair(
const
std
::
string
&
obn
=
node
->
op
()
->
output_bns
().
at
(
i
);
const
std
::
string
&
obn
=
node
->
op
()
->
output_bns
().
at
(
i
);
std
::
string
lbn
=
node
->
op
()
->
obn2lbn
(
obn
);
std
::
string
lbn
=
node
->
op
()
->
obn2lbn
(
obn
);
Shape
*
ptr
=
out_regst
->
EnrollLbn
(
lbn
);
Shape
*
ptr
=
out_regst
->
EnrollLbn
(
lbn
);
node
->
op
()
->
Set
ShapePtr
(
obn
,
ptr
);
node
->
BindBnInOpAnd
ShapePtr
(
obn
,
ptr
);
node
->
BindBnInOpAndRegst
(
obn
,
out_regst
);
node
->
BindBnInOpAndRegst
(
obn
,
out_regst
);
}
}
// dtbn
// dtbn
for
(
const
std
::
string
&
dtbn
:
node
->
op
()
->
data_tmp_bns
())
{
for
(
const
std
::
string
&
dtbn
:
node
->
op
()
->
data_tmp_bns
())
{
std
::
string
lbn
=
node
->
op
()
->
dtbn2lbn
(
dtbn
);
std
::
string
lbn
=
node
->
op
()
->
dtbn2lbn
(
dtbn
);
Shape
*
ptr
=
middle_regst
->
EnrollLbn
(
lbn
);
Shape
*
ptr
=
middle_regst
->
EnrollLbn
(
lbn
);
node
->
op
()
->
SetShapePtr
(
dtbn
,
ptr
);
node
->
BindBnInOpAndShapePtr
(
dtbn
,
ptr
);
node
->
BindBnInOpAndRegst
(
dtbn
,
middle_regst
);
}
}
node
->
op
()
->
InferShape4ObAndDtbFromIb
();
}
}
void
BoxingTaskNode
::
FwInferShape4LbnInProducedRegsts
(
TaskGraph
*
)
{
for
(
const
auto
&
exec_node
:
exec_gph
().
nodes
())
{
exec_node
->
op
()
->
InferShape4ObAndDtbFromIb
(
exec_node
->
BnInOp2ShapePtr
());
}
}
}
}
...
@@ -165,10 +171,9 @@ inline RegstDesc* GetBpRegstFromFwRegst(RegstDesc* fw_regst) {
...
@@ -165,10 +171,9 @@ inline RegstDesc* GetBpRegstFromFwRegst(RegstDesc* fw_regst) {
}
}
void
BoxingTaskNode
::
BpBuildExecAnd
Produced
Regsts
(
TaskGraph
*
)
{
void
BoxingTaskNode
::
BpBuildExecAnd
EnrollLbn2
Regsts
(
TaskGraph
*
)
{
EnrollAllRegstAndBindRelatedEdge
();
EnrollAllRegstAndBindRelatedEdge
();
const
ExecGraph
&
fw_exec_gph
=
GetFwNode
()
->
exec_gph
();
const
ExecGraph
&
fw_exec_gph
=
GetFwNode
()
->
exec_gph
();
HashMap
<
const
ExecNode
*
,
ExecNode
*>
fw_node2bp_node
;
for
(
const
std
::
unique_ptr
<
ExecNode
>&
fw_node
:
fw_exec_gph
.
nodes
())
{
for
(
const
std
::
unique_ptr
<
ExecNode
>&
fw_node
:
fw_exec_gph
.
nodes
())
{
ExecNode
*
bp_node
=
mut_exec_gph
().
NewNode
();
ExecNode
*
bp_node
=
mut_exec_gph
().
NewNode
();
bp_node
->
mut_op
()
=
fw_node
->
op
();
bp_node
->
mut_op
()
=
fw_node
->
op
();
...
@@ -178,8 +183,7 @@ void BoxingTaskNode::BpBuildExecAndProducedRegsts(TaskGraph*) {
...
@@ -178,8 +183,7 @@ void BoxingTaskNode::BpBuildExecAndProducedRegsts(TaskGraph*) {
std
::
string
lbn
=
fw_node
->
op
()
->
ibn2lbn
(
ibn
);
std
::
string
lbn
=
fw_node
->
op
()
->
ibn2lbn
(
ibn
);
RegstDesc
*
in_regst
=
fw_node
->
GetRegstFromBnInOp
(
ibn
);
RegstDesc
*
in_regst
=
fw_node
->
GetRegstFromBnInOp
(
ibn
);
RegstDesc
*
in_diff_regst
=
GetBpRegstFromFwRegst
(
in_regst
);
RegstDesc
*
in_diff_regst
=
GetBpRegstFromFwRegst
(
in_regst
);
Shape
*
in_diff_shape_ptr
=
in_diff_regst
->
EnrollLbn
(
lbn
);
in_diff_regst
->
EnrollLbn
(
lbn
);
*
in_diff_shape_ptr
=
in_regst
->
GetShape
(
lbn
);
bp_node
->
BindBnInOpAndRegst
(
idbn
,
in_diff_regst
);
bp_node
->
BindBnInOpAndRegst
(
idbn
,
in_diff_regst
);
}
}
// out_diff
// out_diff
...
@@ -195,12 +199,22 @@ void BoxingTaskNode::BpBuildExecAndProducedRegsts(TaskGraph*) {
...
@@ -195,12 +199,22 @@ void BoxingTaskNode::BpBuildExecAndProducedRegsts(TaskGraph*) {
std
::
string
lbn
=
fw_node
->
op
()
->
dtbn2lbn
(
dtbn
);
std
::
string
lbn
=
fw_node
->
op
()
->
dtbn2lbn
(
dtbn
);
RegstDesc
*
fw_middle_regst
=
GetFwNode
()
->
GetProducedRegstDesc
(
"middle"
);
RegstDesc
*
fw_middle_regst
=
GetFwNode
()
->
GetProducedRegstDesc
(
"middle"
);
RegstDesc
*
bp_middle_regst
=
GetProducedRegstDesc
(
"middle"
);
RegstDesc
*
bp_middle_regst
=
GetProducedRegstDesc
(
"middle"
);
Shape
*
ptr
=
bp_middle_regst
->
EnrollLbn
(
lbn
);
bp_middle_regst
->
EnrollLbn
(
lbn
);
*
ptr
=
fw_middle_regst
->
GetShape
(
lbn
);
bp_node
->
BindBnInOpAndRegst
(
dtbn
,
bp_middle_regst
);
bp_node
->
BindBnInOpAndRegst
(
dtbn
,
bp_middle_regst
);
}
}
}
}
mut_exec_gph
().
UpdateSourceAndSink
();
mut_exec_gph
().
UpdateSourceAndSink
();
}
}
void
BoxingTaskNode
::
BpInferShape4LbnInProducedRegsts
(
TaskGraph
*
)
{
for
(
TaskEdge
*
fw_in_edge
:
GetFwNode
()
->
in_edges
())
{
RegstDesc
*
in_regst
=
GetRelatedRegst
(
fw_in_edge
);
RegstDesc
*
in_diff_regst
=
GetBpRegstFromFwRegst
(
in_regst
);
in_diff_regst
->
CopyShapeFrom
(
in_regst
);
}
RegstDesc
*
fw_middle_regst
=
GetFwNode
()
->
GetProducedRegstDesc
(
"middle"
);
RegstDesc
*
bp_middle_regst
=
GetProducedRegstDesc
(
"middle"
);
bp_middle_regst
->
CopyShapeFrom
(
fw_middle_regst
);
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/graph/boxing_task_node.h
浏览文件 @
3589c342
...
@@ -12,7 +12,7 @@ class BoxingTaskNode : public TaskNode {
...
@@ -12,7 +12,7 @@ class BoxingTaskNode : public TaskNode {
virtual
~
BoxingTaskNode
()
=
default
;
virtual
~
BoxingTaskNode
()
=
default
;
std
::
string
VisualStr
()
const
override
{
std
::
string
VisualStr
()
const
override
{
return
TaskNode
::
VisualStr
()
+
"Boxing
_"
+
node_id_str
()
;
return
TaskNode
::
VisualStr
()
+
"Boxing
"
;
}
}
protected:
protected:
...
@@ -39,8 +39,10 @@ class BoxingTaskNode : public TaskNode {
...
@@ -39,8 +39,10 @@ class BoxingTaskNode : public TaskNode {
virtual
void
FwVirtualBuild
()
=
0
;
virtual
void
FwVirtualBuild
()
=
0
;
private:
private:
void
FwBuildExecAndProducedRegsts
(
TaskGraph
*
)
override
;
void
FwBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
)
override
;
void
BpBuildExecAndProducedRegsts
(
TaskGraph
*
)
override
;
void
FwInferShape4LbnInProducedRegsts
(
TaskGraph
*
)
override
;
void
BpBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
)
override
;
void
BpInferShape4LbnInProducedRegsts
(
TaskGraph
*
)
override
;
void
EnrollAllRegstAndBindRelatedEdge
();
void
EnrollAllRegstAndBindRelatedEdge
();
...
...
oneflow/graph/comm_net_task_node.cpp
浏览文件 @
3589c342
...
@@ -4,11 +4,11 @@
...
@@ -4,11 +4,11 @@
namespace
oneflow
{
namespace
oneflow
{
void
CommNetTaskNode
::
BuildExecAndProducedRegstsForNetCopy
(
TaskGraph
*
gph
)
{
void
CommNetTaskNode
::
CommNetBuildExecAndEnrollLbn2Regsts
()
{
auto
out_regst
=
of_make_unique
<
DisContigRegstDesc
>
();
auto
out_regst
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
BindProducedRegstAndOutEdge
(
out_regst
.
get
(),
SoleOutEdge
());
BindProducedRegstAndOutEdge
(
out_regst
.
get
(),
SoleOutEdge
());
RegstDesc
*
in_regst
=
GetRelatedRegst
(
SoleInEdge
());
RegstDesc
*
in_regst
=
GetRelatedRegst
(
SoleInEdge
());
out_regst
->
CopyLbn
2ShapeMap
(
in_regst
);
out_regst
->
CopyLbn
From
(
in_regst
);
OperatorConf
op_conf
;
OperatorConf
op_conf
;
op_conf
.
set_name
(
"comm_net_"
+
NewUniqueId
());
op_conf
.
set_name
(
"comm_net_"
+
NewUniqueId
());
...
@@ -22,15 +22,29 @@ void CommNetTaskNode::BuildExecAndProducedRegstsForNetCopy(TaskGraph* gph){
...
@@ -22,15 +22,29 @@ void CommNetTaskNode::BuildExecAndProducedRegstsForNetCopy(TaskGraph* gph){
node
->
BindBnInOpAndRegst
(
node
->
op
()
->
SoleObn
(),
out_regst
.
get
());
node
->
BindBnInOpAndRegst
(
node
->
op
()
->
SoleObn
(),
out_regst
.
get
());
mut_exec_gph
().
UpdateSourceAndSink
();
mut_exec_gph
().
UpdateSourceAndSink
();
EnrollProducedRegstDesc
(
"
comm_ne
t"
,
std
::
move
(
out_regst
));
EnrollProducedRegstDesc
(
"
ou
t"
,
std
::
move
(
out_regst
));
}
}
void
CommNetTaskNode
::
FwBuildExecAndProducedRegsts
(
TaskGraph
*
gph
)
{
void
CommNetTaskNode
::
CommNetInferShape4LbnInProducedRegsts
()
{
BuildExecAndProducedRegstsForNetCopy
(
gph
);
RegstDesc
*
in_regst
=
GetRelatedRegst
(
SoleInEdge
());
RegstDesc
*
out_regst
=
GetRelatedRegst
(
SoleOutEdge
());
out_regst
->
CopyShapeFrom
(
in_regst
);
}
void
FwBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
)
override
{
return
CommNetBuildExecAndEnrollLbn2Regsts
();
}
void
FwInferShape4LbnInProducedRegsts
(
TaskGraph
*
)
override
{
return
CommNetInferShape4LbnInProducedRegsts
();
}
void
BpBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
)
override
{
return
CommNetBuildExecAndEnrollLbn2Regsts
();
}
}
void
CommNetTaskNode
::
BpBuildExecAndProducedRegsts
(
TaskGraph
*
gph
)
{
void
BpInferShape4LbnInProducedRegsts
(
TaskGraph
*
)
override
{
BuildExecAndProducedRegstsForNetCopy
(
gph
);
return
CommNetInferShape4LbnInProducedRegsts
(
);
}
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/graph/comm_net_task_node.h
浏览文件 @
3589c342
...
@@ -29,13 +29,17 @@ class CommNetTaskNode final : public TaskNode {
...
@@ -29,13 +29,17 @@ class CommNetTaskNode final : public TaskNode {
}
}
std
::
string
VisualStr
()
const
override
{
std
::
string
VisualStr
()
const
override
{
return
TaskNode
::
VisualStr
()
+
"CommNet
_"
+
node_id_str
()
;
return
TaskNode
::
VisualStr
()
+
"CommNet
"
;
}
}
private:
private:
void
BuildExecAndProducedRegstsForNetCopy
(
TaskGraph
*
);
void
CommNetBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
);
void
FwBuildExecAndProducedRegsts
(
TaskGraph
*
)
override
;
void
CommNetInferShape4LbnInProducedRegsts
();
void
BpBuildExecAndProducedRegsts
(
TaskGraph
*
)
override
;
void
FwBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
)
override
;
void
FwInferShape4LbnInProducedRegsts
(
TaskGraph
*
)
override
;
void
BpBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
)
override
;
void
BpInferShape4LbnInProducedRegsts
(
TaskGraph
*
)
override
;
std
::
unique_ptr
<
TaskNode
>
CreateSameTypeNode
()
const
override
{
std
::
unique_ptr
<
TaskNode
>
CreateSameTypeNode
()
const
override
{
return
of_make_unique
<
CommNetTaskNode
>
();
return
of_make_unique
<
CommNetTaskNode
>
();
...
...
oneflow/graph/comp_task_node.cpp
浏览文件 @
3589c342
...
@@ -5,27 +5,52 @@
...
@@ -5,27 +5,52 @@
namespace
oneflow
{
namespace
oneflow
{
void
CompTaskNode
::
FwBuildExecAndProducedRegsts
(
TaskGraph
*
gph
)
{
std
::
string
CompTaskNode
::
VisualStr
()
const
override
{
(
this
->*
(
gph
->
Func4FwBuildExecAndProducedRegsts
()))(
gph
);
std
::
stringstream
ss
;
ss
<<
TaskNode
::
VisualStr
()
<<
"Compute"
<<
":"
<<
stage_node
()
->
machine_id_str
()
<<
":"
<<
thrd_loc_id_str
()
<<
"
\\
n"
<<
chain_node
()
->
VisualStr
();
return
ss
.
str
();
}
}
void
CompTaskNode
::
DataFwBuildExecAnd
Produced
Regsts
(
TaskGraph
*
)
{
void
CompTaskNode
::
DataFwBuildExecAnd
EnrollLbn2
Regsts
(
TaskGraph
*
)
{
Lbn2NodeBnMap
lbn2producer
;
Lbn2NodeBnMap
lbn2producer
;
Lbn2NodeBnMap
extern_in_lbn2consumer
;
Lbn2NodeBnMap
extern_in_lbn2consumer
;
FwBuildFromUserOps
(
&
lbn2producer
,
&
extern_in_lbn2consumer
);
FwBuildFromUserOps
(
&
lbn2producer
,
&
extern_in_lbn2consumer
);
mut_exec_gph
().
UpdateSourceAndSink
();
mut_exec_gph
().
UpdateSourceAndSink
();
// data regst
// produced regsts
auto
data_regst
=
of_make_unique
<
DisContigRegstDesc
>
();
auto
out_regst
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
BindProducedRegstAndOutEdge
(
data_regst
.
get
(),
SoleOutEdge
());
auto
activation_regst
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
EnrollProducedRegstDesc
(
"data"
,
std
::
move
(
data_regst
));
auto
data_tmp_regst
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
FwSetDataRegstDesc
(
lbn2producer
,
extern_in_lbn2consumer
);
auto
model_tmp_regst
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
// model_tmp regst
// Bind Out Edge
auto
model_tmp_regst
=
of_make_unique
<
DisContigRegstDesc
>
();
BindProducedRegstAndOutEdge
(
out_regst
.
get
(),
SoleOutEdge
());
// EnrollProducedRegstDesc
EnrollProducedRegstDesc
(
"out"
,
std
::
move
(
out_regst
));
EnrollProducedRegstDesc
(
"activation"
,
std
::
move
(
activation_regst
));
EnrollProducedRegstDesc
(
"data_tmp"
,
std
::
move
(
data_tmp_regst
));
EnrollProducedRegstDesc
(
"model_tmp"
,
std
::
move
(
model_tmp_regst
));
EnrollProducedRegstDesc
(
"model_tmp"
,
std
::
move
(
model_tmp_regst
));
FwSetModelTmpRegstDesc
();
// Enroll Lbn
FwSetExecNodeFromInRegst
(
extern_in_lbn2consumer
);
FwEnrollLbn2OutRegst
(
lbn2producer
);
FwEnrollLbn2ActivationRegst
();
FwEnrollLbn2TmpRegsts
();
}
}
void
CompTaskNode
::
MdUpdtFwBuildExecAndProducedRegsts
(
TaskGraph
*
gph
)
{
void
CompTaskNode
::
DataFwInferShape4LbnInProducedRegsts
()
{
for
(
const
ExecNode
&
node
:
exec_gph
())
{
node
.
op
()
->
InferShape4ObAndDtbFromIb
(
node
.
BnInOp2ShapePtr
());
node
.
op
()
->
InferShape4ModelTmpBlob
(
node
.
BnInOp2ShapePtr
(),
chain_node
()
->
parallel_desc
()
->
policy
(),
parallel_id
());
}
}
void
CompTaskNode
::
MdUpdtFwBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
gph
)
{
TODO
();
/*
if (IsFaker()) {
if (IsFaker()) {
CompTaskNode* mccoy = gph->faker2mccoy().at(this);
CompTaskNode* mccoy = gph->faker2mccoy().at(this);
RegstDesc* regst = mccoy->GetProducedRegstDesc("model_diff");
RegstDesc* regst = mccoy->GetProducedRegstDesc("model_diff");
...
@@ -39,9 +64,16 @@ void CompTaskNode::MdUpdtFwBuildExecAndProducedRegsts(TaskGraph* gph) {
...
@@ -39,9 +64,16 @@ void CompTaskNode::MdUpdtFwBuildExecAndProducedRegsts(TaskGraph* gph) {
mut_exec_gph().UpdateSourceAndSink();
mut_exec_gph().UpdateSourceAndSink();
// PostProcessing in ModelUpdateTaskGraph will complete the work
// PostProcessing in ModelUpdateTaskGraph will complete the work
// which should be implemented in this function
// which should be implemented in this function
*/
}
void
CompTaskNode
::
MdUpdtFwInferShape4LbnInProducedRegsts
(
TaskGraph
*
gph
)
{
TODO
();
}
}
void
CompTaskNode
::
MdLoadFwBuildExecAndProducedRegsts
(
TaskGraph
*
gph
)
{
void
CompTaskNode
::
MdLoadFwBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
gph
)
{
TODO
();
/*
if (IsFaker()) {
if (IsFaker()) {
CompTaskNode* update_task = gph->faker2mccoy().at(this);
CompTaskNode* update_task = gph->faker2mccoy().at(this);
ExecNode* exec_node = update_task->exec_gph().SoleNode();
ExecNode* exec_node = update_task->exec_gph().SoleNode();
...
@@ -58,9 +90,16 @@ void CompTaskNode::MdLoadFwBuildExecAndProducedRegsts(TaskGraph* gph) {
...
@@ -58,9 +90,16 @@ void CompTaskNode::MdLoadFwBuildExecAndProducedRegsts(TaskGraph* gph) {
exec_node->op()->SetShapePtr(exec_node->op()->SoleObn(), shape_ptr);
exec_node->op()->SetShapePtr(exec_node->op()->SoleObn(), shape_ptr);
exec_node->op()->InferShape4ObAndDtbFromIb();
exec_node->op()->InferShape4ObAndDtbFromIb();
EnrollProducedRegstDesc("model_regst", std::move(model_regst));
EnrollProducedRegstDesc("model_regst", std::move(model_regst));
*/
}
}
void
CompTaskNode
::
MdSaveFwBuildExecAndProducedRegsts
(
TaskGraph
*
gph
)
{
void
CompTaskNode
::
MdLoadFwInferShape4LbnInProducedRegsts
(
TaskGraph
*
gph
)
{
TODO
();
}
void
CompTaskNode
::
MdSaveFwBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
gph
)
{
TODO
();
/*
if (IsFaker()) {
if (IsFaker()) {
CompTaskNode* update_task = gph->faker2mccoy().at(this);
CompTaskNode* update_task = gph->faker2mccoy().at(this);
RegstDesc* model_regst = update_task->GetProducedRegstDesc("model");
RegstDesc* model_regst = update_task->GetProducedRegstDesc("model");
...
@@ -72,6 +111,11 @@ void CompTaskNode::MdSaveFwBuildExecAndProducedRegsts(TaskGraph* gph) {
...
@@ -72,6 +111,11 @@ void CompTaskNode::MdSaveFwBuildExecAndProducedRegsts(TaskGraph* gph) {
mut_exec_gph().UpdateSourceAndSink();
mut_exec_gph().UpdateSourceAndSink();
const std::string& ibn = exec_node->op()->SoleIbn();
const std::string& ibn = exec_node->op()->SoleIbn();
exec_node->BindBnInOpAndRegst(ibn, GetRelatedRegst(SoleInEdge()));
exec_node->BindBnInOpAndRegst(ibn, GetRelatedRegst(SoleInEdge()));
*/
}
void
CompTaskNode
::
MdSaveFwInferShape4LbnInProducedRegsts
(
TaskGraph
*
gph
)
{
TODO
();
}
}
void
CompTaskNode
::
FwBuildFromUserOps
(
void
CompTaskNode
::
FwBuildFromUserOps
(
...
@@ -103,80 +147,97 @@ void CompTaskNode::FwBuildFromUserOps(
...
@@ -103,80 +147,97 @@ void CompTaskNode::FwBuildFromUserOps(
}
}
}
}
void
CompTaskNode
::
FwSetDataRegstDesc
(
void
CompTaskNode
::
FwSetExecNodeFromInRegst
(
const
Lbn2NodeBnMap
&
lbn2producer
,
const
Lbn2NodeBnMap
&
extern_in_lbn2consumer
)
{
const
Lbn2NodeBnMap
&
extern_in_lbn2consumer
)
{
RegstDesc
*
in_regst
=
GetRelatedRegst
(
SoleInEdge
());
RegstDesc
*
in_regst
=
GetRelatedRegst
(
SoleInEdge
());
RegstDesc
*
out_regst
=
GetRelatedRegst
(
SoleOutEdge
());
// blob on exec_edge
for
(
const
std
::
unique_ptr
<
ExecEdge
>&
edge
:
exec_gph
().
edges
())
{
Shape
*
ptr
=
out_regst
->
EnrollLbn
(
edge
->
lbn
());
edge
->
src_node
()
->
BindBnInOpAndRegst
(
edge
->
src_bn
(),
out_regst
);
edge
->
src_node
()
->
op
()
->
SetShapePtr
(
edge
->
src_bn
(),
ptr
);
edge
->
dst_node
()
->
BindBnInOpAndRegst
(
edge
->
dst_bn
(),
out_regst
);
edge
->
dst_node
()
->
op
()
->
SetShapePtr
(
edge
->
dst_bn
(),
ptr
);
}
// extern in blobs
for
(
const
auto
&
pair
:
extern_in_lbn2consumer
)
{
for
(
const
auto
&
pair
:
extern_in_lbn2consumer
)
{
const
std
::
string
&
lbn
=
pair
.
first
;
const
std
::
string
&
lbn
=
pair
.
first
;
Shape
*
ptr
=
in_regst
->
GetMutShapePtr
(
lbn
);
Shape
*
ptr
=
in_regst
->
GetMutShapePtr
(
lbn
);
ExecNode
*
node
=
pair
.
second
.
first
;
ExecNode
*
node
=
pair
.
second
.
first
;
const
std
::
string
&
ibn
=
pair
.
second
.
second
;
const
std
::
string
&
ibn
=
pair
.
second
.
second
;
node
->
op
()
->
Set
ShapePtr
(
ibn
,
ptr
);
node
->
BindBnInOpAnd
ShapePtr
(
ibn
,
ptr
);
node
->
BindBnInOpAndRegst
(
ibn
,
in_regst
);
node
->
BindBnInOpAndRegst
(
ibn
,
in_regst
);
}
}
// extern out blobs
}
void
CompTaskNode
::
FwEnrollLbn2OutRegst
(
const
Lbn2NodeBnMap
&
lbn2producer
)
{
RegstDesc
*
out_regst
=
GetRelatedRegst
(
SoleOutEdge
());
for
(
const
std
::
string
&
lbn
:
chain_node
()
->
output_lbns
())
{
for
(
const
std
::
string
&
lbn
:
chain_node
()
->
output_lbns
())
{
const
std
::
pair
<
ExecNode
*
,
std
::
string
>&
producer
=
lbn2producer
.
at
(
lbn
);
const
std
::
pair
<
ExecNode
*
,
std
::
string
>&
producer
=
lbn2producer
.
at
(
lbn
);
ExecNode
*
node
=
producer
.
first
;
ExecNode
*
node
=
producer
.
first
;
const
std
::
string
&
obn
=
producer
.
second
;
const
std
::
string
&
obn
=
producer
.
second
;
Shape
*
ptr
=
out_regst
->
EnrollLbn
(
lbn
);
Shape
*
ptr
=
out_regst
->
EnrollLbn
(
lbn
);
node
->
op
()
->
Set
ShapePtr
(
obn
,
ptr
);
node
->
BindBnInOpAnd
ShapePtr
(
obn
,
ptr
);
node
->
BindBnInOpAndRegst
(
obn
,
out_regst
);
node
->
BindBnInOpAndRegst
(
obn
,
out_regst
);
}
}
// data tmp blobs
}
for
(
const
std
::
unique_ptr
<
ExecNode
>&
node
:
exec_gph
().
nodes
())
{
for
(
const
std
::
string
&
dtbn
:
node
->
op
()
->
data_tmp_bns
())
{
void
CompTaskNode
::
FwEnrollLbn2ActivationRegst
()
{
std
::
string
lbn
=
node
->
op
()
->
dtbn2lbn
(
dtbn
);
RegstDesc
*
activation_regst
=
GetProducedRegstDesc
(
"activation"
);
Shape
*
ptr
=
out_regst
->
EnrollLbn
(
lbn
);
for
(
const
std
::
unique_ptr
<
ExecEdge
>&
edge
:
exec_gph
().
edges
())
{
node
->
op
()
->
SetShapePtr
(
dtbn
,
ptr
);
Shape
*
ptr
=
activation_regst
->
EnrollLbn
(
edge
->
lbn
());
node
->
BindBnInOpAndRegst
(
dtbn
,
out_regst
);
edge
->
src_node
()
->
BindBnInOpAndRegst
(
edge
->
src_bn
(),
activation_regst
);
}
edge
->
src_node
()
->
BindBnInOpAndShapePtr
(
edge
->
src_bn
(),
ptr
);
}
edge
->
dst_node
()
->
BindBnInOpAndRegst
(
edge
->
dst_bn
(),
activation_regst
);
// Inference Shape
edge
->
dst_node
()
->
BindBnInOpAndShapePtr
(
edge
->
dst_bn
(),
ptr
);
for
(
const
ExecNode
&
node
:
exec_gph
())
{
node
.
op
()
->
InferShape4ObAndDtbFromIb
();
}
}
}
}
void
CompTaskNode
::
FwSetModelTmpRegstDesc
()
{
void
CompTaskNode
::
FwEnrollLbn2TmpRegsts
()
{
RegstDesc
*
data_tmp_regst
=
GetProducedRegstDesc
(
"data_tmp"
);
RegstDesc
*
model_tmp_regst
=
GetProducedRegstDesc
(
"model_tmp"
);
RegstDesc
*
model_tmp_regst
=
GetProducedRegstDesc
(
"model_tmp"
);
for
(
const
std
::
unique_ptr
<
ExecNode
>&
node
:
exec_gph
().
nodes
())
{
for
(
const
std
::
unique_ptr
<
ExecNode
>&
node
:
exec_gph
().
nodes
())
{
for
(
const
std
::
string
&
dtbn
:
node
->
op
()
->
data_tmp_bns
())
{
std
::
string
lbn
=
node
->
op
()
->
dtbn2lbn
(
dtbn
);
Shape
*
ptr
=
data_tmp_regst
->
EnrollLbn
(
lbn
);
node
->
BindBnInOpAndShapePtr
(
dtbn
,
ptr
);
node
->
BindBnInOpAndRegst
(
dtbn
,
out_regst
);
}
for
(
const
std
::
string
&
mtbn
:
node
->
op
()
->
model_tmp_bns
())
{
for
(
const
std
::
string
&
mtbn
:
node
->
op
()
->
model_tmp_bns
())
{
std
::
string
lbn
=
node
->
op
()
->
mtbn2lbn
(
mtbn
);
std
::
string
lbn
=
node
->
op
()
->
mtbn2lbn
(
mtbn
);
Shape
*
ptr
=
model_tmp_regst
->
EnrollLbn
(
lbn
);
Shape
*
ptr
=
model_tmp_regst
->
EnrollLbn
(
lbn
);
node
->
op
()
->
Set
ShapePtr
(
mtbn
,
ptr
);
node
->
BindBnInOpAnd
ShapePtr
(
mtbn
,
ptr
);
node
->
BindBnInOpAndRegst
(
mtbn
,
model_tmp_regst
);
node
->
BindBnInOpAndRegst
(
mtbn
,
model_tmp_regst
);
}
}
node
->
op
()
->
InferShape4ModelTmpBlob
(
chain_node
()
->
parallel_desc
()
->
policy
(),
parallel_id
());
}
}
}
}
void
CompTaskNode
::
BpBuildExecAnd
Produced
Regsts
(
TaskGraph
*
)
{
void
CompTaskNode
::
BpBuildExecAnd
EnrollLbn2
Regsts
(
TaskGraph
*
)
{
const
ExecGraph
&
fw_gph
=
GetFwNode
()
->
exec_gph
();
const
ExecGraph
&
fw_gph
=
GetFwNode
()
->
exec_gph
();
HashMap
<
const
ExecNode
*
,
ExecNode
*>
fw_node2bp_node
;
HashMap
<
const
ExecNode
*
,
ExecNode
*>
fw_node2bp_node
;
HashMap
<
ExecEdge
*
,
const
ExecEdge
*>
bp_edge2fw_edge
;
HashMap
<
ExecEdge
*
,
const
ExecEdge
*>
bp_edge2fw_edge
;
BpBuildExecGraph
(
fw_gph
,
&
fw_node2bp_node
,
&
bp_edge2fw_edge
);
BpBuildExecGraph
(
fw_gph
,
&
fw_node2bp_node
,
&
bp_edge2fw_edge
);
//
// Produced registers
auto
data_diff_regst
=
of_make_unique
<
DisContigRegstDesc
>
();
auto
in_diff_regst
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
BindProducedRegstAndOutEdge
(
data_diff_regst
.
get
(),
SoleOutEdge
());
auto
model_diff_regst
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
EnrollProducedRegstDesc
(
"data_diff"
,
std
::
move
(
data_diff_regst
));
auto
activation_diff_regst
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
BpSetDataDiffRegst
(
fw_node2bp_node
,
bp_edge2fw_edge
);
// Bind out edge
//
BindProducedRegstAndOutEdge
(
in_diff_regst
.
get
(),
SoleOutEdge
());
auto
model_diff_regst
=
of_make_unique
<
ContigRegstDesc
>
();
// Enroll registers
EnrollProducedRegstDesc
(
"in_diff"
,
std
::
move
(
in_diff_regst
));
EnrollProducedRegstDesc
(
"model_diff"
,
std
::
move
(
model_diff_regst
));
EnrollProducedRegstDesc
(
"model_diff"
,
std
::
move
(
model_diff_regst
));
BpSetModelDiffRegst
();
EnrollProducedRegstDesc
(
"activation_diff"
,
std
::
move
(
activation_diff_regst
));
// Enroll Lbn
BpEnrollLbn2ProducedRegst
(
fw_node2bp_node
,
bp_edge2fw_edge
);
}
void
CompTaskNode
::
BpInferShape4LbnInProducedRegsts
(
TaskGraph
*
)
{
// in_diff_regst
RegstDesc
*
in_diff_regst
=
GetRelatedRegst
(
SoleOutEdge
());
RegstDesc
*
in_regst
=
GetRelatedRegst
(
GetFwNode
()
->
SoleInEdge
());
in_diff_regst
->
CopyShapeFrom
(
in_regst
);
// model_diff_regst
RegstDesc
*
model_diff_regst
=
GetProducedRegstDesc
(
"model_diff"
);
for
(
const
std
::
unique_ptr
<
ExecNode
>&
cur_node
:
exec_gph
().
nodes
())
{
cur_node
->
op
()
->
InferShape4ModelDiffBlob
(
cur_node
->
BnInOp2ShapePtr
(),
chain_node
()
->
parallel_desc
()
->
policy
(),
parallel_id
());
}
// activation_diff_regst
RegstDesc
*
activation_diff_regst
=
GetProducedRegstDesc
(
"activation_diff"
);
RegstDesc
*
activation_regst
=
GetFwNode
()
->
GetProducedRegstDesc
(
"activation"
);
activation_diff_regst
->
CopyShapeFrom
(
activation_regst
);
}
}
void
CompTaskNode
::
BpBuildExecGraph
(
void
CompTaskNode
::
BpBuildExecGraph
(
...
@@ -200,21 +261,23 @@ void CompTaskNode::BpBuildExecGraph(
...
@@ -200,21 +261,23 @@ void CompTaskNode::BpBuildExecGraph(
}
}
}
}
void
CompTaskNode
::
Bp
SetDataDiff
Regst
(
void
CompTaskNode
::
Bp
EnrollLbn2Produced
Regst
(
const
HashMap
<
const
ExecNode
*
,
ExecNode
*>&
fw_node2bp_node
,
const
HashMap
<
const
ExecNode
*
,
ExecNode
*>&
fw_node2bp_node
,
const
HashMap
<
ExecEdge
*
,
const
ExecEdge
*>&
bp_edge2fw_edge
)
{
const
HashMap
<
ExecEdge
*
,
const
ExecEdge
*>&
bp_edge2fw_edge
)
{
// Regsts
// Regsts
RegstDesc
*
in_diff_regst
=
GetRelatedRegst
(
SoleOutEdge
());
RegstDesc
*
in_diff_regst
=
GetRelatedRegst
(
SoleOutEdge
());
RegstDesc
*
out_diff_regst
=
GetRelatedRegst
(
SoleInEdge
());
RegstDesc
*
out_diff_regst
=
GetRelatedRegst
(
SoleInEdge
());
RegstDesc
*
in_regst
=
GetRelatedRegst
(
GetFwNode
()
->
SoleInEdge
());
RegstDesc
*
in_regst
=
GetRelatedRegst
(
GetFwNode
()
->
SoleInEdge
());
RegstDesc
*
out_regst
=
GetRelatedRegst
(
GetFwNode
()
->
SoleOutEdge
());
RegstDesc
*
activation_regst
=
GetFwNode
()
->
GetProducedRegstDesc
(
"activation"
);
RegstDesc
*
data_tmp_regst
=
GetFwNode
()
->
GetProducedRegstDesc
(
"data_tmp"
);
RegstDesc
*
model_tmp_regst
=
GetFwNode
()
->
GetProducedRegstDesc
(
"model_tmp"
);
RegstDesc
*
model_tmp_regst
=
GetFwNode
()
->
GetProducedRegstDesc
(
"model_tmp"
);
RegstDesc
*
activation_diff_regst
=
GetProducedRegstDesc
(
"activation_diff"
);
RegstDesc
*
model_diff_regst
=
GetProducedRegstDesc
(
"model_diff"
);
// blobs on edge
// blobs on edge
for
(
const
std
::
unique_ptr
<
ExecEdge
>&
edge
:
exec_gph
().
edges
())
{
for
(
const
std
::
unique_ptr
<
ExecEdge
>&
edge
:
exec_gph
().
edges
())
{
Shape
*
ptr
=
in_diff_regst
->
EnrollLbn
(
edge
->
lbn
());
activation_diff_regst
->
EnrollLbn
(
edge
->
lbn
());
*
ptr
=
out_regst
->
GetShape
(
edge
->
lbn
());
edge
->
src_node
()
->
BindBnInOpAndRegst
(
edge
->
src_bn
(),
activation_diff_regst
);
edge
->
src_node
()
->
BindBnInOpAndRegst
(
edge
->
src_bn
(),
in_diff_regst
);
edge
->
dst_node
()
->
BindBnInOpAndRegst
(
edge
->
dst_bn
(),
activation_diff_regst
);
edge
->
dst_node
()
->
BindBnInOpAndRegst
(
edge
->
dst_bn
(),
in_diff_regst
);
}
}
// extern out_diff blobs
// extern out_diff blobs
for
(
const
std
::
unique_ptr
<
ExecNode
>&
bp_node
:
exec_gph
().
nodes
())
{
for
(
const
std
::
unique_ptr
<
ExecNode
>&
bp_node
:
exec_gph
().
nodes
())
{
...
@@ -237,35 +300,27 @@ void CompTaskNode::BpSetDataDiffRegst(
...
@@ -237,35 +300,27 @@ void CompTaskNode::BpSetDataDiffRegst(
for
(
const
std
::
string
&
idbn
:
bp_node
->
op
()
->
input_diff_bns
())
{
for
(
const
std
::
string
&
idbn
:
bp_node
->
op
()
->
input_diff_bns
())
{
if
(
found_bns
.
find
(
idbn
)
!=
found_bns
.
end
())
{
continue
;
}
if
(
found_bns
.
find
(
idbn
)
!=
found_bns
.
end
())
{
continue
;
}
std
::
string
lbn
=
bp_node
->
op
()
->
idbn2lbn
(
idbn
);
std
::
string
lbn
=
bp_node
->
op
()
->
idbn2lbn
(
idbn
);
Shape
*
ptr
=
in_diff_regst
->
EnrollLbn
(
lbn
);
in_diff_regst
->
EnrollLbn
(
lbn
);
*
ptr
=
in_regst
->
GetShape
(
lbn
);
bp_node
->
BindBnInOpAndRegst
(
idbn
,
in_diff_regst
);
bp_node
->
BindBnInOpAndRegst
(
idbn
,
in_diff_regst
);
bp_node
->
BindBnInOpAndRegst
(
GenUnDiffBn
(
idbn
),
in_regst
);
}
}
}
}
// tmp blobs
// tmp blobs
and model_diff blobs
for
(
const
std
::
unique_ptr
<
ExecNode
>&
node
:
exec_gph
().
nodes
())
{
for
(
const
std
::
unique_ptr
<
ExecNode
>&
node
:
exec_gph
().
nodes
())
{
for
(
const
std
::
string
&
dtbn
:
node
->
op
()
->
data_tmp_bns
())
{
for
(
const
std
::
string
&
dtbn
:
node
->
op
()
->
data_tmp_bns
())
{
std
::
string
lbn
=
node
->
op
()
->
dtbn2lbn
(
dtbn
);
std
::
string
lbn
=
node
->
op
()
->
dtbn2lbn
(
dtbn
);
node
->
BindBnInOpAndRegst
(
dtbn
,
out
_regst
);
node
->
BindBnInOpAndRegst
(
dtbn
,
data_tmp
_regst
);
}
}
for
(
const
std
::
string
&
mtbn
:
node
->
op
()
->
model_tmp_bns
())
{
for
(
const
std
::
string
&
mtbn
:
node
->
op
()
->
model_tmp_bns
())
{
std
::
string
lbn
=
node
->
op
()
->
mtbn2lbn
(
mtbn
);
std
::
string
lbn
=
node
->
op
()
->
mtbn2lbn
(
mtbn
);
node
->
BindBnInOpAndRegst
(
mtbn
,
model_tmp_regst
);
node
->
BindBnInOpAndRegst
(
mtbn
,
model_tmp_regst
);
}
}
}
}
void
CompTaskNode
::
BpSetModelDiffRegst
()
{
RegstDesc
*
model_diff_regst
=
GetProducedRegstDesc
(
"model_diff"
);
for
(
const
std
::
unique_ptr
<
ExecNode
>&
cur_node
:
exec_gph
().
nodes
())
{
for
(
const
std
::
string
&
mdbn
:
cur_node
->
op
()
->
model_diff_bns
())
{
for
(
const
std
::
string
&
mdbn
:
cur_node
->
op
()
->
model_diff_bns
())
{
std
::
string
lbn
=
cur_node
->
op
()
->
mdbn2lbn
(
mdbn
);
std
::
string
lbn
=
cur_node
->
op
()
->
mdbn2lbn
(
mdbn
);
Shape
*
ptr
=
model_diff_regst
->
EnrollLbn
(
lbn
);
Shape
*
ptr
=
model_diff_regst
->
EnrollLbn
(
lbn
);
cur_node
->
op
()
->
Set
ShapePtr
(
mdbn
,
ptr
);
cur_node
->
BindBnInOpAnd
ShapePtr
(
mdbn
,
ptr
);
cur_node
->
BindBnInOpAndRegst
(
mdbn
,
model_diff_regst
);
cur_node
->
BindBnInOpAndRegst
(
mdbn
,
model_diff_regst
);
}
}
cur_node
->
op
()
->
InferShape4ModelDiffBlob
(
chain_node
()
->
parallel_desc
()
->
policy
(),
parallel_id
());
}
}
}
}
...
...
oneflow/graph/comp_task_node.h
浏览文件 @
3589c342
...
@@ -12,55 +12,60 @@ class CompTaskNode : public TaskNode {
...
@@ -12,55 +12,60 @@ class CompTaskNode : public TaskNode {
CompTaskNode
()
=
default
;
CompTaskNode
()
=
default
;
virtual
~
CompTaskNode
()
=
default
;
virtual
~
CompTaskNode
()
=
default
;
// Getters and Setters
uint64_t
parallel_id
()
const
{
return
parallel_id_
;
}
uint64_t
parallel_id
()
const
{
return
parallel_id_
;
}
void
set_parallel_id
(
uint64_t
parallel_id
)
{
parallel_id_
=
parallel_id
;
}
void
set_parallel_id
(
uint64_t
parallel_id
)
{
parallel_id_
=
parallel_id
;
}
bool
IsLossNode
()
const
{
return
chain_node
()
->
IsLossNode
();
}
bool
IsLossNode
()
const
{
return
chain_node
()
->
IsLossNode
();
}
bool
IsFaker
()
const
{
return
chain_node
()
->
IsFaker
();
}
bool
IsFaker
()
const
{
return
chain_node
()
->
IsFaker
();
}
std
::
string
VisualStr
()
const
override
;
void
DataFwBuildExecAndProducedRegsts
(
TaskGraph
*
);
// Build Exec and Set Produced Regsts
void
MdUpdtFwBuildExecAndProducedRegsts
(
TaskGraph
*
);
void
DataFwBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
);
void
MdLoadFwBuildExecAndProducedRegsts
(
TaskGraph
*
);
void
DataFwInferShape4LbnInProducedRegsts
(
TaskGraph
*
);
void
MdSaveFwBuildExecAndProducedRegsts
(
TaskGraph
*
);
std
::
string
VisualStr
()
const
override
{
void
MdUpdtFwBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
);
std
::
stringstream
ss
;
void
MdUpdtFwInferShape4LbnInProducedRegsts
(
TaskGraph
*
);
ss
<<
TaskNode
::
VisualStr
()
<<
"Compute_"
<<
node_id_str
()
<<
":"
void
MdLoadFwBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
);
<<
stage_node
()
->
machine_id_str
()
<<
":"
void
MdLoadFwInferShape4LbnInProducedRegsts
(
TaskGraph
*
);
<<
thrd_loc_id_str
()
<<
"
\\
n"
<<
chain_node
()
->
VisualStr
();
void
MdSaveFwBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
);
return
ss
.
str
();
void
MdSaveFwInferShape4LbnInProducedRegsts
(
TaskGraph
*
);
}
protected:
protected:
virtual
void
InitWithFwNode
(
TaskNode
*
fw_node
)
override
{
virtual
void
InitWithFwNode
(
TaskNode
*
fw_node
)
override
{
TaskNode
::
InitWithFwNode
(
fw_node
);
TaskNode
::
InitWithFwNode
(
fw_node
);
parallel_id_
=
of_dynamic_cast
<
CompTaskNode
*>
(
fw_node
)
->
parallel_id_
;
}
}
private:
private:
using
Lbn2NodeBnMap
=
using
Lbn2NodeBnMap
=
HashMap
<
std
::
string
,
std
::
pair
<
ExecNode
*
,
std
::
string
>>
;
HashMap
<
std
::
string
,
std
::
pair
<
ExecNode
*
,
std
::
string
>>
;
void
FwBuildExecAndProducedRegsts
(
TaskGraph
*
)
override
;
void
FwBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
gph
)
override
{
(
this
->*
(
gph
->
Func4FwBuildExecAndEnrollLbn2Regsts
()))(
gph
);
}
void
FwInferShape4LbnInProducedRegsts
(
TaskGraph
*
gph
)
override
{
(
this
->*
(
gph
->
Func4FwInferShape4LbnInProducedRegsts
()))(
gph
);
}
void
FwBuildFromUserOps
(
void
FwBuildFromUserOps
(
Lbn2NodeBnMap
*
lbn2producer
,
Lbn2NodeBnMap
*
lbn2producer
,
Lbn2NodeBnMap
*
extern_in_lbn2consumer
);
Lbn2NodeBnMap
*
extern_in_lbn2consumer
);
void
FwSetDataRegstDesc
(
void
FwSetExecNodeFromInRegst
(
const
Lbn2NodeBnMap
&
lbn2producer
,
const
Lbn2NodeBnMap
&
extern_in_lbn2consumer
);
const
Lbn2NodeBnMap
&
extern_in_lbn2consumer
);
void
FwSetModelTmpRegstDesc
();
void
FwEnrollLbn2OutRegst
(
const
Lbn2NodeBnMap
&
lbn2producer
);
void
FwEnrollLbn2ActivationRegst
();
void
FwEnrollLbn2TmpRegsts
();
void
BpBuildExecAndProducedRegsts
(
TaskGraph
*
)
override
;
void
BpBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
)
override
;
void
BpInferShape4LbnInProducedRegsts
(
TaskGraph
*
)
override
;
void
BpBuildExecGraph
(
void
BpBuildExecGraph
(
const
ExecGraph
&
fw_gph
,
const
ExecGraph
&
fw_gph
,
HashMap
<
const
ExecNode
*
,
ExecNode
*>*
fw_node2bp_node
,
HashMap
<
const
ExecNode
*
,
ExecNode
*>*
fw_node2bp_node
,
HashMap
<
ExecEdge
*
,
const
ExecEdge
*>*
bp_edge2fw_edge
);
HashMap
<
ExecEdge
*
,
const
ExecEdge
*>*
bp_edge2fw_edge
);
void
Bp
SetDataDiff
Regst
(
void
Bp
EnrollLbn2Produced
Regst
(
const
HashMap
<
const
ExecNode
*
,
ExecNode
*>&
fw_node2bp_node
,
const
HashMap
<
const
ExecNode
*
,
ExecNode
*>&
fw_node2bp_node
,
const
HashMap
<
ExecEdge
*
,
const
ExecEdge
*>&
bp_edge2fw_edge
);
const
HashMap
<
ExecEdge
*
,
const
ExecEdge
*>&
bp_edge2fw_edge
);
void
BpSetModelDiffRegst
();
uint64_t
parallel_id_
;
uint64_t
parallel_id_
;
...
...
oneflow/graph/copy_hd_task_node.cpp
浏览文件 @
3589c342
...
@@ -4,52 +4,6 @@
...
@@ -4,52 +4,6 @@
namespace
oneflow
{
namespace
oneflow
{
void
CopyHDTaskNode
::
BuildExecAndProducedRegstsForCopy
(
TaskGraph
*
gph
){
auto
out_regst
=
of_make_unique
<
DisContigRegstDesc
>
();
BindProducedRegstAndOutEdge
(
out_regst
.
get
(),
SoleOutEdge
());
OperatorConf
op_conf
;
op_conf
.
set_name
(
"copy_"
+
NewUniqueId
());
CopyOpConf
*
copy_conf
=
op_conf
.
mutable_copy_conf
();
copy_conf
->
set_copy_type
(
IsH2D
()
?
CopyOpConf
::
H2D
:
CopyOpConf
::
D2H
);
for
(
std
::
string
lbn
:
CopiedLbns
()){
copy_conf
->
add_copied_lbns
(
lbn
);
}
RegstDesc
*
in_regst
=
GetRelatedRegst
(
SoleInEdge
());
bool
get_kalllbn
=
false
;
if
(
copy_conf
->
copied_lbns_size
()
==
1
&&
copy_conf
->
copied_lbns
(
0
)
==
RegstDesc
::
kAllLbn
){
out_regst
->
CopyLbn2ShapeMap
(
in_regst
);
get_kalllbn
=
true
;
}
ExecNode
*
node
=
mut_exec_gph
().
NewNode
();
node
->
mut_op
()
=
OpMgr
::
Singleton
().
ConstructOp
(
op_conf
);
for
(
std
::
string
ibn
:
node
->
op
()
->
input_bns
()){
std
::
string
lbn
=
node
->
op
()
->
ibn2lbn
(
ibn
);
Shape
*
shape_ptr
=
in_regst
->
GetMutShapePtr
(
lbn
);
node
->
op
()
->
SetShapePtr
(
ibn
,
shape_ptr
);
node
->
BindBnInOpAndRegst
(
ibn
,
in_regst
);
}
for
(
std
::
string
obn
:
node
->
op
()
->
output_bns
()){
std
::
string
lbn
=
node
->
op
()
->
obn2lbn
(
obn
);
Shape
*
shape_ptr
=
nullptr
;
if
(
!
get_kalllbn
){
shape_ptr
=
out_regst
->
EnrollLbn
(
lbn
);
}
else
{
shape_ptr
=
out_regst
->
GetMutShapePtr
(
lbn
);
}
node
->
op
()
->
SetShapePtr
(
obn
,
shape_ptr
);
node
->
BindBnInOpAndRegst
(
obn
,
out_regst
.
get
());
}
if
(
!
get_kalllbn
){
node
->
op
()
->
InferShape4ObAndDtbFromIb
();
}
mut_exec_gph
().
UpdateSourceAndSink
();
EnrollProducedRegstDesc
(
"copy"
,
std
::
move
(
out_regst
));
}
void
CopyHDTaskNode
::
SetFwInCopy
()
{
void
CopyHDTaskNode
::
SetFwInCopy
()
{
CHECK
(
IsFwNode
());
CHECK
(
IsFwNode
());
is_fw_in_copy_
=
true
;
is_fw_in_copy_
=
true
;
...
@@ -60,21 +14,52 @@ void CopyHDTaskNode::SetFwOutCopy() {
...
@@ -60,21 +14,52 @@ void CopyHDTaskNode::SetFwOutCopy() {
is_fw_in_copy_
=
false
;
is_fw_in_copy_
=
false
;
}
}
const
std
::
vector
<
std
::
string
>&
CopyHDTaskNode
::
CopiedLbns
()
const
{
return
IsFwInCopy
()
?
chain_node
()
->
input_lbns
()
:
chain_node
()
->
output_lbns
();
}
void
CopyHDTaskNode
::
InitWithFwNode
(
TaskNode
*
fw_node
)
{
void
CopyHDTaskNode
::
InitWithFwNode
(
TaskNode
*
fw_node
)
{
TaskNode
::
InitWithFwNode
(
fw_node
);
TaskNode
::
InitWithFwNode
(
fw_node
);
is_fw_in_copy_
=
of_dynamic_cast
<
CopyHDTaskNode
*>
(
fw_node
)
->
is_fw_in_copy_
;
is_fw_in_copy_
=
of_dynamic_cast
<
CopyHDTaskNode
*>
(
fw_node
)
->
is_fw_in_copy_
;
}
}
void
CopyHDTaskNode
::
FwBuildExecAnd
ProducedRegsts
(
TaskGraph
*
gph
)
{
void
CopyHDTaskNode
::
FwBuildExecAnd
EnrollLbn2Regsts
(
TaskGraph
*
)
{
BuildExecAndProducedRegstsForCopy
(
gph
);
return
CopyHdBuildExecAndEnrollLbn2Regsts
(
);
}
}
void
CopyHDTaskNode
::
BpBuildExecAndProducedRegsts
(
TaskGraph
*
gph
)
{
void
CopyHDTaskNode
::
FwInferShape4LbnInProducedRegsts
(
TaskGraph
*
)
{
BuildExecAndProducedRegstsForCopy
(
gph
);
return
CopyHdInferShape4LbnInProducedRegsts
();
}
void
CopyHDTaskNode
::
BpBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
)
{
return
CopyHdBuildExecAndEnrollLbn2Regsts
();
}
void
CopyHDTaskNode
::
BpInferShape4LbnInProducedRegsts
(
TaskGraph
*
)
{
return
CopyHdInferShape4LbnInProducedRegsts
();
}
void
CopyHDTaskNode
::
CopyHdBuildExecAndEnrollLbn2Regsts
(){
auto
out_regst
=
RegstDescMgr
::
Singleton
().
CreateRegisterDesc
();
BindProducedRegstAndOutEdge
(
out_regst
.
get
(),
SoleOutEdge
());
RegstDesc
*
in_regst
=
GetRelatedRegst
(
SoleInEdge
());
out_regst
->
CopyLbnFrom
(
in_regst
);
OperatorConf
op_conf
;
op_conf
.
set_name
(
"copy_hd_"
+
NewUniqueId
());
CopyHdOpConf
*
copy_hd_conf
=
op_conf
.
mutable_copy_hd_conf
();
copy_hd_conf
->
set_type
(
IsH2D
()
?
CopyHdOpConf
::
H2D
:
CopyHdOpConf
::
D2H
);
ExecNode
*
node
=
mut_exec_gph
().
NewNode
();
node
->
mut_op
()
=
OpMgr
::
Singleton
().
ConstructOp
(
op_conf
);
node
->
BindBnInOpAndRegst
(
node
->
op
()
->
SoleIbn
(),
in_regst
);
node
->
BindBnInOpAndRegst
(
node
->
op
()
->
SoleObn
(),
out_regst
.
get
());
mut_exec_gph
().
UpdateSourceAndSink
();
EnrollProducedRegstDesc
(
"out"
,
std
::
move
(
out_regst
));
}
void
CopyHDTaskNode
::
CopyHdInferShape4LbnInProducedRegsts
()
{
RegstDesc
*
in_regst
=
GetRelatedRegst
(
SoleInEdge
());
RegstDesc
*
out_regst
=
GetRelatedRegst
(
SoleOutEdge
());
out_regst
->
CopyShapeFrom
(
in_regst
);
}
}
}
// namespace oneflow
}
// namespace oneflow
oneflow/graph/copy_hd_task_node.h
浏览文件 @
3589c342
...
@@ -23,22 +23,24 @@ class CopyHDTaskNode final : public TaskNode {
...
@@ -23,22 +23,24 @@ class CopyHDTaskNode final : public TaskNode {
void
SetFwInCopy
();
void
SetFwInCopy
();
void
SetFwOutCopy
();
void
SetFwOutCopy
();
const
std
::
vector
<
std
::
string
>&
CopiedLbns
()
const
;
std
::
string
VisualStr
()
const
override
{
std
::
string
VisualStr
()
const
override
{
return
TaskNode
::
VisualStr
()
+
"CopyHD
_"
+
node_id_str
()
;
return
TaskNode
::
VisualStr
()
+
"CopyHD
"
;
}
}
private:
private:
void
InitWithFwNode
(
TaskNode
*
fw_node
)
override
;
void
InitWithFwNode
(
TaskNode
*
fw_node
)
override
;
void
BuildExecAndProducedRegstsForCopy
(
TaskGraph
*
);
void
FwBuildExecAndProducedRegsts
(
TaskGraph
*
)
override
;
void
BpBuildExecAndProducedRegsts
(
TaskGraph
*
)
override
;
std
::
unique_ptr
<
TaskNode
>
CreateSameTypeNode
()
const
override
{
std
::
unique_ptr
<
TaskNode
>
CreateSameTypeNode
()
const
override
{
return
of_make_unique
<
CopyHDTaskNode
>
();
return
of_make_unique
<
CopyHDTaskNode
>
();
}
}
void
FwBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
)
override
;
void
FwInferShape4LbnInProducedRegsts
(
TaskGraph
*
)
override
;
void
BpBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
)
override
;
void
BpInferShape4LbnInProducedRegsts
(
TaskGraph
*
)
override
;
void
CopyHdBuildExecAndEnrollLbn2Regsts
();
void
CopyHdInferShape4LbnInProducedRegsts
();
bool
is_fw_in_copy_
;
bool
is_fw_in_copy_
;
};
};
...
...
oneflow/graph/exec_graph.h
浏览文件 @
3589c342
...
@@ -44,17 +44,25 @@ class ExecNode final : public Node<ExecNode, ExecEdge> {
...
@@ -44,17 +44,25 @@ class ExecNode final : public Node<ExecNode, ExecEdge> {
std
::
shared_ptr
<
const
Operator
>&
mut_op
()
{
return
op_
;
}
std
::
shared_ptr
<
const
Operator
>&
mut_op
()
{
return
op_
;
}
void
BindBnInOpAndRegst
(
const
std
::
string
&
bn_in_op
,
RegstDesc
*
regst
)
{
void
BindBnInOpAndRegst
(
const
std
::
string
&
bn_in_op
,
RegstDesc
*
regst
)
{
CHECK
(
bn_in_op2regst
.
emplace
(
bn_in_op
,
regst
).
second
);
CHECK
(
bn_in_op2regst
_
.
emplace
(
bn_in_op
,
regst
).
second
);
}
}
RegstDesc
*
GetRegstFromBnInOp
(
const
std
::
string
&
bn_in_op
)
{
RegstDesc
*
GetRegstFromBnInOp
(
const
std
::
string
&
bn_in_op
)
{
return
bn_in_op2regst
.
at
(
bn_in_op
);
return
bn_in_op2regst_
.
at
(
bn_in_op
);
}
void
BindBnInOpAndShapePtr
(
const
std
::
string
&
bn_in_op
,
Shape
*
shape_ptr
)
{
CHECK
(
bn_in_op2shape_ptr_
.
emplace
(
bn_in_op
,
shape_ptr
).
second
);
}
const
HashMap
<
std
::
string
,
Shape
*>&
BnInOp2ShapePtr
()
const
{
return
bn_in_op2shape_ptr_
;
}
}
std
::
string
VisualStr
()
const
{
TODO
();
}
std
::
string
VisualStr
()
const
{
TODO
();
}
private:
private:
std
::
shared_ptr
<
const
Operator
>
op_
;
std
::
shared_ptr
<
const
Operator
>
op_
;
HashMap
<
std
::
string
,
RegstDesc
*>
bn_in_op2regst
;
HashMap
<
std
::
string
,
RegstDesc
*>
bn_in_op2regst_
;
HashMap
<
std
::
string
,
Shape
*>
bn_in_op2shape_ptr_
;
};
};
...
...
oneflow/graph/task_graph.cpp
浏览文件 @
3589c342
...
@@ -16,9 +16,15 @@ inline void TaskConnect(TaskNode* src_node,
...
@@ -16,9 +16,15 @@ inline void TaskConnect(TaskNode* src_node,
}
}
void
TaskGraph
::
BuildExecAnd
Produced
Regsts
()
{
void
TaskGraph
::
BuildExecAnd
EnrollLbn2
Regsts
()
{
for
(
TaskNode
&
node
:
*
this
)
{
for
(
TaskNode
&
node
:
*
this
)
{
node
.
BuildExecAndProducedRegsts
(
this
);
node
.
BuildExecAndEnrollLbn2Regsts
(
this
);
}
}
void
TaskGraph
::
InferShape4LbnInProducedRegsts
()
{
for
(
TaskNode
&
node
:
*
this
)
{
node
.
InferShape4LbnInProducedRegsts
(
this
);
}
}
}
}
...
...
oneflow/graph/task_graph.h
浏览文件 @
3589c342
...
@@ -17,18 +17,21 @@ class TaskGraph : public Graph<TaskNode, TaskEdge> {
...
@@ -17,18 +17,21 @@ class TaskGraph : public Graph<TaskNode, TaskEdge> {
OF_DISALLOW_COPY_AND_MOVE
(
TaskGraph
);
OF_DISALLOW_COPY_AND_MOVE
(
TaskGraph
);
virtual
~
TaskGraph
()
=
default
;
virtual
~
TaskGraph
()
=
default
;
// Getters
const
StageGraph
*
stage_gph
()
const
{
return
stage_gph_
.
get
();
}
const
StageGraph
*
stage_gph
()
const
{
return
stage_gph_
.
get
();
}
const
ChainGraph
*
chain_gph
()
const
{
return
stage_gph_
->
chain_gph
();
}
const
ChainGraph
*
chain_gph
()
const
{
return
stage_gph_
->
chain_gph
();
}
const
HashMap
<
CompTaskNode
*
,
CompTaskNode
*>&
faker2mccoy
()
{
const
HashMap
<
CompTaskNode
*
,
CompTaskNode
*>&
faker2mccoy
()
{
return
faker2mccoy_
;
return
faker2mccoy_
;
}
}
std
::
vector
<
CompTaskNode
*>
SortedCompTasksInChain
(
const
ChainNode
*
)
const
;
void
BuildExecAndProducedRegsts
();
// Build Exec And Set Produced Registers
void
BuildExecAndEnrollLbn2Regsts
();
void
InferShape4LbnInProducedRegsts
();
typedef
void
(
CompTaskNode
::*
CompTaskNodeMemFunc
)(
TaskGraph
*
);
using
CompTaskNodeMemFunc
=
void
(
CompTaskNode
::*
)(
TaskGraph
*
);
virtual
CompTaskNodeMemFunc
Func4FwBuildExecAndProducedRegsts
()
const
=
0
;
virtual
CompTaskNodeMemFunc
Func4FwBuildExecAndEnrollLbn2Regsts
()
const
=
0
;
virtual
CompTaskNodeMemFunc
Func4FwInferShape4LbnInProducedRegsts
()
const
=
0
;
std
::
vector
<
CompTaskNode
*>
SortedCompTasksInChain
(
const
ChainNode
*
)
const
;
protected:
protected:
TaskGraph
()
=
default
;
TaskGraph
()
=
default
;
...
...
oneflow/graph/task_node.cpp
浏览文件 @
3589c342
...
@@ -40,11 +40,19 @@ std::unique_ptr<TaskNode> TaskNode::BuildAndConnectBpNode() {
...
@@ -40,11 +40,19 @@ std::unique_ptr<TaskNode> TaskNode::BuildAndConnectBpNode() {
return
bp_node
;
return
bp_node
;
}
}
void
TaskNode
::
BuildExecAnd
Produced
Regsts
(
TaskGraph
*
gph
)
{
void
TaskNode
::
BuildExecAnd
EnrollLbn2
Regsts
(
TaskGraph
*
gph
)
{
if
(
IsFwNode
())
{
if
(
IsFwNode
())
{
FwBuildExecAnd
Produced
Regsts
(
gph
);
FwBuildExecAnd
EnrollLbn2
Regsts
(
gph
);
}
else
{
}
else
{
BpBuildExecAndProducedRegsts
(
gph
);
BpBuildExecAndEnrollLbn2Regsts
(
gph
);
}
}
void
TaskNode
::
InferShape4LbnInProducedRegsts
(
TaskGraph
*
gph
)
{
if
(
IsFwNode
())
{
FwInferShape4LbnInProducedRegsts
();
}
else
{
BpInferShape4LbnInProducedRegsts
();
}
}
}
}
...
...
oneflow/graph/task_node.h
浏览文件 @
3589c342
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#include "task/task.pb.h"
#include "task/task.pb.h"
#include "graph/stage_graph.h"
#include "graph/stage_graph.h"
#include "graph/exec_graph.h"
#include "graph/exec_graph.h"
#include "register/register_desc.h"
#include "register/register_desc
_manager
.h"
namespace
oneflow
{
namespace
oneflow
{
...
@@ -39,7 +39,8 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
...
@@ -39,7 +39,8 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
std
::
unique_ptr
<
TaskNode
>
BuildAndConnectBpNode
();
std
::
unique_ptr
<
TaskNode
>
BuildAndConnectBpNode
();
//
//
void
BuildExecAndProducedRegsts
(
TaskGraph
*
);
void
BuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
);
void
InferShape4LbnInProducedRegsts
(
TaskGraph
*
);
RegstDesc
*
GetProducedRegstDesc
(
const
std
::
string
&
regst_desc_name
);
RegstDesc
*
GetProducedRegstDesc
(
const
std
::
string
&
regst_desc_name
);
//
//
...
@@ -67,8 +68,10 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
...
@@ -67,8 +68,10 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
void
EnrollProducedRegstDesc
(
const
std
::
string
&
regst_desc_name
,
void
EnrollProducedRegstDesc
(
const
std
::
string
&
regst_desc_name
,
std
::
unique_ptr
<
RegstDesc
>&&
regst_desc
);
std
::
unique_ptr
<
RegstDesc
>&&
regst_desc
);
virtual
void
FwBuildExecAndProducedRegsts
(
TaskGraph
*
)
{
UNEXPECTED_RUN
();
}
virtual
void
FwBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
)
{
UNEXPECTED_RUN
();
}
virtual
void
BpBuildExecAndProducedRegsts
(
TaskGraph
*
)
{
UNEXPECTED_RUN
();
}
virtual
void
BpBuildExecAndEnrollLbn2Regsts
(
TaskGraph
*
)
{
UNEXPECTED_RUN
();
}
virtual
void
FwInferShape4LbnInProducedRegsts
(
TaskGraph
*
)
{
UNEXPECTED_RUN
();
}
virtual
void
BpInferShape4LbnInProducedRegsts
(
TaskGraph
*
)
{
UNEXPECTED_RUN
();
}
private:
private:
// In task_gph level
// In task_gph level
...
...
oneflow/operator/comm_net_op.h
浏览文件 @
3589c342
...
@@ -13,7 +13,7 @@ class CommNetOp final : public SysOperator {
...
@@ -13,7 +13,7 @@ class CommNetOp final : public SysOperator {
~
CommNetOp
()
=
default
;
~
CommNetOp
()
=
default
;
void
InitFromOpConf
(
const
OperatorConf
&
op_conf
)
override
;
void
InitFromOpConf
(
const
OperatorConf
&
op_conf
)
override
;
void
InferShape4ObAndDtbFromIb
()
const
override
{
TODO
();
}
void
InferShape4ObAndDtbFromIb
()
const
override
{
UNEXPECTED_RUN
();
}
std
::
string
GetValueFromPbOpConf
(
const
std
::
string
&
k
)
const
override
;
std
::
string
GetValueFromPbOpConf
(
const
std
::
string
&
k
)
const
override
;
std
::
string
normal_ibn2lbn
(
const
std
::
string
&
input_bn
)
const
override
;
std
::
string
normal_ibn2lbn
(
const
std
::
string
&
input_bn
)
const
override
;
...
...
oneflow/operator/copy_hd_op.cpp
0 → 100644
浏览文件 @
3589c342
#include "operator/copy_hd_op.h"
#include "operator/operator_manager.h"
namespace
oneflow
{
void
CopyHdOp
::
InitFromOpConf
(
const
OperatorConf
&
op_conf
)
{
CHECK
(
op_conf
.
has_copy_hd_conf
());
mut_op_conf
()
=
op_conf
;
EnrollInputBn
(
"in"
);
EnrollOutputBn
(
"out"
);
}
std
::
string
CopyHdOp
::
GetValueFromPbOpConf
(
const
std
::
string
&
k
)
const
{
return
GetValueFromPbMessage
(
op_conf
().
copy_hd_conf
(),
k
);
}
REGISTER_OP
(
OperatorConf
::
kCopyConf
,
CopyHdOp
);
}
// namespace oneflow
oneflow/operator/copy_op.h
→
oneflow/operator/copy_
hd_
op.h
浏览文件 @
3589c342
#ifndef ONEFLOW_OPERATOR_COPY_OP_H_
#ifndef ONEFLOW_OPERATOR_COPY_
HD_
OP_H_
#define ONEFLOW_OPERATOR_COPY_OP_H_
#define ONEFLOW_OPERATOR_COPY_
HD_
OP_H_
#include "operator/operator.h"
#include "operator/operator.h"
namespace
oneflow
{
namespace
oneflow
{
class
CopyOp
final
:
public
SysOperator
{
class
Copy
Hd
Op
final
:
public
SysOperator
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
CopyOp
);
OF_DISALLOW_COPY_AND_MOVE
(
Copy
Hd
Op
);
CopyOp
()
=
default
;
Copy
Hd
Op
()
=
default
;
~
CopyOp
()
=
default
;
~
Copy
Hd
Op
()
=
default
;
void
InitFromOpConf
(
const
OperatorConf
&
op_conf
)
override
;
void
InitFromOpConf
(
const
OperatorConf
&
op_conf
)
override
;
void
InitFromOperatorProto
(
const
OperatorProto
&
operatorproto
)
override
;
OperatorProto
ToOperatorProto
()
override
;
void
InferShape4ObAndDtbFromIb
()
const
override
;
std
::
string
GetValueFromPbOpConf
(
const
std
::
string
&
k
)
const
override
;
std
::
string
GetValueFromPbOpConf
(
const
std
::
string
&
k
)
const
override
;
void
InferShape4ObAndDtbFromIb
()
const
override
{
UNEXPECTED_RUN
();
}
std
::
string
normal_ibn2lbn
(
const
std
::
string
&
input_bn
)
const
override
{
std
::
string
normal_ibn2lbn
(
const
std
::
string
&
input_bn
)
const
override
{
return
ibn2lbn_
.
at
(
input_bn
)
;
return
RegstDesc
::
kAllLbn
;
}
}
std
::
string
obn2lbn
(
const
std
::
string
&
output_bn
)
const
override
{
std
::
string
obn2lbn
(
const
std
::
string
&
output_bn
)
const
override
{
return
obn2lbn_
.
at
(
output_bn
)
;
return
RegstDesc
::
kAllLbn
;
}
}
private:
private:
HashMap
<
std
::
string
,
std
::
string
>
ibn2lbn_
;
HashMap
<
std
::
string
,
std
::
string
>
obn2lbn_
;
};
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_OPERATOR_COPY_OP_H_
#endif // ONEFLOW_OPERATOR_COPY_
HD_
OP_H_
oneflow/operator/copy_op.cpp
已删除
100644 → 0
浏览文件 @
b55817ac
#include "operator/copy_op.h"
#include "operator/operator_manager.h"
namespace
oneflow
{
void
CopyOp
::
InitFromOpConf
(
const
OperatorConf
&
op_conf
)
{
CHECK
(
op_conf
.
has_copy_conf
());
mut_op_conf
()
=
op_conf
;
for
(
int64_t
i
=
0
;
i
<
op_conf
.
copy_conf
().
copied_lbns_size
();
++
i
)
{
std
::
string
ibn
=
"in_"
+
std
::
to_string
(
i
);
EnrollInputBn
(
ibn
);
CHECK
(
ibn2lbn_
.
emplace
(
ibn
,
op_conf
.
copy_conf
().
copied_lbns
(
i
)).
second
);
std
::
string
obn
=
"out_"
+
std
::
to_string
(
i
);
EnrollOutputBn
(
obn
);
CHECK
(
obn2lbn_
.
emplace
(
obn
,
op_conf
.
copy_conf
().
copied_lbns
(
i
)).
second
);
}
}
std
::
string
CopyOp
::
GetValueFromPbOpConf
(
const
std
::
string
&
k
)
const
{
return
GetValueFromPbMessage
(
op_conf
().
copy_conf
(),
k
);
}
void
CopyOp
::
InitFromOperatorProto
(
const
OperatorProto
&
operatorproto
)
{
CHECK
(
operatorproto
.
has_copy_op
());
Operator
::
InitFromOperatorProto
(
operatorproto
);
ibn2lbn_
=
PbMap2HashMap
(
operatorproto
.
copy_op
().
ibn2lbn
());
obn2lbn_
=
PbMap2HashMap
(
operatorproto
.
copy_op
().
obn2lbn
());
}
OperatorProto
CopyOp
::
ToOperatorProto
()
{
OperatorProto
operatorproto
=
Operator
::
ToOperatorProto
();
CopyOpProto
copyopproto
;
*
(
copyopproto
.
mutable_ibn2lbn
())
=
HashMap2PbMap
(
ibn2lbn_
);
*
(
copyopproto
.
mutable_obn2lbn
())
=
HashMap2PbMap
(
obn2lbn_
);
*
(
operatorproto
.
mutable_copy_op
())
=
copyopproto
;
return
operatorproto
;
}
void
CopyOp
::
InferShape4ObAndDtbFromIb
()
const
{
CHECK_EQ
(
output_bns
().
size
(),
input_bns
().
size
());
for
(
size_t
i
=
0
;
i
<
output_bns
().
size
();
++
i
){
std
::
string
obn
=
output_bns
().
at
(
i
);
std
::
string
ibn
=
input_bns
().
at
(
i
);
*
GetShapePtr
(
obn
)
=
*
GetShapePtr
(
ibn
);
}
}
REGISTER_OP
(
OperatorConf
::
kCopyConf
,
CopyOp
);
}
// namespace oneflow
oneflow/operator/op_conf.proto
浏览文件 @
3589c342
...
@@ -170,13 +170,12 @@ message CommNetOpConf {
...
@@ -170,13 +170,12 @@ message CommNetOpConf {
CommNetType
comm_net_type
=
1
;
CommNetType
comm_net_type
=
1
;
}
}
message
CopyOpConf
{
message
Copy
Hd
OpConf
{
enum
CopyType
{
enum
Copy
Hd
Type
{
H2D
=
0
;
H2D
=
0
;
D2H
=
1
;
D2H
=
1
;
}
}
CopyType
copy_type
=
1
;
CopyHdType
type
=
1
;
repeated
string
copied_lbns
=
2
;
}
}
...
@@ -186,11 +185,14 @@ message CloneOpConf {
...
@@ -186,11 +185,14 @@ message CloneOpConf {
}
}
message
BoxConcatConf
{
message
BoxConcatConf
{
int32
axis
=
1
;
enum
ConcatType
{
kData
=
0
;
kModel
=
1
;
};
ConcatType
type
=
1
;
}
}
message
BoxSplitConf
{
message
BoxDataSplitConf
{
int32
axis
=
1
;
}
}
message
BoxCloneConf
{
message
BoxCloneConf
{
...
@@ -202,7 +204,7 @@ message BoxingOpConf {
...
@@ -202,7 +204,7 @@ message BoxingOpConf {
uint32
out_num
=
3
;
uint32
out_num
=
3
;
BoxConcatConf
concat_box
=
4
;
BoxConcatConf
concat_box
=
4
;
oneof
out_box
{
oneof
out_box
{
Box
SplitConf
split_box
=
5
;
Box
DataSplitConf
data_
split_box
=
5
;
BoxCloneConf
clone_box
=
6
;
BoxCloneConf
clone_box
=
6
;
}
}
}
}
...
@@ -226,7 +228,7 @@ message OperatorConf {
...
@@ -226,7 +228,7 @@ message OperatorConf {
ReluOpConf
relu_conf
=
104
;
ReluOpConf
relu_conf
=
104
;
SoftmaxOpConf
softmax_conf
=
105
;
SoftmaxOpConf
softmax_conf
=
105
;
MultinomialLogisticLossOpConf
multinomial_logistic_loss_conf
=
106
;
MultinomialLogisticLossOpConf
multinomial_logistic_loss_conf
=
106
;
Copy
OpConf
copy
_conf
=
107
;
Copy
HdOpConf
copy_hd
_conf
=
107
;
CloneOpConf
clone_conf
=
108
;
CloneOpConf
clone_conf
=
108
;
BoxingOpConf
boxing_conf
=
109
;
BoxingOpConf
boxing_conf
=
109
;
ModelUpdateOpConf
model_update_conf
=
110
;
ModelUpdateOpConf
model_update_conf
=
110
;
...
...
oneflow/operator/operator.proto
浏览文件 @
3589c342
syntax
=
"proto3"
;
syntax
=
"proto3"
;
package
oneflow
;
package
oneflow
;
import
"operator/op_conf.proto"
;
message
CopyOpProto
{
import
"operator/op_conf.proto"
;
map
<
string
,
string
>
ibn2lbn
=
1
;
map
<
string
,
string
>
obn2lbn
=
2
;
}
message
OperatorProto
{
message
OperatorProto
{
OperatorConf
user_conf
=
1
;
OperatorConf
user_conf
=
1
;
...
@@ -22,6 +18,5 @@ message OperatorProto {
...
@@ -22,6 +18,5 @@ message OperatorProto {
repeated
string
model_tmp_bns
=
11
;
repeated
string
model_tmp_bns
=
11
;
oneof
specified_op_proto
{
oneof
specified_op_proto
{
CopyOpProto
copy_op
=
100
;
}
}
}
}
oneflow/register/register_desc.cpp
浏览文件 @
3589c342
...
@@ -7,14 +7,22 @@ RegstDesc::RegstDesc() {
...
@@ -7,14 +7,22 @@ RegstDesc::RegstDesc() {
producer_
=
nullptr
;
producer_
=
nullptr
;
}
}
void
RegstDesc
::
CopyLbn2ShapeMap
(
const
RegstDesc
*
rhs
)
{
void
RegstDesc
::
CopyLbnFrom
(
const
RegstDesc
*
rhs
)
{
lbn2shape_
.
clear
();
for
(
const
auto
&
pair
:
rhs
->
lbn2shape_
)
{
for
(
const
auto
&
pair
:
rhs
->
lbn2shape_
)
{
const
std
::
string
&
lbn
=
pair
.
first
;
const
std
::
string
&
lbn
=
pair
.
first
;
auto
shape
=
of_make_unique
<
Shape
>
(
*
(
pair
.
second
)
);
auto
shape
=
of_make_unique
<
Shape
>
();
CHECK
(
lbn2shape_
.
emplace
(
lbn
,
std
::
move
(
shape
)).
second
);
CHECK
(
lbn2shape_
.
emplace
(
lbn
,
std
::
move
(
shape
)).
second
);
}
}
}
}
void
RegstDesc
::
CopyShapeFrom
(
const
RegstDesc
*
rhs
)
{
for
(
const
auto
&
pair
:
lbn2shape_
)
{
const
std
::
string
&
lbn
=
pair
.
first
;
*
(
lbn2shape_
.
at
(
lbn
))
=
rhs
->
GetShape
(
lbn
);
}
}
Shape
*
RegstDesc
::
EnrollLbn
(
const
std
::
string
&
lbn
)
{
Shape
*
RegstDesc
::
EnrollLbn
(
const
std
::
string
&
lbn
)
{
Shape
*
raw_ptr
=
new
Shape
;
Shape
*
raw_ptr
=
new
Shape
;
std
::
unique_ptr
<
Shape
>
uptr
(
raw_ptr
);
std
::
unique_ptr
<
Shape
>
uptr
(
raw_ptr
);
...
...
oneflow/register/register_desc.h
浏览文件 @
3589c342
...
@@ -11,11 +11,11 @@ namespace oneflow {
...
@@ -11,11 +11,11 @@ namespace oneflow {
class
TaskNode
;
class
TaskNode
;
class
RegstDesc
{
class
RegstDesc
final
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
RegstDesc
);
OF_DISALLOW_COPY_AND_MOVE
(
RegstDesc
);
RegstDesc
();
RegstDesc
();
virtual
~
RegstDesc
()
=
default
;
~
RegstDesc
()
=
default
;
// regst_desc_id
// regst_desc_id
uint64_t
regst_desc_id
()
const
{
return
regst_desc_id_
;
}
uint64_t
regst_desc_id
()
const
{
return
regst_desc_id_
;
}
...
@@ -25,7 +25,8 @@ class RegstDesc {
...
@@ -25,7 +25,8 @@ class RegstDesc {
void
SetProducer
(
const
TaskNode
*
task_node
)
{
producer_
=
task_node
;
}
void
SetProducer
(
const
TaskNode
*
task_node
)
{
producer_
=
task_node
;
}
// Lbn and Shape
// Lbn and Shape
void
CopyLbn2ShapeMap
(
const
RegstDesc
*
);
void
CopyLbnFrom
(
const
RegstDesc
*
);
void
CopyShapeFrom
(
const
RegstDesc
*
);
Shape
*
EnrollLbn
(
const
std
::
string
&
lbn
);
Shape
*
EnrollLbn
(
const
std
::
string
&
lbn
);
const
Shape
&
GetShape
(
const
std
::
string
&
lbn
);
const
Shape
&
GetShape
(
const
std
::
string
&
lbn
);
Shape
*
GetMutShapePtr
(
const
std
::
string
&
lbn
);
Shape
*
GetMutShapePtr
(
const
std
::
string
&
lbn
);
...
@@ -40,24 +41,6 @@ class RegstDesc {
...
@@ -40,24 +41,6 @@ class RegstDesc {
};
};
class
ContigRegstDesc
final
:
public
RegstDesc
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
ContigRegstDesc
);
ContigRegstDesc
()
=
default
;
~
ContigRegstDesc
()
=
default
;
private:
};
class
DisContigRegstDesc
final
:
public
RegstDesc
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
DisContigRegstDesc
);
DisContigRegstDesc
()
=
default
;
~
DisContigRegstDesc
()
=
default
;
};
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_REGISTER_REGISTER_DESC_H_
#endif // ONEFLOW_REGISTER_REGISTER_DESC_H_
oneflow/register/register_desc_manager.h
0 → 100644
浏览文件 @
3589c342
#ifndef ONEFLOW_REGISTER_REGISTER_DESC_MANAGER_H_
#define ONEFLOW_REGISTER_REGISTER_DESC_MANAGER_H_
#include "register/register_desc.h"
namespace
oneflow
{
class
RegstDescMgr
final
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
RegstDescMgr
);
RegstDescMgr
()
=
default
;
~
RegstDescMgr
()
=
default
;
std
::
unique_ptr
<
RegstDesc
>
CreateRegisterDesc
()
{
return
of_make_unique
<
RegstDesc
>
();
}
private:
};
}
// namespace oneflow
#endif // ONEFLOW_REGISTER_REGISTER_DESC_MANAGER_H_
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录