提交 69728969 编写于 作者: M Megvii Engine Team

fix(mge/utils): fix network multiple outputs issue

GitOrigin-RevId: d22e639cd3cec9dfe09d1452b9c3c352862be911
上级 f36e99d3
......@@ -140,7 +140,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
else:
if ndim != 0 and ndim != 1:
raise ValueError("ndim != 1 or 0, get : %d" % ndim)
if not isinstance(x, Tensor):
if not isinstance(x, (Tensor, SymbolVar)):
(x,) = Const(x, dtype=dtype, device=device)(*reference)
return x
......
......@@ -334,7 +334,7 @@ def split(inp, nsplits_or_sections, axis=0):
x = tensor(np.random.random((10, 20)), dtype=np.float32)
y = F.split(x, 3)
z = F.split(x, [6, 17], axis=1)
print([i.numpy().shape for i in y])
print([i.numpy().shape for i in z])
......@@ -686,9 +686,9 @@ def cond_take(mask: Tensor, x: Tensor) -> Tensor:
[1. 4.] [0 3]
"""
if not isinstance(x, Tensor):
if not isinstance(x, (Tensor, SymbolVar)):
raise TypeError("input must be a tensor")
if not isinstance(mask, Tensor):
if not isinstance(mask, (Tensor, SymbolVar)):
raise TypeError("mask must be a tensor")
if mask.dtype != np.bool_:
raise ValueError("mask must be bool")
......
......@@ -17,6 +17,7 @@ import numpy as np
from ..core._imperative_rt import ComputingGraph
from ..core._imperative_rt.core2 import SymbolVar
from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape
from ..core.tensor import megbrain_graph as G
from ..logger import get_logger
from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq
......@@ -182,8 +183,13 @@ class Network:
"""
def _set_var_name(var):
graph_var = G.VarNode(var.var)
graph_var.name = var.name
return graph_var
self._compile()
out = [G.VarNode(var.var) for var in self.output_vars]
out = list(map(_set_var_name, self.output_vars))
if kwargs.pop("arg_names", False):
logger.warning(
......@@ -231,15 +237,20 @@ class Network:
if not all([var.owner for var in vars]):
self.add_dep_oprs(*vars)
for var in vars:
if var not in self.output_vars:
# use method 'is' instead of 'in' to avoid
# compare VarNode use elemwise equal
if not any(var is _ for _ in self.output_vars):
self.output_vars.append(var)
def remove_output(self, *vars: VarNode):
"""Removes vars from the network output node list.
"""
for var in vars:
if var in self.output_vars:
self.output_vars.remove(var)
# use list pop instead of remove to avoid
# compare VarNode use elemwise equal
for idx, out_var in enumerate(self.output_vars):
if var is out_var:
self.output_vars.pop(idx)
def add_dep_oprs(self, *vars):
if len(vars) == 0:
......@@ -434,6 +445,15 @@ class Network:
opnode.add_out_var(self._get_var(var))
return opnode
else:
# overwrite the opnode 'new' output VarNode with
# original one when output number larger than 1,
# or will cause dependence issue in _compiler step.
if len(opr.outputs) > 1:
opnode = self.all_oprs_map[opr.id]
for idx, output in enumerate(opnode.outputs):
if output.var.id in self.all_vars_map:
opnode.outputs[idx] = self.all_vars_map[output.var.id]
return None
def _get_opr(self, x):
......@@ -449,6 +469,15 @@ class Network:
return self.all_vars_map[x.id]
def set_symbolic_shape(option: bool):
"""
Set the VarNode use symbolic shape or not, return the last status.
Please set to True and must recover after dump if want to change the input batch size.
:param option: True for enable symbolic shape.
"""
return _set_symbolic_shape(option)
def as_varnode(obj):
"""convert a :class:`.VarNode` compatible object to :class:`.VarNode`.
......
......@@ -14,7 +14,8 @@ from typing import Callable, Sequence
import numpy as np
from ..core import _imperative_rt as rt
from ..core._imperative_rt.core2 import SymbolVar
from ..core._imperative_rt.core2 import SymbolVar, apply
from ..core._trace_option import use_symbolic_shape
from ..core._wrap import Device
from ..core.ops import builtin
from ..core.tensor.array_method import ArrayMethodMixin
......@@ -53,15 +54,41 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
obj.owner = owner_opr
return obj
def _get_var_shape(self, axis=None):
opdef = (
builtin.GetVarShape() if axis is None else builtin.GetVarShape(axis=axis)
)
return apply(opdef, self)[0]
@property
def partial_shape(self):
"""Return the tuple type inferred shape of VarNode
"""
return tuple(self._get_var_shape().numpy())
def shapeof(self, axis):
"""Return the symbolic shape of axis
"""
return self._get_var_shape(axis=axis) if self.var else None
@property
def _tuple_shape(self):
return self.partial_shape
@property
def shape(self):
"""Return the symbolic shape if using set_symbolic_shape(True)
else inferred shape
"""
rst = None
if self.var:
try:
rst = self.var.shape
except:
rst = None
return rst
if not use_symbolic_shape():
return rst
return self._get_var_shape() if self.var else None
@property
def dtype(self):
......@@ -78,10 +105,6 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
def __hash__(self):
return id(self)
@property
def _tuple_shape(self):
return self.var.shape
def numpy(self):
o = OutputNode(self.var)
self.graph.compile(o.outputs).execute()
......
......@@ -19,7 +19,7 @@ from megengine.core._trace_option import use_symbolic_shape
from megengine.core.tensor import megbrain_graph as G
from megengine.core.tensor.utils import astensor1d
from megengine.distributed.helper import get_device_count_by_fork
from megengine.utils.network import Network
from megengine.utils.network import Network, set_symbolic_shape
from megengine.utils.network_node import VarNode
......@@ -62,6 +62,22 @@ def test_concat(is_varnode):
opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_condtake(is_varnode):
if is_varnode:
network = Network()
else:
network = None
x = np.array([[1, 2, 3], [4, 5, 6]]).astype("float32")
y = np.array([[True, False, True], [False, True, True]])
xx = make_tensor(x, network)
yy = make_tensor(y, network)
val, idx = F.cond_take(yy, xx)
np.testing.assert_equal(val.numpy(), x[y])
np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
@pytest.mark.parametrize("is_varnode", [True, False])
def test_concat_device(is_varnode):
if is_varnode:
......@@ -102,6 +118,7 @@ def test_stack(is_varnode):
def test_split(is_varnode):
if is_varnode:
network = Network()
saved_symbolic_shape = set_symbolic_shape(False)
else:
network = None
......@@ -134,6 +151,9 @@ def test_split(is_varnode):
except ValueError as e:
assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]"
if is_varnode:
set_symbolic_shape(saved_symbolic_shape)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_reshape(is_varnode):
......@@ -161,6 +181,7 @@ def test_reshape(is_varnode):
def test_reshape_shape_inference(is_varnode):
if is_varnode:
network = Network()
saved_symbolic_shape = set_symbolic_shape(False)
else:
network = None
......@@ -192,12 +213,15 @@ def test_reshape_shape_inference(is_varnode):
{"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]},
]
opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network)
if is_varnode:
set_symbolic_shape(saved_symbolic_shape)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_squeeze(is_varnode):
if is_varnode:
network = Network()
saved_symbolic_shape = set_symbolic_shape(False)
else:
network = None
......@@ -209,6 +233,9 @@ def test_squeeze(is_varnode):
yy = F.squeeze(xx, axis)
np.testing.assert_equal(y, yy.numpy())
if is_varnode:
set_symbolic_shape(saved_symbolic_shape)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_expand_dims(is_varnode):
......@@ -358,7 +385,7 @@ def test_flatten(is_varnode):
data1 = np.random.random(data1_shape).astype(np.float32)
def compare_fn(x, y):
assert x.shape[0] == y
assert x._tuple_shape[0] == y
output0 = (2 * 3 * 4 * 5,)
output1 = (4 * 5 * 6 * 7,)
......@@ -420,7 +447,7 @@ def test_broadcast(is_varnode):
data3 = np.random.random(input3_shape).astype(np.float32)
def compare_fn(x, y):
assert x.shape[0] == y
assert x._tuple_shape[0] == y
cases = [
{"input": [data1, output1_shape], "output": output1_shape},
......
......@@ -10,7 +10,7 @@ from megengine.jit.tracing import trace
from megengine.tensor import Tensor
from megengine.utils.comp_graph_tools import GraphInference
from megengine.utils.network import Network as Net
from megengine.utils.network import as_oprnode
from megengine.utils.network import as_oprnode, set_symbolic_shape
from megengine.utils.network_node import Host2DeviceCopy, VarNode
......@@ -181,19 +181,22 @@ def test_add_input():
np.testing.assert_equal(out["o1"], ((a + b) * 2 + a).numpy())
def test_add_output():
def test_add_remove_output():
a = Tensor([1.0, 2.0])
b = Tensor([3.0, 4.0])
@trace(symbolic=True, capture_as_const=True)
def fwd(a, b):
return (a + b) * 2
return (a + b) * 2, (a - b)
fwd(a, b)
orig_model = io.BytesIO()
fwd.dump(
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
orig_model,
arg_names=["a", "b"],
output_names=["o1", "o2"],
optimize_for_inference=False,
)
orig_model.seek(0)
......@@ -201,11 +204,13 @@ def test_add_output():
var_a = net.var_filter.name("a").as_unique()
var_b = net.var_filter.name("b").as_unique()
y = F.add(var_a, var_b)
y = F.sigmoid(y)
y1 = (var_a + var_b) * 3
y2 = F.sigmoid(var_a + var_b)
y.name = "o1"
net.add_output(y)
net.remove_output(*net.output_vars)
y1.name = "new_o1"
y2.name = "new_o2"
net.add_output(y1, y2)
modified_model = io.BytesIO()
net.dump(modified_model)
......@@ -214,8 +219,8 @@ def test_add_output():
g = GraphInference(modified_model)
out = g.run(a.numpy(), b.numpy())
np.testing.assert_equal(out["o"], ((a + b) * 2).numpy())
np.testing.assert_equal(out["o1"], (F.sigmoid((a + b))).numpy())
np.testing.assert_equal(out["new_o1"], ((a + b) * 3).numpy())
np.testing.assert_equal(out["new_o2"], (F.sigmoid((a + b))).numpy())
def test_query():
......@@ -343,3 +348,68 @@ def test_modify_opr_name():
net1 = Net.load(modified_model)
assert net1.data_providers_filter.as_unique().name == "net1.net.a"
def test_dump_cond_take():
a = Tensor([1.0, 2.0])
@trace(symbolic=True, capture_as_const=True)
def fwd(a):
return F.cond_take(a > 1, a)
fwd(a)
orig_model = io.BytesIO()
fwd.dump(
orig_model,
arg_names=["a"],
output_names=["o1", "o2"],
optimize_for_inference=False,
)
orig_model.seek(0)
net = Net.load(orig_model)
var_a = net.input_vars[0]
val, idx = F.cond_take(var_a > 1, var_a)
net.remove_output(*net.output_vars)
val.name = "value"
idx.name = "index"
net.add_output(val, idx)
modified_model = io.BytesIO()
net.dump(modified_model)
modified_model.seek(0)
g = GraphInference(modified_model)
out = g.run(a.numpy())
data = a.numpy()
mask = a.numpy() > 1
np.testing.assert_equal(out["index"], np.where(mask.reshape(-1))[0])
np.testing.assert_equal(out["value"], data[mask])
def test_set_symbolic_shape():
a = Tensor([1.0, 2.0])
@trace(symbolic=True, capture_as_const=True)
def fwd(a):
return F.relu(a * 2)
fwd(a)
orig_model = io.BytesIO()
fwd.dump(
orig_model, arg_names=["a"], output_names=["o"], optimize_for_inference=False,
)
orig_model.seek(0)
net = Net.load(orig_model)
var_a = net.input_vars[0]
saved_symbolic_shape = set_symbolic_shape(True)
assert isinstance(var_a.shape, VarNode)
set_symbolic_shape(False)
assert var_a.shape == var_a.partial_shape
set_symbolic_shape(saved_symbolic_shape)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册