提交 b2f15a24 编写于 作者: M Megvii Engine Team 提交者: wenjuan

test(trace): test subtensor on unknown shape

GitOrigin-RevId: 1b5cfa4e0ac098b54e7f4544433310d9ae90c99e
上级 6a2348f4
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections import collections
import platform
from tempfile import NamedTemporaryFile
import numpy as np import numpy as np
import pytest import pytest
...@@ -16,6 +18,8 @@ import megengine ...@@ -16,6 +18,8 @@ import megengine
import megengine.core.tensor.megbrain_graph as G import megengine.core.tensor.megbrain_graph as G
import megengine.functional as F import megengine.functional as F
import megengine.jit as jit import megengine.jit as jit
import megengine.random as rand
import megengine.utils.comp_graph_tools as cgtools
from megengine.core._imperative_rt.core2 import apply from megengine.core._imperative_rt.core2 import apply
from megengine.core._trace_option import use_symbolic_shape from megengine.core._trace_option import use_symbolic_shape
from megengine.core.ops import builtin from megengine.core.ops import builtin
...@@ -724,3 +728,26 @@ def test_nd_int_indexing(symbolic): ...@@ -724,3 +728,26 @@ def test_nd_int_indexing(symbolic):
np.testing.assert_equal(out.numpy(), npy_out) np.testing.assert_equal(out.numpy(), npy_out)
run_test([inp, idx], lambda inp, idx: inp[idx]) run_test([inp, idx], lambda inp, idx: inp[idx])
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows temp file issue, fixme later"
)
def test_subtensor_when_shape_invalid():
@jit.trace(symbolic=True, capture_as_const=True)
def fun(inp):
shape = inp.shape
H = shape[-1]
NH = H * 8 + 4
arr = F.arange(4, NH, 8)
arr_shape = arr.shape
return arr_shape[0]
inp = rand.uniform(size=[1, 3, 224, 224])
fun(inp)
with NamedTemporaryFile() as f:
fun.dump(f.name, arg_names=["data"], optimize_for_inference=True)
inp = rand.uniform(size=[1, 3, 512, 512])
net = cgtools.GraphInference(f.name)
net.run(inp_dict={"data": inp})
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册