From b111baf1b62f93c167bb3b79986a12068cf2b98f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 14 Sep 2020 23:08:43 +0800 Subject: [PATCH] fix(mge/tensor): fix const target shape in reshape GitOrigin-RevId: 7c04a9efbabc7c7e75d87c67a611a08a7371fe49 --- .../megengine/core/tensor/tensor_wrapper.py | 4 +--- imperative/python/test/unit/test_tracing.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index 946c6f961..05f2d7091 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -74,9 +74,7 @@ def _reshape(x, shape): raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i)) unspec_axis = i - if not isinstance(shape, (TensorBase, TensorWrapperBase)): - # TODO: device should be None (cpu) - (shape,) = Const(shape, dtype=np.int32, device=x.device)(x) + shape = utils.astensor1d(shape, x, dtype="int32", device=x.device) if unspec_axis is None: op = builtin.Reshape() diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index ceed89fb6..78bf6c7c4 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -266,3 +266,20 @@ def test_trace_cvt_bool(): for i in range(3): np.testing.assert_equal(f(x).numpy()[0], False) + + +def test_trace_reshape(): + for symbolic in [False, True]: + set_tensor_shape(True) + x1 = tensor(np.random.randn(2, 10, 10)) + x2 = tensor(np.random.randn(4, 10, 10)) + x3 = tensor(np.random.randn(8, 10, 10)) + + @trace(symbolic=symbolic, capture_as_const=True) + def f(x): + y = x.reshape(x.shape[0], 100) + return y + + f(x1) + f(x2) + f(x3) -- GitLab