diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index 946c6f9612364248e5ec63c62c1ac3a48c01ff74..05f2d7091f187e9f3993f1d6f5e79ac56a6fd018 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -74,9 +74,7 @@ def _reshape(x, shape): raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i)) unspec_axis = i - if not isinstance(shape, (TensorBase, TensorWrapperBase)): - # TODO: device should be None (cpu) - (shape,) = Const(shape, dtype=np.int32, device=x.device)(x) + shape = utils.astensor1d(shape, x, dtype="int32", device=x.device) if unspec_axis is None: op = builtin.Reshape() diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index ceed89fb668a6fd6dcabdfc8dc84fbf3f188f8af..78bf6c7c461dfb625461eced244958c5fde7f72b 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -266,3 +266,20 @@ def test_trace_cvt_bool(): for i in range(3): np.testing.assert_equal(f(x).numpy()[0], False) + + +def test_trace_reshape(): + for symbolic in [False, True]: + set_tensor_shape(True) + x1 = tensor(np.random.randn(2, 10, 10)) + x2 = tensor(np.random.randn(4, 10, 10)) + x3 = tensor(np.random.randn(8, 10, 10)) + + @trace(symbolic=symbolic, capture_as_const=True) + def f(x): + y = x.reshape(x.shape[0], 100) + return y + + f(x1) + f(x2) + f(x3)