From 2dbe8194ade5d122c11bb425fd1d026bc431eb67 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 27 Aug 2020 04:31:48 +0800 Subject: [PATCH] fix(mge/opr): fix reduction static infer value GitOrigin-RevId: 6623d4719a83e887ab44ec1bee1235b8746aca57 --- src/opr/impl/basic_arith.cpp | 4 +++- src/opr/test/basic_arith/reduction.cpp | 22 ++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index d91955442..6c126651c 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -1666,8 +1666,10 @@ void Reduce::perform( mgb_assert(!dest.storage().comp_node_valid() || opr.comp_node() == dest.comp_node()); KernScheduler ksched; + OutTensorShapeExtender extender(input.shape(), target_shape); + auto&& canonized_oshp = extender.get(); ksched.init_shapes(opr.get(), opr.comp_node(), input.layout().dtype, - mode, input.shape(), target_shape, data_type); + mode, input.shape(), canonized_oshp, data_type); if (!ksched.has_actual_computing()) { mgb_assert(target_shape.total_nr_elems() == diff --git a/src/opr/test/basic_arith/reduction.cpp b/src/opr/test/basic_arith/reduction.cpp index 6b6ea09e7..f3101613a 100644 --- a/src/opr/test/basic_arith/reduction.cpp +++ b/src/opr/test/basic_arith/reduction.cpp @@ -844,4 +844,26 @@ TEST(TestBasicArithReduction, CompSeqRecordLevel2) { EXPECT_NO_THROW(func->execute().wait()); } +TEST(TestBasicArithReduction, StaticInferValue) { + HostTensorGenerator<> gen; + auto host_x = gen({2, 3, 4, 5}); + auto graph = ComputingGraph::make(); + using AI = opr::Subtensor::AxisIndexer; + // h2d default param enable value infer + auto x = opr::Host2DeviceCopy::make(*graph, host_x), + x_shape = opr::GetVarShape::make(x), + x_shape_sub = opr::Subtensor::make(x_shape, + {AI::make_interval(0, x.make_scalar(-2), nullptr ,nullptr)}), + y = opr::reduce_sum(x, x_shape_sub); + auto inferred_dev = graph->static_infer_manager().infer_value(y.node()); + HostTensorND expected{host_x->comp_node(), dtype::Float32()}; + // reduce_raw requires the same ndim between src and dest + expected.resize({1, 1, 4, 5}); + reduce_raw(expected, *host_x); + // reshape as {4, 5} + expected.reset(expected.storage(), inferred_dev.layout()); + HostTensorND inferred = HostTensorND::make_proxy(inferred_dev); + MGB_ASSERT_TENSOR_EQ(inferred, expected); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} -- GitLab