Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
50c4daac
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
50c4daac
编写于
11月 02, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/interpreter): add async_level mechanism for Interpreter
GitOrigin-RevId: 8615a23b75b7e3172d724acc8f7fffd2cf9b73d5
上级
82b0f677
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
64 addition
and
9 deletion
+64
-9
imperative/python/src/imperative_rt.cpp
imperative/python/src/imperative_rt.cpp
+3
-1
imperative/python/test/unit/core/test_async_level.py
imperative/python/test/unit/core/test_async_level.py
+35
-0
imperative/src/impl/interpreter_impl.cpp
imperative/src/impl/interpreter_impl.cpp
+19
-7
imperative/src/impl/interpreter_impl.h
imperative/src/impl/interpreter_impl.h
+6
-1
imperative/src/include/megbrain/imperative/interpreter.h
imperative/src/include/megbrain/imperative/interpreter.h
+1
-0
未找到文件。
imperative/python/src/imperative_rt.cpp
浏览文件 @
50c4daac
...
@@ -77,12 +77,14 @@ void init_imperative_rt(py::module m) {
...
@@ -77,12 +77,14 @@ void init_imperative_rt(py::module m) {
.
def
(
"get_shape"
,
&
Interpreter
::
Channel
::
get_shape
)
.
def
(
"get_shape"
,
&
Interpreter
::
Channel
::
get_shape
)
.
def
(
"_get_dev_tensor"
,
&
Interpreter
::
Channel
::
get_dev_tensor
)
.
def
(
"_get_dev_tensor"
,
&
Interpreter
::
Channel
::
get_dev_tensor
)
.
def
(
"apply_op"
,
&
Interpreter
::
Channel
::
apply_op
)
.
def
(
"apply_op"
,
&
Interpreter
::
Channel
::
apply_op
)
.
def
(
"config_async_level"
,
&
Interpreter
::
Channel
::
config_async_level
)
.
def
(
"get_async_level"
,
&
Interpreter
::
Channel
::
get_async_level
)
.
def
(
"sync"
,
&
Interpreter
::
Channel
::
sync
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
.
def
(
"sync"
,
&
Interpreter
::
Channel
::
sync
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
std
::
unique_ptr
<
Interpreter
::
Channel
>
ch
=
Interpreter
::
inst
().
create_channel
();
std
::
unique_ptr
<
Interpreter
::
Channel
>
ch
=
Interpreter
::
inst
().
create_channel
();
m
.
attr
(
"interpreter"
)
=
py
::
detail
::
make_caster
<
decltype
(
ch
)
>::
cast
(
m
.
attr
(
"interpreter"
)
=
py
::
detail
::
make_caster
<
decltype
(
ch
)
>::
cast
(
std
::
move
(
ch
),
py
::
return_value_policy
::
move
,
{});
std
::
move
(
ch
),
py
::
return_value_policy
::
move
,
{});
for
(
auto
name
:
{
"put"
,
"delete"
,
"get_value"
,
"get_dtype"
,
"get_device"
,
"get_shape"
,
"_get_dev_tensor"
,
"apply_op"
})
{
for
(
auto
name
:
{
"put"
,
"delete"
,
"get_value"
,
"get_dtype"
,
"get_device"
,
"get_shape"
,
"_get_dev_tensor"
,
"apply_op"
,
"config_async_level"
,
"get_async_level"
})
{
m
.
attr
(
name
)
=
m
.
attr
(
"interpreter"
).
attr
(
name
);
m
.
attr
(
name
)
=
m
.
attr
(
"interpreter"
).
attr
(
name
);
}
}
...
...
imperative/python/test/unit/core/test_async_level.py
0 → 100644
浏览文件 @
50c4daac
import
pytest
import
megengine
as
mge
import
megengine.functional
as
F
from
megengine.core._imperative_rt.imperative
import
config_async_level
,
get_async_level
def
test_basic
():
config_async_level
(
2
)
assert
get_async_level
()
==
2
with
pytest
.
raises
(
RuntimeError
):
config_async_level
(
3
)
def
test_level1_infer_value
():
config_async_level
(
1
)
a
=
mge
.
tensor
([[
1
,
2
],
[
2
,
3
],
[
3
,
4
]],
dtype
=
"float32"
)
b
=
mge
.
tensor
([
1
,
1
],
dtype
=
"float32"
)
# make DepType::VALUE unknown
c
=
b
*
2
with
pytest
.
raises
(
RuntimeError
):
d
=
F
.
reshape
(
a
,
c
)
def
test_level1_infer_shape_with_unknown
():
config_async_level
(
2
)
a
=
mge
.
tensor
([[
1
,
2
,
2
,
3
]],
dtype
=
"float32"
)
b
=
mge
.
tensor
([
1
,
1
])
c
=
b
*
2
# make DepType::SHAPE unknown
d
=
F
.
reshape
(
a
,
c
)
config_async_level
(
1
)
e
=
mge
.
tensor
([[
1
,
2
]],
dtype
=
"float32"
)
with
pytest
.
raises
(
RuntimeError
):
f
=
F
.
matmul
(
d
,
e
)
imperative/src/impl/interpreter_impl.cpp
浏览文件 @
50c4daac
...
@@ -54,21 +54,25 @@ void ChannelImpl::del(void* handle) {
...
@@ -54,21 +54,25 @@ void ChannelImpl::del(void* handle) {
SmallVector
<
void
*>
ChannelImpl
::
apply_op
(
SmallVector
<
void
*>
ChannelImpl
::
apply_op
(
std
::
shared_ptr
<
OpDef
>
op
,
std
::
shared_ptr
<
OpDef
>
op
,
const
SmallVector
<
void
*>&
inputs
)
{
const
SmallVector
<
void
*>&
inputs
)
{
SmallVector
<
TensorInfo
*>
input_infos
;
input_infos
.
reserve
(
inputs
.
size
());
SmallVector
<
LogicalTensorDesc
>
input_descs
;
SmallVector
<
LogicalTensorDesc
>
input_descs
;
input_descs
.
reserve
(
inputs
.
size
());
input_descs
.
reserve
(
inputs
.
size
());
for
(
auto
h
:
inputs
)
{
for
(
auto
i
:
inputs
)
{
auto
info
=
reinterpret_cast
<
TensorInfo
*>
(
h
);
auto
info
=
reinterpret_cast
<
TensorInfo
*>
(
i
);
input_infos
.
push_back
(
info
);
input_descs
.
push_back
(
info
->
desc
);
input_descs
.
push_back
(
info
->
desc
);
}
}
auto
output_descs
=
OpDef
::
infer_output_attrs_fallible
(
*
op
,
input_descs
);
auto
output_descs
=
OpDef
::
infer_output_attrs_fallible
(
*
op
,
input_descs
);
ApplyOp
cmd
{
std
::
move
(
op
)};
ApplyOp
cmd
{
std
::
move
(
op
)};
cmd
.
inputs
.
reserve
(
inputs
.
size
());
cmd
.
inputs
=
std
::
move
(
input_infos
);
for
(
auto
i
:
inputs
)
{
cmd
.
inputs
.
push_back
(
reinterpret_cast
<
TensorInfo
*>
(
i
));
}
cmd
.
outputs
.
reserve
(
output_descs
.
size
());
cmd
.
outputs
.
reserve
(
output_descs
.
size
());
SmallVector
<
void
*>
outputs
;
SmallVector
<
void
*>
outputs
;
bool
is_fallible
=
false
;
for
(
auto
&&
desc
:
output_descs
)
{
for
(
auto
&&
desc
:
output_descs
)
{
if
(
desc
.
layout
.
ndim
==
0
)
{
is_fallible
=
true
;
}
auto
info
=
alloc
();
auto
info
=
alloc
();
info
->
desc
=
desc
;
info
->
desc
=
desc
;
m_valid_handle
.
insert
(
info
);
m_valid_handle
.
insert
(
info
);
...
@@ -76,6 +80,9 @@ SmallVector<void*> ChannelImpl::apply_op(
...
@@ -76,6 +80,9 @@ SmallVector<void*> ChannelImpl::apply_op(
outputs
.
push_back
(
info
);
outputs
.
push_back
(
info
);
}
}
m_worker
.
add_task
(
std
::
move
(
cmd
));
m_worker
.
add_task
(
std
::
move
(
cmd
));
if
(
is_fallible
&&
m_async_level
<=
1
)
{
sync
();
}
return
outputs
;
return
outputs
;
}
}
...
@@ -162,7 +169,12 @@ void ChannelImpl::close() {
...
@@ -162,7 +169,12 @@ void ChannelImpl::close() {
}
}
void
ChannelImpl
::
config_async_level
(
int
level
)
{
void
ChannelImpl
::
config_async_level
(
int
level
)
{
mgb_assert
(
0
);
mgb_assert
(
level
<=
2
and
level
>=
0
,
"async_level should be 0, 1 or 2"
);
m_async_level
=
level
;
}
int
ChannelImpl
::
get_async_level
()
{
return
m_async_level
;
}
}
TensorInfo
*
ChannelImpl
::
alloc
()
{
TensorInfo
*
ChannelImpl
::
alloc
()
{
...
...
imperative/src/impl/interpreter_impl.h
浏览文件 @
50c4daac
...
@@ -74,6 +74,7 @@ struct ChannelImpl : Interpreter::Channel {
...
@@ -74,6 +74,7 @@ struct ChannelImpl : Interpreter::Channel {
void
close
()
override
;
void
close
()
override
;
void
config_async_level
(
int
level
)
override
;
void
config_async_level
(
int
level
)
override
;
int
get_async_level
()
override
;
private:
private:
TensorInfo
*
alloc
();
TensorInfo
*
alloc
();
...
@@ -101,7 +102,11 @@ private:
...
@@ -101,7 +102,11 @@ private:
ChannelImpl
*
m_owner
;
ChannelImpl
*
m_owner
;
}
m_worker
;
}
m_worker
;
int
m_async_level
=
2
;
//! config whether raise error exactly when invoking op.
//! level 2: both device and user side errors are async;
//! level 1: user side errors are sync;
//! level 0: both sync.
int
m_async_level
=
1
;
};
};
}
// namespace mgb::imperative::interpreter::intl
}
// namespace mgb::imperative::interpreter::intl
imperative/src/include/megbrain/imperative/interpreter.h
浏览文件 @
50c4daac
...
@@ -41,6 +41,7 @@ struct Interpreter {
...
@@ -41,6 +41,7 @@ struct Interpreter {
virtual
void
close
()
=
0
;
virtual
void
close
()
=
0
;
virtual
void
config_async_level
(
int
level
)
=
0
;
virtual
void
config_async_level
(
int
level
)
=
0
;
virtual
int
get_async_level
()
=
0
;
};
};
virtual
std
::
unique_ptr
<
Channel
>
create_channel
()
=
0
;
virtual
std
::
unique_ptr
<
Channel
>
create_channel
()
=
0
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录