Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
76f4f975
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
76f4f975
编写于
5月 24, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(sublinear): add SeqModifierBase
GitOrigin-RevId: 2d0393be6b950690c5960ac63bd47931a3afb324
上级
f584416a
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
438 addition
and
323 deletion
+438
-323
src/core/impl/graph/seq_modifier_base.cpp
src/core/impl/graph/seq_modifier_base.cpp
+161
-0
src/core/impl/graph/seq_modifier_base.h
src/core/impl/graph/seq_modifier_base.h
+237
-0
src/core/impl/graph/seq_sublinear_memory.cpp
src/core/impl/graph/seq_sublinear_memory.cpp
+24
-266
src/core/impl/graph/seq_sublinear_memory.h
src/core/impl/graph/seq_sublinear_memory.h
+16
-57
未找到文件。
src/core/impl/graph/seq_modifier_base.cpp
0 → 100644
浏览文件 @
76f4f975
/**
* \file src/core/impl/graph/seq_modifier_base.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./seq_modifier_base.h"
#if MGB_ENABLE_SUBLINEAR
using
namespace
mgb
;
using
namespace
cg
;
void
SeqModifierBase
::
ModifyActionPlannerBase
::
init_seq
(
const
OprNodeArray
&
opr_seq
)
{
m_orig_opr_seq
=
&
opr_seq
;
m_var_storage
.
clear
();
m_seq
.
clear
();
m_var_mempool
.
reorder_free
();
m_opr_mempool
.
reorder_free
();
m_nr_endpoint_oprs
=
0
;
ThinHashMap
<
VarNode
*
,
Var
*>
varmap
;
for
(
auto
orig_opr
:
*
m_orig_opr_seq
)
{
auto
time
=
m_seq
.
size
();
m_seq
.
emplace_back
(
m_opr_mempool
.
alloc_unique
(
orig_opr
,
time
));
auto
opr
=
m_seq
.
back
().
get
();
m_nr_endpoint_oprs
+=
opr
->
is_endpoint
;
for
(
auto
&&
dep
:
orig_opr
->
node_prop
().
dep_map
())
{
if
(
!
OperatorNodeBase
::
NodeProp
::
is_device_value_dep
(
dep
.
second
))
continue
;
auto
iter
=
varmap
.
find
(
dep
.
first
);
if
(
iter
==
varmap
.
end
())
{
// input var needs not to be considered
continue
;
}
auto
ivar
=
iter
->
second
;
bool
exist
=
false
;
for
(
auto
i
:
opr
->
input
)
{
if
(
i
==
ivar
)
{
exist
=
true
;
break
;
}
}
if
(
exist
)
{
// same var for different inputs
continue
;
}
opr
->
input
.
push_back
(
ivar
);
auto
&&
prev_rec
=
ivar
->
access_rec
.
back
();
prev_rec
.
stride
=
time
-
prev_rec
.
opr
->
time
;
ivar
->
access_rec
.
emplace_back
(
opr
);
}
for
(
auto
i
:
orig_opr
->
output
())
{
auto
var2memsize
=
m_par_modifier
->
m_mem_opt
.
var2memsize
();
auto
iter
=
var2memsize
->
find
(
i
);
if
(
iter
==
var2memsize
->
end
())
{
// some vars are ignored; see split_into_cn2oprseq()
continue
;
}
m_var_storage
.
emplace_back
(
m_var_mempool
.
alloc_unique
(
i
,
iter
->
second
,
opr
));
auto
ovar
=
m_var_storage
.
back
().
get
();
varmap
[
i
]
=
ovar
;
opr
->
output
.
push_back
(
ovar
);
}
mgb_assert
(
!
opr
->
output
.
empty
());
}
// remove unused output
for
(
auto
&&
i
:
m_seq
)
{
auto
&&
oarr
=
i
->
output
;
for
(
size_t
j
=
0
;
j
<
oarr
.
size
();)
{
if
(
oarr
[
j
]
->
access_rec
.
size
()
==
1
)
{
std
::
swap
(
oarr
[
j
],
oarr
.
back
());
oarr
.
pop_back
();
}
else
++
j
;
}
}
}
bool
SeqModifierBase
::
replace_vars
(
const
VarNodeArray
&
inputs
)
{
m_new_inputs
.
assign
(
inputs
.
begin
(),
inputs
.
end
());
bool
changed
=
false
;
for
(
auto
&&
i
:
m_new_inputs
)
{
auto
iter
=
m_var_map
.
find
(
i
);
if
(
iter
!=
m_var_map
.
end
())
{
i
=
iter
->
second
;
changed
=
true
;
}
}
return
changed
;
}
OperatorNodeBase
*
SeqModifierBase
::
copy_opr_from_new_inputs
(
OperatorNodeBase
*
opr
,
bool
recomp
,
size_t
recomp_cnt
)
{
auto
config
=
opr
->
config
();
// update operator instance id to bybass the shallow copy's cache if
// it's a dup-opr-copying due to discarding.
// Don't update instance id by `this` pointer if it's a recomp-opr-copying
// because:
// 0) recomp-opr would be copied iff its input vars is changed
// 1) some pair of recomp-opr and dup-opr have the same inputs, params
// and config, we use instance id to differentiate them.
config
.
name
(
opr
->
name
()
+
(
recomp
?
":recomp"
:
":dup"
)
+
std
::
to_string
(
recomp_cnt
));
config
.
update_instance_id
(
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
size_t
>
(
this
)
+
((
static_cast
<
size_t
>
(
recomp
)
+
1
)
<<
10
)
*
recomp_cnt
));
// Note: if all outputs of op were placed on the same comp_node, since its
// stream maybe changed during seq_comp_node_opt, output's comp_node has
// higher priority than opr->config()
auto
out_cn
=
opr
->
output
(
0
)
->
comp_node
();
for
(
auto
i
:
opr
->
output
())
{
auto
cn
=
i
->
comp_node
();
if
(
out_cn
!=
cn
)
{
out_cn
=
{};
break
;
}
}
if
(
out_cn
.
valid
())
config
.
comp_node
(
out_cn
);
auto
opr_new
=
serialization
::
copy_opr_shallow
(
*
opr
,
m_new_inputs
,
config
);
mgb_assert
(
opr_new
!=
opr
);
auto
&&
out0
=
opr
->
output
();
auto
&&
out1
=
opr_new
->
output
();
mgb_assert
(
out0
.
size
()
==
out1
.
size
());
bool
stream_changed
=
false
;
for
(
size_t
i
=
0
;
i
<
out0
.
size
();
++
i
)
{
auto
&&
cn0
=
out0
[
i
]
->
comp_node
(),
&&
cn1
=
out1
[
i
]
->
comp_node
();
if
(
cn0
!=
cn1
)
{
mgb_assert
(
recomp
);
mgb_assert
(
cn0
.
locator
().
type
==
cn1
.
locator
().
type
&&
cn0
.
locator
().
device
==
cn1
.
locator
().
device
);
out1
[
i
]
->
comp_node
(
cn0
);
stream_changed
=
true
;
}
m_var_map
[
out0
[
i
]]
=
out1
[
i
];
}
if
(
stream_changed
)
{
opr_new
->
on_output_comp_node_stream_changed
();
}
return
opr_new
;
}
#endif // MGB_ENABLE_SUBLINEAR
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
src/core/impl/graph/seq_modifier_base.h
0 → 100644
浏览文件 @
76f4f975
/**
* \file src/core/impl/graph/seq_modifier_base.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "./memory_optimizer.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/graph/cg.h"
#include "megbrain/plugin/opr_footprint.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/system.h"
#include "megbrain/utils/async_worker.h"
#include "megbrain/utils/arith_helper.h"
#include "megbrain/utils/mempool.h"
#include "megbrain/utils/timer.h"
#if MGB_ENABLE_SUBLINEAR
namespace
mgb
{
namespace
cg
{
/*!
* \brief modifying computing sequence, with basically the same idea of Training
* Deep Nets with Sublinear Memory Cost and Dynamic Tensor Rematerialization
*/
class
SeqModifierBase
{
public:
/*!
* describes modifications that should be applied to an operator sequnce:
* maps from an opr to the oprs that should be duplicated and inserted
* before it.
*/
using
SeqModifyAction
=
std
::
unordered_map
<
OperatorNodeBase
*
,
OprNodeArray
>
;
struct
Var
;
struct
Opr
;
class
ModifyActionPlannerBase
{
const
SeqModifierBase
*
const
m_par_modifier
;
const
OprNodeArray
*
m_orig_opr_seq
;
MemPool
<
Var
>
m_var_mempool
;
MemPool
<
Opr
>
m_opr_mempool
;
std
::
vector
<
MemPool
<
Var
>::
UniquePtr
>
m_var_storage
;
std
::
vector
<
MemPool
<
Opr
>::
UniquePtr
>
m_seq
;
size_t
m_nr_endpoint_oprs
=
0
;
public:
//! special creation time used for oprs duplicated from others
static
constexpr
size_t
DUPOPR_TIME
=
std
::
numeric_limits
<
size_t
>::
max
()
-
1
;
const
SeqModifierBase
*
const
par_modifier
()
{
return
m_par_modifier
;
}
const
OprNodeArray
*
const
orig_opr_seq
()
{
return
m_orig_opr_seq
;
}
MemPool
<
Var
>&
var_mempool
()
{
return
m_var_mempool
;
}
MemPool
<
Opr
>&
opr_mempool
()
{
return
m_opr_mempool
;
}
std
::
vector
<
MemPool
<
Var
>::
UniquePtr
>&
var_storage
()
{
return
m_var_storage
;
}
std
::
vector
<
MemPool
<
Opr
>::
UniquePtr
>&
seq
()
{
return
m_seq
;
}
size_t
&
nr_endpoint_oprs
()
{
return
m_nr_endpoint_oprs
;
}
ModifyActionPlannerBase
(
SeqModifierBase
*
par
)
:
m_par_modifier
{
par
}
{}
~
ModifyActionPlannerBase
()
noexcept
{
m_opr_mempool
.
disable_freelist
();
m_var_mempool
.
disable_freelist
();
}
//! init m_orig_opr_seq from opr_seq, should be called first.
void
init_seq
(
const
OprNodeArray
&
opr_seq
);
};
SeqModifierBase
(
ComputingGraphImpl
*
owner
)
:
m_mem_opt
(
owner
),
m_owner_graph
(
owner
)
{}
MemoryOptimizerHelper
&
mem_opt
()
{
return
m_mem_opt
;
}
ComputingGraphImpl
*
const
owner_graph
()
{
return
m_owner_graph
;
}
ThinHashMap
<
VarNode
*
,
VarNode
*>&
var_map
()
{
return
m_var_map
;
}
/*!
* \brief copy opr and set inputs to m_new_inputs, and add outputs in
* m_var_map
* \return new operator
*/
OperatorNodeBase
*
copy_opr_from_new_inputs
(
OperatorNodeBase
*
opr
,
bool
recomp
,
size_t
recomp_cnt
=
0
);
/*!
* \brief replace input vars according to m_var_map, and store results in
* m_new_inputs;
* \return whether any var is changed
*/
bool
replace_vars
(
const
VarNodeArray
&
inputs
);
//! see memory_optimizer set_priority_before_opt
void
set_priority_before_opt
(
const
VarNodeArray
&
endpoints
)
{
m_mem_opt
.
set_priority_before_opt
(
endpoints
);
}
//! see memory_optimizer restore_graph_option
void
restore_graph_option
()
{
m_mem_opt
.
restore_graph_option
();
}
private:
MemoryOptimizerHelper
m_mem_opt
;
ComputingGraphImpl
*
const
m_owner_graph
=
nullptr
;
//! map from original var to replaced var
ThinHashMap
<
VarNode
*
,
VarNode
*>
m_var_map
;
VarNodeArray
m_new_inputs
;
//!< setup by replace_vars
};
struct
SeqModifierBase
::
Opr
{
OperatorNodeBase
*
const
orig_opr
;
std
::
vector
<
Var
*>
input
,
output
;
const
size_t
time
;
//!< index in opr sequence
const
bool
is_endpoint
;
double
estimate_compute_time
=
1
;
//! input vars that have been discarded and need to be recomputed before
//! this opr; for internal use by apply_discard_plan()
std
::
vector
<
Var
*>
inputs_to_recompute
;
//! new oprs to be inserted before this opr; setup by apply_discard_plan()
std
::
vector
<
MemPool
<
Opr
>::
UniquePtr
>
oprs_insert_before
;
//! [begin, end) interval of *time* for oprs belonging to this block; setup
//! by make_discard_plan()
size_t
block_begin_time
=
0
,
block_end_time
=
0
;
Opr
(
OperatorNodeBase
*
opr
,
size_t
t
)
:
orig_opr
{
opr
},
time
{
t
},
is_endpoint
{
opr
->
owner_graph
()
->
options
()
.
opr_attribute
.
get_sublinear_memory_endpoint
(
opr
)}
{}
};
struct
SeqModifierBase
::
Var
{
VarNode
*
const
orig_var
;
size_t
size
;
//!< memory usage in bytes of this var
size_t
recomp_id
=
0
;
double
last_access_time
=
0
;
//! write or read access of a var
struct
AccessRecord
{
Opr
*
const
opr
;
const
size_t
time
;
size_t
stride
;
explicit
AccessRecord
(
Opr
*
o
=
nullptr
)
:
opr
{
o
},
time
{
o
->
time
},
stride
{
0
}
{}
};
//! access_rec[0] is the creation opr, and others are reader oprs
std
::
vector
<
AccessRecord
>
access_rec
;
/*!
* An index in access_rec
*
* if valid, then the var should be discarded after
* discard_tailing_access->opr finishes
*
* setup by make_discard_plan
*/
Maybe
<
size_t
>
discard_tailing_access
;
/*!
* An index in access_rec
* maintained during make_discard_plan(), for the next access relative to
* current operator
*/
Maybe
<
size_t
>
next_access
;
AccessRecord
*
visit_discard_tailing_access
()
{
return
discard_tailing_access
.
valid
()
?
&
access_rec
.
at
(
discard_tailing_access
.
val
())
:
nullptr
;
}
AccessRecord
*
visit_next_access
()
{
return
next_access
.
valid
()
?
&
access_rec
.
at
(
next_access
.
val
())
:
nullptr
;
}
auto
owner_opr
()
const
{
return
access_rec
[
0
].
opr
;
}
auto
last_access_opr
()
const
{
return
access_rec
.
back
().
opr
;
}
Var
(
VarNode
*
var
,
size_t
s
,
Opr
*
opr
)
:
orig_var
{
var
},
size
{
s
}
{
access_rec
.
emplace_back
(
opr
);
}
};
}
// namespace cg
}
// namespace mgb
#endif // MGB_ENABLE_SUBLINEAR
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
src/core/impl/graph/seq_sublinear_memory.cpp
浏览文件 @
76f4f975
...
...
@@ -61,108 +61,15 @@ bool is_bad_opr(OperatorNodeBase* opr) {
}
}
// namespace
/* ====================== Abstract Opr & Var ====================== */
struct
SeqModifierForSublinearMemory
::
Opr
{
OperatorNodeBase
*
const
orig_opr
;
std
::
vector
<
Var
*>
input
,
output
;
const
size_t
time
;
//!< index in opr sequence
const
bool
is_endpoint
;
//! input vars that have been discarded and need to be recomputed before
//! this opr; for internal use by apply_discard_plan()
std
::
vector
<
Var
*>
inputs_to_recompute
;
//! new oprs to be inserted before this opr; setup by apply_discard_plan()
std
::
vector
<
MemPool
<
Opr
>::
UniquePtr
>
oprs_insert_before
;
//! [begin, end) interval of *time* for oprs belonging to this block; setup
//! by make_discard_plan()
size_t
block_begin_time
=
0
,
block_end_time
=
0
;
Opr
(
OperatorNodeBase
*
opr
,
size_t
t
)
:
orig_opr
{
opr
},
time
{
t
},
is_endpoint
{
opr
->
owner_graph
()
->
options
()
.
opr_attribute
.
get_sublinear_memory_endpoint
(
opr
)}
{}
};
struct
SeqModifierForSublinearMemory
::
Var
{
//! write or read access of a var
struct
AccessRecord
{
Opr
*
const
opr
;
const
size_t
time
;
size_t
stride
;
//!< time distance until next read; 0 for last access
explicit
AccessRecord
(
Opr
*
o
=
nullptr
)
:
opr
{
o
},
time
{
o
->
time
},
stride
{
0
}
{}
};
VarNode
*
const
orig_var
;
const
size_t
size
;
//!< memory usage in bytes of this var
//! access_rec[0] is the creation opr, and others are reader oprs
std
::
vector
<
AccessRecord
>
access_rec
;
/*!
* An index in access_rec
*
* if valid, then the var should be discarded after
* discard_tailing_access->opr finishes
*
* setup by make_discard_plan
*/
Maybe
<
size_t
>
discard_tailing_access
;
/*!
* An index in access_rec
* maintained during make_discard_plan(), for the next access relative to
* current operator
*/
Maybe
<
size_t
>
next_access
;
AccessRecord
*
visit_discard_tailing_access
()
{
return
discard_tailing_access
.
valid
()
?
&
access_rec
.
at
(
discard_tailing_access
.
val
())
:
nullptr
;
}
AccessRecord
*
visit_next_access
()
{
return
next_access
.
valid
()
?
&
access_rec
.
at
(
next_access
.
val
())
:
nullptr
;
}
auto
owner_opr
()
const
{
return
access_rec
[
0
].
opr
;
}
auto
last_access_opr
()
const
{
return
access_rec
.
back
().
opr
;
}
Var
(
VarNode
*
var
,
size_t
s
,
Opr
*
opr
)
:
orig_var
{
var
},
size
{
s
}
{
access_rec
.
emplace_back
(
opr
);
}
};
/* ====================== ModifyActionPlanner ====================== */
class
SeqModifierForSublinearMemory
::
ModifyActionPlanner
{
//! special creation time used for oprs duplicated from others
static
constexpr
size_t
DUPOPR_TIME
=
std
::
numeric_limits
<
size_t
>::
max
()
-
1
;
class
SeqModifierForSublinearMemory
::
ModifyActionPlanner
:
public
ModifyActionPlannerBase
{
using
VarArray
=
std
::
vector
<
Var
*>
;
using
VarSet
=
ThinHashSet
<
Var
*>
;
using
OprArray
=
std
::
vector
<
Opr
*>
;
const
SeqModifierForSublinearMemory
*
const
m_par_modifier
;
const
OprNodeArray
*
m_orig_opr_seq
;
MemPool
<
Var
>
m_var_mempool
;
MemPool
<
Opr
>
m_opr_mempool
;
std
::
vector
<
MemPool
<
Var
>::
UniquePtr
>
m_var_storage
;
std
::
vector
<
MemPool
<
Opr
>::
UniquePtr
>
m_seq
;
size_t
m_nr_endpoint_oprs
=
0
;
VarSet
m_prev_block_discard_vars
;
std
::
vector
<
OprArray
>
m_blocks
;
SeqModifyAction
m_action
;
//! split_point_set to block
void
split_into_blocks
(
const
SplitPointSet
&
split_point_set
);
...
...
@@ -188,14 +95,7 @@ class SeqModifierForSublinearMemory::ModifyActionPlanner {
public:
ModifyActionPlanner
(
SeqModifierForSublinearMemory
*
par
)
:
m_par_modifier
{
par
}
{}
~
ModifyActionPlanner
()
noexcept
{
m_opr_mempool
.
disable_freelist
();
m_var_mempool
.
disable_freelist
();
}
//! init m_orig_opr_seq from opr_seq, should be called first.
void
init_seq
(
const
OprNodeArray
&
opr_seq
);
:
ModifyActionPlannerBase
{
par
}
{}
//! generate split point set from thresh
SplitPointSet
get_split_point_set
(
size_t
block_size_thresh
);
...
...
@@ -213,7 +113,7 @@ public:
void
SeqModifierForSublinearMemory
::
ModifyActionPlanner
::
get_prev_action
(
SeqModifyAction
&
action
)
{
action
.
clear
();
for
(
auto
&&
opr
:
m_seq
)
{
for
(
auto
&&
opr
:
seq
()
)
{
auto
&&
arr
=
opr
->
oprs_insert_before
;
if
(
arr
.
empty
())
continue
;
...
...
@@ -261,8 +161,8 @@ SeqModifierForSublinearMemory::ModifyActionPlanner::get_split_point_set(
cur_block_alive_vars
.
clear
();
};
for
(
size_t
i
=
0
;
i
<
m_seq
.
size
();
++
i
)
{
auto
opr
=
m_seq
[
i
].
get
();
for
(
size_t
i
=
0
;
i
<
seq
()
.
size
();
++
i
)
{
auto
opr
=
seq
()
[
i
].
get
();
for
(
auto
i
:
opr
->
output
)
add_alive
(
i
);
...
...
@@ -272,8 +172,8 @@ SeqModifierForSublinearMemory::ModifyActionPlanner::get_split_point_set(
remove_alive
(
i
);
}
if
(
i
+
1
<
m_seq
.
size
()
&&
(
cur_block_usage
<
block_size_thresh
||
(
m_nr_endpoint_oprs
&&
!
opr
->
is_endpoint
)))
if
(
i
+
1
<
seq
()
.
size
()
&&
(
cur_block_usage
<
block_size_thresh
||
(
nr_endpoint_oprs
()
&&
!
opr
->
is_endpoint
)))
continue
;
flush_block_member
(
i
);
...
...
@@ -281,81 +181,6 @@ SeqModifierForSublinearMemory::ModifyActionPlanner::get_split_point_set(
return
split_point_set
;
}
void
SeqModifierForSublinearMemory
::
ModifyActionPlanner
::
init_seq
(
const
OprNodeArray
&
opr_seq
)
{
m_orig_opr_seq
=
&
opr_seq
;
m_var_storage
.
clear
();
m_seq
.
clear
();
m_var_mempool
.
reorder_free
();
m_opr_mempool
.
reorder_free
();
m_nr_endpoint_oprs
=
0
;
ThinHashMap
<
VarNode
*
,
Var
*>
varmap
;
for
(
auto
orig_opr
:
*
m_orig_opr_seq
)
{
auto
time
=
m_seq
.
size
();
m_seq
.
emplace_back
(
m_opr_mempool
.
alloc_unique
(
orig_opr
,
time
));
auto
opr
=
m_seq
.
back
().
get
();
m_nr_endpoint_oprs
+=
opr
->
is_endpoint
;
for
(
auto
&&
dep
:
orig_opr
->
node_prop
().
dep_map
())
{
if
(
!
OperatorNodeBase
::
NodeProp
::
is_device_value_dep
(
dep
.
second
))
continue
;
auto
iter
=
varmap
.
find
(
dep
.
first
);
if
(
iter
==
varmap
.
end
())
{
// input var needs not to be considered
continue
;
}
auto
ivar
=
iter
->
second
;
bool
exist
=
false
;
for
(
auto
i
:
opr
->
input
)
{
if
(
i
==
ivar
)
{
exist
=
true
;
break
;
}
}
if
(
exist
)
{
// same var for different inputs
continue
;
}
opr
->
input
.
push_back
(
ivar
);
auto
&&
prev_rec
=
ivar
->
access_rec
.
back
();
prev_rec
.
stride
=
time
-
prev_rec
.
opr
->
time
;
ivar
->
access_rec
.
emplace_back
(
opr
);
}
for
(
auto
i
:
orig_opr
->
output
())
{
auto
var2memsize
=
m_par_modifier
->
m_mem_opt
.
var2memsize
();
auto
iter
=
var2memsize
->
find
(
i
);
if
(
iter
==
var2memsize
->
end
())
{
// some vars are ignored; see split_into_cn2oprseq()
continue
;
}
m_var_storage
.
emplace_back
(
m_var_mempool
.
alloc_unique
(
i
,
iter
->
second
,
opr
));
auto
ovar
=
m_var_storage
.
back
().
get
();
varmap
[
i
]
=
ovar
;
opr
->
output
.
push_back
(
ovar
);
}
mgb_assert
(
!
opr
->
output
.
empty
());
}
// remove unused output
for
(
auto
&&
i
:
m_seq
)
{
auto
&&
oarr
=
i
->
output
;
for
(
size_t
j
=
0
;
j
<
oarr
.
size
();)
{
if
(
oarr
[
j
]
->
access_rec
.
size
()
==
1
)
{
std
::
swap
(
oarr
[
j
],
oarr
.
back
());
oarr
.
pop_back
();
}
else
++
j
;
}
}
}
size_t
SeqModifierForSublinearMemory
::
ModifyActionPlanner
::
calc_bottleneck_from_discard_plan
()
{
size_t
cur_usage
=
0
,
max_usage
=
0
;
...
...
@@ -394,7 +219,7 @@ size_t SeqModifierForSublinearMemory::ModifyActionPlanner::
++
time
;
};
for
(
auto
&&
opr
:
m_seq
)
{
for
(
auto
&&
opr
:
seq
()
)
{
for
(
auto
&&
i
:
opr
->
oprs_insert_before
)
process_opr
(
i
.
get
());
process_opr
(
opr
.
get
());
...
...
@@ -480,7 +305,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() {
mgb_assert
(
opr
->
time
<
block_end
);
auto
new_opr_storage
=
m_opr_mempool
.
alloc_unique
(
auto
new_opr_storage
=
opr_mempool
()
.
alloc_unique
(
opr
->
orig_opr
,
static_cast
<
size_t
>
(
DUPOPR_TIME
));
auto
new_opr
=
new_opr_storage
.
get
();
...
...
@@ -497,7 +322,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() {
Var
*
new_var
=
nullptr
;
for
(
auto
i
:
opr
->
output
)
{
auto
&&
ovar
=
m_var_mempool
.
alloc_unique
(
i
->
orig_var
,
i
->
size
,
auto
&&
ovar
=
var_mempool
()
.
alloc_unique
(
i
->
orig_var
,
i
->
size
,
new_opr
);
new_opr
->
output
.
push_back
(
ovar
.
get
());
if
(
i
==
var
)
...
...
@@ -507,7 +332,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() {
auto
ins
=
var_map
.
insert
({
i
,
ovar
.
get
()});
mgb_assert
(
ins
.
second
);
m_var_storage
.
emplace_back
(
std
::
move
(
ovar
));
var_storage
()
.
emplace_back
(
std
::
move
(
ovar
));
}
mgb_assert
(
new_var
);
return
new_var
;
...
...
@@ -515,7 +340,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() {
add_dep
(
var
);
};
for
(
auto
&&
_raw_opr
:
m_seq
)
{
for
(
auto
&&
_raw_opr
:
seq
()
)
{
auto
opr
=
_raw_opr
.
get
();
for
(
auto
i
:
opr
->
inputs_to_recompute
)
...
...
@@ -640,8 +465,8 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::split_into_blocks(
m_blocks
.
clear
();
std
::
vector
<
Opr
*>
cur_block_member
;
size_t
i
,
j
;
for
(
i
=
j
=
0
;
i
<
m_seq
.
size
()
&&
j
<
split_point_set
->
size
();
++
i
)
{
auto
opr
=
m_seq
[
i
].
get
();
for
(
i
=
j
=
0
;
i
<
seq
()
.
size
()
&&
j
<
split_point_set
->
size
();
++
i
)
{
auto
opr
=
seq
()
[
i
].
get
();
cur_block_member
.
push_back
(
opr
);
if
(
i
!=
split_point_set
->
at
(
j
))
continue
;
...
...
@@ -649,7 +474,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::split_into_blocks(
cur_block_member
.
clear
();
j
++
;
}
mgb_assert
(
i
>=
m_seq
.
size
());
mgb_assert
(
i
>=
seq
()
.
size
());
mgb_assert
(
j
>=
split_point_set
->
size
());
}
...
...
@@ -1081,7 +906,7 @@ void SeqModifierForSublinearMemory::InternalDeleter::operator()(
}
void
SeqModifierForSublinearMemory
::
reset_opr_seq
(
const
OprNodeArray
&
oprseq
)
{
m_var_map
.
clear
();
var_map
()
.
clear
();
m_opr2replace_info
.
clear
();
auto
config
=
MemoryOptimizerHelper
::
SubGraphConfig
()
...
...
@@ -1099,7 +924,7 @@ void SeqModifierForSublinearMemory::reset_opr_seq(const OprNodeArray& oprseq) {
.
add_bad_var_flag
(
VarNode
::
Flag
::
NO_SYS_MEM_ALLOC
)
.
add_bad_var_flag
(
VarNode
::
Flag
::
PERSISTENT_DEVICE_VALUE
);
auto
cn2oprseq
=
m
_mem_opt
.
split_into_cn2oprseq
(
oprseq
,
config
);
auto
cn2oprseq
=
m
em_opt
()
.
split_into_cn2oprseq
(
oprseq
,
config
);
if
(
cn2oprseq
->
empty
())
{
// empty graph
...
...
@@ -1175,7 +1000,7 @@ void SeqModifierForSublinearMemory::apply_action(SeqModifyAction& action,
// each operator should be set no more than once
auto
set_priority
=
[
&
](
OperatorNodeBase
*
opr
)
{
mgb_assert
(
modified_opr
.
insert
(
opr
).
second
);
m
_mem_opt
.
set_priority
(
opr
,
cur_priority
++
);
m
em_opt
()
.
set_priority
(
opr
,
cur_priority
++
);
};
auto
on_opr_visited
=
[
&
](
OperatorNodeBase
*
opr
)
{
...
...
@@ -1218,80 +1043,13 @@ void SeqModifierForSublinearMemory::apply_action(SeqModifyAction& action,
mgb_assert
(
action
.
empty
());
}
bool
SeqModifierForSublinearMemory
::
replace_vars
(
const
VarNodeArray
&
inputs
)
{
m_new_inputs
.
assign
(
inputs
.
begin
(),
inputs
.
end
());
bool
changed
=
false
;
for
(
auto
&&
i
:
m_new_inputs
)
{
auto
iter
=
m_var_map
.
find
(
i
);
if
(
iter
!=
m_var_map
.
end
())
{
i
=
iter
->
second
;
changed
=
true
;
}
}
return
changed
;
}
OperatorNodeBase
*
SeqModifierForSublinearMemory
::
copy_opr_from_new_inputs
(
OperatorNodeBase
*
opr
,
bool
recomp
)
{
auto
config
=
opr
->
config
();
// update operator instance id to bybass the shallow copy's cache if
// it's a dup-opr-copying due to discarding.
// Don't update instance id by `this` pointer if it's a recomp-opr-copying
// because:
// 0) recomp-opr would be copied iff its input vars is changed
// 1) some pair of recomp-opr and dup-opr have the same inputs, params
// and config, we use instance id to differentiate them.
config
.
name
(
opr
->
name
()
+
(
recomp
?
":recomp"
:
":dup"
));
if
(
!
recomp
)
{
config
.
update_instance_id
(
this
);
}
// Note: if all outputs of op were placed on the same comp_node, since its
// stream maybe changed during seq_comp_node_opt, output's comp_node has
// higher priority than opr->config()
auto
out_cn
=
opr
->
output
(
0
)
->
comp_node
();
for
(
auto
i
:
opr
->
output
())
{
auto
cn
=
i
->
comp_node
();
if
(
out_cn
!=
cn
)
{
out_cn
=
{};
break
;
}
}
if
(
out_cn
.
valid
())
config
.
comp_node
(
out_cn
);
auto
opr_new
=
serialization
::
copy_opr_shallow
(
*
opr
,
m_new_inputs
,
config
);
mgb_assert
(
opr_new
!=
opr
);
auto
&&
out0
=
opr
->
output
();
auto
&&
out1
=
opr_new
->
output
();
mgb_assert
(
out0
.
size
()
==
out1
.
size
());
bool
stream_changed
=
false
;
for
(
size_t
i
=
0
;
i
<
out0
.
size
();
++
i
)
{
auto
&&
cn0
=
out0
[
i
]
->
comp_node
(),
&&
cn1
=
out1
[
i
]
->
comp_node
();
if
(
cn0
!=
cn1
)
{
mgb_assert
(
recomp
);
mgb_assert
(
cn0
.
locator
().
type
==
cn1
.
locator
().
type
&&
cn0
.
locator
().
device
==
cn1
.
locator
().
device
);
out1
[
i
]
->
comp_node
(
cn0
);
stream_changed
=
true
;
}
m_var_map
[
out0
[
i
]]
=
out1
[
i
];
}
if
(
stream_changed
)
{
opr_new
->
on_output_comp_node_stream_changed
();
}
return
opr_new
;
}
void
SeqModifierForSublinearMemory
::
modify_endpoint_vars
(
VarNodeArray
&
endpoints
)
{
auto
comp_seq
=
MemoryOptimizerHelper
::
CompSeq
(
m_owner_graph
,
endpoints
);
auto
comp_seq
=
MemoryOptimizerHelper
::
CompSeq
(
owner_graph
()
,
endpoints
);
reset_opr_seq
(
*
comp_seq
.
m_seq
);
for
(
auto
&&
i
:
endpoints
)
{
auto
iter
=
m_var_map
.
find
(
i
);
if
(
iter
!=
m_var_map
.
end
())
{
auto
iter
=
var_map
()
.
find
(
i
);
if
(
iter
!=
var_map
()
.
end
())
{
i
=
iter
->
second
;
}
}
...
...
@@ -1357,7 +1115,7 @@ SeqModifierForSublinearMemory::prev_min_bottleneck() {
SeqModifierForSublinearMemory
::
SeqModifierForSublinearMemory
(
ComputingGraphImpl
*
owner
,
Config
*
config_p
)
:
m_config
(
config_p
),
m_mem_opt
(
owner
),
m_owner_graph
(
owner
)
{}
:
SeqModifierBase
(
owner
),
m_config
(
config_p
)
{}
#endif // !MGB_ENABLE_SUBLINEAR
...
...
src/core/impl/graph/seq_sublinear_memory.h
浏览文件 @
76f4f975
...
...
@@ -12,6 +12,7 @@
#pragma once
#include "./memory_optimizer.h"
#include "./seq_modifier_base.h"
#include "megbrain/graph/cg.h"
#include "megbrain/utils/async_worker.h"
...
...
@@ -23,28 +24,31 @@ namespace cg {
* \brief modifying computing sequence, with basically the same idea of Training
* Deep Nets with Sublinear Memory Cost
*/
class
SeqModifierForSublinearMemory
{
/*!
* describes modifications that should be applied to an operator sequnce:
* maps from an opr to the oprs that should be duplicated and inserted
* before it.
*/
using
SeqModifyAction
=
std
::
unordered_map
<
OperatorNodeBase
*
,
OprNodeArray
>
;
using
SplitPointSet
=
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
;
class
SeqModifierForSublinearMemory
:
public
SeqModifierBase
{
//! Config options
using
Config
=
mgb
::
cg
::
ComputingGraph
::
Options
::
SublinearMemConfig
;
Config
*
m_config
;
public:
SeqModifierForSublinearMemory
(
ComputingGraphImpl
*
owner
,
Config
*
config_g
);
//! replace endpoint vars by the ones that require more computing
void
modify_endpoint_vars
(
VarNodeArray
&
endpoints
);
//! check whether actual opr_seq is what we expect; throw InternalError
void
sanity_check
(
const
OprNodeArray
&
opr_seq
);
const
CompNode
::
UnorderedMap
<
size_t
>&
prev_min_bottleneck
();
private:
using
SplitPointSet
=
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
;
//! get modifications to be taken under some specific constraints
class
ModifyActionPlanner
;
//! search best modify action for opr seq on a single comp node
class
ActionSearcherSingleCN
;
struct
Opr
;
struct
Var
;
struct
InternalDeleter
{
void
operator
()(
ActionSearcherSingleCN
*
)
const
;
void
operator
()(
ModifyActionPlanner
*
)
const
;
...
...
@@ -67,32 +71,8 @@ class SeqModifierForSublinearMemory {
//! thread pool to run ModifyActionPlanner
FutureThreadPool
<
void
>
m_planner_thread_pool
;
//! map from original var to replaced var
ThinHashMap
<
VarNode
*
,
VarNode
*>
m_var_map
;
VarNodeArray
m_new_inputs
;
//!< setup by replace_vars
MemoryOptimizerHelper
m_mem_opt
;
ComputingGraphImpl
*
const
m_owner_graph
=
nullptr
;
CompNode
::
UnorderedMap
<
size_t
>
m_prev_min_bottleneck
;
/*!
* \brief replace input vars according to m_var_map, and store results in
* m_new_inputs;
* \return whether any var is changed
*/
bool
replace_vars
(
const
VarNodeArray
&
inputs
);
/*!
* \brief copy opr and set inputs to m_new_inputs, and add outputs in
* m_var_map
* \return new operator
*/
OperatorNodeBase
*
copy_opr_from_new_inputs
(
OperatorNodeBase
*
opr
,
bool
recomp
);
//! restore computing sequence and modify operator priority
void
reset_opr_seq
(
const
OprNodeArray
&
oprseq
);
...
...
@@ -107,27 +87,6 @@ class SeqModifierForSublinearMemory {
return
std
::
make_shared
<
SplitPointSet
::
element_type
>
(
std
::
forward
<
Args
>
(
args
)...);
}
public:
SeqModifierForSublinearMemory
(
ComputingGraphImpl
*
owner
,
Config
*
config_g
);
//! see memory_optimizer set_priority_before_opt
void
set_priority_before_opt
(
const
VarNodeArray
&
endpoints
)
{
m_mem_opt
.
set_priority_before_opt
(
endpoints
);
}
//! see memory_optimizer restore_graph_option
void
restore_graph_option
()
{
m_mem_opt
.
restore_graph_option
();
}
//! replace endpoint vars by the ones that require more computing
void
modify_endpoint_vars
(
VarNodeArray
&
endpoints
);
//! check whether actual opr_seq is what we expect; throw InternalError
void
sanity_check
(
const
OprNodeArray
&
opr_seq
);
const
CompNode
::
UnorderedMap
<
size_t
>&
prev_min_bottleneck
();
};
}
// namespace cg
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录