From b2f15a242eaf5cebbc66345473316328f9a993a6 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 28 Jan 2022 14:10:31 +0800 Subject: [PATCH] test(trace): test subtensor on unknown shape GitOrigin-RevId: 1b5cfa4e0ac098b54e7f4544433310d9ae90c99e --- .../python/test/unit/core/test_indexing_op.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/imperative/python/test/unit/core/test_indexing_op.py b/imperative/python/test/unit/core/test_indexing_op.py index a2c127a1..9adf12f6 100644 --- a/imperative/python/test/unit/core/test_indexing_op.py +++ b/imperative/python/test/unit/core/test_indexing_op.py @@ -7,6 +7,8 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import collections +import platform +from tempfile import NamedTemporaryFile import numpy as np import pytest @@ -16,6 +18,8 @@ import megengine import megengine.core.tensor.megbrain_graph as G import megengine.functional as F import megengine.jit as jit +import megengine.random as rand +import megengine.utils.comp_graph_tools as cgtools from megengine.core._imperative_rt.core2 import apply from megengine.core._trace_option import use_symbolic_shape from megengine.core.ops import builtin @@ -724,3 +728,26 @@ def test_nd_int_indexing(symbolic): np.testing.assert_equal(out.numpy(), npy_out) run_test([inp, idx], lambda inp, idx: inp[idx]) + + +@pytest.mark.skipif( + platform.system() == "Windows", reason="windows temp file issue, fixme later" +) +def test_subtensor_when_shape_invalid(): + @jit.trace(symbolic=True, capture_as_const=True) + def fun(inp): + shape = inp.shape + H = shape[-1] + NH = H * 8 + 4 + arr = F.arange(4, NH, 8) + arr_shape = arr.shape + return arr_shape[0] + + inp = rand.uniform(size=[1, 3, 224, 224]) + fun(inp) + + with NamedTemporaryFile() as f: + fun.dump(f.name, arg_names=["data"], optimize_for_inference=True) + inp = rand.uniform(size=[1, 3, 512, 512]) + net = cgtools.GraphInference(f.name) + net.run(inp_dict={"data": inp}) -- GitLab