Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
8109cc4b
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
8109cc4b
编写于
9月 23, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(imperative): set exception which worker thread throws for rendezvous
GitOrigin-RevId: f583888fdfdd422262a9bd0bcd3425055ce51a94
上级
2f4a75e7
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
109 addition
and
10 deletion
+109
-10
imperative/python/megengine/core/tensor/megbrain_graph.py
imperative/python/megengine/core/tensor/megbrain_graph.py
+10
-1
imperative/python/src/graph_rt.cpp
imperative/python/src/graph_rt.cpp
+39
-7
imperative/python/src/graph_rt.h
imperative/python/src/graph_rt.h
+43
-2
imperative/python/test/unit/core/test_megbrain_graph.py
imperative/python/test/unit/core/test_megbrain_graph.py
+17
-0
未找到文件。
imperative/python/megengine/core/tensor/megbrain_graph.py
浏览文件 @
8109cc4b
...
...
@@ -50,7 +50,16 @@ class Graph(_imperative_rt.ComputingGraph):
def
execute
(
self
,
*
args
):
assert
self
.
_future
is
None
self
.
_future
=
self
.
_executor
.
submit
(
self
.
_function
.
execute
,
*
args
)
def
wrapped
(
*
args
):
try
:
self
.
_function
.
execute
(
*
args
)
except
Exception
as
exc
:
for
i
in
self
.
_function
.
_all_rendezvous
:
i
.
set_exception
(
str
(
exc
))
raise
exc
self
.
_future
=
self
.
_executor
.
submit
(
wrapped
,
*
args
)
def
wait
(
self
):
assert
self
.
_future
is
not
None
...
...
imperative/python/src/graph_rt.cpp
浏览文件 @
8109cc4b
...
...
@@ -49,17 +49,28 @@ class _CompGraphProfilerImpl {
return
json
->
to_string
();
}
};
struct
WeakRendezvousArray
:
public
std
::
vector
<
std
::
weak_ptr
<
RendezvousBase
>>
,
public
UserDataContainer
::
UserData
{
MGB_TYPEINFO_OBJ_DECL
;
};
MGB_TYPEINFO_OBJ_IMPL
(
WeakRendezvousArray
);
}
#define DEF_READWRITE(name) .def_readwrite(#name, &CURRENT_CLASS::name)
template
<
typename
T
>
auto
def_rendezvous
(
py
::
object
m
,
const
char
*
name
)
{
return
py
::
class_
<
Rendezvous
<
T
>
,
std
::
shared_ptr
<
Rendezvous
<
T
>>>
(
m
,
name
)
.
def
(
py
::
init
([](){
return
std
::
make_shared
<
Rendezvous
<
T
>>
();}))
.
def
(
py
::
init
([](){
return
Rendezvous
<
T
>::
make
();}))
.
def
(
"set"
,
[](
Rendezvous
<
T
>&
r
,
T
v
)
{
r
.
set
(
std
::
move
(
v
));})
.
def
(
"get"
,
[](
Rendezvous
<
T
>&
r
)
{
return
r
.
get
();},
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"drop"
,
&
Rendezvous
<
T
>::
drop
)
.
def
(
"reset"
,
&
Rendezvous
<
T
>::
reset
);
.
def
(
"reset"
,
&
Rendezvous
<
T
>::
reset
)
.
def
(
"set_exception"
,
[](
Rendezvous
<
T
>&
r
,
std
::
string
&&
message
)
{
r
.
set_exception
(
std
::
make_exception_ptr
(
std
::
runtime_error
(
std
::
move
(
message
))));
});
}
using
TensorAttr
=
LogicalTensorDesc
;
...
...
@@ -186,7 +197,21 @@ void init_graph_rt(py::module m) {
py
::
class_
<
cg
::
AsyncExecutable
>
(
m
,
"AsyncExecutable"
)
.
def
(
"execute"
,
&
cg
::
AsyncExecutable
::
execute
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"wait"
,
&
cg
::
AsyncExecutable
::
wait
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
.
def
(
"wait"
,
&
cg
::
AsyncExecutable
::
wait
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
// only used for exception handle
.
def_property_readonly
(
"_all_rendezvous"
,
[](
cg
::
AsyncExecutable
*
exec
)
{
auto
ud
=
exec
->
owner_graph
()
->
options
().
user_data
.
get_user_data
<
WeakRendezvousArray
>
();
std
::
vector
<
std
::
shared_ptr
<
RendezvousBase
>>
ret
;
if
(
ud
.
second
)
{
for
(
auto
&&
r
:
*
ud
.
first
[
0
])
{
if
(
auto
p
=
r
.
lock
())
{
ret
.
emplace_back
(
std
::
move
(
p
));
}
}
}
return
ret
;
});
auto
PyComputingGraph
=
py
::
class_
<
cg
::
ComputingGraph
,
std
::
shared_ptr
<
cg
::
ComputingGraph
>>
(
m
,
"ComputingGraph"
)
.
def
(
py
::
init
(
py
::
overload_cast
<>
(
&
cg
::
ComputingGraph
::
make
)))
...
...
@@ -483,7 +508,14 @@ void init_graph_rt(py::module m) {
},
py
::
arg
(),
py
::
arg
(),
py
::
arg
(),
py
::
arg
()
=
py
::
none
(),
py
::
arg
()
=
py
::
tuple
(),
py
::
arg
(
"graph"
)
=
py
::
none
());
auto
output_callback
=
[](
auto
callback
,
const
std
::
vector
<
cg
::
VarNode
*>&
inputs
,
bool
borrow
=
false
,
bool
prefer_host_value
=
false
)
{
auto
output_callback
=
[](
auto
callback
,
const
std
::
vector
<
cg
::
VarNode
*>&
inputs
,
std
::
shared_ptr
<
RendezvousBase
>
r
=
{},
bool
borrow
=
false
,
bool
prefer_host_value
=
false
)
{
if
(
r
)
{
mgb_assert
(
inputs
.
size
());
auto
cg
=
inputs
[
0
]
->
owner_graph
();
cg
->
options
().
user_data
.
get_user_data_or_create
<
WeakRendezvousArray
>
()
->
emplace_back
(
r
);
}
SymbolVarArray
sinputs
;
for
(
auto
i
:
inputs
)
{
sinputs
.
emplace_back
(
i
);
...
...
@@ -508,7 +540,7 @@ void init_graph_rt(py::module m) {
auto
f
=
[
p
](
DeviceTensorND
dv
)
{
p
->
set
(
std
::
move
(
dv
));
};
return
output_callback
(
std
::
move
(
f
),
std
::
move
(
inputs
));
return
output_callback
(
std
::
move
(
f
),
std
::
move
(
inputs
)
,
p
);
});
m
.
def
(
"value_output_callback"
,
[
output_callback
](
std
::
shared_ptr
<
Rendezvous
<
HostNDWithEvent
>>
p
,
std
::
vector
<
cg
::
VarNode
*>
inputs
)
{
...
...
@@ -519,13 +551,13 @@ void init_graph_rt(py::module m) {
hv_with_event
.
second
->
record
();
p
->
set
(
std
::
move
(
hv_with_event
));
};
return
output_callback
(
std
::
move
(
f
),
std
::
move
(
inputs
),
true
,
true
);
return
output_callback
(
std
::
move
(
f
),
std
::
move
(
inputs
),
p
,
true
,
true
);
});
m
.
def
(
"attr_output_callback"
,
[
output_callback
](
std
::
shared_ptr
<
Rendezvous
<
TensorAttr
>>
p
,
std
::
vector
<
cg
::
VarNode
*>
inputs
)
{
auto
f
=
[
p
](
DeviceTensorND
dv
)
{
p
->
set
(
TensorAttr
{
TensorLayout
{
dv
.
shape
(),
dv
.
dtype
()},
dv
.
comp_node
()});
};
return
output_callback
(
std
::
move
(
f
),
std
::
move
(
inputs
),
true
);
return
output_callback
(
std
::
move
(
f
),
std
::
move
(
inputs
),
p
,
true
);
});
}
imperative/python/src/graph_rt.h
浏览文件 @
8109cc4b
...
...
@@ -35,18 +35,36 @@ public:
PYBIND11_DECLARE_HOLDER_TYPE
(
T
,
GraphNodePtr
<
T
>
,
true
);
class
RendezvousBase
{
public:
virtual
~
RendezvousBase
()
=
default
;
virtual
void
set_exception
(
std
::
exception_ptr
p
)
=
0
;
};
template
<
typename
R
>
class
Rendezvous
{
class
Rendezvous
:
public
RendezvousBase
{
std
::
mutex
m_lock
;
int
m_read_ahead
=
0
;
bool
m_drop_next
=
false
;
std
::
promise
<
R
>
m_promise
;
public:
Rendezvous
()
=
default
;
struct
Factory
{
template
<
typename
...
Args
>
static
auto
make_rendezvous
(
Args
&&
...
args
)
{
auto
ptr
=
new
Rendezvous
<
R
>
{
std
::
forward
(
args
)...};
return
std
::
shared_ptr
<
Rendezvous
<
R
>>
(
ptr
);
}
};
public:
Rendezvous
(
const
Rendezvous
&
rhs
)
=
delete
;
Rendezvous
(
Rendezvous
&&
rhs
)
=
delete
;
Rendezvous
&
operator
=
(
const
Rendezvous
&
rhs
)
=
delete
;
template
<
typename
...
Args
>
static
auto
make
(
Args
&&
...
args
)
{
return
Factory
::
make_rendezvous
(
std
::
forward
<
Args
>
(
args
)...);
}
R
get
()
{
std
::
future
<
R
>
f
;
{
...
...
@@ -96,6 +114,29 @@ public:
m_read_ahead
=
0
;
m_drop_next
=
false
;
}
void
set_exception
(
std
::
exception_ptr
e
)
{
if
(
e
)
{
MGB_LOCK_GUARD
(
m_lock
);
if
(
m_read_ahead
>=
0
)
{
mgb_assert
(
m_read_ahead
<=
1
);
if
(
m_drop_next
)
{
m_drop_next
=
false
;
}
else
{
m_promise
.
set_exception
(
e
);
}
if
(
m_read_ahead
==
1
)
{
m_promise
=
{};
}
--
m_read_ahead
;
}
else
{
mgb_assert
(
m_read_ahead
==
-
1
);
// TODO: maybe exception should be ignored
// if value was already set ?
m_promise
.
set_exception
(
e
);
}
}
}
};
void
init_graph_rt
(
pybind11
::
module
m
);
imperative/python/test/unit/core/test_megbrain_graph.py
浏览文件 @
8109cc4b
...
...
@@ -82,3 +82,20 @@ def test_op():
f
()
np
.
testing
.
assert_equal
(
x
.
numpy
(),
-
y
.
result
().
numpy
())
def
test_exception
():
err_msg
=
"QwQ"
def
throw_exc
():
raise
RuntimeError
(
err_msg
)
g
=
mgb_graph
.
Graph
()
x
,
_
=
mgb_graph
.
input_callback
(
throw_exc
,
device
=
"xpux"
,
dtype
=
"float32"
,
graph
=
g
)
y
=
mgb_graph
.
OutputNode
(
F
.
neg
(
x
))
f
=
g
.
compile
(
y
.
outputs
[
0
])
try
:
f
.
execute
()
y
.
get_value
()
except
Exception
as
exc
:
assert
err_msg
in
str
(
exc
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录