Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
a3dd5a34
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
a3dd5a34
编写于
1月 19, 2019
作者:
L
Li Xinqi
提交者:
GitHub
1月 19, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
faster improver (#1628)
Former-commit-id: 2550030088fb6b15f1784a2bd1cfb78eeabe3b0d
上级
88a04c86
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
31 addition
and
25 deletion
+31
-25
oneflow/core/graph/regst_lifetime_graph.cpp
oneflow/core/graph/regst_lifetime_graph.cpp
+7
-7
oneflow/core/graph/regst_lifetime_graph.h
oneflow/core/graph/regst_lifetime_graph.h
+5
-5
oneflow/core/graph/sharable_mem_block_graph.cpp
oneflow/core/graph/sharable_mem_block_graph.cpp
+3
-3
oneflow/core/job/improver.cpp
oneflow/core/job/improver.cpp
+16
-10
未找到文件。
oneflow/core/graph/regst_lifetime_graph.cpp
浏览文件 @
a3dd5a34
...
...
@@ -3,17 +3,17 @@
namespace
oneflow
{
RegstLifetimeGraph
::
RegstLifetimeGraph
(
const
std
::
list
<
const
RegstDescProto
*>&
regst_descs
,
const
std
::
vector
<
const
RegstDescProto
*>&
regst_descs
,
const
std
::
function
<
void
(
const
RegstDescProto
*
,
HashSet
<
int64_t
>*
)
>&
ComputeLifetimeActorIds
)
{
std
::
list
<
RegstLifetimeNode
*>
nodes
;
std
::
vector
<
RegstLifetimeNode
*>
nodes
;
InitNodes
(
regst_descs
,
ComputeLifetimeActorIds
,
&
nodes
);
InitEdges
(
nodes
);
}
void
RegstLifetimeGraph
::
InitNodes
(
const
std
::
list
<
const
RegstDescProto
*>&
regst_descs
,
const
std
::
vector
<
const
RegstDescProto
*>&
regst_descs
,
const
std
::
function
<
void
(
const
RegstDescProto
*
,
HashSet
<
int64_t
>*
)
>&
ComputeLifetimeActorIds
,
std
::
list
<
RegstLifetimeNode
*>*
nodes
)
{
std
::
vector
<
RegstLifetimeNode
*>*
nodes
)
{
for
(
const
RegstDescProto
*
regst_desc
:
regst_descs
)
{
auto
lifetime_actor_ids
=
std
::
make_unique
<
HashSet
<
int64_t
>>
();
ComputeLifetimeActorIds
(
regst_desc
,
lifetime_actor_ids
.
get
());
...
...
@@ -23,7 +23,7 @@ void RegstLifetimeGraph::InitNodes(
}
}
void
RegstLifetimeGraph
::
InitEdges
(
const
std
::
list
<
RegstLifetimeNode
*>&
nodes
)
{
void
RegstLifetimeGraph
::
InitEdges
(
const
std
::
vector
<
RegstLifetimeNode
*>&
nodes
)
{
HashMap
<
int64_t
,
HashSet
<
RegstLifetimeNode
*>>
task_id2intersected_nodes
;
for
(
RegstLifetimeNode
*
node
:
nodes
)
{
for
(
int64_t
task_id
:
node
->
lifetime_actor_ids
())
{
...
...
@@ -46,7 +46,7 @@ void RegstLifetimeGraph::InitEdges(const std::list<RegstLifetimeNode*>& nodes) {
}
void
RegstLifetimeGraph
::
ForEachSameColoredRegstDescs
(
const
std
::
function
<
void
(
const
std
::
list
<
const
RegstDescProto
*>&
)
>&
Handler
)
const
{
const
std
::
function
<
void
(
const
std
::
vector
<
const
RegstDescProto
*>&
)
>&
Handler
)
const
{
std
::
vector
<
const
RegstLifetimeNode
*>
nodes
;
ForEachNode
([
&
](
const
RegstLifetimeNode
*
node
)
{
nodes
.
push_back
(
node
);
});
std
::
sort
(
nodes
.
begin
(),
nodes
.
end
(),
...
...
@@ -65,7 +65,7 @@ void RegstLifetimeGraph::ForEachSameColoredRegstDescs(
node2excluded_color_ids
[
intersected
].
insert
(
color_id
);
});
}
HashMap
<
int32_t
,
std
::
list
<
const
RegstDescProto
*>>
color_id2regst_descs
;
HashMap
<
int32_t
,
std
::
vector
<
const
RegstDescProto
*>>
color_id2regst_descs
;
for
(
const
auto
&
pair
:
node2color_id
)
{
color_id2regst_descs
[
pair
.
second
].
push_back
(
&
pair
.
first
->
regst_desc
());
}
...
...
oneflow/core/graph/regst_lifetime_graph.h
浏览文件 @
a3dd5a34
...
...
@@ -43,19 +43,19 @@ class RegstLifetimeGraph final : public Graph<const RegstLifetimeNode, RegstLife
public:
OF_DISALLOW_COPY_AND_MOVE
(
RegstLifetimeGraph
);
RegstLifetimeGraph
(
const
std
::
list
<
const
RegstDescProto
*>&
regst_descs
,
const
std
::
vector
<
const
RegstDescProto
*>&
regst_descs
,
const
std
::
function
<
void
(
const
RegstDescProto
*
,
HashSet
<
int64_t
>*
)
>&
ComputeLifetimeActorIds
);
~
RegstLifetimeGraph
()
=
default
;
void
ForEachSameColoredRegstDescs
(
const
std
::
function
<
void
(
const
std
::
list
<
const
RegstDescProto
*>&
)
>&
Handler
)
const
;
const
std
::
function
<
void
(
const
std
::
vector
<
const
RegstDescProto
*>&
)
>&
Handler
)
const
;
private:
void
InitNodes
(
const
std
::
list
<
const
RegstDescProto
*>&
regst_descs
,
const
std
::
vector
<
const
RegstDescProto
*>&
regst_descs
,
const
std
::
function
<
void
(
const
RegstDescProto
*
,
HashSet
<
int64_t
>*
)
>&
ComputeLifetimeActorIds
,
std
::
list
<
RegstLifetimeNode
*>*
nodes
);
void
InitEdges
(
const
std
::
list
<
RegstLifetimeNode
*>&
nodes
);
std
::
vector
<
RegstLifetimeNode
*>*
nodes
);
void
InitEdges
(
const
std
::
vector
<
RegstLifetimeNode
*>&
nodes
);
};
}
// namespace oneflow
...
...
oneflow/core/graph/sharable_mem_block_graph.cpp
浏览文件 @
a3dd5a34
...
...
@@ -69,11 +69,11 @@ SharableMemBlockGraph::SharableMemBlockGraph(
void
SharableMemBlockGraph
::
ForEachSourceNodeGroup
(
const
std
::
function
<
int64_t
(
const
SharableMemBlockNode
*
)
>&
GroupBy
,
const
std
::
function
<
void
(
const
std
::
vector
<
const
SharableMemBlockNode
*>&
)
>&
Handler
)
const
{
HashMap
<
int64_t
,
std
::
vector
<
const
SharableMemBlockNode
*>>
chain_id
2source_nodes
;
HashMap
<
int64_t
,
std
::
vector
<
const
SharableMemBlockNode
*>>
group_key
2source_nodes
;
for
(
const
SharableMemBlockNode
*
source
:
source_nodes
())
{
chain_id
2source_nodes
[
GroupBy
(
source
)].
push_back
(
source
);
group_key
2source_nodes
[
GroupBy
(
source
)].
push_back
(
source
);
}
for
(
const
auto
&
pair
:
chain_id
2source_nodes
)
{
Handler
(
pair
.
second
);
}
for
(
const
auto
&
pair
:
group_key
2source_nodes
)
{
Handler
(
pair
.
second
);
}
}
}
// namespace oneflow
oneflow/core/job/improver.cpp
浏览文件 @
a3dd5a34
...
...
@@ -29,8 +29,9 @@ bool IsConsumersAndProducerInSameChain(const RegstDescProto& regst_desc,
}
void
ForEachSharableStreamRegstDescsWithoutConsumer
(
const
Plan
&
plan
,
const
std
::
function
<
void
(
const
std
::
list
<
const
RegstDescProto
*>&
)
>&
Handler
)
{
HashMap
<
int64_t
,
std
::
list
<
const
RegstDescProto
*>>
global_work_stream_id2regst_descs
;
const
Plan
&
plan
,
const
std
::
function
<
void
(
const
std
::
vector
<
const
RegstDescProto
*>&
)
>&
Handler
)
{
HashMap
<
int64_t
,
std
::
vector
<
const
RegstDescProto
*>>
global_work_stream_id2regst_descs
;
for
(
const
auto
&
task
:
plan
.
task
())
{
int64_t
global_work_stream_id
=
Global
<
IDMgr
>::
Get
()
->
GlobalWorkStreamId4TaskId
(
task
.
task_id
());
for
(
const
auto
&
pair
:
task
.
produced_regst_desc
())
{
...
...
@@ -45,20 +46,21 @@ void ForEachSharableStreamRegstDescsWithoutConsumer(
}
void
ForEachSameColoredStreamRegstDescWithoutConsumer
(
const
Plan
&
plan
,
const
std
::
function
<
void
(
const
std
::
list
<
const
RegstDescProto
*>&
)
>&
Handler
)
{
const
Plan
&
plan
,
const
std
::
function
<
void
(
const
std
::
vector
<
const
RegstDescProto
*>&
)
>&
Handler
)
{
auto
GetProducerTaskId
=
[](
const
RegstDescProto
*
regst_desc
,
HashSet
<
int64_t
>*
ret_actor_ids
)
{
CHECK
(
regst_desc
->
enable_mem_sharing
());
ret_actor_ids
->
insert
(
regst_desc
->
producer_task_id
());
};
ForEachSharableStreamRegstDescsWithoutConsumer
(
plan
,
[
&
](
const
std
::
list
<
const
RegstDescProto
*>&
regst_descs
)
{
plan
,
[
&
](
const
std
::
vector
<
const
RegstDescProto
*>&
regst_descs
)
{
RegstLifetimeGraph
(
regst_descs
,
GetProducerTaskId
).
ForEachSameColoredRegstDescs
(
Handler
);
});
}
void
ForEachSameColoredChainRegstDescWithConsumer
(
const
PlanTaskGraph
&
plan_task_graph
,
const
std
::
function
<
void
(
const
std
::
list
<
const
RegstDescProto
*>&
)
>&
Handler
)
{
const
std
::
function
<
void
(
const
std
::
vector
<
const
RegstDescProto
*>&
)
>&
Handler
)
{
// construct SharableMemBlockGraph
auto
ChainId4TaskId
=
[
&
](
int64_t
task_id
)
{
return
plan_task_graph
.
TaskProto4TaskId
(
task_id
)
->
task_set_info
().
chain_id
();
...
...
@@ -92,7 +94,7 @@ void ForEachSameColoredChainRegstDescWithConsumer(
header2members
.
emplace
(
regst_descs
.
at
(
0
),
regst_descs
);
}
auto
GetRegstDescs
=
[
&
](
const
std
::
vector
<
const
SharableMemBlockNode
*>&
sharable_mem_blocks
)
{
std
::
list
<
const
RegstDescProto
*>
ret
;
std
::
vector
<
const
RegstDescProto
*>
ret
;
for
(
const
SharableMemBlockNode
*
sharable_mem_block
:
sharable_mem_blocks
)
{
for
(
const
RegstDescProto
*
regst_desc
:
sharable_mem_block
->
regst_descs
())
{
if
(
header2members
.
find
(
regst_desc
)
!=
header2members
.
end
())
{
...
...
@@ -111,8 +113,8 @@ void ForEachSameColoredChainRegstDescWithConsumer(
plan_task_graph
.
ComputeLifetimeSameChainActorIds
(
member
,
ret_actor_ids
);
}
};
auto
AppendGroupMembers
=
[
&
](
const
std
::
list
<
const
RegstDescProto
*>&
regst_descs
)
{
std
::
list
<
const
RegstDescProto
*>
members
;
auto
AppendGroupMembers
=
[
&
](
const
std
::
vector
<
const
RegstDescProto
*>&
regst_descs
)
{
std
::
vector
<
const
RegstDescProto
*>
members
;
for
(
const
auto
*
header
:
regst_descs
)
{
for
(
const
auto
*
member
:
header2members
.
at
(
header
))
{
members
.
push_back
(
member
);
}
}
...
...
@@ -121,6 +123,11 @@ void ForEachSameColoredChainRegstDescWithConsumer(
sharable_mem_block_gph
.
ForEachSourceNodeGroup
(
&
SharableMemBlockNode
::
chain_id
,
[
&
](
const
std
::
vector
<
const
SharableMemBlockNode
*>&
sharable_mem_blocks
)
{
if
(
sharable_mem_blocks
.
size
()
==
1
)
{
const
auto
&
regst_descs
=
sharable_mem_blocks
.
at
(
0
)
->
regst_descs
();
if
(
regst_descs
.
size
()
>
1
)
{
Handler
(
regst_descs
);
}
return
;
}
RegstLifetimeGraph
(
GetRegstDescs
(
sharable_mem_blocks
),
ComputeLifetimeSameChainActorIds
)
.
ForEachSameColoredRegstDescs
(
AppendGroupMembers
);
});
...
...
@@ -128,9 +135,8 @@ void ForEachSameColoredChainRegstDescWithConsumer(
void
ForEachImprovedMemSharedId
(
const
PlanTaskGraph
&
plan_task_graph
,
const
std
::
function
<
void
(
int64_t
,
int64_t
)
>&
Handler
)
{
using
RegstDescs
=
std
::
list
<
const
RegstDescProto
*>
;
const
Plan
&
plan
=
plan_task_graph
.
plan
();
auto
HandleMemSharedId
=
[
&
](
const
RegstDescs
&
regst_descs
)
{
auto
HandleMemSharedId
=
[
&
](
const
std
::
vector
<
const
RegstDescProto
*>
&
regst_descs
)
{
int64_t
mem_shared_id
=
Global
<
IDMgr
>::
Get
()
->
NewMemSharedId
();
for
(
const
RegstDescProto
*
regst_desc
:
regst_descs
)
{
Handler
(
regst_desc
->
regst_desc_id
(),
mem_shared_id
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录