From c33126ab5c1ec80fa771db245783ec08fa3ee8ea Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 16 Jul 2021 18:54:14 +0800 Subject: [PATCH] feat(mgb/gopt): add reformat manager GitOrigin-RevId: b9791b131a996f4f7c55173294abf0715fac97fc --- dnn/include/megdnn/named_tensor.h | 1 + dnn/src/common/named_tensor.cpp | 8 +- dnn/test/common/named_tensor.cpp | 2 +- src/gopt/impl/reformat_emitter.cpp | 15 +- src/gopt/impl/reformat_manager.cpp | 399 ++++++++++++++++++ .../include/megbrain/gopt/reformat_manager.h | 103 +++++ 6 files changed, 515 insertions(+), 13 deletions(-) create mode 100644 src/gopt/impl/reformat_manager.cpp create mode 100644 src/gopt/include/megbrain/gopt/reformat_manager.h diff --git a/dnn/include/megdnn/named_tensor.h b/dnn/include/megdnn/named_tensor.h index 2a6ab0139..8b04730b2 100644 --- a/dnn/include/megdnn/named_tensor.h +++ b/dnn/include/megdnn/named_tensor.h @@ -30,6 +30,7 @@ public: C = 'C', // input channel H = 'H', // input height W = 'W', // input width + G = 'G', // group K = 'K', // output channel R = 'R', // filter height S = 'S', // filter width diff --git a/dnn/src/common/named_tensor.cpp b/dnn/src/common/named_tensor.cpp index afa6027f8..cee817322 100644 --- a/dnn/src/common/named_tensor.cpp +++ b/dnn/src/common/named_tensor.cpp @@ -18,8 +18,9 @@ using namespace megdnn; /* ===================== Dimension ============================ */ const Dimension::Name Dimension::NAME_ALL[] = { Dimension::Name::N, Dimension::Name::C, Dimension::Name::H, - Dimension::Name::W, Dimension::Name::K, Dimension::Name::R, - Dimension::Name::S, Dimension::Name::P, Dimension::Name::Q, + Dimension::Name::W, Dimension::Name::G, Dimension::Name::K, + Dimension::Name::R, Dimension::Name::S, Dimension::Name::P, + Dimension::Name::Q, }; const int Dimension::NR_NAMES = sizeof(Dimension::NAME_ALL); Dimension::Dimension(const std::string& expr) { @@ -92,6 +93,9 @@ bool Dimension::operator<(const Dimension& rhs) const { if (m_name != rhs.m_name) { return static_cast(m_name) < static_cast(rhs.m_name); } + if (m_stride == rhs.m_stride) { + return m_extent > rhs.m_extent; + } return m_stride > rhs.m_stride; } diff --git a/dnn/test/common/named_tensor.cpp b/dnn/test/common/named_tensor.cpp index 4c790a133..89b9eae11 100644 --- a/dnn/test/common/named_tensor.cpp +++ b/dnn/test/common/named_tensor.cpp @@ -21,7 +21,7 @@ using namespace megdnn; using megdnn::test::MegDNNError; TEST(NAMED_TENSOR, NAMED_TENSOR_SHAPE_BASIC) { - ASSERT_EQ(Dimension::NR_NAMES, 9); + ASSERT_EQ(Dimension::NR_NAMES, 10); Dimension dim0 = {"C"}, dim1 = {"C//32"}, dim2 = {"C//4"}, dim3 = {"C%32"}, dim4 = {"C%4"}, dim5 = {"C//4%8"}; ASSERT_TRUE(dim0 == dim1 * dim3); diff --git a/src/gopt/impl/reformat_emitter.cpp b/src/gopt/impl/reformat_emitter.cpp index 9b5b37d84..7f3d41169 100644 --- a/src/gopt/impl/reformat_emitter.cpp +++ b/src/gopt/impl/reformat_emitter.cpp @@ -182,7 +182,6 @@ ReformatEmitter::UnderlyingBuilders ReformatEmitter::analyze() const { }; std::sort(src_dims.begin(), src_dims.end(), compare); std::sort(dest_dims.begin(), dest_dims.end(), compare); - auto src_iter = src_dims.begin(); auto dest_iter = dest_dims.begin(); for (; src_iter != src_dims.end() && dest_iter != dest_dims.end();) { @@ -233,18 +232,14 @@ ReformatEmitter::UnderlyingBuilders ReformatEmitter::analyze() const { } UnderlyingBuilders builders; if (!m_src.eq_shape(i1)) { - builders.make_shape1 = - std::move(std::get<0>(MakeShapeEmitter(m_src, i1).emit())); - builders.reshape1 = - std::move(std::get<0>(ReshapeEmitter(m_src, i1).emit())); + builders.make_shape1 = std::get<0>(MakeShapeEmitter(m_src, i1).emit()); + builders.reshape1 = std::get<0>(ReshapeEmitter(m_src, i1).emit()); } - builders.dimshuffle = - std::move(std::get<0>(DimshuffleEmitter(permute).emit())); + builders.dimshuffle = std::get<0>(DimshuffleEmitter(permute).emit()); if (!m_dest.eq_shape(i2)) { builders.make_shape2 = - std::move(std::get<0>(MakeShapeEmitter(m_src, m_dest).emit())); - builders.reshape2 = - std::move(std::get<0>(ReshapeEmitter(i2, m_dest).emit())); + std::get<0>(MakeShapeEmitter(m_src, m_dest).emit()); + builders.reshape2 = std::get<0>(ReshapeEmitter(i2, m_dest).emit()); } return builders; } diff --git a/src/gopt/impl/reformat_manager.cpp b/src/gopt/impl/reformat_manager.cpp new file mode 100644 index 000000000..8fa6e7604 --- /dev/null +++ b/src/gopt/impl/reformat_manager.cpp @@ -0,0 +1,399 @@ +/** + * \file src/gopt/impl/reformat_manager.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain/gopt/reformat_manager.h" +#include +#include "megbrain/opr/tensor_manip.h" + +using namespace mgb; +using namespace gopt; +using NamedTensorShape = megdnn::NamedTensorShape; + +namespace { +NamedTensorShape tensor_formats_to_named_tensor_shape(TensorFormats format) { + switch (format) { + case TensorFormats::NCHW: + return {{"N"}, {"C"}, {"H"}, {"W"}}; + case TensorFormats::NHWC: + return {{"N"}, {"H"}, {"W"}, {"C"}}; + case TensorFormats::NCHWc4: + return {{"N"}, {"C//4"}, {"H"}, {"W"}, {"C%4"}}; + case TensorFormats::NCHWc8: + return {{"N"}, {"C//8"}, {"H"}, {"W"}, {"C%8"}}; + case TensorFormats::NCHWc32: + return {{"N"}, {"C//32"}, {"H"}, {"W"}, {"C%32"}}; + case TensorFormats::NCHWc64: + return {{"N"}, {"C//64"}, {"H"}, {"W"}, {"C%64"}}; + case TensorFormats::CHWNc4: + return {{"C//4"}, {"H"}, {"W"}, {"N"}, {"C%4"}}; + case TensorFormats::NHCWc4: + return {{"N"}, {"H"}, {"C//4"}, {"W"}, {"C%4"}}; + case TensorFormats::KRSCk4: + return {{"K//4"}, {"R"}, {"S"}, {"C"}, {"K%4"}}; + case TensorFormats::GKRSCk4: + return {{"G"}, {"K//4"}, {"R"}, {"S"}, {"C"}, {"K%4"}}; + case TensorFormats::C1RSc4: + return {{"C//4"}, {"C%1"}, {"R"}, {"S"}, {"C%4"}}; + case TensorFormats::KRSCk4c4: + return {{"K//4"}, {"R"}, {"S"}, {"C//4"}, {"K%4"}, {"C%4"}}; + case TensorFormats::GKRSCk4c4: + return {{"G"}, {"K//4"}, {"R"}, {"S"}, {"C//4"}, {"K%4"}, {"C%4"}}; + case TensorFormats::KCRSk4c4: + return {{"K//4"}, {"C//4"}, {"R"}, {"S"}, {"K%4"}, {"C%4"}}; + case TensorFormats::GKCRSk4c4: + return {{"G"}, {"K//4"}, {"C//4"}, {"R"}, {"S"}, {"K%4"}, {"C%4"}}; + case TensorFormats::KCRSc4k4: + return {{"K//4"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}, {"K%4"}}; + case TensorFormats::GKCRSc4k4: + return {{"G"}, {"K//4"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}, {"K%4"}}; + case TensorFormats::C11RSc4: + return {{"C//4"}, {"C%1"}, {"C%1"}, {"R"}, {"S"}, {"C%4"}}; + case TensorFormats::KCRSc8k8: + return {{"K//8"}, {"C//8"}, {"R"}, {"S"}, {"C%8"}, {"K%8"}}; + case TensorFormats::GKCRSc8k8: + return {{"G"}, {"K//8"}, {"C//8"}, {"R"}, {"S"}, {"C%8"}, {"K%8"}}; + case TensorFormats::C11RSc8: + return {{"C//8"}, {"C%1"}, {"C%1"}, {"R"}, {"S"}, {"C%8"}}; + case TensorFormats::KRSCk8: + return {{"K//8"}, {"R"}, {"S"}, {"C"}, {"K%8"}}; + case TensorFormats::KCRS: + return {{"K"}, {"C"}, {"R"}, {"S"}}; + case TensorFormats::GKCRS: + return {{"G"}, {"K"}, {"C"}, {"R"}, {"S"}}; + case TensorFormats::C11RS: + return {{"C"}, {"C%1"}, {"C%1"}, {"R"}, {"S"}}; + default: + mgb_throw(AssertionError, "invalid tensor formats(%u)", + static_cast(format)); + } +} +}; // namespace + +// =================== ReformatManager::ReformatKey ====================*/ +std::string ReformatManager::ReformatKey::to_string() const { + auto&& i = tensor_formats_to_named_tensor_shape(input_format); + auto&& o = tensor_formats_to_named_tensor_shape(output_format); + std::string input_name, output_name; + +#define cb(_name) \ + if (input_dtype == DTypeEnum::_name) { \ + input_name = #_name; \ + } else + MEGDNN_FOREACH_DTYPE_NAME(cb) + MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb) { + mgb_throw(MegBrainError, "invalid input dtype enum(%u)", + static_cast(input_dtype)); + } +#undef cb +#define cb(_name) \ + if (output_dtype == DTypeEnum::_name) { \ + output_name = #_name; \ + } else + MEGDNN_FOREACH_DTYPE_NAME(cb) + MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb) { + mgb_throw(MegBrainError, "invalid output dtype enum(%u)", + static_cast(output_dtype)); + } +#undef cb + return ssprintf("%s;%s;%s;%s;%s", i.to_string().c_str(), + o.to_string().c_str(), + std::to_string(static_cast(attribute)).c_str(), + input_name.c_str(), output_name.c_str()); +} + +size_t ReformatManager::ReformatKey::Hash::operator()( + const ReformatKey& key) const { + auto enumhash = mgb::enumhash(); + size_t h = enumhash(key.input_format); + h = mgb::hash_pair_combine(h, enumhash(key.output_format)); + h = mgb::hash_pair_combine(h, enumhash(key.attribute)); + h = mgb::hash_pair_combine(h, enumhash(key.input_dtype)); + h = mgb::hash_pair_combine(h, enumhash(key.output_dtype)); + return h; +} + +bool ReformatManager::ReformatKey::Equal::operator()( + const ReformatKey& lhs, const ReformatKey& rhs) const { + return lhs.input_format == rhs.input_format && + lhs.output_format == rhs.output_format && + lhs.input_dtype == rhs.input_dtype && + lhs.output_dtype == rhs.output_dtype && + lhs.attribute == rhs.attribute; +} + +// =================== ReformatManager ====================*/ +#define FOREACH_FEATURE_TENSOR_FORMATS(cb) \ + cb(NCHW) cb(NHWC) cb(NCHWc4) cb(NCHWc8) cb(NCHWc32) cb(NCHWc64) cb(CHWNc4) \ + cb(NHCWc4) +#define FOREACH_WEIGHT_TENSOR_FORMATS(cb) \ + cb(KRSCk4) cb(KRSCk4c4) cb(KCRSk4c4) cb(KCRSc4k4) cb(KCRSc8k8) cb(KRSCk8) \ + cb(GKRSCk4) cb(GKRSCk4c4) cb(GKCRSc4k4) cb(GKCRSk4c4) \ + cb(GKCRSc8k8) cb(C11RSc4) cb(C11RSc8) +ReformatManager::ReformatManager() { + static constexpr TensorFormats feature_tensor_formats[] = { +#define cb(_fmt) TensorFormats::_fmt, + FOREACH_FEATURE_TENSOR_FORMATS(cb) +#undef cb + }; + static constexpr int nr_feature_tensor_formats = + sizeof(feature_tensor_formats) / sizeof(TensorFormats); + for (int i = 0; i < nr_feature_tensor_formats; ++i) { + for (int o = 0; o < nr_feature_tensor_formats; ++o) { + if (i == o) + continue; + NamedTensorShape input_shape = tensor_formats_to_named_tensor_shape( + feature_tensor_formats[i]); + NamedTensorShape output_shape = + tensor_formats_to_named_tensor_shape( + feature_tensor_formats[o]); + auto impl = std::get<0>( + ReformatEmitter{input_shape, output_shape}.emit()); + m_cache.emplace(ReformatKey{feature_tensor_formats[i], + feature_tensor_formats[o]}, + impl); + } + } + static constexpr TensorFormats default_weight_tensor_formats = + TensorFormats::KCRS; + static constexpr TensorFormats default_group_conv_weight_tensor_formats = + TensorFormats::GKCRS; + static constexpr TensorFormats default_chan_conv_weight_tensor_formats = + TensorFormats::C11RS; + static constexpr TensorFormats weight_tensor_formats[] = { +#define cb(_fmt) TensorFormats::_fmt, + FOREACH_WEIGHT_TENSOR_FORMATS(cb) +#undef cb + }; + static constexpr int nr_weight_tensor_formats = + sizeof(weight_tensor_formats) / sizeof(TensorFormats); + using Name = megdnn::Dimension::Name; + for (int o = 0; o < nr_weight_tensor_formats; ++o) { + NamedTensorShape output_shape = + tensor_formats_to_named_tensor_shape(weight_tensor_formats[o]); + TensorFormats input_format; + if (output_shape[0].name() == Name::G) { + input_format = default_group_conv_weight_tensor_formats; + } else if (output_shape[0].name() == Name::C) { + input_format = default_chan_conv_weight_tensor_formats; + } else { + mgb_assert(output_shape[0].name() == Name::K); + input_format = default_weight_tensor_formats; + } + NamedTensorShape input_shape = + tensor_formats_to_named_tensor_shape(input_format); + auto impl = + std::get<0>(ReformatEmitter{input_shape, output_shape}.emit()); + m_cache.emplace(ReformatKey{input_format, weight_tensor_formats[o]}, + impl); + } + { + auto i = TensorFormats::NCHW, o = TensorFormats::NCHWc4; + auto&& impl = [](const VarNodeArray& vars) { + return opr::RelayoutFormat::make(vars[0], + megdnn::param::RelayoutFormat:: + Mode::NCHW_NCHW4_IC_SMALL) + .node(); + }; + m_cache.emplace(ReformatKey{i, o, Attribute::IC_SMALL}, impl); + } + { + auto i = TensorFormats::KCRS, o = TensorFormats::KCRSc4k4; + auto&& impl = [](const VarNodeArray& vars) { + return opr::RelayoutFormat::make( + vars[0], + megdnn::param::RelayoutFormat::Mode:: + NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) + .node(); + }; + m_cache.emplace(ReformatKey{i, o, Attribute::IC_SMALL}, impl); + } + { + auto i = TensorFormats::NCHW, o = TensorFormats::NCHWc64; + auto&& impl = [](const VarNodeArray& vars) { + return opr::RelayoutFormat::make( + vars[0], + megdnn::param::RelayoutFormat::Mode::NCHW_NCHW64) + .node(); + }; + m_cache.emplace( + ReformatKey{i, o, Attribute::DEFAULT, DTypeEnum::QuantizedS4, + DTypeEnum::QuantizedS4}, + impl); + m_cache.emplace(ReformatKey{i, o, Attribute::DEFAULT, + DTypeEnum::Quantized4Asymm, + DTypeEnum::Quantized4Asymm}, + impl); + } + { + auto i = TensorFormats::NCHWc64, o = TensorFormats::NCHW; + auto&& impl = [](const VarNodeArray& vars) { + return opr::RelayoutFormat::make( + vars[0], + megdnn::param::RelayoutFormat::Mode::NCHW_NCHW64) + .node(); + }; + m_cache.emplace( + ReformatKey{i, o, Attribute::DEFAULT, DTypeEnum::QuantizedS4, + DTypeEnum::QuantizedS4}, + impl); + m_cache.emplace(ReformatKey{i, o, Attribute::DEFAULT, + DTypeEnum::Quantized4Asymm, + DTypeEnum::Quantized4Asymm}, + impl); + } + { + auto i = TensorFormats::NCHW, o = TensorFormats::NHWC; + auto&& impl = [](const VarNodeArray& vars) { + return opr::RelayoutFormat::make( + vars[0], + megdnn::param::RelayoutFormat::Mode::NCHW_NHWC) + .node(); + }; + m_cache.emplace( + ReformatKey{i, o, Attribute::DEFAULT, DTypeEnum::QuantizedS4, + DTypeEnum::QuantizedS4}, + impl); + m_cache.emplace(ReformatKey{i, o, Attribute::DEFAULT, + DTypeEnum::Quantized4Asymm, + DTypeEnum::Quantized4Asymm}, + impl); + } + { + auto i = TensorFormats::NHWC, o = TensorFormats::NCHW; + auto&& impl = [](const VarNodeArray& vars) { + return opr::RelayoutFormat::make( + vars[0], + megdnn::param::RelayoutFormat::Mode::NCHW_NHWC) + .node(); + }; + m_cache.emplace( + ReformatKey{i, o, Attribute::DEFAULT, DTypeEnum::QuantizedS4, + DTypeEnum::QuantizedS4}, + impl); + m_cache.emplace(ReformatKey{i, o, Attribute::DEFAULT, + DTypeEnum::Quantized4Asymm, + DTypeEnum::Quantized4Asymm}, + impl); + } + // nhcw4 + { + auto i = TensorFormats::KCRS, o = TensorFormats::KRSCk4; + auto&& impl = [](const VarNodeArray& vars) { + return opr::RelayoutFormat::make(vars[0], + megdnn::param::RelayoutFormat:: + Mode::INTER_WEIGHT_DENSEI) + .node(); + }; + m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); + } + { + auto i = TensorFormats::KCRS, o = TensorFormats::GKRSCk4; + auto&& impl = [](const VarNodeArray& vars) { + return opr::RelayoutFormat::make(vars[0], + megdnn::param::RelayoutFormat:: + Mode::INTER_WEIGHT_GROUPI) + .node(); + }; + m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); + } + { + auto i = TensorFormats::KCRS, o = TensorFormats::C1RSc4; + auto&& impl = [](const VarNodeArray& vars) { + return opr::RelayoutFormat::make(vars[0], + megdnn::param::RelayoutFormat:: + Mode::INTER_WEIGHT_CHANI) + .node(); + }; + m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); + } + { + auto i = TensorFormats::NCHW, o = TensorFormats::NHCWc4; + auto&& impl = [](const VarNodeArray& vars) { + return opr::RelayoutFormat::make( + vars[0], + megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I) + .node(); + }; + m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); + } + { + auto i = TensorFormats::NHCWc4, o = TensorFormats::NCHW; + auto&& impl = [](const VarNodeArray& vars) { + return opr::RelayoutFormat::make( + vars[0], + megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I) + .node(); + }; + m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); + } + // nhcw4-dot + { + auto i = TensorFormats::KCRS, o = TensorFormats::KRSCk4c4; + auto&& impl = [](const VarNodeArray& vars) { + return opr::RelayoutFormat::make( + vars[0], megdnn::param::RelayoutFormat::Mode:: + INTER_WEIGHT_DENSEI_DOT) + .node(); + }; + m_cache.emplace( + ReformatKey{i, o, Attribute::IMAGE2D, DTypeEnum::QuantizedS8, + DTypeEnum::QuantizedS8}, + impl); + m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D, + DTypeEnum::Quantized8Asymm, + DTypeEnum::Quantized8Asymm}, + impl); + } + { + auto i = TensorFormats::GKCRS, o = TensorFormats::GKRSCk4c4; + auto&& impl = [](const VarNodeArray& vars) { + return opr::RelayoutFormat::make( + vars[0], megdnn::param::RelayoutFormat::Mode:: + INTER_WEIGHT_GROUPI_DOT) + .node(); + }; + m_cache.emplace( + ReformatKey{i, o, Attribute::IMAGE2D, DTypeEnum::QuantizedS8, + DTypeEnum::QuantizedS8}, + impl); + m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D, + DTypeEnum::Quantized8Asymm, + DTypeEnum::Quantized8Asymm}, + impl); + } +} +#undef FOREACH_FEATURE_TENSOR_FORMATS +#undef FOREACH_WEIGHT_TENSOR_FORMATS + +const ReformatManager::ReformatImpl& ReformatManager::get( + const ReformatKey& key) const { + MGB_TRY { + auto&& impl = m_cache.at(key); + return impl; + } + MGB_CATCH(std::exception & exc, { + mgb_log_error( + "cannot find ReformatImpl for ReformatKey(%s), extra " + "message(%s)", + key.to_string().c_str(), exc.what()); + throw; + }) +} + +const ReformatManager& ReformatManager::instance() { + static ReformatManager* inst = nullptr; + if (inst == nullptr) { + inst = new ReformatManager(); + } + return *inst; +} +// vim: syntax=cpp.doxygen diff --git a/src/gopt/include/megbrain/gopt/reformat_manager.h b/src/gopt/include/megbrain/gopt/reformat_manager.h new file mode 100644 index 000000000..547fab19f --- /dev/null +++ b/src/gopt/include/megbrain/gopt/reformat_manager.h @@ -0,0 +1,103 @@ +/** + * \file src/gopt/include/megbrain/gopt/reformat_manager.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "megbrain/gopt/reformat_emitter.h" +#include "megbrain/graph.h" + +namespace mgb { +namespace gopt { + +enum class TensorFormats : uint32_t { + // input tensor formats + NCHW = 0, ///< [N, C, H, W] + NHWC = 1, ///< [N, H, W, C] + NCHWc4 = 2, ///< [N, C/4, H, W, C%4] + NCHWc8 = 3, ///< [N, C/8, H, W, C%8] + NCHWc32 = 4, ///< [N, C/32, H, W, C%32] + NCHWc64 = 5, ///< [N, C/64, H, W, C%64] + CHWNc4 = 6, ///< [C/4, H, W, N, C%4] + NHCWc4 = 7, ///< [N, H, C/4, W, C%4] + // weight tensor formats + // NHWCD4 + KRSCk4 = 8, ///< [K/4, R, S, C, K%4] [dense conv] + GKRSCk4 = 9, ///< [G, K/4, R, S, C, K%4] [group conv] + C1RSc4 = 10, ///< [C/4, 1, R, S, C%4] [channel wise conv] + // NHWCD4-dot + KRSCk4c4 = 11, ///< [K/4, R, S, C/4, K%4, C%4] [dense conv] + GKRSCk4c4 = 12, ///< [G, K/4, R, S, C/4, K%4, C%4] [group conv] + // NCHW44-dot + KCRSk4c4 = 13, ///< [K/4, C/4, R, S, K%4, C%4] [dense conv] + GKCRSk4c4 = 14, ///< [G, K/4, C/4, R, S, K%4, C%4] [group conv] + // NCHW44 + KCRSc4k4 = 15, ///< [K/4, C/4, R, S, C%4, K%4] [dense conv] + GKCRSc4k4 = 16, ///< [G, K/4, C/4, R, S, C%4, K%4] [group conv] + C11RSc4 = 17, ///< [C/4, 1, 1, R, S, C%4] [channel wise conv] + // NCHW88 + KCRSc8k8 = 18, ///< [K/8, C/8, R, S, C%8, K%8] [dense conv] + GKCRSc8k8 = 19, ///< [G, K/8, C/8, R, S, C%8, K%8] [group conv] + C11RSc8 = 20, ///< [C/8, 1, 1, R, S, C%8] [channel wise conv] + + KRSCk8 = 21, ///< [K/8, R, S, C, K%8] + + // default weight format + KCRS = 22, ///< [K, C, R, S] + GKCRS = 23, ///< [G, K, C, R, S] + C11RS = 24, ///< [C, 1, 1, R, S] +}; + +class ReformatManager : public NonCopyableObj { + ReformatManager(); + +public: + using ReformatImpl = thin_function; + enum class Attribute : uint32_t { + DEFAULT = 0, + IMAGE2D = 1 << 0, + IC_SMALL = 1 << 1, + }; + struct ReformatKey { + TensorFormats input_format, output_format; + DTypeEnum input_dtype, output_dtype; + Attribute attribute; + std::string to_string() const; + ReformatKey(TensorFormats input_format_, TensorFormats output_format_, + Attribute attribute_ = Attribute::DEFAULT, + DTypeEnum input_dtype_ = DTypeEnum::Float32, + DTypeEnum output_dtype_ = DTypeEnum::Float32) + : input_format{input_format_}, + output_format{output_format_}, + input_dtype{input_dtype_}, + output_dtype{output_dtype_}, + attribute{attribute_} {} + struct Hash { + size_t operator()(const ReformatKey& key) const; + }; + struct Equal { + bool operator()(const ReformatKey& lhs, + const ReformatKey& rhs) const; + }; + }; + using ReformatCache = + std::unordered_map; + const ReformatImpl& get(const ReformatKey& key) const; + static const ReformatManager& instance(); + +private: + ReformatCache m_cache; +}; + +} // namespace gopt +} // namespace mgb + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} -- GitLab