提交 638ab52f 编写于 作者: M Megvii Engine Team

feat(mge/imperative): simulates scalar

GitOrigin-RevId: e81630e25647688d2c2385e9faec0ee3b8c8174c
上级 7167fdbd
......@@ -26,4 +26,6 @@ def set_symbolic_shape(option: bool):
""" Sets whether tensor.shape returns a tensor instead of a tuple
"""
global _use_symbolic_shape
_org = _use_symbolic_shape
_use_symbolic_shape = option
return _org
......@@ -14,7 +14,7 @@ from .._trace_option import use_symbolic_shape
from ..ops import builtin
from ..ops.special import Const
from .core import TensorBase, TensorWrapperBase, apply
from .utils import astensor1d, make_shape_tuple
from .utils import astensor1d, isscalar, make_shape_tuple
def remove_ellipsis(tensor, tuple_val):
......@@ -89,9 +89,13 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
if not isinstance(tuple_val, tuple):
tuple_val = (tuple_val,)
ndim_indexed = 0
ndim_indexed_scalar = 0
for i in tuple_val:
if not i is Ellipsis:
ndim_indexed += 1 if not hasattr(i, "ndim") else i.ndim
if isscalar(i):
ndim_indexed_scalar += 1
if ndim_indexed > inp.ndim:
raise IndexError(
"too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format(
......@@ -103,15 +107,6 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
use_subtensor = True
inp, tuple_val = check_bool_index(inp, tuple_val)
def is_scalar(d):
if isinstance(i, int):
return True
if type(d).__module__ == np.__name__:
return np.isscalar(d)
# if isinstance(d, (TensorBase, TensorWrapperBase)):
# return d.shape == (1,)
return False
new_axes = []
tensors = []
items = []
......@@ -134,7 +129,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
continue
if (
not is_scalar(i)
not isscalar(i)
and not i is np.newaxis
and not i is Ellipsis
and not isinstance(i, slice)
......@@ -191,7 +186,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
items.append(item)
if new_axes:
raise IndexError("newaxis is not allowed here")
return inp, tensors, items, use_subtensor
return inp, tensors, items, use_subtensor, ndim_indexed_scalar == inp.ndim
def try_condtake(tensor, index):
......@@ -217,11 +212,11 @@ def getitem(tensor, index):
try_result = try_condtake(tensor, index)
if len(try_result) == 2:
return try_result[0]
tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index)
tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index)
for v in tensors:
if isinstance(v.shape, v.__class__):
break
if v.shape[0] == 0:
if len(v.shape) > 0 and v.shape[0] == 0:
(empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)(
tensor
)
......@@ -231,6 +226,8 @@ def getitem(tensor, index):
else:
op = builtin.IndexingMultiAxisVec(items=items)
(result,) = apply(op, tensor, *tensors)
if ret_scalar:
result.__wrapped__._data._isscalar = True
return result
......@@ -245,9 +242,9 @@ def setitem(tensor, index, value):
if not isinstance(value, (TensorBase, TensorWrapperBase)):
op = Const(value, dtype=tensor.dtype, device=tensor.device)
(value,) = op(tensor)
tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index)
tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index)
for v in tensors:
if v.shape[0] == 0:
if len(v.shape) > 0 and v.shape[0] == 0:
return tensor
if use_subtensor:
op = builtin.Subtensor(items=items)
......
......@@ -102,8 +102,9 @@ class Graph(_imperative_rt.ComputingGraph):
class VarNode(TensorBase):
def __init__(self, node: _imperative_rt.VarNode):
def __init__(self, node: _imperative_rt.VarNode, isscalar=False):
self._node = node
self._isscalar = isscalar
if hasattr(self.graph, "_var_cache"):
self.graph._var_cache[node] = self
......
......@@ -33,8 +33,9 @@ class RawTensor(TensorBase):
_del_cb = None
_handle = None
def __init__(self, handle=None):
def __init__(self, handle=None, isscalar=False):
self._handle = handle
self._isscalar = isscalar
if handle is not None:
if self._init_cb:
self._init_cb()
......@@ -49,10 +50,15 @@ class RawTensor(TensorBase):
@property
def shape(self):
if self._isscalar:
return ()
return get_shape(self._handle)
def numpy(self):
return get_value(self._handle)
ret = get_value(self._handle)
if self._isscalar:
ret = ret.squeeze()
return ret
def _dev_tensor(self):
return _get_dev_tensor(self._handle)
......@@ -102,7 +108,7 @@ def _(array: np.ndarray, dtype=None, device=None):
device = None if device is None else as_device(device).to_c()
if 0 in array.strides:
array = array.squeeze().reshape(array.shape)
return RawTensor(put(array, dtype=dtype, device=device))
return RawTensor(put(array, dtype=dtype, device=device), isscalar=(array.ndim == 0))
@as_raw_tensor.register(RawTensor)
......
......@@ -21,7 +21,9 @@ from .indexing import getitem as _getitem
from .indexing import setitem as _setitem
from .raw_tensor import RawTensor, as_raw_tensor
from .tensor import Tensor
from .utils import isscalar
from .utils import make_shape_tuple as _make_shape_tuple
from .utils import setscalar
_ElwMod = Elemwise.Mode
......@@ -39,6 +41,13 @@ def _elwise(*args, mode):
)
args = utils.convert_inputs(*args)
(result,) = apply(op, *args)
_isscalar = True
for i in args:
if isscalar(i) == False:
_isscalar = False
break
if _isscalar:
setscalar(result)
return result
......@@ -153,6 +162,8 @@ def _remove_axis(inp: Tensor, axis) -> Tensor:
param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis))
op = builtin.AxisAddRemove(param=param)
(result,) = apply(op, inp)
if len(axis) == inp.ndim:
setscalar(result)
return result
......@@ -189,6 +200,8 @@ def _reduce(mode):
if self.dtype == np.bool_:
if mode in ["MIN", "MAX"]:
result = result.astype("bool")
if axis is None or self.ndim == 1:
setscalar(result)
return result
return f
......@@ -321,9 +334,7 @@ class ArrayMethodMixin(abc.ABC):
__complex__ = lambda self: complex(self.item())
def __len__(self):
shape = self.shape
if use_symbolic_shape():
shape = shape.numpy()
shape = self.__wrapped__.shape
if shape:
return int(shape[0])
raise TypeError("ndim is 0")
......@@ -344,18 +355,17 @@ class ArrayMethodMixin(abc.ABC):
@property
def ndim(self):
shape = self.shape
if isinstance(shape, self.__class__):
# XXX: assume ndim is not changed during trace
ndim = shape.__wrapped__.shape[0]
return ndim
shape = self.__wrapped__.shape
if shape is None:
raise ValueError("unkown ndim")
return len(shape)
@property
def size(self):
if use_symbolic_shape():
return self.shape.prod()
shape = self.shape
if shape.__class__ is tuple:
return np.prod(self.shape).item()
return shape.prod()
@property
def T(self):
......@@ -416,8 +426,8 @@ class ArrayMethodMixin(abc.ABC):
.. testoutput::
[2]
[10.]
2
10.
"""
return _reduce("SUM")(self, axis, keepdims)
......@@ -444,10 +454,10 @@ class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase):
@property
def shape(self):
if use_symbolic_shape():
shape = self.__wrapped__.shape
if shape == () or not use_symbolic_shape():
return shape
return apply(GetVarShape(), self)[0]
else:
return self.__wrapped__.shape
@property
def device(self):
......
......@@ -133,7 +133,9 @@ def concatenate(inputs, axis=0, *, device=None):
def astype(x, dtype):
dtype = np.dtype(dtype)
if not is_equal(x.dtype, dtype):
isscalar = x.__wrapped__._data._isscalar
(x,) = apply(builtin.TypeCvt(param=dtype), x)
x.__wrapped__._data._isscalar = isscalar
return x
......@@ -176,13 +178,29 @@ def result_type(*args):
def isscalar(x):
try:
return x.ndim == 0
except:
pass
if isinstance(x, TensorWrapperBase):
x = x.__wrapped__
if hasattr(x, "_isscalar"):
return x._isscalar
if isinstance(x, TensorBase):
return x._data._isscalar
return np.isscalar(x)
def setscalar(x):
if isinstance(x, TensorWrapperBase):
x = x.__wrapped__
if hasattr(x, "_isscalar"):
x._isscalar = True
elif isinstance(x, TensorBase):
x._data._isscalar = True
else:
raise NotImplementedError("Unsupport type {}".format(type(x)))
def astensor1d(x, *reference, dtype=None, device=None):
"""
Convert something to 1D tensor. Support following types
......@@ -195,8 +213,8 @@ def astensor1d(x, *reference, dtype=None, device=None):
except AttributeError:
pass
else:
if ndim != 1:
raise ValueError("ndim != 1: %d" % ndim)
if ndim != 0 and ndim != 1:
raise ValueError("ndim != 1 or 0, get : %d" % ndim)
if not isinstance(x, (TensorBase, TensorWrapperBase)):
(x,) = Const(x, dtype=dtype, device=device)(*reference)
return x
......@@ -216,7 +234,11 @@ def astensor1d(x, *reference, dtype=None, device=None):
def _expand_int(s, i):
if isinstance(i, (TensorBase, TensorWrapperBase)):
s += list(i.numpy())
i_np = i.numpy()
if i_np.ndim == 0:
s.append(int(i_np))
else:
s += list(i_np)
return
if isinstance(i, Iterable):
for ii in i:
......
......@@ -63,8 +63,12 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list):
"""
op = ParamPackSplit()
op.offsets = offsets
op.shapes = shapes
return apply(op, inp)
op.shapes = [s or (1,) for s in shapes]
outputs = apply(op, inp)
for s, x in zip(shapes, outputs):
if not s:
x._isscalar = True
return outputs
def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list):
......
......@@ -13,6 +13,7 @@ from ..core.ops import builtin
from ..core.ops.builtin import Elemwise
from ..core.tensor import megbrain_graph, utils
from ..core.tensor.core import apply
from ..core.tensor.utils import isscalar, setscalar
from ..device import get_default_device
from ..jit.tracing import is_tracing
from ..tensor import Tensor
......@@ -105,7 +106,14 @@ def _elwise(*args, mode):
args = utils.convert_inputs(*args)
if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"):
args = tuple(map(lambda x: x.astype("float32"), args))
_isscalar = True
for i in args:
if isscalar(i) == False:
_isscalar = False
break
(result,) = apply(op, *args)
if _isscalar:
setscalar(result)
return result
......
......@@ -63,7 +63,7 @@ def l1_loss(pred: Tensor, label: Tensor) -> Tensor:
.. testoutput::
[2.75]
2.75
"""
diff = pred - label
......@@ -115,7 +115,7 @@ def square_loss(pred: Tensor, label: Tensor) -> Tensor:
.. testoutput::
[9.75]
9.75
"""
diff = pred - label
......@@ -170,7 +170,7 @@ def cross_entropy(
.. testoutput::
[0.6931]
0.6931
"""
n0 = pred.ndim
......@@ -226,7 +226,7 @@ def binary_cross_entropy(
.. testoutput::
[0.6931]
0.6931
"""
if not with_logits:
......@@ -265,7 +265,7 @@ def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor:
.. testoutput::
[1.5]
1.5
"""
assert norm in ["L1", "L2"], "norm must be L1 or L2"
......
......@@ -155,7 +155,7 @@ def sum(
.. testoutput::
[21]
21
"""
return inp.sum(axis=axis, keepdims=keepdims)
......@@ -189,7 +189,7 @@ def prod(
.. testoutput::
[720]
720
"""
return inp.prod(axis=axis, keepdims=keepdims)
......@@ -226,7 +226,7 @@ def mean(
.. testoutput::
[3.5]
3.5
"""
return inp.mean(axis=axis, keepdims=keepdims)
......@@ -263,7 +263,7 @@ def var(
.. testoutput::
[2.9167]
2.9167
"""
if axis is None:
m = mean(inp, axis=axis, keepdims=False)
......@@ -340,7 +340,7 @@ def min(
.. testoutput::
[1]
1
"""
return inp.min(axis=axis, keepdims=keepdims)
......@@ -377,7 +377,7 @@ def max(
.. testoutput::
[6]
6
"""
return inp.max(axis=axis, keepdims=keepdims)
......@@ -412,7 +412,7 @@ def norm(
.. testoutput::
[4.3589]
4.3589
"""
if axis is None:
......@@ -460,7 +460,7 @@ def argmin(
.. testoutput::
[0]
0
"""
if isinstance(axis, collections.abc.Iterable):
......@@ -519,7 +519,7 @@ def argmax(
.. testoutput::
[5]
5
"""
if isinstance(axis, collections.abc.Iterable):
......
......@@ -111,6 +111,8 @@ def full(shape, value, dtype="float32", device=None):
(x,) = Const(value, dtype=dtype, device=device)(
Tensor(value, dtype=dtype, device=device)
)
if len(shape) == 0: # scalar
return x
return broadcast_to(x, shape)
......
......@@ -53,7 +53,7 @@ def topk_accuracy(
.. testoutput::
[0.] [0.375]
0.0 0.375
"""
if isinstance(topk, int):
topk = (topk,)
......
......@@ -168,8 +168,6 @@ class trace:
self._output_bindings = None
self._output_names = None
set_symbolic_shape(self._symbolic_shape)
def _new_handle(self):
handle = len(self._tinfo)
info = TensorInfo()
......@@ -368,6 +366,7 @@ class trace:
interrupted = False
def do_enter():
self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape)
self._set_active(True)
if self._untraced:
self._init_trace(self._symbolic)
......@@ -423,6 +422,8 @@ class trace:
apply.disable(apply_compiled_mode)
apply.disable(apply_const_compiled_mode)
self._set_active(False)
# Restore global variable
set_symbolic_shape(self._save_symbolic_shape)
def do_exit():
if not self._untraced and self._pc != len(self._seq):
......@@ -498,7 +499,7 @@ class trace:
opnode = info.data_setter = G.InputNode(
device=info.device,
dtype=info.dtype,
shape=info.shape,
shape=info.shape or (1,),
graph=graph,
use_static_shape=_input_node_use_static_shape(),
)
......@@ -544,7 +545,7 @@ class trace:
*links,
device=info.device,
dtype=info.dtype,
shape=info.shape,
shape=info.shape or (1,),
graph=graph,
use_static_shape=_input_node_use_static_shape(),
)
......@@ -719,13 +720,13 @@ class trace:
h2v[h] = graph.make_h2d(
dtype=info.dtype,
device=dumped_device,
shape=info.shape,
shape=info.shape or (1,),
name=arg_names[i] if arg_names else None,
)
for k, h in self._kwarg_bindings.items():
info = self._tinfo[h]
h2v[h] = graph.make_h2d(
dtype=info.dtype, device=dumped_device, shape=info.shape, name=k
dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k
)
for op, ihandles, ohandles in self._seq:
......@@ -919,6 +920,7 @@ class CompiledTensorProxy(RawTensor):
def __init__(self, handle):
self.__handle = handle
self._isscalar = False
self.__info = active_trace._tinfo[handle]
self.__shape = None
self.__data = None
......@@ -934,6 +936,8 @@ class CompiledTensorProxy(RawTensor):
@property
def shape(self):
if self._isscalar:
return ()
if self.__shape is None:
if self.__info.shape_read:
self.__shape = self.__info.shape_reader.get_value().shape
......@@ -951,6 +955,8 @@ class CompiledTensorProxy(RawTensor):
self.__value = self._dev_tensor().numpy()
else:
raise TraceMismatchError("value of this tensor is not read in trace")
if self._isscalar:
self.__value = self.__value.squeeze()
return self.__value
def _dev_tensor(self):
......@@ -970,9 +976,10 @@ class CompiledTensorProxy(RawTensor):
class LazyEvalTensor(RawTensor):
def __init__(self, varnode):
super(LazyEvalTensor, self).__init__()
def __init__(self, varnode, isscalar=False):
super().__init__()
self.__varnode = varnode
self._isscalar = isscalar
@property
def dtype(self):
......@@ -984,10 +991,15 @@ class LazyEvalTensor(RawTensor):
@property
def shape(self):
if self._isscalar:
return ()
return self.__varnode.shape
def numpy(self):
return self.__varnode.value
ret = self.__varnode.value
if self._isscalar:
ret = ret.squeeze()
return ret
def _dev_tensor(self):
raise RuntimeError("cannot access data during symbolic tracing")
......@@ -1041,10 +1053,12 @@ class TracedLazyTensor(TraceMixin, LazyEvalTensor):
def assign_raw_tensor(lhs, rhs):
handle = rhs._handle
# Keep isscalar of lhs
isscalar = lhs._isscalar
rhs.__dict__.clear()
lhs.__dict__.clear()
lhs.__class__ = RawTensor
lhs.__init__(handle)
lhs.__init__(handle, isscalar=isscalar)
# this hook turns RawTensor into LazyEvalTensor
......@@ -1060,7 +1074,7 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
data_setter = G.InputNode(
device=x.device,
dtype=x.dtype,
shape=x.shape,
shape=x.shape or (1,),
graph=graph,
use_static_shape=True,
)
......@@ -1091,7 +1105,9 @@ apply.disable(apply_symbolic_mode)
@apply.register()
def apply_const_symbolic_mode(op: Const, *args: RawTensor):
graph = active_trace._lazy_eval_graph
ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device))
ret = LazyEvalTensor(
graph.make_const(op.value, dtype=op.dtype, device=op.device), isscalar=True
)
active_trace._lazy_eval_tensors.add(ret)
return (ret,)
......
......@@ -46,9 +46,9 @@ class Observer(Module):
def get_dtype(self):
q_dict = self.get_qparams()
numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0]
numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy()
numpy_zero_point = (
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0]
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()
)
return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point)
......
......@@ -18,7 +18,7 @@ from megengine.module import Module
class Simple(Module):
def __init__(self):
super().__init__()
self.a = Parameter(1.0, dtype=np.float32)
self.a = Parameter([1.0], dtype=np.float32)
def forward(self, x, y):
x = x[y] * self.a
......@@ -28,7 +28,7 @@ class Simple(Module):
class Simple2(Module):
def __init__(self):
super().__init__()
self.a = Parameter(1.0, dtype=np.float32)
self.a = Parameter([1.0], dtype=np.float32)
def forward(self, x):
x = x[1, ..., :, 0:4:2, 0:2] * self.a
......
......@@ -18,7 +18,7 @@ from megengine.module import Module
class Simple(Module):
def __init__(self):
super().__init__()
self.a = Parameter(1.0, dtype=np.float32)
self.a = Parameter([1.0], dtype=np.float32)
def forward(self, x):
x = x[:, 0] * self.a
......
......@@ -18,8 +18,8 @@ from megengine.module import Module
class Simple(Module):
def __init__(self):
super().__init__()
self.a = Parameter(1.0, dtype=np.float32)
self.b = Parameter(1.0, dtype=np.float32)
self.a = Parameter([1.0], dtype=np.float32)
self.b = Parameter([1.0], dtype=np.float32)
def forward(self, x):
x = x * self.a
......
......@@ -21,7 +21,7 @@ from megengine.module import Module
class Simple(Module):
def __init__(self):
super().__init__()
self.a = Parameter(1.23, dtype=np.float32)
self.a = Parameter([1.23], dtype=np.float32)
def forward(self, x):
x = x * self.a
......
......@@ -18,7 +18,7 @@ from megengine.optimizer import SGD, MultiStepLR
class Simple(Module):
def __init__(self):
super().__init__()
self.a = Parameter(1.23, dtype=np.float32)
self.a = Parameter([1.23], dtype=np.float32)
def forward(self, x):
x = x * self.a
......
......@@ -32,7 +32,7 @@ class MLP(Module):
class Simple(Module):
def __init__(self):
super().__init__()
self.a = Parameter(1.23, dtype=np.float32)
self.a = Parameter([1.23], dtype=np.float32)
def forward(self, x):
x = x * self.a
......
......@@ -11,7 +11,7 @@ from megengine.module import Module
class Simple(Module):
def __init__(self):
super().__init__()
self.a = Parameter(1.23, dtype=np.float32)
self.a = Parameter([1.23], dtype=np.float32)
def forward(self, x):
x = x * self.a
......
......@@ -19,7 +19,7 @@ from megengine.module import Module
class Simple(Module):
def __init__(self):
super().__init__()
self.a = Parameter(1.23, dtype=np.float32)
self.a = Parameter([1.23], dtype=np.float32)
def forward(self, x):
x = x * self.a
......
......@@ -107,7 +107,7 @@ def test_xornet_trace_dump():
if step % 50 == 0:
minibatch = next(val_dataset)
_, loss = val_fun(data, label)
loss = loss.numpy()[0]
loss = loss.numpy()
val_loss.append((step, loss))
print("Step: {} loss={}".format(step, loss))
opt.step()
......
......@@ -449,7 +449,7 @@ def test_advance_indexing_high_level():
y = np.array([1, 2])
yy = Tensor(y)
np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy())
# np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) # FIXME
np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy())
np.testing.assert_equal(x[:, y], xx[:, y].numpy())
np.testing.assert_equal(x[:, y], xx[:, yy].numpy())
......@@ -469,10 +469,9 @@ def test_advance_indexing_high_level():
y = np.array([1])
yy = Tensor(y)
np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy())
# np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) # FIXME
np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy())
np.testing.assert_equal(x[:, y], xx[:, y].numpy())
# XXX: no way to tell whether yy is scalar or ndim=1 array
np.testing.assert_equal(x[:, y], xx[:, yy].numpy())
x = np.arange(9).reshape(3, 3).astype("int32")
......
......@@ -21,6 +21,7 @@ from megengine.core.ops import builtin as ops
from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.core import apply
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.core.tensor.utils import isscalar
from megengine.functional import exp, log
from megengine.jit import exclude_from_trace, trace
from megengine.random import normal, uniform
......@@ -263,20 +264,21 @@ def test_optimize_for_inference_broadcast():
def test_trace_cvt_bool():
set_symbolic_shape(True)
x = tensor([0], dtype=np.int32)
@trace(symbolic=True)
def f(x):
return x.shape[0] == 0
a = x.shape
b = a[0]
assert isscalar(b)
return b == 0
for i in range(3):
np.testing.assert_equal(f(x).numpy()[0], False)
np.testing.assert_equal(f(x).numpy(), False)
def test_trace_reshape():
for symbolic in [False, True]:
set_symbolic_shape(True)
x1 = tensor(np.random.randn(2, 10, 10))
x2 = tensor(np.random.randn(4, 10, 10))
x3 = tensor(np.random.randn(8, 10, 10))
......@@ -359,7 +361,6 @@ def test_raise_on_trace():
def test_trace_broadcast():
for symbolic in [False, True]:
set_symbolic_shape(True)
x1 = tensor(np.random.randn(3, 1, 1))
x2 = tensor(np.random.randn(1, 4, 1))
x3 = tensor(np.random.randn(1, 1, 5))
......@@ -397,7 +398,6 @@ def test_trace_nms():
def test_trace_valid_broadcast():
set_symbolic_shape(True)
x1 = tensor(np.random.randn(1, 1))
x2 = tensor(np.random.randn(1, 2))
shape = (tensor([2]), tensor([2]))
......
import numpy as np
import megengine.functional as F
from megengine import Tensor
from megengine.core._trace_option import use_symbolic_shape
def test_zero_dim():
a = Tensor(1)
a_np = np.array(1, dtype=np.int32)
np.testing.assert_equal(a, a_np)
if use_symbolic_shape():
np.testing.assert_equal(a.shape, np.array(a_np.shape))
else:
np.testing.assert_equal(a.shape, a_np.shape)
def test_sum():
a = Tensor([1, 2])
a = a.reshape((1, 2))
assert a.sum().ndim == 0
assert a.sum(axis=1).ndim == 1
def test_max():
a = Tensor([1, 2])
a = a.reshape((1, 2))
assert a.max().ndim == 0
assert a.max(axis=1).ndim == 1
def test_reshape():
a = Tensor(1)
a = a.reshape((1, 1))
def test_squeeze():
a = Tensor(1)
a = a.reshape((1, 1))
assert F.squeeze(a).ndim == 0
def test_elemementwise():
a = Tensor(1.0)
assert F.exp(a).ndim == 0
assert (a + a).ndim == 0
assert (a + 1).ndim == 0
def test_astype():
a = Tensor(1.0)
assert a.astype("int32").ndim == 0
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册