Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
d94a17d3
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
d94a17d3
编写于
5月 13, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mge/jit): using static global_enable for apply ctx insted of global variable
GitOrigin-RevId: dd82b53faf55aa0d01ab181b54f48d63e143384c
上级
9eacb9df
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
12 addition
and
34 deletion
+12
-34
imperative/python/megengine/autodiff/grad_manager.py
imperative/python/megengine/autodiff/grad_manager.py
+0
-2
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+0
-4
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+6
-19
imperative/python/src/tensor.h
imperative/python/src/tensor.h
+2
-6
imperative/python/src/trace.cpp
imperative/python/src/trace.cpp
+4
-3
未找到文件。
imperative/python/megengine/autodiff/grad_manager.py
浏览文件 @
d94a17d3
import
weakref
import
weakref
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
typing
import
Callable
,
Iterable
from
typing
import
Callable
,
Iterable
from
..core._imperative_rt.core2
import
pop_scope
,
push_scope
,
set_option
from
..core._imperative_rt.core2
import
pop_scope
,
push_scope
,
set_option
...
...
imperative/python/megengine/jit/tracing.py
浏览文件 @
d94a17d3
...
@@ -1125,10 +1125,6 @@ def apply_compiled_mode(op: OpDef, *args: RawTensor):
...
@@ -1125,10 +1125,6 @@ def apply_compiled_mode(op: OpDef, *args: RawTensor):
def
apply_const_compiled_mode
(
value
,
dtype
,
device
,
is_const
,
no_cache
,
name
):
def
apply_const_compiled_mode
(
value
,
dtype
,
device
,
is_const
,
no_cache
,
name
):
if
skip_tracing
:
if
skip_tracing
:
args
=
[
RawTensor
(
x
.
_dev_tensor
())
if
x
.
__class__
is
CompiledTensorProxy
else
x
for
x
in
args
]
unset_tracing
()
unset_tracing
()
ret
=
RawTensor
(
value
,
dtype
,
device
,
False
,
name
)
ret
=
RawTensor
(
value
,
dtype
,
device
,
False
,
name
)
set_tracing
()
set_tracing
()
...
...
imperative/python/src/tensor.cpp
浏览文件 @
d94a17d3
...
@@ -50,29 +50,20 @@ REGISTE_APPLY_FUNC(cpp_apply_backward_varnode)
...
@@ -50,29 +50,20 @@ REGISTE_APPLY_FUNC(cpp_apply_backward_varnode)
#undef REGISTE_APPLY_FUNC
#undef REGISTE_APPLY_FUNC
bool
is_tracing
=
false
;
Tensor
::
flags_t
ApplyContext
::
global_disable
=
0
;
Tensor
::
flags_t
ApplyContext
::
global_enable
=
0
;
#define SET_UNSET_PROP(mode) \
void set_##mode() { \
is_##mode = true; \
} \
void unset_##mode() { \
is_##mode = false; \
} \
SET_UNSET_PROP
(
tracing
)
#undef SET_UNSET_PROP
void
set_tracing
()
{
ApplyContext
::
global_enable
|=
Tensor
::
Flags
::
TRACE
;
}
void
unset_tracing
()
{
ApplyContext
::
global_enable
&=
~
Tensor
::
Flags
::
TRACE
;
}
bool
skip_tracing
=
false
;
bool
skip_tracing
=
false
;
Tensor
::
flags_t
ApplyContext
::
global_disable
=
0
;
apply_result_t
apply
(
ApplyContext
&
ctx
)
{
apply_result_t
apply
(
ApplyContext
&
ctx
)
{
// emulating scalar should be put to specific op's apply, e.g.,
// emulating scalar should be put to specific op's apply, e.g.,
// elementwise, reduce, typecvt. Currently it's still handled at python
// elementwise, reduce, typecvt. Currently it's still handled at python
// side. It could be move to C++ side if it has an impact on performance
// side. It could be move to C++ side if it has an impact on performance
auto
flags
=
ctx
.
flags
&
~
ApplyContext
::
global_disable
;
auto
flags
=
ctx
.
flags
&
~
ApplyContext
::
global_disable
;
flags
=
flags
|
ApplyContext
::
global_enable
;
if
(
flags
&
Tensor
::
Flags
::
SCALAR
)
{
if
(
flags
&
Tensor
::
Flags
::
SCALAR
)
{
// TODO: emulate scalar
// TODO: emulate scalar
...
@@ -190,10 +181,6 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
...
@@ -190,10 +181,6 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
}
}
}
}
if
(
is_tracing
)
{
ctx
.
flags
|=
Tensor
::
Flags
::
TRACE
;
}
auto
outputs
=
apply
(
ctx
);
auto
outputs
=
apply
(
ctx
);
size_t
nout
=
outputs
.
size
();
size_t
nout
=
outputs
.
size
();
auto
ret
=
py
::
tuple
(
nout
);
auto
ret
=
py
::
tuple
(
nout
);
...
@@ -255,7 +242,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
...
@@ -255,7 +242,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
if
(
tup
[
nargs
-
1
].
ptr
()
!=
Py_None
)
name
=
tup
[
nargs
-
1
].
cast
<
std
::
string
>
();
if
(
tup
[
nargs
-
1
].
ptr
()
!=
Py_None
)
name
=
tup
[
nargs
-
1
].
cast
<
std
::
string
>
();
// const op
// const op
if
(
is_const
&&
is_tracing
)
{
if
(
is_const
&&
(
ApplyContext
::
global_enable
==
Tensor
::
Flags
::
TRACE
)
)
{
auto
py_ret
=
PyObject_Call
(
cpp_apply_const_with_tracing
,
tup
.
ptr
(),
nullptr
);
auto
py_ret
=
PyObject_Call
(
cpp_apply_const_with_tracing
,
tup
.
ptr
(),
nullptr
);
if
(
!
py_ret
)
throw
py
::
error_already_set
();
if
(
!
py_ret
)
throw
py
::
error_already_set
();
auto
py_list
=
py
::
reinterpret_steal
<
py
::
list
>
(
py_ret
);
auto
py_list
=
py
::
reinterpret_steal
<
py
::
list
>
(
py_ret
);
...
...
imperative/python/src/tensor.h
浏览文件 @
d94a17d3
...
@@ -193,8 +193,9 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
...
@@ -193,8 +193,9 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
struct
ApplyContext
{
struct
ApplyContext
{
static
Tensor
::
flags_t
global_disable
;
static
Tensor
::
flags_t
global_disable
;
static
Tensor
::
flags_t
global_enable
;
Tensor
::
flags_t
flags
;
Tensor
::
flags_t
flags
=
0
;
std
::
shared_ptr
<
OpDef
>
op
;
std
::
shared_ptr
<
OpDef
>
op
;
Tensor
*
const
*
args
;
Tensor
*
const
*
args
;
size_t
nargs
;
size_t
nargs
;
...
@@ -236,14 +237,11 @@ decltype(auto) resolve_arrow(T&& p) {
...
@@ -236,14 +237,11 @@ decltype(auto) resolve_arrow(T&& p) {
template
<
typename
...
Args
>
template
<
typename
...
Args
>
constexpr
bool
is_all_tensor_ptr
=
(...
&&
std
::
is_same_v
<
decltype
(
resolve_arrow
(
std
::
declval
<
Args
>
())),
Tensor
*>
);
constexpr
bool
is_all_tensor_ptr
=
(...
&&
std
::
is_same_v
<
decltype
(
resolve_arrow
(
std
::
declval
<
Args
>
())),
Tensor
*>
);
extern
bool
is_tracing
;
// FIXME: should use ApplyContext::global_enable
template
<
typename
...
Args
,
std
::
enable_if_t
<
is_all_tensor_ptr
<
Args
...>,
int
>
=
0
>
template
<
typename
...
Args
,
std
::
enable_if_t
<
is_all_tensor_ptr
<
Args
...>,
int
>
=
0
>
apply_result_t
apply
(
std
::
shared_ptr
<
OpDef
>
op
,
Args
&&
...
args
)
{
apply_result_t
apply
(
std
::
shared_ptr
<
OpDef
>
op
,
Args
&&
...
args
)
{
ApplyContext
ctx
;
ApplyContext
ctx
;
Tensor
*
arg_arr
[]
=
{
resolve_arrow
(
args
)...};
Tensor
*
arg_arr
[]
=
{
resolve_arrow
(
args
)...};
ctx
.
flags
=
(
0
|
...
|
args
->
m_flags
);
ctx
.
flags
=
(
0
|
...
|
args
->
m_flags
);
ctx
.
flags
|=
is_tracing
?
Tensor
::
Flags
::
TRACE
:
0
;
ctx
.
args
=
arg_arr
;
ctx
.
args
=
arg_arr
;
ctx
.
nargs
=
sizeof
...(
args
);
ctx
.
nargs
=
sizeof
...(
args
);
ctx
.
op
=
std
::
move
(
op
);
ctx
.
op
=
std
::
move
(
op
);
...
@@ -256,7 +254,6 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors)
...
@@ -256,7 +254,6 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors)
apply_result_t
>
{
apply_result_t
>
{
ApplyContext
ctx
;
ApplyContext
ctx
;
ctx
.
op
=
std
::
move
(
op
);
ctx
.
op
=
std
::
move
(
op
);
ctx
.
flags
=
is_tracing
?
Tensor
::
Flags
::
TRACE
:
0
;
ctx
.
nargs
=
tensors
.
size
();
ctx
.
nargs
=
tensors
.
size
();
Tensor
*
args
[
ctx
.
nargs
];
Tensor
*
args
[
ctx
.
nargs
];
ctx
.
args
=
args
;
ctx
.
args
=
args
;
...
@@ -270,7 +267,6 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors)
...
@@ -270,7 +267,6 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors)
inline
auto
apply
(
std
::
shared_ptr
<
OpDef
>
op
,
Tensor
*
const
*
args
,
size_t
nargs
)
{
inline
auto
apply
(
std
::
shared_ptr
<
OpDef
>
op
,
Tensor
*
const
*
args
,
size_t
nargs
)
{
ApplyContext
ctx
;
ApplyContext
ctx
;
ctx
.
op
=
std
::
move
(
op
);
ctx
.
op
=
std
::
move
(
op
);
ctx
.
flags
=
is_tracing
?
Tensor
::
Flags
::
TRACE
:
0
;
ctx
.
nargs
=
nargs
;
ctx
.
nargs
=
nargs
;
ctx
.
args
=
args
;
ctx
.
args
=
args
;
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
...
...
imperative/python/src/trace.cpp
浏览文件 @
d94a17d3
...
@@ -28,12 +28,12 @@ apply_result_t apply_trace(ApplyContext& ctx) {
...
@@ -28,12 +28,12 @@ apply_result_t apply_trace(ApplyContext& ctx) {
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
i
++
)
{
args
[
i
+
1
]
=
py
::
cast
(
ctx
.
args
[
i
]
->
m_var
);
args
[
i
+
1
]
=
py
::
cast
(
ctx
.
args
[
i
]
->
m_var
);
}
}
py
::
object
re
t
=
py
::
reinterpret_steal
<
py
::
object
>
(
py
::
object
pyou
t
=
py
::
reinterpret_steal
<
py
::
object
>
(
PyObject_Call
(
cpp_apply_backward_varnode
,
args
.
ptr
(),
nullptr
));
PyObject_Call
(
cpp_apply_backward_varnode
,
args
.
ptr
(),
nullptr
));
if
(
!
re
t
)
throw
py
::
error_already_set
();
if
(
!
pyou
t
)
throw
py
::
error_already_set
();
// assumption: python function always returns PyList
// assumption: python function always returns PyList
auto
tup
=
py
::
reinterpret_borrow
<
py
::
list
>
(
re
t
);
auto
tup
=
py
::
reinterpret_borrow
<
py
::
list
>
(
pyou
t
);
for
(
size_t
i
=
0
;
i
<
tup
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
tup
.
size
();
i
++
)
{
auto
pitem
=
tup
[
i
].
cast
<
cg
::
VarNode
*>
();
auto
pitem
=
tup
[
i
].
cast
<
cg
::
VarNode
*>
();
outputs
.
emplace_back
(
std
::
make_shared
<
Tensor
>
(
pitem
));
outputs
.
emplace_back
(
std
::
make_shared
<
Tensor
>
(
pitem
));
...
@@ -48,6 +48,7 @@ apply_result_t apply_trace(ApplyContext& ctx) {
...
@@ -48,6 +48,7 @@ apply_result_t apply_trace(ApplyContext& ctx) {
}
}
auto
pyout
=
PyObject_Call
(
cpp_apply_with_tracing
,
args
.
ptr
(),
nullptr
);
auto
pyout
=
PyObject_Call
(
cpp_apply_with_tracing
,
args
.
ptr
(),
nullptr
);
if
(
!
pyout
)
throw
py
::
error_already_set
();
if
(
!
pyout
)
throw
py
::
error_already_set
();
// assumption: python function always returns PyList
// assumption: python function always returns PyList
auto
tup
=
py
::
reinterpret_steal
<
py
::
list
>
(
pyout
);
auto
tup
=
py
::
reinterpret_steal
<
py
::
list
>
(
pyout
);
for
(
size_t
i
=
0
;
i
<
tup
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
tup
.
size
();
i
++
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录