/** * \file src/gopt/impl/opr_tensor_formats_config.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #include "./utils.h" #include "megbrain/gopt/layout_transform_context.h" #include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/imgproc.h" #include "midout.h" MIDOUT_DECL(megbrain_opr_tensor_formats_config) #define MIDOUT_B(...) \ MIDOUT_BEGIN(megbrain_opr_tensor_formats_config, __VA_ARGS__) { #define MIDOUT_E \ } \ MIDOUT_END(); using namespace mgb; using namespace cg; using namespace gopt; using OprFormat = opr::ConvBias::Param::Format; namespace { template struct ConvParamTrait; #define INST(_conv, _weight_idx, _bias_idx, _has_bias) \ template <> \ struct ConvParamTrait { \ static constexpr int weight_idx = _weight_idx; \ static constexpr int bias_idx = _bias_idx; \ static constexpr bool has_bias = _has_bias; \ } INST(ConvBias, 1, 2, true); INST(ConvolutionForward, 1, 0, false); INST(ConvolutionBackwardData, 0, 0, false); template ::weight_idx> static bool is_channel_wise_conv(const OperatorNodeBase* opr) { MGB_MARK_USED_VAR(ConvParamTrait::has_bias); MGB_MARK_USED_VAR(ConvParamTrait::bias_idx); auto&& conv = opr->cast_final_safe(); auto format = conv.param().format; auto weight = opr->input(weight_idx); auto weight_shp = weight->shape(); if (conv.param().sparse == Opr::Param::Sparse::DENSE) return false; size_t ocpg, icpg; if (format == Opr::Param::Format::NCHW) { ocpg = weight_shp[1], icpg = weight_shp[2]; return ocpg == 1 && icpg == 1; } return false; } template struct OprSingleInOutTensorFormatsDispatcherImpl; template <> struct OprSingleInOutTensorFormatsDispatcherImpl { static Maybe dispatch( const OperatorNodeBase* opr) { OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW; config.input_dtypes = {opr->input(0)->dtype().enumv()}; config.input_tensor_types = {TensorType::FEATURE}; config.output_dtypes = {opr->output(0)->dtype().enumv()}; config.input_tensor_formats = {TensorFormats::NCHW}; config.output_tensor_formats = {TensorFormats::NCHW}; return config; } }; template <> struct OprSingleInOutTensorFormatsDispatcherImpl { static Maybe dispatch( const OperatorNodeBase* opr) { OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW4; bool available = true; available &= opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8; config.input_dtypes = {opr->input(0)->dtype().enumv()}; config.input_tensor_types = {TensorType::FEATURE}; available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; config.output_dtypes = {opr->output(0)->dtype().enumv()}; config.input_tensor_formats = {TensorFormats::NCHWc4}; config.output_tensor_formats = {TensorFormats::NCHWc4}; if (available) return config; return None; } }; template <> struct OprSingleInOutTensorFormatsDispatcherImpl { static Maybe dispatch( const OperatorNodeBase* opr) { OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::CHWN4; bool available = true; available &= opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8; config.input_dtypes = {opr->input(0)->dtype().enumv()}; config.input_tensor_types = {TensorType::FEATURE}; available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; config.output_dtypes = {opr->output(0)->dtype().enumv()}; config.input_tensor_formats = {TensorFormats::CHWNc4}; config.output_tensor_formats = {TensorFormats::CHWNc4}; if (available) return config; return None; } }; template <> struct OprSingleInOutTensorFormatsDispatcherImpl { static Maybe dispatch( const OperatorNodeBase* opr) { OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW32; bool available = true; available &= opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8; config.input_dtypes = {opr->input(0)->dtype().enumv()}; config.input_tensor_types = {TensorType::FEATURE}; available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; config.output_dtypes = {opr->output(0)->dtype().enumv()}; config.input_tensor_formats = {TensorFormats::NCHWc32}; config.output_tensor_formats = {TensorFormats::NCHWc32}; if (available) return config; return None; } }; template <> struct OprSingleInOutTensorFormatsDispatcherImpl { static Maybe dispatch( const OperatorNodeBase* opr) { OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NHWC; bool available = true; available &= opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm || opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4; config.input_dtypes = {opr->input(0)->dtype().enumv()}; config.input_tensor_types = {TensorType::FEATURE}; available &= opr->output(0)->dtype().enumv() == opr->input(0)->dtype().enumv(); config.output_dtypes = {opr->output(0)->dtype().enumv()}; config.input_tensor_formats = {TensorFormats::NHWC}; config.output_tensor_formats = {TensorFormats::NHWC}; if (available) return config; return None; } }; template <> struct OprSingleInOutTensorFormatsDispatcherImpl { static Maybe dispatch( const OperatorNodeBase* opr) { OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW64; bool available = true; available &= opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm || opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4; config.input_dtypes = {opr->input(0)->dtype().enumv()}; config.input_tensor_types = {TensorType::FEATURE}; available &= opr->output(0)->dtype().enumv() == opr->input(0)->dtype().enumv(); config.output_dtypes = {opr->output(0)->dtype().enumv()}; config.input_tensor_formats = {TensorFormats::NCHWc64}; config.output_tensor_formats = {TensorFormats::NCHWc64}; if (available) return config; return None; } }; template struct ConvTensorFormatsDispatcherImpl; template struct ConvTensorFormatsDispatcherImpl { static Maybe dispatch( const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW; // setup dtypes for (size_t i = 0; i < opr->input().size(); ++i) { config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; config.input_tensor_types.emplace_back(tensor_type); } config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); // setup tensor formats if (conv.param().sparse == Opr::Param::Sparse::DENSE) { config.input_tensor_formats = { TensorFormats::NCHW, TensorFormats::NCHW, TensorFormats::NCHW, TensorFormats::NCHW}; } else { mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); if (is_channel_wise_conv(opr)) { config.input_tensor_formats = { TensorFormats::NCHW, TensorFormats::C11RS, TensorFormats::NCHW, TensorFormats::NCHW}; } else { config.input_tensor_formats = { TensorFormats::NCHW, TensorFormats::GKCRS, TensorFormats::NCHW, TensorFormats::NCHW}; } } config.output_tensor_formats = {TensorFormats::NCHW}; return config; } }; template struct ConvTensorFormatsDispatcherImpl { static Maybe dispatch( const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NHWC; bool available = true; for (size_t i = 0; i < opr->input().size(); ++i) { if (i == 2) available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS32; else { bool i4_config = opr->input(i)->dtype().enumv() == DTypeEnum::Quantized4Asymm || opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS4; bool i8_config = opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8; available &= (i4_config || i8_config); } config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; config.input_tensor_types.emplace_back(tensor_type); } bool i4_config = opr->output(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm || opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS4; bool i8_config = opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; available &= (i4_config || i8_config); config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); available &= conv.param().sparse == Opr::Param::Sparse::DENSE; config.input_tensor_formats = {TensorFormats::NHWC, TensorFormats::NHWC, TensorFormats::NHWC, TensorFormats::NHWC}; config.output_tensor_formats = {TensorFormats::NHWC}; if (available) return config; return None; } }; template struct ConvTensorFormatsDispatcherImpl { static Maybe dispatch( const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW4; bool available = true; // setup dtypes for (size_t i = 0; i < opr->input().size(); ++i) { if (i == 2) available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS32; else available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8; config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; config.input_tensor_types.emplace_back(tensor_type); } available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); // setup tensor formats if (conv.param().sparse == Opr::Param::Sparse::DENSE) { config.input_tensor_formats = { TensorFormats::NCHWc4, TensorFormats::NCHWc4, TensorFormats::NCHWc4, TensorFormats::NCHWc4}; } else { mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); if (is_channel_wise_conv(opr)) { config.input_tensor_formats = { TensorFormats::NCHWc4, TensorFormats::C11RSc4, TensorFormats::NCHWc4, TensorFormats::NCHWc4}; } else { config.input_tensor_formats = { TensorFormats::NCHWc4, TensorFormats::GKCRSc4, TensorFormats::NCHWc4, TensorFormats::NCHWc4}; } } config.output_tensor_formats = {TensorFormats::NCHWc4}; if (available) return config; return None; } }; template struct ConvTensorFormatsDispatcherImpl { static Maybe dispatch( const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW32; bool available = true; for (size_t i = 0; i < opr->input().size(); ++i) { if (i == 2) available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS32; else available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8; config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; config.input_tensor_types.emplace_back(tensor_type); } available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); available &= conv.param().sparse == Opr::Param::Sparse::DENSE; config.input_tensor_formats = { TensorFormats::NCHWc32, TensorFormats::NCHWc32, TensorFormats::NCHWc32, TensorFormats::NCHWc32}; config.output_tensor_formats = {TensorFormats::NCHWc32}; if (available) return config; return None; } }; template struct ConvTensorFormatsDispatcherImpl { static Maybe dispatch( const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW64; bool available = true; for (size_t i = 0; i < opr->input().size(); ++i) { if (i == 2) available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS32; else available &= opr->input(i)->dtype().enumv() == DTypeEnum::Quantized4Asymm || opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS4; config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; config.input_tensor_types.emplace_back(tensor_type); } available &= opr->output(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm || opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS4; config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); available &= conv.param().sparse == Opr::Param::Sparse::DENSE; config.input_tensor_formats = { TensorFormats::NCHWc64, TensorFormats::NCHWc64, TensorFormats::NCHWc64, TensorFormats::NCHWc64}; config.output_tensor_formats = {TensorFormats::NCHWc64}; if (available) return config; return None; } }; template struct ConvTensorFormatsDispatcherImpl { static Maybe dispatch( const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::CHWN4; bool available = true; for (size_t i = 0; i < opr->input().size(); ++i) { if (i == 2) available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS32; else available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8; config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; config.input_tensor_types.emplace_back(tensor_type); } available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); available &= conv.param().sparse == Opr::Param::Sparse::DENSE; config.input_tensor_formats = { TensorFormats::CHWNc4, TensorFormats::CHWNc4, TensorFormats::CHWNc4, TensorFormats::CHWNc4}; config.output_tensor_formats = {TensorFormats::CHWNc4}; if (available) return config; return None; } }; template <> struct ConvTensorFormatsDispatcherImpl { using Opr = opr::ConvolutionBackwardData; static Maybe dispatch( const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW; // setup dtypes for (size_t i = 0; i < opr->input().size(); ++i) { config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); TensorType tensor_type = i == 0 ? TensorType::WEIGHT : TensorType::FEATURE; config.input_tensor_types.emplace_back(tensor_type); } config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); // setup tensor formats if (conv.param().sparse == Opr::Param::Sparse::DENSE) { config.input_tensor_formats = { TensorFormats::NCHW, TensorFormats::NCHW, TensorFormats::NCHW, TensorFormats::NCHW}; } else { mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); if (is_channel_wise_conv(opr)) { config.input_tensor_formats = { TensorFormats::C11RS, TensorFormats::NCHW, TensorFormats::NCHW, TensorFormats::NCHW}; } else { config.input_tensor_formats = { TensorFormats::GKCRS, TensorFormats::NCHW, TensorFormats::NCHW, TensorFormats::NCHW}; } } config.output_tensor_formats = {TensorFormats::NCHW}; return config; } }; template <> struct ConvTensorFormatsDispatcherImpl { using Opr = opr::ConvolutionBackwardData; static Maybe dispatch( const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW4; bool available = true; for (size_t i = 0; i < opr->input().size(); ++i) { available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8; config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); TensorType tensor_type = i == 0 ? TensorType::WEIGHT : TensorType::FEATURE; config.input_tensor_types.emplace_back(tensor_type); } available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); available &= conv.param().sparse == opr::ConvBias::Param::Sparse::DENSE; config.input_tensor_formats = { TensorFormats::NCHWc4, TensorFormats::NCHWc4, TensorFormats::NCHWc4, TensorFormats::NCHWc4}; config.output_tensor_formats = {TensorFormats::NCHWc4}; if (available) return config; return None; } }; template <> struct ConvTensorFormatsDispatcherImpl { using Opr = opr::ConvolutionBackwardData; static Maybe dispatch( const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NHWC; bool available = true; for (size_t i = 0; i < opr->input().size(); ++i) { available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8; config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); TensorType tensor_type = i == 0 ? TensorType::WEIGHT : TensorType::FEATURE; config.input_tensor_types.emplace_back(tensor_type); } available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); available &= conv.param().sparse == opr::ConvBias::Param::Sparse::DENSE; config.input_tensor_formats = {TensorFormats::NHWC, TensorFormats::NHWC, TensorFormats::NHWC, TensorFormats::NHWC}; config.output_tensor_formats = {TensorFormats::NHWC}; if (available) return config; return None; } }; struct StaticData { struct KeyHash { size_t operator()(const std::pair& val) const { size_t h1 = mgb::hash(val.first); size_t h2 = std::hash()(static_cast(val.second)); return mgb::hash_pair_combine(h1, h2); } }; using OprTensorFormatsDispatcher = OprTensorFormatsConfiguration::OprTensorFormatsDispatcher; std::unordered_map, OprTensorFormatsDispatcher, KeyHash> typefmt2dispatcher; StaticData(); }; StaticData::StaticData() { #define OPR_TENSOR_FORMATS_CONFIG_REG(_Opr, _fmt) \ typefmt2dispatcher[{opr::_Opr::typeinfo(), OprFormat::_fmt}] = \ [](const OperatorNodeBase* opr) { \ MIDOUT_B(opr::_Opr, midout_iv(OprFormat::_fmt)) \ return ConvTensorFormatsDispatcherImpl< \ opr::_Opr, OprFormat::_fmt>::dispatch(opr); \ MIDOUT_E \ } #define OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(_Opr, _fmt) \ typefmt2dispatcher[{opr::_Opr::typeinfo(), OprFormat::_fmt}] = \ [](const OperatorNodeBase* opr) { \ MIDOUT_B(opr::_Opr, midout_iv(OprFormat::_fmt)) \ return OprSingleInOutTensorFormatsDispatcherImpl< \ OprFormat::_fmt>::dispatch(opr); \ MIDOUT_E \ } OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NHWC); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW4); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, CHWN4); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW32); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW64); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW4); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWC); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW4); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NCHW); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NHWC); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NCHW4); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NCHW64); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NHWC); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW4); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, CHWN4); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW32); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW64); #undef OPR_TENSOR_FORMATS_CONFIG_REG #undef OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG } StaticData& static_data() { static StaticData inst; return inst; } } // namespace OprTensorFormatsConfiguration::OprTensorFormatsDispatcher* OprTensorFormatsConfiguration::find_dispatcher_by_type_format( Typeinfo* type, OprFormat opr_format) { auto&& typefmt2dispatcher = static_data().typefmt2dispatcher; auto iter = typefmt2dispatcher.find(std::make_pair(type, opr_format)); mgb_assert(iter != typefmt2dispatcher.end(), "cannot find OprTensorFormatsDispatcher for opr type(%s) and " "opr format(%s)", type->name, opr_format_to_string(opr_format)); return &iter->second; } // vim: syntax=cpp.doxygen