diff --git a/imperative/python/megengine/core/ops/special.py b/imperative/python/megengine/core/ops/special.py index 1e496c1c828c9bac33810b79b499ad6eaf3be036..a4a8d693426f3fdb34eef259bf1b44c692145bc7 100644 --- a/imperative/python/megengine/core/ops/special.py +++ b/imperative/python/megengine/core/ops/special.py @@ -8,6 +8,9 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import numpy as np +from .._imperative_rt import make_const +from .._imperative_rt.core2 import SymbolVar, Tensor + class Const: def __init__(self, value=None, *, dtype=None, device=None): @@ -19,7 +22,19 @@ class Const: from ...tensor import Tensor device = self.device - if device is None: - device = reference[0].device + + if len(reference) != 0: + reference = reference[0] + assert isinstance( + reference, (SymbolVar, Tensor) + ), "Reference should be Tensor or VarNode" + + if device is None: + device = reference.device + + if isinstance(reference, SymbolVar): + cls = type(reference) + rst = cls(make_const(reference.graph, self.value, device, self.dtype)) + return (rst,) return (Tensor(self.value, self.dtype, self.device, True),) diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index 6647642b84e14ae7273f45223e6f2723877ddbad..cda2306678b4568094e3ab09172584cc474fa342 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -13,7 +13,7 @@ from typing import Union import numpy as np from .._imperative_rt.common import CompNode -from .._imperative_rt.core2 import Tensor, apply +from .._imperative_rt.core2 import SymbolVar, Tensor, apply from ..ops import builtin from ..ops.builtin import Elemwise, GetVarShape from . import utils @@ -230,7 +230,9 @@ def _todo(*_): def _expand_args(args): if len(args) == 1: - if isinstance(args[0], (collections.abc.Sequence, Tensor, np.ndarray),): + if isinstance( + args[0], (collections.abc.Sequence, Tensor, SymbolVar, np.ndarray), + ): args = args[0] return args diff --git a/imperative/python/megengine/core/tensor/indexing.py b/imperative/python/megengine/core/tensor/indexing.py index e912d471863d0c0dc83a79365b4946eb601a6df5..8017befcfa15dca62ec2cf205410351af75581c5 100644 --- a/imperative/python/megengine/core/tensor/indexing.py +++ b/imperative/python/megengine/core/tensor/indexing.py @@ -10,7 +10,7 @@ from typing import Iterable import numpy as np -from .._imperative_rt.core2 import Tensor, apply +from .._imperative_rt.core2 import SymbolVar, Tensor, apply from .._trace_option import use_symbolic_shape from ..ops import builtin from ..ops.special import Const @@ -149,13 +149,13 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): return True def get_index(i): - if not isinstance(i, (Tensor)): + if not isinstance(i, (Tensor, SymbolVar)): if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: - (i,) = Const(i, dtype=np.bool_, device=inp.device)() + (i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) else: - (i,) = Const(i, dtype=np.int32, device=inp.device)() + (i,) = Const(i, dtype=np.int32, device=inp.device)(inp) return i - assert isinstance(i, Tensor) + assert isinstance(i, (Tensor, SymbolVar)) if i.dtype != np.bool_: return i _, ind = apply(builtin.CondTake(), i, i) @@ -197,9 +197,9 @@ def try_condtake(tensor, index): ): return [] if isinstance(index, np.ndarray): - (index,) = Const(index, dtype=np.bool_, device=tensor.device)() - assert isinstance(index, Tensor) - if not isinstance(tensor, Tensor): + (index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor) + assert isinstance(index, (Tensor, SymbolVar)) + if not isinstance(tensor, (Tensor, SymbolVar)): raise TypeError("input must be a tensor") if tensor.device != index.device: raise ValueError( @@ -214,11 +214,16 @@ def getitem(tensor, index): return try_result[0] tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index) for v in tensors: + if v.shape is None: + break if isinstance(v.shape, v.__class__): break if len(v.shape) > 0 and v.shape[0] == 0: - (empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)() + (empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)( + tensor + ) return empty_tensor + if use_subtensor: op = builtin.Subtensor(items=items) else: @@ -235,8 +240,8 @@ def setitem(tensor, index, value): if len(try_result) == 2: index = try_result[1] tensor = tensor.reshape(-1) - if not isinstance(value, Tensor): - (value,) = Const(value, dtype=tensor.dtype, device=tensor.device)() + if not isinstance(value, (Tensor, SymbolVar)): + (value,) = Const(value, dtype=tensor.dtype, device=tensor.device)(tensor) tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) if use_subtensor: op = builtin.Subtensor(items=items) diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index 9a9a0f97d64c22c27f020ff033b2f26584371a40..4163edca9ff5f954fc53043ad4115e187dd26c60 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -11,8 +11,9 @@ from typing import Iterable, Union import numpy as np -from .._imperative_rt import VarNode -from .._imperative_rt.core2 import Tensor, apply, dtype_promotion, get_device +from .._imperative_rt import make_const +from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device +from .._wrap import device as as_device from ..ops import builtin from ..ops.special import Const from .dtype import is_dtype_equal, is_quantize @@ -38,13 +39,9 @@ def set_convert_inputs(flag): def concatenate(inputs, axis=0, *, device=None): - dtype = dtype_promotion(inputs) - device = get_device(inputs) - - def convert(x): - return convert_single_value(x, dtype=dtype, device=device) - - inputs = tuple(map(convert, inputs)) + inputs = convert_inputs(*inputs) + if device is None: + device = get_device(inputs) (result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs) return result @@ -60,7 +57,7 @@ def astype(x, dtype): def convert_single_value(v, *, dtype=None, device=None): - if isinstance(v, (Tensor, VarNode)): + if isinstance(v, (Tensor, SymbolVar)): if not is_quantize(v.dtype): v = astype(v, dtype) else: @@ -68,17 +65,35 @@ def convert_single_value(v, *, dtype=None, device=None): return v -def convert_inputs(*args: Tensor): +def convert_inputs(*args, device=None): if not _enable_convert_inputs: return args dtype = dtype_promotion(args) - device = get_device(args) + if device is None: + device = get_device(args) + device = as_device(device) + + graph = None + sym_type = None + for a in args: + if isinstance(a, SymbolVar): + if graph is None: + graph = a.var.graph + sym_type = type(a) + else: + assert graph == a.var.graph + args = list(args) + if graph is not None: + for i in range(len(args)): + if not isinstance(args[i], SymbolVar): + rst = make_const(graph, np.array(args[i]), device.to_c(), dtype) + args[i] = sym_type(rst) def convert(value): if value is None: return value - return convert_single_value(value, dtype=dtype, device=device) + return convert_single_value(value, dtype=dtype, device=device.to_c()) return tuple(map(convert, args)) @@ -98,14 +113,14 @@ def result_type(*args): def isscalar(x): - if isinstance(x, Tensor): + if isinstance(x, (Tensor, SymbolVar)): return x._isscalar() return np.isscalar(x) def setscalar(x): - if isinstance(x, Tensor): + if isinstance(x, (Tensor, SymbolVar)): x._setscalar() else: raise NotImplementedError("Unsupport type {}".format(type(x))) @@ -132,7 +147,7 @@ def astensor1d(x, *reference, dtype=None, device=None): if not isinstance(x, collections.abc.Sequence): raise TypeError - if any(isinstance(i, Tensor) for i in x): + if any(isinstance(i, (Tensor, SymbolVar)) for i in x): x = concatenate(x, device=device) if dtype is not None: x = astype(x, dtype) @@ -142,7 +157,7 @@ def astensor1d(x, *reference, dtype=None, device=None): def _expand_int(s, i): - if isinstance(i, Tensor): + if isinstance(i, (Tensor, SymbolVar)): i_np = i.numpy() if i_np.ndim == 0: s.append(int(i_np)) diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index 69bebd6479445ee3a50535c7fc9862ce60d917d8..c1f1fff6390cceb245551f72aa4ab41cc9fc1c84 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -9,8 +9,7 @@ # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order import numpy as np -from ..core._imperative_rt.core2 import apply -from ..core._imperative_rt.graph import VarNode +from ..core._imperative_rt.core2 import SymbolVar, apply from ..core.ops import builtin from ..core.ops.builtin import Elemwise from ..core.tensor import utils @@ -72,7 +71,7 @@ __all__ = [ def _elwise(*args, mode): - tensor_args = list(filter(lambda x: isinstance(x, (Tensor, VarNode)), args)) + tensor_args = list(filter(lambda x: isinstance(x, (Tensor, SymbolVar)), args)) if len(tensor_args) == 0: dtype = utils.dtype_promotion(args) first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index c67b44286eb06f902df473920572544316f3b32e..0b45ec9c97711b33f394af066d64274e2bdb508d 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -12,7 +12,7 @@ from typing import Iterable, Optional, Sequence, Union import numpy as np from ..core._imperative_rt import CompNode -from ..core._imperative_rt.core2 import apply +from ..core._imperative_rt.core2 import SymbolVar, apply from ..core._wrap import device as as_device from ..core.ops import builtin from ..core.ops.builtin import Copy, Identity @@ -101,7 +101,7 @@ def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Ten return result -def full(shape, value, dtype="float32", device=None): +def full(shape, value, dtype="float32", device=None) -> Tensor: """ Returns a tensor with given shape and value. """ @@ -115,7 +115,7 @@ def full(shape, value, dtype="float32", device=None): return broadcast_to(x, shape) -def ones(shape, dtype="float32", device=None): +def ones(shape, dtype="float32", device=None) -> Tensor: """ Returns a ones tensor with given shape. @@ -142,14 +142,14 @@ def ones(shape, dtype="float32", device=None): return full(shape, 1.0, dtype=dtype, device=device) -def zeros(shape, dtype="float32", device=None): +def zeros(shape, dtype="float32", device=None) -> Tensor: """ Returns a zero tensor with given shape. """ return full(shape, 0.0, dtype=dtype, device=device) -def zeros_like(inp: Tensor) -> Tensor: +def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: """ Returns a zero tensor with the same shape as input tensor. @@ -176,21 +176,26 @@ def zeros_like(inp: Tensor) -> Tensor: [0 0 0]] """ - return zeros(inp.shape, dtype=inp.dtype, device=inp.device) + return full_like(inp, 0.0) -def ones_like(inp: Tensor) -> Tensor: +def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: """ Returns a ones tensor with the same shape as input tensor. """ - return ones(inp.shape, dtype=inp.dtype, device=inp.device) + return full_like(inp, 1.0) -def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: +def full_like( + inp: Union[Tensor, SymbolVar], value: Union[int, float] +) -> Union[Tensor, SymbolVar]: """ Returns a tensor filled with given value with the same shape as input tensor. """ - return full(inp.shape, value, dtype=inp.dtype, device=inp.device) + (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) + if inp.shape is (): + return x + return broadcast_to(x, inp.shape) def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: @@ -259,15 +264,10 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: if len(inps) == 1: return inps[0] - dtype = dtype_promotion(inps) + inps = convert_inputs(*inps, device=device) if device is None: device = get_device(inps) device = as_device(device) - - def convert(x): - return convert_single_value(x, dtype=dtype, device=device) - - inps = tuple(map(convert, inps)) (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps) return result @@ -379,8 +379,14 @@ def split(inp, nsplits_or_sections, axis=0): Ntotal, axis, Nsections ) ) + + func = ( + floor_div + if isinstance(Nsections, (SymbolVar, Tensor)) + else lambda x, y: x // y + ) div_points = [0] + [ - floor_div(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections) + func(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections) ] for i in range(2, Nsections + 1): div_points[i] = div_points[i - 1] + div_points[i] @@ -925,11 +931,15 @@ def linspace( if not (cur_device is None or device == cur_device): raise ("ambiguous device for linspace opr") - if not isinstance(start, Tensor): + is_symbolvar = list(isinstance(x, SymbolVar) for x in [start, stop, num]) + if any(is_symbolvar) and not all(is_symbolvar): + raise TypeError("start, stop and num should all be VarNode or none of them") + + if not isinstance(start, (Tensor, SymbolVar)): start = Tensor(start, device=device) - if not isinstance(stop, Tensor): + if not isinstance(stop, (Tensor, SymbolVar)): stop = Tensor(stop, device=device) - if not isinstance(num, Tensor): + if not isinstance(num, (Tensor, SymbolVar)): num = Tensor(num, device=device) op = builtin.Linspace(comp_node=device) @@ -983,7 +993,7 @@ def arange( stop = stop.astype("float32") if isinstance(step, Tensor): step = step.astype("float32") - num = ceil(Tensor((stop - start) / step, device=device)) + num = ceil((stop - start) / step) stop = start + step * (num - 1) result = linspace(start, stop, num, device=device) if np.dtype(dtype) == np.int32: diff --git a/imperative/python/megengine/utils/network.py b/imperative/python/megengine/utils/network.py index 8738a5b54526f680955a073a00321e5636fc5c4d..3ba992197041dbef199466a2cc00de1853b1eac8 100644 --- a/imperative/python/megengine/utils/network.py +++ b/imperative/python/megengine/utils/network.py @@ -16,6 +16,7 @@ from typing import Dict, List import numpy as np from ..core._imperative_rt import ComputingGraph +from ..core._imperative_rt.core2 import SymbolVar from ..core.tensor import megbrain_graph as G from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq from .network_node import ( @@ -60,12 +61,12 @@ class Network: ) outputs = [new_outputs[i] for i in outspec] self._orig_outputs = outputs - self.add_dep_oprs(*outputs) + for x in self._orig_outputs: + self.output_vars.append(self._get_var(x)) + self.add_dep_oprs() for x in self._orig_inputs: self.input_vars.append(self._get_var(x)) - for x in self._orig_outputs: - self.output_vars.append(self._get_var(x)) self.graph = self._orig_outputs[0].graph return self @@ -197,6 +198,8 @@ class Network: def add_output(self, *vars: VarNode): """Adds vars into the network output node list """ + if not all([var.owner for var in vars]): + self.add_dep_oprs(*vars) for var in vars: if var not in self.output_vars: self.output_vars.append(var) @@ -209,21 +212,25 @@ class Network: self.output_vars.remove(var) def add_dep_oprs(self, *vars): - """Adds dependent opnodes and varnodes of vars into network - """ - oprs = get_oprs_seq(vars, False, False) - for mge_opr in oprs: + if len(vars) == 0: + vars = self.output_vars + q = list(vars) + while len(q) > 0: + cur = q.pop(0) + if cur.owner is not None: + continue + if cur.name is None: + cur.name = cur.var.name + self.all_vars_map[cur.var.id] = cur + mge_opr = cur.var.owner if get_opr_type(mge_opr) == "Host2DeviceCopy": self._orig_inputs.extend(mge_opr.outputs) - opr = self._add_opr(mge_opr) - if opr is not None: - for x in mge_opr.inputs: - opr.add_inp_var(self._get_var(x)) - # set out var - for x in mge_opr.outputs: - opr.add_out_var(self._get_var(x)) - - return [self.all_vars_map[var.id] for var in vars] + cur.owner = self._add_opr(mge_opr) + if cur.owner is None: + cur.owner = self.all_oprs_map[mge_opr.id] + continue + q.extend(cur.owner.inputs) + return list(vars) def modify_opr_names(self, modifier): """Modifies names of operators **inplace**; useful for merging loaded @@ -275,6 +282,9 @@ class Network: Replaces vars in the graph. :param repl_dict: the map {old_var: new_var} that specifies how to replace the vars. """ + if not all([var.owner for var in repl_dict.values()]): + print(repl_dict.values()) + self.add_dep_oprs(*list(repl_dict.values())) for var in self.all_vars: if var in repl_dict: repl_var = repl_dict[var] @@ -282,6 +292,7 @@ class Network: idx = owner.outputs.index(repl_var) owner.outputs[idx] = var var.__dict__.update(repl_var.__dict__) + var.var = repl_var.var def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): """ @@ -297,6 +308,7 @@ class Network: for ind, var in enumerate(opr.outputs): var.owner = repl_dict[opr] var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) + var.var = repl_dict[opr].outputs[ind].var def get_opr_by_type(self, oprcls, unique=True): assert issubclass(oprcls, OpNode) @@ -381,11 +393,16 @@ class Network: return self.opr_filter.as_dict() # used for loading and building graph - def _add_opr(self, x): + def _add_opr(self, opr): # TODO: use megbrain C++ RTTI to replace type string - if x.id not in self.all_oprs_map: - self.all_oprs_map[x.id] = str_to_mge_class(get_opr_type(x)).load(x) - return self.all_oprs_map[x.id] + if opr.id not in self.all_oprs_map: + opnode = str_to_mge_class(get_opr_type(opr)).load(opr) + self.all_oprs_map[opr.id] = opnode + for var in opr.inputs: + opnode.add_inp_var(self._get_var(var)) + for var in opr.outputs: + opnode.add_out_var(self._get_var(var)) + return opnode else: return None @@ -397,7 +414,7 @@ class Network: def _get_var(self, x): # auto convert to VarNode of Network - if x.id not in self.all_vars_map: + if x.id not in self.all_vars_map or self.all_vars_map[x.id].var != x: self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner)) return self.all_vars_map[x.id] @@ -652,7 +669,7 @@ class NodeFilterHasInput(NodeFilter): assert isinstance( i, OpNode ), "has_input() must be used with OpNode; " "got {!r}".format(i) - if self.var in i.inputs: + if any(self.var is _ for _ in i.inputs): yield i diff --git a/imperative/python/megengine/utils/network_node.py b/imperative/python/megengine/utils/network_node.py index 155b057965a8513e5eb3441005135b47a5b03171..cbd0665e5b8d39f5da9031831b26f121394cd85f 100644 --- a/imperative/python/megengine/utils/network_node.py +++ b/imperative/python/megengine/utils/network_node.py @@ -6,16 +6,21 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import abc import json import sys -from typing import Callable +from typing import Callable, Sequence import numpy as np from ..core import _imperative_rt as rt +from ..core._imperative_rt.core2 import SymbolVar from ..core._wrap import Device from ..core.ops import builtin -from ..core.tensor.megbrain_graph import InputNode +from ..core.tensor.array_method import ArrayMethodMixin +from ..core.tensor.indexing import getitem as _getitem +from ..core.tensor.indexing import setitem as _setitem +from ..core.tensor.megbrain_graph import InputNode, OutputNode from ..tensor import Tensor from .comp_graph_tools import replace_vars from .module_stats import ( @@ -29,9 +34,13 @@ class NetworkNode: pass -class VarNode(NetworkNode): - def __init__(self, owner_opr=None, name=None): - self.var = None +class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)): + pass + + +class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): + def __init__(self, var=None, *, owner_opr=None, name=None): + SymbolVar.__init__(self, var) self.owner = owner_opr self.name = name self.id = id(self) @@ -58,6 +67,40 @@ class VarNode(NetworkNode): def dtype(self): return self.var.dtype if self.var else None + def __bool__(self): + return False + + __index__ = None + __int__ = None + __float__ = None + __complex__ = None + + def __hash__(self): + return id(self) + + @property + def _tuple_shape(self): + return self.var.shape + + def numpy(self): + o = OutputNode(self.var) + self.graph.compile(o.outputs).execute() + return o.get_value().numpy() + + def __getitem__(self, index): + return _getitem(self, index) + + def __setitem__(self, index, value): + if index is not Ellipsis: + value = _setitem(self, index, value) + if self.owner is not None: + idx = self.owner.outputs.index(self) + self.owner.outputs[idx] = VarNode( + self.var, owner_opr=self.owner, name=self.var.name + ) + self.var = value.var + self.owner = None + def set_owner_opr(self, owner_opr): self.owner = owner_opr @@ -138,7 +181,7 @@ class Host2DeviceCopy(OpNode): outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) self._opr = outputs.owner if len(self.outputs) == 0: - self.outputs.append(VarNode(self, self.name)) + self.outputs.append(VarNode(owner_opr=self, name=self.name)) self.outputs[0].var = outputs assert self.outputs[0].owner is self @@ -176,8 +219,8 @@ class ImmutableTensor(OpNode): def set_value(self, data, device=None): assert self.graph is not None cn = device if device else self.device - assert isinstance(data, (int, float, np.ndarray)) - if isinstance(data, (int, float)): + assert isinstance(data, (int, float, Sequence, np.ndarray)) + if not isinstance(data, np.ndarray): data = np.array(data) if data.dtype == np.float64: data = data.astype(np.float32) @@ -185,7 +228,7 @@ class ImmutableTensor(OpNode): data = data.astype(np.int32) varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name) if len(self.outputs) == 0: - self.outputs.append(VarNode(self, self.name)) + self.outputs.append(VarNode(owner_opr=self, name=self.name)) self.outputs[0].var = varnode self._opr = varnode.owner diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index f5eac698e00003af002b1ee1551d771e2069dda5..f809685be13e403e9de21ff5ed809ae6d7b7d36b 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -160,16 +160,21 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje if (ctx.op->same_type()) { ctx.backward = true; } - - - if (py::isinstance(py::handle(args[0]))){ - SmallVector vinputs(nargs); - for (size_t i = 0; i < nargs; ++i) { - vinputs[i] = py::handle(args[i]).cast(); - } - auto op = ctx.op.get(); - return to_tuple(OpDef::apply_on_var_node(*op, vinputs)).release().ptr(); - } + + if (py::isinstance(py::handle(args[0]))){ + SmallVector vinputs(nargs); + for (size_t i = 0; i < nargs; ++i) { + vinputs[i] = py::handle(args[i]).cast()->m_node; + } + auto op = ctx.op.get(); + auto rst = OpDef::apply_on_var_node(*op, vinputs); + auto ret = pybind11::tuple(rst.size()); + auto typeobj = py::handle(args[0]).get_type(); + for (size_t i = 0; i(py::handle(handle))){ - auto var = py::handle(handle).cast(); - mgb::DType type = var->dtype(); + if (py::isinstance(py::handle(handle))){ + auto var = py::handle(handle).cast(); + mgb::DType type = var->m_node->dtype(); auto && descr = npy::dtype_mgb2np_descr(type); Py_INCREF(descr.get()); tensors.emplace_back(descr.get()); @@ -737,19 +742,26 @@ CompNode _get_device(PyObject*const* args, size_t nargs) { bool valid = false; CompNode cn; for (size_t i = 0; i < nargs; ++i) { - PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; + PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; TensorWrapper* tw = TensorWrapper::try_cast(handle); - bool is_var = py::isinstance(py::handle(handle)); - if (tw || is_var) { + bool is_symvar = py::isinstance(py::handle(handle)); + if (tw || is_symvar) { if (!valid) { - cn = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast()->comp_node(); + cn = tw ? tw->m_tensor->comp_node() + : py::handle(handle) + .cast() + ->m_node->comp_node(); valid = true; } else { - CompNode cn1 = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast()->comp_node(); + CompNode cn1 = tw ? tw->m_tensor->comp_node() + : py::handle(handle) + .cast() + ->m_node->comp_node(); if (cn1 != cn) { throw py::value_error(ssprintf("ambiguous device: %s vs %s", - cn.to_string().c_str(), cn1.to_string().c_str())); + cn.to_string().c_str(), + cn1.to_string().c_str())); } } } @@ -849,6 +861,32 @@ void init_tensor(py::module m) { .def("__call__", &TensorWeakRef::operator()) .def("_use_cnt", &TensorWeakRef::_use_cnt); + py::class_>(m, "SymbolVar") + .def_property_readonly( + "dtype", [](PySymbolVar* v) { return v->m_node->dtype(); }) + .def_property("var", [](PySymbolVar* v) { return v->m_node; }, + [](PySymbolVar* s, cg::VarNode* v) { s->m_node = v; }) + .def_property_readonly( + "device", + [](PySymbolVar* v) { return v->m_node->comp_node(); }) + .def_property_readonly( + "graph", + [](PySymbolVar* v) { return v->m_node->owner_graph(); }) + .def_property_readonly( + "shape", + [](PySymbolVar* v) -> const TensorShape* { + auto&& mgr = v->m_node->owner_graph() + ->static_infer_manager(); + return mgr.infer_shape_fallible(v->m_node); + }) + .def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; }) + .def("_setscalar", + [](PySymbolVar* v) { return v->is_scalar = true; }) + .def(py::init([](cg::VarNode* node) { + return std::make_shared(node); + }), + py::arg() = nullptr); + static PyMethodDef method_defs[] = { MGE_PY_INTERFACE(apply, py_apply), MGE_PY_INTERFACE(dtype_promotion, dtype_promotion), diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index d3d5fef48feacf5c6aa68e554a6e141be0bae555..f2a36568935b88829a5eb8ffb7b3ede6d24a400f 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -181,6 +181,12 @@ struct TensorWrapper { PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); }; }; +struct PySymbolVar { + cg::VarNode* m_node = nullptr; + bool is_scalar = false; + PySymbolVar() = default; + PySymbolVar(VarNode *m): m_node(m){} +}; PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */); diff --git a/imperative/python/test/helpers/utils.py b/imperative/python/test/helpers/utils.py index 63de93469810904279bc6a19729f32081fae2ac1..35d00801d2a643081e67eee77d8f92140fa35cea 100644 --- a/imperative/python/test/helpers/utils.py +++ b/imperative/python/test/helpers/utils.py @@ -2,9 +2,11 @@ import io import numpy as np +import megengine.core.tensor.megbrain_graph as G import megengine.utils.comp_graph_tools as cgtools from megengine import tensor from megengine.jit import trace +from megengine.utils.network_node import VarNode def _default_compare_fn(x, y): @@ -14,8 +16,23 @@ def _default_compare_fn(x, y): np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) +def make_tensor(x, network=None, device=None): + if network is not None: + if isinstance(x, VarNode): + return VarNode(x.var) + return network.make_const(x, device=device) + else: + return tensor(x, device=device) + + def opr_test( - cases, func, compare_fn=_default_compare_fn, ref_fn=None, test_trace=True, **kwargs + cases, + func, + compare_fn=_default_compare_fn, + ref_fn=None, + test_trace=True, + network=None, + **kwargs ): """ :param cases: the list which have dict element, the list length should be 2 for dynamic shape test. @@ -44,7 +61,7 @@ def opr_test( if not isinstance(results, (tuple, list)): results = (results,) for r, e in zip(results, expected): - if not isinstance(r, tensor): + if not isinstance(r, (tensor, VarNode)): r = tensor(r) compare_fn(r, e) @@ -72,9 +89,9 @@ def opr_test( raise ValueError("the input func should be callable") inp, outp = get_param(cases, 0) - inp_tensor = [tensor(inpi) for inpi in inp] + inp_tensor = [make_tensor(inpi, network) for inpi in inp] - if test_trace: + if test_trace and not network: copied_inp = inp_tensor.copy() for symbolic in [False, True]: traced_func = trace(symbolic=symbolic)(func) diff --git a/imperative/python/test/unit/core/test_indexing_op.py b/imperative/python/test/unit/core/test_indexing_op.py index be4d918a6f840d8195dc953dd3c9e69bea18c30e..114dc45e5da9479b7b219fefde3c5f8a00bfebef 100644 --- a/imperative/python/test/unit/core/test_indexing_op.py +++ b/imperative/python/test/unit/core/test_indexing_op.py @@ -10,12 +10,17 @@ import collections import numpy as np import pytest +from utils import make_tensor import megengine +import megengine.core.tensor.megbrain_graph as G +import megengine.functional as F from megengine.core._imperative_rt.core2 import apply from megengine.core._trace_option import use_symbolic_shape from megengine.core.ops import builtin from megengine.tensor import Tensor +from megengine.utils.network import Network +from megengine.utils.network_node import VarNode def cvt_to_shape_desc(val, inpvar, config=None): @@ -387,108 +392,130 @@ def test_batched_mesh_indexing(): # high level +def get_value(x): + if isinstance(x, VarNode): + var = x.var + o = G.OutputNode(var) + graph = x.graph + graph.compile(o.outputs).execute() + return o.get_value().numpy() + else: + return x.numpy() + + +@pytest.mark.parametrize("test_varnode", [True, False]) +def test_advance_indexing_high_level(test_varnode): + if test_varnode: + network = Network() + else: + network = None - -def test_advance_indexing_high_level(): x = np.arange(25).reshape(5, 5).astype("int32") d = np.arange(15).reshape(3, 5).astype("int32") - xx = Tensor(x) + xx = make_tensor(x, network) - np.testing.assert_equal(x[1, :], xx[1, :].numpy()) - np.testing.assert_equal(x[:, 1], xx[:, 1].numpy()) - np.testing.assert_equal(x[1:3, :], xx[1:3, :].numpy()) + np.testing.assert_equal(x[1, :], get_value(xx[1, :])) + np.testing.assert_equal(x[:, 1], get_value(xx[:, 1])) + np.testing.assert_equal(x[1:3, :], get_value(xx[1:3, :])) - np.testing.assert_equal(x[:, :], xx[:, :].numpy()) - np.testing.assert_equal(x[1, 1], xx[1, 1].numpy()) + np.testing.assert_equal(x[:, :], get_value(xx[:, :])) + np.testing.assert_equal(x[1, 1], get_value(xx[1, 1])) yy = xx[(0, 4, 2), :] - np.testing.assert_equal(x[(0, 4, 2), :], yy.numpy()) + np.testing.assert_equal(x[(0, 4, 2), :], get_value(yy)) x_ = x.copy() x_[(0, 4, 2), :] = d - xx_ = Tensor(xx) + xx_ = make_tensor(xx, network) xx_[(0, 4, 2), :] = d - np.testing.assert_equal(x_, xx_.numpy()) + np.testing.assert_equal(x_, get_value(xx_)) x = np.arange(27).reshape(3, 3, 3).astype("int32") - xx = Tensor(x) + xx = make_tensor(x, network) - np.testing.assert_equal(x[1, :, :], xx[1, :, :].numpy()) - np.testing.assert_equal(x[1, :, 1], xx[1, :, 1].numpy()) - np.testing.assert_equal(x[1, 0:1, :], xx[1, 0:1, :].numpy()) - np.testing.assert_equal(x[0:1, 1, 1], xx[0:1, 1, 1].numpy()) - np.testing.assert_equal(x[:, 1, 1], xx[:, 1, 1].numpy()) - np.testing.assert_equal(x[:, 1], xx[:, 1].numpy()) - np.testing.assert_equal(x[1, 1:2], xx[1, 1:2].numpy()) + np.testing.assert_equal(x[1, :, :], get_value(xx[1, :, :])) + np.testing.assert_equal(x[1, :, 1], get_value(xx[1, :, 1])) + np.testing.assert_equal(x[1, 0:1, :], get_value(xx[1, 0:1, :])) + np.testing.assert_equal(x[0:1, 1, 1], get_value(xx[0:1, 1, 1])) + np.testing.assert_equal(x[:, 1, 1], get_value(xx[:, 1, 1])) + np.testing.assert_equal(x[:, 1], get_value(xx[:, 1])) + np.testing.assert_equal(x[1, 1:2], get_value(xx[1, 1:2])) x_ = x.copy() x_[1, 1, 1] = -1 xx[1, 1, 1] = -1 - np.testing.assert_equal(x_, xx.numpy()) + np.testing.assert_equal(x_, get_value(xx)) x_[:, 1, 1] = -2 xx[:, 1, 1] = x_[:, 1, 1] - np.testing.assert_equal(x_, xx.numpy()) + np.testing.assert_equal(x_, get_value(xx)) x_[0:1, :, 1] = -3 xx[0:1, :, 1] = x_[0:1, :, 1] - np.testing.assert_equal(x_, xx.numpy()) + np.testing.assert_equal(x_, get_value(xx)) x_[0:1, :, 1] = -4 - y = Tensor(x_) + y = make_tensor(x_, network) xx[0:1, :, 1] = y[0:1, :, 1] - np.testing.assert_equal(y.numpy(), xx.numpy()) + np.testing.assert_equal(get_value(y), get_value(xx)) x[:] = 1 xx[:] = 1 - np.testing.assert_equal(x, xx.numpy()) + np.testing.assert_equal(x, get_value(xx)) x = np.arange(9).reshape(3, 3).astype("int32") - xx = Tensor(x) + xx = make_tensor(x, network) y = np.array([1, 2]) - yy = Tensor(y) - np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy()) - np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) - np.testing.assert_equal(x[:, y], xx[:, y].numpy()) - np.testing.assert_equal(x[:, y], xx[:, yy].numpy()) + yy = make_tensor(y, network) + np.testing.assert_equal(x[:, y[0]], get_value(xx[:, y[0]])) + np.testing.assert_equal(x[:, y[0]], get_value(xx[:, yy[0]])) + np.testing.assert_equal(x[:, y], get_value(xx[:, y])) + np.testing.assert_equal(x[:, y], get_value(xx[:, yy])) x_ = x.copy() x_[:, y[0]] = -1 - xx_ = Tensor(x_) + xx_ = make_tensor(x_, network) xx[:, yy[0]] = xx_[:, yy[0]] - np.testing.assert_equal(x_, xx.numpy()) + np.testing.assert_equal(x_, get_value(xx)) x_[:, y] = -1 - xx_ = Tensor(x_) + xx_ = make_tensor(x_, network) xx[:, yy] = xx_[:, yy] - np.testing.assert_equal(x_, xx.numpy()) + np.testing.assert_equal(x_, get_value(xx)) x = np.arange(9).reshape(3, 3).astype("int32") - xx = Tensor(x) + xx = make_tensor(x, network) y = np.array([1]) - yy = Tensor(y) - np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy()) - np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) - np.testing.assert_equal(x[:, y], xx[:, y].numpy()) + yy = make_tensor(y, network) + np.testing.assert_equal(x[:, y[0]], get_value(xx[:, y[0]])) + np.testing.assert_equal(x[:, y[0]], get_value(xx[:, yy[0]])) + np.testing.assert_equal(x[:, y], get_value(xx[:, y])) - np.testing.assert_equal(x[:, y], xx[:, yy].numpy()) + np.testing.assert_equal(x[:, y], get_value(xx[:, yy])) x = np.arange(9).reshape(3, 3).astype("int32") - xx = Tensor(x) - np.testing.assert_equal(x[[0, 1], 0], xx[[0, 1], 0].numpy()) - np.testing.assert_equal(x[0:2, 0], xx[0:2, 0].numpy()) - - -def test_advance_indexing_with_bool(): + xx = make_tensor(x, network) + np.testing.assert_equal(x[[0, 1], 0], get_value(xx[[0, 1], 0])) + np.testing.assert_equal(x[0:2, 0], get_value(xx[0:2, 0])) + + +@pytest.mark.parametrize( + "test_varnode", [True, False], +) +def test_advance_indexing_with_bool(test_varnode): + if test_varnode: + network = Network() + else: + network = None a = np.arange(9).reshape(3, 3).astype(np.float32) b = np.array([1, 2, 3]) c = np.array([1, 2, 3]) - aa = Tensor(a) - bb = Tensor(b) - cc = Tensor(c) - np.testing.assert_equal(a[b == 1, c == 2], aa[bb == 1, cc == 2].numpy()) + aa = make_tensor(a, network) + bb = make_tensor(b, network) + cc = make_tensor(c, network) + np.testing.assert_equal(a[b == 1, c == 2], get_value(aa[bb == 1, cc == 2])) a[b == 1, c == 2] = -1.0 aa[bb == 1, cc == 2] = -1.0 - np.testing.assert_equal(a, aa.numpy()) + np.testing.assert_equal(a, get_value(aa)) a = np.arange(9).reshape(3, 3).astype(np.float32) b = np.array([False, True, True]) diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index e43386b8c94227ca2cb6a7275e43b2422d943e34..b5a9199372bcf0cb844c7aebbb02cfab6bfea186 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -11,13 +11,16 @@ import platform import numpy as np import pytest -from utils import opr_test +from utils import make_tensor, opr_test import megengine.functional as F from megengine import tensor from megengine.core._trace_option import use_symbolic_shape +from megengine.core.tensor import megbrain_graph as G from megengine.core.tensor.utils import astensor1d from megengine.distributed.helper import get_device_count_by_fork +from megengine.utils.network import Network +from megengine.utils.network_node import VarNode def test_eye(): @@ -38,7 +41,13 @@ def test_eye(): ) -def test_concat(): +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_concat(is_varnode): + if is_varnode: + network = Network() + else: + network = None + def get_data_shape(length: int): return (length, 2, 3) @@ -50,18 +59,30 @@ def test_concat(): return F.concat([data1, data2]) cases = [{"input": [data1, data2]}, {"input": [data1, data3]}] - opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y])) + opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network) -def test_concat_device(): - data1 = tensor(np.random.random((3, 2, 2)).astype("float32"), device="cpu0") - data2 = tensor(np.random.random((2, 2, 2)).astype("float32"), device="cpu1") +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_concat_device(is_varnode): + if is_varnode: + network = Network() + else: + network = None + + data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0") + data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1") out = F.concat([data1, data2], device="cpu0") assert str(out.device).split(":")[0] == "cpu0" -def test_stack(): +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_stack(is_varnode): + if is_varnode: + network = Network() + else: + network = None + data1 = np.random.random((3, 2, 2)).astype("float32") data2 = np.random.random((3, 2, 2)).astype("float32") data3 = np.random.random((3, 2, 2)).astype("float32") @@ -72,12 +93,20 @@ def test_stack(): def run(data1, data2): return F.stack([data1, data2], axis=ai) - opr_test(cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai)) + opr_test( + cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network + ) + +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_split(is_varnode): + if is_varnode: + network = Network() + else: + network = None -def test_split(): data = np.random.random((2, 3, 4, 5)).astype(np.float32) - inp = tensor(data) + inp = make_tensor(data, network) mge_out0 = F.split(inp, 2, axis=3) mge_out1 = F.split(inp, [3], axis=3) @@ -106,26 +135,42 @@ def test_split(): assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]" -def test_reshape(): +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_reshape(is_varnode): + if is_varnode: + network = Network() + else: + network = None + x = np.arange(6, dtype="float32") - xx = tensor(x) + xx = make_tensor(x, network) y = x.reshape(1, 2, 3) for shape in [ (1, 2, 3), (1, -1, 3), - (1, tensor(-1), 3), + (1, make_tensor(-1, network), 3), np.array([1, -1, 3], dtype="int32"), - tensor([1, -1, 3]), + make_tensor([1, -1, 3], network), ]: yy = F.reshape(xx, shape) np.testing.assert_equal(yy.numpy(), y) -def test_reshape_shape_inference(): - x_shape_known = tensor([1, 2, 3, 4], dtype="float32") - x_shape_unknown = F.broadcast_to(tensor([1.0]), shape=tensor([1, 1, 1, 1]).sum()) - tshp_unknown = astensor1d((tensor([2]), tensor([2])), x_shape_known) +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_reshape_shape_inference(is_varnode): + if is_varnode: + network = Network() + else: + network = None + + x_shape_known = make_tensor([1, 2, 3, 4], network) + x_shape_unknown = F.broadcast_to( + make_tensor([1.0], network), shape=make_tensor([1, 1, 1, 1], network).sum() + ) + tshp_unknown = astensor1d( + (make_tensor([2], network), make_tensor([2], network)), x_shape_known + ) tshp_known = astensor1d((2, 2), x_shape_known) tshp_known_unspec = astensor1d((2, -1), x_shape_known) @@ -146,12 +191,18 @@ def test_reshape_shape_inference(): {"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]}, {"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]}, ] - opr_test(cases, func, compare_fn=check_shape, test_trace=True) + opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network) + +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_squeeze(is_varnode): + if is_varnode: + network = Network() + else: + network = None -def test_squeeze(): x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1) - xx = tensor(x) + xx = make_tensor(x, network) for axis in [None, 3, -4, (3, -4)]: y = np.squeeze(x, axis) @@ -159,9 +210,15 @@ def test_squeeze(): np.testing.assert_equal(y, yy.numpy()) -def test_expand_dims(): +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_expand_dims(is_varnode): + if is_varnode: + network = Network() + else: + network = None + x = np.arange(6, dtype="float32").reshape(2, 3) - xx = tensor(x) + xx = make_tensor(x, network) for axis in [2, -3, (3, -4), (1, -4)]: y = np.expand_dims(x, axis) @@ -169,11 +226,17 @@ def test_expand_dims(): np.testing.assert_equal(y, yy.numpy()) -def test_elemwise_dtype_promotion(): +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_elemwise_dtype_promotion(is_varnode): + if is_varnode: + network = Network() + else: + network = None + x = np.random.rand(2, 3).astype("float32") y = np.random.rand(1, 3).astype("float16") - xx = tensor(x) - yy = tensor(y) + xx = make_tensor(x, network) + yy = make_tensor(y, network) z = xx * yy np.testing.assert_equal(z.numpy(), x * y) @@ -184,7 +247,13 @@ def test_elemwise_dtype_promotion(): np.testing.assert_equal(z.numpy(), x - y) -def test_linspace(): +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_linspace(is_varnode): + if is_varnode: + network = Network() + else: + network = None + cases = [ {"input": [1, 9, 9]}, {"input": [3, 10, 8]}, @@ -193,6 +262,7 @@ def test_linspace(): cases, F.linspace, ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), + network=network, ) cases = [ @@ -203,20 +273,28 @@ def test_linspace(): cases, F.linspace, ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), + network=network, ) cases = [ - {"input": [1, tensor(9), 9]}, - {"input": [tensor(1), 9, tensor(9)]}, + {"input": [1, make_tensor(9, network), 9]}, + {"input": [make_tensor(1, network), 9, make_tensor(9, network)]}, ] opr_test( cases, F.linspace, ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32), + network=network, ) -def test_arange(): +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_arange(is_varnode): + if is_varnode: + network = Network() + else: + network = None + cases = [ {"input": [1, 9, 1]}, {"input": [2, 10, 2]}, @@ -225,6 +303,7 @@ def test_arange(): cases, F.arange, ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), + network=network, ) cases = [ @@ -235,6 +314,7 @@ def test_arange(): cases, F.arange, ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), + network=network, ) cases = [ @@ -245,20 +325,33 @@ def test_arange(): cases, F.arange, ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), + network=network, ) -def test_round(): +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_round(is_varnode): + if is_varnode: + network = Network() + else: + network = None + data1_shape = (15,) data2_shape = (25,) data1 = np.random.random(data1_shape).astype(np.float32) data2 = np.random.random(data2_shape).astype(np.float32) cases = [{"input": data1}, {"input": data2}] - opr_test(cases, F.round, ref_fn=np.round) + opr_test(cases, F.round, ref_fn=np.round, network=network) -def test_flatten(): +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_flatten(is_varnode): + if is_varnode: + network = Network() + else: + network = None + data0_shape = (2, 3, 4, 5) data1_shape = (4, 5, 6, 7) data0 = np.random.random(data0_shape).astype(np.float32) @@ -273,7 +366,7 @@ def test_flatten(): {"input": data0, "output": output0}, {"input": data1, "output": output1}, ] - opr_test(cases, F.flatten, compare_fn=compare_fn) + opr_test(cases, F.flatten, compare_fn=compare_fn, network=network) output0 = (2, 3 * 4 * 5) output1 = (4, 5 * 6 * 7) @@ -281,7 +374,7 @@ def test_flatten(): {"input": data0, "output": output0}, {"input": data1, "output": output1}, ] - opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1) + opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, network=network) output0 = (2, 3, 4 * 5) output1 = (4, 5, 6 * 7) @@ -289,7 +382,7 @@ def test_flatten(): {"input": data0, "output": output0}, {"input": data1, "output": output1}, ] - opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2) + opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2, network=network) output0 = (2, 3 * 4, 5) output1 = (4, 5 * 6, 7) @@ -297,10 +390,23 @@ def test_flatten(): {"input": data0, "output": output0}, {"input": data1, "output": output1}, ] - opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, end_axis=2) + opr_test( + cases, + F.flatten, + compare_fn=compare_fn, + start_axis=1, + end_axis=2, + network=network, + ) + +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_broadcast(is_varnode): + if is_varnode: + network = Network() + else: + network = None -def test_broadcast(): input1_shape = (20, 30) output1_shape = (30, 20, 30) data1 = np.random.random(input1_shape).astype(np.float32) @@ -321,7 +427,7 @@ def test_broadcast(): {"input": [data2, output2_shape], "output": output2_shape}, {"input": [data3, output3_shape], "output": output3_shape}, ] - opr_test(cases, F.broadcast_to, compare_fn=compare_fn) + opr_test(cases, F.broadcast_to, compare_fn=compare_fn, network=network) x = F.ones((2, 1, 3)) with pytest.raises(RuntimeError): @@ -334,35 +440,41 @@ def test_broadcast(): F.broadcast_to(x, (1, 3)) -def test_utils_astensor1d(): - reference = tensor(0) +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_utils_astensor1d(is_varnode): + if is_varnode: + network = Network() + else: + network = None + + reference = make_tensor(0, network) # literal x = [1, 2, 3] for dtype in [None, "float32"]: xx = astensor1d(x, reference, dtype=dtype) - assert type(xx) is tensor + assert isinstance(xx, type(reference)) np.testing.assert_equal(xx.numpy(), x) # numpy array x = np.asarray([1, 2, 3], dtype="int32") for dtype in [None, "float32"]: xx = astensor1d(x, reference, dtype=dtype) - assert type(xx) is tensor + assert isinstance(xx, type(reference)) np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x) # tensor - x = tensor([1, 2, 3], dtype="int32") + x = make_tensor([1, 2, 3], network) for dtype in [None, "float32"]: xx = astensor1d(x, reference, dtype=dtype) - assert type(xx) is tensor + assert isinstance(xx, type(reference)) np.testing.assert_equal(xx.numpy(), x.numpy()) # mixed - x = [1, tensor(2), 3] + x = [1, make_tensor(2, network), 3] for dtype in [None, "float32"]: xx = astensor1d(x, reference, dtype=dtype) - assert type(xx) is tensor + assert isinstance(xx, type(reference)) np.testing.assert_equal(xx.numpy(), [1, 2, 3]) @@ -382,35 +494,60 @@ def test_device(): np.testing.assert_almost_equal(y5.numpy(), y6.numpy()) -def test_identity(): - x = tensor(np.random.random((5, 10)).astype(np.float32)) +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_identity(is_varnode): + if is_varnode: + network = Network() + else: + network = None + + x = make_tensor(np.random.random((5, 10)).astype(np.float32), network) y = F.copy(x) np.testing.assert_equal(y.numpy(), x) -def copy_test(dst, src): +def copy_test(dst, src, network): data = np.random.random((2, 3)).astype(np.float32) - x = tensor(data, device=src) + x = make_tensor(data, device=src, network=network) y = F.copy(x, dst) assert np.allclose(data, y.numpy()) - z = x.to(dst) - assert np.allclose(data, z.numpy()) + if network is None: + z = x.to(dst) + assert np.allclose(data, z.numpy()) @pytest.mark.require_ngpu(1) -def test_copy_h2d(): - copy_test("cpu0", "gpu0") +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_copy_h2d(is_varnode): + if is_varnode: + network = Network() + else: + network = None + + copy_test("cpu0", "gpu0", network=network) @pytest.mark.require_ngpu(1) -def test_copy_d2h(): - copy_test("gpu0", "cpu0") +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_copy_d2h(is_varnode): + if is_varnode: + network = Network() + else: + network = None + + copy_test("gpu0", "cpu0", network=network) @pytest.mark.require_ngpu(2) -def test_copy_d2d(): - copy_test("gpu0", "gpu1") - copy_test("gpu0:0", "gpu0:1") +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_copy_d2d(is_varnode): + if is_varnode: + network = Network() + else: + network = None + + copy_test("gpu0", "gpu1", network=network) + copy_test("gpu0:0", "gpu0:1", network=network) @pytest.mark.parametrize( @@ -425,7 +562,13 @@ def test_copy_d2d(): ((), 10, None), ], ) -def test_repeat(shape, repeats, axis): +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_repeat(shape, repeats, axis, is_varnode): + if is_varnode: + network = Network() + else: + network = None + def repeat_func(inp): return F.repeat(inp=inp, repeats=repeats, axis=axis) @@ -437,7 +580,10 @@ def test_repeat(shape, repeats, axis): cases = [{"input": np.array(1.23)}] opr_test( - cases, repeat_func, ref_fn=lambda inp: np.repeat(inp, repeats, axis), + cases, + repeat_func, + ref_fn=lambda inp: np.repeat(inp, repeats, axis), + network=network, ) @@ -450,14 +596,16 @@ def test_repeat(shape, repeats, axis): ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)), ], ) -def test_tile(shape, reps): +@pytest.mark.parametrize("is_varnode", [True]) +def test_tile(shape, reps, is_varnode): + if is_varnode: + network = Network() + else: + network = None + def tile_func(inp): return F.tile(inp=inp, reps=reps) - cases = [ - {"input": np.random.randn(*shape).astype("float32")}, - ] + cases = [{"input": np.random.randn(*shape).astype("float32")}] - opr_test( - cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), - ) + opr_test(cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), network=network) diff --git a/imperative/python/test/unit/utils/test_network.py b/imperative/python/test/unit/utils/test_network.py index e398fd64abd7be912f86b73b241ac146e8a93a84..e578c06412bf3901ae58c996c7e8b3d2c53107b4 100644 --- a/imperative/python/test/unit/utils/test_network.py +++ b/imperative/python/test/unit/utils/test_network.py @@ -34,13 +34,11 @@ def test_replace_var(): vara = graph.var_filter.name("a").as_unique() varb = graph.var_filter.name("b").as_unique() - out = F.mul(vara.var, varb.var) + out = F.mul(vara, varb) out = F.relu(out) - var_list = graph.add_dep_oprs(out) - opnode = list(graph.opr_filter.has_input(vara)) - repl_dict = {opnode[0].outputs[0]: var_list[0]} + repl_dict = {opnode[0].outputs[0]: out} graph.replace_vars(repl_dict) modified_model = io.BytesIO() @@ -72,14 +70,12 @@ def test_replace_opr(): vara = graph.var_filter.name("a").as_unique() varb = graph.var_filter.name("b").as_unique() - out1 = F.sub(vara.var, varb.var) + out1 = F.sub(vara, varb) out1 = F.relu(out1) - - var_list = graph.add_dep_oprs(out1) - repl_opr = as_oprnode(var_list) + out1 = graph.add_dep_oprs(out1) orig_opr = graph.opr_filter.has_input(vara).as_unique() - repl_dict = {orig_opr: repl_opr} + repl_dict = {orig_opr: out1[0].owner} graph.replace_oprs(repl_dict) modified_model1 = io.BytesIO() graph.dump(modified_model1) @@ -171,8 +167,7 @@ def test_add_input(): inp_c = graph.make_input_node((2,), np.int32, name="c") varo = graph.var_filter.name("o").as_unique() - out = F.add(varo.var, inp_c.var) - out = graph.add_dep_oprs(out)[0] + out = F.add(varo, inp_c) out.name = "o1" graph.remove_output(varo) graph.add_output(out) @@ -206,12 +201,11 @@ def test_add_output(): var_a = net.var_filter.name("a").as_unique() var_b = net.var_filter.name("b").as_unique() - y = F.add(var_a.var, var_b.var) + y = F.add(var_a, var_b) y = F.sigmoid(y) - new_vars = net.add_dep_oprs(y)[0] - new_vars.name = "o1" - net.add_output(new_vars) + y.name = "o1" + net.add_output(y) modified_model = io.BytesIO() net.dump(modified_model)