提交 7225b0f0 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(mge/utils): use static infer manager to get value of network.varnode

GitOrigin-RevId: ecc47edab8334e3f41a409020db2a9090db62147
上级 ffe2bb2e
...@@ -18,7 +18,6 @@ from ..core._trace_option import use_symbolic_shape ...@@ -18,7 +18,6 @@ from ..core._trace_option import use_symbolic_shape
from ..core._wrap import Device from ..core._wrap import Device
from ..core.ops import builtin from ..core.ops import builtin
from ..core.tensor.array_method import ArrayMethodMixin from ..core.tensor.array_method import ArrayMethodMixin
from ..core.tensor.megbrain_graph import OutputNode
from .comp_graph_tools import replace_vars from .comp_graph_tools import replace_vars
from .module_stats import ( from .module_stats import (
preprocess_receptive_field, preprocess_receptive_field,
...@@ -106,9 +105,7 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): ...@@ -106,9 +105,7 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
return id(self) return id(self)
def numpy(self): def numpy(self):
o = OutputNode(self.var) return super().numpy()
self.graph.compile(o.outputs).execute()
return o.get_value().numpy()
def _reset(self, other): def _reset(self, other):
if not isinstance(other, VarNode): if not isinstance(other, VarNode):
...@@ -141,15 +138,13 @@ class OpNode(NetworkNode): ...@@ -141,15 +138,13 @@ class OpNode(NetworkNode):
@property @property
def id(self): def id(self):
if self._opr is not None:
return self._opr.id
return id(self) return id(self)
@property @property
def priority(self): def priority(self):
if self._opr is not None: if self._opr is not None:
return self._opr.priority return (self._opr.priority, self._opr.id)
return 0 return (0, 0)
@classmethod @classmethod
def load(cls, opr): def load(cls, opr):
......
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
import megengine.core.tensor.megbrain_graph as G import megengine.core.tensor.megbrain_graph as G
import megengine.utils.comp_graph_tools as cgtools import megengine.utils.comp_graph_tools as cgtools
from megengine import tensor from megengine import tensor
from megengine.core.tensor.megbrain_graph import OutputNode
from megengine.jit import trace from megengine.jit import trace
from megengine.utils.network_node import VarNode from megengine.utils.network_node import VarNode
...@@ -12,8 +13,10 @@ from megengine.utils.network_node import VarNode ...@@ -12,8 +13,10 @@ from megengine.utils.network_node import VarNode
def _default_compare_fn(x, y): def _default_compare_fn(x, y):
if isinstance(x, np.ndarray): if isinstance(x, np.ndarray):
np.testing.assert_allclose(x, y, rtol=1e-6) np.testing.assert_allclose(x, y, rtol=1e-6)
else: elif isinstance(x, tensor):
np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) np.testing.assert_allclose(x.numpy(), y, rtol=1e-6)
else:
np.testing.assert_allclose(get_var_value(x), y, rtol=1e-6)
def make_tensor(x, network=None, device=None): def make_tensor(x, network=None, device=None):
...@@ -25,6 +28,15 @@ def make_tensor(x, network=None, device=None): ...@@ -25,6 +28,15 @@ def make_tensor(x, network=None, device=None):
return tensor(x, device=device) return tensor(x, device=device)
def get_var_value(x):
try:
o = OutputNode(x.var)
o.graph.compile(o.outputs).execute()
return o.get_value().numpy()
except RuntimeError:
raise ValueError("value invalid!")
def opr_test( def opr_test(
cases, cases,
func, func,
......
...@@ -10,7 +10,7 @@ import copy ...@@ -10,7 +10,7 @@ import copy
import numpy as np import numpy as np
import pytest import pytest
from utils import make_tensor from utils import get_var_value, make_tensor
from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8
from megengine.tensor import Parameter, Tensor from megengine.tensor import Parameter, Tensor
...@@ -55,7 +55,12 @@ def test_matmul(is_varnode): ...@@ -55,7 +55,12 @@ def test_matmul(is_varnode):
A = make_tensor(np.random.rand(5, 7).astype("float32"), network) A = make_tensor(np.random.rand(5, 7).astype("float32"), network)
B = make_tensor(np.random.rand(7, 10).astype("float32"), network) B = make_tensor(np.random.rand(7, 10).astype("float32"), network)
C = A @ B C = A @ B
np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6) if is_varnode:
np.testing.assert_almost_equal(
get_var_value(C), get_var_value(A) @ get_var_value(B), decimal=6
)
else:
np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6)
@pytest.mark.parametrize("is_varnode", [True, False]) @pytest.mark.parametrize("is_varnode", [True, False])
...@@ -116,11 +121,17 @@ def test_set_subtensor(is_varnode): ...@@ -116,11 +121,17 @@ def test_set_subtensor(is_varnode):
x = make_tensor([1, 2, 3], network) x = make_tensor([1, 2, 3], network)
x[:] = [1, 1, 1] x[:] = [1, 1, 1]
np.testing.assert_almost_equal(x.numpy(), [1, 1, 1], decimal=6) np.testing.assert_almost_equal(
get_var_value(x) if is_varnode else x.numpy(), [1, 1, 1], decimal=6
)
x[[0, 2]] = [3, 2] x[[0, 2]] = [3, 2]
np.testing.assert_almost_equal(x.numpy(), [3, 1, 2], decimal=6) np.testing.assert_almost_equal(
get_var_value(x) if is_varnode else x.numpy(), [3, 1, 2], decimal=6
)
x[1:3] = [4, 5] x[1:3] = [4, 5]
np.testing.assert_almost_equal(x.numpy(), [3, 4, 5], decimal=6) np.testing.assert_almost_equal(
get_var_value(x) if is_varnode else x.numpy(), [3, 4, 5], decimal=6
)
def test_computing_with_numpy_array(): def test_computing_with_numpy_array():
......
...@@ -11,7 +11,7 @@ import platform ...@@ -11,7 +11,7 @@ import platform
import numpy as np import numpy as np
import pytest import pytest
from utils import make_tensor, opr_test from utils import get_var_value, make_tensor, opr_test
import megengine.functional as F import megengine.functional as F
from megengine import tensor from megengine import tensor
...@@ -75,8 +75,12 @@ def test_condtake(is_varnode): ...@@ -75,8 +75,12 @@ def test_condtake(is_varnode):
xx = make_tensor(x, network) xx = make_tensor(x, network)
yy = make_tensor(y, network) yy = make_tensor(y, network)
val, idx = F.cond_take(yy, xx) val, idx = F.cond_take(yy, xx)
np.testing.assert_equal(val.numpy(), x[y]) if is_varnode:
np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0]) np.testing.assert_equal(get_var_value(val), x[y])
np.testing.assert_equal(get_var_value(idx), np.where(y.reshape(-1))[0])
else:
np.testing.assert_equal(val.numpy(), x[y])
np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
@pytest.mark.parametrize("is_varnode", [True, False]) @pytest.mark.parametrize("is_varnode", [True, False])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册