Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
6a139c48
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
6a139c48
编写于
8月 30, 2018
作者:
J
Jinhui Yuan
提交者:
GitHub
8月 30, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rm duplicate ReduceTaskNodes caused by ReduceConcat&Split (#1179)
Former-commit-id:
40c299bc
上级
0252bca8
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
29 addition
and
18 deletion
+29
-18
oneflow/core/graph/task_graph.cpp
oneflow/core/graph/task_graph.cpp
+8
-16
oneflow/core/graph/task_graph.h
oneflow/core/graph/task_graph.h
+21
-2
未找到文件。
oneflow/core/graph/task_graph.cpp
浏览文件 @
6a139c48
...
...
@@ -133,26 +133,17 @@ void TaskGraph::BuildCtrlRegstDescInSameChain() {
}
}
struct
ReduceTaskNodes
{
CompTaskNode
*
concat
=
nullptr
;
CompTaskNode
*
scatter
=
nullptr
;
CompTaskNode
*
local_add
=
nullptr
;
CompTaskNode
*
global_add
=
nullptr
;
CompTaskNode
*
gather
=
nullptr
;
CompTaskNode
*
split
=
nullptr
;
};
void
TaskGraph
::
EnableMemSharingInReduceStruct
()
{
HashMap
<
CompTaskNode
*
,
ReduceTaskNodes
>
bw2
reduce_tasks
;
CollectReduceTaskNodes
(
&
bw2
reduce_tasks
);
for
(
auto
&
pair
:
bw2
reduce_tasks
)
{
EnableMemSharingInOneReduce
(
pair
.
second
);
AddCtrlEdge4MemSharingInOneReduce
(
pair
.
second
);
std
::
unordered_set
<
ReduceTaskNodes
,
ReduceTaskNodesHasher
>
reduce_tasks
;
CollectReduceTaskNodes
(
&
reduce_tasks
);
for
(
auto
&
reduce_task
:
reduce_tasks
)
{
EnableMemSharingInOneReduce
(
reduce_task
);
AddCtrlEdge4MemSharingInOneReduce
(
reduce_task
);
}
}
void
TaskGraph
::
CollectReduceTaskNodes
(
HashMap
<
CompTaskNode
*
,
ReduceTaskNodes
>*
bw2
reduce_tasks
)
const
{
std
::
unordered_set
<
ReduceTaskNodes
,
ReduceTaskNodesHasher
>*
reduce_tasks
)
const
{
auto
FindSuccReduceTaskNode
=
[](
CompTaskNode
*
task_node
,
TaskType
type
)
->
CompTaskNode
*
{
for
(
TaskEdge
*
out_edge
:
task_node
->
out_edges
())
{
TaskNode
*
dst_node
=
out_edge
->
dst_node
();
...
...
@@ -186,7 +177,7 @@ void TaskGraph::CollectReduceTaskNodes(
return
;
}
ReduceTaskNodes
&
reduce_task_nodes
=
(
*
bw2reduce_tasks
)[
bw_task_node
]
;
ReduceTaskNodes
reduce_task_nodes
;
CompTaskNode
*
diff_acc_task_node
=
FindSuccReduceTaskNode
(
bw_task_node
,
TaskType
::
kMdDiffAcc
);
if
(
diff_acc_task_node
!=
nullptr
)
{
FindConcatAndScatter
(
diff_acc_task_node
,
&
reduce_task_nodes
);
...
...
@@ -212,6 +203,7 @@ void TaskGraph::CollectReduceTaskNodes(
CHECK
(
reduce_task_nodes
.
scatter
!=
nullptr
);
CHECK
(
reduce_task_nodes
.
global_add
!=
nullptr
);
CHECK
(
reduce_task_nodes
.
gather
!=
nullptr
);
reduce_tasks
->
insert
(
reduce_task_nodes
);
});
}
...
...
oneflow/core/graph/task_graph.h
浏览文件 @
6a139c48
...
...
@@ -9,7 +9,26 @@
namespace
oneflow
{
class
ReduceTaskNodes
;
struct
ReduceTaskNodes
{
CompTaskNode
*
concat
=
nullptr
;
CompTaskNode
*
scatter
=
nullptr
;
CompTaskNode
*
local_add
=
nullptr
;
CompTaskNode
*
global_add
=
nullptr
;
CompTaskNode
*
gather
=
nullptr
;
CompTaskNode
*
split
=
nullptr
;
bool
operator
==
(
const
ReduceTaskNodes
&
rhs
)
const
{
return
this
->
concat
==
rhs
.
concat
&&
this
->
scatter
==
rhs
.
scatter
&&
this
->
local_add
==
rhs
.
local_add
&&
this
->
global_add
==
rhs
.
global_add
&&
this
->
gather
==
rhs
.
gather
&&
this
->
split
==
rhs
.
split
;
}
};
struct
ReduceTaskNodesHasher
{
std
::
size_t
operator
()(
const
ReduceTaskNodes
&
key
)
const
{
return
(
size_t
)(
key
.
concat
)
^
(
size_t
)(
key
.
scatter
)
^
(
size_t
)(
key
.
local_add
)
^
(
size_t
)(
key
.
global_add
)
^
(
size_t
)(
key
.
gather
)
^
(
size_t
)(
key
.
split
);
}
};
class
TaskGraph
final
:
public
Graph
<
TaskNode
,
TaskEdge
>
{
public:
...
...
@@ -24,7 +43,7 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
void
AddOrderingCtrlEdgeInSameChain
();
void
EnableMemSharingInReduceStruct
();
void
CollectReduceTaskNodes
(
HashMap
<
CompTaskNode
*
,
ReduceTaskNodes
>*
)
const
;
void
CollectReduceTaskNodes
(
std
::
unordered_set
<
ReduceTaskNodes
,
ReduceTaskNodesHasher
>*
)
const
;
void
EnableMemSharingInOneReduce
(
const
ReduceTaskNodes
&
);
void
AddCtrlEdge4MemSharingInOneReduce
(
const
ReduceTaskNodes
&
);
void
BuildCtrlRegstBetweenReduceCopyNodes
(
const
CompTaskNode
*
src_reduce
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录