提交 2dbe8194 编写于 作者: M Megvii Engine Team

fix(mge/opr): fix reduction static infer value

GitOrigin-RevId: 6623d4719a83e887ab44ec1bee1235b8746aca57
上级 c20d4cc6
......@@ -1666,8 +1666,10 @@ void Reduce::perform(
mgb_assert(!dest.storage().comp_node_valid() ||
opr.comp_node() == dest.comp_node());
KernScheduler ksched;
OutTensorShapeExtender extender(input.shape(), target_shape);
auto&& canonized_oshp = extender.get();
ksched.init_shapes(opr.get(), opr.comp_node(), input.layout().dtype,
mode, input.shape(), target_shape, data_type);
mode, input.shape(), canonized_oshp, data_type);
if (!ksched.has_actual_computing()) {
mgb_assert(target_shape.total_nr_elems() ==
......
......@@ -844,4 +844,26 @@ TEST(TestBasicArithReduction, CompSeqRecordLevel2) {
EXPECT_NO_THROW(func->execute().wait());
}
TEST(TestBasicArithReduction, StaticInferValue) {
HostTensorGenerator<> gen;
auto host_x = gen({2, 3, 4, 5});
auto graph = ComputingGraph::make();
using AI = opr::Subtensor::AxisIndexer;
// h2d default param enable value infer
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
x_shape = opr::GetVarShape::make(x),
x_shape_sub = opr::Subtensor::make(x_shape,
{AI::make_interval(0, x.make_scalar(-2), nullptr ,nullptr)}),
y = opr::reduce_sum(x, x_shape_sub);
auto inferred_dev = graph->static_infer_manager().infer_value(y.node());
HostTensorND expected{host_x->comp_node(), dtype::Float32()};
// reduce_raw requires the same ndim between src and dest
expected.resize({1, 1, 4, 5});
reduce_raw<Mode::SUM, float>(expected, *host_x);
// reshape as {4, 5}
expected.reset(expected.storage(), inferred_dev.layout());
HostTensorND inferred = HostTensorND::make_proxy(inferred_dev);
MGB_ASSERT_TENSOR_EQ(inferred, expected);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册