Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
8c895ee9
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 接近 3 年
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
8c895ee9
编写于
8月 02, 2018
作者:
C
cheng cheng
提交者:
Jinhui Yuan
8月 02, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rm IsBwClone (#1078)
上级
8d2daef3
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
17 addition
and
42 deletion
+17
-42
oneflow/core/graph/logical_graph.cpp
oneflow/core/graph/logical_graph.cpp
+0
-4
oneflow/core/graph/logical_graph.h
oneflow/core/graph/logical_graph.h
+0
-1
oneflow/core/graph/normal_backward_compute_task_node.cpp
oneflow/core/graph/normal_backward_compute_task_node.cpp
+17
-36
oneflow/core/graph/normal_backward_compute_task_node.h
oneflow/core/graph/normal_backward_compute_task_node.h
+0
-1
未找到文件。
oneflow/core/graph/logical_graph.cpp
浏览文件 @
8c895ee9
...
...
@@ -295,7 +295,6 @@ void LogicalGraph::AddOneBackwardClone(const BackwardCloneInfo& clone_info) {
LogicalNode
*
clone_node
=
NewNode
<
NormalBackwardLogicalNode
>
();
clone_node
->
mut_op_vec
()
=
{
clone_op
};
clone_node
->
mut_parallel_desc
()
=
clone_info
.
succ_node
->
parallel_desc
();
CHECK
(
bw_clone2fw_producer_
.
emplace
(
clone_node
,
nullptr
).
second
);
*
(
clone_op
->
MutBnInOp2Lbi
(
clone_op
->
SoleIbn
()))
=
clone_info
.
lbi
;
*
(
clone_op
->
MutBnInOp2Lbi
(
clone_op
->
SoleIdbn
()))
=
clone_info
.
lbi
;
...
...
@@ -567,9 +566,6 @@ void LogicalGraph::ConnectFwToBw() {
if
(
bw_node
->
fw_node
()
==
nullptr
)
{
return
;
}
Connect
<
LogicalNode
>
(
bw_node
->
fw_node
(),
NewEdge
(),
bw_node
);
});
for
(
auto
&
pair
:
bw_clone2fw_producer_
)
{
if
(
pair
.
second
)
{
Connect
<
LogicalNode
>
(
pair
.
second
,
NewEdge
(),
pair
.
first
);
}
}
}
void
LogicalGraph
::
UpdateEdge2Ibn
(
const
LogicalEdge
*
edge
,
const
std
::
string
&
ibn
)
{
...
...
oneflow/core/graph/logical_graph.h
浏览文件 @
8c895ee9
...
...
@@ -69,7 +69,6 @@ class LogicalGraph final : public Graph<LogicalNode, LogicalEdge> {
HashMap
<
const
LogicalEdge
*
,
std
::
string
>
edge2ibn_
;
HashMap
<
const
LogicalEdge
*
,
std
::
string
>
edge2obn_
;
HashMap
<
LogicalNode
*
,
LogicalNode
*>
bw_clone2fw_producer_
;
};
}
// namespace oneflow
...
...
oneflow/core/graph/normal_backward_compute_task_node.cpp
浏览文件 @
8c895ee9
...
...
@@ -37,7 +37,7 @@ void NormalBackwardCompTaskNode::ConsumeAllRegsts() {
}
}
CompTaskNode
*
fw_task
=
GetRelatedFwTaskNode
();
if
(
fw_task
&&
!
IsBwClone
()
)
{
if
(
fw_task
)
{
const
std
::
list
<
std
::
weak_ptr
<
RegstDesc
>>&
in_regst
=
fw_task
->
GetConsumedRegst
(
"in"
);
for
(
std
::
weak_ptr
<
RegstDesc
>
regst
:
in_regst
)
{
ConsumeRegst
(
"in"
,
regst
.
lock
());
}
}
...
...
@@ -78,7 +78,7 @@ void NormalBackwardCompTaskNode::BuildExecGphAndBindOutDiffRegst() {
}
});
CompTaskNode
*
fw_task
=
GetRelatedFwTaskNode
();
if
(
fw_task
&&
!
IsBwClone
()
)
{
if
(
fw_task
)
{
const
HashSet
<
LogicalBlobId
>&
lbi_boxing
=
fw_task
->
logical_node
()
->
lbi_boxing
();
const
HashSet
<
LogicalBlobId
>&
lbi_121
=
fw_task
->
logical_node
()
->
lbi_121
();
std
::
shared_ptr
<
RegstDesc
>
out_regst_boxing
=
GetSoleConsumedRegst
(
"boxing_out"
);
...
...
@@ -106,7 +106,6 @@ void NormalBackwardCompTaskNode::BuildExecGphAndBindOutDiffRegst() {
void
NormalBackwardCompTaskNode
::
LinkFwExecNode
()
{
CompTaskNode
*
fw_task
=
GetRelatedFwTaskNode
();
if
(
fw_task
==
nullptr
)
{
return
;
}
if
(
IsBwClone
())
{
return
;
}
HashMap
<
std
::
string
,
ExecNode
*>
op_name2fw_exec
;
fw_task
->
exec_gph
().
ForEachNode
([
&
](
ExecNode
*
fw_exec
)
{
CHECK
(
op_name2fw_exec
.
emplace
(
fw_exec
->
op
()
->
op_name
(),
fw_exec
).
second
);
...
...
@@ -139,14 +138,7 @@ void NormalBackwardCompTaskNode::BuildInDiffRegst() {
const
LogicalBlobId
&
lbi
=
cur_node
->
op
()
->
BnInOp2Lbi
(
idbn
);
CompTaskNode
*
fw_task
=
GetRelatedFwTaskNode
();
if
(
fw_task
)
{
if
(
IsBwClone
())
{
std
::
list
<
std
::
weak_ptr
<
RegstDesc
>>
out_regsts
;
out_regsts
.
push_back
(
GetSoleConsumedRegst
(
"boxing_out"
));
out_regsts
.
push_back
(
GetSoleConsumedRegst
(
"121_out"
));
cur_node
->
BindBnWithOneOfTheRegsts
(
GenUnDiffBn
(
idbn
),
out_regsts
);
}
else
{
cur_node
->
BindBnWithOneOfTheRegsts
(
GenUnDiffBn
(
idbn
),
GetConsumedRegst
(
"in"
));
}
cur_node
->
BindBnWithOneOfTheRegsts
(
GenUnDiffBn
(
idbn
),
GetConsumedRegst
(
"in"
));
}
if
(
TryAddLbiToB121RegstAndBindIt
(
cur_node
,
idbn
,
"in_diff"
)
==
false
)
{
CHECK
(
found_lbis
.
empty
()
||
found_lbis
.
find
(
lbi
)
!=
found_lbis
.
end
());
...
...
@@ -174,27 +166,23 @@ void NormalBackwardCompTaskNode::BindModelDiffRegst() {
void
NormalBackwardCompTaskNode
::
InferBlobDescsInProducedRegsts
()
{
if
(
GetRelatedFwTaskNode
())
{
if
(
IsBwClone
())
{
mut_exec_gph
().
SoleNode
()
->
InferDiffBlobDescsWithoutFwNode
(
parallel_ctx
());
}
else
{
std
::
shared_ptr
<
RegstDesc
>
in_diff_regst_boxing
=
GetProducedRegst
(
"boxing_in_diff"
);
for
(
std
::
weak_ptr
<
RegstDesc
>
regst
:
GetConsumedRegst
(
"in"
))
{
in_diff_regst_boxing
->
CopyBlobDescWithoutAddLbi
(
regst
.
lock
().
get
());
}
std
::
shared_ptr
<
RegstDesc
>
in_diff_regst_boxing
=
GetProducedRegst
(
"boxing_in_diff"
);
for
(
std
::
weak_ptr
<
RegstDesc
>
regst
:
GetConsumedRegst
(
"in"
))
{
in_diff_regst_boxing
->
CopyBlobDescWithoutAddLbi
(
regst
.
lock
().
get
());
}
std
::
shared_ptr
<
RegstDesc
>
in_diff_regst_121
=
GetProducedRegst
(
"121_in_diff"
);
for
(
std
::
weak_ptr
<
RegstDesc
>
regst
:
GetConsumedRegst
(
"in"
))
{
in_diff_regst_121
->
CopyBlobDescWithoutAddLbi
(
regst
.
lock
().
get
());
}
std
::
shared_ptr
<
RegstDesc
>
in_diff_regst_121
=
GetProducedRegst
(
"121_in_diff"
);
for
(
std
::
weak_ptr
<
RegstDesc
>
regst
:
GetConsumedRegst
(
"in"
))
{
in_diff_regst_121
->
CopyBlobDescWithoutAddLbi
(
regst
.
lock
().
get
());
}
std
::
shared_ptr
<
RegstDesc
>
md_diff_regst
=
GetProducedRegst
(
"model_diff"
);
if
(
md_diff_regst
)
{
md_diff_regst
->
CopyBlobDescFrom
(
GetSoleConsumedRegst
(
"model"
).
get
());
}
std
::
shared_ptr
<
RegstDesc
>
md_diff_regst
=
GetProducedRegst
(
"model_diff"
);
if
(
md_diff_regst
)
{
md_diff_regst
->
CopyBlobDescFrom
(
GetSoleConsumedRegst
(
"model"
).
get
());
}
std
::
shared_ptr
<
RegstDesc
>
activation_diff_regst
=
GetProducedRegst
(
"activation_diff"
);
activation_diff_regst
->
CopyBlobDescWithoutAddLbi
(
GetSoleConsumedRegst
(
"activation"
).
get
());
activation_diff_regst
->
CopyBlobDescWithoutAddLbi
(
GetSoleConsumedRegst
(
"boxing_out"
).
get
());
activation_diff_regst
->
CopyBlobDescWithoutAddLbi
(
GetSoleConsumedRegst
(
"121_out"
).
get
());
}
std
::
shared_ptr
<
RegstDesc
>
activation_diff_regst
=
GetProducedRegst
(
"activation_diff"
);
activation_diff_regst
->
CopyBlobDescWithoutAddLbi
(
GetSoleConsumedRegst
(
"activation"
).
get
());
activation_diff_regst
->
CopyBlobDescWithoutAddLbi
(
GetSoleConsumedRegst
(
"boxing_out"
).
get
());
activation_diff_regst
->
CopyBlobDescWithoutAddLbi
(
GetSoleConsumedRegst
(
"121_out"
).
get
());
}
else
{
mut_exec_gph
().
SoleNode
()
->
InferDiffBlobDescsWithoutFwNode
(
parallel_ctx
());
}
...
...
@@ -210,11 +198,4 @@ CompTaskNode* NormalBackwardCompTaskNode::GetRelatedFwTaskNode() {
return
nullptr
;
}
bool
NormalBackwardCompTaskNode
::
IsBwClone
()
const
{
const
BackwardLogicalNode
*
bw_logical_node
=
dynamic_cast
<
const
BackwardLogicalNode
*>
(
logical_node
());
CHECK_NOTNULL
(
bw_logical_node
);
return
bw_logical_node
->
fw_node
()
==
nullptr
;
}
}
// namespace oneflow
oneflow/core/graph/normal_backward_compute_task_node.h
浏览文件 @
8c895ee9
...
...
@@ -15,7 +15,6 @@ class NormalBackwardCompTaskNode final : public CompTaskNode {
void
ConsumeAllRegsts
()
override
;
void
BuildExecGphAndRegst
()
override
;
TaskType
GetTaskType
()
const
override
{
return
TaskType
::
kNormalBackward
;
}
bool
IsBwClone
()
const
;
protected:
void
BuildExecGphAndBindOutDiffRegst
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录