提交 1add4517 编写于 作者: M Megvii Engine Team 提交者: “wenjuan”

test(trace): test subtensor on unknown shape

GitOrigin-RevId: 1b5cfa4e0ac098b54e7f4544433310d9ae90c99e
上级 54eef558
......@@ -7,6 +7,8 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import platform
from tempfile import NamedTemporaryFile
import numpy as np
import pytest
......@@ -16,6 +18,8 @@ import megengine
import megengine.core.tensor.megbrain_graph as G
import megengine.functional as F
import megengine.jit as jit
import megengine.random as rand
import megengine.utils.comp_graph_tools as cgtools
from megengine.core._imperative_rt.core2 import apply
from megengine.core._trace_option import use_symbolic_shape
from megengine.core.ops import builtin
......@@ -724,3 +728,26 @@ def test_nd_int_indexing(symbolic):
np.testing.assert_equal(out.numpy(), npy_out)
run_test([inp, idx], lambda inp, idx: inp[idx])
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows temp file issue, fixme later"
)
def test_subtensor_when_shape_invalid():
@jit.trace(symbolic=True, capture_as_const=True)
def fun(inp):
shape = inp.shape
H = shape[-1]
NH = H * 8 + 4
arr = F.arange(4, NH, 8)
arr_shape = arr.shape
return arr_shape[0]
inp = rand.uniform(size=[1, 3, 224, 224])
fun(inp)
with NamedTemporaryFile() as f:
fun.dump(f.name, arg_names=["data"], optimize_for_inference=True)
inp = rand.uniform(size=[1, 3, 512, 512])
net = cgtools.GraphInference(f.name)
net.run(inp_dict={"data": inp})
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册