diff --git a/imperative/python/megengine/core/_trace_option.py b/imperative/python/megengine/core/_trace_option.py index 9a32e1ed43b65737d86312f24da82a0224b6e0e0..4e95222063717bf599bd6600342c142d061f8f8d 100644 --- a/imperative/python/megengine/core/_trace_option.py +++ b/imperative/python/megengine/core/_trace_option.py @@ -26,4 +26,6 @@ def set_symbolic_shape(option: bool): """ Sets whether tensor.shape returns a tensor instead of a tuple """ global _use_symbolic_shape + _org = _use_symbolic_shape _use_symbolic_shape = option + return _org diff --git a/imperative/python/megengine/core/tensor/indexing.py b/imperative/python/megengine/core/tensor/indexing.py index bcb78b1883a8d1b6060c268943e800d122be5c8c..7f2934b84c95d8fe16879990439d1d19718f3af4 100644 --- a/imperative/python/megengine/core/tensor/indexing.py +++ b/imperative/python/megengine/core/tensor/indexing.py @@ -14,7 +14,7 @@ from .._trace_option import use_symbolic_shape from ..ops import builtin from ..ops.special import Const from .core import TensorBase, TensorWrapperBase, apply -from .utils import astensor1d, make_shape_tuple +from .utils import astensor1d, isscalar, make_shape_tuple def remove_ellipsis(tensor, tuple_val): @@ -89,9 +89,13 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): if not isinstance(tuple_val, tuple): tuple_val = (tuple_val,) ndim_indexed = 0 + ndim_indexed_scalar = 0 for i in tuple_val: if not i is Ellipsis: ndim_indexed += 1 if not hasattr(i, "ndim") else i.ndim + if isscalar(i): + ndim_indexed_scalar += 1 + if ndim_indexed > inp.ndim: raise IndexError( "too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format( @@ -103,15 +107,6 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): use_subtensor = True inp, tuple_val = check_bool_index(inp, tuple_val) - def is_scalar(d): - if isinstance(i, int): - return True - if type(d).__module__ == np.__name__: - return np.isscalar(d) - # if isinstance(d, (TensorBase, TensorWrapperBase)): - # return d.shape == (1,) - return False - new_axes = [] tensors = [] items = [] @@ -134,7 +129,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): continue if ( - not is_scalar(i) + not isscalar(i) and not i is np.newaxis and not i is Ellipsis and not isinstance(i, slice) @@ -191,7 +186,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): items.append(item) if new_axes: raise IndexError("newaxis is not allowed here") - return inp, tensors, items, use_subtensor + return inp, tensors, items, use_subtensor, ndim_indexed_scalar == inp.ndim def try_condtake(tensor, index): @@ -217,11 +212,11 @@ def getitem(tensor, index): try_result = try_condtake(tensor, index) if len(try_result) == 2: return try_result[0] - tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index) + tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index) for v in tensors: if isinstance(v.shape, v.__class__): break - if v.shape[0] == 0: + if len(v.shape) > 0 and v.shape[0] == 0: (empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)( tensor ) @@ -231,6 +226,8 @@ def getitem(tensor, index): else: op = builtin.IndexingMultiAxisVec(items=items) (result,) = apply(op, tensor, *tensors) + if ret_scalar: + result.__wrapped__._data._isscalar = True return result @@ -245,9 +242,9 @@ def setitem(tensor, index, value): if not isinstance(value, (TensorBase, TensorWrapperBase)): op = Const(value, dtype=tensor.dtype, device=tensor.device) (value,) = op(tensor) - tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index) + tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) for v in tensors: - if v.shape[0] == 0: + if len(v.shape) > 0 and v.shape[0] == 0: return tensor if use_subtensor: op = builtin.Subtensor(items=items) diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index 2cb6e136ab4ae9367b14b91d7943a3d28175d2c6..abe5e98c2abe83fa4d82c4fe68daa6fc9ca394e0 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -102,8 +102,9 @@ class Graph(_imperative_rt.ComputingGraph): class VarNode(TensorBase): - def __init__(self, node: _imperative_rt.VarNode): + def __init__(self, node: _imperative_rt.VarNode, isscalar=False): self._node = node + self._isscalar = isscalar if hasattr(self.graph, "_var_cache"): self.graph._var_cache[node] = self diff --git a/imperative/python/megengine/core/tensor/raw_tensor/__init__.py b/imperative/python/megengine/core/tensor/raw_tensor/__init__.py index ca62b105be7e76ea827fef4fd04ca3989440beb0..bc3c8f9f8641d0edaea657c88df312da0eec66bb 100644 --- a/imperative/python/megengine/core/tensor/raw_tensor/__init__.py +++ b/imperative/python/megengine/core/tensor/raw_tensor/__init__.py @@ -33,8 +33,9 @@ class RawTensor(TensorBase): _del_cb = None _handle = None - def __init__(self, handle=None): + def __init__(self, handle=None, isscalar=False): self._handle = handle + self._isscalar = isscalar if handle is not None: if self._init_cb: self._init_cb() @@ -49,10 +50,15 @@ class RawTensor(TensorBase): @property def shape(self): + if self._isscalar: + return () return get_shape(self._handle) def numpy(self): - return get_value(self._handle) + ret = get_value(self._handle) + if self._isscalar: + ret = ret.squeeze() + return ret def _dev_tensor(self): return _get_dev_tensor(self._handle) @@ -102,7 +108,7 @@ def _(array: np.ndarray, dtype=None, device=None): device = None if device is None else as_device(device).to_c() if 0 in array.strides: array = array.squeeze().reshape(array.shape) - return RawTensor(put(array, dtype=dtype, device=device)) + return RawTensor(put(array, dtype=dtype, device=device), isscalar=(array.ndim == 0)) @as_raw_tensor.register(RawTensor) diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index 9bf4bd1fda5dcdf376795d825bfd93f919da7797..5d8cdf846c03811dfefee8aa9049d89bfb77fc55 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -21,7 +21,9 @@ from .indexing import getitem as _getitem from .indexing import setitem as _setitem from .raw_tensor import RawTensor, as_raw_tensor from .tensor import Tensor +from .utils import isscalar from .utils import make_shape_tuple as _make_shape_tuple +from .utils import setscalar _ElwMod = Elemwise.Mode @@ -39,6 +41,13 @@ def _elwise(*args, mode): ) args = utils.convert_inputs(*args) (result,) = apply(op, *args) + _isscalar = True + for i in args: + if isscalar(i) == False: + _isscalar = False + break + if _isscalar: + setscalar(result) return result @@ -153,6 +162,8 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis)) op = builtin.AxisAddRemove(param=param) (result,) = apply(op, inp) + if len(axis) == inp.ndim: + setscalar(result) return result @@ -189,6 +200,8 @@ def _reduce(mode): if self.dtype == np.bool_: if mode in ["MIN", "MAX"]: result = result.astype("bool") + if axis is None or self.ndim == 1: + setscalar(result) return result return f @@ -321,9 +334,7 @@ class ArrayMethodMixin(abc.ABC): __complex__ = lambda self: complex(self.item()) def __len__(self): - shape = self.shape - if use_symbolic_shape(): - shape = shape.numpy() + shape = self.__wrapped__.shape if shape: return int(shape[0]) raise TypeError("ndim is 0") @@ -344,18 +355,17 @@ class ArrayMethodMixin(abc.ABC): @property def ndim(self): - shape = self.shape - if isinstance(shape, self.__class__): - # XXX: assume ndim is not changed during trace - ndim = shape.__wrapped__.shape[0] - return ndim + shape = self.__wrapped__.shape + if shape is None: + raise ValueError("unkown ndim") return len(shape) @property def size(self): - if use_symbolic_shape(): - return self.shape.prod() - return np.prod(self.shape).item() + shape = self.shape + if shape.__class__ is tuple: + return np.prod(self.shape).item() + return shape.prod() @property def T(self): @@ -416,8 +426,8 @@ class ArrayMethodMixin(abc.ABC): .. testoutput:: - [2] - [10.] + 2 + 10. """ return _reduce("SUM")(self, axis, keepdims) @@ -444,10 +454,10 @@ class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase): @property def shape(self): - if use_symbolic_shape(): - return apply(GetVarShape(), self)[0] - else: - return self.__wrapped__.shape + shape = self.__wrapped__.shape + if shape == () or not use_symbolic_shape(): + return shape + return apply(GetVarShape(), self)[0] @property def device(self): diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index 703ad2d2c88e00fb213fa7516f8f8ae4322cd16b..49055d5f8d20919f4861d6727a43389f67e71fa9 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -133,7 +133,9 @@ def concatenate(inputs, axis=0, *, device=None): def astype(x, dtype): dtype = np.dtype(dtype) if not is_equal(x.dtype, dtype): + isscalar = x.__wrapped__._data._isscalar (x,) = apply(builtin.TypeCvt(param=dtype), x) + x.__wrapped__._data._isscalar = isscalar return x @@ -176,13 +178,29 @@ def result_type(*args): def isscalar(x): - try: - return x.ndim == 0 - except: - pass + if isinstance(x, TensorWrapperBase): + x = x.__wrapped__ + + if hasattr(x, "_isscalar"): + return x._isscalar + if isinstance(x, TensorBase): + return x._data._isscalar + return np.isscalar(x) +def setscalar(x): + if isinstance(x, TensorWrapperBase): + x = x.__wrapped__ + + if hasattr(x, "_isscalar"): + x._isscalar = True + elif isinstance(x, TensorBase): + x._data._isscalar = True + else: + raise NotImplementedError("Unsupport type {}".format(type(x))) + + def astensor1d(x, *reference, dtype=None, device=None): """ Convert something to 1D tensor. Support following types @@ -195,8 +213,8 @@ def astensor1d(x, *reference, dtype=None, device=None): except AttributeError: pass else: - if ndim != 1: - raise ValueError("ndim != 1: %d" % ndim) + if ndim != 0 and ndim != 1: + raise ValueError("ndim != 1 or 0, get : %d" % ndim) if not isinstance(x, (TensorBase, TensorWrapperBase)): (x,) = Const(x, dtype=dtype, device=device)(*reference) return x @@ -216,7 +234,11 @@ def astensor1d(x, *reference, dtype=None, device=None): def _expand_int(s, i): if isinstance(i, (TensorBase, TensorWrapperBase)): - s += list(i.numpy()) + i_np = i.numpy() + if i_np.ndim == 0: + s.append(int(i_np)) + else: + s += list(i_np) return if isinstance(i, Iterable): for ii in i: diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index d990cc263422fbcdb4d4575d14440a212ab59570..f8269251e353f51dc476c1e42fa0a37029166448 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -63,8 +63,12 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list): """ op = ParamPackSplit() op.offsets = offsets - op.shapes = shapes - return apply(op, inp) + op.shapes = [s or (1,) for s in shapes] + outputs = apply(op, inp) + for s, x in zip(shapes, outputs): + if not s: + x._isscalar = True + return outputs def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list): diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index 98abebce98ab848cb2f231d31be4a3de6978c622..9ebce0a952c373ad00861da47b8cd8459bd9b455 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -13,6 +13,7 @@ from ..core.ops import builtin from ..core.ops.builtin import Elemwise from ..core.tensor import megbrain_graph, utils from ..core.tensor.core import apply +from ..core.tensor.utils import isscalar, setscalar from ..device import get_default_device from ..jit.tracing import is_tracing from ..tensor import Tensor @@ -105,7 +106,14 @@ def _elwise(*args, mode): args = utils.convert_inputs(*args) if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"): args = tuple(map(lambda x: x.astype("float32"), args)) + _isscalar = True + for i in args: + if isscalar(i) == False: + _isscalar = False + break (result,) = apply(op, *args) + if _isscalar: + setscalar(result) return result diff --git a/imperative/python/megengine/functional/loss.py b/imperative/python/megengine/functional/loss.py index b5613758a904f3f3e6f9ab515f43afdb6cf54842..9421ec67fe1cc6020fde9c5db8a6a68bf3e93dc5 100644 --- a/imperative/python/megengine/functional/loss.py +++ b/imperative/python/megengine/functional/loss.py @@ -63,7 +63,7 @@ def l1_loss(pred: Tensor, label: Tensor) -> Tensor: .. testoutput:: - [2.75] + 2.75 """ diff = pred - label @@ -115,7 +115,7 @@ def square_loss(pred: Tensor, label: Tensor) -> Tensor: .. testoutput:: - [9.75] + 9.75 """ diff = pred - label @@ -170,7 +170,7 @@ def cross_entropy( .. testoutput:: - [0.6931] + 0.6931 """ n0 = pred.ndim @@ -226,7 +226,7 @@ def binary_cross_entropy( .. testoutput:: - [0.6931] + 0.6931 """ if not with_logits: @@ -265,7 +265,7 @@ def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor: .. testoutput:: - [1.5] + 1.5 """ assert norm in ["L1", "L2"], "norm must be L1 or L2" diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index 60c3c89929dacc468dc87a05c82b06f49157d57c..e89862a1f298907988bfd5b0dcd0a3552c14cec0 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -155,7 +155,7 @@ def sum( .. testoutput:: - [21] + 21 """ return inp.sum(axis=axis, keepdims=keepdims) @@ -189,7 +189,7 @@ def prod( .. testoutput:: - [720] + 720 """ return inp.prod(axis=axis, keepdims=keepdims) @@ -226,7 +226,7 @@ def mean( .. testoutput:: - [3.5] + 3.5 """ return inp.mean(axis=axis, keepdims=keepdims) @@ -263,7 +263,7 @@ def var( .. testoutput:: - [2.9167] + 2.9167 """ if axis is None: m = mean(inp, axis=axis, keepdims=False) @@ -340,7 +340,7 @@ def min( .. testoutput:: - [1] + 1 """ return inp.min(axis=axis, keepdims=keepdims) @@ -377,7 +377,7 @@ def max( .. testoutput:: - [6] + 6 """ return inp.max(axis=axis, keepdims=keepdims) @@ -412,7 +412,7 @@ def norm( .. testoutput:: - [4.3589] + 4.3589 """ if axis is None: @@ -460,7 +460,7 @@ def argmin( .. testoutput:: - [0] + 0 """ if isinstance(axis, collections.abc.Iterable): @@ -519,7 +519,7 @@ def argmax( .. testoutput:: - [5] + 5 """ if isinstance(axis, collections.abc.Iterable): diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 60ccf66b4fc983d1b7de8b39f28dc44eaa1e48b3..05c87a8ad1059b3142ca5eeda498acb205269ed2 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -111,6 +111,8 @@ def full(shape, value, dtype="float32", device=None): (x,) = Const(value, dtype=dtype, device=device)( Tensor(value, dtype=dtype, device=device) ) + if len(shape) == 0: # scalar + return x return broadcast_to(x, shape) diff --git a/imperative/python/megengine/functional/utils.py b/imperative/python/megengine/functional/utils.py index fa38e8b1fc7ddd24db0cdb77007b9548e2476529..a8e6b0e02e91e9745147e6904a7bdf1811751a9a 100644 --- a/imperative/python/megengine/functional/utils.py +++ b/imperative/python/megengine/functional/utils.py @@ -53,7 +53,7 @@ def topk_accuracy( .. testoutput:: - [0.] [0.375] + 0.0 0.375 """ if isinstance(topk, int): topk = (topk,) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index ae211a5af346fb03d6b1bd86ac38b0cbc93b2bc6..a6450919bc19c3c05d32d84543907fc697aa3aea 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -168,8 +168,6 @@ class trace: self._output_bindings = None self._output_names = None - set_symbolic_shape(self._symbolic_shape) - def _new_handle(self): handle = len(self._tinfo) info = TensorInfo() @@ -368,6 +366,7 @@ class trace: interrupted = False def do_enter(): + self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape) self._set_active(True) if self._untraced: self._init_trace(self._symbolic) @@ -423,6 +422,8 @@ class trace: apply.disable(apply_compiled_mode) apply.disable(apply_const_compiled_mode) self._set_active(False) + # Restore global variable + set_symbolic_shape(self._save_symbolic_shape) def do_exit(): if not self._untraced and self._pc != len(self._seq): @@ -498,7 +499,7 @@ class trace: opnode = info.data_setter = G.InputNode( device=info.device, dtype=info.dtype, - shape=info.shape, + shape=info.shape or (1,), graph=graph, use_static_shape=_input_node_use_static_shape(), ) @@ -544,7 +545,7 @@ class trace: *links, device=info.device, dtype=info.dtype, - shape=info.shape, + shape=info.shape or (1,), graph=graph, use_static_shape=_input_node_use_static_shape(), ) @@ -719,13 +720,13 @@ class trace: h2v[h] = graph.make_h2d( dtype=info.dtype, device=dumped_device, - shape=info.shape, + shape=info.shape or (1,), name=arg_names[i] if arg_names else None, ) for k, h in self._kwarg_bindings.items(): info = self._tinfo[h] h2v[h] = graph.make_h2d( - dtype=info.dtype, device=dumped_device, shape=info.shape, name=k + dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k ) for op, ihandles, ohandles in self._seq: @@ -919,6 +920,7 @@ class CompiledTensorProxy(RawTensor): def __init__(self, handle): self.__handle = handle + self._isscalar = False self.__info = active_trace._tinfo[handle] self.__shape = None self.__data = None @@ -934,6 +936,8 @@ class CompiledTensorProxy(RawTensor): @property def shape(self): + if self._isscalar: + return () if self.__shape is None: if self.__info.shape_read: self.__shape = self.__info.shape_reader.get_value().shape @@ -951,6 +955,8 @@ class CompiledTensorProxy(RawTensor): self.__value = self._dev_tensor().numpy() else: raise TraceMismatchError("value of this tensor is not read in trace") + if self._isscalar: + self.__value = self.__value.squeeze() return self.__value def _dev_tensor(self): @@ -970,9 +976,10 @@ class CompiledTensorProxy(RawTensor): class LazyEvalTensor(RawTensor): - def __init__(self, varnode): - super(LazyEvalTensor, self).__init__() + def __init__(self, varnode, isscalar=False): + super().__init__() self.__varnode = varnode + self._isscalar = isscalar @property def dtype(self): @@ -984,10 +991,15 @@ class LazyEvalTensor(RawTensor): @property def shape(self): + if self._isscalar: + return () return self.__varnode.shape def numpy(self): - return self.__varnode.value + ret = self.__varnode.value + if self._isscalar: + ret = ret.squeeze() + return ret def _dev_tensor(self): raise RuntimeError("cannot access data during symbolic tracing") @@ -1041,10 +1053,12 @@ class TracedLazyTensor(TraceMixin, LazyEvalTensor): def assign_raw_tensor(lhs, rhs): handle = rhs._handle + # Keep isscalar of lhs + isscalar = lhs._isscalar rhs.__dict__.clear() lhs.__dict__.clear() lhs.__class__ = RawTensor - lhs.__init__(handle) + lhs.__init__(handle, isscalar=isscalar) # this hook turns RawTensor into LazyEvalTensor @@ -1060,7 +1074,7 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): data_setter = G.InputNode( device=x.device, dtype=x.dtype, - shape=x.shape, + shape=x.shape or (1,), graph=graph, use_static_shape=True, ) @@ -1091,7 +1105,9 @@ apply.disable(apply_symbolic_mode) @apply.register() def apply_const_symbolic_mode(op: Const, *args: RawTensor): graph = active_trace._lazy_eval_graph - ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device)) + ret = LazyEvalTensor( + graph.make_const(op.value, dtype=op.dtype, device=op.device), isscalar=True + ) active_trace._lazy_eval_tensors.add(ret) return (ret,) diff --git a/imperative/python/megengine/quantization/observer.py b/imperative/python/megengine/quantization/observer.py index b6239bafb8cae92d1cb97388a85602b92ca1717d..7d3f452e76ffcb0fb014ed3dac68aaf8795ad697 100644 --- a/imperative/python/megengine/quantization/observer.py +++ b/imperative/python/megengine/quantization/observer.py @@ -46,9 +46,9 @@ class Observer(Module): def get_dtype(self): q_dict = self.get_qparams() - numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0] + numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy() numpy_zero_point = ( - None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0] + None if "zero_point" not in q_dict else q_dict["zero_point"].numpy() ) return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point) diff --git a/imperative/python/test/integration/test_advance_indexing.py b/imperative/python/test/integration/test_advance_indexing.py index 2785292afc1a63172306034c036ec470507ca5c2..8088d9f9c320efe51ab2bacedadff692fcafd456 100644 --- a/imperative/python/test/integration/test_advance_indexing.py +++ b/imperative/python/test/integration/test_advance_indexing.py @@ -18,7 +18,7 @@ from megengine.module import Module class Simple(Module): def __init__(self): super().__init__() - self.a = Parameter(1.0, dtype=np.float32) + self.a = Parameter([1.0], dtype=np.float32) def forward(self, x, y): x = x[y] * self.a @@ -28,7 +28,7 @@ class Simple(Module): class Simple2(Module): def __init__(self): super().__init__() - self.a = Parameter(1.0, dtype=np.float32) + self.a = Parameter([1.0], dtype=np.float32) def forward(self, x): x = x[1, ..., :, 0:4:2, 0:2] * self.a diff --git a/imperative/python/test/integration/test_ai.py b/imperative/python/test/integration/test_ai.py index 89c6e86ef849895c65ce0d91182b1b12fd515d44..dd45616cd4c02a06a4ce0d277811adceb99c5f34 100644 --- a/imperative/python/test/integration/test_ai.py +++ b/imperative/python/test/integration/test_ai.py @@ -18,7 +18,7 @@ from megengine.module import Module class Simple(Module): def __init__(self): super().__init__() - self.a = Parameter(1.0, dtype=np.float32) + self.a = Parameter([1.0], dtype=np.float32) def forward(self, x): x = x[:, 0] * self.a diff --git a/imperative/python/test/integration/test_detach.py b/imperative/python/test/integration/test_detach.py index 6bd9c890f5a0079cf5f175290dd006a76a801c23..b441425d856566430c0e7f6c3010e1cd18fc4c79 100644 --- a/imperative/python/test/integration/test_detach.py +++ b/imperative/python/test/integration/test_detach.py @@ -18,8 +18,8 @@ from megengine.module import Module class Simple(Module): def __init__(self): super().__init__() - self.a = Parameter(1.0, dtype=np.float32) - self.b = Parameter(1.0, dtype=np.float32) + self.a = Parameter([1.0], dtype=np.float32) + self.b = Parameter([1.0], dtype=np.float32) def forward(self, x): x = x * self.a diff --git a/imperative/python/test/integration/test_hello_world.py b/imperative/python/test/integration/test_hello_world.py index 138518bf886fa029424bfb3367c9a5f3d0850ca5..d49221864cf26476162514109e8a04c7df6b597f 100644 --- a/imperative/python/test/integration/test_hello_world.py +++ b/imperative/python/test/integration/test_hello_world.py @@ -21,7 +21,7 @@ from megengine.module import Module class Simple(Module): def __init__(self): super().__init__() - self.a = Parameter(1.23, dtype=np.float32) + self.a = Parameter([1.23], dtype=np.float32) def forward(self, x): x = x * self.a diff --git a/imperative/python/test/integration/test_lr_scheduler.py b/imperative/python/test/integration/test_lr_scheduler.py index a0f788f6887a946de5c12da526cbf32ce3f18e6c..c3f42af98728ed7805f7440d3a07a94d3c3413f7 100644 --- a/imperative/python/test/integration/test_lr_scheduler.py +++ b/imperative/python/test/integration/test_lr_scheduler.py @@ -18,7 +18,7 @@ from megengine.optimizer import SGD, MultiStepLR class Simple(Module): def __init__(self): super().__init__() - self.a = Parameter(1.23, dtype=np.float32) + self.a = Parameter([1.23], dtype=np.float32) def forward(self, x): x = x * self.a diff --git a/imperative/python/test/integration/test_optimizer.py b/imperative/python/test/integration/test_optimizer.py index d7375c31cbedbfb84b24878c0860f0ca7dcb02eb..38b6350e49785c884a18f3d241790632ff90afe6 100644 --- a/imperative/python/test/integration/test_optimizer.py +++ b/imperative/python/test/integration/test_optimizer.py @@ -32,7 +32,7 @@ class MLP(Module): class Simple(Module): def __init__(self): super().__init__() - self.a = Parameter(1.23, dtype=np.float32) + self.a = Parameter([1.23], dtype=np.float32) def forward(self, x): x = x * self.a diff --git a/imperative/python/test/integration/test_save_load.py b/imperative/python/test/integration/test_save_load.py index be93523f9c6af2148cb44d0b6ede69749c27db3c..f18cc8c429f842f41a72b3a040f6f5f291eccae7 100644 --- a/imperative/python/test/integration/test_save_load.py +++ b/imperative/python/test/integration/test_save_load.py @@ -11,7 +11,7 @@ from megengine.module import Module class Simple(Module): def __init__(self): super().__init__() - self.a = Parameter(1.23, dtype=np.float32) + self.a = Parameter([1.23], dtype=np.float32) def forward(self, x): x = x * self.a diff --git a/imperative/python/test/integration/test_sgd_momentum.py b/imperative/python/test/integration/test_sgd_momentum.py index 6c8638773f30af763b763ed9c64035f8134b2bff..cf395fd24fddb567118e812dc91c28a45974c146 100644 --- a/imperative/python/test/integration/test_sgd_momentum.py +++ b/imperative/python/test/integration/test_sgd_momentum.py @@ -19,7 +19,7 @@ from megengine.module import Module class Simple(Module): def __init__(self): super().__init__() - self.a = Parameter(1.23, dtype=np.float32) + self.a = Parameter([1.23], dtype=np.float32) def forward(self, x): x = x * self.a diff --git a/imperative/python/test/integration/test_trace_dump.py b/imperative/python/test/integration/test_trace_dump.py index 149148a4a497abd7b1b092fdf0211468c1853f1f..73bad2f28758526ca724a146268c01e44fac33f4 100644 --- a/imperative/python/test/integration/test_trace_dump.py +++ b/imperative/python/test/integration/test_trace_dump.py @@ -107,7 +107,7 @@ def test_xornet_trace_dump(): if step % 50 == 0: minibatch = next(val_dataset) _, loss = val_fun(data, label) - loss = loss.numpy()[0] + loss = loss.numpy() val_loss.append((step, loss)) print("Step: {} loss={}".format(step, loss)) opt.step() diff --git a/imperative/python/test/unit/core/test_indexing_op.py b/imperative/python/test/unit/core/test_indexing_op.py index 746d35e8e1baa694d768ed2e2d7de795aedc1b2a..a336a28a45642978e26d077354dd4a5fa0784b12 100644 --- a/imperative/python/test/unit/core/test_indexing_op.py +++ b/imperative/python/test/unit/core/test_indexing_op.py @@ -449,7 +449,7 @@ def test_advance_indexing_high_level(): y = np.array([1, 2]) yy = Tensor(y) np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy()) - # np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) # FIXME + np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) np.testing.assert_equal(x[:, y], xx[:, y].numpy()) np.testing.assert_equal(x[:, y], xx[:, yy].numpy()) @@ -469,10 +469,9 @@ def test_advance_indexing_high_level(): y = np.array([1]) yy = Tensor(y) np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy()) - # np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) # FIXME + np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) np.testing.assert_equal(x[:, y], xx[:, y].numpy()) - # XXX: no way to tell whether yy is scalar or ndim=1 array np.testing.assert_equal(x[:, y], xx[:, yy].numpy()) x = np.arange(9).reshape(3, 3).astype("int32") diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index c6ccf55bfeb26a62d0628ea147e70db2dc93104f..862d98b367c417fb27ae608eec20af51532aabc7 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -21,6 +21,7 @@ from megengine.core.ops import builtin as ops from megengine.core.ops.builtin import Elemwise from megengine.core.tensor.core import apply from megengine.core.tensor.raw_tensor import as_raw_tensor +from megengine.core.tensor.utils import isscalar from megengine.functional import exp, log from megengine.jit import exclude_from_trace, trace from megengine.random import normal, uniform @@ -263,20 +264,21 @@ def test_optimize_for_inference_broadcast(): def test_trace_cvt_bool(): - set_symbolic_shape(True) x = tensor([0], dtype=np.int32) @trace(symbolic=True) def f(x): - return x.shape[0] == 0 + a = x.shape + b = a[0] + assert isscalar(b) + return b == 0 for i in range(3): - np.testing.assert_equal(f(x).numpy()[0], False) + np.testing.assert_equal(f(x).numpy(), False) def test_trace_reshape(): for symbolic in [False, True]: - set_symbolic_shape(True) x1 = tensor(np.random.randn(2, 10, 10)) x2 = tensor(np.random.randn(4, 10, 10)) x3 = tensor(np.random.randn(8, 10, 10)) @@ -359,7 +361,6 @@ def test_raise_on_trace(): def test_trace_broadcast(): for symbolic in [False, True]: - set_symbolic_shape(True) x1 = tensor(np.random.randn(3, 1, 1)) x2 = tensor(np.random.randn(1, 4, 1)) x3 = tensor(np.random.randn(1, 1, 5)) @@ -397,7 +398,6 @@ def test_trace_nms(): def test_trace_valid_broadcast(): - set_symbolic_shape(True) x1 = tensor(np.random.randn(1, 1)) x2 = tensor(np.random.randn(1, 2)) shape = (tensor([2]), tensor([2])) diff --git a/imperative/python/test/unit/test_zero_dim_tensor.py b/imperative/python/test/unit/test_zero_dim_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..8130ca82b5dfa4f6d7a9786c57d2a28667ad4c79 --- /dev/null +++ b/imperative/python/test/unit/test_zero_dim_tensor.py @@ -0,0 +1,52 @@ +import numpy as np + +import megengine.functional as F +from megengine import Tensor +from megengine.core._trace_option import use_symbolic_shape + + +def test_zero_dim(): + a = Tensor(1) + a_np = np.array(1, dtype=np.int32) + np.testing.assert_equal(a, a_np) + if use_symbolic_shape(): + np.testing.assert_equal(a.shape, np.array(a_np.shape)) + else: + np.testing.assert_equal(a.shape, a_np.shape) + + +def test_sum(): + a = Tensor([1, 2]) + a = a.reshape((1, 2)) + assert a.sum().ndim == 0 + assert a.sum(axis=1).ndim == 1 + + +def test_max(): + a = Tensor([1, 2]) + a = a.reshape((1, 2)) + assert a.max().ndim == 0 + assert a.max(axis=1).ndim == 1 + + +def test_reshape(): + a = Tensor(1) + a = a.reshape((1, 1)) + + +def test_squeeze(): + a = Tensor(1) + a = a.reshape((1, 1)) + assert F.squeeze(a).ndim == 0 + + +def test_elemementwise(): + a = Tensor(1.0) + assert F.exp(a).ndim == 0 + assert (a + a).ndim == 0 + assert (a + 1).ndim == 0 + + +def test_astype(): + a = Tensor(1.0) + assert a.astype("int32").ndim == 0