diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index 5981b2f581dd8f7f23d8b1c99573a04bd972eece..b700c1cdb9b78b1bea784d1f2c89cd8d680cdaa9 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -31,7 +31,9 @@ def dtype_promotion(raw_inputs): ] inputs = [i for i in raw_inputs if hasattr(i, "dtype")] assert len(scalar_inputs + inputs) > 0 - dtype = np.result_type(*inputs) + dtype = None + if len(inputs) > 0: + dtype = np.result_type(*inputs) dtype_all = np.result_type(*(inputs + scalar_inputs)) assert ( dtype != np.float64 and dtype != np.int64 diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index f3a43733785c82a07e22a5ce976afbb46a341d60..bc7c68f8630bbfb9b85fdc4ccf594d96b19399dc 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -10,8 +10,9 @@ import functools from ..core.ops import builtin -from ..core.tensor import utils +from ..core.tensor import megbrain_graph, utils from ..core.tensor.core import apply +from ..device import get_default_device from ..tensor import Tensor __all__ = [ @@ -76,11 +77,17 @@ __all__ = [ def _elwise(*args, mode): op = builtin.Elemwise(mode=mode) + tensor_args = list( + filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args) + ) + if len(tensor_args) == 0: + dtype = utils.dtype_promotion(args) + first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) + args = utils.convert_inputs(first_arg, *args[1:]) + else: + args = utils.convert_inputs(*args) if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"): - args = tuple( - map(lambda x: x.astype("float32") if hasattr(x, "dtype") else x, args) - ) - args = utils.convert_inputs(*args) + args = tuple(map(lambda x: x.astype("float32"), args)) (result,) = apply(op, *args) return result diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index f7163cd466326b7bcb7754f944ae2c1c5b41bf55..38b4b16dc25cafcb89ab7bf774bde1341424367c 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1126,11 +1126,8 @@ def interpolate( if mode == "LINEAR": inp = add_axis(inp, 3) - if not isinstance(inp.shape, inp.__class__): - if len(inp.shape) != 4: - raise ValueError( - "shape of input tensor must correspond to the operartion mode" - ) + if inp.ndim != 4: + raise ValueError("shape of input tensor must correspond to the operartion mode") if size is None: if scale_factor is None: diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index d1cd11104e720041a3ff324d6efbd83245e5f2b3..ef1e3a76e89512631a02543ef9ef120ccbf2bafa 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -317,7 +317,7 @@ def split(inp, nsplits_or_sections, axis=0): def swapaxis(inp, src, dst): if src == dst: return inp - shape = [i for i in range(len(inp.shape))] + shape = [i for i in range(inp.ndim)] shape[src] = dst shape[dst] = src return inp.transpose(shape) @@ -325,9 +325,11 @@ def split(inp, nsplits_or_sections, axis=0): inp = swapaxis(inp, 0, axis) if isinstance(nsplits_or_sections, int): - incr_step = math.ceil(inp.shape[0] / nsplits_or_sections) - while incr_step < inp.shape[0]: - sections.append(incr_step) + incr_step = ceil(inp.shape[0] / nsplits_or_sections) + nsplits = nsplits_or_sections + while nsplits > 0: + nsplits -= 1 + sections.append(incr_step.astype("int32")) incr_step += nsplits_or_sections else: sections = nsplits_or_sections diff --git a/imperative/python/test/unit/functional/test_elemwise.py b/imperative/python/test/unit/functional/test_elemwise.py index 75d6874dbb6a74617716701d7543af4cdda57b44..683103fd84bbcb3e99dc029623a501144a4410cf 100644 --- a/imperative/python/test/unit/functional/test_elemwise.py +++ b/imperative/python/test/unit/functional/test_elemwise.py @@ -19,13 +19,13 @@ def test_abs(): np.abs(np.array([-3.0, -4.0, -5.0], dtype=np.float32)), ) - # assertTensorClose(F.abs(-3.0), np.abs(np.float32(-3.0))) + assertTensorClose(F.abs(-3.0).numpy(), np.abs(np.float32(-3.0))) def test_multiply(): - # assertTensorClose( - # F.mul(-3.0, -4.0), np.multiply(np.float32(-3.0), np.float32(-4.0)) - # ) + assertTensorClose( + F.mul(-3.0, -4.0).numpy(), np.multiply(np.float32(-3.0), np.float32(-4.0)) + ) assertTensorClose( F.mul(tensor([3.0, 4.0]), 4.0).numpy(), diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 04d9e72460806cb2b9e5064477fcf0773ee37483..58a582d0542742826ca32df7067ea2bf3cffc83b 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -194,9 +194,6 @@ def test_matmul(): def test_interpolate(): - if use_tensor_shape(): # XXX: please fix me - return - def linear_interpolate(): inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2)) diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index 8fe8cb8d06f9c179025ddf50a76957f417b6d77c..72e1fb73bb0448690d3e35fbc088c4a27e170b12 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -125,8 +125,6 @@ def test_stack(): def test_split(): - if use_tensor_shape(): # XXX: please fix me - return data = np.random.random((2, 3, 4, 5)).astype(np.float32) mge_out1 = F.split(tensor(data), 2, axis=3) mge_out2 = F.split(tensor(data), [3, 5], axis=3)