Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
d9a9d9d4
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
d9a9d9d4
编写于
9月 07, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(imperative/fastrun): set workspace limit for imperative rt
GitOrigin-RevId: 474dc691a3ec20b09eac3e7b6682f622f6e56774
上级
a09a2b73
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
34 addition
and
10 deletion
+34
-10
imperative/src/impl/proxy_graph.cpp
imperative/src/impl/proxy_graph.cpp
+13
-0
imperative/src/impl/proxy_graph.h
imperative/src/impl/proxy_graph.h
+5
-0
src/core/impl/comp_node/cuda/comp_node.cpp
src/core/impl/comp_node/cuda/comp_node.cpp
+7
-0
src/core/include/megbrain/comp_node.h
src/core/include/megbrain/comp_node.h
+3
-0
src/opr/impl/internal/megdnn_opr_wrapper.cpp
src/opr/impl/internal/megdnn_opr_wrapper.cpp
+6
-10
未找到文件。
imperative/src/impl/proxy_graph.cpp
浏览文件 @
d9a9d9d4
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "megbrain/graph/static_infer.h"
#include "megbrain/graph/static_infer.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/opr/utility.h"
...
@@ -509,6 +510,8 @@ SmallVector<LogicalTensorDesc> ProxyGraph::infer_output_attrs(
...
@@ -509,6 +510,8 @@ SmallVector<LogicalTensorDesc> ProxyGraph::infer_output_attrs(
const
OpDef
&
opdef
,
const
SmallVector
<
Tensor
*>&
inputs
)
{
const
OpDef
&
opdef
,
const
SmallVector
<
Tensor
*>&
inputs
)
{
SmallVector
<
LogicalTensorDesc
>
ret
;
SmallVector
<
LogicalTensorDesc
>
ret
;
CUR_OPR_GUARD
(
get_proxy_opr
(
opdef
,
inputs
));
CUR_OPR_GUARD
(
get_proxy_opr
(
opdef
,
inputs
));
::
mgb
::
opr
::
intl
::
WorkspaceLimitHook
::
set_impl
(
m_graph
.
get
(),
ProxyGraph
::
get_workspace_limit
);
do_shape_infer
(
true
);
do_shape_infer
(
true
);
for
(
auto
&&
i
:
m_cur_opr
->
usable_output
())
{
for
(
auto
&&
i
:
m_cur_opr
->
usable_output
())
{
mgb_assert
(
i
->
dtype
().
valid
()
&&
i
->
comp_node
().
valid
());
mgb_assert
(
i
->
dtype
().
valid
()
&&
i
->
comp_node
().
valid
());
...
@@ -547,6 +550,14 @@ void ProxyGraph::init_output_tensor(
...
@@ -547,6 +550,14 @@ void ProxyGraph::init_output_tensor(
// get proxy opr
// get proxy opr
auto
proxy
=
m_cur_opr
;
auto
proxy
=
m_cur_opr
;
auto
get_workspace_size
=
[
=
](
CompNode
cn
,
size_t
old_limit
)
{
size_t
limit
=
0
;
for
(
auto
&&
var
:
workspaces
)
{
limit
+=
var
->
dtype
().
size
(
var
->
shape
().
total_nr_elems
());
}
return
limit
;
};
::
mgb
::
opr
::
intl
::
WorkspaceLimitHook
::
set_impl
(
m_graph
.
get
(),
get_workspace_size
);
do_shape_infer
(
true
);
do_shape_infer
(
true
);
size_t
j
=
0
;
size_t
j
=
0
;
...
@@ -640,6 +651,8 @@ std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> ProxyGraph::
...
@@ -640,6 +651,8 @@ std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> ProxyGraph::
const
SmallVector
<
MemoryDesc
>&
inputs_mems
)
{
const
SmallVector
<
MemoryDesc
>&
inputs_mems
)
{
auto
opr
=
get_proxy_opr
(
def
,
inputs_tensors
);
auto
opr
=
get_proxy_opr
(
def
,
inputs_tensors
);
CUR_OPR_GUARD
(
opr
);
CUR_OPR_GUARD
(
opr
);
::
mgb
::
opr
::
intl
::
WorkspaceLimitHook
::
set_impl
(
m_graph
.
get
(),
ProxyGraph
::
get_workspace_limit
);
do_shape_infer
(
true
);
do_shape_infer
(
true
);
SmallVector
<
MemoryDesc
>
outputs
;
SmallVector
<
MemoryDesc
>
outputs
;
SmallVector
<
MemoryDesc
>
workspaces
;
SmallVector
<
MemoryDesc
>
workspaces
;
...
...
imperative/src/impl/proxy_graph.h
浏览文件 @
d9a9d9d4
...
@@ -27,6 +27,11 @@ public:
...
@@ -27,6 +27,11 @@ public:
static
std
::
unique_ptr
<
MegBrainError
>
get_async_error
()
{
static
std
::
unique_ptr
<
MegBrainError
>
get_async_error
()
{
return
std
::
move
(
tm_async_error
);
return
std
::
move
(
tm_async_error
);
}
}
static
size_t
get_workspace_limit
(
CompNode
cn
,
size_t
old_limit
)
{
size_t
free
=
cn
.
get_free_mem
();
size_t
lmt
=
cn
.
get_max_block_size_available
();
return
std
::
max
(
lmt
,
free
);
}
/********************** Physical Tensor API **********************/
/********************** Physical Tensor API **********************/
...
...
src/core/impl/comp_node/cuda/comp_node.cpp
浏览文件 @
d9a9d9d4
...
@@ -273,6 +273,13 @@ public:
...
@@ -273,6 +273,13 @@ public:
activate
();
activate
();
return
m_mem_alloc
->
get_max_block_size_available
();
return
m_mem_alloc
->
get_max_block_size_available
();
}
}
size_t
get_free_mem
()
override
{
m_env
.
cuda_env
().
activate
();
size_t
tot
,
free
;
MGB_CUDA_CHECK
(
cudaMemGetInfo
(
&
free
,
&
tot
));
return
free
;
}
#endif
#endif
Locator
locator
()
override
{
return
m_locator
;
}
Locator
locator
()
override
{
return
m_locator
;
}
...
...
src/core/include/megbrain/comp_node.h
浏览文件 @
d9a9d9d4
...
@@ -336,6 +336,8 @@ public:
...
@@ -336,6 +336,8 @@ public:
size_t
get_max_block_size_available
()
const
{
size_t
get_max_block_size_available
()
const
{
return
m_impl
->
get_max_block_size_available
();
return
m_impl
->
get_max_block_size_available
();
}
}
size_t
get_free_mem
()
const
{
return
m_impl
->
get_free_mem
();
}
#endif
#endif
//! change to another stream on the same memory node
//! change to another stream on the same memory node
...
@@ -519,6 +521,7 @@ protected:
...
@@ -519,6 +521,7 @@ protected:
}
}
virtual
size_t
get_used_memory
()
{
return
0
;
}
virtual
size_t
get_used_memory
()
{
return
0
;
}
virtual
size_t
get_max_block_size_available
()
{
return
0
;
}
virtual
size_t
get_max_block_size_available
()
{
return
0
;
}
virtual
size_t
get_free_mem
()
{
return
0
;
}
#endif
#endif
virtual
Locator
locator
()
=
0
;
virtual
Locator
locator
()
=
0
;
...
...
src/opr/impl/internal/megdnn_opr_wrapper.cpp
浏览文件 @
d9a9d9d4
...
@@ -428,13 +428,12 @@ void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl /* impl */) {
...
@@ -428,13 +428,12 @@ void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl /* impl */) {
mgb_assert
(
false
);
mgb_assert
(
false
);
}
}
const
WorkspaceLimitHook
::
GetWorkspaceLimitImpl
&
WorkspaceLimitHook
::
get_impl
()
const
WorkspaceLimitHook
::
GetWorkspaceLimitImpl
&
WorkspaceLimitHook
::
get_impl
()
const
{
const
{
mgb_assert
(
false
);
mgb_assert
(
false
);
}
}
void
WorkspaceLimitHook
::
set_impl
(
ComputingGraph
*
/* graph */
,
void
WorkspaceLimitHook
::
set_impl
(
GetWorkspaceLimitImpl
/* impl */
)
{
ComputingGraph
*
/* graph */
,
GetWorkspaceLimitImpl
/* impl */
)
{
mgb_assert
(
false
);
mgb_assert
(
false
);
}
}
...
@@ -447,13 +446,11 @@ void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl impl) {
...
@@ -447,13 +446,11 @@ void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl impl) {
m_impl
=
std
::
move
(
impl
);
m_impl
=
std
::
move
(
impl
);
}
}
const
WorkspaceLimitHook
::
GetWorkspaceLimitImpl
&
WorkspaceLimitHook
::
get_impl
()
const
WorkspaceLimitHook
::
GetWorkspaceLimitImpl
&
WorkspaceLimitHook
::
get_impl
()
const
{
const
{
return
m_impl
;
return
m_impl
;
}
}
void
WorkspaceLimitHook
::
set_impl
(
ComputingGraph
*
graph
,
void
WorkspaceLimitHook
::
set_impl
(
ComputingGraph
*
graph
,
GetWorkspaceLimitImpl
impl
)
{
GetWorkspaceLimitImpl
impl
)
{
mgb_assert
(
graph
->
options
().
imperative_proxy_graph
);
mgb_assert
(
graph
->
options
().
imperative_proxy_graph
);
auto
maker
=
[]()
{
return
std
::
make_shared
<
WorkspaceLimitHook
>
();
};
auto
maker
=
[]()
{
return
std
::
make_shared
<
WorkspaceLimitHook
>
();
};
graph
->
options
()
graph
->
options
()
...
@@ -464,8 +461,7 @@ void WorkspaceLimitHook::set_impl(ComputingGraph* graph,
...
@@ -464,8 +461,7 @@ void WorkspaceLimitHook::set_impl(ComputingGraph* graph,
const
WorkspaceLimitHook
::
GetWorkspaceLimitImpl
&
WorkspaceLimitHook
::
get_impl
(
const
WorkspaceLimitHook
::
GetWorkspaceLimitImpl
&
WorkspaceLimitHook
::
get_impl
(
ComputingGraph
*
graph
)
{
ComputingGraph
*
graph
)
{
mgb_assert
(
graph
->
options
().
imperative_proxy_graph
);
mgb_assert
(
graph
->
options
().
imperative_proxy_graph
);
auto
container
=
auto
container
=
graph
->
options
().
user_data
.
get_user_data
<
WorkspaceLimitHook
>
();
graph
->
options
().
user_data
.
get_user_data
<
WorkspaceLimitHook
>
();
mgb_assert
(
container
.
second
==
1
);
mgb_assert
(
container
.
second
==
1
);
return
container
.
first
[
0
]
->
get_impl
();
return
container
.
first
[
0
]
->
get_impl
();
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录