diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index 4163edca9ff5f954fc53043ad4115e187dd26c60..98379a0151f58ec0d799deda0d6b7b27aae206f7 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -140,7 +140,7 @@ def astensor1d(x, *reference, dtype=None, device=None): else: if ndim != 0 and ndim != 1: raise ValueError("ndim != 1 or 0, get : %d" % ndim) - if not isinstance(x, Tensor): + if not isinstance(x, (Tensor, SymbolVar)): (x,) = Const(x, dtype=dtype, device=device)(*reference) return x diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 6b50027be0a9fec738e5f5c2f2ffdd72e9eb2101..7cf0472b74003346c7408448f72744076c4b9ff6 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -334,7 +334,7 @@ def split(inp, nsplits_or_sections, axis=0): x = tensor(np.random.random((10, 20)), dtype=np.float32) y = F.split(x, 3) z = F.split(x, [6, 17], axis=1) - + print([i.numpy().shape for i in y]) print([i.numpy().shape for i in z]) @@ -686,9 +686,9 @@ def cond_take(mask: Tensor, x: Tensor) -> Tensor: [1. 4.] [0 3] """ - if not isinstance(x, Tensor): + if not isinstance(x, (Tensor, SymbolVar)): raise TypeError("input must be a tensor") - if not isinstance(mask, Tensor): + if not isinstance(mask, (Tensor, SymbolVar)): raise TypeError("mask must be a tensor") if mask.dtype != np.bool_: raise ValueError("mask must be bool") diff --git a/imperative/python/megengine/utils/network.py b/imperative/python/megengine/utils/network.py index 80121d33f954512f55121496483b9f2b74133951..bc56f00bb8ef3500bfd7fceacacce7288d1f35cd 100644 --- a/imperative/python/megengine/utils/network.py +++ b/imperative/python/megengine/utils/network.py @@ -17,6 +17,7 @@ import numpy as np from ..core._imperative_rt import ComputingGraph from ..core._imperative_rt.core2 import SymbolVar +from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape from ..core.tensor import megbrain_graph as G from ..logger import get_logger from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq @@ -182,8 +183,13 @@ class Network: """ + def _set_var_name(var): + graph_var = G.VarNode(var.var) + graph_var.name = var.name + return graph_var + self._compile() - out = [G.VarNode(var.var) for var in self.output_vars] + out = list(map(_set_var_name, self.output_vars)) if kwargs.pop("arg_names", False): logger.warning( @@ -231,15 +237,20 @@ class Network: if not all([var.owner for var in vars]): self.add_dep_oprs(*vars) for var in vars: - if var not in self.output_vars: + # use method 'is' instead of 'in' to avoid + # compare VarNode use elemwise equal + if not any(var is _ for _ in self.output_vars): self.output_vars.append(var) def remove_output(self, *vars: VarNode): """Removes vars from the network output node list. """ for var in vars: - if var in self.output_vars: - self.output_vars.remove(var) + # use list pop instead of remove to avoid + # compare VarNode use elemwise equal + for idx, out_var in enumerate(self.output_vars): + if var is out_var: + self.output_vars.pop(idx) def add_dep_oprs(self, *vars): if len(vars) == 0: @@ -434,6 +445,15 @@ class Network: opnode.add_out_var(self._get_var(var)) return opnode else: + # overwrite the opnode 'new' output VarNode with + # original one when output number larger than 1, + # or will cause dependence issue in _compiler step. + if len(opr.outputs) > 1: + opnode = self.all_oprs_map[opr.id] + for idx, output in enumerate(opnode.outputs): + if output.var.id in self.all_vars_map: + opnode.outputs[idx] = self.all_vars_map[output.var.id] + return None def _get_opr(self, x): @@ -449,6 +469,15 @@ class Network: return self.all_vars_map[x.id] +def set_symbolic_shape(option: bool): + """ + Set the VarNode use symbolic shape or not, return the last status. + Please set to True and must recover after dump if want to change the input batch size. + :param option: True for enable symbolic shape. + """ + return _set_symbolic_shape(option) + + def as_varnode(obj): """convert a :class:`.VarNode` compatible object to :class:`.VarNode`. diff --git a/imperative/python/megengine/utils/network_node.py b/imperative/python/megengine/utils/network_node.py index cbd0665e5b8d39f5da9031831b26f121394cd85f..2bb3693e472aa5ed4d97368db4d356348b944b8c 100644 --- a/imperative/python/megengine/utils/network_node.py +++ b/imperative/python/megengine/utils/network_node.py @@ -14,7 +14,8 @@ from typing import Callable, Sequence import numpy as np from ..core import _imperative_rt as rt -from ..core._imperative_rt.core2 import SymbolVar +from ..core._imperative_rt.core2 import SymbolVar, apply +from ..core._trace_option import use_symbolic_shape from ..core._wrap import Device from ..core.ops import builtin from ..core.tensor.array_method import ArrayMethodMixin @@ -53,15 +54,41 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): obj.owner = owner_opr return obj + def _get_var_shape(self, axis=None): + opdef = ( + builtin.GetVarShape() if axis is None else builtin.GetVarShape(axis=axis) + ) + return apply(opdef, self)[0] + + @property + def partial_shape(self): + """Return the tuple type inferred shape of VarNode + """ + return tuple(self._get_var_shape().numpy()) + + def shapeof(self, axis): + """Return the symbolic shape of axis + """ + return self._get_var_shape(axis=axis) if self.var else None + + @property + def _tuple_shape(self): + return self.partial_shape + @property def shape(self): + """Return the symbolic shape if using set_symbolic_shape(True) + else inferred shape + """ rst = None if self.var: try: rst = self.var.shape except: rst = None - return rst + if not use_symbolic_shape(): + return rst + return self._get_var_shape() if self.var else None @property def dtype(self): @@ -78,10 +105,6 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): def __hash__(self): return id(self) - @property - def _tuple_shape(self): - return self.var.shape - def numpy(self): o = OutputNode(self.var) self.graph.compile(o.outputs).execute() diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index b5a9199372bcf0cb844c7aebbb02cfab6bfea186..a0ed6d1f64dce63aabac2a3b57bd15b111581ca3 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -19,7 +19,7 @@ from megengine.core._trace_option import use_symbolic_shape from megengine.core.tensor import megbrain_graph as G from megengine.core.tensor.utils import astensor1d from megengine.distributed.helper import get_device_count_by_fork -from megengine.utils.network import Network +from megengine.utils.network import Network, set_symbolic_shape from megengine.utils.network_node import VarNode @@ -62,6 +62,22 @@ def test_concat(is_varnode): opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network) +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_condtake(is_varnode): + if is_varnode: + network = Network() + else: + network = None + + x = np.array([[1, 2, 3], [4, 5, 6]]).astype("float32") + y = np.array([[True, False, True], [False, True, True]]) + xx = make_tensor(x, network) + yy = make_tensor(y, network) + val, idx = F.cond_take(yy, xx) + np.testing.assert_equal(val.numpy(), x[y]) + np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0]) + + @pytest.mark.parametrize("is_varnode", [True, False]) def test_concat_device(is_varnode): if is_varnode: @@ -102,6 +118,7 @@ def test_stack(is_varnode): def test_split(is_varnode): if is_varnode: network = Network() + saved_symbolic_shape = set_symbolic_shape(False) else: network = None @@ -134,6 +151,9 @@ def test_split(is_varnode): except ValueError as e: assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]" + if is_varnode: + set_symbolic_shape(saved_symbolic_shape) + @pytest.mark.parametrize("is_varnode", [True, False]) def test_reshape(is_varnode): @@ -161,6 +181,7 @@ def test_reshape(is_varnode): def test_reshape_shape_inference(is_varnode): if is_varnode: network = Network() + saved_symbolic_shape = set_symbolic_shape(False) else: network = None @@ -192,12 +213,15 @@ def test_reshape_shape_inference(is_varnode): {"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]}, ] opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network) + if is_varnode: + set_symbolic_shape(saved_symbolic_shape) @pytest.mark.parametrize("is_varnode", [True, False]) def test_squeeze(is_varnode): if is_varnode: network = Network() + saved_symbolic_shape = set_symbolic_shape(False) else: network = None @@ -209,6 +233,9 @@ def test_squeeze(is_varnode): yy = F.squeeze(xx, axis) np.testing.assert_equal(y, yy.numpy()) + if is_varnode: + set_symbolic_shape(saved_symbolic_shape) + @pytest.mark.parametrize("is_varnode", [True, False]) def test_expand_dims(is_varnode): @@ -358,7 +385,7 @@ def test_flatten(is_varnode): data1 = np.random.random(data1_shape).astype(np.float32) def compare_fn(x, y): - assert x.shape[0] == y + assert x._tuple_shape[0] == y output0 = (2 * 3 * 4 * 5,) output1 = (4 * 5 * 6 * 7,) @@ -420,7 +447,7 @@ def test_broadcast(is_varnode): data3 = np.random.random(input3_shape).astype(np.float32) def compare_fn(x, y): - assert x.shape[0] == y + assert x._tuple_shape[0] == y cases = [ {"input": [data1, output1_shape], "output": output1_shape}, diff --git a/imperative/python/test/unit/utils/test_network.py b/imperative/python/test/unit/utils/test_network.py index e578c06412bf3901ae58c996c7e8b3d2c53107b4..a3935e362b31e8b7d4158a1b9618ad6becdfd39f 100644 --- a/imperative/python/test/unit/utils/test_network.py +++ b/imperative/python/test/unit/utils/test_network.py @@ -10,7 +10,7 @@ from megengine.jit.tracing import trace from megengine.tensor import Tensor from megengine.utils.comp_graph_tools import GraphInference from megengine.utils.network import Network as Net -from megengine.utils.network import as_oprnode +from megengine.utils.network import as_oprnode, set_symbolic_shape from megengine.utils.network_node import Host2DeviceCopy, VarNode @@ -181,19 +181,22 @@ def test_add_input(): np.testing.assert_equal(out["o1"], ((a + b) * 2 + a).numpy()) -def test_add_output(): +def test_add_remove_output(): a = Tensor([1.0, 2.0]) b = Tensor([3.0, 4.0]) @trace(symbolic=True, capture_as_const=True) def fwd(a, b): - return (a + b) * 2 + return (a + b) * 2, (a - b) fwd(a, b) orig_model = io.BytesIO() fwd.dump( - orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False + orig_model, + arg_names=["a", "b"], + output_names=["o1", "o2"], + optimize_for_inference=False, ) orig_model.seek(0) @@ -201,11 +204,13 @@ def test_add_output(): var_a = net.var_filter.name("a").as_unique() var_b = net.var_filter.name("b").as_unique() - y = F.add(var_a, var_b) - y = F.sigmoid(y) + y1 = (var_a + var_b) * 3 + y2 = F.sigmoid(var_a + var_b) - y.name = "o1" - net.add_output(y) + net.remove_output(*net.output_vars) + y1.name = "new_o1" + y2.name = "new_o2" + net.add_output(y1, y2) modified_model = io.BytesIO() net.dump(modified_model) @@ -214,8 +219,8 @@ def test_add_output(): g = GraphInference(modified_model) out = g.run(a.numpy(), b.numpy()) - np.testing.assert_equal(out["o"], ((a + b) * 2).numpy()) - np.testing.assert_equal(out["o1"], (F.sigmoid((a + b))).numpy()) + np.testing.assert_equal(out["new_o1"], ((a + b) * 3).numpy()) + np.testing.assert_equal(out["new_o2"], (F.sigmoid((a + b))).numpy()) def test_query(): @@ -343,3 +348,68 @@ def test_modify_opr_name(): net1 = Net.load(modified_model) assert net1.data_providers_filter.as_unique().name == "net1.net.a" + + +def test_dump_cond_take(): + + a = Tensor([1.0, 2.0]) + + @trace(symbolic=True, capture_as_const=True) + def fwd(a): + return F.cond_take(a > 1, a) + + fwd(a) + orig_model = io.BytesIO() + fwd.dump( + orig_model, + arg_names=["a"], + output_names=["o1", "o2"], + optimize_for_inference=False, + ) + orig_model.seek(0) + + net = Net.load(orig_model) + var_a = net.input_vars[0] + + val, idx = F.cond_take(var_a > 1, var_a) + + net.remove_output(*net.output_vars) + val.name = "value" + idx.name = "index" + net.add_output(val, idx) + + modified_model = io.BytesIO() + net.dump(modified_model) + modified_model.seek(0) + + g = GraphInference(modified_model) + out = g.run(a.numpy()) + + data = a.numpy() + mask = a.numpy() > 1 + np.testing.assert_equal(out["index"], np.where(mask.reshape(-1))[0]) + np.testing.assert_equal(out["value"], data[mask]) + + +def test_set_symbolic_shape(): + + a = Tensor([1.0, 2.0]) + + @trace(symbolic=True, capture_as_const=True) + def fwd(a): + return F.relu(a * 2) + + fwd(a) + orig_model = io.BytesIO() + fwd.dump( + orig_model, arg_names=["a"], output_names=["o"], optimize_for_inference=False, + ) + orig_model.seek(0) + net = Net.load(orig_model) + var_a = net.input_vars[0] + + saved_symbolic_shape = set_symbolic_shape(True) + assert isinstance(var_a.shape, VarNode) + set_symbolic_shape(False) + assert var_a.shape == var_a.partial_shape + set_symbolic_shape(saved_symbolic_shape)