Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
3eb529d6
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
3eb529d6
编写于
1月 19, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(interpreter): recognize recomp on main thread rather than worker
GitOrigin-RevId: 4ba3942ce475284b2adb14832e3c136aec602016
上级
df976782
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
155 addition
and
173 deletion
+155
-173
imperative/src/impl/interpreter_impl.cpp
imperative/src/impl/interpreter_impl.cpp
+89
-143
imperative/src/impl/interpreter_impl.h
imperative/src/impl/interpreter_impl.h
+66
-30
未找到文件。
imperative/src/impl/interpreter_impl.cpp
浏览文件 @
3eb529d6
...
...
@@ -35,7 +35,6 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
info
->
desc
.
comp_node
=
value
.
comp_node
();
info
->
desc
.
value
=
value
.
proxy_to_default_cpu
();
info
->
h_value
=
value
;
m_valid_handle
.
insert
(
info
);
m_buffer
.
enqueue
(
Put
{
info
,
value
,
no_cache
});
if
(
m_async_level
==
0
)
{
sync
();
...
...
@@ -49,20 +48,25 @@ Handle ChannelImpl::put(const DeviceTensorND& data) {
info
->
desc
.
layout
=
data
.
layout
();
info
->
desc
.
comp_node
=
data
.
comp_node
();
info
->
ptr
=
Tensor
::
make
(
data
);
m_valid_handle
.
insert
(
info
);
return
info
;
}
void
ChannelImpl
::
del
(
Handle
handle
)
{
mgb_assert
(
m_valid_handle
.
erase
(
handle
),
"invalid handle: %p"
,
handle
);
m_buffer
.
enqueue
(
Del
{
reinterpret_cast
<
TensorInfo
*>
(
handle
)});
mgb_assert
(
m_valid_handle
.
count
(
handle
),
"invalid handle: %p"
,
handle
);
auto
*
info
=
reinterpret_cast
<
TensorInfo
*>
(
handle
);
detach_users
(
info
);
info
->
detach_producer
();
m_valid_handle
.
erase
(
handle
);
m_buffer
.
enqueue
(
Del
{
info
});
}
void
ChannelImpl
::
swap_in
(
Handle
handle
)
{
if
(
m_enable_evict
&
SWAP
)
{
mgb_assert
(
m_valid_handle
.
find
(
handle
)
!=
m_valid_handle
.
end
(),
"invalid handle: %p"
,
handle
);
m_buffer
.
enqueue
(
SwapIn
{
reinterpret_cast
<
TensorInfo
*>
(
handle
)});
auto
*
info
=
reinterpret_cast
<
TensorInfo
*>
(
handle
);
m_buffer
.
enqueue
(
SwapIn
{
info
});
info
->
evict_type
=
NONE
;
}
}
...
...
@@ -70,7 +74,9 @@ void ChannelImpl::swap_out(Handle handle) {
if
(
m_enable_evict
&
SWAP
)
{
mgb_assert
(
m_valid_handle
.
find
(
handle
)
!=
m_valid_handle
.
end
(),
"invalid handle: %p"
,
handle
);
m_buffer
.
enqueue
(
SwapOut
{
reinterpret_cast
<
TensorInfo
*>
(
handle
)});
auto
*
info
=
reinterpret_cast
<
TensorInfo
*>
(
handle
);
m_buffer
.
enqueue
(
SwapOut
{
info
});
info
->
evict_type
=
SWAP
;
}
}
...
...
@@ -78,7 +84,13 @@ void ChannelImpl::drop(Handle handle) {
if
(
m_enable_evict
&
DROP
)
{
mgb_assert
(
m_valid_handle
.
find
(
handle
)
!=
m_valid_handle
.
end
(),
"invalid handle: %p"
,
handle
);
m_buffer
.
enqueue
(
Drop
{
reinterpret_cast
<
TensorInfo
*>
(
handle
)});
auto
*
info
=
reinterpret_cast
<
TensorInfo
*>
(
handle
);
if
(
!
info
->
producer
)
{
mgb_log_warn
(
"the input that produced tensor %p has been deleted, this drop operation will be ignored"
,
info
);
return
;
}
info
->
evict_type
=
DROP
;
m_buffer
.
enqueue
(
Drop
{
info
});
}
}
...
...
@@ -134,18 +146,8 @@ void ChannelImpl::dispatch_default_cpu(
output_infos
.
push_back
(
info
);
outputs
->
push_back
(
info
);
}
if
(
m_enable_evict
&
DROP
)
{
for
(
auto
out
:
output_infos
)
{
out
->
path
.
op
=
op
;
for
(
auto
out_
:
output_infos
)
{
out
->
path
.
outputs
.
push_back
(
m_st
.
at
(
out_
));
}
for
(
auto
inp
:
input_infos
)
{
out
->
path
.
inputs
.
push_back
(
m_st
.
at
(
inp
));
inp
->
path
.
dep_outputs
.
push_back
(
m_st
.
at
(
out
));
}
}
TensorInfo
::
ComputePath
::
make
(
op
,
input_infos
,
output_infos
);
}
}
...
...
@@ -168,21 +170,11 @@ void ChannelImpl::dispatch_kernel(
info
->
h_value
=
HostTensorND
::
make_proxy
(
desc
.
value
)
.
proxy_to_comp_node
(
desc
.
comp_node
);
}
m_valid_handle
.
insert
(
info
);
cmd
.
outputs
.
push_back
(
info
);
outputs
->
push_back
(
info
);
}
if
(
m_enable_evict
&
DROP
)
{
for
(
auto
out
:
cmd
.
outputs
)
{
out
->
path
.
op
=
cmd
.
op
;
for
(
auto
out_
:
cmd
.
outputs
)
{
out
->
path
.
outputs
.
push_back
(
m_st
.
at
(
out_
));
}
for
(
auto
inp
:
cmd
.
inputs
)
{
out
->
path
.
inputs
.
push_back
(
m_st
.
at
(
inp
));
inp
->
path
.
dep_outputs
.
push_back
(
m_st
.
at
(
out
));
}
}
TensorInfo
::
ComputePath
::
make
(
cmd
.
op
,
cmd
.
inputs
,
cmd
.
outputs
);
}
m_buffer
.
enqueue
(
std
::
move
(
cmd
));
if
(
!
validated
&&
m_async_level
==
1
)
{
...
...
@@ -215,6 +207,7 @@ SmallVector<Handle> ChannelImpl::apply_op(
mgb_assert
(
!
info
->
invalid
,
"Invalid tensor, unable to apply_op!"
);
input_infos
.
push_back
(
info
);
input_descs
.
push_back
(
info
->
desc
);
regenerate
(
info
);
}
}
...
...
@@ -233,23 +226,31 @@ SmallVector<Handle> ChannelImpl::apply_op(
}
HostTensorND
ChannelImpl
::
get_value
(
Handle
handle
)
{
// TODO: maybe get_value should be done on host. i.e. delete GetValue
mgb_assert
(
m_valid_handle
.
find
(
handle
)
!=
m_valid_handle
.
end
(),
"invalid handle: %p"
,
handle
);
auto
info
=
reinterpret_cast
<
TensorInfo
*>
(
handle
);
std
::
unique_lock
<
decltype
(
m_mutex
)
>
lock
(
m_mutex
);
mgb_assert
(
!
m_waitee
);
if
(
!
info
->
value_fetched
)
{
mgb_assert
(
!
info
->
invalid
,
"Invalid tensor, unable to get_value!"
);
// donnot use info->value_fetched, it's unsafe
mgb_assert
(
!
info
->
invalid
,
"Invalid tensor, unable to get_value!"
);
TensorPtr
tensor_ptr
=
info
->
ptr
;
auto
value_fetched
=
[
&
]()
{
return
tensor_ptr
&&
tensor_ptr
->
value_fetched
();
};
if
(
!
value_fetched
())
{
std
::
unique_lock
<
decltype
(
m_mutex
)
>
lock
(
m_mutex
);
m_waitee
=
info
;
regenerate
(
info
);
m_buffer
.
enqueue
(
GetValue
{
info
});
m_cv
.
wait
(
lock
,
[
&
]()
{
check_worker_exc_unsafe
();
return
info
->
value_fetched
;
// get tensor ptr in lock to ensure safety
tensor_ptr
=
info
->
ptr
;
return
value_fetched
();
});
m_waitee
=
nullptr
;
}
mgb_assert
(
info
->
ptr
->
value_fetched
());
return
info
->
ptr
->
get_value
();
return
tensor_ptr
->
get_value
();
}
TensorShape
ChannelImpl
::
get_shape
(
Handle
handle
)
{
...
...
@@ -298,6 +299,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
std
::
unique_lock
<
decltype
(
m_mutex
)
>
lock
(
m_mutex
);
mgb_assert
(
!
m_waitee
);
m_waitee
=
info
;
regenerate
(
info
);
m_buffer
.
enqueue
(
Flush
{
info
});
m_cv
.
wait
(
lock
,
[
&
]()
{
check_worker_exc_unsafe
();
...
...
@@ -332,17 +334,12 @@ int ChannelImpl::get_async_level() {
TensorInfo
*
ChannelImpl
::
alloc
()
{
MGB_LOCK_GUARD
(
m_mutex
);
auto
info
=
m_pool
.
alloc
();
m_
st
.
insert
(
info
);
m_
valid_handle
.
insert
(
info
);
return
info
;
}
void
ChannelImpl
::
free
(
TensorInfo
*
ptr
)
{
MGB_LOCK_GUARD
(
m_mutex
);
if
(
ptr
->
path
.
dep_outputs
.
size
()
>
0
)
{
remove_dep
(
ptr
);
}
m_st
.
erase
(
ptr
);
mgb_assert
(
ptr
->
allow_delete
,
"delete before ref_cnt = 0"
);
m_pool
.
free
(
ptr
);
}
...
...
@@ -350,77 +347,64 @@ ChannelImpl::~ChannelImpl() {
close
();
}
void
ChannelImpl
::
produce_tensor
(
TensorInfo
*
dest
,
TensorPtr
ptr
,
bool
notice
=
true
)
{
auto
lock
=
notice
?
std
::
unique_lock
<
std
::
mutex
>
(
m_mutex
)
:
std
::
unique_lock
<
std
::
mutex
>
();
void
ChannelImpl
::
produce_tensor
(
TensorInfo
*
dest
,
TensorPtr
ptr
)
{
MGB_LOCK_GUARD
(
m_mutex
);
dest
->
value_fetched
=
ptr
->
value_fetched
();
// update tensor desc for static infer
dest
->
desc
.
layout
=
ptr
->
layout
();
dest
->
desc
.
comp_node
=
ptr
->
comp_node
();
dest
->
ptr
=
std
::
move
(
ptr
);
if
(
notice
&&
m_waitee
==
dest
)
{
if
(
m_waitee
==
dest
)
{
m_cv
.
notify_all
();
}
}
void
ChannelImpl
::
do_swap_out
(
TensorInfo
*
dest
)
{
void
ChannelImpl
::
regenerate
(
TensorInfo
*
dest
)
{
if
(
dest
->
evict_type
==
DROP
)
{
mgb_log_warn
(
"the evict type of tensor %p was set to DROP, this SWAP operation will be ignored"
,
dest
);
return
;
}
if
(
!
dest
->
ptr
)
{
return
;
recompute
(
dest
->
producer
);
}
else
if
(
dest
->
evict_type
==
SWAP
)
{
swap_in
(
dest
);
}
dest
->
evict_type
=
SWAP
;
dest
->
value_fetched
=
false
;
// TODO: swap in parallel
dest
->
h_value
=
dest
->
ptr
->
get_value
();
dest
->
ptr
.
reset
();
mgb_assert
(
dest
->
evict_type
==
NONE
);
}
void
ChannelImpl
::
do_swap_in
(
TensorInfo
*
dest
)
{
if
(
dest
->
ptr
)
{
return
;
}
if
(
dest
->
h_value
.
empty
())
{
mgb_log_error
(
"backup of the tensor %p not found"
,
dest
);
return
;
void
ChannelImpl
::
recompute
(
TensorInfo
::
ComputePath
*
path
)
{
SmallVector
<
TensorInfo
*>
workspaces
(
path
->
outputs
.
size
(),
nullptr
);
for
(
auto
&&
input
:
path
->
inputs
)
{
regenerate
(
input
);
}
produce_tensor
(
dest
,
Tensor
::
make
(
dest
->
h_value
),
false
);
dest
->
evict_type
=
NONE
;
}
void
ChannelImpl
::
remove_dep
(
TensorInfo
*
dest
)
{
for
(
auto
i
:
dest
->
path
.
dep_outputs
)
{
auto
out_ptr
=
i
.
lock
();
if
(
out_ptr
)
{
regenerate
(
out_ptr
.
get
(),
true
);
for
(
auto
&&
output
:
path
->
outputs
)
{
if
(
output
==
nullptr
)
{
continue
;
}
output
->
evict_type
=
NONE
;
}
m_buffer
.
enqueue
(
ApplyOp
{
path
->
op
,
path
->
inputs
,
path
->
outputs
});
}
void
ChannelImpl
::
do_drop
(
TensorInfo
*
dest
)
{
if
(
dest
->
evict_type
==
SWAP
)
{
mgb_log_warn
(
"the evict type of tensor %p was set to SWAP, this DROP operation will be ignored"
,
dest
);
return
;
}
if
(
!
dest
->
path
.
op
)
{
mgb_log_warn
(
"the input that produced tensor %p has been deleted, this drop operation will be ignored"
,
dest
);
return
;
}
if
(
dest
->
recompute_times
>=
m_max_recompute_time
)
{
mgb_log_warn
(
"the recomputation time for tensor %p exceeds the limit, this drop operation will be ignored"
,
dest
);
return
;
}
if
(
!
dest
->
ptr
)
{
return
;
void
ChannelImpl
::
detach_users
(
TensorInfo
*
dest
)
{
SmallVector
<
TensorInfo
::
ComputePath
*>
users
=
dest
->
users
;
for
(
auto
*
user
:
users
)
{
for
(
auto
*
output
:
user
->
outputs
)
{
if
(
output
==
nullptr
)
{
continue
;
}
regenerate
(
output
);
output
->
detach_producer
();
}
}
dest
->
evict_type
=
DROP
;
dest
->
value_fetched
=
false
;
dest
->
ptr
.
reset
();
dest
->
users
.
clear
();
}
void
ChannelImpl
::
set_swap_flag
(
bool
flag
)
{
if
((
!
flag
)
&&
(
m_enable_evict
&
SWAP
))
{
for
(
auto
handle
:
m_valid_handle
)
{
auto
*
info
=
reinterpret_cast
<
TensorInfo
*>
(
handle
);
if
(
info
->
evict_type
==
SWAP
)
{
swap_in
(
info
);
}
}
}
if
(
flag
)
{
m_enable_evict
|=
SWAP
;
}
else
{
...
...
@@ -429,6 +413,14 @@ void ChannelImpl::set_swap_flag(bool flag) {
}
void
ChannelImpl
::
set_drop_flag
(
bool
flag
)
{
if
((
!
flag
)
&&
(
m_enable_evict
&
DROP
))
{
for
(
auto
handle
:
m_valid_handle
)
{
auto
*
info
=
reinterpret_cast
<
TensorInfo
*>
(
handle
);
if
(
info
->
evict_type
==
DROP
)
{
recompute
(
info
->
producer
);
}
}
}
if
(
flag
)
{
m_enable_evict
|=
DROP
;
}
else
{
...
...
@@ -440,46 +432,6 @@ void ChannelImpl::set_buffer_length(int length) {
m_buffer
.
set_capacity
(
length
);
}
void
ChannelImpl
::
regenerate
(
TensorInfo
*
info
,
bool
must_drop
=
false
)
{
if
(
!
info
->
ptr
&&
info
->
evict_type
!=
NONE
)
{
if
(
info
->
evict_type
==
SWAP
)
{
do_swap_in
(
info
);
}
else
{
mgb_assert
(
info
->
evict_type
==
DROP
);
mgb_assert
(
info
->
path
.
op
,
"recomputation path not found"
);
auto
path
=
info
->
path
;
SmallVector
<
TensorPtr
>
inputs
;
inputs
.
reserve
(
path
.
inputs
.
size
());
for
(
auto
i
:
path
.
inputs
)
{
mgb_assert
(
i
,
"invalid history input"
);
if
(
!
i
->
ptr
)
{
regenerate
(
i
.
get
(),
must_drop
);
}
inputs
.
push_back
(
i
->
ptr
);
}
auto
outputs
=
OpDef
::
apply_on_physical_tensor
(
*
path
.
op
,
inputs
);
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
i
++
)
{
auto
out_ptr
=
path
.
outputs
[
i
].
lock
();
if
(
out_ptr
)
{
out_ptr
->
recompute_times
++
;
if
(
!
out_ptr
->
ptr
&&
out_ptr
->
evict_type
==
DROP
)
{
produce_tensor
(
out_ptr
.
get
(),
std
::
move
(
outputs
[
i
]),
false
);
}
}
}
}
}
if
(
must_drop
)
{
if
(
info
->
path
.
op
)
{
info
->
path
.
op
.
reset
();
info
->
path
.
inputs
.
clear
();
if
(
info
->
evict_type
==
DROP
)
{
info
->
evict_type
=
NONE
;
}
}
}
}
void
ChannelImpl
::
process_one_task
(
Command
&
cmd
)
{
//TODO: remove std::visit for support osx 10.12
std
::
visit
([
this
](
auto
&
cmd
)
{
...
...
@@ -493,11 +445,6 @@ void ChannelImpl::process_one_task(Command& cmd) {
tensor_inputs
.
reserve
(
cmd
.
inputs
.
size
());
// refcnt == 1, owners: [TensorInfo::ptr]
for
(
auto
i
:
cmd
.
inputs
)
{
if
(
m_enable_evict
&&
i
->
evict_type
!=
NONE
)
{
if
(
!
i
->
ptr
)
{
regenerate
(
i
);
}
}
mgb_assert
(
i
->
ptr
,
"Invalid input tensor ptr!"
);
// refcnt ++, owners: [i->ptr, tensor_inputs]
tensor_inputs
.
push_back
(
i
->
ptr
);
...
...
@@ -515,16 +462,14 @@ void ChannelImpl::process_one_task(Command& cmd) {
*
cmd
.
op
,
std
::
move
(
tensor_inputs
));
mgb_assert
(
tensor_outputs
.
size
()
==
cmd
.
outputs
.
size
());
for
(
size_t
i
=
0
;
i
<
tensor_outputs
.
size
();
++
i
)
{
if
(
cmd
.
outputs
[
i
]
==
nullptr
)
{
continue
;
}
produce_tensor
(
cmd
.
outputs
[
i
],
std
::
move
(
tensor_outputs
[
i
]));
}
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
Del
>
)
{
free
(
cmd
.
dest
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
GetValue
>
)
{
if
(
m_enable_evict
&&
cmd
.
dest
->
evict_type
!=
NONE
)
{
if
(
!
cmd
.
dest
->
ptr
)
{
regenerate
(
cmd
.
dest
);
}
}
mgb_assert
(
cmd
.
dest
->
ptr
,
"Invalid tensor ptr!"
);
cmd
.
dest
->
ptr
->
fetch_value
();
MGB_LOCK_GUARD
(
m_mutex
);
...
...
@@ -533,11 +478,12 @@ void ChannelImpl::process_one_task(Command& cmd) {
m_cv
.
notify_all
();
}
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
SwapIn
>
)
{
do_swap_in
(
cmd
.
dest
);
produce_tensor
(
cmd
.
dest
,
Tensor
::
make
(
cmd
.
dest
->
h_value
)
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
SwapOut
>
)
{
do_swap_out
(
cmd
.
dest
);
cmd
.
dest
->
h_value
=
cmd
.
dest
->
ptr
->
get_value
();
cmd
.
dest
->
ptr
.
reset
();
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
Drop
>
)
{
do_drop
(
cmd
.
dest
);
cmd
.
dest
->
ptr
.
reset
(
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
Move
>
)
{
produce_tensor
(
cmd
.
dest
,
cmd
.
src
->
ptr
);
free
(
cmd
.
src
);
...
...
imperative/src/impl/interpreter_impl.h
浏览文件 @
3eb529d6
...
...
@@ -38,22 +38,77 @@ using TensorInfoPtr = std::shared_ptr<TensorInfo>;
struct
TensorInfo
{
TensorPtr
ptr
;
LogicalTensorDesc
desc
;
// FIXME: broken by drop
bool
value_fetched
=
false
;
bool
invalid
=
false
;
bool
allow_delete
=
false
;
EvictType
evict_type
=
NONE
;
HostTensorND
h_value
;
size_t
locked
=
0
;
// reserved for auto drop
size_t
pinned
=
0
;
size_t
recompute_times
=
0
;
struct
ComputePath
{
std
::
shared_ptr
<
OpDef
>
op
;
SmallVector
<
TensorInfoPtr
>
inputs
;
SmallVector
<
std
::
weak_ptr
<
TensorInfo
>>
outputs
;
SmallVector
<
std
::
weak_ptr
<
TensorInfo
>>
dep_outputs
;
}
path
;
SmallVector
<
TensorInfo
*>
inputs
;
SmallVector
<
TensorInfo
*>
unique_inputs
;
SmallVector
<
TensorInfo
*>
outputs
;
size_t
ref_cnt
()
{
return
outputs
.
size
()
-
std
::
count
(
outputs
.
begin
(),
outputs
.
end
(),
nullptr
);
}
static
ComputePath
*
make
(
std
::
shared_ptr
<
OpDef
>
op
,
SmallVector
<
TensorInfo
*>
inputs
,
SmallVector
<
TensorInfo
*>
outputs
)
{
auto
*
path
=
new
TensorInfo
::
ComputePath
();
path
->
op
=
op
;
path
->
inputs
=
inputs
;
path
->
outputs
=
outputs
;
// dedup
SmallVector
<
TensorInfo
*>
unique_inputs
=
inputs
;
std
::
sort
(
unique_inputs
.
begin
(),
unique_inputs
.
end
());
unique_inputs
.
erase
(
std
::
unique
(
unique_inputs
.
begin
(),
unique_inputs
.
end
()),
unique_inputs
.
end
());
path
->
unique_inputs
=
unique_inputs
;
// attach users
for
(
auto
input
:
unique_inputs
)
{
input
->
users
.
push_back
(
path
);
}
// attach producer
for
(
auto
output
:
outputs
)
{
output
->
producer
=
path
;
}
return
path
;
}
}
*
producer
=
nullptr
;
void
pin
()
{
++
pinned
;
}
void
unpin
()
{
--
pinned
;
}
void
detach_producer
()
{
if
(
!
producer
)
{
return
;
}
auto
output
=
std
::
find
(
producer
->
outputs
.
begin
(),
producer
->
outputs
.
end
(),
this
);
mgb_assert
(
output
!=
producer
->
outputs
.
end
());
*
output
=
nullptr
;
if
(
producer
->
ref_cnt
()
==
0
)
{
for
(
auto
*
input
:
producer
->
unique_inputs
)
{
input
->
users
.
erase
(
std
::
find
(
input
->
users
.
begin
(),
input
->
users
.
end
(),
producer
));
}
delete
producer
;
}
producer
=
nullptr
;
}
SmallVector
<
ComputePath
*>
users
;
};
struct
Put
{
...
...
@@ -186,17 +241,16 @@ struct ChannelImpl : Interpreter::Channel {
private:
TensorInfo
*
alloc
();
void
free
(
TensorInfo
*
);
void
remove_dep
(
TensorInfo
*
);
void
detach_users
(
TensorInfo
*
);
void
process_one_task
(
Command
&
);
void
check_worker_exc_unsafe
();
void
produce_tensor
(
TensorInfo
*
dest
,
TensorPtr
ptr
,
bool
notice
);
void
do_swap_out
(
TensorInfo
*
dest
);
void
do_swap_in
(
TensorInfo
*
dest
);
void
do_drop
(
TensorInfo
*
dest
);
void
regenerate
(
TensorInfo
*
dest
,
bool
must_drop
);
void
produce_tensor
(
TensorInfo
*
dest
,
TensorPtr
ptr
);
void
regenerate
(
TensorInfo
*
dest
);
void
recompute
(
TensorInfo
::
ComputePath
*
path
);
void
dispatch_default_cpu
(
std
::
shared_ptr
<
OpDef
>
op
,
...
...
@@ -235,24 +289,6 @@ private:
ChannelImpl
*
m_owner
;
}
m_worker
;
struct
SharedTensorInfoMap
{
void
insert
(
TensorInfo
*
info
)
{
MGB_LOCK_GUARD
(
mtx
);
tmap
.
emplace
(
info
,
TensorInfoPtr
{
info
,
[](
TensorInfo
*
ptr
){
ptr
->
allow_delete
=
true
;}});
}
void
erase
(
TensorInfo
*
info
)
{
MGB_LOCK_GUARD
(
mtx
);
tmap
.
erase
(
info
);
}
TensorInfoPtr
at
(
TensorInfo
*
info
)
{
MGB_LOCK_GUARD
(
mtx
);
return
tmap
.
at
(
info
);
}
private:
std
::
mutex
mtx
;
std
::
unordered_map
<
TensorInfo
*
,
TensorInfoPtr
>
tmap
;
}
m_st
;
/**
* Buf a command window for following fuse
* example:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录