提交 9dbe71dd 编写于 作者: M Megvii Engine Team

feat(mge/utils): add array method for varnode

GitOrigin-RevId: 6e4d05b475667cbffe4573ede4f2f581672ae3c9
上级 9b0bd695
......@@ -8,6 +8,9 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
from .._imperative_rt import make_const
from .._imperative_rt.core2 import SymbolVar, Tensor
class Const:
def __init__(self, value=None, *, dtype=None, device=None):
......@@ -19,7 +22,19 @@ class Const:
from ...tensor import Tensor
device = self.device
if device is None:
device = reference[0].device
if len(reference) != 0:
reference = reference[0]
assert isinstance(
reference, (SymbolVar, Tensor)
), "Reference should be Tensor or VarNode"
if device is None:
device = reference.device
if isinstance(reference, SymbolVar):
cls = type(reference)
rst = cls(make_const(reference.graph, self.value, device, self.dtype))
return (rst,)
return (Tensor(self.value, self.dtype, self.device, True),)
......@@ -13,7 +13,7 @@ from typing import Union
import numpy as np
from .._imperative_rt.common import CompNode
from .._imperative_rt.core2 import Tensor, apply
from .._imperative_rt.core2 import SymbolVar, Tensor, apply
from ..ops import builtin
from ..ops.builtin import Elemwise, GetVarShape
from . import utils
......@@ -230,7 +230,9 @@ def _todo(*_):
def _expand_args(args):
if len(args) == 1:
if isinstance(args[0], (collections.abc.Sequence, Tensor, np.ndarray),):
if isinstance(
args[0], (collections.abc.Sequence, Tensor, SymbolVar, np.ndarray),
):
args = args[0]
return args
......
......@@ -10,7 +10,7 @@ from typing import Iterable
import numpy as np
from .._imperative_rt.core2 import Tensor, apply
from .._imperative_rt.core2 import SymbolVar, Tensor, apply
from .._trace_option import use_symbolic_shape
from ..ops import builtin
from ..ops.special import Const
......@@ -149,13 +149,13 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
return True
def get_index(i):
if not isinstance(i, (Tensor)):
if not isinstance(i, (Tensor, SymbolVar)):
if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_:
(i,) = Const(i, dtype=np.bool_, device=inp.device)()
(i,) = Const(i, dtype=np.bool_, device=inp.device)(inp)
else:
(i,) = Const(i, dtype=np.int32, device=inp.device)()
(i,) = Const(i, dtype=np.int32, device=inp.device)(inp)
return i
assert isinstance(i, Tensor)
assert isinstance(i, (Tensor, SymbolVar))
if i.dtype != np.bool_:
return i
_, ind = apply(builtin.CondTake(), i, i)
......@@ -197,9 +197,9 @@ def try_condtake(tensor, index):
):
return []
if isinstance(index, np.ndarray):
(index,) = Const(index, dtype=np.bool_, device=tensor.device)()
assert isinstance(index, Tensor)
if not isinstance(tensor, Tensor):
(index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor)
assert isinstance(index, (Tensor, SymbolVar))
if not isinstance(tensor, (Tensor, SymbolVar)):
raise TypeError("input must be a tensor")
if tensor.device != index.device:
raise ValueError(
......@@ -214,11 +214,16 @@ def getitem(tensor, index):
return try_result[0]
tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index)
for v in tensors:
if v.shape is None:
break
if isinstance(v.shape, v.__class__):
break
if len(v.shape) > 0 and v.shape[0] == 0:
(empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)()
(empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)(
tensor
)
return empty_tensor
if use_subtensor:
op = builtin.Subtensor(items=items)
else:
......@@ -235,8 +240,8 @@ def setitem(tensor, index, value):
if len(try_result) == 2:
index = try_result[1]
tensor = tensor.reshape(-1)
if not isinstance(value, Tensor):
(value,) = Const(value, dtype=tensor.dtype, device=tensor.device)()
if not isinstance(value, (Tensor, SymbolVar)):
(value,) = Const(value, dtype=tensor.dtype, device=tensor.device)(tensor)
tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index)
if use_subtensor:
op = builtin.Subtensor(items=items)
......
......@@ -11,8 +11,9 @@ from typing import Iterable, Union
import numpy as np
from .._imperative_rt import VarNode
from .._imperative_rt.core2 import Tensor, apply, dtype_promotion, get_device
from .._imperative_rt import make_const
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device
from .._wrap import device as as_device
from ..ops import builtin
from ..ops.special import Const
from .dtype import is_dtype_equal, is_quantize
......@@ -38,13 +39,9 @@ def set_convert_inputs(flag):
def concatenate(inputs, axis=0, *, device=None):
dtype = dtype_promotion(inputs)
device = get_device(inputs)
def convert(x):
return convert_single_value(x, dtype=dtype, device=device)
inputs = tuple(map(convert, inputs))
inputs = convert_inputs(*inputs)
if device is None:
device = get_device(inputs)
(result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs)
return result
......@@ -60,7 +57,7 @@ def astype(x, dtype):
def convert_single_value(v, *, dtype=None, device=None):
if isinstance(v, (Tensor, VarNode)):
if isinstance(v, (Tensor, SymbolVar)):
if not is_quantize(v.dtype):
v = astype(v, dtype)
else:
......@@ -68,17 +65,35 @@ def convert_single_value(v, *, dtype=None, device=None):
return v
def convert_inputs(*args: Tensor):
def convert_inputs(*args, device=None):
if not _enable_convert_inputs:
return args
dtype = dtype_promotion(args)
device = get_device(args)
if device is None:
device = get_device(args)
device = as_device(device)
graph = None
sym_type = None
for a in args:
if isinstance(a, SymbolVar):
if graph is None:
graph = a.var.graph
sym_type = type(a)
else:
assert graph == a.var.graph
args = list(args)
if graph is not None:
for i in range(len(args)):
if not isinstance(args[i], SymbolVar):
rst = make_const(graph, np.array(args[i]), device.to_c(), dtype)
args[i] = sym_type(rst)
def convert(value):
if value is None:
return value
return convert_single_value(value, dtype=dtype, device=device)
return convert_single_value(value, dtype=dtype, device=device.to_c())
return tuple(map(convert, args))
......@@ -98,14 +113,14 @@ def result_type(*args):
def isscalar(x):
if isinstance(x, Tensor):
if isinstance(x, (Tensor, SymbolVar)):
return x._isscalar()
return np.isscalar(x)
def setscalar(x):
if isinstance(x, Tensor):
if isinstance(x, (Tensor, SymbolVar)):
x._setscalar()
else:
raise NotImplementedError("Unsupport type {}".format(type(x)))
......@@ -132,7 +147,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
if not isinstance(x, collections.abc.Sequence):
raise TypeError
if any(isinstance(i, Tensor) for i in x):
if any(isinstance(i, (Tensor, SymbolVar)) for i in x):
x = concatenate(x, device=device)
if dtype is not None:
x = astype(x, dtype)
......@@ -142,7 +157,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
def _expand_int(s, i):
if isinstance(i, Tensor):
if isinstance(i, (Tensor, SymbolVar)):
i_np = i.numpy()
if i_np.ndim == 0:
s.append(int(i_np))
......
......@@ -9,8 +9,7 @@
# pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order
import numpy as np
from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.graph import VarNode
from ..core._imperative_rt.core2 import SymbolVar, apply
from ..core.ops import builtin
from ..core.ops.builtin import Elemwise
from ..core.tensor import utils
......@@ -72,7 +71,7 @@ __all__ = [
def _elwise(*args, mode):
tensor_args = list(filter(lambda x: isinstance(x, (Tensor, VarNode)), args))
tensor_args = list(filter(lambda x: isinstance(x, (Tensor, SymbolVar)), args))
if len(tensor_args) == 0:
dtype = utils.dtype_promotion(args)
first_arg = Tensor(args[0], dtype=dtype, device=get_default_device())
......
......@@ -12,7 +12,7 @@ from typing import Iterable, Optional, Sequence, Union
import numpy as np
from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.core2 import SymbolVar, apply
from ..core._wrap import device as as_device
from ..core.ops import builtin
from ..core.ops.builtin import Copy, Identity
......@@ -101,7 +101,7 @@ def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Ten
return result
def full(shape, value, dtype="float32", device=None):
def full(shape, value, dtype="float32", device=None) -> Tensor:
"""
Returns a tensor with given shape and value.
"""
......@@ -115,7 +115,7 @@ def full(shape, value, dtype="float32", device=None):
return broadcast_to(x, shape)
def ones(shape, dtype="float32", device=None):
def ones(shape, dtype="float32", device=None) -> Tensor:
"""
Returns a ones tensor with given shape.
......@@ -142,14 +142,14 @@ def ones(shape, dtype="float32", device=None):
return full(shape, 1.0, dtype=dtype, device=device)
def zeros(shape, dtype="float32", device=None):
def zeros(shape, dtype="float32", device=None) -> Tensor:
"""
Returns a zero tensor with given shape.
"""
return full(shape, 0.0, dtype=dtype, device=device)
def zeros_like(inp: Tensor) -> Tensor:
def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]:
"""
Returns a zero tensor with the same shape as input tensor.
......@@ -176,21 +176,26 @@ def zeros_like(inp: Tensor) -> Tensor:
[0 0 0]]
"""
return zeros(inp.shape, dtype=inp.dtype, device=inp.device)
return full_like(inp, 0.0)
def ones_like(inp: Tensor) -> Tensor:
def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]:
"""
Returns a ones tensor with the same shape as input tensor.
"""
return ones(inp.shape, dtype=inp.dtype, device=inp.device)
return full_like(inp, 1.0)
def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
def full_like(
inp: Union[Tensor, SymbolVar], value: Union[int, float]
) -> Union[Tensor, SymbolVar]:
"""
Returns a tensor filled with given value with the same shape as input tensor.
"""
return full(inp.shape, value, dtype=inp.dtype, device=inp.device)
(x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp)
if inp.shape is ():
return x
return broadcast_to(x, inp.shape)
def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
......@@ -259,15 +264,10 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
if len(inps) == 1:
return inps[0]
dtype = dtype_promotion(inps)
inps = convert_inputs(*inps, device=device)
if device is None:
device = get_device(inps)
device = as_device(device)
def convert(x):
return convert_single_value(x, dtype=dtype, device=device)
inps = tuple(map(convert, inps))
(result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps)
return result
......@@ -379,8 +379,14 @@ def split(inp, nsplits_or_sections, axis=0):
Ntotal, axis, Nsections
)
)
func = (
floor_div
if isinstance(Nsections, (SymbolVar, Tensor))
else lambda x, y: x // y
)
div_points = [0] + [
floor_div(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections)
func(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections)
]
for i in range(2, Nsections + 1):
div_points[i] = div_points[i - 1] + div_points[i]
......@@ -925,11 +931,15 @@ def linspace(
if not (cur_device is None or device == cur_device):
raise ("ambiguous device for linspace opr")
if not isinstance(start, Tensor):
is_symbolvar = list(isinstance(x, SymbolVar) for x in [start, stop, num])
if any(is_symbolvar) and not all(is_symbolvar):
raise TypeError("start, stop and num should all be VarNode or none of them")
if not isinstance(start, (Tensor, SymbolVar)):
start = Tensor(start, device=device)
if not isinstance(stop, Tensor):
if not isinstance(stop, (Tensor, SymbolVar)):
stop = Tensor(stop, device=device)
if not isinstance(num, Tensor):
if not isinstance(num, (Tensor, SymbolVar)):
num = Tensor(num, device=device)
op = builtin.Linspace(comp_node=device)
......@@ -983,7 +993,7 @@ def arange(
stop = stop.astype("float32")
if isinstance(step, Tensor):
step = step.astype("float32")
num = ceil(Tensor((stop - start) / step, device=device))
num = ceil((stop - start) / step)
stop = start + step * (num - 1)
result = linspace(start, stop, num, device=device)
if np.dtype(dtype) == np.int32:
......
......@@ -16,6 +16,7 @@ from typing import Dict, List
import numpy as np
from ..core._imperative_rt import ComputingGraph
from ..core._imperative_rt.core2 import SymbolVar
from ..core.tensor import megbrain_graph as G
from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq
from .network_node import (
......@@ -60,12 +61,12 @@ class Network:
)
outputs = [new_outputs[i] for i in outspec]
self._orig_outputs = outputs
self.add_dep_oprs(*outputs)
for x in self._orig_outputs:
self.output_vars.append(self._get_var(x))
self.add_dep_oprs()
for x in self._orig_inputs:
self.input_vars.append(self._get_var(x))
for x in self._orig_outputs:
self.output_vars.append(self._get_var(x))
self.graph = self._orig_outputs[0].graph
return self
......@@ -197,6 +198,8 @@ class Network:
def add_output(self, *vars: VarNode):
"""Adds vars into the network output node list
"""
if not all([var.owner for var in vars]):
self.add_dep_oprs(*vars)
for var in vars:
if var not in self.output_vars:
self.output_vars.append(var)
......@@ -209,21 +212,25 @@ class Network:
self.output_vars.remove(var)
def add_dep_oprs(self, *vars):
"""Adds dependent opnodes and varnodes of vars into network
"""
oprs = get_oprs_seq(vars, False, False)
for mge_opr in oprs:
if len(vars) == 0:
vars = self.output_vars
q = list(vars)
while len(q) > 0:
cur = q.pop(0)
if cur.owner is not None:
continue
if cur.name is None:
cur.name = cur.var.name
self.all_vars_map[cur.var.id] = cur
mge_opr = cur.var.owner
if get_opr_type(mge_opr) == "Host2DeviceCopy":
self._orig_inputs.extend(mge_opr.outputs)
opr = self._add_opr(mge_opr)
if opr is not None:
for x in mge_opr.inputs:
opr.add_inp_var(self._get_var(x))
# set out var
for x in mge_opr.outputs:
opr.add_out_var(self._get_var(x))
return [self.all_vars_map[var.id] for var in vars]
cur.owner = self._add_opr(mge_opr)
if cur.owner is None:
cur.owner = self.all_oprs_map[mge_opr.id]
continue
q.extend(cur.owner.inputs)
return list(vars)
def modify_opr_names(self, modifier):
"""Modifies names of operators **inplace**; useful for merging loaded
......@@ -275,6 +282,9 @@ class Network:
Replaces vars in the graph.
:param repl_dict: the map {old_var: new_var} that specifies how to replace the vars.
"""
if not all([var.owner for var in repl_dict.values()]):
print(repl_dict.values())
self.add_dep_oprs(*list(repl_dict.values()))
for var in self.all_vars:
if var in repl_dict:
repl_var = repl_dict[var]
......@@ -282,6 +292,7 @@ class Network:
idx = owner.outputs.index(repl_var)
owner.outputs[idx] = var
var.__dict__.update(repl_var.__dict__)
var.var = repl_var.var
def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]):
"""
......@@ -297,6 +308,7 @@ class Network:
for ind, var in enumerate(opr.outputs):
var.owner = repl_dict[opr]
var.__dict__.update(repl_dict[opr].outputs[ind].__dict__)
var.var = repl_dict[opr].outputs[ind].var
def get_opr_by_type(self, oprcls, unique=True):
assert issubclass(oprcls, OpNode)
......@@ -381,11 +393,16 @@ class Network:
return self.opr_filter.as_dict()
# used for loading and building graph
def _add_opr(self, x):
def _add_opr(self, opr):
# TODO: use megbrain C++ RTTI to replace type string
if x.id not in self.all_oprs_map:
self.all_oprs_map[x.id] = str_to_mge_class(get_opr_type(x)).load(x)
return self.all_oprs_map[x.id]
if opr.id not in self.all_oprs_map:
opnode = str_to_mge_class(get_opr_type(opr)).load(opr)
self.all_oprs_map[opr.id] = opnode
for var in opr.inputs:
opnode.add_inp_var(self._get_var(var))
for var in opr.outputs:
opnode.add_out_var(self._get_var(var))
return opnode
else:
return None
......@@ -397,7 +414,7 @@ class Network:
def _get_var(self, x):
# auto convert to VarNode of Network
if x.id not in self.all_vars_map:
if x.id not in self.all_vars_map or self.all_vars_map[x.id].var != x:
self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner))
return self.all_vars_map[x.id]
......@@ -652,7 +669,7 @@ class NodeFilterHasInput(NodeFilter):
assert isinstance(
i, OpNode
), "has_input() must be used with OpNode; " "got {!r}".format(i)
if self.var in i.inputs:
if any(self.var is _ for _ in i.inputs):
yield i
......
......@@ -6,16 +6,21 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import abc
import json
import sys
from typing import Callable
from typing import Callable, Sequence
import numpy as np
from ..core import _imperative_rt as rt
from ..core._imperative_rt.core2 import SymbolVar
from ..core._wrap import Device
from ..core.ops import builtin
from ..core.tensor.megbrain_graph import InputNode
from ..core.tensor.array_method import ArrayMethodMixin
from ..core.tensor.indexing import getitem as _getitem
from ..core.tensor.indexing import setitem as _setitem
from ..core.tensor.megbrain_graph import InputNode, OutputNode
from ..tensor import Tensor
from .comp_graph_tools import replace_vars
from .module_stats import (
......@@ -29,9 +34,13 @@ class NetworkNode:
pass
class VarNode(NetworkNode):
def __init__(self, owner_opr=None, name=None):
self.var = None
class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)):
pass
class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
def __init__(self, var=None, *, owner_opr=None, name=None):
SymbolVar.__init__(self, var)
self.owner = owner_opr
self.name = name
self.id = id(self)
......@@ -58,6 +67,40 @@ class VarNode(NetworkNode):
def dtype(self):
return self.var.dtype if self.var else None
def __bool__(self):
return False
__index__ = None
__int__ = None
__float__ = None
__complex__ = None
def __hash__(self):
return id(self)
@property
def _tuple_shape(self):
return self.var.shape
def numpy(self):
o = OutputNode(self.var)
self.graph.compile(o.outputs).execute()
return o.get_value().numpy()
def __getitem__(self, index):
return _getitem(self, index)
def __setitem__(self, index, value):
if index is not Ellipsis:
value = _setitem(self, index, value)
if self.owner is not None:
idx = self.owner.outputs.index(self)
self.owner.outputs[idx] = VarNode(
self.var, owner_opr=self.owner, name=self.var.name
)
self.var = value.var
self.owner = None
def set_owner_opr(self, owner_opr):
self.owner = owner_opr
......@@ -138,7 +181,7 @@ class Host2DeviceCopy(OpNode):
outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name)
self._opr = outputs.owner
if len(self.outputs) == 0:
self.outputs.append(VarNode(self, self.name))
self.outputs.append(VarNode(owner_opr=self, name=self.name))
self.outputs[0].var = outputs
assert self.outputs[0].owner is self
......@@ -176,8 +219,8 @@ class ImmutableTensor(OpNode):
def set_value(self, data, device=None):
assert self.graph is not None
cn = device if device else self.device
assert isinstance(data, (int, float, np.ndarray))
if isinstance(data, (int, float)):
assert isinstance(data, (int, float, Sequence, np.ndarray))
if not isinstance(data, np.ndarray):
data = np.array(data)
if data.dtype == np.float64:
data = data.astype(np.float32)
......@@ -185,7 +228,7 @@ class ImmutableTensor(OpNode):
data = data.astype(np.int32)
varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name)
if len(self.outputs) == 0:
self.outputs.append(VarNode(self, self.name))
self.outputs.append(VarNode(owner_opr=self, name=self.name))
self.outputs[0].var = varnode
self._opr = varnode.owner
......
......@@ -160,16 +160,21 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
if (ctx.op->same_type<BackwardGraph>()) {
ctx.backward = true;
}
if (py::isinstance<cg::VarNode>(py::handle(args[0]))){
SmallVector<cg::VarNode*> vinputs(nargs);
for (size_t i = 0; i < nargs; ++i) {
vinputs[i] = py::handle(args[i]).cast<cg::VarNode *>();
}
auto op = ctx.op.get();
return to_tuple(OpDef::apply_on_var_node(*op, vinputs)).release().ptr();
}
if (py::isinstance<PySymbolVar>(py::handle(args[0]))){
SmallVector<cg::VarNode*> vinputs(nargs);
for (size_t i = 0; i < nargs; ++i) {
vinputs[i] = py::handle(args[i]).cast<PySymbolVar*>()->m_node;
}
auto op = ctx.op.get();
auto rst = OpDef::apply_on_var_node(*op, vinputs);
auto ret = pybind11::tuple(rst.size());
auto typeobj = py::handle(args[0]).get_type();
for (size_t i = 0; i<rst.size(); ++i) {
ret[i] = typeobj(pybind11::cast(rst[i], pybind11::return_value_policy::automatic));
}
return ret.release().ptr();
}
for (size_t i = 0; i < nargs; ++i) {
if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
......@@ -686,9 +691,9 @@ PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) {
continue;
}
if (py::isinstance<cg::VarNode>(py::handle(handle))){
auto var = py::handle(handle).cast<cg::VarNode *>();
mgb::DType type = var->dtype();
if (py::isinstance<PySymbolVar>(py::handle(handle))){
auto var = py::handle(handle).cast<PySymbolVar*>();
mgb::DType type = var->m_node->dtype();
auto && descr = npy::dtype_mgb2np_descr(type);
Py_INCREF(descr.get());
tensors.emplace_back(descr.get());
......@@ -737,19 +742,26 @@ CompNode _get_device(PyObject*const* args, size_t nargs) {
bool valid = false;
CompNode cn;
for (size_t i = 0; i < nargs; ++i) {
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i];
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
TensorWrapper* tw = TensorWrapper::try_cast(handle);
bool is_var = py::isinstance<cg::VarNode>(py::handle(handle));
if (tw || is_var) {
bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle));
if (tw || is_symvar) {
if (!valid) {
cn = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node();
cn = tw ? tw->m_tensor->comp_node()
: py::handle(handle)
.cast<PySymbolVar*>()
->m_node->comp_node();
valid = true;
} else {
CompNode cn1 = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node();
CompNode cn1 = tw ? tw->m_tensor->comp_node()
: py::handle(handle)
.cast<PySymbolVar*>()
->m_node->comp_node();
if (cn1 != cn) {
throw py::value_error(ssprintf("ambiguous device: %s vs %s",
cn.to_string().c_str(), cn1.to_string().c_str()));
cn.to_string().c_str(),
cn1.to_string().c_str()));
}
}
}
......@@ -849,6 +861,32 @@ void init_tensor(py::module m) {
.def("__call__", &TensorWeakRef::operator())
.def("_use_cnt", &TensorWeakRef::_use_cnt);
py::class_<PySymbolVar, std::shared_ptr<PySymbolVar>>(m, "SymbolVar")
.def_property_readonly(
"dtype", [](PySymbolVar* v) { return v->m_node->dtype(); })
.def_property("var", [](PySymbolVar* v) { return v->m_node; },
[](PySymbolVar* s, cg::VarNode* v) { s->m_node = v; })
.def_property_readonly(
"device",
[](PySymbolVar* v) { return v->m_node->comp_node(); })
.def_property_readonly(
"graph",
[](PySymbolVar* v) { return v->m_node->owner_graph(); })
.def_property_readonly(
"shape",
[](PySymbolVar* v) -> const TensorShape* {
auto&& mgr = v->m_node->owner_graph()
->static_infer_manager();
return mgr.infer_shape_fallible(v->m_node);
})
.def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; })
.def("_setscalar",
[](PySymbolVar* v) { return v->is_scalar = true; })
.def(py::init([](cg::VarNode* node) {
return std::make_shared<PySymbolVar>(node);
}),
py::arg() = nullptr);
static PyMethodDef method_defs[] = {
MGE_PY_INTERFACE(apply, py_apply),
MGE_PY_INTERFACE(dtype_promotion, dtype_promotion),
......
......@@ -181,6 +181,12 @@ struct TensorWrapper {
PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); };
};
struct PySymbolVar {
cg::VarNode* m_node = nullptr;
bool is_scalar = false;
PySymbolVar() = default;
PySymbolVar(VarNode *m): m_node(m){}
};
PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */);
......
......@@ -2,9 +2,11 @@ import io
import numpy as np
import megengine.core.tensor.megbrain_graph as G
import megengine.utils.comp_graph_tools as cgtools
from megengine import tensor
from megengine.jit import trace
from megengine.utils.network_node import VarNode
def _default_compare_fn(x, y):
......@@ -14,8 +16,23 @@ def _default_compare_fn(x, y):
np.testing.assert_allclose(x.numpy(), y, rtol=1e-6)
def make_tensor(x, network=None, device=None):
if network is not None:
if isinstance(x, VarNode):
return VarNode(x.var)
return network.make_const(x, device=device)
else:
return tensor(x, device=device)
def opr_test(
cases, func, compare_fn=_default_compare_fn, ref_fn=None, test_trace=True, **kwargs
cases,
func,
compare_fn=_default_compare_fn,
ref_fn=None,
test_trace=True,
network=None,
**kwargs
):
"""
:param cases: the list which have dict element, the list length should be 2 for dynamic shape test.
......@@ -44,7 +61,7 @@ def opr_test(
if not isinstance(results, (tuple, list)):
results = (results,)
for r, e in zip(results, expected):
if not isinstance(r, tensor):
if not isinstance(r, (tensor, VarNode)):
r = tensor(r)
compare_fn(r, e)
......@@ -72,9 +89,9 @@ def opr_test(
raise ValueError("the input func should be callable")
inp, outp = get_param(cases, 0)
inp_tensor = [tensor(inpi) for inpi in inp]
inp_tensor = [make_tensor(inpi, network) for inpi in inp]
if test_trace:
if test_trace and not network:
copied_inp = inp_tensor.copy()
for symbolic in [False, True]:
traced_func = trace(symbolic=symbolic)(func)
......
......@@ -10,12 +10,17 @@ import collections
import numpy as np
import pytest
from utils import make_tensor
import megengine
import megengine.core.tensor.megbrain_graph as G
import megengine.functional as F
from megengine.core._imperative_rt.core2 import apply
from megengine.core._trace_option import use_symbolic_shape
from megengine.core.ops import builtin
from megengine.tensor import Tensor
from megengine.utils.network import Network
from megengine.utils.network_node import VarNode
def cvt_to_shape_desc(val, inpvar, config=None):
......@@ -387,108 +392,130 @@ def test_batched_mesh_indexing():
# high level
def get_value(x):
if isinstance(x, VarNode):
var = x.var
o = G.OutputNode(var)
graph = x.graph
graph.compile(o.outputs).execute()
return o.get_value().numpy()
else:
return x.numpy()
@pytest.mark.parametrize("test_varnode", [True, False])
def test_advance_indexing_high_level(test_varnode):
if test_varnode:
network = Network()
else:
network = None
def test_advance_indexing_high_level():
x = np.arange(25).reshape(5, 5).astype("int32")
d = np.arange(15).reshape(3, 5).astype("int32")
xx = Tensor(x)
xx = make_tensor(x, network)
np.testing.assert_equal(x[1, :], xx[1, :].numpy())
np.testing.assert_equal(x[:, 1], xx[:, 1].numpy())
np.testing.assert_equal(x[1:3, :], xx[1:3, :].numpy())
np.testing.assert_equal(x[1, :], get_value(xx[1, :]))
np.testing.assert_equal(x[:, 1], get_value(xx[:, 1]))
np.testing.assert_equal(x[1:3, :], get_value(xx[1:3, :]))
np.testing.assert_equal(x[:, :], xx[:, :].numpy())
np.testing.assert_equal(x[1, 1], xx[1, 1].numpy())
np.testing.assert_equal(x[:, :], get_value(xx[:, :]))
np.testing.assert_equal(x[1, 1], get_value(xx[1, 1]))
yy = xx[(0, 4, 2), :]
np.testing.assert_equal(x[(0, 4, 2), :], yy.numpy())
np.testing.assert_equal(x[(0, 4, 2), :], get_value(yy))
x_ = x.copy()
x_[(0, 4, 2), :] = d
xx_ = Tensor(xx)
xx_ = make_tensor(xx, network)
xx_[(0, 4, 2), :] = d
np.testing.assert_equal(x_, xx_.numpy())
np.testing.assert_equal(x_, get_value(xx_))
x = np.arange(27).reshape(3, 3, 3).astype("int32")
xx = Tensor(x)
xx = make_tensor(x, network)
np.testing.assert_equal(x[1, :, :], xx[1, :, :].numpy())
np.testing.assert_equal(x[1, :, 1], xx[1, :, 1].numpy())
np.testing.assert_equal(x[1, 0:1, :], xx[1, 0:1, :].numpy())
np.testing.assert_equal(x[0:1, 1, 1], xx[0:1, 1, 1].numpy())
np.testing.assert_equal(x[:, 1, 1], xx[:, 1, 1].numpy())
np.testing.assert_equal(x[:, 1], xx[:, 1].numpy())
np.testing.assert_equal(x[1, 1:2], xx[1, 1:2].numpy())
np.testing.assert_equal(x[1, :, :], get_value(xx[1, :, :]))
np.testing.assert_equal(x[1, :, 1], get_value(xx[1, :, 1]))
np.testing.assert_equal(x[1, 0:1, :], get_value(xx[1, 0:1, :]))
np.testing.assert_equal(x[0:1, 1, 1], get_value(xx[0:1, 1, 1]))
np.testing.assert_equal(x[:, 1, 1], get_value(xx[:, 1, 1]))
np.testing.assert_equal(x[:, 1], get_value(xx[:, 1]))
np.testing.assert_equal(x[1, 1:2], get_value(xx[1, 1:2]))
x_ = x.copy()
x_[1, 1, 1] = -1
xx[1, 1, 1] = -1
np.testing.assert_equal(x_, xx.numpy())
np.testing.assert_equal(x_, get_value(xx))
x_[:, 1, 1] = -2
xx[:, 1, 1] = x_[:, 1, 1]
np.testing.assert_equal(x_, xx.numpy())
np.testing.assert_equal(x_, get_value(xx))
x_[0:1, :, 1] = -3
xx[0:1, :, 1] = x_[0:1, :, 1]
np.testing.assert_equal(x_, xx.numpy())
np.testing.assert_equal(x_, get_value(xx))
x_[0:1, :, 1] = -4
y = Tensor(x_)
y = make_tensor(x_, network)
xx[0:1, :, 1] = y[0:1, :, 1]
np.testing.assert_equal(y.numpy(), xx.numpy())
np.testing.assert_equal(get_value(y), get_value(xx))
x[:] = 1
xx[:] = 1
np.testing.assert_equal(x, xx.numpy())
np.testing.assert_equal(x, get_value(xx))
x = np.arange(9).reshape(3, 3).astype("int32")
xx = Tensor(x)
xx = make_tensor(x, network)
y = np.array([1, 2])
yy = Tensor(y)
np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy())
np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy())
np.testing.assert_equal(x[:, y], xx[:, y].numpy())
np.testing.assert_equal(x[:, y], xx[:, yy].numpy())
yy = make_tensor(y, network)
np.testing.assert_equal(x[:, y[0]], get_value(xx[:, y[0]]))
np.testing.assert_equal(x[:, y[0]], get_value(xx[:, yy[0]]))
np.testing.assert_equal(x[:, y], get_value(xx[:, y]))
np.testing.assert_equal(x[:, y], get_value(xx[:, yy]))
x_ = x.copy()
x_[:, y[0]] = -1
xx_ = Tensor(x_)
xx_ = make_tensor(x_, network)
xx[:, yy[0]] = xx_[:, yy[0]]
np.testing.assert_equal(x_, xx.numpy())
np.testing.assert_equal(x_, get_value(xx))
x_[:, y] = -1
xx_ = Tensor(x_)
xx_ = make_tensor(x_, network)
xx[:, yy] = xx_[:, yy]
np.testing.assert_equal(x_, xx.numpy())
np.testing.assert_equal(x_, get_value(xx))
x = np.arange(9).reshape(3, 3).astype("int32")
xx = Tensor(x)
xx = make_tensor(x, network)
y = np.array([1])
yy = Tensor(y)
np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy())
np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy())
np.testing.assert_equal(x[:, y], xx[:, y].numpy())
yy = make_tensor(y, network)
np.testing.assert_equal(x[:, y[0]], get_value(xx[:, y[0]]))
np.testing.assert_equal(x[:, y[0]], get_value(xx[:, yy[0]]))
np.testing.assert_equal(x[:, y], get_value(xx[:, y]))
np.testing.assert_equal(x[:, y], xx[:, yy].numpy())
np.testing.assert_equal(x[:, y], get_value(xx[:, yy]))
x = np.arange(9).reshape(3, 3).astype("int32")
xx = Tensor(x)
np.testing.assert_equal(x[[0, 1], 0], xx[[0, 1], 0].numpy())
np.testing.assert_equal(x[0:2, 0], xx[0:2, 0].numpy())
def test_advance_indexing_with_bool():
xx = make_tensor(x, network)
np.testing.assert_equal(x[[0, 1], 0], get_value(xx[[0, 1], 0]))
np.testing.assert_equal(x[0:2, 0], get_value(xx[0:2, 0]))
@pytest.mark.parametrize(
"test_varnode", [True, False],
)
def test_advance_indexing_with_bool(test_varnode):
if test_varnode:
network = Network()
else:
network = None
a = np.arange(9).reshape(3, 3).astype(np.float32)
b = np.array([1, 2, 3])
c = np.array([1, 2, 3])
aa = Tensor(a)
bb = Tensor(b)
cc = Tensor(c)
np.testing.assert_equal(a[b == 1, c == 2], aa[bb == 1, cc == 2].numpy())
aa = make_tensor(a, network)
bb = make_tensor(b, network)
cc = make_tensor(c, network)
np.testing.assert_equal(a[b == 1, c == 2], get_value(aa[bb == 1, cc == 2]))
a[b == 1, c == 2] = -1.0
aa[bb == 1, cc == 2] = -1.0
np.testing.assert_equal(a, aa.numpy())
np.testing.assert_equal(a, get_value(aa))
a = np.arange(9).reshape(3, 3).astype(np.float32)
b = np.array([False, True, True])
......
......@@ -11,13 +11,16 @@ import platform
import numpy as np
import pytest
from utils import opr_test
from utils import make_tensor, opr_test
import megengine.functional as F
from megengine import tensor
from megengine.core._trace_option import use_symbolic_shape
from megengine.core.tensor import megbrain_graph as G
from megengine.core.tensor.utils import astensor1d
from megengine.distributed.helper import get_device_count_by_fork
from megengine.utils.network import Network
from megengine.utils.network_node import VarNode
def test_eye():
......@@ -38,7 +41,13 @@ def test_eye():
)
def test_concat():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_concat(is_varnode):
if is_varnode:
network = Network()
else:
network = None
def get_data_shape(length: int):
return (length, 2, 3)
......@@ -50,18 +59,30 @@ def test_concat():
return F.concat([data1, data2])
cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]))
opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network)
def test_concat_device():
data1 = tensor(np.random.random((3, 2, 2)).astype("float32"), device="cpu0")
data2 = tensor(np.random.random((2, 2, 2)).astype("float32"), device="cpu1")
@pytest.mark.parametrize("is_varnode", [True, False])
def test_concat_device(is_varnode):
if is_varnode:
network = Network()
else:
network = None
data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0")
data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1")
out = F.concat([data1, data2], device="cpu0")
assert str(out.device).split(":")[0] == "cpu0"
def test_stack():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_stack(is_varnode):
if is_varnode:
network = Network()
else:
network = None
data1 = np.random.random((3, 2, 2)).astype("float32")
data2 = np.random.random((3, 2, 2)).astype("float32")
data3 = np.random.random((3, 2, 2)).astype("float32")
......@@ -72,12 +93,20 @@ def test_stack():
def run(data1, data2):
return F.stack([data1, data2], axis=ai)
opr_test(cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai))
opr_test(
cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network
)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_split(is_varnode):
if is_varnode:
network = Network()
else:
network = None
def test_split():
data = np.random.random((2, 3, 4, 5)).astype(np.float32)
inp = tensor(data)
inp = make_tensor(data, network)
mge_out0 = F.split(inp, 2, axis=3)
mge_out1 = F.split(inp, [3], axis=3)
......@@ -106,26 +135,42 @@ def test_split():
assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]"
def test_reshape():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_reshape(is_varnode):
if is_varnode:
network = Network()
else:
network = None
x = np.arange(6, dtype="float32")
xx = tensor(x)
xx = make_tensor(x, network)
y = x.reshape(1, 2, 3)
for shape in [
(1, 2, 3),
(1, -1, 3),
(1, tensor(-1), 3),
(1, make_tensor(-1, network), 3),
np.array([1, -1, 3], dtype="int32"),
tensor([1, -1, 3]),
make_tensor([1, -1, 3], network),
]:
yy = F.reshape(xx, shape)
np.testing.assert_equal(yy.numpy(), y)
def test_reshape_shape_inference():
x_shape_known = tensor([1, 2, 3, 4], dtype="float32")
x_shape_unknown = F.broadcast_to(tensor([1.0]), shape=tensor([1, 1, 1, 1]).sum())
tshp_unknown = astensor1d((tensor([2]), tensor([2])), x_shape_known)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_reshape_shape_inference(is_varnode):
if is_varnode:
network = Network()
else:
network = None
x_shape_known = make_tensor([1, 2, 3, 4], network)
x_shape_unknown = F.broadcast_to(
make_tensor([1.0], network), shape=make_tensor([1, 1, 1, 1], network).sum()
)
tshp_unknown = astensor1d(
(make_tensor([2], network), make_tensor([2], network)), x_shape_known
)
tshp_known = astensor1d((2, 2), x_shape_known)
tshp_known_unspec = astensor1d((2, -1), x_shape_known)
......@@ -146,12 +191,18 @@ def test_reshape_shape_inference():
{"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]},
{"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]},
]
opr_test(cases, func, compare_fn=check_shape, test_trace=True)
opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_squeeze(is_varnode):
if is_varnode:
network = Network()
else:
network = None
def test_squeeze():
x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1)
xx = tensor(x)
xx = make_tensor(x, network)
for axis in [None, 3, -4, (3, -4)]:
y = np.squeeze(x, axis)
......@@ -159,9 +210,15 @@ def test_squeeze():
np.testing.assert_equal(y, yy.numpy())
def test_expand_dims():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_expand_dims(is_varnode):
if is_varnode:
network = Network()
else:
network = None
x = np.arange(6, dtype="float32").reshape(2, 3)
xx = tensor(x)
xx = make_tensor(x, network)
for axis in [2, -3, (3, -4), (1, -4)]:
y = np.expand_dims(x, axis)
......@@ -169,11 +226,17 @@ def test_expand_dims():
np.testing.assert_equal(y, yy.numpy())
def test_elemwise_dtype_promotion():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_elemwise_dtype_promotion(is_varnode):
if is_varnode:
network = Network()
else:
network = None
x = np.random.rand(2, 3).astype("float32")
y = np.random.rand(1, 3).astype("float16")
xx = tensor(x)
yy = tensor(y)
xx = make_tensor(x, network)
yy = make_tensor(y, network)
z = xx * yy
np.testing.assert_equal(z.numpy(), x * y)
......@@ -184,7 +247,13 @@ def test_elemwise_dtype_promotion():
np.testing.assert_equal(z.numpy(), x - y)
def test_linspace():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_linspace(is_varnode):
if is_varnode:
network = Network()
else:
network = None
cases = [
{"input": [1, 9, 9]},
{"input": [3, 10, 8]},
......@@ -193,6 +262,7 @@ def test_linspace():
cases,
F.linspace,
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
network=network,
)
cases = [
......@@ -203,20 +273,28 @@ def test_linspace():
cases,
F.linspace,
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
network=network,
)
cases = [
{"input": [1, tensor(9), 9]},
{"input": [tensor(1), 9, tensor(9)]},
{"input": [1, make_tensor(9, network), 9]},
{"input": [make_tensor(1, network), 9, make_tensor(9, network)]},
]
opr_test(
cases,
F.linspace,
ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32),
network=network,
)
def test_arange():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_arange(is_varnode):
if is_varnode:
network = Network()
else:
network = None
cases = [
{"input": [1, 9, 1]},
{"input": [2, 10, 2]},
......@@ -225,6 +303,7 @@ def test_arange():
cases,
F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
network=network,
)
cases = [
......@@ -235,6 +314,7 @@ def test_arange():
cases,
F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
network=network,
)
cases = [
......@@ -245,20 +325,33 @@ def test_arange():
cases,
F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
network=network,
)
def test_round():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_round(is_varnode):
if is_varnode:
network = Network()
else:
network = None
data1_shape = (15,)
data2_shape = (25,)
data1 = np.random.random(data1_shape).astype(np.float32)
data2 = np.random.random(data2_shape).astype(np.float32)
cases = [{"input": data1}, {"input": data2}]
opr_test(cases, F.round, ref_fn=np.round)
opr_test(cases, F.round, ref_fn=np.round, network=network)
def test_flatten():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_flatten(is_varnode):
if is_varnode:
network = Network()
else:
network = None
data0_shape = (2, 3, 4, 5)
data1_shape = (4, 5, 6, 7)
data0 = np.random.random(data0_shape).astype(np.float32)
......@@ -273,7 +366,7 @@ def test_flatten():
{"input": data0, "output": output0},
{"input": data1, "output": output1},
]
opr_test(cases, F.flatten, compare_fn=compare_fn)
opr_test(cases, F.flatten, compare_fn=compare_fn, network=network)
output0 = (2, 3 * 4 * 5)
output1 = (4, 5 * 6 * 7)
......@@ -281,7 +374,7 @@ def test_flatten():
{"input": data0, "output": output0},
{"input": data1, "output": output1},
]
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1)
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, network=network)
output0 = (2, 3, 4 * 5)
output1 = (4, 5, 6 * 7)
......@@ -289,7 +382,7 @@ def test_flatten():
{"input": data0, "output": output0},
{"input": data1, "output": output1},
]
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2)
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2, network=network)
output0 = (2, 3 * 4, 5)
output1 = (4, 5 * 6, 7)
......@@ -297,10 +390,23 @@ def test_flatten():
{"input": data0, "output": output0},
{"input": data1, "output": output1},
]
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, end_axis=2)
opr_test(
cases,
F.flatten,
compare_fn=compare_fn,
start_axis=1,
end_axis=2,
network=network,
)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_broadcast(is_varnode):
if is_varnode:
network = Network()
else:
network = None
def test_broadcast():
input1_shape = (20, 30)
output1_shape = (30, 20, 30)
data1 = np.random.random(input1_shape).astype(np.float32)
......@@ -321,7 +427,7 @@ def test_broadcast():
{"input": [data2, output2_shape], "output": output2_shape},
{"input": [data3, output3_shape], "output": output3_shape},
]
opr_test(cases, F.broadcast_to, compare_fn=compare_fn)
opr_test(cases, F.broadcast_to, compare_fn=compare_fn, network=network)
x = F.ones((2, 1, 3))
with pytest.raises(RuntimeError):
......@@ -334,35 +440,41 @@ def test_broadcast():
F.broadcast_to(x, (1, 3))
def test_utils_astensor1d():
reference = tensor(0)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_utils_astensor1d(is_varnode):
if is_varnode:
network = Network()
else:
network = None
reference = make_tensor(0, network)
# literal
x = [1, 2, 3]
for dtype in [None, "float32"]:
xx = astensor1d(x, reference, dtype=dtype)
assert type(xx) is tensor
assert isinstance(xx, type(reference))
np.testing.assert_equal(xx.numpy(), x)
# numpy array
x = np.asarray([1, 2, 3], dtype="int32")
for dtype in [None, "float32"]:
xx = astensor1d(x, reference, dtype=dtype)
assert type(xx) is tensor
assert isinstance(xx, type(reference))
np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x)
# tensor
x = tensor([1, 2, 3], dtype="int32")
x = make_tensor([1, 2, 3], network)
for dtype in [None, "float32"]:
xx = astensor1d(x, reference, dtype=dtype)
assert type(xx) is tensor
assert isinstance(xx, type(reference))
np.testing.assert_equal(xx.numpy(), x.numpy())
# mixed
x = [1, tensor(2), 3]
x = [1, make_tensor(2, network), 3]
for dtype in [None, "float32"]:
xx = astensor1d(x, reference, dtype=dtype)
assert type(xx) is tensor
assert isinstance(xx, type(reference))
np.testing.assert_equal(xx.numpy(), [1, 2, 3])
......@@ -382,35 +494,60 @@ def test_device():
np.testing.assert_almost_equal(y5.numpy(), y6.numpy())
def test_identity():
x = tensor(np.random.random((5, 10)).astype(np.float32))
@pytest.mark.parametrize("is_varnode", [True, False])
def test_identity(is_varnode):
if is_varnode:
network = Network()
else:
network = None
x = make_tensor(np.random.random((5, 10)).astype(np.float32), network)
y = F.copy(x)
np.testing.assert_equal(y.numpy(), x)
def copy_test(dst, src):
def copy_test(dst, src, network):
data = np.random.random((2, 3)).astype(np.float32)
x = tensor(data, device=src)
x = make_tensor(data, device=src, network=network)
y = F.copy(x, dst)
assert np.allclose(data, y.numpy())
z = x.to(dst)
assert np.allclose(data, z.numpy())
if network is None:
z = x.to(dst)
assert np.allclose(data, z.numpy())
@pytest.mark.require_ngpu(1)
def test_copy_h2d():
copy_test("cpu0", "gpu0")
@pytest.mark.parametrize("is_varnode", [True, False])
def test_copy_h2d(is_varnode):
if is_varnode:
network = Network()
else:
network = None
copy_test("cpu0", "gpu0", network=network)
@pytest.mark.require_ngpu(1)
def test_copy_d2h():
copy_test("gpu0", "cpu0")
@pytest.mark.parametrize("is_varnode", [True, False])
def test_copy_d2h(is_varnode):
if is_varnode:
network = Network()
else:
network = None
copy_test("gpu0", "cpu0", network=network)
@pytest.mark.require_ngpu(2)
def test_copy_d2d():
copy_test("gpu0", "gpu1")
copy_test("gpu0:0", "gpu0:1")
@pytest.mark.parametrize("is_varnode", [True, False])
def test_copy_d2d(is_varnode):
if is_varnode:
network = Network()
else:
network = None
copy_test("gpu0", "gpu1", network=network)
copy_test("gpu0:0", "gpu0:1", network=network)
@pytest.mark.parametrize(
......@@ -425,7 +562,13 @@ def test_copy_d2d():
((), 10, None),
],
)
def test_repeat(shape, repeats, axis):
@pytest.mark.parametrize("is_varnode", [True, False])
def test_repeat(shape, repeats, axis, is_varnode):
if is_varnode:
network = Network()
else:
network = None
def repeat_func(inp):
return F.repeat(inp=inp, repeats=repeats, axis=axis)
......@@ -437,7 +580,10 @@ def test_repeat(shape, repeats, axis):
cases = [{"input": np.array(1.23)}]
opr_test(
cases, repeat_func, ref_fn=lambda inp: np.repeat(inp, repeats, axis),
cases,
repeat_func,
ref_fn=lambda inp: np.repeat(inp, repeats, axis),
network=network,
)
......@@ -450,14 +596,16 @@ def test_repeat(shape, repeats, axis):
((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)),
],
)
def test_tile(shape, reps):
@pytest.mark.parametrize("is_varnode", [True])
def test_tile(shape, reps, is_varnode):
if is_varnode:
network = Network()
else:
network = None
def tile_func(inp):
return F.tile(inp=inp, reps=reps)
cases = [
{"input": np.random.randn(*shape).astype("float32")},
]
cases = [{"input": np.random.randn(*shape).astype("float32")}]
opr_test(
cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps),
)
opr_test(cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), network=network)
......@@ -34,13 +34,11 @@ def test_replace_var():
vara = graph.var_filter.name("a").as_unique()
varb = graph.var_filter.name("b").as_unique()
out = F.mul(vara.var, varb.var)
out = F.mul(vara, varb)
out = F.relu(out)
var_list = graph.add_dep_oprs(out)
opnode = list(graph.opr_filter.has_input(vara))
repl_dict = {opnode[0].outputs[0]: var_list[0]}
repl_dict = {opnode[0].outputs[0]: out}
graph.replace_vars(repl_dict)
modified_model = io.BytesIO()
......@@ -72,14 +70,12 @@ def test_replace_opr():
vara = graph.var_filter.name("a").as_unique()
varb = graph.var_filter.name("b").as_unique()
out1 = F.sub(vara.var, varb.var)
out1 = F.sub(vara, varb)
out1 = F.relu(out1)
var_list = graph.add_dep_oprs(out1)
repl_opr = as_oprnode(var_list)
out1 = graph.add_dep_oprs(out1)
orig_opr = graph.opr_filter.has_input(vara).as_unique()
repl_dict = {orig_opr: repl_opr}
repl_dict = {orig_opr: out1[0].owner}
graph.replace_oprs(repl_dict)
modified_model1 = io.BytesIO()
graph.dump(modified_model1)
......@@ -171,8 +167,7 @@ def test_add_input():
inp_c = graph.make_input_node((2,), np.int32, name="c")
varo = graph.var_filter.name("o").as_unique()
out = F.add(varo.var, inp_c.var)
out = graph.add_dep_oprs(out)[0]
out = F.add(varo, inp_c)
out.name = "o1"
graph.remove_output(varo)
graph.add_output(out)
......@@ -206,12 +201,11 @@ def test_add_output():
var_a = net.var_filter.name("a").as_unique()
var_b = net.var_filter.name("b").as_unique()
y = F.add(var_a.var, var_b.var)
y = F.add(var_a, var_b)
y = F.sigmoid(y)
new_vars = net.add_dep_oprs(y)[0]
new_vars.name = "o1"
net.add_output(new_vars)
y.name = "o1"
net.add_output(y)
modified_model = io.BytesIO()
net.dump(modified_model)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册