From 7225b0f09f61a55e14838eb465af627b6dfc8fc1 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 10 Jun 2021 17:29:19 +0800 Subject: [PATCH] fix(mge/utils): use static infer manager to get value of network.varnode GitOrigin-RevId: ecc47edab8334e3f41a409020db2a9090db62147 --- .../python/megengine/utils/network_node.py | 11 +++------- imperative/python/test/helpers/utils.py | 14 ++++++++++++- .../test/unit/core/test_tensor_wrapper.py | 21 ++++++++++++++----- .../test/unit/functional/test_tensor.py | 10 ++++++--- 4 files changed, 39 insertions(+), 17 deletions(-) diff --git a/imperative/python/megengine/utils/network_node.py b/imperative/python/megengine/utils/network_node.py index 7357d1749..f94fa86d4 100644 --- a/imperative/python/megengine/utils/network_node.py +++ b/imperative/python/megengine/utils/network_node.py @@ -18,7 +18,6 @@ from ..core._trace_option import use_symbolic_shape from ..core._wrap import Device from ..core.ops import builtin from ..core.tensor.array_method import ArrayMethodMixin -from ..core.tensor.megbrain_graph import OutputNode from .comp_graph_tools import replace_vars from .module_stats import ( preprocess_receptive_field, @@ -106,9 +105,7 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): return id(self) def numpy(self): - o = OutputNode(self.var) - self.graph.compile(o.outputs).execute() - return o.get_value().numpy() + return super().numpy() def _reset(self, other): if not isinstance(other, VarNode): @@ -141,15 +138,13 @@ class OpNode(NetworkNode): @property def id(self): - if self._opr is not None: - return self._opr.id return id(self) @property def priority(self): if self._opr is not None: - return self._opr.priority - return 0 + return (self._opr.priority, self._opr.id) + return (0, 0) @classmethod def load(cls, opr): diff --git a/imperative/python/test/helpers/utils.py b/imperative/python/test/helpers/utils.py index 35d00801d..10cd4a8b5 100644 --- a/imperative/python/test/helpers/utils.py +++ b/imperative/python/test/helpers/utils.py @@ -5,6 +5,7 @@ import numpy as np import megengine.core.tensor.megbrain_graph as G import megengine.utils.comp_graph_tools as cgtools from megengine import tensor +from megengine.core.tensor.megbrain_graph import OutputNode from megengine.jit import trace from megengine.utils.network_node import VarNode @@ -12,8 +13,10 @@ from megengine.utils.network_node import VarNode def _default_compare_fn(x, y): if isinstance(x, np.ndarray): np.testing.assert_allclose(x, y, rtol=1e-6) - else: + elif isinstance(x, tensor): np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) + else: + np.testing.assert_allclose(get_var_value(x), y, rtol=1e-6) def make_tensor(x, network=None, device=None): @@ -25,6 +28,15 @@ def make_tensor(x, network=None, device=None): return tensor(x, device=device) +def get_var_value(x): + try: + o = OutputNode(x.var) + o.graph.compile(o.outputs).execute() + return o.get_value().numpy() + except RuntimeError: + raise ValueError("value invalid!") + + def opr_test( cases, func, diff --git a/imperative/python/test/unit/core/test_tensor_wrapper.py b/imperative/python/test/unit/core/test_tensor_wrapper.py index ba85be704..18dcada79 100644 --- a/imperative/python/test/unit/core/test_tensor_wrapper.py +++ b/imperative/python/test/unit/core/test_tensor_wrapper.py @@ -10,7 +10,7 @@ import copy import numpy as np import pytest -from utils import make_tensor +from utils import get_var_value, make_tensor from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 from megengine.tensor import Parameter, Tensor @@ -55,7 +55,12 @@ def test_matmul(is_varnode): A = make_tensor(np.random.rand(5, 7).astype("float32"), network) B = make_tensor(np.random.rand(7, 10).astype("float32"), network) C = A @ B - np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6) + if is_varnode: + np.testing.assert_almost_equal( + get_var_value(C), get_var_value(A) @ get_var_value(B), decimal=6 + ) + else: + np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6) @pytest.mark.parametrize("is_varnode", [True, False]) @@ -116,11 +121,17 @@ def test_set_subtensor(is_varnode): x = make_tensor([1, 2, 3], network) x[:] = [1, 1, 1] - np.testing.assert_almost_equal(x.numpy(), [1, 1, 1], decimal=6) + np.testing.assert_almost_equal( + get_var_value(x) if is_varnode else x.numpy(), [1, 1, 1], decimal=6 + ) x[[0, 2]] = [3, 2] - np.testing.assert_almost_equal(x.numpy(), [3, 1, 2], decimal=6) + np.testing.assert_almost_equal( + get_var_value(x) if is_varnode else x.numpy(), [3, 1, 2], decimal=6 + ) x[1:3] = [4, 5] - np.testing.assert_almost_equal(x.numpy(), [3, 4, 5], decimal=6) + np.testing.assert_almost_equal( + get_var_value(x) if is_varnode else x.numpy(), [3, 4, 5], decimal=6 + ) def test_computing_with_numpy_array(): diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index 5f0305a4a..63965d37b 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -11,7 +11,7 @@ import platform import numpy as np import pytest -from utils import make_tensor, opr_test +from utils import get_var_value, make_tensor, opr_test import megengine.functional as F from megengine import tensor @@ -75,8 +75,12 @@ def test_condtake(is_varnode): xx = make_tensor(x, network) yy = make_tensor(y, network) val, idx = F.cond_take(yy, xx) - np.testing.assert_equal(val.numpy(), x[y]) - np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0]) + if is_varnode: + np.testing.assert_equal(get_var_value(val), x[y]) + np.testing.assert_equal(get_var_value(idx), np.where(y.reshape(-1))[0]) + else: + np.testing.assert_equal(val.numpy(), x[y]) + np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0]) @pytest.mark.parametrize("is_varnode", [True, False]) -- GitLab