Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
4f240ec2
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
4f240ec2
编写于
4月 06, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mge/jit): make trace return any kind of output
GitOrigin-RevId: fd1265c661e7f4d750f2c13599113b874c949ba5
上级
c6b552cf
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
91 addition
and
60 deletion
+91
-60
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+29
-48
imperative/python/src/grad.cpp
imperative/python/src/grad.cpp
+4
-2
imperative/python/src/grad.h
imperative/python/src/grad.h
+1
-0
imperative/python/src/grad_override.cpp
imperative/python/src/grad_override.cpp
+16
-0
imperative/python/src/graph_rt.cpp
imperative/python/src/graph_rt.cpp
+1
-0
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+10
-7
imperative/python/src/tensor.h
imperative/python/src/tensor.h
+1
-0
imperative/python/src/trace.cpp
imperative/python/src/trace.cpp
+1
-1
imperative/python/src/trace_info.h
imperative/python/src/trace_info.h
+0
-2
imperative/src/impl/ops/utility.cpp
imperative/src/impl/ops/utility.cpp
+14
-0
imperative/src/include/megbrain/imperative/ops/utility.h
imperative/src/include/megbrain/imperative/ops/utility.h
+14
-0
未找到文件。
imperative/python/megengine/jit/tracing.py
浏览文件 @
4f240ec2
...
@@ -170,9 +170,9 @@ class trace:
...
@@ -170,9 +170,9 @@ class trace:
self
.
_graph
=
None
self
.
_graph
=
None
self
.
_need_reset_nodes
=
None
self
.
_need_reset_nodes
=
None
self
.
_lazy_eval_graph
=
None
self
.
_lazy_eval_graph
=
None
self
.
_lazy_eval_tensors
=
{}
self
.
_lazy_eval_tensors
=
set
()
self
.
_lazy_eval_links
=
None
self
.
_lazy_eval_links
=
None
self
.
_active_tensors
=
{}
self
.
_active_tensors
=
set
()
self
.
_tensor_remaps
=
None
self
.
_tensor_remaps
=
None
self
.
_inputs_to_restore
=
None
self
.
_inputs_to_restore
=
None
self
.
_arg_bindings
=
None
self
.
_arg_bindings
=
None
...
@@ -258,7 +258,7 @@ class trace:
...
@@ -258,7 +258,7 @@ class trace:
y
.
_compiled_info
=
CompiledTensorProxy
(
h
)
y
.
_compiled_info
=
CompiledTensorProxy
(
h
)
y
.
_mixin_handle
=
h
y
.
_mixin_handle
=
h
outputs
+=
[
y
]
outputs
+=
[
y
]
self
.
_active_tensors
[
h
]
=
TensorWeakRef
(
y
)
self
.
_active_tensors
.
add
(
TensorWeakRef
(
y
)
)
self
.
_output_handles
.
update
(
ohandles
)
self
.
_output_handles
.
update
(
ohandles
)
return
outputs
return
outputs
...
@@ -318,9 +318,9 @@ class trace:
...
@@ -318,9 +318,9 @@ class trace:
x
.
_mixin_handle
=
h
x
.
_mixin_handle
=
h
x
.
_recording
=
True
x
.
_recording
=
True
x
.
_trace_mixin_info
=
info
x
.
_trace_mixin_info
=
info
self
.
_active_tensors
[
h
]
=
TensorWeakRef
(
x
)
self
.
_active_tensors
.
add
(
TensorWeakRef
(
x
)
)
if
self
.
_symbolic
:
if
self
.
_symbolic
:
self
.
_lazy_eval_tensors
[
h
]
=
TensorWeakRef
(
x
)
self
.
_lazy_eval_tensors
.
add
(
TensorWeakRef
(
x
)
)
self
.
_seq
.
append
((
op
,
tuple
(
ihandles
),
tuple
(
ohandles
)))
self
.
_seq
.
append
((
op
,
tuple
(
ihandles
),
tuple
(
ohandles
)))
...
@@ -345,7 +345,7 @@ class trace:
...
@@ -345,7 +345,7 @@ class trace:
x
.
_recording
=
True
x
.
_recording
=
True
x
.
_trace_mixin_info
=
info
x
.
_trace_mixin_info
=
info
if
self
.
_symbolic
:
if
self
.
_symbolic
:
self
.
_lazy_eval_tensors
[
h
]
=
TensorWeakRef
(
x
)
self
.
_lazy_eval_tensors
.
add
(
TensorWeakRef
(
x
)
)
self
.
_seq
.
append
((
"Const"
,
tuple
(),
tuple
(
ohandles
)))
self
.
_seq
.
append
((
"Const"
,
tuple
(),
tuple
(
ohandles
)))
def
_set_active
(
self
,
active
:
bool
):
def
_set_active
(
self
,
active
:
bool
):
...
@@ -365,17 +365,14 @@ class trace:
...
@@ -365,17 +365,14 @@ class trace:
self
.
_lazy_eval_links
=
()
self
.
_lazy_eval_links
=
()
def
_take_escaped_tensors
(
self
):
def
_take_escaped_tensors
(
self
):
escaped_tensors
=
tuple
(
escaped_tensors
=
tuple
(
filter
(
lambda
x
:
x
()
is
not
None
,
self
.
_active_tensors
))
filter
(
lambda
x
:
x
()
is
not
None
,
self
.
_active_tensors
.
values
())
)
self
.
_active_tensors
.
clear
()
self
.
_active_tensors
.
clear
()
return
escaped_tensors
return
escaped_tensors
def
_lazy_eval
(
self
,
lazy_eval_graph
,
lazy_eval_tensors
,
lazy_eval_links
):
def
_lazy_eval
(
self
,
lazy_eval_graph
,
lazy_eval_tensors
,
lazy_eval_links
):
lazy_eval_tensors
=
list
(
lazy_eval_tensors
=
[
x
()
for
x
in
lazy_eval_tensors
]
filter
(
lambda
x
:
x
()
is
not
None
,
lazy_eval_tensors
.
values
())
lazy_eval_tensors
=
[
x
for
x
in
lazy_eval_tensors
if
x
is
not
None
]
)
readers
=
[
G
.
OutputNode
(
x
.
_varnode
).
outputs
[
0
]
for
x
in
lazy_eval_tensors
]
readers
=
[
G
.
OutputNode
(
x
().
_varnode
).
outputs
[
0
]
for
x
in
lazy_eval_tensors
]
self
.
_apply_graph_options
(
lazy_eval_graph
)
self
.
_apply_graph_options
(
lazy_eval_graph
)
lazy_eval_graph
.
options
.
graph_opt_level
=
self
.
_graph_opt_level
lazy_eval_graph
.
options
.
graph_opt_level
=
self
.
_graph_opt_level
lazy_eval_graph
.
_set_priority_to_id
([
*
lazy_eval_links
,
*
readers
])
lazy_eval_graph
.
_set_priority_to_id
([
*
lazy_eval_links
,
*
readers
])
...
@@ -383,8 +380,8 @@ class trace:
...
@@ -383,8 +380,8 @@ class trace:
lazy_eval_graph
()
lazy_eval_graph
()
for
r
,
x
in
zip
(
readers
,
lazy_eval_tensors
):
for
r
,
x
in
zip
(
readers
,
lazy_eval_tensors
):
# get values from lazy_eval_graph and assign to lazy_eval tensor
# get values from lazy_eval_graph and assign to lazy_eval tensor
x
()
.
_handle
=
RawTensor
(
r
.
op
.
get_value
()).
_handle
x
.
_handle
=
RawTensor
(
r
.
op
.
get_value
()).
_handle
x
()
.
_reset_varnode
()
x
.
_reset_varnode
()
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
_setup
(
self
):
def
_setup
(
self
):
...
@@ -454,13 +451,14 @@ class trace:
...
@@ -454,13 +451,14 @@ class trace:
raise
TraceMismatchError
(
"premature end"
)
raise
TraceMismatchError
(
"premature end"
)
if
not
self
.
_symbolic
or
not
self
.
_untraced
:
if
not
self
.
_symbolic
or
not
self
.
_untraced
:
# reset output tensors
# reset output tensors
for
x
in
self
.
_active_tensors
.
values
():
for
x
in
self
.
_active_tensors
.
copy
():
if
x
()
is
not
None
:
strong_x
=
x
()
x
().
_dev_tensor
()
if
strong_x
is
not
None
:
x
().
_reset_varnode
()
strong_x
.
_dev_tensor
()
x
().
_mixin_handle
=
-
1
strong_x
.
_reset_varnode
()
x
().
_recording
=
False
strong_x
.
_mixin_handle
=
-
1
x
().
_trace_mixin_info
=
None
strong_x
.
_recording
=
False
strong_x
.
_trace_mixin_info
=
None
try
:
try
:
do_enter
()
do_enter
()
...
@@ -482,15 +480,17 @@ class trace:
...
@@ -482,15 +480,17 @@ class trace:
if
self
.
_untraced
:
if
self
.
_untraced
:
# conditionally reading a compiled tensor in excluded region
# conditionally reading a compiled tensor in excluded region
# is permitted, so we have to assume every tensor might be read
# is permitted, so we have to assume every tensor might be read
for
x
in
self
.
_active_tensors
.
values
():
for
x
in
self
.
_active_tensors
:
if
x
():
strong_x
=
x
()
info
=
self
.
_tinfo
[
x
().
_mixin_handle
]
if
strong_x
:
info
=
self
.
_tinfo
[
strong_x
.
_mixin_handle
]
info
.
exported
=
True
info
.
exported
=
True
info
.
data_read
=
True
info
.
data_read
=
True
else
:
else
:
for
x
in
self
.
_active_tensors
.
values
():
for
x
in
self
.
_active_tensors
:
if
x
():
strong_x
=
x
()
x
().
_dev_tensor
()
if
strong_x
:
strong_x
.
_dev_tensor
()
def
_apply_graph_options
(
self
,
graph
):
def
_apply_graph_options
(
self
,
graph
):
...
@@ -520,7 +520,6 @@ class trace:
...
@@ -520,7 +520,6 @@ class trace:
graph
=
self
.
_graph
=
G
.
Graph
()
graph
=
self
.
_graph
=
G
.
Graph
()
graph
.
options
.
async_exec_level
=
0b100
graph
.
options
.
async_exec_level
=
0b100
self
.
_apply_graph_options
(
graph
)
self
.
_apply_graph_options
(
graph
)
# graph.options.graph_opt_level = 0
need_reset_nodes
=
self
.
_need_reset_nodes
=
[]
need_reset_nodes
=
self
.
_need_reset_nodes
=
[]
# links enforce ordering of I/O nodes
# links enforce ordering of I/O nodes
in_out_links
=
()
in_out_links
=
()
...
@@ -563,7 +562,7 @@ class trace:
...
@@ -563,7 +562,7 @@ class trace:
if
not
hasattr
(
info
,
"varnode"
):
if
not
hasattr
(
info
,
"varnode"
):
assert
info
.
external
assert
info
.
external
if
info
.
bound_data
:
if
info
.
bound_data
:
if
hasattr
(
info
,
"is_const"
)
and
info
.
is_const
:
if
getattr
(
info
,
"is_const"
,
False
)
:
info
.
varnode
=
graph
.
make_const
(
info
.
varnode
=
graph
.
make_const
(
info
.
bound_data
.
numpy
(),
info
.
bound_data
.
numpy
(),
info
.
bound_data
.
dtype
,
info
.
bound_data
.
dtype
,
...
@@ -635,30 +634,12 @@ class trace:
...
@@ -635,30 +634,12 @@ class trace:
opnode
.
reset
()
opnode
.
reset
()
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
if
is_tracing
():
return
self
.
__wrapped__
(
*
args
,
**
kwargs
)
with
self
.
_setup
():
with
self
.
_setup
():
if
self
.
_capture_as_const
:
if
self
.
_capture_as_const
:
self
.
_process_inputs
(
*
args
,
**
kwargs
)
self
.
_process_inputs
(
*
args
,
**
kwargs
)
outputs
=
self
.
__wrapped__
(
*
args
,
**
kwargs
)
outputs
=
self
.
__wrapped__
(
*
args
,
**
kwargs
)
if
self
.
_capture_as_const
:
if
self
.
_capture_as_const
:
self
.
_process_outputs
(
outputs
)
self
.
_process_outputs
(
outputs
)
# outputs could be None
if
outputs
is
not
None
:
list_outputs
=
outputs
if
isinstance
(
outputs
,
collections
.
abc
.
Mapping
):
_
,
list_outputs
=
zip
(
*
sorted
(
outputs
.
items
()))
elif
not
isinstance
(
outputs
,
collections
.
abc
.
Sequence
):
list_outputs
=
(
outputs
,)
for
o
in
list_outputs
:
# if outputs are copied, then use the newest info in trace data structure
if
o
.
_copied
:
self
.
_active_tensors
[
o
.
_mixin_handle
]
=
TensorWeakRef
(
o
)
if
self
.
_untraced
and
self
.
_symbolic
:
self
.
_lazy_eval_tensors
[
o
.
_mixin_handle
]
=
TensorWeakRef
(
o
)
return
outputs
return
outputs
def
dump
(
def
dump
(
...
...
imperative/python/src/grad.cpp
浏览文件 @
4f240ec2
...
@@ -9,11 +9,12 @@
...
@@ -9,11 +9,12 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
*/
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
#include "./grad.h"
#include "./grad.h"
#include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/imperative/backward_graph_opt.h"
#include "megbrain/imperative/backward_graph_opt.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/utils/mempool.h"
#include "megbrain/utils/mempool.h"
#include "range/v3/all.hpp"
#include "range/v3/all.hpp"
...
@@ -434,7 +435,8 @@ apply_result_t apply_grad(ApplyContext& ctx) {
...
@@ -434,7 +435,8 @@ apply_result_t apply_grad(ApplyContext& ctx) {
if
(
backward
.
output_requires_grad
(
i
))
{
if
(
backward
.
output_requires_grad
(
i
))
{
if
(
backward
.
output_captured
(
i
))
{
if
(
backward
.
output_captured
(
i
))
{
// avoid reference cycle [Tensor <-> GradFn]
// avoid reference cycle [Tensor <-> GradFn]
outputs
[
i
]
=
outputs
[
i
]
->
copy
();
static
std
::
shared_ptr
<
OpDef
>
op
=
std
::
shared_ptr
<
OpDef
>
(
new
FastpathCopy
());
outputs
[
i
]
=
python
::
apply
(
op
,
outputs
[
i
])[
0
];
}
}
// populate grad info of output tensor
// populate grad info of output tensor
auto
&
grad_info
=
outputs
[
i
]
->
m_grad_info
;
auto
&
grad_info
=
outputs
[
i
]
->
m_grad_info
;
...
...
imperative/python/src/grad.h
浏览文件 @
4f240ec2
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#pragma once
#pragma once
#include "./tensor.h"
#include "./tensor.h"
#include "megbrain/imperative/ops/utility.h"
#include <megbrain/utils/small_vector.h>
#include <megbrain/utils/small_vector.h>
#include <memory>
#include <memory>
...
...
imperative/python/src/grad_override.cpp
浏览文件 @
4f240ec2
...
@@ -221,6 +221,21 @@ apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& ma
...
@@ -221,6 +221,21 @@ apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& ma
return
apply
(
ctx
);
return
apply
(
ctx
);
}
}
apply_result_t
fastpathcopy_grad_rule
(
ApplyContext
&
ctx
,
CustomBackward
::
Maker
&
maker
)
{
mgb_assert
(
ctx
.
nargs
==
1
);
maker
.
output_size
(
1
).
output_captured
(
0
,
false
);
maker
.
backward
([](
BackwardContext
&
,
Tensor
*
const
*
grads
,
size_t
ngrads
)
{
mgb_assert
(
ngrads
==
1
);
Tensor
*
grad
=
grads
[
0
];
apply_result_t
ret
(
1
);
if
(
grad
)
{
ret
[
0
]
=
grad
->
shared_from_this
();
}
return
ret
;
});
return
apply
(
ctx
);
}
struct
Init
{
struct
Init
{
Init
()
{
Init
()
{
auto
&
reg
=
grad_rule_registry
();
auto
&
reg
=
grad_rule_registry
();
...
@@ -231,6 +246,7 @@ struct Init {
...
@@ -231,6 +246,7 @@ struct Init {
reg
.
emplace
(
Reduce
::
typeinfo
(),
reduce_grad_rule
);
reg
.
emplace
(
Reduce
::
typeinfo
(),
reduce_grad_rule
);
reg
.
emplace
(
AddAxis
::
typeinfo
(),
addAxis_grad_rule
);
reg
.
emplace
(
AddAxis
::
typeinfo
(),
addAxis_grad_rule
);
reg
.
emplace
(
RemoveAxis
::
typeinfo
(),
removeAxis_grad_rule
);
reg
.
emplace
(
RemoveAxis
::
typeinfo
(),
removeAxis_grad_rule
);
reg
.
emplace
(
FastpathCopy
::
typeinfo
(),
fastpathcopy_grad_rule
);
}
}
}
_
;
}
_
;
...
...
imperative/python/src/graph_rt.cpp
浏览文件 @
4f240ec2
...
@@ -23,6 +23,7 @@
...
@@ -23,6 +23,7 @@
#include "./common.h"
#include "./common.h"
#include "./ops.h"
#include "./ops.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/imperative/ops/utility.h"
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
...
...
imperative/python/src/tensor.cpp
浏览文件 @
4f240ec2
...
@@ -118,9 +118,18 @@ apply_result_t apply(ApplyContext& ctx) {
...
@@ -118,9 +118,18 @@ apply_result_t apply(ApplyContext& ctx) {
handles
[
i
]
=
ctx
.
args
[
i
]
->
m_handle
.
get
();
handles
[
i
]
=
ctx
.
args
[
i
]
->
m_handle
.
get
();
}
}
apply_result_t
outputs
;
// fast copy without really applying
if
(
ctx
.
op
->
same_type
<
FastpathCopy
>
())
{
mgb_assert
(
ctx
.
nargs
==
1
);
outputs
.
reserve
(
ctx
.
nargs
);
outputs
.
emplace_back
(
std
::
make_shared
<
Tensor
>
(
ctx
.
args
[
0
]
->
m_handle
));
return
outputs
;
}
auto
output_handles
=
interpreter_for_py
->
apply_op
(
ctx
.
op
,
handles
);
auto
output_handles
=
interpreter_for_py
->
apply_op
(
ctx
.
op
,
handles
);
apply_result_t
outputs
;
outputs
.
reserve
(
output_handles
.
size
());
outputs
.
reserve
(
output_handles
.
size
());
for
(
auto
h
:
output_handles
)
{
for
(
auto
h
:
output_handles
)
{
outputs
.
emplace_back
(
std
::
make_shared
<
Tensor
>
(
h
));
outputs
.
emplace_back
(
std
::
make_shared
<
Tensor
>
(
h
));
...
@@ -303,11 +312,6 @@ REGISTE_TENSORWRAPPER_FUNC(bool, recording)
...
@@ -303,11 +312,6 @@ REGISTE_TENSORWRAPPER_FUNC(bool, recording)
#undef REGISTE_TENSORWRAPPER_FUNC
#undef REGISTE_TENSORWRAPPER_FUNC
PyObject
*
TensorWrapper
::
copied
()
{
return
py
::
cast
(
m_tensor
->
m_trace_info
.
copied
).
release
().
ptr
();
}
#define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \
#define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \
PyObject* TensorWrapper::member() { \
PyObject* TensorWrapper::member() { \
if (m_tensor->m_trace_info.member) { \
if (m_tensor->m_trace_info.member) { \
...
@@ -841,7 +845,6 @@ void init_tensor(py::module m) {
...
@@ -841,7 +845,6 @@ void init_tensor(py::module m) {
.
def
<&
TensorWrapper
::
reset_varnode
>
(
"_reset_varnode"
)
.
def
<&
TensorWrapper
::
reset_varnode
>
(
"_reset_varnode"
)
.
def
<&
TensorWrapper
::
_use_cnt
>
(
"_use_cnt"
)
.
def
<&
TensorWrapper
::
_use_cnt
>
(
"_use_cnt"
)
.
def_getset
<&
TensorWrapper
::
varnode
>
(
"_varnode"
)
.
def_getset
<&
TensorWrapper
::
varnode
>
(
"_varnode"
)
.
def_getset
<&
TensorWrapper
::
copied
>
(
"_copied"
)
.
def_getset
<&
TensorWrapper
::
mixin_handle
,
&
TensorWrapper
::
set_mixin_handle
>
(
"_mixin_handle"
)
.
def_getset
<&
TensorWrapper
::
mixin_handle
,
&
TensorWrapper
::
set_mixin_handle
>
(
"_mixin_handle"
)
.
def_getset
<&
TensorWrapper
::
recording
,
&
TensorWrapper
::
set_recording
>
(
"_recording"
)
.
def_getset
<&
TensorWrapper
::
recording
,
&
TensorWrapper
::
set_recording
>
(
"_recording"
)
.
def_getset
<&
TensorWrapper
::
handle
,
&
TensorWrapper
::
set_handle
>
(
"_handle"
)
.
def_getset
<&
TensorWrapper
::
handle
,
&
TensorWrapper
::
set_handle
>
(
"_handle"
)
...
...
imperative/python/src/tensor.h
浏览文件 @
4f240ec2
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
*/
*/
#pragma once
#pragma once
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
#include <variant>
#include <variant>
...
...
imperative/python/src/trace.cpp
浏览文件 @
4f240ec2
...
@@ -35,7 +35,7 @@ apply_result_t apply_trace(ApplyContext& ctx) {
...
@@ -35,7 +35,7 @@ apply_result_t apply_trace(ApplyContext& ctx) {
// assumption: python function always returns PyList
// assumption: python function always returns PyList
auto
tup
=
py
::
reinterpret_borrow
<
py
::
list
>
(
ret
);
auto
tup
=
py
::
reinterpret_borrow
<
py
::
list
>
(
ret
);
for
(
auto
i
=
0
;
i
<
tup
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
tup
.
size
();
i
++
)
{
auto
pitem
=
tup
[
i
].
cast
<
cg
::
VarNode
*>
();
auto
pitem
=
tup
[
i
].
cast
<
cg
::
VarNode
*>
();
outputs
.
emplace_back
(
std
::
make_shared
<
Tensor
>
(
pitem
));
outputs
.
emplace_back
(
std
::
make_shared
<
Tensor
>
(
pitem
));
}
}
...
...
imperative/python/src/trace_info.h
浏览文件 @
4f240ec2
...
@@ -17,7 +17,6 @@ namespace mgb::imperative::python {
...
@@ -17,7 +17,6 @@ namespace mgb::imperative::python {
struct
TraceInfo
{
struct
TraceInfo
{
int64_t
mixin_handle
=
-
1
;
int64_t
mixin_handle
=
-
1
;
bool
recording
=
false
;
bool
recording
=
false
;
bool
copied
=
false
;
// refer to CompiledTensorProxy in tracing.py, works from second trace step
// refer to CompiledTensorProxy in tracing.py, works from second trace step
PyObject
*
compiled_info
=
nullptr
;
PyObject
*
compiled_info
=
nullptr
;
...
@@ -35,7 +34,6 @@ struct TraceInfo {
...
@@ -35,7 +34,6 @@ struct TraceInfo {
compiled_info
=
that
.
compiled_info
;
compiled_info
=
that
.
compiled_info
;
Py_XINCREF
(
compiled_info
);
Py_XINCREF
(
compiled_info
);
copied
=
true
;
return
*
this
;
return
*
this
;
}
}
...
...
imperative/src/impl/ops/utility.cpp
浏览文件 @
4f240ec2
...
@@ -18,4 +18,18 @@ namespace mgb::imperative {
...
@@ -18,4 +18,18 @@ namespace mgb::imperative {
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
GenericPyOp
);
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
GenericPyOp
);
namespace
{
namespace
fastpathcopy
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
return
inputs
;
}
OP_TRAIT_REG
(
FastpathCopy
,
FastpathCopy
)
.
apply_on_var_node
(
apply_on_var_node
)
.
fallback
();
}}
// fastpathcopy
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
FastpathCopy
);
}
// namespace mgb::imperative
}
// namespace mgb::imperative
imperative/src/include/megbrain/imperative/ops/utility.h
浏览文件 @
4f240ec2
...
@@ -35,4 +35,18 @@ struct GenericPyOp final : OpDefImplBase<GenericPyOp> {
...
@@ -35,4 +35,18 @@ struct GenericPyOp final : OpDefImplBase<GenericPyOp> {
MGB_DYN_TYPE_OBJ_FINAL_DECL
;
MGB_DYN_TYPE_OBJ_FINAL_DECL
;
};
};
struct
FastpathCopy
final
:
OpDefImplBase
<
FastpathCopy
>
{
FastpathCopy
()
=
default
;
size_t
hash
()
const
override
{
return
mgb
::
hash
(
this
->
dyn_typeinfo
());
}
bool
is_same_st
(
const
Hashable
&
rhs
)
const
override
{
return
this
->
dyn_typeinfo
()
==
rhs
.
dyn_typeinfo
();
}
MGB_DYN_TYPE_OBJ_FINAL_DECL
;
};
}
// namespace mgb::imperative
}
// namespace mgb::imperative
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录