Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
619d78ed
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
411
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
619d78ed
编写于
11月 17, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(imperative): check async error when getting value
GitOrigin-RevId: 52b8a29932d2abb33f4bb3d4acff91fe53a6a998
上级
2afa0af9
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
37 addition
and
8 deletion
+37
-8
imperative/python/megengine/functional/vision.py
imperative/python/megengine/functional/vision.py
+1
-0
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+13
-7
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+4
-0
imperative/python/test/unit/core/test_interpreter.py
imperative/python/test/unit/core/test_interpreter.py
+9
-0
imperative/src/impl/interpreter/interpreter_impl.cpp
imperative/src/impl/interpreter/interpreter_impl.cpp
+10
-1
未找到文件。
imperative/python/megengine/functional/vision.py
浏览文件 @
619d78ed
...
...
@@ -420,6 +420,7 @@ def warp_affine(
Here all available options for params are listed,
however it does not mean that you can use all the combinations.
On different platforms, different combinations are supported.
``warp_affine`` only support forward inference, Please refer to ``warp_perspective`` if backward is needed.
"""
conv_format
=
_config
.
_get_actual_op_param
(
format
,
_config
.
__conv_format
)
...
...
imperative/python/megengine/jit/tracing.py
浏览文件 @
619d78ed
...
...
@@ -104,6 +104,7 @@ class TensorInfo:
"shape"
,
"is_const"
,
"bound_data"
,
"bound_data_numpy"
,
# resources for execution
"varnode"
,
"data_setter"
,
...
...
@@ -119,12 +120,18 @@ class TensorInfo:
self
.
shape_read
=
None
self
.
value_read
=
None
self
.
bound_data
=
None
self
.
bound_data_numpy
=
None
self
.
data_setter
=
None
self
.
shape_reader
=
None
self
.
value_reader
=
None
self
.
data_reader
=
None
def
get_numpy
(
self
):
if
self
.
bound_data_numpy
is
None
:
self
.
bound_data_numpy
=
self
.
bound_data
.
numpy
()
return
self
.
bound_data_numpy
_io_op_types
=
{
AssertEqual
,
CollectiveComm
,
RemoteSend
,
RemoteRecv
}
...
...
@@ -292,7 +299,7 @@ class trace:
# Const op is represented by a str
assert
isinstance
(
op_
,
str
)
and
op_
==
"Const"
expected
=
self
.
_tinfo
[
ohandles
[
0
]].
bound_data
.
numpy
()
expected
=
self
.
_tinfo
[
ohandles
[
0
]].
get_
numpy
()
shape
=
value
.
shape
if
shape
!=
expected
.
shape
or
dtype
!=
expected
.
dtype
:
eq
=
False
...
...
@@ -369,6 +376,7 @@ class trace:
info
.
dtype
=
x
.
dtype
info
.
shape
=
x
.
shape
info
.
bound_data
=
x
info
.
bound_data_numpy
=
None
info
.
is_const
=
True
x
.
_mixin_handle
=
h
x
.
_recording
=
True
...
...
@@ -612,9 +620,7 @@ class trace:
assert
info
.
external
assert
info
.
bound_data
info
.
varnode
=
graph
.
make_const
(
info
.
bound_data
.
numpy
(),
info
.
bound_data
.
dtype
,
info
.
bound_data
.
device
,
info
.
get_numpy
(),
info
.
bound_data
.
dtype
,
info
.
bound_data
.
device
,
)
continue
...
...
@@ -627,7 +633,7 @@ class trace:
if
info
.
bound_data
:
if
getattr
(
info
,
"is_const"
,
False
):
info
.
varnode
=
graph
.
make_const
(
info
.
bound_data
.
numpy
(),
info
.
get_
numpy
(),
info
.
bound_data
.
dtype
,
info
.
bound_data
.
device
,
)
...
...
@@ -1174,7 +1180,7 @@ class trace:
assert
info
.
external
assert
info
.
bound_data
h2v
[
h
]
=
graph
.
make_const
(
info
.
bound_data
.
numpy
(),
info
.
get_
numpy
(),
dtype
=
info
.
dtype
,
device
=
dumped_device
(
info
),
name
=
info
.
name
,
...
...
@@ -1187,7 +1193,7 @@ class trace:
assert
info
.
external
assert
info
.
bound_data
h2v
[
h
]
=
graph
.
make_const
(
info
.
bound_data
.
numpy
(),
info
.
get_
numpy
(),
dtype
=
info
.
dtype
,
device
=
dumped_device
(
info
),
name
=
info
.
name
,
...
...
imperative/python/src/tensor.cpp
浏览文件 @
619d78ed
...
...
@@ -1074,6 +1074,10 @@ void init_tensor(py::module m) {
[]()
{
interpreter_for_py
->
sync
();
CompNode
::
sync_all
();
CompNode
::
foreach
([](
CompNode
cn
)
{
auto
err
=
cn
.
check_async_error
();
mgb_assert
(
!
err
,
"%s"
,
err
->
what
());
});
sync_py_task_q
();
},
py
::
call_guard
<
py
::
gil_scoped_release
>
());
...
...
imperative/python/test/unit/core/test_interpreter.py
浏览文件 @
619d78ed
...
...
@@ -96,6 +96,15 @@ def test_regression_2870():
(
x
+
x
).
numpy
()
@
pytest
.
mark
.
require_ngpu
(
1
)
def
test_async_error_check
():
src
=
mge
.
tensor
([[
1.0
,
2.0
]])
index
=
mge
.
tensor
([
3
])
val
=
F
.
indexing_one_hot
(
src
,
index
)
with
pytest
.
raises
(
RuntimeError
):
val
.
numpy
()
# NOTE: DO NOT REMOVE THIS TEST
# This is also a compatibility test for
# mge.core.set_option('async_level', 0).
...
...
imperative/src/impl/interpreter/interpreter_impl.cpp
浏览文件 @
619d78ed
...
...
@@ -156,6 +156,8 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
if
(
m_async_level
==
0
)
{
sync_impl
();
info
->
desc
.
comp_node
.
sync
();
auto
err
=
info
->
desc
.
comp_node
.
check_async_error
();
mgb_assert
(
!
err
,
"%s"
,
err
->
what
());
}
return
info
;
}
...
...
@@ -336,6 +338,8 @@ void ChannelImpl::dispatch_kernel(
for
(
auto
&&
oup
:
*
outputs
)
{
auto
info
=
reinterpret_cast
<
TensorInfo
*>
(
oup
);
info
->
ptr
->
comp_node
().
sync
();
auto
err
=
info
->
ptr
->
comp_node
().
check_async_error
();
mgb_assert
(
!
err
,
"%s"
,
err
->
what
());
}
}
}
...
...
@@ -931,7 +935,8 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
MGB_RECORD_EVENT
(
TensorWaitPropEvent
,
info
->
id
,
m_waitee_id
,
prop
);
bool
require_host
=
prop
==
TensorProp
::
HostValue
;
auto
host_available
=
[
&
]
{
return
info
->
ptr
&&
info
->
ptr
->
value_fetched
();
};
if
(
require_host
&&
!
host_available
())
{
bool
wait_host
=
!
host_available
();
if
(
require_host
&&
wait_host
)
{
// avoid dead lock
lock
.
unlock
();
m_buffer
.
enqueue
(
GetValue
{
info
});
...
...
@@ -944,6 +949,10 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
});
MGB_RECORD_EVENT
(
TensorWaitPropFinishEvent
,
info
->
id
,
m_waitee_id
,
prop
);
m_waitee
=
nullptr
;
if
(
require_host
&&
wait_host
)
{
auto
err
=
info
->
ptr
->
comp_node
().
check_async_error
();
mgb_assert
(
!
err
,
"%s"
,
err
->
what
());
}
return
info
->
ptr
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录