diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index 3742d6100e60393d3e0528c3371af29163b3b69b..58f07207334be6da0d65a1d7f34b4db65811a825 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -6,6 +6,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import inspect import io import itertools from tempfile import mkstemp @@ -492,3 +493,44 @@ def test_random(shape_mode): run_test(uniform) run_test(normal) + + +@pytest.mark.parametrize("shape_mode", [False, True]) +def test_trace_advance_indexing(shape_mode): + funcs = [ + lambda x, i: x[i], + # lambda x, i, j: x[i, j], # FIXME + lambda x, i, j: x[i, :, j, ...], + # lambda x, start, end: x[start:end], # FIXME + lambda x, start, end: x[:, 0, start:end, ..., 1], + lambda x, vec: x[vec], + lambda x, vec: x[vec, ..., 0, 1:3], + lambda x, vec: x[vec, vec[0], vec[1]], + # lambda x, i, start, end, vec: x[i, ..., :, vec, start:end], # FIXME + lambda x, mask: x[mask], + ] + + inputs = { + "x": np.random.randn(5, 5, 5, 5, 5).astype("float32"), + "i": 0, + "j": 2, + "start": 1, + "end": 3, + "vec": [1, 2, 3], + "mask": np.random.randn(5, 5, 5, 5, 5) >= 0, + } + for f in funcs: + sig = inspect.signature(f) + param_names = list(sig._parameters.keys()) + params = {} + params_np = {} + f_traced = trace(f, symbolic=False, symbolic_shape=shape_mode) + for name in param_names: + params[name] = tensor(inputs[name]) + params_np[name] = inputs[name] + expected = f(**params_np) + result_imperative = f(**params) + np.testing.assert_equal(expected, result_imperative.numpy()) + for _ in range(3): + result_trace = f_traced(**params) + np.testing.assert_equal(expected, result_trace.numpy()) diff --git a/src/opr/impl/indexing.cpp b/src/opr/impl/indexing.cpp index 1715b99d2811c85c6aef0231b07f1deb81584aad..b6e4e41f1554b50baa3b103996bd5d021d0a42ab 100644 --- a/src/opr/impl/indexing.cpp +++ b/src/opr/impl/indexing.cpp @@ -248,7 +248,8 @@ intl::MultiAxisVecFancyIndexingHelper::make_megdnn_index_desc( } } - if (!m_scalar_idx_warn_printed && warn_all_scalar) { + if (!m_scalar_idx_warn_printed && warn_all_scalar && + !this->owner_graph()->options().imperative_proxy_graph) { bool all_scalar = true; for (auto &&i: index) { if (!i.vec.layout.is_scalar()) {