diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index d91955442625d8b4e23234a0d5b8e1c5d7e224e0..6c126651c0b740a83499cf229521da573bd9f81e 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -1666,8 +1666,10 @@ void Reduce::perform( mgb_assert(!dest.storage().comp_node_valid() || opr.comp_node() == dest.comp_node()); KernScheduler ksched; + OutTensorShapeExtender extender(input.shape(), target_shape); + auto&& canonized_oshp = extender.get(); ksched.init_shapes(opr.get(), opr.comp_node(), input.layout().dtype, - mode, input.shape(), target_shape, data_type); + mode, input.shape(), canonized_oshp, data_type); if (!ksched.has_actual_computing()) { mgb_assert(target_shape.total_nr_elems() == diff --git a/src/opr/test/basic_arith/reduction.cpp b/src/opr/test/basic_arith/reduction.cpp index 6b6ea09e7f328513d4fd74e2b09638c61ca6bc2b..f3101613a934bd29df2815facf074eac44815fd7 100644 --- a/src/opr/test/basic_arith/reduction.cpp +++ b/src/opr/test/basic_arith/reduction.cpp @@ -844,4 +844,26 @@ TEST(TestBasicArithReduction, CompSeqRecordLevel2) { EXPECT_NO_THROW(func->execute().wait()); } +TEST(TestBasicArithReduction, StaticInferValue) { + HostTensorGenerator<> gen; + auto host_x = gen({2, 3, 4, 5}); + auto graph = ComputingGraph::make(); + using AI = opr::Subtensor::AxisIndexer; + // h2d default param enable value infer + auto x = opr::Host2DeviceCopy::make(*graph, host_x), + x_shape = opr::GetVarShape::make(x), + x_shape_sub = opr::Subtensor::make(x_shape, + {AI::make_interval(0, x.make_scalar(-2), nullptr ,nullptr)}), + y = opr::reduce_sum(x, x_shape_sub); + auto inferred_dev = graph->static_infer_manager().infer_value(y.node()); + HostTensorND expected{host_x->comp_node(), dtype::Float32()}; + // reduce_raw requires the same ndim between src and dest + expected.resize({1, 1, 4, 5}); + reduce_raw(expected, *host_x); + // reshape as {4, 5} + expected.reset(expected.storage(), inferred_dev.layout()); + HostTensorND inferred = HostTensorND::make_proxy(inferred_dev); + MGB_ASSERT_TENSOR_EQ(inferred, expected); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}