提交 7225b0f0 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(mge/utils): use static infer manager to get value of network.varnode

GitOrigin-RevId: ecc47edab8334e3f41a409020db2a9090db62147
上级 ffe2bb2e
......@@ -18,7 +18,6 @@ from ..core._trace_option import use_symbolic_shape
from ..core._wrap import Device
from ..core.ops import builtin
from ..core.tensor.array_method import ArrayMethodMixin
from ..core.tensor.megbrain_graph import OutputNode
from .comp_graph_tools import replace_vars
from .module_stats import (
preprocess_receptive_field,
......@@ -106,9 +105,7 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
return id(self)
def numpy(self):
o = OutputNode(self.var)
self.graph.compile(o.outputs).execute()
return o.get_value().numpy()
return super().numpy()
def _reset(self, other):
if not isinstance(other, VarNode):
......@@ -141,15 +138,13 @@ class OpNode(NetworkNode):
@property
def id(self):
if self._opr is not None:
return self._opr.id
return id(self)
@property
def priority(self):
if self._opr is not None:
return self._opr.priority
return 0
return (self._opr.priority, self._opr.id)
return (0, 0)
@classmethod
def load(cls, opr):
......
......@@ -5,6 +5,7 @@ import numpy as np
import megengine.core.tensor.megbrain_graph as G
import megengine.utils.comp_graph_tools as cgtools
from megengine import tensor
from megengine.core.tensor.megbrain_graph import OutputNode
from megengine.jit import trace
from megengine.utils.network_node import VarNode
......@@ -12,8 +13,10 @@ from megengine.utils.network_node import VarNode
def _default_compare_fn(x, y):
if isinstance(x, np.ndarray):
np.testing.assert_allclose(x, y, rtol=1e-6)
else:
elif isinstance(x, tensor):
np.testing.assert_allclose(x.numpy(), y, rtol=1e-6)
else:
np.testing.assert_allclose(get_var_value(x), y, rtol=1e-6)
def make_tensor(x, network=None, device=None):
......@@ -25,6 +28,15 @@ def make_tensor(x, network=None, device=None):
return tensor(x, device=device)
def get_var_value(x):
try:
o = OutputNode(x.var)
o.graph.compile(o.outputs).execute()
return o.get_value().numpy()
except RuntimeError:
raise ValueError("value invalid!")
def opr_test(
cases,
func,
......
......@@ -10,7 +10,7 @@ import copy
import numpy as np
import pytest
from utils import make_tensor
from utils import get_var_value, make_tensor
from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8
from megengine.tensor import Parameter, Tensor
......@@ -55,7 +55,12 @@ def test_matmul(is_varnode):
A = make_tensor(np.random.rand(5, 7).astype("float32"), network)
B = make_tensor(np.random.rand(7, 10).astype("float32"), network)
C = A @ B
np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6)
if is_varnode:
np.testing.assert_almost_equal(
get_var_value(C), get_var_value(A) @ get_var_value(B), decimal=6
)
else:
np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6)
@pytest.mark.parametrize("is_varnode", [True, False])
......@@ -116,11 +121,17 @@ def test_set_subtensor(is_varnode):
x = make_tensor([1, 2, 3], network)
x[:] = [1, 1, 1]
np.testing.assert_almost_equal(x.numpy(), [1, 1, 1], decimal=6)
np.testing.assert_almost_equal(
get_var_value(x) if is_varnode else x.numpy(), [1, 1, 1], decimal=6
)
x[[0, 2]] = [3, 2]
np.testing.assert_almost_equal(x.numpy(), [3, 1, 2], decimal=6)
np.testing.assert_almost_equal(
get_var_value(x) if is_varnode else x.numpy(), [3, 1, 2], decimal=6
)
x[1:3] = [4, 5]
np.testing.assert_almost_equal(x.numpy(), [3, 4, 5], decimal=6)
np.testing.assert_almost_equal(
get_var_value(x) if is_varnode else x.numpy(), [3, 4, 5], decimal=6
)
def test_computing_with_numpy_array():
......
......@@ -11,7 +11,7 @@ import platform
import numpy as np
import pytest
from utils import make_tensor, opr_test
from utils import get_var_value, make_tensor, opr_test
import megengine.functional as F
from megengine import tensor
......@@ -75,8 +75,12 @@ def test_condtake(is_varnode):
xx = make_tensor(x, network)
yy = make_tensor(y, network)
val, idx = F.cond_take(yy, xx)
np.testing.assert_equal(val.numpy(), x[y])
np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
if is_varnode:
np.testing.assert_equal(get_var_value(val), x[y])
np.testing.assert_equal(get_var_value(idx), np.where(y.reshape(-1))[0])
else:
np.testing.assert_equal(val.numpy(), x[y])
np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
@pytest.mark.parametrize("is_varnode", [True, False])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册