提交 5f171298 编写于 作者: M Megvii Engine Team

feat(mgb/gopt): add AxisAddRemove opr support for cd4 opt pass

GitOrigin-RevId: 85218ee0c4f3103d451479304fa3787f82a4fa72
上级 93f4977c
......@@ -1635,6 +1635,7 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_chw;
replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_chw;
replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_chw;
replace_func[opr::AxisAddRemove::typeinfo()] = relayout_inp_to_chw;
replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
replace_func[opr::WarpPerspectiveForward::typeinfo()] =
replace_warp_perspective_opr;
......
......@@ -1171,16 +1171,22 @@ TEST(TestGoptInference, ConvertFormatNHWCD4) {
opr::Elemwise::Param::Mode::RELU);
param.pad_h = param.pad_w = 1;
auto w2 = mkcvar("w2", {4, 4, 3, 3}),
y = opr::Convolution::make(elem, w2, param);
y = opr::Convolution::make(elem, w2, param),
z = opr::AxisAddRemove::make(y, {opr::AxisAddRemove::AxisDesc::make_add(0)});
SymbolVar y_opt;
SymbolVar y_opt, z_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nhwcd4();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
unpack_vector(gopt::optimize_for_inference({z}, options), z_opt);
ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4,
find_opr<opr::Convolution>(y_opt).param().format);
ASSERT_EQ(TensorFormat::Type::DEFAULT,
find_opr<opr::AxisAddRemove>(z_opt).input(0)->format().type());
ASSERT_EQ(4, find_opr<opr::AxisAddRemove>(z_opt).input(0)->shape().ndim);
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册