提交 cfc430b8 编写于 作者: M Megvii Engine Team

fix(mbg/opr): fix ModifySubtensorImplHelper infer shape depend input value for valid check

GitOrigin-RevId: 5ad6b9f35cf1463c12e077172fe5cf3dc8ada64c
上级 8764a6c8
......@@ -742,7 +742,8 @@ void ModifySubtensorImplHelper::init_output_static_infer_desc() {
!cg::is_static_var_shape(input(1)))
return false;
for (size_t i = 2; i < input().size(); ++ i) {
if (!cg::is_static_var_value(input(i)))
if (!cg::is_static_var_value(input(i)) ||
!mgr.infer_value_fallible(input(i)))
return false;
}
......
......@@ -1189,6 +1189,43 @@ TEST(TestTensorManip, SetSubtensor) {
run(mkshp({18, 5, 2, 3}), opt);
}
TEST(TestTensorManip, SetSubtensorCheckByShapeInfer) {
HostTensorGenerator<> gen;
HostTensorGenerator<dtype::Int32> gen_int;
auto host_x = gen({12}), host_sub = gen({1}), host_idx = gen_int({1});
host_idx->ptr<int>()[0] = 13;
auto graph = ComputingGraph::make();
using Ad = opr::Subtensor::AxisIndexer;
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
sub = opr::Host2DeviceCopy::make(*graph, host_sub);
auto idx1 = Ad::make_index(0,
opr::ImmutableTensor::make(*graph, *host_idx)),
idx2 = Ad::make_index(0, opr::Host2DeviceCopy::make(*graph, host_idx));
MGB_MARK_USED_VAR(x);
MGB_MARK_USED_VAR(sub);
MGB_MARK_USED_VAR(idx1);
MGB_MARK_USED_VAR(idx2);
ASSERT_THROW(opr::SetSubtensor::make(x, sub, {idx1}), MegBrainError);
ASSERT_THROW(opr::SetSubtensor::make(x, sub, {idx2}), MegBrainError);
}
TEST(TestTensorManip, SetSubtensorShapeInfer) {
HostTensorGenerator<> gen;
HostTensorGenerator<dtype::Int32> gen_int;
auto host_x = gen({12}), host_sub = gen({1}), host_idx = gen_int({1});
host_idx->ptr<int>()[0] = 13;
auto graph = ComputingGraph::make();
auto&& mgr = graph->static_infer_manager();
using Ad = opr::Subtensor::AxisIndexer;
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
sub = opr::Host2DeviceCopy::make(*graph, host_sub),
index = opr::Host2DeviceCopy::make_no_value_infer(*graph, host_idx);
auto rt_static_idx = Ad::make_index(0, index * 2);
auto y = opr::SetSubtensor::make(x, sub, {rt_static_idx});
ASSERT_TRUE(mgr.infer_shape_fallible(y.node()));
}
TEST(TestTensorManip, SetSubtensorDynIdx) {
HostTensorGenerator<> gen;
auto host_x = gen({12}), host_sub = gen({1}),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册