From 8e5bf948576fc281acf3afe7ef3366715a263678 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 12 May 2021 19:36:52 +0800 Subject: [PATCH] fix(mge/utils): fix bug of VarNode inplace operations GitOrigin-RevId: fa9eec7079671a117809c3da8ae7338e12f345f0 --- .../python/megengine/utils/network_node.py | 22 ++-- .../test/unit/core/test_tensor_wrapper.py | 102 +++++++++++++++--- 2 files changed, 94 insertions(+), 30 deletions(-) diff --git a/imperative/python/megengine/utils/network_node.py b/imperative/python/megengine/utils/network_node.py index 2bb3693e4..314c73e8f 100644 --- a/imperative/python/megengine/utils/network_node.py +++ b/imperative/python/megengine/utils/network_node.py @@ -6,10 +6,9 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import abc import json import sys -from typing import Callable, Sequence +from typing import Sequence import numpy as np @@ -19,10 +18,7 @@ from ..core._trace_option import use_symbolic_shape from ..core._wrap import Device from ..core.ops import builtin from ..core.tensor.array_method import ArrayMethodMixin -from ..core.tensor.indexing import getitem as _getitem -from ..core.tensor.indexing import setitem as _setitem -from ..core.tensor.megbrain_graph import InputNode, OutputNode -from ..tensor import Tensor +from ..core.tensor.megbrain_graph import OutputNode from .comp_graph_tools import replace_vars from .module_stats import ( preprocess_receptive_field, @@ -110,18 +106,18 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): self.graph.compile(o.outputs).execute() return o.get_value().numpy() - def __getitem__(self, index): - return _getitem(self, index) - - def __setitem__(self, index, value): - if index is not Ellipsis: - value = _setitem(self, index, value) + def _reset(self, other): + if not isinstance(other, VarNode): + assert self.graph, "VarNode _reset must have graph" + node = ImmutableTensor(other, graph=self.graph) + node.compile(self.graph) + other = node.outputs[0] if self.owner is not None: idx = self.owner.outputs.index(self) self.owner.outputs[idx] = VarNode( self.var, owner_opr=self.owner, name=self.var.name ) - self.var = value.var + self.var = other.var self.owner = None def set_owner_opr(self, owner_opr): diff --git a/imperative/python/test/unit/core/test_tensor_wrapper.py b/imperative/python/test/unit/core/test_tensor_wrapper.py index 2d8649a21..b3aa6dfc9 100644 --- a/imperative/python/test/unit/core/test_tensor_wrapper.py +++ b/imperative/python/test/unit/core/test_tensor_wrapper.py @@ -9,38 +9,81 @@ import copy import numpy as np +import pytest +from utils import make_tensor from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 from megengine.tensor import Tensor +from megengine.utils.network import Network -def test_basic(): +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_basic(is_varnode): + if is_varnode: + network = Network() + else: + network = None + x_np = np.random.rand(10).astype("float32") - x = Tensor(x_np) + x = make_tensor(x_np, network) y = x * x y_np = y.numpy() np.testing.assert_almost_equal(y_np, x_np * x_np) -def test_literal_arith(): +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_literal_arith(is_varnode): + if is_varnode: + network = Network() + else: + network = None + x_np = np.random.rand(10).astype("float32") - x = Tensor(x_np) + x = make_tensor(x_np, network) y = x * 2 y_np = y.numpy() np.testing.assert_almost_equal(y_np, x_np * 2) -def test_matmul(): - A = Tensor(np.random.rand(5, 7).astype("float32")) - B = Tensor(np.random.rand(7, 10).astype("float32")) +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_matmul(is_varnode): + if is_varnode: + network = Network() + else: + network = None + + A = make_tensor(np.random.rand(5, 7).astype("float32"), network) + B = make_tensor(np.random.rand(7, 10).astype("float32"), network) C = A @ B np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6) -def test_reduce(): +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_inplace_add(is_varnode): + if is_varnode: + network = Network() + else: + network = None + + x_np = np.random.rand(10).astype("float32") + y_np = np.random.rand(10).astype("float32") + x = make_tensor(x_np, network) + y = make_tensor(y_np, network) + y += x + out_np = y.numpy() + np.testing.assert_almost_equal(out_np, x_np + y_np) + + +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_reduce(is_varnode): + if is_varnode: + network = Network() + else: + network = None + def test_x(x_np): for m in ["sum", "prod", "min", "max", "mean"]: - x = Tensor(x_np) + x = make_tensor(x_np, network) y = getattr(x, m)(axis=-1, keepdims=True) np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6) @@ -50,16 +93,28 @@ def test_reduce(): test_x(np.array([True, False, True])) -def test_set_value(): +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_set_value(is_varnode): + if is_varnode: + network = Network() + else: + network = None + v0 = np.random.random((2, 3)).astype(np.float32) - param = Tensor(v0) + param = make_tensor(v0, network) v1 = np.random.random((2, 3)).astype(np.float32) param[...] = v1 np.testing.assert_allclose(param.numpy(), v1, atol=5e-6) -def test_set_subtensor(): - x = Tensor([1, 2, 3]) +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_set_subtensor(is_varnode): + if is_varnode: + network = Network() + else: + network = None + + x = make_tensor([1, 2, 3], network) x[:] = [1, 1, 1] np.testing.assert_almost_equal(x.numpy(), [1, 1, 1], decimal=6) x[[0, 2]] = [3, 2] @@ -78,14 +133,27 @@ def test_computing_with_numpy_array(): np.testing.assert_equal(np.equal(xx, xx).numpy(), np.equal(x, x)) -def test_transpose(): +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_transpose(is_varnode): + if is_varnode: + network = Network() + else: + network = None + x = np.random.rand(2, 5).astype("float32") - xx = Tensor(x) + xx = make_tensor(x, network) np.testing.assert_almost_equal(xx.T.numpy(), x.T) -def test_as_type(): - x = Tensor([1, 2, 3], dtype=np.float32) +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_as_type(is_varnode): + if is_varnode: + network = Network() + else: + network = None + + x_np = np.array([1, 2, 3], dtype=np.float32) + x = make_tensor(x_np, network) y = x.astype(qint8(0.1)) np.testing.assert_almost_equal(get_scale(y.dtype), 0.1) z = y.astype(qint8(0.2)) -- GitLab