Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
4fcf8b49
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
4fcf8b49
编写于
10月 10, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
chore(mge): improve mem borrow
GitOrigin-RevId: 599562260cc1a668788a173b09bdd42a6a6615ca
上级
a7ca0588
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
56 addition
and
16 deletion
+56
-16
imperative/src/impl/physical_tensor.cpp
imperative/src/impl/physical_tensor.cpp
+56
-16
未找到文件。
imperative/src/impl/physical_tensor.cpp
浏览文件 @
4fcf8b49
...
@@ -59,10 +59,7 @@ class CompNodeSyncManager {
...
@@ -59,10 +59,7 @@ class CompNodeSyncManager {
void
emplace
(
uint64_t
t
,
A
&&
a
)
{
void
emplace
(
uint64_t
t
,
A
&&
a
)
{
map
.
emplace_hint
(
map
.
end
(),
t
,
std
::
forward
<
A
>
(
a
));
map
.
emplace_hint
(
map
.
end
(),
t
,
std
::
forward
<
A
>
(
a
));
}
}
void
release
(
uint64_t
t
)
{
void
release
(
uint64_t
t
)
{
map
.
erase
(
map
.
begin
(),
map
.
upper_bound
(
t
));
}
auto
it
=
map
.
upper_bound
(
t
);
map
.
erase
(
map
.
begin
(),
it
);
}
};
};
//! next virtual event
//! next virtual event
...
@@ -99,6 +96,7 @@ class CompNodeSyncManager {
...
@@ -99,6 +96,7 @@ class CompNodeSyncManager {
return
cndata
.
events
.
emplace_hint
(
cndata
.
events
.
end
(),
cndata
.
next
++
,
e
);
return
cndata
.
events
.
emplace_hint
(
cndata
.
events
.
end
(),
cndata
.
next
++
,
e
);
}
}
// get a real event t' such that t <= t'
std
::
pair
<
uint64_t
,
CompNode
::
Event
*>
get_event
(
std
::
pair
<
uint64_t
,
CompNode
::
Event
*>
get_event
(
CompNode
cn
,
size_t
cnid
,
uint64_t
t
,
std
::
unique_lock
<
std
::
mutex
>&
lock
)
{
CompNode
cn
,
size_t
cnid
,
uint64_t
t
,
std
::
unique_lock
<
std
::
mutex
>&
lock
)
{
auto
&
cndata
=
m_cndata
[
cnid
];
auto
&
cndata
=
m_cndata
[
cnid
];
...
@@ -145,8 +143,11 @@ class CompNodeSyncManager {
...
@@ -145,8 +143,11 @@ class CompNodeSyncManager {
std
::
vector
<
Stat
>
stats
;
std
::
vector
<
Stat
>
stats
;
std
::
vector
<
Item
>
todos
;
std
::
vector
<
Item
>
todos
;
std
::
vector
<
bool
>
updated
;
std
::
unique_lock
lock
(
m_mtx
);
std
::
unique_lock
lock
(
m_mtx
);
for
(;;)
{
for
(;;)
{
updated
.
clear
();
updated
.
resize
(
m_cndata
.
size
(),
false
);
// copy events to a temporary storage so that we may unlock while polling
// copy events to a temporary storage so that we may unlock while polling
stats
.
resize
(
m_cndata
.
size
());
stats
.
resize
(
m_cndata
.
size
());
for
(
size_t
cnid
=
0
;
cnid
<
m_cndata
.
size
();
++
cnid
)
{
for
(
size_t
cnid
=
0
;
cnid
<
m_cndata
.
size
();
++
cnid
)
{
...
@@ -192,35 +193,63 @@ class CompNodeSyncManager {
...
@@ -192,35 +193,63 @@ class CompNodeSyncManager {
lock
.
lock
();
lock
.
lock
();
// update completed
for
(
auto
[
cnid
,
stat
]
:
views
::
enumerate
(
stats
))
{
if
(
stat
.
num_success
==
0
)
{
continue
;
}
auto
t
=
stat
.
it
->
first
;
auto
&
cndata
=
m_cndata
[
cnid
];
if
(
cndata
.
completed
<
t
)
{
cndata
.
completed
=
t
;
updated
[
cnid
]
=
true
;
// also propagate by the transitive <= relation to ensure that
// we can safely delete ordering information without performance
// degradation even if some completion events are missed by our query
auto
it
=
cndata
.
ordering
.
upper_bound
(
t
);
if
(
it
!=
cndata
.
ordering
.
begin
())
{
it
=
std
::
prev
(
it
);
for
(
auto
[
cnid
,
t
]
:
views
::
enumerate
(
it
->
second
))
{
auto
&
cndata
=
m_cndata
[
cnid
];
if
(
cndata
.
completed
<
t
)
{
cndata
.
completed
=
t
;
updated
[
cnid
]
=
true
;
}
}
}
}
}
// release dev storage
// release dev storage
for
(
size_t
receiver_cnid
=
0
;
receiver_cnid
<
m_cndata
.
size
();
for
(
size_t
receiver_cnid
=
0
;
receiver_cnid
<
m_cndata
.
size
();
++
receiver_cnid
)
{
++
receiver_cnid
)
{
for
(
size_t
releaser_cnid
=
0
;
for
(
size_t
releaser_cnid
=
0
;
releaser_cnid
<
m_cndata
[
receiver_cnid
].
release_queues
.
size
();
releaser_cnid
<
m_cndata
[
receiver_cnid
].
release_queues
.
size
();
++
releaser_cnid
)
{
++
releaser_cnid
)
{
if
(
releaser_cnid
>=
stats
.
size
()
||
if
(
!
(
releaser_cnid
<
updated
.
size
()
&&
updated
[
releaser_cnid
]))
{
stats
[
releaser_cnid
].
num_success
==
0
)
{
continue
;
continue
;
}
}
auto
&
q
=
m_cndata
[
receiver_cnid
].
release_queues
[
releaser_cnid
];
auto
&
q
=
m_cndata
[
receiver_cnid
].
release_queues
[
releaser_cnid
];
q
.
release
(
stats
[
releaser_cnid
].
it
->
first
);
q
.
release
(
m_cndata
[
releaser_cnid
].
completed
);
}
}
}
}
for
(
size_t
cnid
=
0
;
cnid
<
stats
.
size
();
++
cnid
)
{
for
(
size_t
cnid
=
0
;
cnid
<
updated
.
size
();
++
cnid
)
{
if
(
stats
[
cnid
].
num_success
==
0
)
{
if
(
!
updated
[
cnid
]
)
{
continue
;
continue
;
}
}
auto
&
cndata
=
m_cndata
[
cnid
];
auto
&
cndata
=
m_cndata
[
cnid
];
auto
it
=
stats
[
cnid
].
it
;
auto
t
=
cndata
.
completed
;
auto
t
=
it
->
first
;
// update completed
cndata
.
completed
=
t
;
// release host storage
// release host storage
cndata
.
host_release_queue
.
release
(
t
);
cndata
.
host_release_queue
.
release
(
t
);
// remove completed events
// remove completed events
auto
&
events
=
cndata
.
events
;
[
&
](
auto
&
map
)
{
events
.
erase
(
events
.
begin
(),
std
::
next
(
it
));
map
.
erase
(
map
.
begin
(),
map
.
upper_bound
(
t
));
}(
cndata
.
events
);
// delete ordering information
[
&
](
auto
&
map
)
{
map
.
erase
(
map
.
begin
(),
map
.
upper_bound
(
t
));
}(
cndata
.
ordering
);
}
}
using
namespace
std
::
literals
;
using
namespace
std
::
literals
;
...
@@ -287,6 +316,17 @@ public:
...
@@ -287,6 +316,17 @@ public:
auto
waitee_id
=
get_cnid_unsafe
(
waitee
);
auto
waitee_id
=
get_cnid_unsafe
(
waitee
);
auto
&
waiter_data
=
m_cndata
.
at
(
waiter_id
);
auto
&
waiter_data
=
m_cndata
.
at
(
waiter_id
);
auto
&
waitee_data
=
m_cndata
.
at
(
waitee_id
);
auto
&
waitee_data
=
m_cndata
.
at
(
waitee_id
);
if
(
t
<=
waitee_data
.
completed
)
{
return
;
}
if
(
waiter_data
.
ordering
.
size
()
&&
waitee_id
<
waiter_data
.
ordering
.
rbegin
()
->
second
.
size
()
&&
t
<=
waiter_data
.
ordering
.
rbegin
()
->
second
[
waitee_id
])
{
return
;
}
auto
[
t_waitee
,
e
]
=
get_event
(
waitee
,
waitee_id
,
t
,
lock
);
auto
[
t_waitee
,
e
]
=
get_event
(
waitee
,
waitee_id
,
t
,
lock
);
// DO NOT unlock around this line! Event* could be invalidated!
// DO NOT unlock around this line! Event* could be invalidated!
...
@@ -301,7 +341,7 @@ public:
...
@@ -301,7 +341,7 @@ public:
ordering
[
waitee_id
]
=
t_waitee
;
ordering
[
waitee_id
]
=
t_waitee
;
ordering
[
waiter_id
]
=
t_waiter
;
ordering
[
waiter_id
]
=
t_waiter
;
{
{
auto
it
=
waitee_data
.
ordering
.
low
er_bound
(
t_waitee
);
auto
it
=
waitee_data
.
ordering
.
upp
er_bound
(
t_waitee
);
if
(
it
!=
waitee_data
.
ordering
.
begin
())
{
if
(
it
!=
waitee_data
.
ordering
.
begin
())
{
for
(
auto
[
a
,
b
]
:
views
::
zip
(
ordering
,
std
::
prev
(
it
)
->
second
))
{
for
(
auto
[
a
,
b
]
:
views
::
zip
(
ordering
,
std
::
prev
(
it
)
->
second
))
{
static_assert
(
std
::
is_lvalue_reference_v
<
decltype
(
a
)
>
);
static_assert
(
std
::
is_lvalue_reference_v
<
decltype
(
a
)
>
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录