From b82e8f007c06d37c6d7c05c066850652bb9ba20f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 16 Sep 2022 18:41:11 +0800 Subject: [PATCH] refactor(gopt): refact the padding channel opt pass GitOrigin-RevId: ee3f55aa66f21fe2d4a042298aafe4a0a02915f7 --- src/gopt/impl/framework.cpp | 3 +- src/gopt/impl/padding_channel.cpp | 588 ++++++++++----------- src/gopt/include/megbrain/gopt/inference.h | 30 ++ src/gopt/test/inference.cpp | 13 +- 4 files changed, 327 insertions(+), 307 deletions(-) diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index 15f64b4ba..9b5784a5b 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -783,7 +783,8 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( }); cb(nchw64, { add_pass(); - add_pass(); + add_pass(PaddingChannelPass::make( + cg::GraphCommonOptimizeOptions::LayoutTransform::NCHW64)); add_pass(); add_pass(EnableNCHW64Pass::make_nchw64_converter()); add_pass(); diff --git a/src/gopt/impl/padding_channel.cpp b/src/gopt/impl/padding_channel.cpp index 7184fce11..d4089a902 100644 --- a/src/gopt/impl/padding_channel.cpp +++ b/src/gopt/impl/padding_channel.cpp @@ -33,6 +33,54 @@ using namespace gopt; using ReformatKey = ReformatManager::ReformatKey; /* ==================== PaddingChannelPass ================= */ +namespace { +size_t padding_int4(size_t in_channel, bool flag) { + static_cast(flag); + if (in_channel <= 32) { + return (8 - (in_channel % 8)) % 8; + } else { + return (64 - (in_channel % 64)) % 64; + } +} + +size_t padding_int8(size_t in_channel, bool flag) { + if (flag) { + if (in_channel <= 16) { + return (4 - (in_channel % 4)) % 4; + } else { + return (32 - (in_channel % 32)) % 32; + } + } else { + return (4 - (in_channel % 4)) % 4; + } +} +size_t padding_4(size_t in_channel, bool) { + return (4 - (in_channel % 4)) % 4; +}; + +} // namespace + +std::unique_ptr PaddingChannelPass::make( + cg::GraphCommonOptimizeOptions::LayoutTransform layout_transform) { + MIDOUT_B("PaddingChannelPass::make") + using LayoutTrans = cg::GraphCommonOptimizeOptions::LayoutTransform; + auto ret = std::make_unique(); + auto& alignment_map = ret->m_alignment_map; + if (layout_transform == LayoutTrans::NCHW64) { + alignment_map[DTypeEnum::QuantizedS4] = padding_int4; + alignment_map[DTypeEnum::Quantized4Asymm] = padding_int4; + alignment_map[DTypeEnum::QuantizedS8] = padding_int8; + } else if ( + layout_transform == LayoutTrans::NCHW44 || + layout_transform == LayoutTrans::NCHW44_DOT) { + alignment_map[DTypeEnum::QuantizedS8] = padding_4; + alignment_map[DTypeEnum::Quantized8Asymm] = padding_4; + alignment_map[DTypeEnum::Float32] = padding_4; + } + ret->fill_opr_convert_fun(layout_transform); + return ret; + MIDOUT_E +} const char* PaddingChannelPass::name() const { return mgb_cstr_log("padding output channel to multiple of 4/32"); } @@ -42,267 +90,240 @@ void PaddingChannelPass::apply(OptState& opt) const { // do not check shape opt.set_var_replace_check_flag( VarReplaceCheckFlag::CHECK_ALL ^ VarReplaceCheckFlag::CHECK_SHAPE); - - ThinHashSet padding_oprs; - ThinHashMap< - Typeinfo*, - thin_function> - opr_replace_funcs; - + m_padding_oprs.clear(); auto rewriter = opt.graph().make_rewriter(); - auto pad_in_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* { - mgb_assert(inp->shape().ndim == 4); - mgb_assert( - inp->dtype().enumv() == DTypeEnum::QuantizedS4 || - inp->dtype().enumv() == DTypeEnum::Quantized4Asymm || - inp->dtype().enumv() == DTypeEnum::QuantizedS8 || - inp->dtype().enumv() == DTypeEnum::QuantizedS32); - TensorShape shape{ - inp->shape()[0], pad_channels, inp->shape()[2], inp->shape()[3]}; - std::shared_ptr host_val = - std::make_shared(inp->comp_node(), inp->dtype()); - host_val->resize(shape); - auto ptr = host_val->raw_ptr(); - size_t size_bytes = TensorLayout{shape, inp->dtype()}.span().dist_byte(); - std::memset(ptr, 0, size_bytes); - auto padding = opr::ImmutableTensor::make(*inp->owner_graph(), *host_val); - auto out = opr::Concat::make({inp, padding}, 1); - return out.node(); - }; - - auto pad_out_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* { - mgb_assert(inp->shape().ndim == 4); - mgb_assert( - inp->dtype().enumv() == DTypeEnum::QuantizedS4 || - inp->dtype().enumv() == DTypeEnum::Quantized4Asymm || - inp->dtype().enumv() == DTypeEnum::QuantizedS8 || - inp->dtype().enumv() == DTypeEnum::QuantizedS32); - TensorShape shape{ - pad_channels, inp->shape()[1], inp->shape()[2], inp->shape()[3]}; - std::shared_ptr host_val = - std::make_shared(inp->comp_node(), inp->dtype()); - host_val->resize(shape); - auto ptr = host_val->raw_ptr(); - size_t size_bytes = TensorLayout{shape, inp->dtype()}.span().dist_byte(); - std::memset(ptr, 0, size_bytes); - auto padding = opr::ImmutableTensor::make(*inp->owner_graph(), *host_val); - auto out = opr::Concat::make({inp, padding}, 0); - return out.node(); - }; - - auto extract_subtensor = [](VarNode* inp, - const TensorShape& orig_shape) -> VarNode* { - mgb_assert(inp->shape().ndim == 4); - mgb_assert(inp->shape()[0] == orig_shape[0]); - mgb_assert(inp->shape()[2] == orig_shape[2]); - mgb_assert(inp->shape()[3] == orig_shape[3]); - size_t orig_channels = orig_shape[1]; - auto x = SymbolVar(inp); - auto cv = [&x](int v) { return x.make_scalar(v); }; - using AIdx = opr::Subtensor::AxisIndexer; - auto sub = opr::Subtensor::make( - x, {AIdx::make_interval(0, None, None, cv(1)), - AIdx::make_interval(1, None, cv(orig_channels), None), - AIdx::make_interval(2, None, None, cv(1)), - AIdx::make_interval(3, None, None, cv(1))}); - return sub.node(); - }; - - // padding policy for conv bias with data type qint8 - auto padding_policy_qint8 = [&padding_oprs, &pad_in_channels, &pad_out_channels]( - OperatorNodeBase* opr, - const VarNodeArray& new_inp) { - mgb_assert(opr->input().size() == new_inp.size()); - mgb_assert(new_inp.size() == 3); - mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape())); - auto inps = new_inp; - size_t out_channels = opr->input(1)->shape()[0]; - size_t in_channels = opr->input(1)->shape()[1]; - size_t new_in_channels = new_inp[0]->shape()[1]; - // pad input channels - if (padding_oprs.count(opr->input(0)->owner_opr())) { - size_t pad_channels = new_in_channels - in_channels; - inps[1] = pad_in_channels(new_inp[1], pad_channels); - } else { - size_t pad_channels = 0; - mgb_assert(new_in_channels == in_channels); - if (in_channels <= 16) { - if (in_channels % 4) - pad_channels = 4 - (in_channels % 4); // pad to use dp4a - } else { - if (in_channels % 32) - pad_channels = 32 - (in_channels % 32); // pad to use tensorcore + auto on_opr = [this, &opt, &rewriter](OperatorNodeBase* opr) { + auto it = m_opr_replace_funcs.find(opr->dyn_typeinfo()); + if (it != m_opr_replace_funcs.end()) { + VarNodeArray new_inp; + new_inp.reserve(opr->input().size()); + for (auto&& inp : opr->input()) { + new_inp.push_back(rewriter.get_var(inp)); } - if (pad_channels > 0) { - inps[0] = pad_in_channels(new_inp[0], pad_channels); - inps[1] = pad_in_channels(new_inp[1], pad_channels); + auto new_opr = (it->second)(opr, new_inp); + auto &&out0 = opr->output(), &&out1 = new_opr->output(); + mgb_assert( + out0.size() == out1.size(), + "bad opr replace: src=%s{%s} dst=%s{%s}, " + "src.size=%zu " + "dst.size=%zu", + opr->cname(), opr->dyn_typeinfo()->name, new_opr->cname(), + new_opr->dyn_typeinfo()->name, out0.size(), out1.size()); + for (size_t i = 0; i < out0.size(); ++i) { + if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { + mgb_assert(!out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)); + auto src = out0[i]; + auto dst = out1[i]; + if (opt.graph().endpoint_contain(src) && + !src->shape().eq_shape(dst->shape())) { + dst = extract_subtensor(dst, src->shape()); + } + rewriter.replace_var(src, dst, nullptr); + } } - } - out_channels = inps[1]->shape()[0]; - in_channels = inps[1]->shape()[1]; - size_t pad_channels = 0; - if (out_channels <= 16) { - if (out_channels % 4) - pad_channels = 4 - (out_channels % 4); } else { - if (out_channels % 32) - pad_channels = 32 - (out_channels % 32); - } - if (pad_channels > 0) { - inps[1] = pad_out_channels(inps[1], pad_channels); - inps[2] = pad_in_channels(inps[2], pad_channels); - padding_oprs.insert(opr); + rewriter.auto_replace_outputs(opr); } - return serialization::copy_opr_shallow(*opr, inps, opr->config()); }; + opt.graph().iter(on_opr); + rewriter.apply_inplace(); - // padding policy for conv bias with data type qint4 and quint4 - auto padding_policy_int4 = [&padding_oprs, &pad_in_channels, &pad_out_channels]( - OperatorNodeBase* opr, - const VarNodeArray& new_inp) { - mgb_assert(opr->input().size() == new_inp.size()); - mgb_assert(new_inp.size() == 3); - mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape())); - auto inps = new_inp; - size_t out_channels = opr->input(1)->shape()[0]; - size_t in_channels = opr->input(1)->shape()[1]; - size_t new_in_channels = new_inp[0]->shape()[1]; - // pad input channels - if (padding_oprs.count(opr->input(0)->owner_opr())) { - if (new_in_channels <= 32) { - if (new_in_channels % 8 == 0) { - size_t pad_channels = new_in_channels - in_channels; - inps[1] = pad_in_channels(new_inp[1], pad_channels); - } else { - size_t pad_channels_0 = 8 - (new_in_channels % 8); - size_t pad_channels_1 = 8 - (in_channels % 8); - inps[0] = pad_in_channels(new_inp[0], pad_channels_0); - inps[1] = pad_in_channels(new_inp[1], pad_channels_1); - } - } else { - if (new_in_channels % 64 == 0) { - size_t pad_channels = new_in_channels - in_channels; - inps[1] = pad_in_channels(new_inp[1], pad_channels); - } else { - size_t pad_channels_0 = 64 - (new_in_channels % 64); - size_t pad_channels_1 = 64 - (in_channels % 64); - inps[0] = pad_in_channels(new_inp[0], pad_channels_0); - inps[1] = pad_in_channels(new_inp[1], pad_channels_1); - } - } + MIDOUT_E +} + +VarNode* PaddingChannelPass::extract_subtensor( + VarNode* inp, const TensorShape& orig_shape) const { + mgb_assert(inp->shape().ndim == 4); + mgb_assert(inp->shape()[0] == orig_shape[0]); + mgb_assert(inp->shape()[2] == orig_shape[2]); + mgb_assert(inp->shape()[3] == orig_shape[3]); + size_t orig_channels = orig_shape[1]; + auto x = SymbolVar(inp); + auto cv = [&x](int v) { return x.make_scalar(v); }; + using AIdx = opr::Subtensor::AxisIndexer; + auto sub = opr::Subtensor::make( + x, {AIdx::make_interval(0, None, None, cv(1)), + AIdx::make_interval(1, None, cv(orig_channels), None), + AIdx::make_interval(2, None, None, cv(1)), + AIdx::make_interval(3, None, None, cv(1))}); + return sub.node(); +}; + +VarNode* PaddingChannelPass::pad_in_channels(VarNode* inp, size_t pad_channels) { + mgb_assert(inp->shape().ndim == 4); + TensorShape shape{inp->shape()[0], pad_channels, inp->shape()[2], inp->shape()[3]}; + std::shared_ptr host_val = + std::make_shared(inp->comp_node(), inp->dtype()); + host_val->resize(shape); + auto ptr = host_val->raw_ptr(); + size_t size_bytes = TensorLayout{shape, inp->dtype()}.span().dist_byte(); + std::memset(ptr, 0, size_bytes); + auto padding = opr::ImmutableTensor::make(*inp->owner_graph(), *host_val); + auto out = opr::Concat::make({inp, padding}, 1); + return out.node(); +}; + +VarNode* PaddingChannelPass::pad_out_channels(VarNode* inp, size_t pad_channels) { + mgb_assert(inp->shape().ndim == 4); + TensorShape shape{pad_channels, inp->shape()[1], inp->shape()[2], inp->shape()[3]}; + std::shared_ptr host_val = + std::make_shared(inp->comp_node(), inp->dtype()); + host_val->resize(shape); + auto ptr = host_val->raw_ptr(); + size_t size_bytes = TensorLayout{shape, inp->dtype()}.span().dist_byte(); + std::memset(ptr, 0, size_bytes); + auto padding = opr::ImmutableTensor::make(*inp->owner_graph(), *host_val); + auto out = opr::Concat::make({inp, padding}, 0); + return out.node(); +}; + +// padding policy for conv bias with data type qint8 +OperatorNodeBase* PaddingChannelPass::padding_policy( + OperatorNodeBase* opr, const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + mgb_assert(new_inp.size() == 3); + //! new weights and old weights are same shape + mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape())); + auto inps = new_inp; + size_t out_channels = opr->input(1)->shape()[0]; + size_t in_channels = opr->input(1)->shape()[1]; + size_t new_in_channels = new_inp[0]->shape()[1]; + auto it = m_alignment_map.find(opr->input(0)->dtype().enumv()); + if (it != m_alignment_map.end()) { + mgb_assert(it->second); + } else { + return serialization::copy_opr_shallow(*opr, inps, opr->config()); + } + // pad input channels + if (m_padding_oprs.count(opr->input(0)->owner_opr())) { + //! as the opr of input var is padding, but the dtype of input and output of + //! the input opr maybe different, so the alignment is not the same + size_t pad_channels_0 = it->second(new_in_channels, true); + size_t pad_channels_1 = it->second(in_channels, true); + if (pad_channels_0) { + inps[0] = pad_in_channels(new_inp[0], pad_channels_0); } else { - size_t pad_channels = 0; - mgb_assert(new_in_channels == in_channels); - if (in_channels <= 32) { - if (in_channels % 8) - pad_channels = 8 - (in_channels % 8); - } else { - if (in_channels % 64) - pad_channels = 64 - (in_channels % 64); - } - if (pad_channels > 0) { - inps[0] = pad_in_channels(new_inp[0], pad_channels); - inps[1] = pad_in_channels(new_inp[1], pad_channels); - } + pad_channels_1 = new_in_channels - in_channels; } - out_channels = inps[1]->shape()[0]; - in_channels = inps[1]->shape()[1]; - size_t pad_channels = 0; - if (out_channels <= 32) { - if (out_channels % 8) - pad_channels = 8 - (out_channels % 8); - } else { - if (out_channels % 64) - pad_channels = 64 - (out_channels % 64); + if (pad_channels_1) { + inps[1] = pad_in_channels(new_inp[1], pad_channels_1); } + } else { + mgb_assert(new_in_channels == in_channels); + size_t pad_channels = it->second(in_channels, true); if (pad_channels > 0) { - inps[1] = pad_out_channels(inps[1], pad_channels); - inps[2] = pad_in_channels(inps[2], pad_channels); - padding_oprs.insert(opr); + inps[0] = pad_in_channels(new_inp[0], pad_channels); + inps[1] = pad_in_channels(new_inp[1], pad_channels); } - return serialization::copy_opr_shallow(*opr, inps, opr->config()); - }; + } + out_channels = inps[1]->shape()[0]; + size_t pad_channels = it->second(out_channels, true); + if (pad_channels > 0) { + inps[1] = pad_out_channels(inps[1], pad_channels); + inps[2] = pad_in_channels(inps[2], pad_channels); + m_padding_oprs.insert(opr); + } + return serialization::copy_opr_shallow(*opr, inps, opr->config()); +}; - opr_replace_funcs[opr::ConvBiasForward::typeinfo()] = - [&padding_oprs, &padding_policy_qint8, &padding_policy_int4]( - OperatorNodeBase* opr, const VarNodeArray& new_inp) { - if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) { - return padding_policy_qint8(opr, new_inp); - } else if ( - opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4 || - opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm) { - return padding_policy_int4(opr, new_inp); - } else { - mgb_assert( - padding_oprs.count(opr->input(0)->owner_opr()) == 0, - "conv bias operator for data type(%s) cannot be " - "padded channel. " - "consumer(%s), producer(%s)", - opr->input(0)->dtype().name(), opr->cname(), - opr->input(0)->owner_opr()->cname()); - return serialization::copy_opr_shallow( - *opr, new_inp, opr->config()); - } - }; - opr_replace_funcs[opr::ConvolutionBackwardData::typeinfo()] = - [&padding_oprs, &pad_in_channels, &pad_out_channels]( - OperatorNodeBase* opr, const VarNodeArray& new_inp) { - if (opr->input(1)->dtype().enumv() != DTypeEnum::QuantizedS8) { +void PaddingChannelPass::fill_opr_convert_fun(LayoutTrans layout_trans) { + add_convbias_replace_func(layout_trans); + add_conv_backward_data_replace_func(layout_trans); + add_format_aware_opr_replace_func(layout_trans); + add_elemwise_like_opr_replace_func(layout_trans); + add_nonpadding_oprs_replace_func(layout_trans); +} + +void PaddingChannelPass::add_convbias_replace_func(LayoutTrans layout_trans) { + if (layout_trans == LayoutTrans::NCHW64) { + m_opr_replace_funcs[opr::ConvBiasForward::typeinfo()] = + [this](OperatorNodeBase* opr, const VarNodeArray& new_inp) { + if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) { + return padding_policy(opr, new_inp); + } else if ( + opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4 || + opr->input(0)->dtype().enumv() == + DTypeEnum::Quantized4Asymm) { + return padding_policy(opr, new_inp); + } else { + mgb_assert( + m_padding_oprs.count(opr->input(0)->owner_opr()) == 0, + "conv bias operator for data type(%s) cannot be " + "padded channel. " + "consumer(%s), producer(%s)", + opr->input(0)->dtype().name(), opr->cname(), + opr->input(0)->owner_opr()->cname()); + return serialization::copy_opr_shallow( + *opr, new_inp, opr->config()); + } + }; + } else if (layout_trans == LayoutTrans::NCHW44) { + m_opr_replace_funcs[opr::ConvBiasForward::typeinfo()] = + [this](OperatorNodeBase* opr, const VarNodeArray& new_inp) { + return padding_policy(opr, new_inp); + }; + } +} + +void PaddingChannelPass::add_conv_backward_data_replace_func(LayoutTrans layout_trans) { + if (layout_trans == LayoutTrans::NCHW64) { + m_opr_replace_funcs[opr::ConvolutionBackwardData::typeinfo()] = + [this](OperatorNodeBase* opr, const VarNodeArray& new_inp) { + if (opr->input(1)->dtype().enumv() != DTypeEnum::QuantizedS8) { + mgb_assert( + m_padding_oprs.count(opr->input(0)->owner_opr()) == 0, + "conv bwd data operator for data type(%s) cannot " + "be " + "padded channel. " + "consumer(%s), producer(%s)", + opr->input(0)->dtype().name(), opr->cname(), + opr->input(0)->owner_opr()->cname()); + return serialization::copy_opr_shallow( + *opr, new_inp, opr->config()); + } + mgb_assert(opr->input().size() == new_inp.size()); mgb_assert( - padding_oprs.count(opr->input(0)->owner_opr()) == 0, - "conv bwd data operator for data type(%s) cannot " - "be " - "padded channel. " - "consumer(%s), producer(%s)", - opr->input(0)->dtype().name(), opr->cname(), - opr->input(0)->owner_opr()->cname()); - return serialization::copy_opr_shallow( - *opr, new_inp, opr->config()); - } - mgb_assert(opr->input().size() == new_inp.size()); - mgb_assert( - new_inp.size() == 2, - "deconv (conv bwd data) operator for inference can " - "only have 2 input vars(got:%zu)", - new_inp.size()); - mgb_assert(opr->input(0)->shape().eq_shape(new_inp[0]->shape())); - auto inps = new_inp; - size_t out_channels = opr->input(0)->shape()[0]; - size_t in_channels = opr->input(0)->shape()[1]; - size_t new_out_channels = new_inp[1]->shape()[1]; - // pad output channels - if (padding_oprs.count(opr->input(1)->owner_opr())) { - size_t pad_channels = new_out_channels - out_channels; - inps[0] = pad_out_channels(new_inp[0], pad_channels); - } else { - size_t pad_channels = 0; - if (out_channels % 4) - pad_channels = 4 - (out_channels % 4); - if (pad_channels > 0) { + new_inp.size() == 2, + "deconv (conv bwd data) operator for inference can " + "only have 2 input vars(got:%zu)", + new_inp.size()); + mgb_assert(opr->input(0)->shape().eq_shape(new_inp[0]->shape())); + auto inps = new_inp; + size_t out_channels = opr->input(0)->shape()[0]; + size_t in_channels = opr->input(0)->shape()[1]; + size_t new_out_channels = new_inp[1]->shape()[1]; + auto it = m_alignment_map.find(opr->input(1)->dtype().enumv()); + // pad output channels + if (m_padding_oprs.count(opr->input(1)->owner_opr())) { + size_t pad_channels = new_out_channels - out_channels; inps[0] = pad_out_channels(new_inp[0], pad_channels); - inps[1] = pad_in_channels(new_inp[1], pad_channels); + } else { + size_t pad_channels = it->second(out_channels, false); + if (pad_channels > 0) { + inps[0] = pad_out_channels(new_inp[0], pad_channels); + inps[1] = pad_in_channels(new_inp[1], pad_channels); + } } - } - out_channels = inps[0]->shape()[0]; - in_channels = inps[0]->shape()[1]; - // pad input channels - size_t pad_channels = 0; - if (in_channels % 4) - pad_channels = 4 - (in_channels % 4); - if (pad_channels > 0) { - inps[0] = pad_in_channels(inps[0], pad_channels); - padding_oprs.insert(opr); - } - return serialization::copy_opr_shallow(*opr, inps, opr->config()); - }; - auto replace_format_aware_opr = [&padding_oprs]( - OperatorNodeBase* opr, - const VarNodeArray& new_inp) { + out_channels = inps[0]->shape()[0]; + // pad input channels + size_t pad_channels = it->second(in_channels, false); + if (pad_channels > 0) { + inps[0] = pad_in_channels(inps[0], pad_channels); + m_padding_oprs.insert(opr); + } + return serialization::copy_opr_shallow(*opr, inps, opr->config()); + }; + } +} + +void PaddingChannelPass::add_format_aware_opr_replace_func(LayoutTrans) { + auto replace_format_aware_opr = [this](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { if (opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8 && opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS4 && opr->input(0)->dtype().enumv() != DTypeEnum::Quantized4Asymm) { mgb_assert( - padding_oprs.count(opr->input(0)->owner_opr()) == 0, + m_padding_oprs.count(opr->input(0)->owner_opr()) == 0, "operator(type:%s,name:%s) for data type(%s) cannot be " "padded channel. extra info:" "consumer(%s), producer(%s)", @@ -312,18 +333,19 @@ void PaddingChannelPass::apply(OptState& opt) const { return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); } mgb_assert(opr->input().size() == new_inp.size()); - if (padding_oprs.count(opr->input(0)->owner_opr())) { - padding_oprs.insert(opr); + if (m_padding_oprs.count(opr->input(0)->owner_opr())) { + m_padding_oprs.insert(opr); } return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); }; - opr_replace_funcs[opr::PoolingForward::typeinfo()] = replace_format_aware_opr; - opr_replace_funcs[opr::WarpPerspectiveForward::typeinfo()] = + m_opr_replace_funcs[opr::PoolingForward::typeinfo()] = replace_format_aware_opr; + m_opr_replace_funcs[opr::WarpPerspectiveForward::typeinfo()] = replace_format_aware_opr; +} - auto replace_elemwise_like_opr = [&padding_oprs, &extract_subtensor]( - OperatorNodeBase* opr, - const VarNodeArray& new_inp) { +void PaddingChannelPass::add_elemwise_like_opr_replace_func(LayoutTrans) { + auto replace_elemwise_like_opr = [this](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); bool have_padding_inp = false; bool padding_all_inps = true; @@ -331,7 +353,7 @@ void PaddingChannelPass::apply(OptState& opt) const { size_t channels_after_padding = 0; size_t i = 0; for (auto&& cur_inp : opr->input()) { - bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0; + bool padding_cur_inp = m_padding_oprs.count(cur_inp->owner_opr()) > 0; if (padding_cur_inp) { if (!have_padding_inp) have_padding_inp = true; @@ -349,7 +371,7 @@ void PaddingChannelPass::apply(OptState& opt) const { auto inps = new_inp; for (size_t i = 0; i < new_inp.size(); ++i) { auto cur_inp = opr->input(i); - bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0; + bool padding_cur_inp = m_padding_oprs.count(cur_inp->owner_opr()) > 0; if (padding_cur_inp) { inps[i] = extract_subtensor(inps[i], cur_inp->shape()); } @@ -357,72 +379,34 @@ void PaddingChannelPass::apply(OptState& opt) const { return serialization::copy_opr_shallow(*opr, inps, opr->config()); } if (padding_all_inps) { - padding_oprs.insert(opr); + m_padding_oprs.insert(opr); } return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); }; - opr_replace_funcs[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_like_opr; - opr_replace_funcs[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr; - opr_replace_funcs[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr; + m_opr_replace_funcs[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_like_opr; + m_opr_replace_funcs[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr; + m_opr_replace_funcs[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr; +} - auto replace_nonpadding_oprs = [&padding_oprs, &extract_subtensor]( - OperatorNodeBase* opr, - const VarNodeArray& new_inp) { +void PaddingChannelPass::add_nonpadding_oprs_replace_func(LayoutTrans) { + auto replace_nonpadding_oprs = [this](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); auto inps = new_inp; for (size_t i = 0; i < new_inp.size(); ++i) { auto cur_inp = opr->input(i); - bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0; + bool padding_cur_inp = m_padding_oprs.count(cur_inp->owner_opr()) > 0; if (padding_cur_inp) { inps[i] = extract_subtensor(inps[i], cur_inp->shape()); } } return serialization::copy_opr_shallow(*opr, inps, opr->config()); }; - opr_replace_funcs[opr::Reshape::typeinfo()] = replace_nonpadding_oprs; - opr_replace_funcs[opr::GetVarShape::typeinfo()] = replace_nonpadding_oprs; - opr_replace_funcs[opr::Concat::typeinfo()] = replace_nonpadding_oprs; - opr_replace_funcs[opr::Reduce::typeinfo()] = replace_nonpadding_oprs; - opr_replace_funcs[opr::Subtensor::typeinfo()] = replace_nonpadding_oprs; - - auto on_opr = [&opt, &rewriter, &opr_replace_funcs, - &extract_subtensor](OperatorNodeBase* opr) { - auto it = opr_replace_funcs.find(opr->dyn_typeinfo()); - if (it != opr_replace_funcs.end()) { - VarNodeArray new_inp; - new_inp.reserve(opr->input().size()); - for (auto&& inp : opr->input()) { - new_inp.push_back(rewriter.get_var(inp)); - } - auto new_opr = (it->second)(opr, new_inp); - auto &&out0 = opr->output(), &&out1 = new_opr->output(); - mgb_assert( - out0.size() == out1.size(), - "bad opr replace: src=%s{%s} dst=%s{%s}, " - "src.size=%zu " - "dst.size=%zu", - opr->cname(), opr->dyn_typeinfo()->name, new_opr->cname(), - new_opr->dyn_typeinfo()->name, out0.size(), out1.size()); - for (size_t i = 0; i < out0.size(); ++i) { - if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { - mgb_assert(!out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)); - auto src = out0[i]; - auto dst = out1[i]; - if (opt.graph().endpoint_contain(src) && - !src->shape().eq_shape(dst->shape())) { - dst = extract_subtensor(dst, src->shape()); - } - rewriter.replace_var(src, dst, nullptr); - } - } - } else { - rewriter.auto_replace_outputs(opr); - } - }; - opt.graph().iter(on_opr); - rewriter.apply_inplace(); - - MIDOUT_E + m_opr_replace_funcs[opr::Reshape::typeinfo()] = replace_nonpadding_oprs; + m_opr_replace_funcs[opr::GetVarShape::typeinfo()] = replace_nonpadding_oprs; + m_opr_replace_funcs[opr::Concat::typeinfo()] = replace_nonpadding_oprs; + m_opr_replace_funcs[opr::Reduce::typeinfo()] = replace_nonpadding_oprs; + m_opr_replace_funcs[opr::Subtensor::typeinfo()] = replace_nonpadding_oprs; } // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/include/megbrain/gopt/inference.h b/src/gopt/include/megbrain/gopt/inference.h index b3cd9702f..6f12dfba9 100644 --- a/src/gopt/include/megbrain/gopt/inference.h +++ b/src/gopt/include/megbrain/gopt/inference.h @@ -509,8 +509,38 @@ public: */ class PaddingChannelPass final : public Pass { public: + using ChannelAlignmentMap = + ThinHashMap>; + using LayoutTrans = cg::GraphCommonOptimizeOptions::LayoutTransform; + const char* name() const override; void apply(OptState& opt) const override; + + void fill_opr_convert_fun(LayoutTrans layout_trans); + + using ReplaceFuncs = ThinHashMap< + Typeinfo*, + thin_function>; + + //! make channel padding opt pass with given tensor format + static std::unique_ptr make(LayoutTrans layout_transform); + +private: + VarNode* extract_subtensor(VarNode* inp, const TensorShape& orig_shape) const; + VarNode* pad_in_channels(VarNode* inp, size_t pad_channels); + VarNode* pad_out_channels(VarNode* inp, size_t pad_channels); + OperatorNodeBase* padding_policy( + OperatorNodeBase* opr, const VarNodeArray& new_inp); + + void add_convbias_replace_func(LayoutTrans layout_transform); + void add_conv_backward_data_replace_func(LayoutTrans layout_transform); + void add_format_aware_opr_replace_func(LayoutTrans layout_transform); + void add_elemwise_like_opr_replace_func(LayoutTrans layout_transform); + void add_nonpadding_oprs_replace_func(LayoutTrans layout_transform); + + ChannelAlignmentMap m_alignment_map; + ReplaceFuncs m_opr_replace_funcs; + mutable ThinHashSet m_padding_oprs; }; /*! diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 45c43f84d..750b5fa78 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -1,3 +1,4 @@ +#include "megbrain/graph/cg.h" #include "megbrain/opr/dnn/local.h" #include "megbrain/test/helper.h" @@ -5037,7 +5038,8 @@ TEST(TestGoptInference, PaddingChannels) { SymbolVar y3_pad; unpack_vector( gopt::GraphOptimizer{} - .add_pass() + .add_pass(gopt::PaddingChannelPass::make( + cg::GraphCommonOptimizeOptions::LayoutTransform::NCHW64)) .apply({{y3}}) .endpoint_vars(), y3_pad); @@ -5101,7 +5103,8 @@ TEST(TestGoptInference, ConcatAfterPaddingChannels) { SymbolVar y2_pad; unpack_vector( gopt::GraphOptimizer{} - .add_pass() + .add_pass(gopt::PaddingChannelPass::make( + cg::GraphCommonOptimizeOptions::LayoutTransform::NCHW64)) .apply({{y2}}) .endpoint_vars(), y2_pad); @@ -5166,7 +5169,8 @@ TEST(TestGoptInference, PaddingChannelsWithPooling) { SymbolVar y1_pad; unpack_vector( gopt::GraphOptimizer{} - .add_pass() + .add_pass(gopt::PaddingChannelPass::make( + cg::GraphCommonOptimizeOptions::LayoutTransform::NCHW64)) .apply({{y1}}) .endpoint_vars(), y1_pad); @@ -5232,7 +5236,8 @@ TEST(TestGoptInference, PaddingChannelsWithWarpPerspective) { SymbolVar y1_pad; unpack_vector( gopt::GraphOptimizer{} - .add_pass() + .add_pass(gopt::PaddingChannelPass::make( + cg::GraphCommonOptimizeOptions::LayoutTransform::NCHW64)) .apply({{y1}}) .endpoint_vars(), y1_pad); -- GitLab