Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
55e9c831
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
55e9c831
编写于
8月 02, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(trace): add imperative mode for debug
GitOrigin-RevId: 067b7d235e107d459b4d09f4f04627676b9073cc
上级
281ecd0b
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
127 addition
and
44 deletion
+127
-44
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+4
-0
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+4
-1
imperative/python/src/tensor_utils.cpp
imperative/python/src/tensor_utils.cpp
+3
-0
imperative/python/test/unit/jit/test_tracing.py
imperative/python/test/unit/jit/test_tracing.py
+4
-3
imperative/src/impl/transformations/trace.cpp
imperative/src/impl/transformations/trace.cpp
+79
-33
imperative/src/include/megbrain/imperative/transformations/trace.h
...e/src/include/megbrain/imperative/transformations/trace.h
+33
-7
未找到文件。
imperative/python/megengine/jit/tracing.py
浏览文件 @
55e9c831
...
...
@@ -112,6 +112,7 @@ class trace:
without_host: if True, will run python code of wrapped function on the first call,
and run the compiled graph/function on subsequent calls. if False, will run python code every time.
Default: False
imperative: if True, will use imperative runtime to execute captured op seq. Default: False
"""
third_party_backend
=
False
...
...
@@ -124,6 +125,7 @@ class trace:
def
__init__
(
self
,
function
,
*
,
symbolic
=
False
,
capture_as_const
=
False
,
record_only
=
False
,
...
...
@@ -134,6 +136,7 @@ class trace:
graph_opt_config
:
GraphOptimizationConfig
=
None
,
symbolic_shape
:
bool
=
True
,
without_host
:
bool
=
False
,
imperative
:
bool
=
False
,
):
self
.
__wrapped__
=
function
self
.
_capture_as_const
=
capture_as_const
or
record_only
...
...
@@ -204,6 +207,7 @@ class trace:
self
.
_trace
.
symbolic
=
symbolic
or
record_only
self
.
_trace
.
capture_as_const
=
capture_as_const
or
record_only
self
.
_trace
.
no_exec
=
record_only
self
.
_trace
.
imperative
=
imperative
self
.
_trace
.
options_visitor
=
apply_options
self
.
_trace
.
profile
=
profiling
self
.
_trace
.
array_comparator
=
array_comparator
...
...
imperative/python/src/tensor.cpp
浏览文件 @
55e9c831
...
...
@@ -1201,6 +1201,8 @@ void init_tensor(py::module m) {
bool
without_host
=
false
;
bool
check_external
=
true
;
bool
remove_unused_data_required
=
true
;
bool
imperative
=
false
;
py
::
function
options_visitor
;
std
::
shared_ptr
<
TracingTransformation
>
tracing
;
std
::
shared_ptr
<
CompiledTransformation
>
compiled
;
...
...
@@ -1257,7 +1259,7 @@ void init_tensor(py::module m) {
}
else
if
(
!
self
.
compiled
)
{
// traced but not compiled
using
namespace
std
::
placeholders
;
self
.
compiled
=
std
::
make_shared
<
CompiledTransformation
>
(
*
self
.
trace_result
,
self
.
record_input_shapes
);
*
self
.
trace_result
,
self
.
record_input_shapes
,
self
.
imperative
);
self
.
compiled
->
set_value_comparator
(
std
::
bind
(
&
Trace
::
compare_value
,
this
,
_1
,
_2
));
self
.
options_visitor
(
py
::
cast
(
&
self
.
compiled
->
options
()));
...
...
@@ -1405,6 +1407,7 @@ void init_tensor(py::module m) {
.
def_readwrite
(
"symbolic"
,
&
Trace
::
symbolic
)
.
def_readwrite
(
"capture_as_const"
,
&
Trace
::
capture_as_const
)
.
def_readwrite
(
"no_exec"
,
&
Trace
::
no_exec
)
.
def_readwrite
(
"imperative"
,
&
Trace
::
imperative
)
.
def_readwrite
(
"options_visitor"
,
&
Trace
::
options_visitor
)
.
def
(
"enter"
,
&
Trace
::
enter
)
.
def
(
"exit"
,
&
Trace
::
exit
)
...
...
imperative/python/src/tensor_utils.cpp
浏览文件 @
55e9c831
...
...
@@ -555,6 +555,9 @@ py::object _astensor1d_cpp(
c_args
[
flat_list
.
size
()]
=
Py_None
;
py
::
tuple
inp_tup
=
py
::
reinterpret_steal
<
py
::
tuple
>
(
convert_inputs_cpp
(
NULL
,
c_args
.
data
(),
c_args
.
size
()));
if
(
!
inp_tup
)
{
throw
py
::
error_already_set
();
}
if
(
device_obj
.
is_none
())
{
std
::
vector
<
PyObject
*>
inp
(
inp_tup
.
size
());
for
(
size_t
i
=
0
;
i
<
inp_tup
.
size
();
++
i
)
{
...
...
imperative/python/test/unit/jit/test_tracing.py
浏览文件 @
55e9c831
...
...
@@ -510,14 +510,15 @@ def test_trace_warp_perspective():
(
"((a + b), (b + c))[1] + a"
,
"((a + b), (b + c))[0] + a"
,
"input id mismatch"
),
],
)
def
test_trace_mismatch
(
normal_expr
,
mismatch_expr
,
reason
):
@
pytest
.
mark
.
parametrize
(
"imperative"
,
[
True
,
False
])
def
test_trace_mismatch
(
normal_expr
,
mismatch_expr
,
reason
,
imperative
):
a
=
tensor
([
1
,
2
,
3
,
4
])
b
=
tensor
([
5
,
6
,
7
,
8
])
c
=
tensor
([
9
,
0
,
1
,
2
])
mismatch
=
False
@
trace
(
symbolic
=
True
)
@
trace
(
symbolic
=
True
,
imperative
=
imperative
)
def
fn
(
a
,
b
,
c
):
if
not
mismatch
:
result
=
eval
(
normal_expr
)
...
...
@@ -525,7 +526,7 @@ def test_trace_mismatch(normal_expr, mismatch_expr, reason):
result
=
eval
(
mismatch_expr
)
return
result
for
i
in
range
(
20
):
for
_
in
range
(
20
):
try
:
d
=
fn
(
a
,
b
,
c
)
except
TraceError
as
e
:
...
...
imperative/src/impl/transformations/trace.cpp
浏览文件 @
55e9c831
...
...
@@ -2,6 +2,7 @@
#include <chrono>
#include <exception>
#include <optional>
#include "megbrain/gopt/inference.h"
#include "megbrain/graph/helper.h"
...
...
@@ -499,11 +500,6 @@ void CompiledTransformation::compile() {
return
accessor
;
};
std
::
vector
<
VarAccessor
>
var_accessors
(
m_vars
.
size
());
auto
exc_setter
=
std
::
bind
(
&
CompiledTransformation
::
set_exception
,
this
,
std
::
placeholders
::
_1
);
for
(
auto
&&
accessor
:
var_accessors
)
{
accessor
.
exc_setter
=
exc_setter
;
}
for
(
auto
&&
item
:
m_seq
)
{
bool
require_link
=
bool
(
item
.
op
)
&&
mm_io_ops
.
count
(
item
.
op
->
dyn_typeinfo
());
VarNodeArray
input_vars
;
...
...
@@ -579,6 +575,12 @@ void CompiledTransformation::compile() {
dep_iter
.
add
(
output_spec
.
first
);
}
}
for
(
auto
&
accessor
:
var_accessors
)
{
accessor
.
exc_setter
=
[
this
](
std
::
exception_ptr
exc
)
{
set_exception
(
exc
);
};
if
(
m_imperative
)
{
accessor
.
node
=
nullptr
;
}
}
m_executable
=
m_graph
->
compile
(
output_specs
);
mgb_assert
(
m_executable
!=
nullptr
,
"The compiled executable is nullptr."
);
...
...
@@ -601,7 +603,7 @@ void CompiledTransformation::assert_tensor_equal(ValueRef lhs, ValueRef rhs) {
trace_assert
(
m_value_comparator
(
lhs
,
rhs
),
"tensors not equals"
);
}
void
CompiledTransformation
::
trace_input
(
size_t
id
,
ValueRef
value
)
{
ValueRef
CompiledTransformation
::
trace_input
(
size_t
id
,
ValueRef
value
)
{
try
{
auto
&
var
=
m_vars
[
id
];
auto
&
var_accessor
=
m_var_accessors
[
id
];
...
...
@@ -626,32 +628,43 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
var_accessor
.
data_setter
(
value
.
dev_tensor
()
->
as_nd
());
m_setted_extern
.
insert
(
id
);
}
break
;
return
value
;
}
case
VarKind
::
Constant
:
{
// expect host value here
mgb_assert
(
var
.
bound_data
,
"const var without data bound"
);
assert_tensor_equal
(
var
.
bound_data
,
value
);
break
;
// TODO: use value
return
var
.
bound_data
;
}
case
VarKind
::
Internal
:
{
trace_assert
(
value
.
is
(
m_value_type
),
"expect internal node, got external"
);
auto
&
traced_value
=
value
.
cast
(
m_value_type
);
trace_assert
(
traced_value
.
id
()
==
id
,
"input id mismatch"
);
break
;
return
traced_value
.
get_imperative_value
()
;
}
default:
trace_assert
(
false
,
"unknown var kind"
);
}
}
catch
(
TraceError
&
)
{
throw
;
}
catch
(
const
std
::
exception
&
exc
)
{
mgb_log_error
(
"unexpected error %s"
,
exc
.
what
());
throw
;
}
catch
(...)
{
mgb_assert
(
false
,
"unexpected error"
);
}
}
auto
CompiledTransformation
::
trace_output
(
size_t
id
)
->
TracedValue
::
ref_t
{
auto
CompiledTransformation
::
trace_output
(
size_t
id
,
ValueRef
value
)
->
TracedValue
::
ref_t
{
auto
traced_value
=
m_value_type
.
make
(
id
,
&
m_vars
[
id
],
&
m_var_accessors
[
id
]);
m_weak_values
.
push_back
(
traced_value
);
if
(
m_imperative
)
{
mgb_assert
(
value
,
"imperative mode requires value"
);
traced_value
->
set_imperative_value
(
value
);
}
return
traced_value
;
}
...
...
@@ -663,6 +676,9 @@ TraceResult::SeqItem& CompiledTransformation::next_instruction() {
ShapeValue
::
ref_t
CompiledTransformation
::
TracedValue
::
shape
()
const
{
if
(
!
m_shape
)
{
trace_assert
(
m_accessor
->
shape_getter
,
"shape unreadable"
);
if
(
m_accessor
->
is_imperative
())
{
return
m_imperative_value
.
shape
();
}
m_shape
=
ShapeValue
::
make
(
ValueShape
::
from
(
m_accessor
->
shape_getter
()));
}
return
m_shape
;
...
...
@@ -675,6 +691,23 @@ DTypeValue::ref_t CompiledTransformation::TracedValue::dtype() const {
CompNodeValue
::
ref_t
CompiledTransformation
::
TracedValue
::
comp_node
()
const
{
return
m_var
->
device
;
}
DeviceValue
::
ref_t
CompiledTransformation
::
TracedValue
::
data
()
const
{
trace_assert
(
m_accessor
->
data_getter
,
"data unreadable"
);
if
(
m_accessor
->
is_imperative
())
{
return
m_imperative_value
.
dev_tensor
();
}
return
DeviceValue
::
make
(
m_accessor
->
data_getter
());
}
HostValue
::
ref_t
CompiledTransformation
::
TracedValue
::
value
()
const
{
trace_assert
(
m_accessor
->
value_getter
,
"value unreadable"
);
if
(
m_accessor
->
is_imperative
())
{
return
m_imperative_value
.
numpy
();
}
return
HostValue
::
make
(
m_accessor
->
value_getter
());
}
auto
CompiledTransformation
::
TracedValue
::
accessor
()
const
->
const
VarAccessor
&
{
return
*
m_accessor
;
}
...
...
@@ -684,12 +717,24 @@ ValueRefList CompiledTransformation::apply_op(
auto
&
item
=
next_instruction
();
trace_assert
(
inputs
.
size
()
==
item
.
inputs
.
size
(),
"input size mismatch"
);
trace_assert
(
apply_op
.
op
().
is_same
(
*
item
.
op
),
"operator mismatch"
);
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
trace_input
(
item
.
inputs
[
i
],
inputs
[
i
]);
}
ValueRefList
outputs
(
item
.
outputs
.
size
());
for
(
size_t
i
=
0
;
i
<
item
.
outputs
.
size
();
++
i
)
{
outputs
[
i
]
=
trace_output
(
item
.
outputs
[
i
]);
if
(
!
m_imperative
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
trace_input
(
item
.
inputs
[
i
],
inputs
[
i
]);
}
for
(
size_t
i
=
0
;
i
<
item
.
outputs
.
size
();
++
i
)
{
outputs
[
i
]
=
trace_output
(
item
.
outputs
[
i
],
{});
}
}
else
{
SmallVector
<
ValueRef
>
input_values
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
input_values
.
push_back
(
trace_input
(
item
.
inputs
[
i
],
inputs
[
i
]));
}
auto
&&
output_values
=
imperative
::
apply
(
apply_op
,
input_values
);
mgb_assert
(
output_values
.
size
()
==
outputs
.
size
());
for
(
size_t
i
=
0
;
i
<
item
.
outputs
.
size
();
++
i
)
{
outputs
[
i
]
=
trace_output
(
item
.
outputs
[
i
],
output_values
[
i
]);
}
}
return
outputs
;
}
...
...
@@ -698,24 +743,22 @@ ValueRefList CompiledTransformation::apply_get_attr(
const
GetAttr
&
get_attr
,
Span
<
ValueRef
>
inputs
)
{
if
(
auto
*
traced_value
=
inputs
[
0
].
as
(
m_value_type
))
{
ValueRef
output
;
auto
&
var_accessor
=
traced_value
->
accessor
();
switch
(
get_attr
.
attr
())
{
case
GetAttr
::
Shape
:
output
=
traced_value
->
shape
();
break
;
case
GetAttr
::
Data
:
trace_assert
(
var_accessor
.
data_getter
,
"data unreadable"
);
output
=
DeviceValue
::
make
(
var_accessor
.
data_getter
());
output
=
traced_value
->
data
();
break
;
case
GetAttr
::
Value
:
trace_assert
(
var_accessor
.
value_getter
,
"value unreadable"
);
output
=
HostValue
::
make
(
var_accessor
.
value_getter
());
output
=
traced_value
->
value
();
break
;
case
GetAttr
::
DType
:
output
=
traced_value
->
dtype
();
break
;
case
GetAttr
::
Device
:
output
=
traced_value
->
comp_node
();
break
;
default:
break
;
}
...
...
@@ -745,8 +788,7 @@ ValueRefList CompiledTransformation::apply_create_tensor(
if
(
!
tensor
)
{
tensor
=
imperative
::
apply
(
create_tensor
,
inputs
)[
0
];
}
trace_input
(
input_id
,
tensor
);
return
{
trace_output
(
output_id
)};
return
{
trace_output
(
output_id
,
trace_input
(
input_id
,
tensor
))};
}
ValueRefList
CompiledTransformation
::
apply_transformation
(
...
...
@@ -762,21 +804,21 @@ ValueRefList CompiledTransformation::apply_transformation(
trace_assert
(
item
.
op
==
nullptr
,
"operator mismatch"
);
trace_assert
(
item
.
inputs
.
size
()
==
1
,
"inputs size mismatch"
);
trace_assert
(
item
.
outputs
.
size
()
==
1
,
"inputs output mismatch"
);
trace_input
(
item
.
inputs
[
0
],
inputs
[
0
]);
auto
value
=
trace_input
(
item
.
inputs
[
0
],
inputs
[
0
]);
trace_assert
(
trace_mark_var
->
mark
()
==
m_vars
[
item
.
outputs
[
0
]].
mark
,
"mark mismatch"
);
return
{
trace_output
(
item
.
outputs
[
0
])};
return
{
trace_output
(
item
.
outputs
[
0
]
,
value
)};
}
else
if
(
auto
*
trace_name_var
=
op
.
as
<
RenameValue
>
())
{
auto
&
item
=
next_instruction
();
trace_assert
(
item
.
op
==
nullptr
,
"operator mismatch"
);
trace_assert
(
item
.
inputs
.
size
()
==
1
,
"inputs size mismatch"
);
trace_assert
(
item
.
outputs
.
size
()
==
1
,
"outputs size mismatch"
);
trace_input
(
item
.
inputs
[
0
],
inputs
[
0
]);
auto
value
=
trace_input
(
item
.
inputs
[
0
],
inputs
[
0
]);
trace_assert
(
trace_name_var
->
name
()
==
m_vars
[
item
.
outputs
[
0
]].
name
,
"name mismatch"
);
return
{
trace_output
(
item
.
outputs
[
0
])};
return
{
trace_output
(
item
.
outputs
[
0
]
,
value
)};
}
else
{
return
op
.
fallback
(
inputs
);
}
...
...
@@ -786,11 +828,9 @@ void CompiledTransformation::on_unregister() noexcept {
// resolve pending values
for
(
auto
&&
weak_value
:
m_weak_values
)
{
if
(
auto
traced_value
=
weak_value
.
lock
())
{
auto
&
var_accessor
=
m_var_accessors
[
traced_value
->
id
()];
auto
value
=
([
&
]()
->
ValueRef
{
try
{
trace_assert
(
var_accessor
.
data_getter
,
"data unreadable"
);
auto
dev_value
=
DeviceValue
::
make
(
var_accessor
.
data_getter
());
auto
dev_value
=
traced_value
->
data
();
return
imperative
::
apply
(
CreateTensor
(
CreateTensor
::
Common
,
dev_value
->
device
(),
...
...
@@ -821,10 +861,9 @@ void CompiledTransformation::wait() {
trace_assert
(
m_pc
==
m_seq
.
size
(),
"mismature end"
);
}
catch
(...)
{
}
mgb_assert
(
m_executable
!=
nullptr
);
std
::
unique_lock
lock
{
m_mutex
};
m_cv
.
wait
(
lock
,
[
&
]
{
return
m_graph_status
==
0
;
});
lock
.
unlock
();
if
(
!
m_imperative
)
{
wait_worker
();
}
for
(
auto
&&
box
:
m_boxes
)
{
box
->
reset
();
}
...
...
@@ -839,6 +878,13 @@ void CompiledTransformation::wait() {
}
}
void
CompiledTransformation
::
wait_worker
()
{
mgb_assert
(
m_executable
!=
nullptr
);
std
::
unique_lock
lock
{
m_mutex
};
m_cv
.
wait
(
lock
,
[
&
]
{
return
m_graph_status
==
0
;
});
lock
.
unlock
();
}
std
::
exception_ptr
CompiledTransformation
::
set_exception
(
std
::
exception_ptr
exc
)
noexcept
{
MGB_LOCK_GUARD
(
m_mutex
);
...
...
imperative/src/include/megbrain/imperative/transformations/trace.h
浏览文件 @
55e9c831
...
...
@@ -265,12 +265,14 @@ public:
using
OpKind
=
TraceResult
::
SeqItem
::
OpKind
;
struct
VarAccessor
{
VarNode
*
node
;
VarNode
*
node
;
// use imperative mode when node == nullptr
std
::
function
<
TensorShape
()
>
shape_getter
;
std
::
function
<
DeviceTensorND
()
>
data_getter
;
std
::
function
<
HostTensorND
()
>
value_getter
;
std
::
function
<
void
(
DeviceTensorND
)
>
data_setter
;
std
::
function
<
void
(
std
::
exception_ptr
)
>
exc_setter
;
bool
is_imperative
()
const
{
return
node
==
nullptr
;
}
};
class
TracedValue
final
:
public
ObjectValue
<
TracedValue
>
{
...
...
@@ -281,6 +283,7 @@ public:
mutable
ShapeValue
::
ref_t
m_shape
;
mutable
DTypeValue
::
ref_t
m_dtype
;
mutable
CompNodeValue
::
ref_t
m_comp_node
;
mutable
ValueRef
m_imperative_value
;
public:
TracedValue
(
size_t
id
,
VarInfo
*
var
,
VarAccessor
*
accessor
)
...
...
@@ -289,9 +292,12 @@ public:
ShapeValue
::
ref_t
shape
()
const
;
DTypeValue
::
ref_t
dtype
()
const
;
CompNodeValue
::
ref_t
comp_node
()
const
;
DeviceValue
::
ref_t
data
()
const
;
HostValue
::
ref_t
value
()
const
;
const
VarAccessor
&
accessor
()
const
;
void
set_exception
(
std
::
exception_ptr
exc
)
const
{
mgb_assert
(
m_accessor
->
exc_setter
,
"exc setter invalid"
);
m_accessor
->
exc_setter
(
exc
);
}
...
...
@@ -299,7 +305,11 @@ public:
return
ssprintf
(
"TracedValue{
\"
id
\"
=%zu}"
,
id
());
}
void
clear
()
override
{}
void
clear
()
override
{
m_imperative_value
=
{};
}
void
set_imperative_value
(
ValueRef
value
)
const
{
m_imperative_value
=
value
;
}
ValueRef
get_imperative_value
()
const
{
return
m_imperative_value
;
}
};
private:
...
...
@@ -322,15 +332,23 @@ private:
ComputingGraph
::
OutputSpec
m_output_spec
;
ObjectType
<
TracedValue
>
m_value_type
{
"TracedValue"
};
std
::
set
<
size_t
>
m_setted_extern
;
bool
m_imperative
=
false
;
public:
CompiledTransformation
(
TraceResult
result
,
bool
input_shape_static
)
CompiledTransformation
(
TraceResult
result
,
bool
input_shape_static
,
bool
imperative
)
:
m_seq
(
result
.
seq
),
m_vars
(
result
.
vars
),
m_input_shape_static
(
input_shape_static
)
{
m_input_shape_static
(
input_shape_static
),
m_imperative
(
imperative
)
{
m_graph
=
ComputingGraph
::
make
();
options
().
no_force_inplace
=
true
;
options
().
async_exec_level
=
0b100
;
if
(
!
m_imperative
)
{
start_worker
();
}
}
void
start_worker
()
{
m_graph_executor
=
std
::
thread
([
&
]
{
while
(
true
)
{
std
::
unique_lock
lock
{
m_mutex
};
...
...
@@ -384,7 +402,7 @@ public:
* \param id
* \param value
*/
void
trace_input
(
size_t
id
,
ValueRef
value
);
ValueRef
trace_input
(
size_t
id
,
ValueRef
value
);
/**
* \brief make a placeholder for output.
...
...
@@ -393,7 +411,7 @@ public:
* \return TracedValue::ref_t output placeholder, would be reset to real value when
* trace exits
*/
TracedValue
::
ref_t
trace_output
(
size_t
id
);
TracedValue
::
ref_t
trace_output
(
size_t
id
,
ValueRef
value
);
TraceResult
::
SeqItem
&
next_instruction
();
...
...
@@ -422,6 +440,8 @@ public:
void
wait
();
void
wait_worker
();
std
::
exception_ptr
set_exception
(
std
::
exception_ptr
exc
)
noexcept
;
template
<
typename
T
>
...
...
@@ -431,7 +451,7 @@ public:
return
box
;
}
~
CompiledTransformation
()
{
void
stop_worker
()
{
{
MGB_LOCK_GUARD
(
m_mutex
);
m_graph_status
=
2
;
...
...
@@ -439,6 +459,12 @@ public:
m_cv
.
notify_all
();
m_graph_executor
.
join
();
}
~
CompiledTransformation
()
{
if
(
!
m_imperative
)
{
stop_worker
();
}
}
};
}
// namespace mgb::imperative
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录