提交 798ae5e5 编写于 作者: M Megvii Engine Team

fix(imperative/opr): close SCALAR_IDX warning of IndexingMultiAxisVec for proxy_graph

GitOrigin-RevId: 6e0d888f8583a8d5b91e2c8f6d7eab5b65b33dd4
上级 1670a8cf
......@@ -6,6 +6,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import inspect
import io
import itertools
from tempfile import mkstemp
......@@ -492,3 +493,44 @@ def test_random(shape_mode):
run_test(uniform)
run_test(normal)
@pytest.mark.parametrize("shape_mode", [False, True])
def test_trace_advance_indexing(shape_mode):
funcs = [
lambda x, i: x[i],
# lambda x, i, j: x[i, j], # FIXME
lambda x, i, j: x[i, :, j, ...],
# lambda x, start, end: x[start:end], # FIXME
lambda x, start, end: x[:, 0, start:end, ..., 1],
lambda x, vec: x[vec],
lambda x, vec: x[vec, ..., 0, 1:3],
lambda x, vec: x[vec, vec[0], vec[1]],
# lambda x, i, start, end, vec: x[i, ..., :, vec, start:end], # FIXME
lambda x, mask: x[mask],
]
inputs = {
"x": np.random.randn(5, 5, 5, 5, 5).astype("float32"),
"i": 0,
"j": 2,
"start": 1,
"end": 3,
"vec": [1, 2, 3],
"mask": np.random.randn(5, 5, 5, 5, 5) >= 0,
}
for f in funcs:
sig = inspect.signature(f)
param_names = list(sig._parameters.keys())
params = {}
params_np = {}
f_traced = trace(f, symbolic=False, symbolic_shape=shape_mode)
for name in param_names:
params[name] = tensor(inputs[name])
params_np[name] = inputs[name]
expected = f(**params_np)
result_imperative = f(**params)
np.testing.assert_equal(expected, result_imperative.numpy())
for _ in range(3):
result_trace = f_traced(**params)
np.testing.assert_equal(expected, result_trace.numpy())
......@@ -248,7 +248,8 @@ intl::MultiAxisVecFancyIndexingHelper::make_megdnn_index_desc(
}
}
if (!m_scalar_idx_warn_printed && warn_all_scalar) {
if (!m_scalar_idx_warn_printed && warn_all_scalar &&
!this->owner_graph()->options().imperative_proxy_graph) {
bool all_scalar = true;
for (auto &&i: index) {
if (!i.vec.layout.is_scalar()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册