提交 3726f5cc 编写于 作者: M Megvii Engine Team

feat(gopt): merger consecutive relayout and dimshuffle to one relayout to optimize CD4 performarce

GitOrigin-RevId: a058776be35b21dfe8be14ac20247a29d6bfb78b
上级 1fead9b6
...@@ -1039,6 +1039,7 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o ...@@ -1039,6 +1039,7 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o
'NCHW64_NCHW = 28', 'NCHW64_NCHW = 28',
'NCHW_NHWC = 29', 'NCHW_NHWC = 29',
'NHWC_NCHW = 30', 'NHWC_NCHW = 30',
'NHWCD4I_NHWC = 31',
) )
) )
......
...@@ -114,6 +114,7 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, TensorLayout& ds ...@@ -114,6 +114,7 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, TensorLayout& ds
dst[3] = src[2]; dst[3] = src[2];
dst[4] = 4; dst[4] = 4;
break; break;
case Param::Mode::NHWCD4I_NHWC:
case Param::Mode::NHWCD4_NHWC: case Param::Mode::NHWCD4_NHWC:
megdnn_assert(src.ndim == 5); megdnn_assert(src.ndim == 5);
dst.ndim = 4; dst.ndim = 4;
...@@ -331,6 +332,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { ...@@ -331,6 +332,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
CHECK_SRC(DefaultTensorFormat::make()); CHECK_SRC(DefaultTensorFormat::make());
dst = Image2DPack4TensorFormat::make_raw(2, align, vendor_type); dst = Image2DPack4TensorFormat::make_raw(2, align, vendor_type);
break; break;
case Param::Mode::NHWCD4I_NHWC:
case Param::Mode::NHWCD4I_NCHW: case Param::Mode::NHWCD4I_NCHW:
CHECK_SRC(Image2DPack4TensorFormat::make_raw(2, align, vendor_type)); CHECK_SRC(Image2DPack4TensorFormat::make_raw(2, align, vendor_type));
dst = DefaultTensorFormat::make(); dst = DefaultTensorFormat::make();
...@@ -594,6 +596,7 @@ void RelayoutFormat::deduce_exec_layout( ...@@ -594,6 +596,7 @@ void RelayoutFormat::deduce_exec_layout(
.dimshuffle({0, 1, 3, 2, 4}); .dimshuffle({0, 1, 3, 2, 4});
exec_dst = dst; exec_dst = dst;
break; break;
case Param::Mode::NHWCD4I_NHWC:
case Param::Mode::NHWCD4_NHWC: case Param::Mode::NHWCD4_NHWC:
// src is {N, H, CB, W, 4} // src is {N, H, CB, W, 4}
// dst is {N, H, W, C}, // dst is {N, H, W, C},
......
...@@ -1002,7 +1002,9 @@ void ConvertFormatPass::apply(OptState& state) const { ...@@ -1002,7 +1002,9 @@ void ConvertFormatPass::apply(OptState& state) const {
rewriter.apply_inplace(); rewriter.apply_inplace();
//! start a second pass that merge consecutive dimshuffle(NHWC->NCHW) + //! start a second pass that merge consecutive dimshuffle(NHWC->NCHW) +
//! relayout_format(NCHW->NHWCD4) to only one relayout_format(NHWC->NHWCD4) //! relayout_format(NCHW->NHWCD4) to only one relayout_format(NHWC->NHWCD4). Merge
//! consecutive relayout_format(NHWCD4 -> NCHW) + dimshuffle(NCHW -> NHWC) to one
//! relayout_format(NHWCD4 -> NHWC).
auto on_opr_merge = [&rewriter](OperatorNodeBase* opr) { auto on_opr_merge = [&rewriter](OperatorNodeBase* opr) {
auto opr_is_relayout = [](OperatorNodeBase* opr) { auto opr_is_relayout = [](OperatorNodeBase* opr) {
return opr->try_cast_final<opr::RelayoutFormat>(); return opr->try_cast_final<opr::RelayoutFormat>();
...@@ -1019,12 +1021,15 @@ void ConvertFormatPass::apply(OptState& state) const { ...@@ -1019,12 +1021,15 @@ void ConvertFormatPass::apply(OptState& state) const {
} }
return false; return false;
}; };
//! dimshuffle + relayout_format
{
auto this_opr_is_relayout = opr_is_relayout(opr); auto this_opr_is_relayout = opr_is_relayout(opr);
auto prev_opr_is_dimshuffle = static_cast<opr::Dimshuffle*>(nullptr); auto prev_opr_is_dimshuffle = static_cast<opr::Dimshuffle*>(nullptr);
if (this_opr_is_relayout) { if (this_opr_is_relayout) {
prev_opr_is_dimshuffle = opr_is_dimshuffle(opr->input(0)->owner_opr()); prev_opr_is_dimshuffle = opr_is_dimshuffle(opr->input(0)->owner_opr());
} }
if (this_opr_is_relayout && prev_opr_is_dimshuffle) { if (this_opr_is_relayout && prev_opr_is_dimshuffle) {
//! megengine only accept NCHW input
if (this_opr_is_relayout->param().mode == if (this_opr_is_relayout->param().mode ==
megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I && megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I &&
match_pattern(prev_opr_is_dimshuffle->param(), {0, 3, 1, 2})) { match_pattern(prev_opr_is_dimshuffle->param(), {0, 3, 1, 2})) {
...@@ -1037,6 +1042,28 @@ void ConvertFormatPass::apply(OptState& state) const { ...@@ -1037,6 +1042,28 @@ void ConvertFormatPass::apply(OptState& state) const {
} else { } else {
rewriter.auto_replace_outputs(opr); rewriter.auto_replace_outputs(opr);
} }
}
//! relayout_format + dimshuffle
{
auto this_opr_is_dimshuffle = opr_is_dimshuffle(opr);
auto prev_opr_is_relayout = static_cast<opr::RelayoutFormat*>(nullptr);
if (this_opr_is_dimshuffle) {
prev_opr_is_relayout = opr_is_relayout(opr->input(0)->owner_opr());
}
if (this_opr_is_dimshuffle && prev_opr_is_relayout) {
if (prev_opr_is_relayout->param().mode ==
megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW &&
match_pattern(this_opr_is_dimshuffle->param(), {0, 2, 3, 1})) {
auto inp = rewriter.get_var(prev_opr_is_relayout->input(0));
auto new_param = megdnn::param::RelayoutFormat();
new_param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NHWC;
auto new_opr = opr::RelayoutFormat::make(inp, new_param);
rewriter.replace_var(opr->output(0), new_opr.node(), nullptr);
}
} else {
rewriter.auto_replace_outputs(opr);
}
}
}; };
state.graph().iter(on_opr_merge); state.graph().iter(on_opr_merge);
rewriter.apply_inplace(); rewriter.apply_inplace();
......
...@@ -1365,6 +1365,71 @@ TEST(TestGoptInference, MergeDimShuffleAndRelayoutFormat) { ...@@ -1365,6 +1365,71 @@ TEST(TestGoptInference, MergeDimShuffleAndRelayoutFormat) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
} }
TEST(TestGoptInference, MergeRelayoutFormatAndDimShuffle) {
// hwcd4 is only supported in naive handle
NaiveMegDNNHandleScope naive_megdnn_handle;
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name);
};
auto host_x = gen({2, 8, 16, 32}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
auto a = mkvar("a", {1});
auto b = mkvar("b", {1});
auto z = x * a + b;
//! to NHWC
auto y = opr::Dimshuffle::make(z, {0, 2, 3, 1});
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nhwcd4();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
ASSERT_EQ(0, find_opr_num<opr::Dimshuffle>(y_opt));
auto check = [](SymbolVar endpoint) -> bool {
bool valid = true;
auto cb = [&](cg::OperatorNodeBase* opr) {
if (opr->same_type<opr::RelayoutFormat>()) {
auto mode = opr->try_cast_final<opr::RelayoutFormat>()->param().mode;
//! The first relayout_format opr's mode is NCHW_NHWCD4I. The second is
//! NHWCD4I_NHWC
if (mode == megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I ||
mode == megdnn::param::RelayoutFormat::Mode::NHWCD4I_NHWC) {
valid &= true;
} else {
valid &= false;
}
}
};
cg::DepOprIter{cb}.add(endpoint.node()->owner_opr());
return valid;
};
ASSERT_EQ(true, check(y_opt));
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file(
"TestGoptInference.MergeRelayoutFormatAndDimShuffle.json"));
HostTensorND host_y;
HostTensorND host_y_opt;
auto func = graph->compile(
{make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
*host_x = *gen({8, 8, 16, 16}, cn);
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}
TEST(TestGoptInference, ConvertFormatNHWCD4Elemwise) { TEST(TestGoptInference, ConvertFormatNHWCD4Elemwise) {
// hwcd4 is only supported in naive handle // hwcd4 is only supported in naive handle
NaiveMegDNNHandleScope naive_megdnn_handle; NaiveMegDNNHandleScope naive_megdnn_handle;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册