From 8a3eb05a1bb034406ed3ebf5c79f7977eef086c6 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 21 Jul 2021 17:10:46 +0800 Subject: [PATCH] refactor(mgb/gopt): refactor tensor reformat opt pass GitOrigin-RevId: a1b1e89b76e4fbdca4f481156bb8af6cae8fe4d8 --- dnn/src/common/named_tensor.cpp | 4 - src/gopt/impl/folding_conv_dimshuffle.cpp | 431 +++ src/gopt/impl/padding_channel.cpp | 451 +++ src/gopt/impl/reformat_manager.cpp | 131 +- src/gopt/impl/tensor_reformat.cpp | 2985 +++-------------- src/gopt/include/megbrain/gopt/inference.h | 1 + .../include/megbrain/gopt/reformat_manager.h | 28 +- src/gopt/test/reformat_manager.cpp | 171 + 8 files changed, 1505 insertions(+), 2697 deletions(-) create mode 100644 src/gopt/impl/folding_conv_dimshuffle.cpp create mode 100644 src/gopt/impl/padding_channel.cpp create mode 100644 src/gopt/test/reformat_manager.cpp diff --git a/dnn/src/common/named_tensor.cpp b/dnn/src/common/named_tensor.cpp index cee817322..0071f9d65 100644 --- a/dnn/src/common/named_tensor.cpp +++ b/dnn/src/common/named_tensor.cpp @@ -120,10 +120,6 @@ Dimension Dimension::operator/(const Dimension& rhs) const { static_cast(m_name), static_cast(rhs.m_name)); if (operator==(rhs)) return Dimension(m_name, 1, 1); - megdnn_assert( - !(*this < rhs), - "Divisor must be smaller than dividend(dividend:%s, divisor:%s)", - to_string().c_str(), rhs.to_string().c_str()); if (m_stride == rhs.m_stride) { if (m_extent == UNDETERMINED_EXTENT) { megdnn_assert(rhs.m_extent != UNDETERMINED_EXTENT, diff --git a/src/gopt/impl/folding_conv_dimshuffle.cpp b/src/gopt/impl/folding_conv_dimshuffle.cpp new file mode 100644 index 000000000..dac9f0757 --- /dev/null +++ b/src/gopt/impl/folding_conv_dimshuffle.cpp @@ -0,0 +1,431 @@ +/** + * \file src/gopt/impl/folding_conv_dimshuffle.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain/gopt/inference.h" +#include "megbrain/opr/basic_arith.h" +#include "megbrain/opr/dnn/convolution.h" +#include "megbrain/opr/tensor_manip.h" +#include "megbrain/opr/utility.h" +#include "megbrain/serialization/opr_shallow_copy.h" + +#include "megdnn/opr_param_defs.h" + +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" + +#include "megbrain/utils/hash_ct.h" + +#include "midout.h" + +#include "megbrain/gopt/reformat_manager.h" + +#if CUDA_VERSION >= 10020 +MIDOUT_DECL(megbrain_folding_conv_dimshuffle) +#define MIDOUT_B(tag) \ + MIDOUT_BEGIN(megbrain_folding_conv_dimshuffle, \ + midout_iv(MGB_HASH_STR(tag))) { +#define MIDOUT_E \ + } \ + MIDOUT_END(); + +using namespace mgb; +using namespace gopt; +using ReformatKey = ReformatManager::ReformatKey; + +/* ==================== FoldingConvBiasDimshufflePass ================= */ +const char* FoldingConvBiasDimshufflePass::name() const { + return mgb_cstr_log("folding conv bias dimshuffle pass"); +} + +void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { + MIDOUT_B("FoldingConvBiasDimshufflePass::apply"); + using DepType = cg::OperatorNodeProp::DepType; + ThinHashMap>> + readers; + static const ThinHashSet opr_type_list = { + opr::TypeCvt::typeinfo(), opr::Dimshuffle::typeinfo(), + opr::Reshape::typeinfo(), opr::ConvBias::typeinfo()}; + opt.graph().iter([&readers](OperatorNodeBase* opr) { + for (auto&& i : opr->node_prop().dep_map()) { + if (opr_type_list.count(i.first->owner_opr()->dyn_typeinfo())) { + readers[i.first->owner_opr()].emplace_back(opr, i.second); + } + } + }); + + auto rewriter = opt.graph().make_rewriter(); + auto try_conv_dimshuffle_reshape_typecvt = [&rewriter, &readers]( + OperatorNodeBase* opr) { + ThinHashSet opr_set; + ThinHashSet reader_set; + // check typecvt + auto typecvt = try_cast_as_op(opr); + if (typecvt == nullptr) + return false; + auto inp_dtype = typecvt->input(0)->dtype(), + out_dtype = typecvt->output(0)->dtype(); + bool is_s82f32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && + out_dtype.enumv() == DTypeEnum::Float32; + if (!is_s82f32) + return false; + opr_set.insert(opr); + + // check reshape + auto reshape = + try_cast_as_op(typecvt->input(0)->owner_opr()); + if (reshape == nullptr) + return false; + opr_set.insert(reshape); + for (auto&& i : readers[reshape]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + + // check shuffle + auto shuffle = + try_cast_as_op(reshape->input(0)->owner_opr()); + if (shuffle == nullptr) + return false; + auto&& param = shuffle->param(); + if (param.pattern_len != 5) + return false; + bool is_nchw42nchw = param.pattern[0] == 0 && param.pattern[1] == 1 && + param.pattern[2] == 4 && param.pattern[3] == 2 && + param.pattern[4] == 3 && + shuffle->input(0)->shape()[4] == 4; + if (!is_nchw42nchw) + return false; + opr_set.insert(shuffle); + for (auto&& i : readers[shuffle]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + + // check conv bias + auto conv_bias = + try_cast_as_op(shuffle->input(0)->owner_opr()); + if (conv_bias == nullptr) + return false; + inp_dtype = conv_bias->input(0)->dtype(); + bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && + conv_bias->param().format == + megdnn::param::ConvBias::Format::NCHW4; + if (!is_s8nchw4) + return false; + if (conv_bias->input().size() != 3) + return false; + opr_set.insert(conv_bias); + for (auto&& i : readers[conv_bias]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + for (auto reader : reader_set) { + if (opr_set.count(reader) <= 0) { + return false; + } + } + auto src = rewriter.get_var(conv_bias->input(0)), + filter = rewriter.get_var(conv_bias->input(1)), + bias = rewriter.get_var(conv_bias->input(2)); + auto new_bias = ReformatManager::instance().get(ReformatKey{ + TensorFormats::NCHWc4, TensorFormats::NCHW})({bias}); + new_bias = opr::TypeCvt::make(new_bias, dtype::Float32()).node(); + auto new_param = conv_bias->param(); + new_param.format = megdnn::param::ConvBias::Format::NCHW4_NCHW; + auto conv_bias_shuffle = opr::ConvBias::make( + src, filter, new_bias, new_param, conv_bias->execution_policy(), + OperatorNodeConfig{dtype::Float32()}); + rewriter.replace_var(opr->output(0), conv_bias_shuffle.node(), + mgb_cstr_log("replace conv_bias + typecvt + " + "dimshuffle + " + "reshape to conv_bias(NCHW4_NCHW)")); + return true; + }; + + auto try_conv_reformat_nchw42nchw32 = [&rewriter, + &readers](OperatorNodeBase* opr) { + ThinHashSet opr_set; + ThinHashSet reader_set; + // check reshape + auto reshape1 = try_cast_as_op(opr); + if (reshape1 == nullptr) + return false; + opr_set.insert(opr); + // check dimshuffle + auto shuffle = try_cast_as_op( + reshape1->input(0)->owner_opr()); + if (shuffle == nullptr) + return false; + auto&& param = shuffle->param(); + if (param.pattern_len != 6) + return false; + bool is_nchw42nchw32 = param.pattern[0] == 0 && param.pattern[1] == 1 && + param.pattern[2] == 3 && param.pattern[3] == 4 && + param.pattern[4] == 2 && param.pattern[5] == 5 && + shuffle->output(0)->shape()[5] == 4 && + shuffle->output(0)->shape()[4] == 8; + if (!is_nchw42nchw32) + return false; + opr_set.insert(shuffle); + for (auto&& i : readers[shuffle]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + // check reshape + auto reshape2 = + try_cast_as_op(shuffle->input(0)->owner_opr()); + if (reshape2 == nullptr) + return false; + opr_set.insert(reshape2); + for (auto&& i : readers[reshape2]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + // check conv bias + auto conv_bias = + try_cast_as_op(reshape2->input(0)->owner_opr()); + if (conv_bias == nullptr) + return false; + auto inp_dtype = conv_bias->input(0)->dtype(); + bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && + conv_bias->param().format == + megdnn::param::ConvBias::Format::NCHW4; + if (!is_s8nchw4) + return false; + if (conv_bias->input().size() != 3) + return false; + opr_set.insert(conv_bias); + for (auto&& i : readers[conv_bias]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + for (auto reader : reader_set) { + if (opr_set.count(reader) <= 0) { + return false; + } + } + auto src = rewriter.get_var(conv_bias->input(0)), + filter = rewriter.get_var(conv_bias->input(1)), + bias = rewriter.get_var(conv_bias->input(2)); + auto new_bias = ReformatManager::instance().get(ReformatKey{ + TensorFormats::NCHWc4, TensorFormats::NCHWc32})({bias}); + auto new_param = conv_bias->param(); + new_param.format = megdnn::param::ConvBias::Format::NCHW4_NCHW32; + auto conv_bias_shuffle = opr::ConvBias::make( + src, filter, new_bias, new_param, conv_bias->execution_policy(), + conv_bias->config()); + rewriter.replace_var( + opr->output(0), conv_bias_shuffle.node(), + mgb_cstr_log("replace conv_bias + " + "reformat to conv_bias(NCHW4_NCHW32)")); + return true; + }; + + auto try_conv_reformat_nchw42nhwc = [&rewriter, + &readers](OperatorNodeBase* opr) { + ThinHashSet opr_set; + ThinHashSet reader_set; + // check reshape + auto reshape = try_cast_as_op(opr); + if (reshape == nullptr) + return false; + opr_set.insert(opr); + + // check dimshuffle + auto shuffle = + try_cast_as_op(reshape->input(0)->owner_opr()); + if (shuffle == nullptr) + return false; + auto&& param = shuffle->param(); + if (param.pattern_len != 5) + return false; + bool is_nchw42nhwc = param.pattern[0] == 0 && param.pattern[1] == 2 && + param.pattern[2] == 3 && param.pattern[3] == 1 && + param.pattern[4] == 4 && + shuffle->output(0)->shape()[4] == 4; + if (!is_nchw42nhwc) + return false; + opr_set.insert(shuffle); + for (auto&& i : readers[shuffle]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + + auto typecvt = + try_cast_as_op(shuffle->input(0)->owner_opr()); + if (typecvt == nullptr) + return false; + auto in_dtype = typecvt->input(0)->dtype(), + out_dtype = typecvt->output(0)->dtype(); + bool is_s82s4 = in_dtype.enumv() == DTypeEnum::QuantizedS8 && + (out_dtype.enumv() == DTypeEnum::QuantizedS4 || + out_dtype.enumv() == DTypeEnum::Quantized4Asymm); + if (!is_s82s4) + return false; + opr_set.insert(typecvt); + for (auto&& i : readers[typecvt]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + + // check conv bias + auto conv_bias = + try_cast_as_op(typecvt->input(0)->owner_opr()); + if (conv_bias == nullptr) + return false; + auto inp_dtype = conv_bias->input(0)->dtype(); + bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && + conv_bias->param().format == + megdnn::param::ConvBias::Format::NCHW4; + if (!is_s8nchw4) + return false; + if (conv_bias->input().size() != 3) + return false; + opr_set.insert(conv_bias); + for (auto&& i : readers[conv_bias]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + for (auto reader : reader_set) { + if (opr_set.count(reader) <= 0) { + return false; + } + } + auto src = rewriter.get_var(conv_bias->input(0)), + filter = rewriter.get_var(conv_bias->input(1)), + bias = rewriter.get_var(conv_bias->input(2)); + auto new_bias = ReformatManager::instance().get(ReformatKey{ + TensorFormats::NCHWc4, TensorFormats::NHWC})({bias}); + auto new_param = conv_bias->param(); + new_param.format = megdnn::param::ConvBias::Format::NCHW4_NHWC; + auto conv_bias_shuffle = opr::ConvBias::make( + src, filter, new_bias, new_param, conv_bias->execution_policy(), + OperatorNodeConfig{out_dtype}); + rewriter.replace_var(opr->output(0), conv_bias_shuffle.node(), + mgb_cstr_log("replace conv_bias + " + "reformat to conv_bias(NCHW4_NHWC)")); + return true; + }; + + auto try_conv_reformat_nchw322nchw4 = [&rewriter, + &readers](OperatorNodeBase* opr) { + ThinHashSet opr_set; + ThinHashSet reader_set; + // check reshape + auto reshape1 = try_cast_as_op(opr); + if (reshape1 == nullptr) + return false; + opr_set.insert(opr); + // check dimshuffle + auto shuffle = try_cast_as_op( + reshape1->input(0)->owner_opr()); + if (shuffle == nullptr) + return false; + auto&& param = shuffle->param(); + if (param.pattern_len != 6) + return false; + bool is_nchw322nchw4 = param.pattern[0] == 0 && param.pattern[1] == 1 && + param.pattern[2] == 4 && param.pattern[3] == 2 && + param.pattern[4] == 3 && param.pattern[5] == 5 && + shuffle->input(0)->shape()[5] == 4 && + shuffle->input(0)->shape()[4] == 8; + if (!is_nchw322nchw4) + return false; + opr_set.insert(shuffle); + for (auto&& i : readers[shuffle]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + // check reshape + auto reshape2 = + try_cast_as_op(shuffle->input(0)->owner_opr()); + if (reshape2 == nullptr) + return false; + opr_set.insert(reshape2); + for (auto&& i : readers[reshape2]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + // check conv bias + auto conv_bias = + try_cast_as_op(reshape2->input(0)->owner_opr()); + if (conv_bias == nullptr) + return false; + auto inp_dtype = conv_bias->input(0)->dtype(); + bool is_s8nchw32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && + conv_bias->param().format == + megdnn::param::ConvBias::Format::NCHW32; + if (!is_s8nchw32) + return false; + if (conv_bias->input().size() != 3) + return false; + opr_set.insert(conv_bias); + for (auto&& i : readers[conv_bias]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + for (auto reader : reader_set) { + if (opr_set.count(reader) <= 0) { + return false; + } + } + auto src = rewriter.get_var(conv_bias->input(0)), + filter = rewriter.get_var(conv_bias->input(1)), + bias = rewriter.get_var(conv_bias->input(2)); + auto new_bias = ReformatManager::instance().get(ReformatKey{ + TensorFormats::NCHWc32, TensorFormats::NCHWc4})({bias}); + auto new_param = conv_bias->param(); + new_param.format = megdnn::param::ConvBias::Format::NCHW32_NCHW4; + auto conv_bias_shuffle = opr::ConvBias::make( + src, filter, new_bias, new_param, conv_bias->execution_policy(), + conv_bias->config()); + rewriter.replace_var( + opr->output(0), conv_bias_shuffle.node(), + mgb_cstr_log("replace conv_bias + " + "reformat to conv_bias(NCHW32_NCHW4)")); + return true; + }; + MGB_MARK_USED_VAR(try_conv_reformat_nchw322nchw4); + MGB_MARK_USED_VAR(try_conv_reformat_nchw42nchw32); + + auto on_opr = [&try_conv_dimshuffle_reshape_typecvt, + &try_conv_reformat_nchw42nchw32, + &try_conv_reformat_nchw42nhwc, + &try_conv_reformat_nchw322nchw4, + &rewriter](OperatorNodeBase* opr) { + if (!try_conv_dimshuffle_reshape_typecvt(opr) && + !try_conv_reformat_nchw42nchw32(opr) && + !try_conv_reformat_nchw42nhwc(opr) && + !try_conv_reformat_nchw322nchw4(opr)) { + rewriter.auto_replace_outputs(opr); + } + }; + opt.graph().iter(on_opr); + rewriter.apply_inplace(); + + MIDOUT_E +} +#endif + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/impl/padding_channel.cpp b/src/gopt/impl/padding_channel.cpp new file mode 100644 index 000000000..feec04cc5 --- /dev/null +++ b/src/gopt/impl/padding_channel.cpp @@ -0,0 +1,451 @@ +/** + * \file src/gopt/impl/padding_channel.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain/gopt/inference.h" +#include "megbrain/opr/basic_arith.h" +#include "megbrain/opr/dnn/convolution.h" +#include "megbrain/opr/dnn/pooling.h" +#include "megbrain/opr/imgproc.h" +#include "megbrain/opr/misc.h" +#include "megbrain/opr/nn_int.h" +#include "megbrain/opr/tensor_manip.h" +#include "megbrain/opr/utility.h" +#include "megbrain/serialization/opr_shallow_copy.h" + +#include "megdnn/opr_param_defs.h" +#include "megdnn/tensor_format.h" + +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" + +#include "megbrain/gopt/misc.h" +#include "megbrain/utils/hash_ct.h" + +#include "midout.h" + +#include "megbrain/gopt/reformat_manager.h" + +MIDOUT_DECL(megbrain_padding_channel) +#define MIDOUT_B(tag) \ + MIDOUT_BEGIN(megbrain_padding_channel, midout_iv(MGB_HASH_STR(tag))) { +#define MIDOUT_E \ + } \ + MIDOUT_END(); + +using namespace mgb; +using namespace gopt; +using ReformatKey = ReformatManager::ReformatKey; + +/* ==================== PaddingChannelPass ================= */ +const char* PaddingChannelPass::name() const { + return mgb_cstr_log("padding output channel to multiple of 4/32"); +} + +void PaddingChannelPass::apply(OptState& opt) const { + MIDOUT_B("PaddingChannelPass::apply"); + // do not check shape + opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^ + VarReplaceCheckFlag::CHECK_SHAPE); + + ThinHashSet padding_oprs; + ThinHashMap> + opr_replace_funcs; + + auto rewriter = opt.graph().make_rewriter(); + auto pad_in_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* { + mgb_assert(inp->shape().ndim == 4); + mgb_assert(inp->dtype().enumv() == DTypeEnum::QuantizedS4 || + inp->dtype().enumv() == DTypeEnum::Quantized4Asymm || + inp->dtype().enumv() == DTypeEnum::QuantizedS8 || + inp->dtype().enumv() == DTypeEnum::QuantizedS32); + TensorShape shape{inp->shape()[0], pad_channels, inp->shape()[2], + inp->shape()[3]}; + std::shared_ptr host_val = + std::make_shared(inp->comp_node(), inp->dtype()); + host_val->resize(shape); + auto ptr = host_val->raw_ptr(); + size_t size_bytes = + TensorLayout{shape, inp->dtype()}.span().dist_byte(); + std::memset(ptr, 0, size_bytes); + auto padding = + opr::ImmutableTensor::make(*inp->owner_graph(), *host_val); + auto out = opr::Concat::make({inp, padding}, 1); + return out.node(); + }; + + auto pad_out_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* { + mgb_assert(inp->shape().ndim == 4); + mgb_assert(inp->dtype().enumv() == DTypeEnum::QuantizedS4 || + inp->dtype().enumv() == DTypeEnum::Quantized4Asymm || + inp->dtype().enumv() == DTypeEnum::QuantizedS8 || + inp->dtype().enumv() == DTypeEnum::QuantizedS32); + TensorShape shape{pad_channels, inp->shape()[1], inp->shape()[2], + inp->shape()[3]}; + std::shared_ptr host_val = + std::make_shared(inp->comp_node(), inp->dtype()); + host_val->resize(shape); + auto ptr = host_val->raw_ptr(); + size_t size_bytes = + TensorLayout{shape, inp->dtype()}.span().dist_byte(); + std::memset(ptr, 0, size_bytes); + auto padding = + opr::ImmutableTensor::make(*inp->owner_graph(), *host_val); + auto out = opr::Concat::make({inp, padding}, 0); + return out.node(); + }; + + auto extract_subtensor = [](VarNode* inp, + const TensorShape& orig_shape) -> VarNode* { + mgb_assert(inp->shape().ndim == 4); + mgb_assert(inp->shape()[0] == orig_shape[0]); + mgb_assert(inp->shape()[2] == orig_shape[2]); + mgb_assert(inp->shape()[3] == orig_shape[3]); + size_t orig_channels = orig_shape[1]; + auto x = SymbolVar(inp); + auto cv = [&x](int v) { return x.make_scalar(v); }; + using AIdx = opr::Subtensor::AxisIndexer; + auto sub = opr::Subtensor::make( + x, {AIdx::make_interval(0, None, None, cv(1)), + AIdx::make_interval(1, None, cv(orig_channels), None), + AIdx::make_interval(2, None, None, cv(1)), + AIdx::make_interval(3, None, None, cv(1))}); + return sub.node(); + }; + + // padding policy for conv bias with data type qint8 + auto padding_policy_qint8 = [&padding_oprs, &pad_in_channels, + &pad_out_channels]( + OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + mgb_assert(new_inp.size() == 3); + mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape())); + auto inps = new_inp; + size_t out_channels = opr->input(1)->shape()[0]; + size_t in_channels = opr->input(1)->shape()[1]; + size_t new_in_channels = new_inp[0]->shape()[1]; + // pad input channels + if (padding_oprs.count(opr->input(0)->owner_opr())) { + size_t pad_channels = new_in_channels - in_channels; + inps[1] = pad_in_channels(new_inp[1], pad_channels); + } else { + size_t pad_channels = 0; + mgb_assert(new_in_channels == in_channels); + if (in_channels <= 16) { + if (in_channels % 4) + pad_channels = 4 - (in_channels % 4); // pad to use dp4a + } else { + if (in_channels % 32) + pad_channels = + 32 - (in_channels % 32); // pad to use tensorcore + } + if (pad_channels > 0) { + inps[0] = pad_in_channels(new_inp[0], pad_channels); + inps[1] = pad_in_channels(new_inp[1], pad_channels); + } + } + out_channels = inps[1]->shape()[0]; + in_channels = inps[1]->shape()[1]; + size_t pad_channels = 0; + if (out_channels <= 16) { + if (out_channels % 4) + pad_channels = 4 - (out_channels % 4); + } else { + if (out_channels % 32) + pad_channels = 32 - (out_channels % 32); + } + if (pad_channels > 0) { + inps[1] = pad_out_channels(inps[1], pad_channels); + inps[2] = pad_in_channels(inps[2], pad_channels); + padding_oprs.insert(opr); + } + return serialization::copy_opr_shallow(*opr, inps, opr->config()); + }; + + // padding policy for conv bias with data type qint4 and quint4 + auto padding_policy_int4 = [&padding_oprs, &pad_in_channels, + &pad_out_channels]( + OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + mgb_assert(new_inp.size() == 3); + mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape())); + auto inps = new_inp; + size_t out_channels = opr->input(1)->shape()[0]; + size_t in_channels = opr->input(1)->shape()[1]; + size_t new_in_channels = new_inp[0]->shape()[1]; + // pad input channels + if (padding_oprs.count(opr->input(0)->owner_opr())) { + if (new_in_channels <= 32) { + if (new_in_channels % 8 == 0) { + size_t pad_channels = new_in_channels - in_channels; + inps[1] = pad_in_channels(new_inp[1], pad_channels); + } else { + size_t pad_channels_0 = 8 - (new_in_channels % 8); + size_t pad_channels_1 = 8 - (in_channels % 8); + inps[0] = pad_in_channels(new_inp[0], pad_channels_0); + inps[1] = pad_in_channels(new_inp[1], pad_channels_1); + } + } else { + if (new_in_channels % 64 == 0) { + size_t pad_channels = new_in_channels - in_channels; + inps[1] = pad_in_channels(new_inp[1], pad_channels); + } else { + size_t pad_channels_0 = 64 - (new_in_channels % 64); + size_t pad_channels_1 = 64 - (in_channels % 64); + inps[0] = pad_in_channels(new_inp[0], pad_channels_0); + inps[1] = pad_in_channels(new_inp[1], pad_channels_1); + } + } + } else { + size_t pad_channels = 0; + mgb_assert(new_in_channels == in_channels); + if (in_channels <= 32) { + if (in_channels % 8) + pad_channels = 8 - (in_channels % 8); + } else { + if (in_channels % 64) + pad_channels = 64 - (in_channels % 64); + } + if (pad_channels > 0) { + inps[0] = pad_in_channels(new_inp[0], pad_channels); + inps[1] = pad_in_channels(new_inp[1], pad_channels); + } + } + out_channels = inps[1]->shape()[0]; + in_channels = inps[1]->shape()[1]; + size_t pad_channels = 0; + if (out_channels <= 32) { + if (out_channels % 8) + pad_channels = 8 - (out_channels % 8); + } else { + if (out_channels % 64) + pad_channels = 64 - (out_channels % 64); + } + if (pad_channels > 0) { + inps[1] = pad_out_channels(inps[1], pad_channels); + inps[2] = pad_in_channels(inps[2], pad_channels); + padding_oprs.insert(opr); + } + return serialization::copy_opr_shallow(*opr, inps, opr->config()); + }; + + opr_replace_funcs[opr::ConvBiasForward::typeinfo()] = + [&padding_oprs, &padding_policy_qint8, &padding_policy_int4]( + OperatorNodeBase* opr, const VarNodeArray& new_inp) { + if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) { + return padding_policy_qint8(opr, new_inp); + } else if (opr->input(0)->dtype().enumv() == + DTypeEnum::QuantizedS4 || + opr->input(0)->dtype().enumv() == + DTypeEnum::Quantized4Asymm) { + return padding_policy_int4(opr, new_inp); + } else { + mgb_assert( + padding_oprs.count(opr->input(0)->owner_opr()) == 0, + "conv bias operator for data type(%s) cannot be " + "padded channel. " + "consumer(%s), producer(%s)", + opr->input(0)->dtype().name(), opr->cname(), + opr->input(0)->owner_opr()->cname()); + return serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + } + }; + opr_replace_funcs[opr::ConvolutionBackwardData::typeinfo()] = + [&padding_oprs, &pad_in_channels, &pad_out_channels]( + OperatorNodeBase* opr, const VarNodeArray& new_inp) { + if (opr->input(1)->dtype().enumv() != DTypeEnum::QuantizedS8) { + mgb_assert( + padding_oprs.count(opr->input(0)->owner_opr()) == 0, + "conv bwd data operator for data type(%s) cannot " + "be " + "padded channel. " + "consumer(%s), producer(%s)", + opr->input(0)->dtype().name(), opr->cname(), + opr->input(0)->owner_opr()->cname()); + return serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + } + mgb_assert(opr->input().size() == new_inp.size()); + mgb_assert(new_inp.size() == 2, + "deconv (conv bwd data) operator for inference can " + "only have 2 input vars(got:%zu)", + new_inp.size()); + mgb_assert( + opr->input(0)->shape().eq_shape(new_inp[0]->shape())); + auto inps = new_inp; + size_t out_channels = opr->input(0)->shape()[0]; + size_t in_channels = opr->input(0)->shape()[1]; + size_t new_out_channels = new_inp[1]->shape()[1]; + // pad output channels + if (padding_oprs.count(opr->input(1)->owner_opr())) { + size_t pad_channels = new_out_channels - out_channels; + inps[0] = pad_out_channels(new_inp[0], pad_channels); + } else { + size_t pad_channels = 0; + if (out_channels % 4) + pad_channels = 4 - (out_channels % 4); + if (pad_channels > 0) { + inps[0] = pad_out_channels(new_inp[0], pad_channels); + inps[1] = pad_in_channels(new_inp[1], pad_channels); + } + } + out_channels = inps[0]->shape()[0]; + in_channels = inps[0]->shape()[1]; + // pad input channels + size_t pad_channels = 0; + if (in_channels % 4) + pad_channels = 4 - (in_channels % 4); + if (pad_channels > 0) { + inps[0] = pad_in_channels(inps[0], pad_channels); + padding_oprs.insert(opr); + } + return serialization::copy_opr_shallow(*opr, inps, + opr->config()); + }; + auto replace_format_aware_opr = [&padding_oprs]( + OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + if (opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8 && + opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS4 && + opr->input(0)->dtype().enumv() != DTypeEnum::Quantized4Asymm) { + mgb_assert(padding_oprs.count(opr->input(0)->owner_opr()) == 0, + "operator(type:%s,name:%s) for data type(%s) cannot be " + "padded channel. extra info:" + "consumer(%s), producer(%s)", + opr->dyn_typeinfo()->name, opr->cname(), + opr->input(0)->dtype().name(), opr->cname(), + opr->input(0)->owner_opr()->cname()); + return serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + } + mgb_assert(opr->input().size() == new_inp.size()); + if (padding_oprs.count(opr->input(0)->owner_opr())) { + padding_oprs.insert(opr); + } + return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); + }; + opr_replace_funcs[opr::PoolingForward::typeinfo()] = + replace_format_aware_opr; + opr_replace_funcs[opr::WarpPerspectiveForward::typeinfo()] = + replace_format_aware_opr; + + auto replace_elemwise_like_opr = [&padding_oprs, &extract_subtensor]( + OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + bool have_padding_inp = false; + bool padding_all_inps = true; + bool same_padding = true; + size_t channels_after_padding = 0; + size_t i = 0; + for (auto&& cur_inp : opr->input()) { + bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0; + if (padding_cur_inp) { + if (!have_padding_inp) + have_padding_inp = true; + if (channels_after_padding == 0) { + channels_after_padding = new_inp[i]->shape()[1]; + } else { + same_padding = + channels_after_padding == new_inp[i]->shape()[1]; + } + } + if (padding_all_inps && (!padding_cur_inp || !same_padding)) + padding_all_inps = false; + ++i; + } + if (have_padding_inp && !padding_all_inps) { + auto inps = new_inp; + for (size_t i = 0; i < new_inp.size(); ++i) { + auto cur_inp = opr->input(i); + bool padding_cur_inp = + padding_oprs.count(cur_inp->owner_opr()) > 0; + if (padding_cur_inp) { + inps[i] = extract_subtensor(inps[i], cur_inp->shape()); + } + } + return serialization::copy_opr_shallow(*opr, inps, opr->config()); + } + if (padding_all_inps) { + padding_oprs.insert(opr); + } + return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); + }; + opr_replace_funcs[opr::ElemwiseMultiType::typeinfo()] = + replace_elemwise_like_opr; + opr_replace_funcs[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr; + opr_replace_funcs[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr; + + auto replace_nonpadding_oprs = [&padding_oprs, &extract_subtensor]( + OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + auto inps = new_inp; + for (size_t i = 0; i < new_inp.size(); ++i) { + auto cur_inp = opr->input(i); + bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0; + if (padding_cur_inp) { + inps[i] = extract_subtensor(inps[i], cur_inp->shape()); + } + } + return serialization::copy_opr_shallow(*opr, inps, opr->config()); + }; + opr_replace_funcs[opr::Reshape::typeinfo()] = replace_nonpadding_oprs; + opr_replace_funcs[opr::GetVarShape::typeinfo()] = replace_nonpadding_oprs; + opr_replace_funcs[opr::Concat::typeinfo()] = replace_nonpadding_oprs; + opr_replace_funcs[opr::Reduce::typeinfo()] = replace_nonpadding_oprs; + opr_replace_funcs[opr::Subtensor::typeinfo()] = replace_nonpadding_oprs; + + auto on_opr = [&opt, &rewriter, &opr_replace_funcs, + &extract_subtensor](OperatorNodeBase* opr) { + auto it = opr_replace_funcs.find(opr->dyn_typeinfo()); + if (it != opr_replace_funcs.end()) { + VarNodeArray new_inp; + new_inp.reserve(opr->input().size()); + for (auto&& inp : opr->input()) { + new_inp.push_back(rewriter.get_var(inp)); + } + auto new_opr = (it->second)(opr, new_inp); + auto &&out0 = opr->output(), &&out1 = new_opr->output(); + mgb_assert(out0.size() == out1.size(), + "bad opr replace: src=%s{%s} dst=%s{%s}, " + "src.size=%zu " + "dst.size=%zu", + opr->cname(), opr->dyn_typeinfo()->name, + new_opr->cname(), new_opr->dyn_typeinfo()->name, + out0.size(), out1.size()); + for (size_t i = 0; i < out0.size(); ++i) { + if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { + mgb_assert(!out1[i]->contain_flag( + VarNode::Flag::VOLATILE_CONTENT)); + auto src = out0[i]; + auto dst = out1[i]; + if (opt.graph().endpoint_contain(src) && + !src->shape().eq_shape(dst->shape())) { + dst = extract_subtensor(dst, src->shape()); + } + rewriter.replace_var(src, dst, nullptr); + } + } + } else { + rewriter.auto_replace_outputs(opr); + } + }; + opt.graph().iter(on_opr); + rewriter.apply_inplace(); + + MIDOUT_E +} + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/impl/reformat_manager.cpp b/src/gopt/impl/reformat_manager.cpp index 8fa6e7604..4df3bfc19 100644 --- a/src/gopt/impl/reformat_manager.cpp +++ b/src/gopt/impl/reformat_manager.cpp @@ -11,7 +11,6 @@ */ #include "megbrain/gopt/reformat_manager.h" -#include #include "megbrain/opr/tensor_manip.h" using namespace mgb; @@ -65,6 +64,10 @@ NamedTensorShape tensor_formats_to_named_tensor_shape(TensorFormats format) { return {{"C//8"}, {"C%1"}, {"C%1"}, {"R"}, {"S"}, {"C%8"}}; case TensorFormats::KRSCk8: return {{"K//8"}, {"R"}, {"S"}, {"C"}, {"K%8"}}; + case TensorFormats::KCRSc4: + return {{"K"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}}; + case TensorFormats::GKCRSc4: + return {{"G"}, {"K"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}}; case TensorFormats::KCRS: return {{"K"}, {"C"}, {"R"}, {"S"}}; case TensorFormats::GKCRS: @@ -130,70 +133,40 @@ bool ReformatManager::ReformatKey::Equal::operator()( lhs.attribute == rhs.attribute; } +ReformatManager::ReformatKey& +ReformatManager::ReformatKey::deduce_reformat_dtype_enum(const DType& dt) { + static const ThinHashSet> set = { + {TensorFormats::NCHW, TensorFormats::NCHWc64}, + {TensorFormats::NCHWc64, TensorFormats::NCHW}, + {TensorFormats::NCHW, TensorFormats::NHWC}, + {TensorFormats::NHWC, TensorFormats::NCHW}}; + if (set.count({input_format, output_format}) > 0 && + (dt.enumv() == DTypeEnum::QuantizedS4 || + dt.enumv() == DTypeEnum::Quantized4Asymm)) { + input_dtype = output_dtype = dt.enumv(); + } + return *this; +} + // =================== ReformatManager ====================*/ -#define FOREACH_FEATURE_TENSOR_FORMATS(cb) \ - cb(NCHW) cb(NHWC) cb(NCHWc4) cb(NCHWc8) cb(NCHWc32) cb(NCHWc64) cb(CHWNc4) \ - cb(NHCWc4) -#define FOREACH_WEIGHT_TENSOR_FORMATS(cb) \ - cb(KRSCk4) cb(KRSCk4c4) cb(KCRSk4c4) cb(KCRSc4k4) cb(KCRSc8k8) cb(KRSCk8) \ - cb(GKRSCk4) cb(GKRSCk4c4) cb(GKCRSc4k4) cb(GKCRSk4c4) \ - cb(GKCRSc8k8) cb(C11RSc4) cb(C11RSc8) ReformatManager::ReformatManager() { - static constexpr TensorFormats feature_tensor_formats[] = { -#define cb(_fmt) TensorFormats::_fmt, - FOREACH_FEATURE_TENSOR_FORMATS(cb) -#undef cb - }; - static constexpr int nr_feature_tensor_formats = - sizeof(feature_tensor_formats) / sizeof(TensorFormats); - for (int i = 0; i < nr_feature_tensor_formats; ++i) { - for (int o = 0; o < nr_feature_tensor_formats; ++o) { - if (i == o) - continue; - NamedTensorShape input_shape = tensor_formats_to_named_tensor_shape( - feature_tensor_formats[i]); - NamedTensorShape output_shape = - tensor_formats_to_named_tensor_shape( - feature_tensor_formats[o]); - auto impl = std::get<0>( - ReformatEmitter{input_shape, output_shape}.emit()); - m_cache.emplace(ReformatKey{feature_tensor_formats[i], - feature_tensor_formats[o]}, - impl); - } - } - static constexpr TensorFormats default_weight_tensor_formats = - TensorFormats::KCRS; - static constexpr TensorFormats default_group_conv_weight_tensor_formats = - TensorFormats::GKCRS; - static constexpr TensorFormats default_chan_conv_weight_tensor_formats = - TensorFormats::C11RS; - static constexpr TensorFormats weight_tensor_formats[] = { -#define cb(_fmt) TensorFormats::_fmt, - FOREACH_WEIGHT_TENSOR_FORMATS(cb) -#undef cb - }; - static constexpr int nr_weight_tensor_formats = - sizeof(weight_tensor_formats) / sizeof(TensorFormats); - using Name = megdnn::Dimension::Name; - for (int o = 0; o < nr_weight_tensor_formats; ++o) { - NamedTensorShape output_shape = - tensor_formats_to_named_tensor_shape(weight_tensor_formats[o]); - TensorFormats input_format; - if (output_shape[0].name() == Name::G) { - input_format = default_group_conv_weight_tensor_formats; - } else if (output_shape[0].name() == Name::C) { - input_format = default_chan_conv_weight_tensor_formats; - } else { - mgb_assert(output_shape[0].name() == Name::K); - input_format = default_weight_tensor_formats; - } - NamedTensorShape input_shape = - tensor_formats_to_named_tensor_shape(input_format); - auto impl = - std::get<0>(ReformatEmitter{input_shape, output_shape}.emit()); - m_cache.emplace(ReformatKey{input_format, weight_tensor_formats[o]}, - impl); + using Attribute = ReformatKey::Attribute; + { + auto i = TensorFormats::NCHWc4, o = TensorFormats::CHWNc4; + auto&& impl1 = [](const VarNodeArray& vars) { + return opr::RelayoutFormat::make( + vars[0], + megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4) + .node(); + }; + m_cache.emplace(ReformatKey{i, o}, impl1); + auto&& impl2 = [](const VarNodeArray& vars) { + return opr::RelayoutFormat::make( + vars[0], + megdnn::param::RelayoutFormat::Mode::CHWN4_NCHW4) + .node(); + }; + m_cache.emplace(ReformatKey{o, i}, impl2); } { auto i = TensorFormats::NCHW, o = TensorFormats::NCHWc4; @@ -206,7 +179,7 @@ ReformatManager::ReformatManager() { m_cache.emplace(ReformatKey{i, o, Attribute::IC_SMALL}, impl); } { - auto i = TensorFormats::KCRS, o = TensorFormats::KCRSc4k4; + auto i = TensorFormats::KCRS, o = TensorFormats::KCRSc4; auto&& impl = [](const VarNodeArray& vars) { return opr::RelayoutFormat::make( vars[0], @@ -238,7 +211,7 @@ ReformatManager::ReformatManager() { auto&& impl = [](const VarNodeArray& vars) { return opr::RelayoutFormat::make( vars[0], - megdnn::param::RelayoutFormat::Mode::NCHW_NCHW64) + megdnn::param::RelayoutFormat::Mode::NCHW64_NCHW) .node(); }; m_cache.emplace( @@ -272,7 +245,7 @@ ReformatManager::ReformatManager() { auto&& impl = [](const VarNodeArray& vars) { return opr::RelayoutFormat::make( vars[0], - megdnn::param::RelayoutFormat::Mode::NCHW_NHWC) + megdnn::param::RelayoutFormat::Mode::NHWC_NCHW) .node(); }; m_cache.emplace( @@ -371,14 +344,23 @@ ReformatManager::ReformatManager() { impl); } } -#undef FOREACH_FEATURE_TENSOR_FORMATS -#undef FOREACH_WEIGHT_TENSOR_FORMATS -const ReformatManager::ReformatImpl& ReformatManager::get( +ReformatManager::ReformatImpl ReformatManager::get( const ReformatKey& key) const { + using Attribute = ReformatKey::Attribute; MGB_TRY { - auto&& impl = m_cache.at(key); - return impl; + auto find = m_cache.find(key); + if (find != m_cache.end()) { + auto rst = find->second; + return rst; + } + mgb_assert(key.attribute == Attribute::DEFAULT); + auto&& i = key.input_format; + auto&& o = key.output_format; + auto ishp = tensor_formats_to_named_tensor_shape(i); + auto oshp = tensor_formats_to_named_tensor_shape(o); + auto builder = std::get<0>(ReformatEmitter{ishp, oshp}.emit()); + return builder; } MGB_CATCH(std::exception & exc, { mgb_log_error( @@ -390,10 +372,7 @@ const ReformatManager::ReformatImpl& ReformatManager::get( } const ReformatManager& ReformatManager::instance() { - static ReformatManager* inst = nullptr; - if (inst == nullptr) { - inst = new ReformatManager(); - } - return *inst; + static ReformatManager inst; + return inst; } // vim: syntax=cpp.doxygen diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index a6be2bd5c..87e4d4f3e 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -42,6 +42,10 @@ #include "midout.h" +#include "megbrain/gopt/reformat_manager.h" + +#include "./global_layout_transform/utils.h" + MIDOUT_DECL(megbrain_tensor_reformat) #define MIDOUT_B(tag) \ MIDOUT_BEGIN(megbrain_tensor_reformat, midout_iv(MGB_HASH_STR(tag))) { @@ -51,6 +55,7 @@ MIDOUT_DECL(megbrain_tensor_reformat) using namespace mgb; using namespace gopt; +using ReformatKey = ReformatManager::ReformatKey; /* ================ TensorReformatPass =============== */ /*! @@ -67,99 +72,40 @@ using namespace gopt; * representations before being translated to MegBrain oprs, so the * oprs should not get involved in any actual computing. */ +// clang-format off MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder, - cg::SingleCNOperatorNodeBase) // { + cg::SingleCNOperatorNodeBase) // { public: -//! relayout type of this opr -enum class LayoutType { - NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout - NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout - NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout - CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout - NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout - NCHW_TO_NCHW4_IC_SMALL_CONV, ///< from nchw layout to nchw4 whose - ///< channel size less than 4 - NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout - NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout - NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout - - WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4 - //!< layout - WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to - //!< nchw4 layout - WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV, //!< weight from nchw layout - //!< to nchw4 layout whose - //! channel size less than 4 - - WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88 - //!< layout - WEIGHT_NCHW_TO_NCHW88_GROUP, //!< group weight from nchw layout to - //!< nchw88 layout - WEIGHT_NCHW_TO_NCHW88_CHAN, //!< channel wise weight from nchw layout - //!< to nchw88 layout - //!< the weight layout of input is nchw output is nchw88, special for - //!< shape weight in nchw like {64, 2, 3, 3} to {8, 3, 3, 2, 8} - WEIGHT_HYBIRD_NCHW_NCHW88, - WEIGHT_NCHW_TO_NCHW44_DENSE, //!< weight from nchw layout to nchw44 - //!< layout - WEIGHT_NCHW_TO_NCHW44_GROUP, //!< group weight from nchw layout to - //!< nchw44 layout - WEIGHT_NCHW_TO_NCHW44_CHAN, //!< channel wise weight from nchw layout - //!< to nchw44 layout - //!< the weight layout of input is nchw output is nchw44, special for - //!< shape weight in nchw like {64, 2, 3, 3} to {16, 3, 3, 2, 4} - WEIGHT_HYBIRD_NCHW_NCHW44, - WEIGHT_NCHW_TO_NCHW44_DOT_DENSE, //!< weight from NCHW44 layout to - //!< NCHW44_DOT layout dense - WEIGHT_NCHW_TO_NCHW44_DOT_GROUP, //!< weight from NCHW44 layout to - //!< NCHW44_DOT layout group - NCHW32_TO_NCHW, //! owner_graph(), {}, "RelayoutPlaceholder", {src_var}), - m_layout_type{layout_type} { + m_key{key} { add_input({src_var}); - add_equivalence_component>(m_layout_type); + add_equivalence_component>(&m_key); + m_output = ReformatManager::instance().get(m_key)({src_var}); add_output(None)->dtype(src_var->dtype()); } @@ -174,360 +120,13 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_comp_node() { void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { using namespace cg::static_infer; auto&& mgr = owner_graph()->static_infer_manager(); - DepVal deps; - for (auto i : input()) - deps.push_back({i, DepType::SHAPE}); - auto infer_shape = [this](TensorShape& dst, const InpVal& inp) { - TensorShape inp_shape = inp.val[0].shape(); - dst = inp_shape; - if (layout_type() == RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] / 8; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = inp_shape[4] * 8; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 32); - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] * 8; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = inp_shape[4] / 8; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); - dst[0] = inp_shape[1]; - dst[1] = inp_shape[2]; - dst[2] = inp_shape[3]; - dst[3] = inp_shape[0]; - dst[4] = inp_shape[4]; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::CHWN4_TO_NCHW4) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); - dst[0] = inp_shape[3]; - dst[1] = inp_shape[0]; - dst[2] = inp_shape[1]; - dst[3] = inp_shape[2]; - dst[4] = inp_shape[4]; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4 || - layout_type() == RelayoutPlaceholder::LayoutType:: - NCHW_TO_NCHW4_IC_SMALL_CONV) { - if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4) { - mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0, - "src shape %s", inp_shape.to_string().c_str()); - } else { - mgb_assert(layout_type() == - RelayoutPlaceholder::LayoutType:: - NCHW_TO_NCHW4_IC_SMALL_CONV); - mgb_assert(inp_shape.ndim == 4 && inp_shape[1] < 4); - } - dst.ndim = 5; - dst[0] = inp_shape[0]; - dst[1] = (inp_shape[1] + 4 - 1) / 4; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = 4; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); - dst.ndim = 4; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] * 4; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - } else if (layout_type() == RelayoutPlaceholder::LayoutType:: - WEIGHT_NCHW_TO_NCHW4_DENSE || - layout_type() == - RelayoutPlaceholder::LayoutType:: - WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV) { - if (layout_type() == - RelayoutPlaceholder::LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE) { - mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); - } else { - mgb_assert(layout_type() == - RelayoutPlaceholder::LayoutType:: - WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV); - mgb_assert(inp_shape.ndim == 4 && inp_shape[1] < 4); - } - - dst.ndim = 5; - dst[0] = inp_shape[0]; - dst[1] = (inp_shape[1] + 4 - 1) / 4; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = 4; - } else if (layout_type() == RelayoutPlaceholder::LayoutType:: - WEIGHT_NCHW_TO_NCHW4_GROUP) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[2] % 4 == 0); - dst.ndim = 6; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1]; - dst[2] = inp_shape[2] / 4; - dst[3] = inp_shape[3]; - dst[4] = inp_shape[4]; - dst[5] = 4; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW88) { - mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 8 == 0); - dst.ndim = 5; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] / 8; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = 8; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW88_TO_NCHW) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 8); - dst.ndim = 4; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] * 8; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - } else if (layout_type() == RelayoutPlaceholder::LayoutType:: - WEIGHT_NCHW_TO_NCHW88_DENSE) { - mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 8 == 0 && - inp_shape[1] % 8 == 0); - dst.ndim = 6; - dst[0] = inp_shape[0] / 8; - dst[1] = inp_shape[1] / 8; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = 8; - dst[5] = 8; - } else if (layout_type() == RelayoutPlaceholder::LayoutType:: - WEIGHT_NCHW_TO_NCHW88_GROUP) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 8 == 0 && - inp_shape[2] % 8 == 0); - dst.ndim = 7; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] / 8; - dst[2] = inp_shape[2] / 8; - dst[3] = inp_shape[3]; - dst[4] = inp_shape[4]; - dst[5] = 8; - dst[6] = 8; - } else if (layout_type() == RelayoutPlaceholder::LayoutType:: - WEIGHT_NCHW_TO_NCHW88_CHAN) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[1] == 1 && - inp_shape[2] == 1 && inp_shape[0] % 8 == 0); - dst.ndim = 6; - dst[0] = inp_shape[0] / 8; - dst[1] = inp_shape[1]; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = inp_shape[4]; - dst[5] = 8; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::WEIGHT_HYBIRD_NCHW_NCHW88) { - mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 8 == 0); - dst.ndim = 5; - dst[0] = inp_shape[0] / 8; - dst[1] = inp_shape[2]; - dst[2] = inp_shape[3]; - dst[3] = inp_shape[1]; - dst[4] = 8; - } else if (layout_type() == RelayoutPlaceholder::LayoutType:: - WEIGHT_NCHW_TO_NCHW44_DENSE || - layout_type() == RelayoutPlaceholder::LayoutType:: - WEIGHT_NCHW_TO_NCHW44_DOT_DENSE) { - mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 4 == 0 && - inp_shape[1] % 4 == 0); - dst.ndim = 6; - dst[0] = inp_shape[0] / 4; - dst[1] = inp_shape[1] / 4; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = 4; - dst[5] = 4; - } else if (layout_type() == RelayoutPlaceholder::LayoutType:: - WEIGHT_NCHW_TO_NCHW44_GROUP || - layout_type() == RelayoutPlaceholder::LayoutType:: - WEIGHT_NCHW_TO_NCHW44_DOT_GROUP) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 4 == 0 && - inp_shape[2] % 4 == 0); - dst.ndim = 7; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] / 4; - dst[2] = inp_shape[2] / 4; - dst[3] = inp_shape[3]; - dst[4] = inp_shape[4]; - dst[5] = 4; - dst[6] = 4; - } else if (layout_type() == RelayoutPlaceholder::LayoutType:: - WEIGHT_NCHW_TO_NCHW44_CHAN) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[1] == 1 && - inp_shape[2] == 1 && inp_shape[0] % 4 == 0); - dst.ndim = 6; - dst[0] = inp_shape[0] / 4; - dst[1] = inp_shape[1]; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = inp_shape[4]; - dst[5] = 4; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::WEIGHT_HYBIRD_NCHW_NCHW44) { - mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 4 == 0); - dst.ndim = 5; - dst[0] = inp_shape[0] / 4; - dst[1] = inp_shape[2]; - dst[2] = inp_shape[3]; - dst[3] = inp_shape[1]; - dst[4] = 4; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 32); - dst.ndim = 4; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] * 32; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW64) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 2 == 0 && - inp_shape[4] == 32); - dst.ndim = 5; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] / 2; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = 64; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW64_TO_NCHW) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 64); - dst.ndim = 4; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] * 64; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW64_TO_NCHW4) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 64); - dst.ndim = 5; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] * 16; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = 4; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW64_TO_NCHW32) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 64); - dst.ndim = 5; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] * 2; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = 32; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW64) { - mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 64 == 0, "%s", - inp_shape.to_string().c_str()); - dst.ndim = 5; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] / 64; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = 64; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW32) { - mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 32 == 0); - dst.ndim = 5; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] / 32; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = 32; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW64) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 16 == 0); - dst.ndim = 5; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] / 16; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = 64; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW_TO_NHWC) { - mgb_assert(inp_shape.ndim == 4); - dst.ndim = 4; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[2]; - dst[2] = inp_shape[3]; - dst[3] = inp_shape[1]; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW4_TO_NHWC) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); - dst.ndim = 4; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[2]; - dst[2] = inp_shape[3]; - dst[3] = inp_shape[1] * 4; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW32_TO_NHWC) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 32); - dst.ndim = 4; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[2]; - dst[2] = inp_shape[3]; - dst[3] = inp_shape[1] * 32; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW64_TO_NHWC) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 64); - dst.ndim = 4; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[2]; - dst[2] = inp_shape[3]; - dst[3] = inp_shape[1] * 64; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW) { - mgb_assert(inp_shape.ndim == 4); - dst.ndim = 4; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[3]; - dst[2] = inp_shape[1]; - dst[3] = inp_shape[2]; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW4) { - mgb_assert(inp_shape.ndim == 4 && inp_shape[3] % 4 == 0); - dst.ndim = 5; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[3] / 4; - dst[2] = inp_shape[1]; - dst[3] = inp_shape[2]; - dst[4] = 4; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW32) { - mgb_assert(inp_shape.ndim == 4 && inp_shape[3] % 32 == 0); - dst.ndim = 5; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[3] / 32; - dst[2] = inp_shape[1]; - dst[3] = inp_shape[2]; - dst[4] = 32; - } else { - mgb_assert(layout_type() == - RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW64); - mgb_assert(inp_shape.ndim == 4 && inp_shape[3] % 64 == 0); - dst.ndim = 5; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[3] / 64; - dst[2] = inp_shape[1]; - dst[3] = inp_shape[2]; - dst[4] = 64; - } - return true; - }; - mgr.register_shape_infer(output(0), {SourceType::DEP, deps, infer_shape}); + mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(m_output)); } SymbolVar TensorReformatPass::RelayoutPlaceholder::make( - VarNode* src_var, LayoutType layout_type) { + VarNode* src_var, const ReformatKey& key) { return src_var->owner_graph() - ->insert_opr( - std::make_unique(src_var, layout_type)) + ->insert_opr(std::make_unique(src_var, key)) ->output(0); } @@ -576,541 +175,13 @@ void TensorReformatPass::insert_pass(OptState& opt) const { } void TensorReformatPass::translate_pass(OptState& opt) const { - ThinHashMap> - reformat; - using LayoutType = RelayoutPlaceholder::LayoutType; - reformat[LayoutType::NCHW4_TO_CHWN4] = [](VarNode* inp) -> VarNode* { - megdnn::param::RelayoutFormat param; - param.mode = megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4; - auto reformat = opr::RelayoutFormat::make(inp, param); - return reformat.node(); - }; - reformat[LayoutType::CHWN4_TO_NCHW4] = [](VarNode* inp) -> VarNode* { - megdnn::param::RelayoutFormat param; - param.mode = megdnn::param::RelayoutFormat::Mode::CHWN4_NCHW4; - auto reformat = opr::RelayoutFormat::make(inp, param); - return reformat.node(); - }; - reformat[LayoutType::NCHW4_TO_NCHW32] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1) / 8, cv(8), sub(2), sub(3), sub(4)}, 0), - tshp1 = opr::Concat::make( - {sub(0), sub(1) / 8, sub(2), sub(3), sub(4) * 8}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::NCHW32_TO_NCHW4] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1), sub(2), sub(3), cv(8), sub(4) / 8}, 0), - tshp1 = opr::Concat::make( - {sub(0), sub(1) * 8, sub(2), sub(3), sub(4) / 8}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - - reformat[LayoutType::NCHW_TO_NCHW4_IC_SMALL_CONV] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto y = opr::RelayoutFormat::make( - x, megdnn::param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL); - return y.node(); - }; - reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto y = opr::RelayoutFormat::make( - x, megdnn::param::RelayoutFormat::Mode:: - NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT); - return y.node(); - }; - - reformat[LayoutType::NCHW_TO_NCHW4] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); - return y1.node(); - }; - reformat[LayoutType::NCHW4_TO_NCHW] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0); - auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); - auto y1 = opr::Reshape::make(y0, tshp0); - return y1.node(); - }; - reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0), - tshp1 = opr::Concat::make( - {sub(0), sub(1) / 4, sub(2), sub(3), cv(4)}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1), sub(2) / 4, cv(4), sub(3), sub(4)}, 0), - tshp1 = opr::Concat::make( - {sub(0), sub(1), sub(2) / 4, sub(3), sub(4), cv(4)}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 2, 4, 5, 3}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::NCHW_TO_NCHW88] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1) / 8, cv(8), sub(2), sub(3)}, 0), - tshp1 = opr::Concat::make( - {sub(0), sub(1) / 8, sub(2), sub(3), cv(8)}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::NCHW88_TO_NCHW] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make({sub(0), sub(1) * 8, sub(2), sub(3)}, 0); - auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); - auto y1 = opr::Reshape::make(y0, tshp0); - return y1.node(); - }; - reformat[LayoutType::WEIGHT_NCHW_TO_NCHW88_DENSE] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0) / 8, cv(8), sub(1) / 8, cv(8), sub(2), sub(3)}, 0), - tshp1 = opr::Concat::make( - {sub(0) / 8, sub(1) / 8, sub(2), sub(3), cv(8), cv(8)}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 2, 4, 5, 3, 1}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::WEIGHT_NCHW_TO_NCHW88_GROUP] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make({sub(0), sub(1) / 8, cv(8), sub(2) / 8, - cv(8), sub(3), sub(4)}, - 0), - tshp1 = opr::Concat::make({sub(0), sub(1) / 8, sub(2) / 8, sub(3), - sub(4), cv(8), cv(8)}, - 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 4, 2}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::WEIGHT_NCHW_TO_NCHW88_CHAN] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0) / 8, cv(8), sub(1), sub(2), sub(3), sub(4)}, 0), - tshp1 = opr::Concat::make( - {sub(0) / 8, sub(1), sub(2), sub(3), sub(4), cv(8)}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 2, 3, 4, 5, 1}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::WEIGHT_HYBIRD_NCHW_NCHW88] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0) / 8, cv(8), sub(1), sub(2), sub(3)}, 0), - tshp1 = opr::Concat::make( - {sub(0) / 8, sub(2), sub(3), sub(1), cv(8)}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 3, 4, 2, 1}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_DENSE] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0) / 4, cv(4), sub(1) / 4, cv(4), sub(2), sub(3)}, 0), - tshp1 = opr::Concat::make( - {sub(0) / 4, sub(1) / 4, sub(2), sub(3), cv(4), cv(4)}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 2, 4, 5, 3, 1}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_GROUP] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2) / 4, - cv(4), sub(3), sub(4)}, - 0), - tshp1 = opr::Concat::make({sub(0), sub(1) / 4, sub(2) / 4, sub(3), - sub(4), cv(4), cv(4)}, - 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 4, 2}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_CHAN] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0) / 4, cv(4), sub(1), sub(2), sub(3), sub(4)}, 0), - tshp1 = opr::Concat::make( - {sub(0) / 4, sub(1), sub(2), sub(3), sub(4), cv(4)}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 2, 3, 4, 5, 1}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::WEIGHT_HYBIRD_NCHW_NCHW44] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0) / 4, cv(4), sub(1), sub(2), sub(3)}, 0), - tshp1 = opr::Concat::make( - {sub(0) / 4, sub(2), sub(3), sub(1), cv(4)}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 3, 4, 2, 1}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0) / 4, cv(4), sub(1) / 4, cv(4), sub(2), sub(3)}, 0), - tshp1 = opr::Concat::make( - {sub(0) / 4, sub(1) / 4, sub(2), sub(3), cv(4), cv(4)}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 2, 4, 5, 1, 3}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2) / 4, - cv(4), sub(3), sub(4)}, - 0), - tshp1 = opr::Concat::make({sub(0), sub(1) / 4, sub(2) / 4, sub(3), - sub(4), cv(4), cv(4)}, - 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 2, 4}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::NCHW32_TO_NCHW] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = - opr::Concat::make({sub(0), sub(1) * 32, sub(2), sub(3)}, 0); - auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); - auto y1 = opr::Reshape::make(y0, tshp0); - return y1.node(); - }; - reformat[LayoutType::NCHW32_TO_NCHW64] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1) / 2, cv(2), sub(2), sub(3), sub(4)}, 0), - tshp1 = opr::Concat::make( - {sub(0), sub(1) / 2, sub(2), sub(3), sub(4) * 2}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::NCHW64_TO_NCHW] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = - opr::Concat::make({sub(0), sub(1) * 64, sub(2), sub(3)}, 0); - auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); - auto y1 = opr::Reshape::make(y0, tshp0); - return y1.node(); - }; - reformat[LayoutType::NCHW64_TO_NCHW4] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1), sub(2), sub(3), sub(4) / 4, cv(4)}, 0), - tshp1 = opr::Concat::make( - {sub(0), sub(1) * 16, sub(2), sub(3), cv(4)}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::NCHW64_TO_NCHW32] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1), sub(2), sub(3), sub(4) / 32, cv(32)}, 0), - tshp1 = opr::Concat::make( - {sub(0), sub(1) * 2, sub(2), sub(3), cv(32)}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::NCHW_TO_NCHW64] = [](VarNode* inp) -> VarNode* { - megdnn::param::RelayoutFormat param; - param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NCHW64; - auto reformat = opr::RelayoutFormat::make(inp, param); - return reformat.node(); - }; - reformat[LayoutType::NCHW_TO_NCHW32] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1) / 32, cv(32), sub(2), sub(3)}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); - return y1.node(); - }; - reformat[LayoutType::NCHW4_TO_NCHW64] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1) / 16, cv(16), sub(2), sub(3), sub(4)}, 0), - tshp1 = opr::Concat::make( - {sub(0), sub(1) / 16, sub(2), sub(3), sub(4) * 16}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::NCHW_TO_NHWC] = [](VarNode* inp) -> VarNode* { - megdnn::param::RelayoutFormat param; - param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWC; - auto reformat = opr::RelayoutFormat::make(inp, param); - return reformat.node(); - }; - reformat[LayoutType::NCHW4_TO_NHWC] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 4}, 0); - auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4}); - auto y1 = opr::Reshape::make(y0, tshp0); - return y1.node(); - }; - reformat[LayoutType::NCHW32_TO_NHWC] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = - opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 32}, 0); - auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4}); - auto y1 = opr::Reshape::make(y0, tshp0); - return y1.node(); - }; - reformat[LayoutType::NCHW64_TO_NHWC] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = - opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 64}, 0); - auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4}); - auto y1 = opr::Reshape::make(y0, tshp0); - return y1.node(); - }; - reformat[LayoutType::NHWC_TO_NCHW] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto y = opr::Dimshuffle::make(x, {0, 3, 1, 2}); - return y.node(); - }; - reformat[LayoutType::NHWC_TO_NCHW4] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1), sub(2), sub(3) / 4, cv(4)}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4}); - return y1.node(); - }; - reformat[LayoutType::NHWC_TO_NCHW32] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1), sub(2), sub(3) / 32, cv(32)}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4}); - return y1.node(); - }; - reformat[LayoutType::NHWC_TO_NCHW64] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1), sub(2), sub(3) / 64, cv(64)}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4}); - return y1.node(); - }; - auto rewriter = opt.graph().make_rewriter(); - auto on_opr = [&reformat, &rewriter](OperatorNodeBase* opr) { + auto on_opr = [&rewriter](OperatorNodeBase* opr) { if (opr->same_type()) { auto ph = try_cast_as_op(opr); auto new_inp = rewriter.get_var(opr->input(0)); - mgb_assert(reformat.count(ph->layout_type()), - "no replace rule can be found for layout_type(%u)", - static_cast(ph->layout_type())); - auto new_var = reformat[ph->layout_type()](new_inp); + auto new_var = + ReformatManager::instance().get(ph->key())({new_inp}); rewriter.replace_var(opr->output(0), new_var, mgb_cstr_log("replace relayout placeholder")); return; @@ -1132,9 +203,9 @@ void TensorReformatPass::apply(OptState& opt) const { VarNode* EnableTensorCorePass::on_graph_endpoint_var(VarNode* new_var, VarNode* orig_var) const { if (!orig_var->shape().eq_shape(new_var->shape())) { - return RelayoutPlaceholder::make( - new_var, - RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4) + return RelayoutPlaceholder::make(new_var, + ReformatKey{TensorFormats::NCHWc32, + TensorFormats::NCHWc4}) .node(); } return new_var; @@ -1204,8 +275,8 @@ EnableTensorCorePass::make_tensorcore_converter() { if (group == 1 && ocpg % 32 == 0 && icpg % 32 == 0 && ih >= 3 && iw >= 3) { auto symvar = RelayoutPlaceholder::make( - new_inp[0], - RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32); + new_inp[0], ReformatKey{TensorFormats::NCHWc4, + TensorFormats::NCHWc32}); src = symvar.node(); can_replace_nchw32 = true; } else { @@ -1232,8 +303,8 @@ EnableTensorCorePass::make_tensorcore_converter() { src = new_inp[0]; } else { auto symvar = RelayoutPlaceholder::make( - new_inp[0], - RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4); + new_inp[0], ReformatKey{TensorFormats::NCHWc32, + TensorFormats::NCHWc4}); src = symvar.node(); } } @@ -1241,7 +312,7 @@ EnableTensorCorePass::make_tensorcore_converter() { if (can_replace_nchw32) { auto symvar = RelayoutPlaceholder::make( new_inp[1], - RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32); + ReformatKey{TensorFormats::NCHWc4, TensorFormats::NCHWc32}); weight = symvar.node(); } else { weight = new_inp[1]; @@ -1265,8 +336,8 @@ EnableTensorCorePass::make_tensorcore_converter() { if (can_replace_nchw32) { if (is_nchw4(inp->shape())) { auto symvar = RelayoutPlaceholder::make( - inp, - RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32); + inp, ReformatKey{TensorFormats::NCHWc4, + TensorFormats::NCHWc32}); return symvar.node(); } else { mgb_assert(is_nchw32(inp->shape())); @@ -1278,8 +349,8 @@ EnableTensorCorePass::make_tensorcore_converter() { } else { mgb_assert(is_nchw32(inp->shape())); auto symvar = RelayoutPlaceholder::make( - inp, - RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4); + inp, ReformatKey{TensorFormats::NCHWc32, + TensorFormats::NCHWc4}); return symvar.node(); } } @@ -1335,8 +406,9 @@ EnableTensorCorePass::make_tensorcore_converter() { for (size_t i = 0; i < nr_inps; ++i) { if (opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { auto symvar = RelayoutPlaceholder::make( - new_inp[i], RelayoutPlaceholder::LayoutType:: - NCHW4_TO_NCHW32); + new_inp[i], + ReformatKey{TensorFormats::NCHWc4, + TensorFormats::NCHWc32}); inps[i] = symvar.node(); } } @@ -1344,8 +416,8 @@ EnableTensorCorePass::make_tensorcore_converter() { for (size_t i = 0; i < nr_inps; ++i) { if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { auto symvar = RelayoutPlaceholder::make( - new_inp[i], RelayoutPlaceholder::LayoutType:: - NCHW32_TO_NCHW4); + new_inp[i], ReformatKey{TensorFormats::NCHWc32, + TensorFormats::NCHWc4}); inps[i] = symvar.node(); } } @@ -1366,8 +438,8 @@ EnableTensorCorePass::make_tensorcore_converter() { mgb_assert(new_inp[i]->shape().ndim == 5 && new_inp[i]->shape()[4] == 32); auto symvar = RelayoutPlaceholder::make( - new_inp[i], - RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4); + new_inp[i], ReformatKey{TensorFormats::NCHWc32, + TensorFormats::NCHWc4}); inps[i] = symvar.node(); } } @@ -1447,8 +519,8 @@ EnableTensorCorePass::make_tensorcore_converter() { if (opr->input(0)->shape().eq_shape(new_inp[0]->shape())) { new_inp_var = RelayoutPlaceholder::make( - new_inp[0], RelayoutPlaceholder::LayoutType:: - NCHW4_TO_NCHW32) + new_inp[0], ReformatKey{TensorFormats::NCHWc4, + TensorFormats::NCHWc32}) .node(); } else { mgb_assert(opr->input(0)->shape().ndim == 5 && @@ -1497,8 +569,9 @@ EnableTensorCorePass::make_tensorcore_converter() { VarNode* EnableCHWN4Pass::on_graph_endpoint_var(VarNode* new_var, VarNode* /* orig_var */) const { if (m_varshape_changed.count(new_var)) { - return RelayoutPlaceholder::make( - new_var, RelayoutPlaceholder::LayoutType::CHWN4_TO_NCHW4) + return RelayoutPlaceholder::make(new_var, + ReformatKey{TensorFormats::CHWNc4, + TensorFormats::NCHWc4}) .node(); } return new_var; @@ -1547,7 +620,7 @@ std::unique_ptr EnableCHWN4Pass::make_chwn4_converter() { // currently not support group conv auto symvar = RelayoutPlaceholder::make( new_inp[0], - RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4); + ReformatKey{TensorFormats::NCHWc4, TensorFormats::CHWNc4}); src = symvar.node(); } else { // new input is NCHW32 layout src = new_inp[0]; @@ -1556,7 +629,7 @@ std::unique_ptr EnableCHWN4Pass::make_chwn4_converter() { { auto symvar = RelayoutPlaceholder::make( new_inp[1], - RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4); + ReformatKey{TensorFormats::NCHWc4, TensorFormats::CHWNc4}); weight = symvar.node(); } if (new_inp.size() == 2) { @@ -1571,7 +644,8 @@ std::unique_ptr EnableCHWN4Pass::make_chwn4_converter() { auto process_inp = [&](VarNode* inp) -> VarNode* { if (varshape_changed.count(inp) == 0) { auto symvar = RelayoutPlaceholder::make( - inp, RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4); + inp, ReformatKey{TensorFormats::NCHWc4, + TensorFormats::CHWNc4}); return symvar.node(); } else { return inp; @@ -1618,8 +692,8 @@ std::unique_ptr EnableCHWN4Pass::make_chwn4_converter() { for (size_t i = 0; i < nr_inps; ++i) { if (varshape_changed.count(new_inp[i]) == 0) { auto symvar = RelayoutPlaceholder::make( - new_inp[i], RelayoutPlaceholder::LayoutType:: - NCHW4_TO_CHWN4); + new_inp[i], ReformatKey{TensorFormats::NCHWc4, + TensorFormats::CHWNc4}); inps[i] = symvar.node(); } } @@ -1631,8 +705,8 @@ std::unique_ptr EnableCHWN4Pass::make_chwn4_converter() { for (size_t i = 0; i < nr_inps; ++i) { if (varshape_changed.count(new_inp[i])) { auto symvar = RelayoutPlaceholder::make( - new_inp[i], RelayoutPlaceholder::LayoutType:: - CHWN4_TO_NCHW4); + new_inp[i], ReformatKey{TensorFormats::CHWNc4, + TensorFormats::NCHWc4}); inps[i] = symvar.node(); } } @@ -1651,8 +725,8 @@ std::unique_ptr EnableCHWN4Pass::make_chwn4_converter() { for (size_t i = 0; i < opr->input().size(); ++i) { if (varshape_changed.count(new_inp[i])) { auto symvar = RelayoutPlaceholder::make( - new_inp[i], - RelayoutPlaceholder::LayoutType::CHWN4_TO_NCHW4); + new_inp[i], ReformatKey{TensorFormats::CHWNc4, + TensorFormats::NCHWc4}); inps[i] = symvar.node(); } } @@ -1770,7 +844,8 @@ VarNode* EnableNCHW4Pass::on_graph_endpoint_var(VarNode* new_var, VarNode* orig_var) const { if (!orig_var->shape().eq_shape(new_var->shape())) { return RelayoutPlaceholder::make( - new_var, RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW) + new_var, + ReformatKey{TensorFormats::NCHWc4, TensorFormats::NCHW}) .node(); } return new_var; @@ -1781,7 +856,6 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { MIDOUT_B("EnableNCHW4Pass::make") auto ret = std::make_unique(); ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); - using RelayoutMode = RelayoutPlaceholder::LayoutType; megdnn::param::Convolution::Format conv_format = megdnn::param::Convolution::Format::NCHW4; megdnn::param::ConvBias::Format conv_bias_format = @@ -1790,16 +864,16 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { megdnn::param::ConvBias::Format::NCHW4_NCHW; megdnn::param::BatchConvBias::Format batch_conv_bias_format = megdnn::param::BatchConvBias::Format::NCHW4; - RelayoutMode src_to_nchw4_mode = RelayoutMode::NCHW_TO_NCHW4; - RelayoutMode src_to_nchw_mode = RelayoutMode::NCHW4_TO_NCHW; - RelayoutMode weight_to_nchw4_mode_dense = - RelayoutMode::WEIGHT_NCHW_TO_NCHW4_DENSE; - RelayoutMode weight_to_nchw4_mode_group = - RelayoutMode::WEIGHT_NCHW_TO_NCHW4_GROUP; + ReformatKey src_to_nchw4_mode{TensorFormats::NCHW, TensorFormats::NCHWc4}; + ReformatKey src_to_nchw_mode{TensorFormats::NCHWc4, TensorFormats::NCHW}; + ReformatKey weight_to_nchw4_mode_dense{TensorFormats::KCRS, + TensorFormats::KCRSc4}; + ReformatKey weight_to_nchw4_mode_group{TensorFormats::GKCRS, + TensorFormats::GKCRSc4}; struct ConvMode { - RelayoutMode weight; - RelayoutMode src; + ReformatKey weight; + ReformatKey src; }; auto trans_nchw4 = @@ -1812,8 +886,11 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { "The origin filter is not NCHW mode"); size_t IC = filter->shape()[1]; if (IC < 4) { - return {RelayoutMode::WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV, - RelayoutMode::NCHW_TO_NCHW4_IC_SMALL_CONV}; + ReformatKey weight{TensorFormats::KCRS, TensorFormats::KCRSc4, + ReformatKey::Attribute::IC_SMALL}; + ReformatKey src{TensorFormats::NCHW, TensorFormats::NCHWc4, + ReformatKey::Attribute::IC_SMALL}; + return {weight, src}; } else { return {weight_to_nchw4_mode_dense, src_to_nchw4_mode}; } @@ -1869,8 +946,8 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { }; auto replace_deconv_opr = [trans_nchw4, conv_format]( - OperatorNodeBase* opr, - const VarNodeArray& new_inp) { + OperatorNodeBase* opr, + const VarNodeArray& new_inp) { if (new_inp[1]->dtype().enumv() == DTypeEnum::Float32) { return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); @@ -1885,8 +962,7 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { opr->config()); } VarNode *deconv_src = new_inp[1], *deconv_filter = new_inp[0]; - auto deconv_mode = - trans_nchw4(deconv_opr.param().sparse, deconv_filter); + auto deconv_mode = trans_nchw4(deconv_opr.param().sparse, deconv_filter); // src: NCHW --> NCWH4 if (deconv_src->shape().ndim != 5) { mgb_assert(deconv_src->shape().ndim == 4); @@ -1940,8 +1016,7 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { } // weight: BNCHW --> BNCHW4 // only support dense mode, which is similar with conv->group. - auto weight_mode = - RelayoutPlaceholder::LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP; + ReformatKey weight_mode{TensorFormats::GKCRS, TensorFormats::GKCRSc4}; auto new_filter = RelayoutPlaceholder::make(new_inp[1], weight_mode); filter = new_filter.node(); // format: NCHW --> NCHW4 @@ -2033,10 +1108,10 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { conv_bias_src, conv_bias_filter, new_param, conv_bias_opr.execution_policy(), conv_bias_opr.config()); OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); - mgb_assert(new_conv_bias_opr.node()->dtype().enumv() == - DTypeEnum::Float32 || - new_conv_bias_opr.shape().ndim == 5, - "The conv_bias dst dim is not trans to nchw4"); + mgb_assert( + new_conv_bias_opr.node()->dtype().enumv() == DTypeEnum::Float32 || + new_conv_bias_opr.shape().ndim == 5, + "The conv_bias dst dim is not trans to nchw4"); return new_opr; } // bias: NCHW --> NCHW4 when bias_dtype is not Float32 @@ -2052,10 +1127,10 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, conv_bias_opr.execution_policy(), conv_bias_opr.config()); OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); - mgb_assert(new_conv_bias_opr.node()->dtype().enumv() == - DTypeEnum::Float32 || - new_conv_bias_opr.shape().ndim == 5, - "The conv_bias dst dim is not trans to nchw4"); + mgb_assert( + new_conv_bias_opr.node()->dtype().enumv() == DTypeEnum::Float32 || + new_conv_bias_opr.shape().ndim == 5, + "The conv_bias dst dim is not trans to nchw4"); return new_opr; } // z_inp: NCHW --> NCHW4 when bias_dtype is not Float32 @@ -2071,10 +1146,10 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { new_param, conv_bias_opr.execution_policy(), conv_bias_opr.config()); OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); - mgb_assert(new_conv_bias_opr.node()->dtype().enumv() == - DTypeEnum::Float32 || - new_conv_bias_opr.shape().ndim == 5, - "The conv_bias dst dim is not trans to nchw4"); + mgb_assert( + new_conv_bias_opr.node()->dtype().enumv() == DTypeEnum::Float32 || + new_conv_bias_opr.shape().ndim == 5, + "The conv_bias dst dim is not trans to nchw4"); return new_opr; }; auto replace_elemwise_opr = [=](OperatorNodeBase* opr, @@ -2215,7 +1290,8 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { auto&& replace_func = ret->m_opr_replace_func; //! supportted nchw4 replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; - replace_func[opr::ConvolutionBackwardData::typeinfo()] = replace_deconv_opr; + replace_func[opr::ConvolutionBackwardData::typeinfo()] = + replace_deconv_opr; replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; replace_func[opr::BatchConvBias::typeinfo()] = replace_batch_conv_bias_opr; replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; @@ -2244,14 +1320,14 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, VarNode* orig_var) const { if (!orig_var->shape().eq_shape(new_var->shape())) { if (m_pack_c_size == 8) { - return RelayoutPlaceholder::make( - new_var, - RelayoutPlaceholder::LayoutType::NCHW88_TO_NCHW) + return RelayoutPlaceholder::make(new_var, + ReformatKey{TensorFormats::NCHWc8, + TensorFormats::NCHW}) .node(); } else if (m_pack_c_size == 4) { - return RelayoutPlaceholder::make( - new_var, - RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW) + return RelayoutPlaceholder::make(new_var, + ReformatKey{TensorFormats::NCHWc4, + TensorFormats::NCHW}) .node(); } } @@ -2335,17 +1411,16 @@ static inline bool nchw_nchwxx_valid( } void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { - using RelayoutMode = RelayoutPlaceholder::LayoutType; - using TestFilterResult = std::pair; - RelayoutMode weight_to_nchwxx_mode_dense = - RelayoutMode::WEIGHT_NCHW_TO_NCHW88_DENSE; - RelayoutMode weight_to_nchwxx_mode_group = - RelayoutMode::WEIGHT_NCHW_TO_NCHW88_GROUP; - RelayoutMode weight_to_nchwxx_mode_chan = - RelayoutMode::WEIGHT_NCHW_TO_NCHW88_CHAN; - RelayoutMode hybrid_nchw_nchwxx = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW88; - RelayoutMode src_to_nchwxx_mode = RelayoutMode::NCHW_TO_NCHW88; - RelayoutMode src_to_nchw_mode = RelayoutMode::NCHW88_TO_NCHW; + using TestFilterResult = std::pair; + ReformatKey weight_to_nchwxx_mode_dense{TensorFormats::KCRS, + TensorFormats::KCRSc8k8}; + ReformatKey weight_to_nchwxx_mode_group{TensorFormats::GKCRS, + TensorFormats::GKCRSc8k8}; + ReformatKey weight_to_nchwxx_mode_chan{TensorFormats::C11RS, + TensorFormats::C11RSc8}; + ReformatKey hybrid_nchw_nchwxx{TensorFormats::KCRS, TensorFormats::KRSCk8}; + ReformatKey src_to_nchwxx_mode{TensorFormats::NCHW, TensorFormats::NCHWc8}; + ReformatKey src_to_nchw_mode{TensorFormats::NCHWc8, TensorFormats::NCHW}; megdnn::param::ConvBias::Format conv_bias_format = megdnn::param::ConvBias::Format::NCHW88; megdnn::param::Convolution::Format conv_format = @@ -2357,12 +1432,12 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { std::string convter_pass_name = "conv_format_nchw88"; if (pack_c_size == 4) { - weight_to_nchwxx_mode_dense = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE; - weight_to_nchwxx_mode_group = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_GROUP; - weight_to_nchwxx_mode_chan = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_CHAN; - hybrid_nchw_nchwxx = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44; - src_to_nchwxx_mode = RelayoutMode::NCHW_TO_NCHW4; - src_to_nchw_mode = RelayoutMode::NCHW4_TO_NCHW; + weight_to_nchwxx_mode_dense.output_format = TensorFormats::KCRSc4k4; + weight_to_nchwxx_mode_group.output_format = TensorFormats::GKCRSc4k4; + weight_to_nchwxx_mode_chan.output_format = TensorFormats::C11RSc4; + hybrid_nchw_nchwxx.output_format = TensorFormats::KRSCk4; + src_to_nchwxx_mode.output_format = TensorFormats::NCHWc4; + src_to_nchw_mode.input_format = TensorFormats::NCHWc4; conv_bias_format = megdnn::param::ConvBias::Format::NCHW44; conv_format = megdnn::param::Convolution::Format::NCHW44; pooling_format = megdnn::param::Pooling::Format::NCHW44; @@ -2641,7 +1716,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { return new_opr; } }; - + auto replace_resize_opr = [=](OperatorNodeBase* opr, const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); @@ -2793,7 +1868,8 @@ VarNode* EnableNchw44DotPass::on_graph_endpoint_var(VarNode* new_var, VarNode* orig_var) const { if (!orig_var->shape().eq_shape(new_var->shape())) { return RelayoutPlaceholder::make( - new_var, RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW) + new_var, + ReformatKey{TensorFormats::NCHWc4, TensorFormats::NCHW}) .node(); } return new_var; @@ -2807,10 +1883,9 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { //! First is whether the conv can trans to nchwxx, second is the filter //! trans mode - using RelayoutMode = RelayoutPlaceholder::LayoutType; struct TestTransResult { TransType trans_type; - RelayoutMode relayout_mod; + ReformatKey relayout_mod; megdnn::param::Convolution::Format conv_format; }; constexpr size_t pack_c_size = 4_z; @@ -2828,18 +1903,19 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { ret.trans_type = TransType::TRANS_PURE_NCHWXX; if (is_int8) { - ret.relayout_mod = - RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE; + ret.relayout_mod = ReformatKey{TensorFormats::KCRS, + TensorFormats::KCRSk4c4}; ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; } else { - ret.relayout_mod = - RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE; + ret.relayout_mod = ReformatKey{TensorFormats::KCRS, + TensorFormats::KCRSc4k4}; ret.conv_format = megdnn::param::ConvBias::Format::NCHW44; } } else if (valid_nchw_nchw44) { ret.trans_type = TransType::TRANS_HYBIRD_NCHWXX; - ret.relayout_mod = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44; + ret.relayout_mod = + ReformatKey{TensorFormats::KCRS, TensorFormats::KRSCk4}; if (is_int8) { ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; @@ -2855,18 +1931,19 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { size_t icpg = filter->shape()[2]; if (icpg == 1 && ocpg == 1 && (group % pack_c_size == 0)) { ret.trans_type = TransType::TRANS_PURE_NCHWXX; - ret.relayout_mod = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_CHAN; + ret.relayout_mod = ReformatKey{TensorFormats::C11RS, + TensorFormats::C11RSc4}; ret.conv_format = megdnn::param::ConvBias::Format::NCHW44; } else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) { ret.trans_type = TransType::TRANS_PURE_NCHWXX; if (is_int8) { - ret.relayout_mod = - RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP; + ret.relayout_mod = ReformatKey{TensorFormats::GKCRS, + TensorFormats::GKCRSk4c4}; ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; } else { - ret.relayout_mod = - RelayoutMode::WEIGHT_NCHW_TO_NCHW44_GROUP; + ret.relayout_mod = ReformatKey{TensorFormats::GKCRS, + TensorFormats::GKCRSc4k4}; ret.conv_format = megdnn::param::ConvBias::Format::NCHW44; } } @@ -2898,7 +1975,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { //! if src is nchwxx, should RelayoutPlaceholder to nchw if (temp_inp[0]->shape().ndim == 5) { auto new_src = RelayoutPlaceholder::make( - new_inp[0], RelayoutMode::NCHW4_TO_NCHW); + new_inp[0], ReformatKey{TensorFormats::NCHWc4, + TensorFormats::NCHW}); temp_inp[0] = new_src.node(); } auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp, @@ -2917,7 +1995,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { if (new_inp[0]->shape().ndim != 5) { mgb_assert(new_inp[0]->shape().ndim == 4); auto new_src = RelayoutPlaceholder::make( - new_inp[0], RelayoutMode::NCHW_TO_NCHW4); + new_inp[0], ReformatKey{TensorFormats::NCHW, + TensorFormats::NCHWc4}); conv_src = new_src.node(); } auto new_param = conv_opr.param(); @@ -2986,14 +2065,16 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { //! if src is nchwxx, should RelayoutPlaceholder to nchw if (temp_inp[0]->shape().ndim == 5) { auto new_src = RelayoutPlaceholder::make( - new_inp[0], RelayoutMode::NCHW4_TO_NCHW); + new_inp[0], + ReformatKey{TensorFormats::NCHWc4, TensorFormats::NCHW}); temp_inp[0] = new_src.node(); } //! the bias is nchwxx if (new_inp.size() > 2 && temp_inp[2]->shape().ndim == 5) { auto new_bias = RelayoutPlaceholder::make( - new_inp[2], RelayoutMode::NCHW4_TO_NCHW); + new_inp[2], ReformatKey{TensorFormats::NCHWc4, + TensorFormats::NCHW}); temp_inp[2] = new_bias.node(); } auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp, @@ -3013,14 +2094,16 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { if (new_inp[0]->shape().ndim != 5) { mgb_assert(new_inp[0]->shape().ndim == 4); auto new_src = RelayoutPlaceholder::make( - new_inp[0], RelayoutMode::NCHW_TO_NCHW4); + new_inp[0], ReformatKey{TensorFormats::NCHW, + TensorFormats::NCHWc4}); conv_bias_src = new_src.node(); } //! bias trans to nchwxx mode if (new_inp.size() > 2) { if (new_inp[2]->shape().ndim == 4) { auto new_bias = RelayoutPlaceholder::make( - new_inp[2], RelayoutMode::NCHW_TO_NCHW4); + new_inp[2], ReformatKey{TensorFormats::NCHW, + TensorFormats::NCHWc4}); conv_bias_bias = new_bias.node(); } else { mgb_assert(new_inp[2]->shape().ndim == 5); @@ -3059,7 +2142,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { if (new_inp.size() > 2) { if (new_inp[2]->shape().ndim == 4) { auto new_bias = RelayoutPlaceholder::make( - new_inp[2], RelayoutMode::NCHW_TO_NCHW4); + new_inp[2], ReformatKey{TensorFormats::NCHW, + TensorFormats::NCHWc4}); conv_bias_bias = new_bias.node(); } else { mgb_assert(new_inp[2]->shape().ndim == 5); @@ -3099,303 +2183,21 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { /* ==================== ShuffleShuffleRemovePass ================= */ class ShuffleShuffleRemovePass::Impl { - using TensorFormat = opr::ConvBias::Param::Format; + using Format = opr::ConvBias::Param::Format; OptState& m_opt_state; - ThinHashMap, - thin_function> - m_reformat; - - class AbstractShuffleOpr; + using AbstractShuffleOpr = TensorReformatPass::RelayoutPlaceholder; void detect_shuffle_operations(); void do_replace(); public: Impl(OptState& opt_state) : m_opt_state{opt_state} { - m_reformat[std::make_pair(TensorFormat::NCHW, TensorFormat::NCHW4)] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp = opr::Concat::make( - {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0); - auto y0 = opr::Reshape::make(x, tshp); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); - return y1.node(); - }; - - m_reformat[std::make_pair(TensorFormat::NCHW, TensorFormat::NCHW32)] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp = opr::Concat::make( - {sub(0), sub(1) / 32, cv(32), sub(2), sub(3)}, 0); - auto y0 = opr::Reshape::make(x, tshp); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); - return y1.node(); - }; - - m_reformat[std::make_pair(TensorFormat::NCHW4, TensorFormat::NCHW)] = - [](VarNode* inp) -> VarNode* { - mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4); - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp = - opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0); - auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); - auto y1 = opr::Reshape::make(y0, tshp); - return y1.node(); - }; - - m_reformat[std::make_pair(TensorFormat::NCHW32, TensorFormat::NCHW)] = - [](VarNode* inp) -> VarNode* { - mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 32); - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp = - opr::Concat::make({sub(0), sub(1) * 32, sub(2), sub(3)}, 0); - auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); - auto y1 = opr::Reshape::make(y0, tshp); - return y1.node(); - }; - - m_reformat[std::make_pair(TensorFormat::NCHW4, TensorFormat::NCHW32)] = - [](VarNode* inp) -> VarNode* { - mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4); - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1) / 8, cv(8), sub(2), sub(3), sub(4)}, - 0), - tshp1 = opr::Concat::make( - {sub(0), sub(1) / 8, sub(2), sub(3), sub(4) * 8}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - - m_reformat[std::make_pair(TensorFormat::NCHW32, TensorFormat::NCHW4)] = - [](VarNode* inp) -> VarNode* { - mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 32); - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1), sub(2), sub(3), cv(8), sub(4) / 8}, - 0), - tshp1 = opr::Concat::make( - {sub(0), sub(1) * 8, sub(2), sub(3), sub(4) / 8}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - - m_reformat[std::make_pair(TensorFormat::NCHW4, TensorFormat::CHWN4)] = - [](VarNode* inp) -> VarNode* { - megdnn::param::RelayoutFormat param; - param.mode = megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4; - auto reformat = opr::RelayoutFormat::make(inp, param); - return reformat.node(); - }; - - m_reformat[std::make_pair(TensorFormat::CHWN4, TensorFormat::NCHW4)] = - [](VarNode* inp) -> VarNode* { - megdnn::param::RelayoutFormat param; - param.mode = megdnn::param::RelayoutFormat::Mode::CHWN4_NCHW4; - auto reformat = opr::RelayoutFormat::make(inp, param); - return reformat.node(); - }; - - m_reformat[std::make_pair(TensorFormat::NCHW, TensorFormat::CHWN4)] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp = opr::Concat::make( - {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0); - auto y0 = opr::Reshape::make(x, tshp); - auto y1 = opr::Dimshuffle::make(y0, {1, 3, 4, 0, 2}); - return y1.node(); - }; - - m_reformat[std::make_pair(TensorFormat::CHWN4, TensorFormat::NCHW)] = - [](VarNode* inp) -> VarNode* { - mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4); - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp = - opr::Concat::make({sub(3), sub(0) * 4, sub(1), sub(2)}, 0); - auto y0 = opr::Dimshuffle::make(x, {3, 0, 4, 1, 2}); - auto y1 = opr::Reshape::make(y0, tshp); - return y1.node(); - }; detect_shuffle_operations(); do_replace(); } }; -/*! - * \brief abstract operator representation of shuffle operation - */ -MGB_DEFINE_OPR_CLASS(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr, - cg::SingleCNOperatorNodeBase) // { -public: -AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format, - TensorFormat out_format); - -static SymbolVar make(VarNode* inpvar, TensorFormat inp_format, - TensorFormat out_format); - -TensorFormat inp_format() const { - return m_inp_format; -} - -TensorFormat out_format() const { - return m_out_format; -} - -private: -void init_output_static_infer_desc() override; -void scn_do_execute() override; -const TensorFormat m_inp_format; -const TensorFormat m_out_format; -} -; - -MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr); - -void ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::scn_do_execute() { - mgb_throw(InternalError, "AbstractShuffleOpr cannot be executed"); -} - -void ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr:: - init_output_static_infer_desc() { - using namespace cg::static_infer; - auto&& mgr = owner_graph()->static_infer_manager(); - DepVal deps; - for (auto i : input()) - deps.push_back({i, DepType::SHAPE}); - auto infer_shape = [this](TensorShape& dst, const InpVal& inp) { - TensorShape inp_shape = inp.val[0].shape(); - if (m_inp_format == TensorFormat::NCHW4 && - m_out_format == TensorFormat::NCHW32) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); - dst = inp_shape; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] / 8; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = inp_shape[4] * 8; - } else if (m_inp_format == TensorFormat::NCHW32 && - m_out_format == TensorFormat::NCHW4) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 32); - dst = inp_shape; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] * 8; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = inp_shape[4] / 8; - } else if (m_inp_format == TensorFormat::NCHW && - m_out_format == TensorFormat::NCHW4) { - mgb_assert(inp_shape.ndim == 4); - dst.ndim = 5; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] / 4; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = 4; - } else if (m_inp_format == TensorFormat::NCHW4 && - m_out_format == TensorFormat::NCHW) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); - dst.ndim = 4; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] * 4; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - } else if (m_inp_format == TensorFormat::NCHW4 && - m_out_format == TensorFormat::CHWN4) { - dst.ndim = 5; - dst[0] = inp_shape[1]; - dst[1] = inp_shape[2]; - dst[2] = inp_shape[3]; - dst[3] = inp_shape[0]; - dst[4] = inp_shape[4]; - } else if (m_inp_format == TensorFormat::CHWN4 && - m_out_format == TensorFormat::NCHW4) { - dst.ndim = 5; - dst[0] = inp_shape[3]; - dst[1] = inp_shape[0]; - dst[2] = inp_shape[1]; - dst[3] = inp_shape[2]; - dst[4] = inp_shape[4]; - } else { - mgb_throw(InternalError, - "Unsupported input format and output format."); - } - return true; - }; - mgr.register_shape_infer(output(0), {SourceType::DEP, deps, infer_shape}); -} - -ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::AbstractShuffleOpr( - VarNode* inpvar, TensorFormat inp_format, TensorFormat out_format) - : Super(inpvar->owner_graph(), {}, "AbstractShuffleOpr", {inpvar}), - m_inp_format{inp_format}, - m_out_format{out_format} { - add_input({inpvar}); - add_equivalence_component>(m_inp_format); - add_equivalence_component>(m_out_format); - add_output(None)->dtype(inpvar->dtype()); -} - -SymbolVar ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::make( - VarNode* inpvar, TensorFormat inp_format, TensorFormat out_format) { - return inpvar->owner_graph() - ->insert_opr(std::make_unique( - inpvar, inp_format, out_format)) - ->output(0); -} - void ShuffleShuffleRemovePass::Impl::detect_shuffle_operations() { auto rewriter = m_opt_state.graph().make_rewriter(); auto uniq_reader_check = UniqReaderCheck{m_opt_state.graph()}; @@ -3423,7 +2225,8 @@ void ShuffleShuffleRemovePass::Impl::detect_shuffle_operations() { return false; auto inp_var = rewriter.get_var(reshape->input(0)); auto abstract_shuffle = AbstractShuffleOpr::make( - inp_var, TensorFormat::NCHW, TensorFormat::NCHW4); + inp_var, + ReformatKey{TensorFormats::NCHW, TensorFormats::NCHWc4}); rewriter.replace_var( opr->output(0), abstract_shuffle.node(), mgb_cstr_log("replace reformat(nchw -> nchw4) to " @@ -3469,12 +2272,11 @@ void ShuffleShuffleRemovePass::Impl::detect_shuffle_operations() { if (reshape2 == nullptr) return false; auto inp_var = rewriter.get_var(reshape2->input(0)); - TensorFormat inp_format = is_nchw42nchw32 ? TensorFormat::NCHW4 - : TensorFormat::NCHW32, - out_format = is_nchw42nchw32 ? TensorFormat::NCHW32 - : TensorFormat::NCHW4; - auto abstract_shuffle = - AbstractShuffleOpr::make(inp_var, inp_format, out_format); + Format inp_format = is_nchw42nchw32 ? Format::NCHW4 : Format::NCHW32, + out_format = is_nchw42nchw32 ? Format::NCHW32 : Format::NCHW4; + auto abstract_shuffle = AbstractShuffleOpr::make( + inp_var, ReformatKey{opr_format_to_tensor_formats(inp_format), + opr_format_to_tensor_formats(out_format)}); std::string reformat_type = is_nchw42nchw32 ? "nchw4 -> nchw32" : "nchw32 -> nchw4"; rewriter.replace_var(opr->output(0), abstract_shuffle.node(), @@ -3507,11 +2309,22 @@ void ShuffleShuffleRemovePass::Impl::detect_shuffle_operations() { param.pattern[2] == 4 && param.pattern[3] == 2 && param.pattern[4] == 3 && shuffle->input(0)->shape()[4] == 4; - if (!is_nchw42nchw) + bool is_nchw42nhwc = param.pattern[0] == 0 && param.pattern[1] == 2 && + param.pattern[2] == 3 && param.pattern[3] == 1 && + param.pattern[4] == 4 && + shuffle->input(0)->shape()[4] == 4; + if (!is_nchw42nchw && !is_nchw42nhwc) return false; auto inp_var = rewriter.get_var(shuffle->input(0)); - auto abstract_shuffle = AbstractShuffleOpr::make( - inp_var, TensorFormat::NCHW4, TensorFormat::NCHW); + ReformatKey key; + key.input_format = TensorFormats::NCHWc4; + if (is_nchw42nchw) { + key.output_format = TensorFormats::NCHW; + } else { + mgb_assert(is_nchw42nhwc); + key.output_format = TensorFormats::NHWC; + } + auto abstract_shuffle = AbstractShuffleOpr::make(inp_var, key); rewriter.replace_var( opr->output(0), abstract_shuffle.node(), mgb_cstr_log("replace reformat(nchw4 -> nchw) to " @@ -3532,10 +2345,12 @@ void ShuffleShuffleRemovePass::Impl::detect_shuffle_operations() { cg::SymbolVar abstract_shuffle; if (param.mode == opr::RelayoutFormat::Param::Mode::NCHW4_CHWN4) { abstract_shuffle = AbstractShuffleOpr::make( - inp_var, TensorFormat::NCHW4, TensorFormat::CHWN4); + inp_var, + ReformatKey{TensorFormats::NCHWc4, TensorFormats::CHWNc4}); } else { abstract_shuffle = AbstractShuffleOpr::make( - inp_var, TensorFormat::CHWN4, TensorFormat::NCHW4); + inp_var, + ReformatKey{TensorFormats::CHWNc4, TensorFormats::NCHWc4}); } rewriter.replace_var( opr->output(0), abstract_shuffle.node(), @@ -3591,7 +2406,7 @@ void ShuffleShuffleRemovePass::Impl::do_replace() { } } - auto on_opr = [this, &rewriter, &uniq_reader_check, &trt_opr_inps, + auto on_opr = [&rewriter, &uniq_reader_check, &trt_opr_inps, &root](OperatorNodeBase* opr) { MGB_MARK_USED_VAR(trt_opr_inps); bool cond_opr = opr->same_type() || @@ -3606,17 +2421,17 @@ void ShuffleShuffleRemovePass::Impl::do_replace() { bool force_folding_typecvt = false; bool first_shuffle = false; // initialize inp_format and out_format - TensorFormat out_format = TensorFormat::NCHW, - inp_format = out_format; + TensorFormats out_format = TensorFormats::NCHW, + inp_format = out_format; megdnn::DType inp_dtype = cur->input(0)->dtype(), out_dtype = cur->output(0)->dtype(); SmallVector out_dtype_vec; while (cond_opr) { if (cur->same_type()) { auto shuffle = try_cast_as_op(cur); - inp_format = shuffle->inp_format(); + inp_format = shuffle->key().input_format; if (!first_shuffle) { - out_format = shuffle->out_format(); + out_format = shuffle->key().output_format; first_shuffle = true; } } else { @@ -3639,11 +2454,8 @@ void ShuffleShuffleRemovePass::Impl::do_replace() { #endif auto new_var = rewriter.get_var(inp_var); if (inp_format != out_format) { - mgb_assert(m_reformat.find(std::make_pair( - inp_format, out_format)) != m_reformat.end(), - "Unsupported shuffle shuffle remove pass"); - new_var = m_reformat[std::make_pair(inp_format, out_format)]( - new_var); + new_var = ReformatManager::instance().get( + ReformatKey{inp_format, out_format})({new_var}); } if (force_folding_typecvt) { inp_dtype = inp_var->dtype(); @@ -3683,992 +2495,111 @@ void ShuffleShuffleRemovePass::apply(OptState& opt) const { MIDOUT_E } -#if CUDA_VERSION >= 10020 -/* ==================== FoldingConvBiasDimshufflePass ================= */ -const char* FoldingConvBiasDimshufflePass::name() const { - return mgb_cstr_log("folding conv bias dimshuffle pass"); +/* ================ EnableNCHW64Pass =============== */ +VarNode* EnableNCHW64Pass::on_graph_endpoint_var(VarNode* new_var, + VarNode* orig_var) const { + if (!orig_var->shape().eq_shape(new_var->shape())) { + auto iter = m_opr_format_map.find(new_var->owner_opr()); + mgb_assert(iter != m_opr_format_map.end(), + "cannot find opr(type:%s,name:%s) information, related " + "output var node(name:%s)", + new_var->owner_opr()->dyn_typeinfo()->name, + new_var->owner_opr()->cname(), new_var->cname()); + const auto& fmt = iter->second; + ReformatKey key; + MGB_TRY { + key.input_format = opr_format_to_tensor_formats(fmt); + key.output_format = TensorFormats::NCHW; + key.input_dtype = new_var->dtype().enumv(); + key.output_dtype = new_var->dtype().enumv(); + } + MGB_CATCH(AssertionError & err, { + mgb_log_error("%s, related var node(name:%s)", err.what(), + orig_var->cname()); + throw; + }) + return RelayoutPlaceholder::make(new_var, key).node(); + } + return new_var; } -void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { - MIDOUT_B("FoldingConvBiasDimshufflePass::apply"); - using DepType = cg::OperatorNodeProp::DepType; - ThinHashMap>> - readers; - static const ThinHashSet opr_type_list = { - opr::TypeCvt::typeinfo(), opr::Dimshuffle::typeinfo(), - opr::Reshape::typeinfo(), opr::ConvBias::typeinfo()}; - opt.graph().iter([&readers](OperatorNodeBase* opr) { - for (auto&& i : opr->node_prop().dep_map()) { - if (opr_type_list.count(i.first->owner_opr()->dyn_typeinfo())) { - readers[i.first->owner_opr()].emplace_back(opr, i.second); - } +std::unique_ptr +EnableNCHW64Pass::make_nchw64_converter() { + MIDOUT_B("EnableNCHW64Pass::make") + auto ret = std::make_unique(); + ret->set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^ + VarReplaceCheckFlag::CHECK_SHAPE); + auto& replace_func = ret->m_opr_replace_func; + auto& format_map = ret->m_opr_format_map; + auto make_new_conv = [](const VarNodeArray& inps, + const opr::ConvBiasForward* orig_conv, + Format format) { + auto param = orig_conv->param(); + // change format + param.format = format; + if (inps.size() == 2) { + auto new_conv = opr::ConvBiasForward::make( + inps[0], inps[1], param, orig_conv->execution_policy(), + orig_conv->config()); + return new_conv.node(); + } else if (inps.size() == 3) { + auto new_conv = opr::ConvBiasForward::make( + inps[0], inps[1], inps[2], param, + orig_conv->execution_policy(), orig_conv->config()); + return new_conv.node(); + } else { + mgb_assert(inps.size() == 4); + auto new_conv = opr::ConvBiasForward::make( + inps[0], inps[1], inps[2], inps[3], param, + orig_conv->execution_policy(), orig_conv->config()); + return new_conv.node(); } - }); - - auto rewriter = opt.graph().make_rewriter(); - auto nchw42nchw = [](VarNode* inp) -> VarNode* { - mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4); - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0); - auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); - auto y1 = opr::Reshape::make(y0, tshp); - auto y2 = opr::TypeCvt::make(y1, dtype::Float32()); - return y2.node(); }; - - auto nchw42nchw32 = [](VarNode* inp) -> VarNode* { - mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4); - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + auto try_transform_to_nchw = + [&format_map](OperatorNodeBase* opr, + const VarNodeArray& new_inp) -> VarNode* { + mgb_assert(opr->input().size() == new_inp.size()); + bool check_dtype = new_inp[0]->dtype().enumv() == DTypeEnum::Float32 && + new_inp[1]->dtype().enumv() == DTypeEnum::Float32; + if (opr->input().size() >= 3) + check_dtype &= new_inp[2]->dtype().enumv() == DTypeEnum::Float32; + if (opr->input().size() >= 4) + check_dtype &= new_inp[3]->dtype().enumv() == DTypeEnum::Float32; + if (!check_dtype) + return nullptr; + auto inps = new_inp; + auto process = [&](size_t i) -> VarNode* { + auto iter = format_map.find(new_inp[i]->owner_opr()); + if (iter == format_map.end()) { + return inps[i]; + } else { + const auto& fmt = iter->second; + ReformatKey key; + key.input_format = opr_format_to_tensor_formats(fmt); + key.output_format = TensorFormats::NCHW; + return RelayoutPlaceholder::make(inps[i], key).node(); + } }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1) / 8, cv(8), sub(2), sub(3), sub(4)}, 0), - tshp1 = opr::Concat::make( - {sub(0), sub(1) / 8, sub(2), sub(3), sub(4) * 8}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - - auto nchw322nchw4 = [](VarNode* inp) -> VarNode* { - mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 32); - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1), sub(2), sub(3), cv(8), sub(4) / 8}, 0), - tshp1 = opr::Concat::make( - {sub(0), sub(1) * 8, sub(2), sub(3), sub(4) / 8}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - - auto nchw42nhwc = [](VarNode* inp) -> VarNode* { - mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4); - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp = opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 4}, 0); - auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4}); - auto y1 = opr::Reshape::make(y0, tshp); - return y1.node(); - }; - - auto try_conv_dimshuffle_reshape_typecvt = [&rewriter, &readers, - &nchw42nchw]( - OperatorNodeBase* opr) { - ThinHashSet opr_set; - ThinHashSet reader_set; - // check typecvt - auto typecvt = try_cast_as_op(opr); - if (typecvt == nullptr) - return false; - auto inp_dtype = typecvt->input(0)->dtype(), - out_dtype = typecvt->output(0)->dtype(); - bool is_s82f32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && - out_dtype.enumv() == DTypeEnum::Float32; - if (!is_s82f32) - return false; - opr_set.insert(opr); - - // check reshape - auto reshape = - try_cast_as_op(typecvt->input(0)->owner_opr()); - if (reshape == nullptr) - return false; - opr_set.insert(reshape); - for (auto&& i : readers[reshape]) { - if (i.second & DepType::DEV_VALUE) { - reader_set.insert(i.first); - } - } - - // check shuffle - auto shuffle = - try_cast_as_op(reshape->input(0)->owner_opr()); - if (shuffle == nullptr) - return false; - auto&& param = shuffle->param(); - if (param.pattern_len != 5) - return false; - bool is_nchw42nchw = param.pattern[0] == 0 && param.pattern[1] == 1 && - param.pattern[2] == 4 && param.pattern[3] == 2 && - param.pattern[4] == 3 && - shuffle->input(0)->shape()[4] == 4; - if (!is_nchw42nchw) - return false; - opr_set.insert(shuffle); - for (auto&& i : readers[shuffle]) { - if (i.second & DepType::DEV_VALUE) { - reader_set.insert(i.first); - } - } - - // check conv bias - auto conv_bias = - try_cast_as_op(shuffle->input(0)->owner_opr()); - if (conv_bias == nullptr) - return false; - inp_dtype = conv_bias->input(0)->dtype(); - bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && - conv_bias->param().format == - megdnn::param::ConvBias::Format::NCHW4; - if (!is_s8nchw4) - return false; - if (conv_bias->input().size() != 3) - return false; - opr_set.insert(conv_bias); - for (auto&& i : readers[conv_bias]) { - if (i.second & DepType::DEV_VALUE) { - reader_set.insert(i.first); - } - } - for (auto reader : reader_set) { - if (opr_set.count(reader) <= 0) { - return false; - } - } - auto src = rewriter.get_var(conv_bias->input(0)), - filter = rewriter.get_var(conv_bias->input(1)), - bias = rewriter.get_var(conv_bias->input(2)); - auto new_bias = nchw42nchw(bias); - auto new_param = conv_bias->param(); - new_param.format = megdnn::param::ConvBias::Format::NCHW4_NCHW; - auto conv_bias_shuffle = opr::ConvBias::make( - src, filter, new_bias, new_param, conv_bias->execution_policy(), - OperatorNodeConfig{dtype::Float32()}); - rewriter.replace_var(opr->output(0), conv_bias_shuffle.node(), - mgb_cstr_log("replace conv_bias + typecvt + " - "dimshuffle + " - "reshape to conv_bias(NCHW4_NCHW)")); - return true; - }; - - auto try_conv_reformat_nchw42nchw32 = [&rewriter, &nchw42nchw32, - &readers](OperatorNodeBase* opr) { - ThinHashSet opr_set; - ThinHashSet reader_set; - // check reshape - auto reshape1 = try_cast_as_op(opr); - if (reshape1 == nullptr) - return false; - opr_set.insert(opr); - // check dimshuffle - auto shuffle = try_cast_as_op( - reshape1->input(0)->owner_opr()); - if (shuffle == nullptr) - return false; - auto&& param = shuffle->param(); - if (param.pattern_len != 6) - return false; - bool is_nchw42nchw32 = param.pattern[0] == 0 && param.pattern[1] == 1 && - param.pattern[2] == 3 && param.pattern[3] == 4 && - param.pattern[4] == 2 && param.pattern[5] == 5 && - shuffle->output(0)->shape()[5] == 4 && - shuffle->output(0)->shape()[4] == 8; - if (!is_nchw42nchw32) - return false; - opr_set.insert(shuffle); - for (auto&& i : readers[shuffle]) { - if (i.second & DepType::DEV_VALUE) { - reader_set.insert(i.first); - } - } - // check reshape - auto reshape2 = - try_cast_as_op(shuffle->input(0)->owner_opr()); - if (reshape2 == nullptr) - return false; - opr_set.insert(reshape2); - for (auto&& i : readers[reshape2]) { - if (i.second & DepType::DEV_VALUE) { - reader_set.insert(i.first); - } - } - // check conv bias - auto conv_bias = - try_cast_as_op(reshape2->input(0)->owner_opr()); - if (conv_bias == nullptr) - return false; - auto inp_dtype = conv_bias->input(0)->dtype(); - bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && - conv_bias->param().format == - megdnn::param::ConvBias::Format::NCHW4; - if (!is_s8nchw4) - return false; - if (conv_bias->input().size() != 3) - return false; - opr_set.insert(conv_bias); - for (auto&& i : readers[conv_bias]) { - if (i.second & DepType::DEV_VALUE) { - reader_set.insert(i.first); - } - } - for (auto reader : reader_set) { - if (opr_set.count(reader) <= 0) { - return false; - } - } - auto src = rewriter.get_var(conv_bias->input(0)), - filter = rewriter.get_var(conv_bias->input(1)), - bias = rewriter.get_var(conv_bias->input(2)); - auto new_bias = nchw42nchw32(bias); - auto new_param = conv_bias->param(); - new_param.format = megdnn::param::ConvBias::Format::NCHW4_NCHW32; - auto conv_bias_shuffle = opr::ConvBias::make( - src, filter, new_bias, new_param, conv_bias->execution_policy(), - conv_bias->config()); - rewriter.replace_var( - opr->output(0), conv_bias_shuffle.node(), - mgb_cstr_log("replace conv_bias + " - "reformat to conv_bias(NCHW4_NCHW32)")); - return true; - }; - - auto try_conv_reformat_nchw42nhwc = [&rewriter, &nchw42nhwc, - &readers](OperatorNodeBase* opr) { - ThinHashSet opr_set; - ThinHashSet reader_set; - // check reshape - auto reshape = try_cast_as_op(opr); - if (reshape == nullptr) - return false; - opr_set.insert(opr); - - // check dimshuffle - auto shuffle = - try_cast_as_op(reshape->input(0)->owner_opr()); - if (shuffle == nullptr) - return false; - auto&& param = shuffle->param(); - if (param.pattern_len != 5) - return false; - bool is_nchw42nhwc = param.pattern[0] == 0 && param.pattern[1] == 2 && - param.pattern[2] == 3 && param.pattern[3] == 1 && - param.pattern[4] == 4 && - shuffle->output(0)->shape()[4] == 4; - if (!is_nchw42nhwc) - return false; - opr_set.insert(shuffle); - for (auto&& i : readers[shuffle]) { - if (i.second & DepType::DEV_VALUE) { - reader_set.insert(i.first); - } - } - - auto typecvt = - try_cast_as_op(shuffle->input(0)->owner_opr()); - if (typecvt == nullptr) - return false; - auto in_dtype = typecvt->input(0)->dtype(), - out_dtype = typecvt->output(0)->dtype(); - bool is_s82s4 = in_dtype.enumv() == DTypeEnum::QuantizedS8 && - (out_dtype.enumv() == DTypeEnum::QuantizedS4 || - out_dtype.enumv() == DTypeEnum::Quantized4Asymm); - if (!is_s82s4) - return false; - opr_set.insert(typecvt); - for (auto&& i : readers[typecvt]) { - if (i.second & DepType::DEV_VALUE) { - reader_set.insert(i.first); - } - } - - // check conv bias - auto conv_bias = - try_cast_as_op(typecvt->input(0)->owner_opr()); - if (conv_bias == nullptr) - return false; - auto inp_dtype = conv_bias->input(0)->dtype(); - bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && - conv_bias->param().format == - megdnn::param::ConvBias::Format::NCHW4; - if (!is_s8nchw4) - return false; - if (conv_bias->input().size() != 3) - return false; - opr_set.insert(conv_bias); - for (auto&& i : readers[conv_bias]) { - if (i.second & DepType::DEV_VALUE) { - reader_set.insert(i.first); - } - } - for (auto reader : reader_set) { - if (opr_set.count(reader) <= 0) { - return false; - } - } - auto src = rewriter.get_var(conv_bias->input(0)), - filter = rewriter.get_var(conv_bias->input(1)), - bias = rewriter.get_var(conv_bias->input(2)); - auto new_bias = nchw42nhwc(bias); - auto new_param = conv_bias->param(); - new_param.format = megdnn::param::ConvBias::Format::NCHW4_NHWC; - auto conv_bias_shuffle = opr::ConvBias::make( - src, filter, new_bias, new_param, conv_bias->execution_policy(), - OperatorNodeConfig{out_dtype}); - rewriter.replace_var(opr->output(0), conv_bias_shuffle.node(), - mgb_cstr_log("replace conv_bias + " - "reformat to conv_bias(NCHW4_NHWC)")); - return true; - }; - - auto try_conv_reformat_nchw322nchw4 = [&rewriter, &readers, &nchw322nchw4]( - OperatorNodeBase* opr) { - ThinHashSet opr_set; - ThinHashSet reader_set; - // check reshape - auto reshape1 = try_cast_as_op(opr); - if (reshape1 == nullptr) - return false; - opr_set.insert(opr); - // check dimshuffle - auto shuffle = try_cast_as_op( - reshape1->input(0)->owner_opr()); - if (shuffle == nullptr) - return false; - auto&& param = shuffle->param(); - if (param.pattern_len != 6) - return false; - bool is_nchw322nchw4 = param.pattern[0] == 0 && param.pattern[1] == 1 && - param.pattern[2] == 4 && param.pattern[3] == 2 && - param.pattern[4] == 3 && param.pattern[5] == 5 && - shuffle->input(0)->shape()[5] == 4 && - shuffle->input(0)->shape()[4] == 8; - if (!is_nchw322nchw4) - return false; - opr_set.insert(shuffle); - for (auto&& i : readers[shuffle]) { - if (i.second & DepType::DEV_VALUE) { - reader_set.insert(i.first); - } - } - // check reshape - auto reshape2 = - try_cast_as_op(shuffle->input(0)->owner_opr()); - if (reshape2 == nullptr) - return false; - opr_set.insert(reshape2); - for (auto&& i : readers[reshape2]) { - if (i.second & DepType::DEV_VALUE) { - reader_set.insert(i.first); - } - } - // check conv bias - auto conv_bias = - try_cast_as_op(reshape2->input(0)->owner_opr()); - if (conv_bias == nullptr) - return false; - auto inp_dtype = conv_bias->input(0)->dtype(); - bool is_s8nchw32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && - conv_bias->param().format == - megdnn::param::ConvBias::Format::NCHW32; - if (!is_s8nchw32) - return false; - if (conv_bias->input().size() != 3) - return false; - opr_set.insert(conv_bias); - for (auto&& i : readers[conv_bias]) { - if (i.second & DepType::DEV_VALUE) { - reader_set.insert(i.first); - } - } - for (auto reader : reader_set) { - if (opr_set.count(reader) <= 0) { - return false; - } - } - auto src = rewriter.get_var(conv_bias->input(0)), - filter = rewriter.get_var(conv_bias->input(1)), - bias = rewriter.get_var(conv_bias->input(2)); - auto new_bias = nchw322nchw4(bias); - auto new_param = conv_bias->param(); - new_param.format = megdnn::param::ConvBias::Format::NCHW32_NCHW4; - auto conv_bias_shuffle = opr::ConvBias::make( - src, filter, new_bias, new_param, conv_bias->execution_policy(), - conv_bias->config()); - rewriter.replace_var( - opr->output(0), conv_bias_shuffle.node(), - mgb_cstr_log("replace conv_bias + " - "reformat to conv_bias(NCHW32_NCHW4)")); - return true; - }; - MGB_MARK_USED_VAR(try_conv_reformat_nchw322nchw4); - MGB_MARK_USED_VAR(try_conv_reformat_nchw42nchw32); - - auto on_opr = [&try_conv_dimshuffle_reshape_typecvt, - &try_conv_reformat_nchw42nchw32, - &try_conv_reformat_nchw42nhwc, - &try_conv_reformat_nchw322nchw4, - &rewriter](OperatorNodeBase* opr) { - if (!try_conv_dimshuffle_reshape_typecvt(opr) && - !try_conv_reformat_nchw42nchw32(opr) && - !try_conv_reformat_nchw42nhwc(opr) && - !try_conv_reformat_nchw322nchw4(opr)) { - rewriter.auto_replace_outputs(opr); - } - }; - opt.graph().iter(on_opr); - rewriter.apply_inplace(); - - MIDOUT_E -} -#endif - -/* ==================== PaddingChannelPass ================= */ -const char* PaddingChannelPass::name() const { - return mgb_cstr_log("padding output channel to multiple of 4/32"); -} - -void PaddingChannelPass::apply(OptState& opt) const { - MIDOUT_B("PaddingChannelPass::apply"); - // do not check shape - opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^ - VarReplaceCheckFlag::CHECK_SHAPE); - - ThinHashSet padding_oprs; - ThinHashMap> - opr_replace_funcs; - - auto rewriter = opt.graph().make_rewriter(); - auto pad_in_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* { - mgb_assert(inp->shape().ndim == 4); - mgb_assert(inp->dtype().enumv() == DTypeEnum::QuantizedS4 || - inp->dtype().enumv() == DTypeEnum::Quantized4Asymm || - inp->dtype().enumv() == DTypeEnum::QuantizedS8 || - inp->dtype().enumv() == DTypeEnum::QuantizedS32); - TensorShape shape{inp->shape()[0], pad_channels, inp->shape()[2], - inp->shape()[3]}; - std::shared_ptr host_val = - std::make_shared(inp->comp_node(), inp->dtype()); - host_val->resize(shape); - auto ptr = host_val->raw_ptr(); - size_t size_bytes = - TensorLayout{shape, inp->dtype()}.span().dist_byte(); - std::memset(ptr, 0, size_bytes); - auto padding = - opr::ImmutableTensor::make(*inp->owner_graph(), *host_val); - auto out = opr::Concat::make({inp, padding}, 1); - return out.node(); - }; - - auto pad_out_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* { - mgb_assert(inp->shape().ndim == 4); - mgb_assert(inp->dtype().enumv() == DTypeEnum::QuantizedS4 || - inp->dtype().enumv() == DTypeEnum::Quantized4Asymm || - inp->dtype().enumv() == DTypeEnum::QuantizedS8 || - inp->dtype().enumv() == DTypeEnum::QuantizedS32); - TensorShape shape{pad_channels, inp->shape()[1], inp->shape()[2], - inp->shape()[3]}; - std::shared_ptr host_val = - std::make_shared(inp->comp_node(), inp->dtype()); - host_val->resize(shape); - auto ptr = host_val->raw_ptr(); - size_t size_bytes = - TensorLayout{shape, inp->dtype()}.span().dist_byte(); - std::memset(ptr, 0, size_bytes); - auto padding = - opr::ImmutableTensor::make(*inp->owner_graph(), *host_val); - auto out = opr::Concat::make({inp, padding}, 0); - return out.node(); - }; - - auto extract_subtensor = [](VarNode* inp, - const TensorShape& orig_shape) -> VarNode* { - mgb_assert(inp->shape().ndim == 4); - mgb_assert(inp->shape()[0] == orig_shape[0]); - mgb_assert(inp->shape()[2] == orig_shape[2]); - mgb_assert(inp->shape()[3] == orig_shape[3]); - size_t orig_channels = orig_shape[1]; - auto x = SymbolVar(inp); - auto cv = [&x](int v) { return x.make_scalar(v); }; - using AIdx = opr::Subtensor::AxisIndexer; - auto sub = opr::Subtensor::make( - x, {AIdx::make_interval(0, None, None, cv(1)), - AIdx::make_interval(1, None, cv(orig_channels), None), - AIdx::make_interval(2, None, None, cv(1)), - AIdx::make_interval(3, None, None, cv(1))}); - return sub.node(); - }; - - // padding policy for conv bias with data type qint8 - auto padding_policy_qint8 = [&padding_oprs, &pad_in_channels, - &pad_out_channels]( - OperatorNodeBase* opr, - const VarNodeArray& new_inp) { - mgb_assert(opr->input().size() == new_inp.size()); - mgb_assert(new_inp.size() == 3); - mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape())); - auto inps = new_inp; - size_t out_channels = opr->input(1)->shape()[0]; - size_t in_channels = opr->input(1)->shape()[1]; - size_t new_in_channels = new_inp[0]->shape()[1]; - // pad input channels - if (padding_oprs.count(opr->input(0)->owner_opr())) { - size_t pad_channels = new_in_channels - in_channels; - inps[1] = pad_in_channels(new_inp[1], pad_channels); - } else { - size_t pad_channels = 0; - mgb_assert(new_in_channels == in_channels); - if (in_channels <= 16) { - if (in_channels % 4) - pad_channels = 4 - (in_channels % 4); // pad to use dp4a - } else { - if (in_channels % 32) - pad_channels = - 32 - (in_channels % 32); // pad to use tensorcore - } - if (pad_channels > 0) { - inps[0] = pad_in_channels(new_inp[0], pad_channels); - inps[1] = pad_in_channels(new_inp[1], pad_channels); - } - } - out_channels = inps[1]->shape()[0]; - in_channels = inps[1]->shape()[1]; - size_t pad_channels = 0; - if (out_channels <= 16) { - if (out_channels % 4) - pad_channels = 4 - (out_channels % 4); - } else { - if (out_channels % 32) - pad_channels = 32 - (out_channels % 32); - } - if (pad_channels > 0) { - inps[1] = pad_out_channels(inps[1], pad_channels); - inps[2] = pad_in_channels(inps[2], pad_channels); - padding_oprs.insert(opr); - } - return serialization::copy_opr_shallow(*opr, inps, opr->config()); - }; - - // padding policy for conv bias with data type qint4 and quint4 - auto padding_policy_int4 = [&padding_oprs, &pad_in_channels, - &pad_out_channels]( - OperatorNodeBase* opr, - const VarNodeArray& new_inp) { - mgb_assert(opr->input().size() == new_inp.size()); - mgb_assert(new_inp.size() == 3); - mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape())); - auto inps = new_inp; - size_t out_channels = opr->input(1)->shape()[0]; - size_t in_channels = opr->input(1)->shape()[1]; - size_t new_in_channels = new_inp[0]->shape()[1]; - // pad input channels - if (padding_oprs.count(opr->input(0)->owner_opr())) { - if (new_in_channels <= 32) { - if (new_in_channels % 8 == 0) { - size_t pad_channels = new_in_channels - in_channels; - inps[1] = pad_in_channels(new_inp[1], pad_channels); - } else { - size_t pad_channels_0 = 8 - (new_in_channels % 8); - size_t pad_channels_1 = 8 - (in_channels % 8); - inps[0] = pad_in_channels(new_inp[0], pad_channels_0); - inps[1] = pad_in_channels(new_inp[1], pad_channels_1); - } - } else { - if (new_in_channels % 64 == 0) { - size_t pad_channels = new_in_channels - in_channels; - inps[1] = pad_in_channels(new_inp[1], pad_channels); - } else { - size_t pad_channels_0 = 64 - (new_in_channels % 64); - size_t pad_channels_1 = 64 - (in_channels % 64); - inps[0] = pad_in_channels(new_inp[0], pad_channels_0); - inps[1] = pad_in_channels(new_inp[1], pad_channels_1); - } - } - } else { - size_t pad_channels = 0; - mgb_assert(new_in_channels == in_channels); - if (in_channels <= 32) { - if (in_channels % 8) - pad_channels = 8 - (in_channels % 8); - } else { - if (in_channels % 64) - pad_channels = 64 - (in_channels % 64); - } - if (pad_channels > 0) { - inps[0] = pad_in_channels(new_inp[0], pad_channels); - inps[1] = pad_in_channels(new_inp[1], pad_channels); - } - } - out_channels = inps[1]->shape()[0]; - in_channels = inps[1]->shape()[1]; - size_t pad_channels = 0; - if (out_channels <= 32) { - if (out_channels % 8) - pad_channels = 8 - (out_channels % 8); - } else { - if (out_channels % 64) - pad_channels = 64 - (out_channels % 64); - } - if (pad_channels > 0) { - inps[1] = pad_out_channels(inps[1], pad_channels); - inps[2] = pad_in_channels(inps[2], pad_channels); - padding_oprs.insert(opr); - } - return serialization::copy_opr_shallow(*opr, inps, opr->config()); - }; - - opr_replace_funcs[opr::ConvBiasForward::typeinfo()] = - [&padding_oprs, &padding_policy_qint8, &padding_policy_int4]( - OperatorNodeBase* opr, const VarNodeArray& new_inp) { - if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) { - return padding_policy_qint8(opr, new_inp); - } else if (opr->input(0)->dtype().enumv() == - DTypeEnum::QuantizedS4 || - opr->input(0)->dtype().enumv() == - DTypeEnum::Quantized4Asymm) { - return padding_policy_int4(opr, new_inp); - } else { - mgb_assert( - padding_oprs.count(opr->input(0)->owner_opr()) == 0, - "conv bias operator for data type(%s) cannot be " - "padded channel. " - "consumer(%s), producer(%s)", - opr->input(0)->dtype().name(), opr->cname(), - opr->input(0)->owner_opr()->cname()); - return serialization::copy_opr_shallow(*opr, new_inp, - opr->config()); - } - }; - opr_replace_funcs[opr::ConvolutionBackwardData::typeinfo()] = - [&padding_oprs, &pad_in_channels, &pad_out_channels]( - OperatorNodeBase* opr, const VarNodeArray& new_inp) { - if (opr->input(1)->dtype().enumv() != DTypeEnum::QuantizedS8) { - mgb_assert( - padding_oprs.count(opr->input(0)->owner_opr()) == 0, - "conv bwd data operator for data type(%s) cannot " - "be " - "padded channel. " - "consumer(%s), producer(%s)", - opr->input(0)->dtype().name(), opr->cname(), - opr->input(0)->owner_opr()->cname()); - return serialization::copy_opr_shallow(*opr, new_inp, - opr->config()); - } - mgb_assert(opr->input().size() == new_inp.size()); - mgb_assert(new_inp.size() == 2, - "deconv (conv bwd data) operator for inference can " - "only have 2 input vars(got:%zu)", - new_inp.size()); - mgb_assert( - opr->input(0)->shape().eq_shape(new_inp[0]->shape())); - auto inps = new_inp; - size_t out_channels = opr->input(0)->shape()[0]; - size_t in_channels = opr->input(0)->shape()[1]; - size_t new_out_channels = new_inp[1]->shape()[1]; - // pad output channels - if (padding_oprs.count(opr->input(1)->owner_opr())) { - size_t pad_channels = new_out_channels - out_channels; - inps[0] = pad_out_channels(new_inp[0], pad_channels); - } else { - size_t pad_channels = 0; - if (out_channels % 4) - pad_channels = 4 - (out_channels % 4); - if (pad_channels > 0) { - inps[0] = pad_out_channels(new_inp[0], pad_channels); - inps[1] = pad_in_channels(new_inp[1], pad_channels); - } - } - out_channels = inps[0]->shape()[0]; - in_channels = inps[0]->shape()[1]; - // pad input channels - size_t pad_channels = 0; - if (in_channels % 4) - pad_channels = 4 - (in_channels % 4); - if (pad_channels > 0) { - inps[0] = pad_in_channels(inps[0], pad_channels); - padding_oprs.insert(opr); - } - return serialization::copy_opr_shallow(*opr, inps, - opr->config()); - }; - auto replace_format_aware_opr = [&padding_oprs]( - OperatorNodeBase* opr, - const VarNodeArray& new_inp) { - if (opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8 && - opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS4 && - opr->input(0)->dtype().enumv() != DTypeEnum::Quantized4Asymm) { - mgb_assert(padding_oprs.count(opr->input(0)->owner_opr()) == 0, - "operator(type:%s,name:%s) for data type(%s) cannot be " - "padded channel. extra info:" - "consumer(%s), producer(%s)", - opr->dyn_typeinfo()->name, opr->cname(), - opr->input(0)->dtype().name(), opr->cname(), - opr->input(0)->owner_opr()->cname()); - return serialization::copy_opr_shallow(*opr, new_inp, - opr->config()); - } - mgb_assert(opr->input().size() == new_inp.size()); - if (padding_oprs.count(opr->input(0)->owner_opr())) { - padding_oprs.insert(opr); - } - return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); - }; - opr_replace_funcs[opr::PoolingForward::typeinfo()] = - replace_format_aware_opr; - opr_replace_funcs[opr::WarpPerspectiveForward::typeinfo()] = - replace_format_aware_opr; - - auto replace_elemwise_like_opr = [&padding_oprs, &extract_subtensor]( - OperatorNodeBase* opr, - const VarNodeArray& new_inp) { - mgb_assert(opr->input().size() == new_inp.size()); - bool have_padding_inp = false; - bool padding_all_inps = true; - bool same_padding = true; - size_t channels_after_padding = 0; - size_t i = 0; - for (auto&& cur_inp : opr->input()) { - bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0; - if (padding_cur_inp) { - if (!have_padding_inp) - have_padding_inp = true; - if (channels_after_padding == 0) { - channels_after_padding = new_inp[i]->shape()[1]; - } else { - same_padding = - channels_after_padding == new_inp[i]->shape()[1]; - } - } - if (padding_all_inps && (!padding_cur_inp || !same_padding)) - padding_all_inps = false; - ++i; - } - if (have_padding_inp && !padding_all_inps) { - auto inps = new_inp; - for (size_t i = 0; i < new_inp.size(); ++i) { - auto cur_inp = opr->input(i); - bool padding_cur_inp = - padding_oprs.count(cur_inp->owner_opr()) > 0; - if (padding_cur_inp) { - inps[i] = extract_subtensor(inps[i], cur_inp->shape()); - } - } - return serialization::copy_opr_shallow(*opr, inps, opr->config()); - } - if (padding_all_inps) { - padding_oprs.insert(opr); - } - return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); - }; - opr_replace_funcs[opr::ElemwiseMultiType::typeinfo()] = - replace_elemwise_like_opr; - opr_replace_funcs[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr; - opr_replace_funcs[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr; - - auto replace_nonpadding_oprs = [&padding_oprs, &extract_subtensor]( - OperatorNodeBase* opr, - const VarNodeArray& new_inp) { - mgb_assert(opr->input().size() == new_inp.size()); - auto inps = new_inp; - for (size_t i = 0; i < new_inp.size(); ++i) { - auto cur_inp = opr->input(i); - bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0; - if (padding_cur_inp) { - inps[i] = extract_subtensor(inps[i], cur_inp->shape()); - } - } - return serialization::copy_opr_shallow(*opr, inps, opr->config()); - }; - opr_replace_funcs[opr::Reshape::typeinfo()] = replace_nonpadding_oprs; - opr_replace_funcs[opr::GetVarShape::typeinfo()] = replace_nonpadding_oprs; - opr_replace_funcs[opr::Concat::typeinfo()] = replace_nonpadding_oprs; - opr_replace_funcs[opr::Reduce::typeinfo()] = replace_nonpadding_oprs; - opr_replace_funcs[opr::Subtensor::typeinfo()] = replace_nonpadding_oprs; - - auto on_opr = [&opt, &rewriter, &opr_replace_funcs, - &extract_subtensor](OperatorNodeBase* opr) { - auto it = opr_replace_funcs.find(opr->dyn_typeinfo()); - if (it != opr_replace_funcs.end()) { - VarNodeArray new_inp; - new_inp.reserve(opr->input().size()); - for (auto&& inp : opr->input()) { - new_inp.push_back(rewriter.get_var(inp)); - } - auto new_opr = (it->second)(opr, new_inp); - auto &&out0 = opr->output(), &&out1 = new_opr->output(); - mgb_assert(out0.size() == out1.size(), - "bad opr replace: src=%s{%s} dst=%s{%s}, " - "src.size=%zu " - "dst.size=%zu", - opr->cname(), opr->dyn_typeinfo()->name, - new_opr->cname(), new_opr->dyn_typeinfo()->name, - out0.size(), out1.size()); - for (size_t i = 0; i < out0.size(); ++i) { - if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { - mgb_assert(!out1[i]->contain_flag( - VarNode::Flag::VOLATILE_CONTENT)); - auto src = out0[i]; - auto dst = out1[i]; - if (opt.graph().endpoint_contain(src) && - !src->shape().eq_shape(dst->shape())) { - dst = extract_subtensor(dst, src->shape()); - } - rewriter.replace_var(src, dst, nullptr); - } - } - } else { - rewriter.auto_replace_outputs(opr); - } - }; - opt.graph().iter(on_opr); - rewriter.apply_inplace(); - - MIDOUT_E -} - -/* ================ EnableNCHW64Pass =============== */ -VarNode* EnableNCHW64Pass::on_graph_endpoint_var(VarNode* new_var, - VarNode* orig_var) const { - if (!orig_var->shape().eq_shape(new_var->shape())) { - auto iter = m_opr_format_map.find(new_var->owner_opr()); - mgb_assert(iter != m_opr_format_map.end(), - "cannot find opr(type:%s,name:%s) information, related " - "output var node(name:%s)", - new_var->owner_opr()->dyn_typeinfo()->name, - new_var->owner_opr()->cname(), new_var->cname()); - const auto& fmt = iter->second; - using LayoutType = RelayoutPlaceholder::LayoutType; - LayoutType type; - switch (fmt) { - case Format::NCHW4: - type = LayoutType::NCHW4_TO_NCHW; - break; - case Format::NCHW32: - type = LayoutType::NCHW32_TO_NCHW; - break; - case Format::NCHW64: - type = LayoutType::NCHW64_TO_NCHW; - break; - case Format::NHWC: - type = LayoutType::NHWC_TO_NCHW; - break; - default: - mgb_throw(AssertionError, - "format(%d) is not supported, related var " - "node(name:%s)", - static_cast(fmt), orig_var->cname()); - }; - return RelayoutPlaceholder::make(new_var, type).node(); - } - return new_var; -} - -std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { - MIDOUT_B("EnableNCHW64Pass::make") - auto ret = std::make_unique(); - ret->set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^ - VarReplaceCheckFlag::CHECK_SHAPE); - auto& replace_func = ret->m_opr_replace_func; - auto& format_map = ret->m_opr_format_map; - auto make_new_conv = [](const VarNodeArray& inps, - const opr::ConvBiasForward* orig_conv, - Format format) { - auto param = orig_conv->param(); - // change format - param.format = format; - if (inps.size() == 2) { - auto new_conv = opr::ConvBiasForward::make( - inps[0], inps[1], param, orig_conv->execution_policy(), - orig_conv->config()); - return new_conv.node(); - } else if (inps.size() == 3) { - auto new_conv = opr::ConvBiasForward::make( - inps[0], inps[1], inps[2], param, - orig_conv->execution_policy(), orig_conv->config()); - return new_conv.node(); - } else { - mgb_assert(inps.size() == 4); - auto new_conv = opr::ConvBiasForward::make( - inps[0], inps[1], inps[2], inps[3], param, - orig_conv->execution_policy(), orig_conv->config()); - return new_conv.node(); - } - }; - auto try_transform_to_nchw = - [&format_map](OperatorNodeBase* opr, - const VarNodeArray& new_inp) -> VarNode* { - mgb_assert(opr->input().size() == new_inp.size()); - bool check_dtype = new_inp[0]->dtype().enumv() == DTypeEnum::Float32 && - new_inp[1]->dtype().enumv() == DTypeEnum::Float32; - if (opr->input().size() >= 3) - check_dtype &= new_inp[2]->dtype().enumv() == DTypeEnum::Float32; - if (opr->input().size() >= 4) - check_dtype &= new_inp[3]->dtype().enumv() == DTypeEnum::Float32; - if (!check_dtype) - return nullptr; - auto inps = new_inp; - auto process = [&](size_t i) -> VarNode* { - auto iter = format_map.find(new_inp[i]->owner_opr()); - if (iter == format_map.end()) { - return inps[i]; - } else { - const auto& fmt = iter->second; - if (fmt == Format::NCHW32) { - auto ovar = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW); - return ovar.node(); - } else if (fmt == Format::NCHW4) { - auto ovar = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW); - return ovar.node(); - } else if (fmt == Format::NHWC) { - auto ovar = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW); - return ovar.node(); - } else { - mgb_assert(fmt == Format::NCHW64); - auto ovar = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType::NCHW64_TO_NCHW); - return ovar.node(); - } - } - }; - for (size_t i = 0; i < inps.size(); ++i) { - inps[i] = process(i); - } - auto ret = serialization::copy_opr_shallow(*opr, inps, opr->config()); - return ret->output()[0]; + for (size_t i = 0; i < inps.size(); ++i) { + inps[i] = process(i); + } + auto ret = serialization::copy_opr_shallow(*opr, inps, opr->config()); + return ret->output()[0]; }; auto try_transform_to_nchw4 = [make_new_conv, &format_map]( OperatorNodeBase* opr, const VarNodeArray& new_inp) -> VarNode* { - mgb_assert(opr->input().size() == new_inp.size()); + mgb_assert(opr->input().size()==new_inp.size()); bool check_dtype = new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 && new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8; mgb_assert(opr->output().size() > 0); bool dst_float = opr->output(0)->dtype().enumv() == DTypeEnum::Float32; if (opr->input().size() >= 3) { - auto dtype_expect = - dst_float ? DTypeEnum::Float32 : DTypeEnum::QuantizedS32; + auto dtype_expect = dst_float ? DTypeEnum::Float32 + : DTypeEnum::QuantizedS32; check_dtype &= new_inp[2]->dtype().enumv() == dtype_expect; } if (opr->input().size() >= 4) { @@ -4689,29 +2620,18 @@ std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { auto iter = format_map.find(new_inp[i]->owner_opr()); if (iter == format_map.end()) { auto ovar = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4); + inps[i], ReformatKey{TensorFormats::NCHW, + TensorFormats::NCHWc4}); return ovar.node(); } else { const auto& fmt = iter->second; if (fmt == Format::NCHW4) { return inps[i]; - } else if (fmt == Format::NCHW32) { - auto ovar = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4); - return ovar.node(); - } else if (fmt == Format::NHWC) { - auto ovar = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW4); - return ovar.node(); } else { - mgb_assert(fmt == Format::NCHW64); - auto ovar = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType::NCHW64_TO_NCHW4); - return ovar.node(); + ReformatKey key; + key.input_format = opr_format_to_tensor_formats(fmt); + key.output_format = TensorFormats::NCHWc4; + return RelayoutPlaceholder::make(inps[i], key).node(); } } }; @@ -4719,13 +2639,12 @@ std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { for (size_t i = 0; i < inps.size(); ++i) { // do not format bias and z when dst_float is true bool skip = dst_float && i >= 2; - if (!skip) - inps[i] = process(i); + if (!skip) inps[i] = process(i); } auto& conv_bias = opr->cast_final_safe(); - auto ret = - make_new_conv(inps, &conv_bias, - dst_float ? Format::NCHW4_NCHW : Format::NCHW4); + auto ret = make_new_conv( + inps, &conv_bias, + dst_float ? Format::NCHW4_NCHW : Format::NCHW4); if (!dst_float) format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW4)); return ret; @@ -4735,7 +2654,7 @@ std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { [make_new_conv, &format_map]( OperatorNodeBase* opr, const VarNodeArray& new_inp) -> VarNode* { - mgb_assert(opr->input().size() == new_inp.size()); + mgb_assert(opr->input().size()==new_inp.size()); bool check_dtype = new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 && new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8; @@ -4755,31 +2674,18 @@ std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { auto inps = new_inp; auto process = [&](size_t i) -> VarNode* { auto iter = format_map.find(new_inp[i]->owner_opr()); + ReformatKey key; + key.output_format = TensorFormats::NCHWc32; if (iter == format_map.end()) { - auto ovar = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW32); - return ovar.node(); + key.input_format = TensorFormats::NCHW; + return RelayoutPlaceholder::make(inps[i], key).node(); } else { const auto& fmt = iter->second; if (fmt == Format::NCHW32) { return inps[i]; - } else if (fmt == Format::NCHW4) { - auto ovar = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32); - return ovar.node(); - } else if (fmt == Format::NHWC) { - auto ovar = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW32); - return ovar.node(); } else { - mgb_assert(fmt == Format::NCHW64); - auto ovar = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType::NCHW64_TO_NCHW32); - return ovar.node(); + key.input_format = opr_format_to_tensor_formats(fmt); + return RelayoutPlaceholder::make(inps[i], key).node(); } } }; @@ -4797,17 +2703,18 @@ std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { OperatorNodeBase* opr, const VarNodeArray& new_inp) -> VarNode* { // fint4XWint4 and fuint4XWint4 - mgb_assert(opr->input().size() == new_inp.size()); + mgb_assert(opr->input().size()==new_inp.size()); bool check_dtype = (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 || - new_inp[0]->dtype().enumv() == DTypeEnum::Quantized4Asymm) && + new_inp[0]->dtype().enumv() == + DTypeEnum::Quantized4Asymm) && new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS4; if (opr->input().size() >= 3) check_dtype &= new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32; if (opr->input().size() >= 4) - check_dtype &= - new_inp[3]->dtype().enumv() == new_inp[0]->dtype().enumv(); + check_dtype &= new_inp[3]->dtype().enumv() == + new_inp[0]->dtype().enumv(); if (!check_dtype) return nullptr; size_t out_channels = opr->input(1)->shape()[0]; @@ -4818,31 +2725,20 @@ std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { auto inps = new_inp; auto process = [&](size_t i) -> VarNode* { auto iter = format_map.find(new_inp[i]->owner_opr()); + ReformatKey key; + key.output_format = TensorFormats::NCHWc64; if (iter == format_map.end()) { - auto ovar = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW64); - return ovar.node(); + key.input_format = TensorFormats::NCHW; + key.input_dtype = key.output_dtype = inps[i]->dtype().enumv(); + return RelayoutPlaceholder::make(inps[i], key).node(); } else { const auto& fmt = iter->second; if (fmt == Format::NCHW64) { return inps[i]; - } else if (fmt == Format::NCHW4) { - auto ovar = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW64); - return ovar.node(); - } else if (fmt == Format::NHWC) { - auto ovar = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW64); - return ovar.node(); } else { - mgb_assert(fmt == Format::NCHW32); - auto ovar = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW64); - return ovar.node(); + key.input_format = opr_format_to_tensor_formats(fmt); + key.input_dtype = key.output_dtype = inps[i]->dtype().enumv(); + return RelayoutPlaceholder::make(inps[i], key).node(); } } }; @@ -4860,17 +2756,18 @@ std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { OperatorNodeBase* opr, const VarNodeArray& new_inp) -> VarNode* { // fint4XWint4 and fuint4XWint4 - mgb_assert(opr->input().size() == new_inp.size()); + mgb_assert(opr->input().size()==new_inp.size()); bool check_dtype = (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 || - new_inp[0]->dtype().enumv() == DTypeEnum::Quantized4Asymm) && + new_inp[0]->dtype().enumv() == + DTypeEnum::Quantized4Asymm) && new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS4; if (opr->input().size() >= 3) check_dtype &= new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32; if (opr->input().size() >= 4) - check_dtype &= - new_inp[3]->dtype().enumv() == new_inp[0]->dtype().enumv(); + check_dtype &= new_inp[3]->dtype().enumv() == + new_inp[0]->dtype().enumv(); if (!check_dtype) return nullptr; size_t out_channels = opr->input(1)->shape()[0]; @@ -4881,30 +2778,19 @@ std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { auto inps = new_inp; auto process = [&](size_t i) -> VarNode* { auto iter = format_map.find(new_inp[i]->owner_opr()); + ReformatKey key; + key.output_format = TensorFormats::NHWC; + key.input_dtype = key.output_dtype = inps[i]->dtype().enumv(); if (iter == format_map.end()) { - auto ovar = RelayoutPlaceholder::make( - inps[i], RelayoutPlaceholder::LayoutType::NCHW_TO_NHWC); - return ovar.node(); + key.input_format = TensorFormats::NCHW; + return RelayoutPlaceholder::make(inps[i], key).node(); } else { const auto& fmt = iter->second; if (fmt == Format::NHWC) { return inps[i]; - } else if (fmt == Format::NCHW4) { - auto ovar = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType::NCHW4_TO_NHWC); - return ovar.node(); - } else if (fmt == Format::NCHW32) { - auto ovar = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType::NCHW32_TO_NHWC); - return ovar.node(); } else { - mgb_assert(fmt == Format::NCHW64); - auto ovar = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType::NCHW64_TO_NHWC); - return ovar.node(); + key.input_format = opr_format_to_tensor_formats(fmt); + return RelayoutPlaceholder::make(inps[i], key).node(); } } }; @@ -4982,38 +2868,17 @@ std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { } auto inps = new_inp; inps[0] = RelayoutPlaceholder::make( - inps[0], - RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4) + inps[0], ReformatKey{TensorFormats::NCHW, + TensorFormats::NCHWc4}) .node(); - switch (cur) { - case Format::NCHW: - inps[1] = RelayoutPlaceholder::make( - inps[1], RelayoutPlaceholder::LayoutType:: - NCHW_TO_NCHW4) - .node(); - break; - case Format::NHWC: - inps[1] = RelayoutPlaceholder::make( - inps[1], RelayoutPlaceholder::LayoutType:: - NHWC_TO_NCHW4) - .node(); - break; - case Format::NCHW32: - inps[1] = RelayoutPlaceholder::make( - inps[1], RelayoutPlaceholder::LayoutType:: - NCHW32_TO_NCHW4) - .node(); - break; - case Format::NCHW64: - inps[1] = RelayoutPlaceholder::make( - inps[1], RelayoutPlaceholder::LayoutType:: - NCHW64_TO_NCHW4) - .node(); - break; - default: - mgb_assert(cur == Format::NCHW4); - } - + if (cur != Format::NCHW4) { + inps[1] = RelayoutPlaceholder::make( + inps[1], + ReformatKey{opr_format_to_tensor_formats(cur), + TensorFormats::NCHWc4}) + .node(); + } + auto param = deconv.param(); param.format = Format::NCHW4; auto new_deconv = opr::ConvolutionBackwardData::make( @@ -5030,7 +2895,7 @@ std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { break; } } - mgb_assert(!shape_changed, + mgb_assert(!shape_changed, "EnableNCHW64Pass won't change format of output tensor " "of non quantized deconv operator(name:%s)", opr->cname()); @@ -5040,9 +2905,8 @@ std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { }; // replace rule for elemwise like opr - auto replace_elemwise_like_opr = [&format_map]( - OperatorNodeBase* opr, - const VarNodeArray& new_inp) { + auto replace_elemwise_like_opr = [&format_map](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); ThinHashMap format_size; bool same_format = true; @@ -5082,28 +2946,6 @@ std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { max_size = item.second; } } - static const ThinHashMap, - thin_function> - map = { -#define cb(_fmt1, _fmt2) \ - { \ - std::make_pair(Format::_fmt1, Format::_fmt2), \ - [](VarNode* in) -> VarNode* { \ - return RelayoutPlaceholder::make( \ - in, RelayoutPlaceholder::LayoutType:: \ - _fmt1##_TO_##_fmt2) \ - .node(); \ - } \ - } - cb(NCHW, NCHW4), cb(NCHW, NCHW32), cb(NCHW, NCHW64), - cb(NCHW4, NCHW), cb(NCHW4, NCHW32), cb(NCHW4, NCHW64), - cb(NCHW32, NCHW), cb(NCHW32, NCHW4), cb(NCHW32, NCHW64), - cb(NCHW32, NCHW), cb(NCHW32, NCHW4), cb(NCHW32, NCHW64), - cb(NCHW, NHWC), cb(NCHW4, NHWC), cb(NCHW32, NHWC), - cb(NCHW64, NHWC), cb(NHWC, NCHW), cb(NHWC, NCHW4), - cb(NHWC, NCHW32), cb(NHWC, NCHW64), -#undef cb - }; auto inps = new_inp; for (size_t i = 0; i < opr->input().size(); ++i) { auto iter = format_map.find(new_inp[i]->owner_opr()); @@ -5114,7 +2956,10 @@ std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { cur = Format::NCHW; } if (cur != max_format) { - inps[i] = map.at(std::make_pair(cur, max_format))(inps[i]); + ReformatKey key{opr_format_to_tensor_formats(cur), + opr_format_to_tensor_formats(max_format)}; + key.input_dtype = key.output_dtype = inps[i]->dtype().enumv(); + inps[i] = RelayoutPlaceholder::make(inps[i], key).node(); } } auto ret = serialization::copy_opr_shallow(*opr, inps, opr->config()); @@ -5144,27 +2989,11 @@ std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { cur = iter->second; } auto inps = new_inp; - switch (cur) { - case Format::NCHW: - inps[0] = RelayoutPlaceholder::make( - inps[0], RelayoutPlaceholder::LayoutType:: - NCHW_TO_NHWC) - .node(); - break; - case Format::NCHW4: - inps[0] = RelayoutPlaceholder::make( - inps[0], RelayoutPlaceholder::LayoutType:: - NCHW4_TO_NHWC) - .node(); - break; - case Format::NCHW32: - inps[0] = RelayoutPlaceholder::make( - inps[0], RelayoutPlaceholder::LayoutType:: - NCHW32_TO_NHWC) - .node(); - break; - default: - mgb_assert(cur == Format::NCHW64 || cur == Format::NHWC); + if (cur != Format::NCHW64 && cur != Format::NHWC) { + ReformatKey key{opr_format_to_tensor_formats(cur), + TensorFormats::NHWC, inps[0]->dtype().enumv(), + inps[0]->dtype().enumv()}; + inps[0] = RelayoutPlaceholder::make(inps[0], key).node(); } auto target_format = cur == Format::NCHW64 ? cur : Format::NHWC; auto param = warp.param(); @@ -5172,7 +3001,8 @@ std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { SymbolVar new_warp; if (inps.size() == 3) { new_warp = opr::WarpPerspectiveForward::make( - inps[0], inps[1], inps[2], param, warp.config()); + inps[0], inps[1], inps[2], param, + warp.config()); } else { mgb_assert(inps.size() == 4); new_warp = opr::WarpPerspectiveForward::make( @@ -5191,41 +3021,20 @@ std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { cur = iter->second; } auto inps = new_inp; - switch (cur) { - case Format::NCHW: - inps[0] = RelayoutPlaceholder::make( - inps[0], RelayoutPlaceholder::LayoutType:: - NCHW_TO_NCHW4) - .node(); - break; - case Format::NHWC: - inps[0] = RelayoutPlaceholder::make( - inps[0], RelayoutPlaceholder::LayoutType:: - NHWC_TO_NCHW4) - .node(); - break; - case Format::NCHW32: - inps[0] = RelayoutPlaceholder::make( - inps[0], RelayoutPlaceholder::LayoutType:: - NCHW32_TO_NCHW4) - .node(); - break; - case Format::NCHW64: - inps[0] = RelayoutPlaceholder::make( - inps[0], RelayoutPlaceholder::LayoutType:: - NCHW64_TO_NCHW4) - .node(); - break; - default: - mgb_assert(cur == Format::NCHW4); + if (cur != Format::NCHW4) { + inps[0] = RelayoutPlaceholder::make( + inps[0], + ReformatKey{opr_format_to_tensor_formats(cur), + TensorFormats::NCHWc4}) + .node(); } - auto param = warp.param(); param.format = Format::NCHW4; SymbolVar new_warp; if (inps.size() == 3) { new_warp = opr::WarpPerspectiveForward::make( - inps[0], inps[1], inps[2], param, warp.config()); + inps[0], inps[1], inps[2], param, + warp.config()); } else { mgb_assert(inps.size() == 4); new_warp = opr::WarpPerspectiveForward::make( @@ -5243,7 +3052,7 @@ std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { break; } } - mgb_assert(!shape_changed, + mgb_assert(!shape_changed, "EnableNCHW64Pass won't change format of output tensor " "of non quantized warp perspective operator(name:%s)", opr->cname()); @@ -5251,8 +3060,9 @@ std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { opr->config()); } }; - auto replace_pooling_opr = [&format_map](OperatorNodeBase* opr, - const VarNodeArray& new_inp) { + auto replace_pooling_opr = [&format_map]( + OperatorNodeBase* opr, + const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); auto& pooling = opr->cast_final_safe(); if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 || @@ -5265,27 +3075,11 @@ std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { cur = iter->second; } auto inps = new_inp; - switch (cur) { - case Format::NCHW: - inps[0] = RelayoutPlaceholder::make( - inps[0], RelayoutPlaceholder::LayoutType:: - NCHW_TO_NHWC) - .node(); - break; - case Format::NCHW4: - inps[0] = RelayoutPlaceholder::make( - inps[0], RelayoutPlaceholder::LayoutType:: - NCHW4_TO_NHWC) - .node(); - break; - case Format::NCHW32: - inps[0] = RelayoutPlaceholder::make( - inps[0], RelayoutPlaceholder::LayoutType:: - NCHW32_TO_NHWC) - .node(); - break; - default: - mgb_assert(cur == Format::NCHW64 || cur == Format::NHWC); + if (cur != Format::NCHW64 && cur != Format::NHWC) { + ReformatKey key{opr_format_to_tensor_formats(cur), + TensorFormats::NHWC, inps[0]->dtype().enumv(), + inps[0]->dtype().enumv()}; + inps[0] = RelayoutPlaceholder::make(inps[0], key).node(); } auto target_format = cur == Format::NCHW64 ? cur : Format::NHWC; auto param = pooling.param(); @@ -5305,30 +3099,31 @@ std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { } bool use_nchw32 = false; auto inps = new_inp; - using LayoutType = RelayoutPlaceholder::LayoutType; + ReformatKey key; switch (cur) { case Format::NCHW: { size_t in_channels = new_inp[0]->shape()[1]; use_nchw32 = in_channels % 32 == 0; - auto layout_type = use_nchw32 ? LayoutType::NCHW_TO_NCHW32 - : LayoutType::NCHW_TO_NCHW4; - inps[0] = RelayoutPlaceholder::make(inps[0], layout_type) - .node(); + key.input_format = TensorFormats::NCHW; + key.output_format = use_nchw32 ? TensorFormats::NCHWc32 + : TensorFormats::NCHWc4; + inps[0] = RelayoutPlaceholder::make(inps[0], key).node(); break; } case Format::NHWC: { size_t in_channels = new_inp[0]->shape()[3]; use_nchw32 = in_channels % 32 == 0; - auto layout_type = use_nchw32 ? LayoutType::NHWC_TO_NCHW32 - : LayoutType::NHWC_TO_NCHW4; - inps[0] = RelayoutPlaceholder::make(inps[0], layout_type) - .node(); + key.input_format = TensorFormats::NHWC; + key.output_format = use_nchw32 ? TensorFormats::NCHWc32 + : TensorFormats::NCHWc4; + inps[0] = RelayoutPlaceholder::make(inps[0], key).node(); break; } case Format::NCHW64: inps[0] = RelayoutPlaceholder::make( - inps[0], RelayoutPlaceholder::LayoutType:: - NCHW64_TO_NCHW32) + inps[0], + ReformatKey{TensorFormats::NCHWc64, + TensorFormats::NCHWc32}) .node(); break; case Format::NCHW32: @@ -5338,7 +3133,7 @@ std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { mgb_assert(cur == Format::NCHW4); } Format out_format = use_nchw32 ? Format::NCHW32 : Format::NCHW4; - + auto param = pooling.param(); param.format = out_format; auto new_pool = @@ -5374,39 +3169,12 @@ std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { auto inps = new_inp; for (size_t i = 0; i < opr->input().size(); ++i) { auto iter = format_map.find(new_inp[i]->owner_opr()); - auto fmt = iter != format_map.end() ? iter->second : Format::NCHW; + auto fmt = iter != format_map.end()?iter->second:Format::NCHW; if (iter != format_map.end()) { - switch (fmt) { - case Format::NHWC: - inps[i] = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType:: - NHWC_TO_NCHW) - .node(); - break; - case Format::NCHW4: - inps[i] = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType:: - NCHW4_TO_NCHW) - .node(); - break; - case Format::NCHW32: - inps[i] = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType:: - NCHW32_TO_NCHW) - .node(); - break; - default: - mgb_assert(fmt == Format::NCHW64); - inps[i] = RelayoutPlaceholder::make( - inps[i], - RelayoutPlaceholder::LayoutType:: - NCHW64_TO_NCHW) - .node(); - break; - } + ReformatKey key{opr_format_to_tensor_formats(fmt), + TensorFormats::NCHW, inps[i]->dtype().enumv(), + inps[i]->dtype().enumv()}; + inps[i] = RelayoutPlaceholder::make(inps[i], key).node(); } } auto ret = serialization::copy_opr_shallow(*opr, inps, opr->config()); @@ -5422,4 +3190,5 @@ std::unique_ptr EnableNCHW64Pass::make_nchw64_converter() { return ret; MIDOUT_E } + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/include/megbrain/gopt/inference.h b/src/gopt/include/megbrain/gopt/inference.h index 76e831ea3..45f2f1875 100644 --- a/src/gopt/include/megbrain/gopt/inference.h +++ b/src/gopt/include/megbrain/gopt/inference.h @@ -227,6 +227,7 @@ namespace gopt { VarReplaceCheckFlag m_var_replace_check_flag = VarReplaceCheckFlag::CHECK_ALL; class RelayoutPlaceholder; + friend class ShuffleShuffleRemovePass; public: TensorReformatPass& set_var_replace_check_flag(VarReplaceCheckFlag flag) { diff --git a/src/gopt/include/megbrain/gopt/reformat_manager.h b/src/gopt/include/megbrain/gopt/reformat_manager.h index 547fab19f..b6180ad60 100644 --- a/src/gopt/include/megbrain/gopt/reformat_manager.h +++ b/src/gopt/include/megbrain/gopt/reformat_manager.h @@ -49,10 +49,14 @@ enum class TensorFormats : uint32_t { KRSCk8 = 21, ///< [K/8, R, S, C, K%8] + // NCHW4 + KCRSc4 = 22, ///< [K, C/4, R, S, C%4] + GKCRSc4 = 23, ///< [G, K, C/4, R, S, C%4] + // default weight format - KCRS = 22, ///< [K, C, R, S] - GKCRS = 23, ///< [G, K, C, R, S] - C11RS = 24, ///< [C, 1, 1, R, S] + KCRS = 24, ///< [K, C, R, S] + GKCRS = 25, ///< [G, K, C, R, S] + C11RS = 26, ///< [C, 1, 1, R, S] }; class ReformatManager : public NonCopyableObj { @@ -60,16 +64,20 @@ class ReformatManager : public NonCopyableObj { public: using ReformatImpl = thin_function; - enum class Attribute : uint32_t { - DEFAULT = 0, - IMAGE2D = 1 << 0, - IC_SMALL = 1 << 1, - }; struct ReformatKey { + enum class Attribute : uint32_t { + DEFAULT = 0, + IMAGE2D = 1 << 0, + IC_SMALL = 1 << 1, + }; TensorFormats input_format, output_format; DTypeEnum input_dtype, output_dtype; Attribute attribute; std::string to_string() const; + ReformatKey() + : input_dtype{DTypeEnum::Float32}, + output_dtype{DTypeEnum::Float32}, + attribute{Attribute::DEFAULT} {} ReformatKey(TensorFormats input_format_, TensorFormats output_format_, Attribute attribute_ = Attribute::DEFAULT, DTypeEnum input_dtype_ = DTypeEnum::Float32, @@ -86,11 +94,13 @@ public: bool operator()(const ReformatKey& lhs, const ReformatKey& rhs) const; }; + ReformatKey& deduce_reformat_dtype_enum(const DType& dt); }; using ReformatCache = std::unordered_map; - const ReformatImpl& get(const ReformatKey& key) const; + ReformatImpl get(const ReformatKey& key) const; + ReformatImpl get(ReformatKey&& key) const { return get(key); } static const ReformatManager& instance(); private: diff --git a/src/gopt/test/reformat_manager.cpp b/src/gopt/test/reformat_manager.cpp new file mode 100644 index 000000000..047069538 --- /dev/null +++ b/src/gopt/test/reformat_manager.cpp @@ -0,0 +1,171 @@ +/** + * \file src/gopt/test/reformat_manager.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "./helper.h" + +#include "megbrain/gopt/reformat_manager.h" +#include "megbrain/opr/tensor_manip.h" + +using namespace mgb; +using namespace gopt; + +TEST(TestReformatManager, Feature) { + constexpr size_t N = 16, C = 128, H = 7, W = 7; + HostTensorGenerator<> gen; + using ReformatKey = ReformatManager::ReformatKey; + auto src_format = TensorFormats::NHWC, dst_format = TensorFormats::NCHWc64; + ReformatKey key{src_format, dst_format}; + auto reformat = ReformatManager::instance().get(key); + + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + + auto r = [](VarNode* inp) { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0), sub(1), sub(2), sub(3) / 64, cv(64)}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4}); + return y1; + }; + + auto mkvar = [&](const char* name, const TensorShape& shp) { + return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); + }; + auto x = mkvar("x", {N, H, W, C}); + auto y1 = SymbolVar(reformat({x.node()})); + auto y2 = r(x.node()); + size_t nr_shapeof = 0; + size_t nr_reshape = 0; + cg::DepOprIter{[&nr_shapeof, &nr_reshape](cg::OperatorNodeBase* o) { + if (o->same_type()) + nr_shapeof++; + if (o->same_type()) + nr_reshape++; + }} + .add(y1.node()->owner_opr()); + ASSERT_EQ(nr_shapeof, 1); + ASSERT_EQ(nr_reshape, 1); + HostTensorND t1, t2; + auto func1 = graph->compile({make_callback_copy(y1, t1)}); + func1->execute(); + auto func2 = graph->compile({make_callback_copy(y2, t2)}); + func2->execute(); + MGB_ASSERT_TENSOR_EQ(t1, t2); +} + +TEST(TestReformatManager, Weight) { + constexpr size_t G = 8, K = 128, C = 128, R = 3, S = 3; + HostTensorGenerator<> gen; + using ReformatKey = ReformatManager::ReformatKey; + auto src_format = TensorFormats::GKCRS, + dst_format = TensorFormats::GKCRSk4c4; + ReformatKey key{src_format, dst_format}; + auto reformat = ReformatManager::instance().get(key); + + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + + auto r = [](VarNode* inp) { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2) / 4, + cv(4), sub(3), sub(4)}, + 0), + tshp1 = opr::Concat::make({sub(0), sub(1) / 4, sub(2) / 4, sub(3), + sub(4), cv(4), cv(4)}, + 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 2, 4}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2; + }; + + auto mkvar = [&](const char* name, const TensorShape& shp) { + return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); + }; + auto w = mkvar("w", {G, K / G, C / G, R, S}); + auto y1 = SymbolVar(reformat({w.node()})); + auto y2 = r(w.node()); + size_t nr_shapeof = 0; + size_t nr_reshape = 0; + cg::DepOprIter{[&nr_shapeof, &nr_reshape](cg::OperatorNodeBase* o) { + if (o->same_type()) + nr_shapeof++; + if (o->same_type()) + nr_reshape++; + }} + .add(y1.node()->owner_opr()); + ASSERT_EQ(nr_shapeof, 1); + ASSERT_EQ(nr_reshape, 1); + HostTensorND t1, t2; + auto func1 = graph->compile({make_callback_copy(y1, t1)}); + func1->execute(); + auto func2 = graph->compile({make_callback_copy(y2, t2)}); + func2->execute(); + MGB_ASSERT_TENSOR_EQ(t1, t2); +} + +TEST(TestReformatManager, InvalidKey) { + using ReformatKey = ReformatManager::ReformatKey; + using Attribute = ReformatKey::Attribute; + auto src_format = TensorFormats::GKCRS, + dst_format = TensorFormats::GKCRSk4c4; + Attribute attribute = Attribute::IMAGE2D; + ReformatKey key{src_format, dst_format, attribute}; + ASSERT_THROW(ReformatManager::instance().get(key), AssertionError); +} + +TEST(TestReformatManager, InputChannelSmall) { + constexpr size_t N = 16, C = 3, H = 224, W = 224; + auto cn = CompNode::load("cpux"); + HostTensorGenerator<> gen; + using ReformatKey = ReformatManager::ReformatKey; + using Attribute = ReformatKey::Attribute; + auto src_format = TensorFormats::NCHW, dst_format = TensorFormats::NCHWc4; + ReformatKey key{src_format, dst_format, Attribute::IC_SMALL}; + auto reformat = ReformatManager::instance().get(key); + + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + + auto r = [](VarNode* inp) { + auto x = SymbolVar(inp); + auto y = opr::RelayoutFormat::make( + x, megdnn::param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL); + return y; + }; + + auto mkvar = [&](const char* name, const TensorShape& shp) { + return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); + }; + auto x = mkvar("x", {N, C, H, W}); + auto y1 = SymbolVar(reformat({x.node()})); + auto y2 = r(x.node()); + HostTensorND t1, t2; + auto func1 = graph->compile({make_callback_copy(y1, t1)}); + func1->execute(); + auto func2 = graph->compile({make_callback_copy(y2, t2)}); + func2->execute(); + MGB_ASSERT_TENSOR_EQ(t1, t2); +} + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} -- GitLab