提交 8e5bf948 编写于 作者: M Megvii Engine Team

fix(mge/utils): fix bug of VarNode inplace operations

GitOrigin-RevId: fa9eec7079671a117809c3da8ae7338e12f345f0
上级 13b15fb0
......@@ -6,10 +6,9 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import abc
import json
import sys
from typing import Callable, Sequence
from typing import Sequence
import numpy as np
......@@ -19,10 +18,7 @@ from ..core._trace_option import use_symbolic_shape
from ..core._wrap import Device
from ..core.ops import builtin
from ..core.tensor.array_method import ArrayMethodMixin
from ..core.tensor.indexing import getitem as _getitem
from ..core.tensor.indexing import setitem as _setitem
from ..core.tensor.megbrain_graph import InputNode, OutputNode
from ..tensor import Tensor
from ..core.tensor.megbrain_graph import OutputNode
from .comp_graph_tools import replace_vars
from .module_stats import (
preprocess_receptive_field,
......@@ -110,18 +106,18 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
self.graph.compile(o.outputs).execute()
return o.get_value().numpy()
def __getitem__(self, index):
return _getitem(self, index)
def __setitem__(self, index, value):
if index is not Ellipsis:
value = _setitem(self, index, value)
def _reset(self, other):
if not isinstance(other, VarNode):
assert self.graph, "VarNode _reset must have graph"
node = ImmutableTensor(other, graph=self.graph)
node.compile(self.graph)
other = node.outputs[0]
if self.owner is not None:
idx = self.owner.outputs.index(self)
self.owner.outputs[idx] = VarNode(
self.var, owner_opr=self.owner, name=self.var.name
)
self.var = value.var
self.var = other.var
self.owner = None
def set_owner_opr(self, owner_opr):
......
......@@ -9,38 +9,81 @@
import copy
import numpy as np
import pytest
from utils import make_tensor
from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8
from megengine.tensor import Tensor
from megengine.utils.network import Network
def test_basic():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_basic(is_varnode):
if is_varnode:
network = Network()
else:
network = None
x_np = np.random.rand(10).astype("float32")
x = Tensor(x_np)
x = make_tensor(x_np, network)
y = x * x
y_np = y.numpy()
np.testing.assert_almost_equal(y_np, x_np * x_np)
def test_literal_arith():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_literal_arith(is_varnode):
if is_varnode:
network = Network()
else:
network = None
x_np = np.random.rand(10).astype("float32")
x = Tensor(x_np)
x = make_tensor(x_np, network)
y = x * 2
y_np = y.numpy()
np.testing.assert_almost_equal(y_np, x_np * 2)
def test_matmul():
A = Tensor(np.random.rand(5, 7).astype("float32"))
B = Tensor(np.random.rand(7, 10).astype("float32"))
@pytest.mark.parametrize("is_varnode", [True, False])
def test_matmul(is_varnode):
if is_varnode:
network = Network()
else:
network = None
A = make_tensor(np.random.rand(5, 7).astype("float32"), network)
B = make_tensor(np.random.rand(7, 10).astype("float32"), network)
C = A @ B
np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6)
def test_reduce():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_inplace_add(is_varnode):
if is_varnode:
network = Network()
else:
network = None
x_np = np.random.rand(10).astype("float32")
y_np = np.random.rand(10).astype("float32")
x = make_tensor(x_np, network)
y = make_tensor(y_np, network)
y += x
out_np = y.numpy()
np.testing.assert_almost_equal(out_np, x_np + y_np)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_reduce(is_varnode):
if is_varnode:
network = Network()
else:
network = None
def test_x(x_np):
for m in ["sum", "prod", "min", "max", "mean"]:
x = Tensor(x_np)
x = make_tensor(x_np, network)
y = getattr(x, m)(axis=-1, keepdims=True)
np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6)
......@@ -50,16 +93,28 @@ def test_reduce():
test_x(np.array([True, False, True]))
def test_set_value():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_set_value(is_varnode):
if is_varnode:
network = Network()
else:
network = None
v0 = np.random.random((2, 3)).astype(np.float32)
param = Tensor(v0)
param = make_tensor(v0, network)
v1 = np.random.random((2, 3)).astype(np.float32)
param[...] = v1
np.testing.assert_allclose(param.numpy(), v1, atol=5e-6)
def test_set_subtensor():
x = Tensor([1, 2, 3])
@pytest.mark.parametrize("is_varnode", [True, False])
def test_set_subtensor(is_varnode):
if is_varnode:
network = Network()
else:
network = None
x = make_tensor([1, 2, 3], network)
x[:] = [1, 1, 1]
np.testing.assert_almost_equal(x.numpy(), [1, 1, 1], decimal=6)
x[[0, 2]] = [3, 2]
......@@ -78,14 +133,27 @@ def test_computing_with_numpy_array():
np.testing.assert_equal(np.equal(xx, xx).numpy(), np.equal(x, x))
def test_transpose():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_transpose(is_varnode):
if is_varnode:
network = Network()
else:
network = None
x = np.random.rand(2, 5).astype("float32")
xx = Tensor(x)
xx = make_tensor(x, network)
np.testing.assert_almost_equal(xx.T.numpy(), x.T)
def test_as_type():
x = Tensor([1, 2, 3], dtype=np.float32)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_as_type(is_varnode):
if is_varnode:
network = Network()
else:
network = None
x_np = np.array([1, 2, 3], dtype=np.float32)
x = make_tensor(x_np, network)
y = x.astype(qint8(0.1))
np.testing.assert_almost_equal(get_scale(y.dtype), 0.1)
z = y.astype(qint8(0.2))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册