Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
cf84a6e8
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
cf84a6e8
编写于
9月 19, 2018
作者:
L
Li Xinqi
提交者:
Niu Chong
9月 19, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
no out_diff then no backward node (#1250)
上级
31693ec1
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
38 addition
and
1 deletion
+38
-1
oneflow/core/graph/logical_graph.cpp
oneflow/core/graph/logical_graph.cpp
+38
-1
未找到文件。
oneflow/core/graph/logical_graph.cpp
浏览文件 @
cf84a6e8
...
...
@@ -7,6 +7,41 @@
namespace
oneflow
{
namespace
{
std
::
function
<
bool
(
const
LogicalNode
*
)
>
MakePredicatorHasActualOutDiff
(
const
LogicalGraph
*
graph
)
{
std
::
list
<
LogicalNode
*>
loss_nodes
;
graph
->
ForEachNode
([
&
](
LogicalNode
*
node
)
{
if
(
dynamic_cast
<
LossLogicalNode
*>
(
node
))
{
loss_nodes
.
push_back
(
node
);
}
});
auto
nodes_have_actual_out_diff_ptr
=
std
::
make_shared
<
HashSet
<
const
LogicalNode
*>>
();
auto
HasBwConnection
=
[](
const
LogicalNode
*
prev
,
const
LogicalNode
*
next
)
{
HashSet
<
LogicalBlobId
>
idbn_lbis
;
for
(
const
auto
&
idbn
:
next
->
SoleOp
()
->
input_diff_bns
())
{
idbn_lbis
.
insert
(
next
->
SoleOp
()
->
BnInOp2Lbi
(
idbn
));
}
for
(
const
auto
&
odbn
:
prev
->
SoleOp
()
->
output_diff_bns
())
{
LogicalBlobId
lbi
=
prev
->
SoleOp
()
->
BnInOp2Lbi
(
odbn
);
if
(
idbn_lbis
.
find
(
lbi
)
!=
idbn_lbis
.
end
())
{
return
true
;
}
}
return
false
;
};
auto
ForEachNext
=
[
&
](
LogicalNode
*
node
,
const
std
::
function
<
void
(
LogicalNode
*
)
>&
Handler
)
{
node
->
ForEachNodeOnInEdge
([
&
](
LogicalNode
*
in_node
)
{
if
(
HasBwConnection
(
in_node
,
node
))
{
Handler
(
in_node
);
}
});
};
graph
->
BfsForEachNode
(
loss_nodes
,
ForEachNext
,
[
nodes_have_actual_out_diff_ptr
](
LogicalNode
*
node
)
{
nodes_have_actual_out_diff_ptr
->
insert
(
node
);
});
return
[
nodes_have_actual_out_diff_ptr
](
const
LogicalNode
*
node
)
{
return
nodes_have_actual_out_diff_ptr
->
find
(
node
)
!=
nodes_have_actual_out_diff_ptr
->
end
();
};
}
}
// namespace
LogicalGraph
::
LogicalGraph
(
bool
is_train
)
{
BuildFwStruct
();
if
(
is_train
)
{
GroupNodesForReduceStruct
();
}
...
...
@@ -166,6 +201,7 @@ void LogicalGraph::BuildBwStruct() {
}
void
LogicalGraph
::
NaiveBuildBwStruct
()
{
auto
HasActualOutDiff
=
MakePredicatorHasActualOutDiff
(
this
);
HashSet
<
LogicalNode
*>
nodes_need_bw
;
TopoForEachNode
([
&
](
LogicalNode
*
logical_node
)
{
auto
fw_node
=
dynamic_cast
<
ForwardLogicalNode
*>
(
logical_node
);
...
...
@@ -175,7 +211,8 @@ void LogicalGraph::NaiveBuildBwStruct() {
return
;
}
for
(
LogicalEdge
*
edge
:
fw_node
->
in_edges
())
{
if
(
nodes_need_bw
.
find
(
edge
->
src_node
())
!=
nodes_need_bw
.
end
())
{
if
(
nodes_need_bw
.
find
(
edge
->
src_node
())
!=
nodes_need_bw
.
end
()
&&
HasActualOutDiff
(
fw_node
))
{
CHECK
(
nodes_need_bw
.
insert
(
fw_node
).
second
);
return
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录