提交 e417798f 编写于 作者: M Megvii Engine Team

fix(mge): correct pytype when calling apply from python

GitOrigin-RevId: 6abfa06adac1c857ace451dc4249da3438aee364
上级 c4048519
...@@ -246,4 +246,11 @@ tensor = Tensor ...@@ -246,4 +246,11 @@ tensor = Tensor
class Parameter(Tensor): class Parameter(Tensor):
r""" r"""
A kind of Tensor that is to be considered a module parameter. A kind of Tensor that is to be considered a module parameter.
.. note::
Operations happened on Parameter usually return a Tensor instead of Parameter.
For example, with a Parameter ``x``, ``x.reshape/to/sum/...`` will result into a Tensor.
Any operations between Parameter and Tensor will have Tensor as outputs.
""" """
...@@ -397,6 +397,10 @@ public: ...@@ -397,6 +397,10 @@ public:
return Py_TYPE(op) == &m_type; return Py_TYPE(op) == &m_type;
} }
bool same_pytype(PyTypeObject *pt) {
return pt == &m_type;
}
PyObject* finalize() { PyObject* finalize() {
if (!m_finalized) { if (!m_finalized) {
m_finalized = true; m_finalized = true;
......
...@@ -140,6 +140,12 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje ...@@ -140,6 +140,12 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
auto* op = args[0]; auto* op = args[0];
PyTypeObject* pytype = args[1]->ob_type; PyTypeObject* pytype = args[1]->ob_type;
// check if pytype is Parameter(and all other python Tensor's derived class),
// if yes, using it's tp_base(python Tensor)
if (TensorWrapper::wrap_t::type().same_pytype(pytype->tp_base->tp_base)) {
pytype = pytype->tp_base;
}
++args; ++args;
--nargs; --nargs;
......
...@@ -13,7 +13,7 @@ import pytest ...@@ -13,7 +13,7 @@ import pytest
from utils import make_tensor from utils import make_tensor
from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8
from megengine.tensor import Tensor from megengine.tensor import Parameter, Tensor
from megengine.utils.network import Network from megengine.utils.network import Network
...@@ -198,3 +198,11 @@ def test_name(): ...@@ -198,3 +198,11 @@ def test_name():
assert x.name == "x" assert x.name == "x"
x = Tensor(0, name="x") x = Tensor(0, name="x")
assert x.name == "x" assert x.name == "x"
def test_tensor_type():
x1 = Parameter(1)
x2 = Tensor(2)
y1 = x1 + x2
y2 = x2 + x1
assert type(y1) == type(y2)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册