提交 99b17623 编写于 作者: M Megvii Engine Team

fix(mgb/opr): fix Reduce static value inference

GitOrigin-RevId: 5e5c56064c48eff306f7449f34e7e221510b954b
上级 a3caa5d3
...@@ -1501,8 +1501,9 @@ void Reduce::init_output_static_infer_desc() { ...@@ -1501,8 +1501,9 @@ void Reduce::init_output_static_infer_desc() {
auto infer_value = [this](DeviceTensorND &dest, const InpVal &inp) { auto infer_value = [this](DeviceTensorND &dest, const InpVal &inp) {
DeviceTensorND workspace; DeviceTensorND workspace;
auto sopr = static_infer_opr.lock(); auto sopr = static_infer_opr.lock();
perform(m_param.mode, dest, workspace, perform(m_param.mode, dest, workspace, inp.val[0].value(),
inp.val[0].value(), inp.val.at(1).shape(), sopr(), m_param.data_type); output(0)->dtype(), inp.val.at(1).shape(), sopr(),
m_param.data_type);
return true; return true;
}; };
...@@ -1632,6 +1633,7 @@ void Reduce::perform( ...@@ -1632,6 +1633,7 @@ void Reduce::perform(
Mode mode, Mode mode,
DeviceTensorND &dest, DeviceTensorND &workspace, DeviceTensorND &dest, DeviceTensorND &workspace,
const DeviceTensorND &input, const DeviceTensorND &input,
const DType &target_dtype,
const TensorShape &target_shape, const TensorShape &target_shape,
intl::UniqPtrWithCN<megdnn::Reduce> &opr, const Param::DataType data_type) { intl::UniqPtrWithCN<megdnn::Reduce> &opr, const Param::DataType data_type) {
...@@ -1674,7 +1676,7 @@ void Reduce::perform( ...@@ -1674,7 +1676,7 @@ void Reduce::perform(
} }
opr.comp_node().activate(); opr.comp_node().activate();
dest.comp_node(opr.comp_node()).dtype(input.dtype()).resize(target_shape); dest.comp_node(opr.comp_node()).dtype(target_dtype).resize(target_shape);
ksched.update_ptr(*input_contig, dest, workspace); ksched.update_ptr(*input_contig, dest, workspace);
ksched.execute(opr.get(), *input_contig, dest); ksched.execute(opr.get(), *input_contig, dest);
} }
......
...@@ -304,6 +304,7 @@ MGB_DEFINE_OPR_CLASS(Reduce, intl::DynamicOutputIfInputDynamic< ...@@ -304,6 +304,7 @@ MGB_DEFINE_OPR_CLASS(Reduce, intl::DynamicOutputIfInputDynamic<
static void perform(Mode mode, DeviceTensorND& dest, static void perform(Mode mode, DeviceTensorND& dest,
DeviceTensorND& workspace, DeviceTensorND& workspace,
const DeviceTensorND& input, const DeviceTensorND& input,
const DType& target_dtype,
const TensorShape& target_shape, const TensorShape& target_shape,
intl::UniqPtrWithCN<megdnn::Reduce>& opr, intl::UniqPtrWithCN<megdnn::Reduce>& opr,
const Param::DataType data_type=Param::DataType::DEFAULT); const Param::DataType data_type=Param::DataType::DEFAULT);
......
...@@ -298,7 +298,8 @@ namespace { ...@@ -298,7 +298,8 @@ namespace {
static_calc_x.copy_from(*host_x); static_calc_x.copy_from(*host_x);
opr::Reduce::perform( opr::Reduce::perform(
Mode::SUM, static_calc_y, static_calc_workspace, Mode::SUM, static_calc_y, static_calc_workspace,
static_calc_x, oshp, static_calc_opr); static_calc_x, dtype::Float32(), oshp,
static_calc_opr);
host_y.ptr<float>()[0] ++; host_y.ptr<float>()[0] ++;
host_y.copy_from(static_calc_y); host_y.copy_from(static_calc_y);
MGB_ASSERT_TENSOR_NEAR(expected, host_y, 1e-5); MGB_ASSERT_TENSOR_NEAR(expected, host_y, 1e-5);
...@@ -468,7 +469,8 @@ TEST(TestBasicArithReduction, NonContPerform) { ...@@ -468,7 +469,8 @@ TEST(TestBasicArithReduction, NonContPerform) {
for (auto &&tshp: for (auto &&tshp:
TensorShapeArray{{5, 1}, {1, 5}, {1, 1}, {1}, {5, 5}}) { TensorShapeArray{{5, 1}, {1, 5}, {1, 1}, {1}, {5, 5}}) {
opr::Reduce::perform(mode, y, workspace, x, tshp, opr); opr::Reduce::perform(mode, y, workspace, x, dtype::Float32(), tshp,
opr);
ASSERT_TRUE(y.layout().is_contiguous()); ASSERT_TRUE(y.layout().is_contiguous());
ASSERT_EQ(tshp, y.shape()); ASSERT_EQ(tshp, y.shape());
size_t nr = tshp.total_nr_elems(); size_t nr = tshp.total_nr_elems();
...@@ -866,4 +868,36 @@ TEST(TestBasicArithReduction, StaticInferValue) { ...@@ -866,4 +868,36 @@ TEST(TestBasicArithReduction, StaticInferValue) {
MGB_ASSERT_TENSOR_EQ(inferred, expected); MGB_ASSERT_TENSOR_EQ(inferred, expected);
} }
TEST(TestBasicArithReduction, StaticInferValueDType) {
using ParamType = opr::Reduce::Param::DataType;
DType F32 = dtype::Float32(), F16 = dtype::Float16();
auto run_test = [](const DType& itype, const DType& expected_otype,
ParamType param_dtype) {
HostTensorGenerator<> gen;
auto host_x = gen({2, 3, 4, 5});
auto host_tshp = std::make_shared<HostTensorND>(host_x->comp_node(),
dtype::Int32());
host_tshp->resize({1});
host_tshp->ptr<int>()[0] = 1;
auto graph = ComputingGraph::make();
auto x_f32 = opr::Host2DeviceCopy::make(*graph, host_x),
x = opr::TypeCvt::make(x_f32, itype),
tshp = opr::Host2DeviceCopy::make(*graph, host_tshp),
y = opr::Reduce::make(
x, {opr::Reduce::Mode::SUM, MEGDNN_MAX_NDIM, param_dtype},
tshp);
auto inferred = graph->static_infer_manager().infer_value(y.node());
ASSERT_EQ(inferred.layout().dtype, expected_otype);
};
run_test(F32, F32, ParamType::DEFAULT);
run_test(F16, F16, ParamType::DEFAULT);
run_test(F32, F32, ParamType::FLOAT_O32xC32);
run_test(F16, F32, ParamType::FLOAT_O32xC32);
run_test(F32, F16, ParamType::FLOAT_O16xC32);
run_test(F16, F16, ParamType::FLOAT_O16xC32);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册