提交 88c1eedb 编写于 作者: M Megvii Engine Team

feat(mgb/gopt): enable reduce for nchw44

GitOrigin-RevId: fce59d07625095c4096aa6e4feb346984626e9b4
上级 563239d3
......@@ -1753,6 +1753,24 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
}
};
auto replace_reduce_opr = [=](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto& reduce_opr = opr->cast_final_safe<opr::Reduce>();
VarNodeArray temp_inp = new_inp;
if (!opr->input(0)->shape().eq_shape(new_inp[0]->shape())) {
mgb_assert(opr->input(0)->shape().ndim == 4);
mgb_assert(new_inp[0]->shape().ndim == 5);
if (reduce_opr.param().axis != 2 && reduce_opr.param().axis != 3) {
auto new_var =
RelayoutPlaceholder::make(new_inp[0], src_to_nchw_mode);
temp_inp[0] = new_var.node();
}
}
return serialization::copy_opr_shallow(*opr, temp_inp, opr->config());
};
//! When input change and all input can convert to nchwxx, this opr will run
//! in nchwxx mode, else it will run in nchw mode, for example concat and
//! elemwise opr
......@@ -1829,13 +1847,13 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
replace_func[opr::TypeCvt::typeinfo()] = replace_multi_inp_opr;
replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_multi_inp_opr;
replace_func[opr::PowC::typeinfo()] = replace_multi_inp_opr;
replace_func[opr::Reduce::typeinfo()] = replace_reduce_opr;
//! not support yet
replace_func[opr::ConvolutionBackwardData::typeinfo()] =
relayout_inp_to_nchw;
replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::WarpPerspectiveForward::typeinfo()] =
......
......@@ -3476,14 +3476,24 @@ TEST(TestGoptInference, ConvertFormatNCHW88) {
//! group
auto w3 = mkcvar("w3", {1, 8, 8, 3, 3}), b3 = mkcvar("b3", {1, 8, 1, 1}),
conv3 = opr::ConvBias::make(conv2, w3, b3, param_conv_bias);
auto shape_of = opr::GetVarShape::make(conv3);
//! reduce
opr::Reduce::Param param_reduce1;
param_reduce1.axis = 2;
param_reduce1.mode = opr::Reduce::Mode::SUM;
opr::Reduce::Param param_reduce2;
param_reduce2.axis = 0;
param_reduce2.mode = opr::Reduce::Mode::MAX;
auto reduce1 = conv3 + opr::Reduce::make(conv3, param_reduce1) +
opr::Reduce::make(conv3, param_reduce2);
auto shape_of = opr::GetVarShape::make(reduce1);
auto subtensor = opr::Subtensor::make(
shape_of, {opr::Subtensor::AxisIndexer::make_interval(
0, x.make_scalar(2), None, x.make_scalar(1))});
opr::Resize::Param param_resize;
param_resize.format = opr::Resize::Param::Format::NCHW;
auto resize = opr::ResizeForward::make(conv3, subtensor * 2, param_resize);
auto resize =
opr::ResizeForward::make(reduce1, subtensor * 2, param_resize);
auto mat = mkcvar("mat", {2, 3, 3}),
warp = opr::WarpPerspectiveForward::make(
resize, mat, nullptr, cg::var_from_tensor_shape(x, {4, 4}));
......@@ -3586,14 +3596,24 @@ TEST(TestGoptInference, ConvertFormatNCHW44) {
//! group
auto w3 = mkcvar("w3", {2, 4, 4, 3, 3}), b3 = mkcvar("b3", {1, 8, 1, 1}),
conv3 = opr::ConvBias::make(conv2, w3, b3, param_conv_bias);
auto shape_of = opr::GetVarShape::make(conv3);
//! reduce
opr::Reduce::Param param_reduce1;
param_reduce1.axis = 1;
param_reduce1.mode = opr::Reduce::Mode::MIN;
opr::Reduce::Param param_reduce2;
param_reduce2.axis = 3;
param_reduce2.mode = opr::Reduce::Mode::SUM_SQR;
auto reduce1 = conv3 + opr::Reduce::make(conv3, param_reduce1) +
opr::Reduce::make(conv3, param_reduce2);
auto shape_of = opr::GetVarShape::make(reduce1);
auto subtensor = opr::Subtensor::make(
shape_of, {opr::Subtensor::AxisIndexer::make_interval(
0, x.make_scalar(2), None, x.make_scalar(1))});
opr::Resize::Param param_resize;
param_resize.format = opr::Resize::Param::Format::NCHW;
auto resize = opr::ResizeForward::make(conv3, subtensor * 2, param_resize);
auto resize =
opr::ResizeForward::make(reduce1, subtensor * 2, param_resize);
auto mat = mkcvar("mat", {2, 3, 3}),
warp = opr::WarpPerspectiveForward::make(
resize, mat, nullptr, cg::var_from_tensor_shape(x, {4, 4}));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册