Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
c0b9e071
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
c0b9e071
编写于
2月 14, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(traced_module): fix compatible test and fix functional compatiblity
GitOrigin-RevId: 5824d232b36199d1b65d779fd442bb039c4ede6a
上级
2c48dc22
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
110 addition
and
2 deletion
+110
-2
imperative/python/megengine/traced_module/compat.py
imperative/python/megengine/traced_module/compat.py
+101
-2
imperative/python/megengine/traced_module/pytree.py
imperative/python/megengine/traced_module/pytree.py
+6
-0
imperative/python/megengine/traced_module/utils.py
imperative/python/megengine/traced_module/utils.py
+3
-0
未找到文件。
imperative/python/megengine/traced_module/compat.py
浏览文件 @
c0b9e071
...
@@ -229,7 +229,11 @@ def square_func_loader(expr):
...
@@ -229,7 +229,11 @@ def square_func_loader(expr):
@
register_functional_loader
((
"megengine.functional.math"
,
"topk"
))
@
register_functional_loader
((
"megengine.functional.math"
,
"topk"
))
def
topk_loader
(
expr
):
def
topk_loader
(
expr
):
if
not
hasattr
(
expr
,
"version"
):
# for mge 1.6
import
pkg_resources
as
pkg
if
not
hasattr
(
expr
,
"version"
)
or
pkg
.
parse_version
(
expr
.
version
)
<=
pkg
.
parse_version
(
"1.12.0"
):
def
origin_topk_signature
(
def
origin_topk_signature
(
inp
,
k
,
descending
=
False
,
kth_only
=
False
,
no_sort
=
False
inp
,
k
,
descending
=
False
,
kth_only
=
False
,
no_sort
=
False
...
@@ -260,7 +264,7 @@ def arange_func_loader(expr):
...
@@ -260,7 +264,7 @@ def arange_func_loader(expr):
if
len
(
args
)
==
5
:
if
len
(
args
)
==
5
:
device
=
args
[
-
1
]
device
=
args
[
-
1
]
dtype
=
args
[
-
2
]
dtype
=
args
[
-
2
]
args
=
args
[:
len
(
args
)
-
2
]
args
=
args
[:
-
2
]
kwargs
[
"dtype"
]
=
dtype
kwargs
[
"dtype"
]
=
dtype
kwargs
[
"device"
]
=
device
kwargs
[
"device"
]
=
device
...
@@ -268,6 +272,24 @@ def arange_func_loader(expr):
...
@@ -268,6 +272,24 @@ def arange_func_loader(expr):
expr
.
set_args_kwargs
(
*
args
,
**
kwargs
)
expr
.
set_args_kwargs
(
*
args
,
**
kwargs
)
@
register_functional_loader
((
"megengine.functional.tensor"
,
"linspace"
))
def
linespace_loader
(
expr
):
args
,
kwargs
=
expr
.
args
,
expr
.
kwargs
if
not
hasattr
(
expr
,
"version"
):
def
orig_linspace_signature
(
start
,
stop
,
num
,
dtype
=
"float32"
,
device
=
None
):
pass
args
,
kwargs
=
_convert_kwargs_to_args
(
orig_linspace_signature
,
expr
.
args
,
expr
.
kwargs
)
expr
.
set_args_kwargs
(
*
args
,
**
kwargs
)
if
len
(
args
)
==
5
:
new_args
=
args
[
0
:
-
2
]
new_kwargs
=
{
"dtype"
:
args
[
-
2
],
"device"
:
args
[
-
1
]}
expr
.
set_args_kwargs
(
*
new_args
,
**
new_kwargs
)
@
register_functional_loader
((
"megengine.functional.tensor"
,
"full"
))
@
register_functional_loader
((
"megengine.functional.tensor"
,
"full"
))
def
full_func_loader
(
expr
):
def
full_func_loader
(
expr
):
kwargs
=
expr
.
kwargs
kwargs
=
expr
.
kwargs
...
@@ -290,3 +312,80 @@ def full_func_loader(expr):
...
@@ -290,3 +312,80 @@ def full_func_loader(expr):
kwargs
[
"device"
]
=
device
kwargs
[
"device"
]
=
device
expr
.
set_args_kwargs
(
*
args
,
**
kwargs
)
expr
.
set_args_kwargs
(
*
args
,
**
kwargs
)
@
register_functional_loader
((
"megengine.functional.nn"
,
"conv_transpose2d"
))
def
deconv_loader
(
expr
):
args
,
kwargs
=
expr
.
args
,
expr
.
kwargs
if
not
hasattr
(
expr
,
"version"
):
def
orig_conv_transpose2d_signature
(
inp
,
weight
,
bias
=
None
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
conv_mode
=
"cross_correlation"
,
compute_mode
=
"default"
,
):
pass
args
,
kwargs
=
_convert_kwargs_to_args
(
orig_conv_transpose2d_signature
,
expr
.
args
,
expr
.
kwargs
)
expr
.
set_args_kwargs
(
*
args
,
**
kwargs
)
if
len
(
args
)
==
9
:
args
=
list
(
args
)
args
.
insert
(
4
,
0
)
# output padding = 0
expr
.
set_args_kwargs
(
*
args
,
**
kwargs
)
@
register_functional_loader
((
"megengine.functional.quantized"
,
"conv_transpose2d"
))
def
deconv_loader
(
expr
):
args
,
kwargs
=
expr
.
args
,
expr
.
kwargs
if
not
hasattr
(
expr
,
"version"
):
def
orig_conv_transpose2d_signature
(
inp
,
weight
,
bias
=
None
,
dtype
=
None
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
conv_mode
=
"cross_correlation"
,
compute_mode
=
"default"
,
):
pass
args
,
kwargs
=
_convert_kwargs_to_args
(
orig_conv_transpose2d_signature
,
expr
.
args
,
expr
.
kwargs
)
expr
.
set_args_kwargs
(
*
args
,
**
kwargs
)
if
len
(
args
)
==
10
:
args
=
list
(
args
)
args
.
insert
(
5
,
0
)
# output padding = 0
expr
.
set_args_kwargs
(
*
args
,
**
kwargs
)
@
register_functional_loader
((
"megengine.functional.nn"
,
"conv_transpose3d"
))
def
deconv3d_loader
(
expr
):
args
,
kwargs
=
expr
.
args
,
expr
.
kwargs
if
not
hasattr
(
expr
,
"version"
):
def
origin_conv_transpose3d_signature
(
inp
,
weight
,
bias
=
None
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
):
pass
args
,
kwargs
=
_convert_kwargs_to_args
(
origin_conv_transpose3d_signature
,
expr
.
args
,
expr
.
kwargs
)
expr
.
set_args_kwargs
(
*
args
,
**
kwargs
)
if
len
(
args
)
==
7
:
args
=
list
(
args
)
args
.
insert
(
4
,
0
)
expr
.
set_args_kwargs
(
*
args
,
**
kwargs
)
imperative/python/megengine/traced_module/pytree.py
浏览文件 @
c0b9e071
...
@@ -328,6 +328,12 @@ class LeafDef(TreeDef):
...
@@ -328,6 +328,12 @@ class LeafDef(TreeDef):
assert
isinstance
(
leaves
[
0
],
self
.
type
),
self
.
type
assert
isinstance
(
leaves
[
0
],
self
.
type
),
self
.
type
return
leaves
[
0
]
return
leaves
[
0
]
def
__setstate__
(
self
,
state
):
for
k
,
v
in
state
.
items
():
setattr
(
self
,
k
,
v
)
if
hasattr
(
self
,
"const_val"
)
and
isinstance
(
self
.
const_val
,
np
.
dtype
):
self
.
type
=
_leaf_type
(
self
.
const_val
)
def
__ne__
(
self
,
other
)
->
bool
:
def
__ne__
(
self
,
other
)
->
bool
:
return
not
self
.
__eq__
(
other
)
return
not
self
.
__eq__
(
other
)
...
...
imperative/python/megengine/traced_module/utils.py
浏览文件 @
c0b9e071
...
@@ -56,6 +56,9 @@ def _convert_kwargs_to_args(
...
@@ -56,6 +56,9 @@ def _convert_kwargs_to_args(
argspecs
:
Union
[
Callable
,
FullArgSpec
],
args
,
kwargs
,
is_bounded
=
False
argspecs
:
Union
[
Callable
,
FullArgSpec
],
args
,
kwargs
,
is_bounded
=
False
):
):
# is_bounded = True when func is a method and provided args don't include 'self'
# is_bounded = True when func is a method and provided args don't include 'self'
if
isinstance
(
argspecs
,
Callable
)
and
hasattr
(
argspecs
,
"__wrapped__"
):
argspecs
=
inspect
.
unwrap
(
argspecs
)
arg_specs
=
(
arg_specs
=
(
inspect
.
getfullargspec
(
argspecs
)
if
isinstance
(
argspecs
,
Callable
)
else
argspecs
inspect
.
getfullargspec
(
argspecs
)
if
isinstance
(
argspecs
,
Callable
)
else
argspecs
)
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录