From e417798f9eac3cd1b0d398c262f3f41901432a34 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 21 May 2021 18:24:25 +0800 Subject: [PATCH] fix(mge): correct pytype when calling apply from python GitOrigin-RevId: 6abfa06adac1c857ace451dc4249da3438aee364 --- imperative/python/megengine/tensor.py | 7 +++++++ imperative/python/src/pyext17.h | 4 ++++ imperative/python/src/tensor.cpp | 6 ++++++ .../python/test/unit/core/test_tensor_wrapper.py | 10 +++++++++- 4 files changed, 26 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index 574e78511..cc17d924a 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -246,4 +246,11 @@ tensor = Tensor class Parameter(Tensor): r""" A kind of Tensor that is to be considered a module parameter. + + .. note:: + + Operations happened on Parameter usually return a Tensor instead of Parameter. + For example, with a Parameter ``x``, ``x.reshape/to/sum/...`` will result into a Tensor. + Any operations between Parameter and Tensor will have Tensor as outputs. + """ diff --git a/imperative/python/src/pyext17.h b/imperative/python/src/pyext17.h index feced9517..4f10a207a 100644 --- a/imperative/python/src/pyext17.h +++ b/imperative/python/src/pyext17.h @@ -397,6 +397,10 @@ public: return Py_TYPE(op) == &m_type; } + bool same_pytype(PyTypeObject *pt) { + return pt == &m_type; + } + PyObject* finalize() { if (!m_finalized) { m_finalized = true; diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 8d44ae05c..67be8ee12 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -140,6 +140,12 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje auto* op = args[0]; PyTypeObject* pytype = args[1]->ob_type; + + // check if pytype is Parameter(and all other python Tensor's derived class), + // if yes, using it's tp_base(python Tensor) + if (TensorWrapper::wrap_t::type().same_pytype(pytype->tp_base->tp_base)) { + pytype = pytype->tp_base; + } ++args; --nargs; diff --git a/imperative/python/test/unit/core/test_tensor_wrapper.py b/imperative/python/test/unit/core/test_tensor_wrapper.py index b3aa6dfc9..ba85be704 100644 --- a/imperative/python/test/unit/core/test_tensor_wrapper.py +++ b/imperative/python/test/unit/core/test_tensor_wrapper.py @@ -13,7 +13,7 @@ import pytest from utils import make_tensor from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 -from megengine.tensor import Tensor +from megengine.tensor import Parameter, Tensor from megengine.utils.network import Network @@ -198,3 +198,11 @@ def test_name(): assert x.name == "x" x = Tensor(0, name="x") assert x.name == "x" + + +def test_tensor_type(): + x1 = Parameter(1) + x2 = Tensor(2) + y1 = x1 + x2 + y2 = x2 + x1 + assert type(y1) == type(y2) -- GitLab