diff --git a/src/opr/impl/tensor_manip.cpp b/src/opr/impl/tensor_manip.cpp index d49bd82cced7958ee04cf3fc2e37de02f8ad5d83..250a83827658223847f4758ea309643b013ae1ea 100644 --- a/src/opr/impl/tensor_manip.cpp +++ b/src/opr/impl/tensor_manip.cpp @@ -742,7 +742,8 @@ void ModifySubtensorImplHelper::init_output_static_infer_desc() { !cg::is_static_var_shape(input(1))) return false; for (size_t i = 2; i < input().size(); ++ i) { - if (!cg::is_static_var_value(input(i))) + if (!cg::is_static_var_value(input(i)) || + !mgr.infer_value_fallible(input(i))) return false; } diff --git a/src/opr/test/tensor_manip.cpp b/src/opr/test/tensor_manip.cpp index 3fd5dd1b7cdc45def04275de78c627cad74d2b05..edf7271c96ad501793586e4fccfaad6a52cdbfb4 100644 --- a/src/opr/test/tensor_manip.cpp +++ b/src/opr/test/tensor_manip.cpp @@ -1189,6 +1189,43 @@ TEST(TestTensorManip, SetSubtensor) { run(mkshp({18, 5, 2, 3}), opt); } +TEST(TestTensorManip, SetSubtensorCheckByShapeInfer) { + HostTensorGenerator<> gen; + HostTensorGenerator gen_int; + auto host_x = gen({12}), host_sub = gen({1}), host_idx = gen_int({1}); + host_idx->ptr()[0] = 13; + auto graph = ComputingGraph::make(); + using Ad = opr::Subtensor::AxisIndexer; + auto x = opr::Host2DeviceCopy::make(*graph, host_x), + sub = opr::Host2DeviceCopy::make(*graph, host_sub); + auto idx1 = Ad::make_index(0, + opr::ImmutableTensor::make(*graph, *host_idx)), + idx2 = Ad::make_index(0, opr::Host2DeviceCopy::make(*graph, host_idx)); + + MGB_MARK_USED_VAR(x); + MGB_MARK_USED_VAR(sub); + MGB_MARK_USED_VAR(idx1); + MGB_MARK_USED_VAR(idx2); + ASSERT_THROW(opr::SetSubtensor::make(x, sub, {idx1}), MegBrainError); + ASSERT_THROW(opr::SetSubtensor::make(x, sub, {idx2}), MegBrainError); +} + +TEST(TestTensorManip, SetSubtensorShapeInfer) { + HostTensorGenerator<> gen; + HostTensorGenerator gen_int; + auto host_x = gen({12}), host_sub = gen({1}), host_idx = gen_int({1}); + host_idx->ptr()[0] = 13; + auto graph = ComputingGraph::make(); + auto&& mgr = graph->static_infer_manager(); + using Ad = opr::Subtensor::AxisIndexer; + auto x = opr::Host2DeviceCopy::make(*graph, host_x), + sub = opr::Host2DeviceCopy::make(*graph, host_sub), + index = opr::Host2DeviceCopy::make_no_value_infer(*graph, host_idx); + auto rt_static_idx = Ad::make_index(0, index * 2); + auto y = opr::SetSubtensor::make(x, sub, {rt_static_idx}); + ASSERT_TRUE(mgr.infer_shape_fallible(y.node())); +} + TEST(TestTensorManip, SetSubtensorDynIdx) { HostTensorGenerator<> gen; auto host_x = gen({12}), host_sub = gen({1}),