diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index 1a97ecf09b13920f5308fe8ce20d72eae2bf2480..56897fc5f41a5a74e895c4073ef46a0468515cad 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -1501,8 +1501,9 @@ void Reduce::init_output_static_infer_desc() { auto infer_value = [this](DeviceTensorND &dest, const InpVal &inp) { DeviceTensorND workspace; auto sopr = static_infer_opr.lock(); - perform(m_param.mode, dest, workspace, - inp.val[0].value(), inp.val.at(1).shape(), sopr(), m_param.data_type); + perform(m_param.mode, dest, workspace, inp.val[0].value(), + output(0)->dtype(), inp.val.at(1).shape(), sopr(), + m_param.data_type); return true; }; @@ -1632,6 +1633,7 @@ void Reduce::perform( Mode mode, DeviceTensorND &dest, DeviceTensorND &workspace, const DeviceTensorND &input, + const DType &target_dtype, const TensorShape &target_shape, intl::UniqPtrWithCN &opr, const Param::DataType data_type) { @@ -1674,7 +1676,7 @@ void Reduce::perform( } opr.comp_node().activate(); - dest.comp_node(opr.comp_node()).dtype(input.dtype()).resize(target_shape); + dest.comp_node(opr.comp_node()).dtype(target_dtype).resize(target_shape); ksched.update_ptr(*input_contig, dest, workspace); ksched.execute(opr.get(), *input_contig, dest); } diff --git a/src/opr/include/megbrain/opr/basic_arith.h b/src/opr/include/megbrain/opr/basic_arith.h index 2ef8e8d638a0a53eb64225f9f4341f07fc6148dc..e541e553c6dbe800837db1d9e7272a14c9323fbe 100644 --- a/src/opr/include/megbrain/opr/basic_arith.h +++ b/src/opr/include/megbrain/opr/basic_arith.h @@ -304,6 +304,7 @@ MGB_DEFINE_OPR_CLASS(Reduce, intl::DynamicOutputIfInputDynamic< static void perform(Mode mode, DeviceTensorND& dest, DeviceTensorND& workspace, const DeviceTensorND& input, + const DType& target_dtype, const TensorShape& target_shape, intl::UniqPtrWithCN& opr, const Param::DataType data_type=Param::DataType::DEFAULT); diff --git a/src/opr/test/basic_arith/reduction.cpp b/src/opr/test/basic_arith/reduction.cpp index f3101613a934bd29df2815facf074eac44815fd7..0e74f4764c6462790e6184beae2cae409836ce1f 100644 --- a/src/opr/test/basic_arith/reduction.cpp +++ b/src/opr/test/basic_arith/reduction.cpp @@ -298,7 +298,8 @@ namespace { static_calc_x.copy_from(*host_x); opr::Reduce::perform( Mode::SUM, static_calc_y, static_calc_workspace, - static_calc_x, oshp, static_calc_opr); + static_calc_x, dtype::Float32(), oshp, + static_calc_opr); host_y.ptr()[0] ++; host_y.copy_from(static_calc_y); MGB_ASSERT_TENSOR_NEAR(expected, host_y, 1e-5); @@ -468,7 +469,8 @@ TEST(TestBasicArithReduction, NonContPerform) { for (auto &&tshp: TensorShapeArray{{5, 1}, {1, 5}, {1, 1}, {1}, {5, 5}}) { - opr::Reduce::perform(mode, y, workspace, x, tshp, opr); + opr::Reduce::perform(mode, y, workspace, x, dtype::Float32(), tshp, + opr); ASSERT_TRUE(y.layout().is_contiguous()); ASSERT_EQ(tshp, y.shape()); size_t nr = tshp.total_nr_elems(); @@ -866,4 +868,36 @@ TEST(TestBasicArithReduction, StaticInferValue) { MGB_ASSERT_TENSOR_EQ(inferred, expected); } +TEST(TestBasicArithReduction, StaticInferValueDType) { + using ParamType = opr::Reduce::Param::DataType; + DType F32 = dtype::Float32(), F16 = dtype::Float16(); + + auto run_test = [](const DType& itype, const DType& expected_otype, + ParamType param_dtype) { + HostTensorGenerator<> gen; + auto host_x = gen({2, 3, 4, 5}); + auto host_tshp = std::make_shared(host_x->comp_node(), + dtype::Int32()); + host_tshp->resize({1}); + host_tshp->ptr()[0] = 1; + + auto graph = ComputingGraph::make(); + auto x_f32 = opr::Host2DeviceCopy::make(*graph, host_x), + x = opr::TypeCvt::make(x_f32, itype), + tshp = opr::Host2DeviceCopy::make(*graph, host_tshp), + y = opr::Reduce::make( + x, {opr::Reduce::Mode::SUM, MEGDNN_MAX_NDIM, param_dtype}, + tshp); + auto inferred = graph->static_infer_manager().infer_value(y.node()); + ASSERT_EQ(inferred.layout().dtype, expected_otype); + }; + + run_test(F32, F32, ParamType::DEFAULT); + run_test(F16, F16, ParamType::DEFAULT); + run_test(F32, F32, ParamType::FLOAT_O32xC32); + run_test(F16, F32, ParamType::FLOAT_O32xC32); + run_test(F32, F16, ParamType::FLOAT_O16xC32); + run_test(F16, F16, ParamType::FLOAT_O16xC32); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}