Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
e417798f
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
e417798f
编写于
5月 21, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge): correct pytype when calling apply from python
GitOrigin-RevId: 6abfa06adac1c857ace451dc4249da3438aee364
上级
c4048519
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
26 addition
and
1 deletion
+26
-1
imperative/python/megengine/tensor.py
imperative/python/megengine/tensor.py
+7
-0
imperative/python/src/pyext17.h
imperative/python/src/pyext17.h
+4
-0
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+6
-0
imperative/python/test/unit/core/test_tensor_wrapper.py
imperative/python/test/unit/core/test_tensor_wrapper.py
+9
-1
未找到文件。
imperative/python/megengine/tensor.py
浏览文件 @
e417798f
...
...
@@ -246,4 +246,11 @@ tensor = Tensor
class
Parameter
(
Tensor
):
r
"""
A kind of Tensor that is to be considered a module parameter.
.. note::
Operations happened on Parameter usually return a Tensor instead of Parameter.
For example, with a Parameter ``x``, ``x.reshape/to/sum/...`` will result into a Tensor.
Any operations between Parameter and Tensor will have Tensor as outputs.
"""
imperative/python/src/pyext17.h
浏览文件 @
e417798f
...
...
@@ -397,6 +397,10 @@ public:
return
Py_TYPE
(
op
)
==
&
m_type
;
}
bool
same_pytype
(
PyTypeObject
*
pt
)
{
return
pt
==
&
m_type
;
}
PyObject
*
finalize
()
{
if
(
!
m_finalized
)
{
m_finalized
=
true
;
...
...
imperative/python/src/tensor.cpp
浏览文件 @
e417798f
...
...
@@ -140,6 +140,12 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
auto
*
op
=
args
[
0
];
PyTypeObject
*
pytype
=
args
[
1
]
->
ob_type
;
// check if pytype is Parameter(and all other python Tensor's derived class),
// if yes, using it's tp_base(python Tensor)
if
(
TensorWrapper
::
wrap_t
::
type
().
same_pytype
(
pytype
->
tp_base
->
tp_base
))
{
pytype
=
pytype
->
tp_base
;
}
++
args
;
--
nargs
;
...
...
imperative/python/test/unit/core/test_tensor_wrapper.py
浏览文件 @
e417798f
...
...
@@ -13,7 +13,7 @@ import pytest
from
utils
import
make_tensor
from
megengine.core.tensor.dtype
import
get_scale
,
get_zero_point
,
qint8
,
quint8
from
megengine.tensor
import
Tensor
from
megengine.tensor
import
Parameter
,
Tensor
from
megengine.utils.network
import
Network
...
...
@@ -198,3 +198,11 @@ def test_name():
assert
x
.
name
==
"x"
x
=
Tensor
(
0
,
name
=
"x"
)
assert
x
.
name
==
"x"
def
test_tensor_type
():
x1
=
Parameter
(
1
)
x2
=
Tensor
(
2
)
y1
=
x1
+
x2
y2
=
x2
+
x1
assert
type
(
y1
)
==
type
(
y2
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录