From a6230ba95add1c8fac92f94654823dac238cd2bc Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 24 Sep 2021 14:54:41 +0800 Subject: [PATCH] feat(mgb/gopt): global layout transform support arm GitOrigin-RevId: db50b33c112b99ab6f34cd81d9cf62790fc87c6e --- src/gopt/impl/framework.cpp | 15 +- .../layout_transform_context.cpp | 41 ++++ .../layout_transform_pass.cpp | 12 +- .../opr_tensor_formats_config.cpp | 203 +++++++++++++++++- .../global_layout_transform/profiler_impl.cpp | 8 +- .../reformat_manager.cpp | 27 ++- .../subgraph_extractor.cpp | 32 +-- src/gopt/impl/global_layout_transform/utils.h | 7 + 8 files changed, 314 insertions(+), 31 deletions(-) diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index b5ca65ce7..25af6d58d 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -820,23 +820,26 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options( _passes need_param_fuse = true; \ } + using Target = GraphTuningOptions::Target; cb(layout_transform, { add_pass(); - add_pass(); + if (options.target == Target::CUDA) + add_pass(); add_pass(LayoutTransformPass::make(options.target)); add_pass(); - add_pass(FuseNCHW4Int8Preprocess::make()); - add_pass(); + if (options.target == Target::CUDA) { + add_pass(FuseNCHW4Int8Preprocess::make()); + add_pass(); #if CUDA_VERSION >= 10020 - add_pass(); - add_pass(); + add_pass(); + add_pass(); #endif + } }); #undef cb if (need_param_fuse) { add_pass(); - add_pass(); } return *this; } diff --git a/src/gopt/impl/global_layout_transform/layout_transform_context.cpp b/src/gopt/impl/global_layout_transform/layout_transform_context.cpp index 13c27d8f7..5034c7dcd 100644 --- a/src/gopt/impl/global_layout_transform/layout_transform_context.cpp +++ b/src/gopt/impl/global_layout_transform/layout_transform_context.cpp @@ -15,6 +15,7 @@ #include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/imgproc.h" #include "megbrain/opr/nn_int.h" +#include "megbrain/opr/tensor_manip.h" using namespace mgb; using namespace gopt; @@ -82,6 +83,44 @@ std::unique_ptr make_cuda_ctx( {OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64}); return ctx; } + +std::unique_ptr make_arm_ctx( + OprFormat base_opr_format, TensorFormats base_tensor_format) { + OprList opr_list = { + opr::ConvBiasForward::typeinfo(), + opr::ConvolutionForward::typeinfo(), + opr::ElemwiseMultiType::typeinfo(), + opr::Elemwise::typeinfo(), + opr::TypeCvt::typeinfo(), + opr::PoolingForward::typeinfo(), + opr::Resize::typeinfo(), + opr::PowC::typeinfo(), + opr::Concat::typeinfo(), + }; + + SmallVector available_tensor_formats = { + TensorFormats::NCHW, TensorFormats::NCHWc4, + DNN_INC_FLOAT16(TensorFormats::NCHWc8)}; + Attribute attribute = {base_opr_format, base_tensor_format, Target::ARM}; + auto ctx = std::make_unique( + std::move(opr_list), std::move(available_tensor_formats), + attribute); + ctx->add_opr_config( + opr::ConvBiasForward::typeinfo(), + {OprFormat::NCHW, OprFormat::NCHW44, + DNN_INC_FLOAT16(OprFormat::NCHW88), OprFormat::NCHW44_DOT}) + .add_opr_config( + opr::ConvolutionForward::typeinfo(), + {OprFormat::NCHW, OprFormat::NCHW44, + DNN_INC_FLOAT16(OprFormat::NCHW88), OprFormat::NCHW44_DOT}) + .add_opr_config(opr::PoolingForward::typeinfo(), + {OprFormat::NCHW, OprFormat::NCHW44, + DNN_INC_FLOAT16(OprFormat::NCHW88)}) + .add_opr_config(opr::ResizeForward::typeinfo(), + {OprFormat::NCHW, OprFormat::NCHW44, + DNN_INC_FLOAT16(OprFormat::NCHW88)}); + return ctx; +} } // namespace /* ================= LayoutTransformContext ==================*/ @@ -110,6 +149,8 @@ std::unique_ptr LayoutTransformContext::make( switch (target) { case Target::CUDA: return make_cuda_ctx(base_opr_format, base_tensor_format); + case Target::ARM: + return make_arm_ctx(base_opr_format, base_tensor_format); default: mgb_assert(false, "unsupported target %s\n", target_to_string(target)); } diff --git a/src/gopt/impl/global_layout_transform/layout_transform_pass.cpp b/src/gopt/impl/global_layout_transform/layout_transform_pass.cpp index 6f99431da..206b51e4f 100644 --- a/src/gopt/impl/global_layout_transform/layout_transform_pass.cpp +++ b/src/gopt/impl/global_layout_transform/layout_transform_pass.cpp @@ -60,6 +60,7 @@ void LayoutTransformPass::apply(OptState& opt) const { auto&& opr_configs = m_ctx->opr_configs(); auto&& base_fmt = m_ctx->attribute().base_tensor_formats; + auto&& base_opr_fmt = m_ctx->attribute().base_opr_format; auto&& reformat_attribute = m_ctx->attribute().reformat_attribute; ThinHashMap var2fmts; static ThinHashSet format_aware_oprs = { @@ -68,15 +69,18 @@ void LayoutTransformPass::apply(OptState& opt) const { #undef cb }; auto rewriter = opt.graph().make_rewriter(); - auto on_opr = [&opr_configs, &base_fmt, &reformat_attribute, &rewriter, &solution, - &var2fmts, &endpoint_vars](OperatorNodeBase* opr) { + auto on_opr = [&opr_configs, &base_fmt, &base_opr_fmt, &reformat_attribute, + &rewriter, &solution, &var2fmts, + &endpoint_vars](OperatorNodeBase* opr) { auto it = solution.find(opr); if (it != solution.end()) { auto opr_fmt = it->second; auto find = opr_configs.find(opr->dyn_typeinfo()); Maybe fmtcfg = None; + Maybe basecfg = None; if (find != opr_configs.end()) { fmtcfg = (*find->second.at(opr_fmt))(opr); + basecfg = (*find->second.at(base_opr_fmt))(opr); } VarNodeArray new_inp; size_t nr_inps = opr->input().size(); @@ -103,6 +107,10 @@ void LayoutTransformPass::apply(OptState& opt) const { bool is_parameter = fmtcfg.valid() && fmtcfg.val().input_tensor_types[i] == TensorType::WEIGHT; + if (is_parameter) { + mgb_assert(basecfg.valid()); + from = basecfg.val().input_tensor_formats[i]; + } // need relayout if (from != to && !new_var->shape().is_scalar()) { ReformatManager::ReformatImpl reformat; diff --git a/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp b/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp index 3b6b7b2d2..c3a327303 100644 --- a/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp +++ b/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp @@ -78,6 +78,48 @@ struct OprSingleInOutTensorFormatsDispatcherImpl { } }; +template <> +struct OprSingleInOutTensorFormatsDispatcherImpl { + static Maybe dispatch( + const OperatorNodeBase* opr) { + OprTensorFormatsConfiguration config; + config.typeinfo = opr->dyn_typeinfo(); + config.opr_format = OprFormat::NCHW44; + bool available = true; + available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float32; + config.input_dtypes = {opr->input(0)->dtype().enumv()}; + config.input_tensor_types = {TensorType::FEATURE}; + config.output_dtypes = {opr->output(0)->dtype().enumv()}; + config.input_tensor_formats = {TensorFormats::NCHWc4}; + config.output_tensor_formats = {TensorFormats::NCHWc4}; + if (!available) + return None; + return config; + } +}; + +#if !MEGDNN_DISABLE_FLOAT16 +template <> +struct OprSingleInOutTensorFormatsDispatcherImpl { + static Maybe dispatch( + const OperatorNodeBase* opr) { + OprTensorFormatsConfiguration config; + config.typeinfo = opr->dyn_typeinfo(); + config.opr_format = OprFormat::NCHW88; + bool available = true; + available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float16; + config.input_dtypes = {opr->input(0)->dtype().enumv()}; + config.input_tensor_types = {TensorType::FEATURE}; + config.output_dtypes = {opr->output(0)->dtype().enumv()}; + config.input_tensor_formats = {TensorFormats::NCHWc8}; + config.output_tensor_formats = {TensorFormats::NCHWc8}; + if (!available) + return None; + return config; + } +}; +#endif + template <> struct OprSingleInOutTensorFormatsDispatcherImpl { static Maybe dispatch(const OperatorNodeBase* opr) { @@ -200,7 +242,7 @@ struct ConvTensorFormatsDispatcherImpl { // setup tensor formats if (conv.param().sparse == Opr::Param::Sparse::DENSE) { config.input_tensor_formats = { - TensorFormats::NCHW, TensorFormats::NCHW, TensorFormats::NCHW, + TensorFormats::NCHW, TensorFormats::KCRS, TensorFormats::NCHW, TensorFormats::NCHW}; } else { mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); @@ -396,6 +438,145 @@ struct ConvTensorFormatsDispatcherImpl { } }; +template +struct ConvTensorFormatsDispatcherImpl { + static Maybe dispatch( + const OperatorNodeBase* opr) { + const auto& conv = opr->cast_final_safe(); + OprTensorFormatsConfiguration config; + config.typeinfo = opr->dyn_typeinfo(); + config.opr_format = OprFormat::NCHW44; + bool available = true; + // setup dtypes + for (size_t i = 0; i < opr->input().size(); ++i) { + available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32; + config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); + TensorType tensor_type = + i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; + config.input_tensor_types.emplace_back(tensor_type); + } + available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32; + config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); + // setup tensor formats + if (conv.param().sparse == Opr::Param::Sparse::DENSE) { + config.input_tensor_formats = { + TensorFormats::NCHWc4, TensorFormats::KCRSc4k4, + TensorFormats::NCHWc4, TensorFormats::NCHWc4}; + } else { + mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); + if (is_channel_wise_conv(opr)) { + config.input_tensor_formats = { + TensorFormats::NCHWc4, TensorFormats::C11RSc4, + TensorFormats::NCHWc4, TensorFormats::NCHWc4}; + } else { + config.input_tensor_formats = { + TensorFormats::NCHWc4, TensorFormats::GKCRSc4k4, + TensorFormats::NCHWc4, TensorFormats::NCHWc4}; + } + } + config.output_tensor_formats = {TensorFormats::NCHWc4}; + if (!available) + return None; + return config; + } +}; + +#if !MEGDNN_DISABLE_FLOAT16 +template +struct ConvTensorFormatsDispatcherImpl { + static Maybe dispatch( + const OperatorNodeBase* opr) { + const auto& conv = opr->cast_final_safe(); + OprTensorFormatsConfiguration config; + config.typeinfo = opr->dyn_typeinfo(); + config.opr_format = OprFormat::NCHW88; + bool available = true; + // setup dtypes + for (size_t i = 0; i < opr->input().size(); ++i) { + available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float16; + config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); + TensorType tensor_type = + i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; + config.input_tensor_types.emplace_back(tensor_type); + } + available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float16; + config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); + // setup tensor formats + if (conv.param().sparse == Opr::Param::Sparse::DENSE) { + config.input_tensor_formats = { + TensorFormats::NCHWc8, TensorFormats::KCRSc8k8, + TensorFormats::NCHWc8, TensorFormats::NCHWc8}; + } else { + mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); + if (is_channel_wise_conv(opr)) { + config.input_tensor_formats = { + TensorFormats::NCHWc8, TensorFormats::C11RSc8, + TensorFormats::NCHWc8, TensorFormats::NCHWc8}; + } else { + config.input_tensor_formats = { + TensorFormats::NCHWc8, TensorFormats::GKCRSc8k8, + TensorFormats::NCHWc8, TensorFormats::NCHWc8}; + } + } + config.output_tensor_formats = {TensorFormats::NCHWc8}; + if (!available) + return None; + return config; + } +}; +#endif + +template +struct ConvTensorFormatsDispatcherImpl { + static Maybe dispatch( + const OperatorNodeBase* opr) { + const auto& conv = opr->cast_final_safe(); + OprTensorFormatsConfiguration config; + config.typeinfo = opr->dyn_typeinfo(); + config.opr_format = OprFormat::NCHW44_DOT; + bool available = true; + // setup dtypes + for (size_t i = 0; i < opr->input().size(); ++i) { + if (i == 2) { + available &= opr->input(i)->dtype().enumv() == + DTypeEnum::QuantizedS32; + } else { + available &= opr->input(i)->dtype().enumv() == + DTypeEnum::QuantizedS8 || + opr->input(i)->dtype().enumv() == + DTypeEnum::Quantized8Asymm; + } + config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); + TensorType tensor_type = + i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; + config.input_tensor_types.emplace_back(tensor_type); + } + available &= + opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8 || + opr->output(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm; + config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); + // setup tensor formats + if (conv.param().sparse == Opr::Param::Sparse::DENSE) { + config.input_tensor_formats = { + TensorFormats::NCHWc4, TensorFormats::KCRSk4c4, + TensorFormats::NCHWc4, TensorFormats::NCHWc4}; + } else { + mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); + if (is_channel_wise_conv(opr)) { + available = false; + } else { + config.input_tensor_formats = { + TensorFormats::NCHWc4, TensorFormats::GKCRSk4c4, + TensorFormats::NCHWc4, TensorFormats::NCHWc4}; + } + } + config.output_tensor_formats = {TensorFormats::NCHWc4}; + if (!available) + return None; + return config; + } +}; + template <> struct ConvTensorFormatsDispatcherImpl { using Opr = opr::ConvolutionBackwardData; @@ -530,9 +711,19 @@ StaticData::StaticData() { OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, CHWN4); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW32); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW64); + OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44); +#if !MEGDNN_DISABLE_FLOAT16 + OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88); +#endif + OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW4); + OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44); +#if !MEGDNN_DISABLE_FLOAT16 + OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW88); +#endif + OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWC); @@ -549,6 +740,16 @@ StaticData::StaticData() { OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, CHWN4); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW32); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW64); + OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW44); +#if !MEGDNN_DISABLE_FLOAT16 + OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW88); +#endif + + OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW); + OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW44); +#if !MEGDNN_DISABLE_FLOAT16 + OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW88); +#endif #undef OPR_TENSOR_FORMATS_CONFIG_REG #undef OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG diff --git a/src/gopt/impl/global_layout_transform/profiler_impl.cpp b/src/gopt/impl/global_layout_transform/profiler_impl.cpp index 0ea951e90..58c392829 100644 --- a/src/gopt/impl/global_layout_transform/profiler_impl.cpp +++ b/src/gopt/impl/global_layout_transform/profiler_impl.cpp @@ -35,9 +35,9 @@ OprFormat tensor_formats_to_opr_format(TensorFormats tensor_format) { case TensorFormats::NCHW: return OprFormat::NCHW; case TensorFormats::NCHWc4: - return OprFormat::NCHW4; + return OprFormat::NCHW44; case TensorFormats::NCHWc8: - return OprFormat::NCHW8; + return OprFormat::NCHW88; case TensorFormats::NCHWc32: return OprFormat::NCHW32; case TensorFormats::NCHWc64: @@ -424,11 +424,11 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile(const Problem& problem) cons skip &= problem.graph_partition().input().count(i) > 0 || skip_oprs.count(i->owner_opr()) > 0; } - skip &= skip_opr_types.count(opr->dyn_typeinfo()); + auto find = format_aware_input_tensors.find(opr->dyn_typeinfo()); + skip &= find == format_aware_input_tensors.end(); if (skip) skip_oprs.insert(opr); oprs.insert(opr); - auto find = format_aware_input_tensors.find(opr->dyn_typeinfo()); if (find == format_aware_input_tensors.end()) { for (auto&& i : opr->input()) { if (!cvprop.is_const(i)) { diff --git a/src/gopt/impl/global_layout_transform/reformat_manager.cpp b/src/gopt/impl/global_layout_transform/reformat_manager.cpp index 7d1d46533..7e5c6ef70 100644 --- a/src/gopt/impl/global_layout_transform/reformat_manager.cpp +++ b/src/gopt/impl/global_layout_transform/reformat_manager.cpp @@ -470,9 +470,9 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( input_shape[i].extent() == Dimension::UNDETERMINED_EXTENT) { in_channels = orig_var->shape()[i] * input_shape[i].stride(); input_channel_idx = i; - // mgb_assert(input_shape[i].stride() == 1, - // "unsupport weight format(got:%s)", - // input_shape.to_string().c_str()); + mgb_assert( + input_shape[i].stride() == 1, "unsupport weight format(got:%s)", + input_shape.to_string().c_str()); } else if ( (input_shape[i].name() == Dimension::Name::K || input_shape[i].name() == Dimension::Name::N) && @@ -485,13 +485,23 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( input_shape.to_string().c_str()); } } + /* \notes: FIXME this is a hack. Since the layout of weight in channelwise + * convolution does not have output channel dimension, so we mannually modify the + * out_channel_name, out_channel_idx to bypass the following assertion statements. */ + bool is_channelwise = key.input_format == TensorFormats::C11RS; + if (is_channelwise) { + out_channel_name = Dimension::Name::K; + out_channels = in_channels; + output_channel_idx = input_channel_idx; + } mgb_assert( out_channel_name == Dimension::Name::K || out_channel_name == Dimension::Name::N, "invalid out channel(shp:%s)", input_shape.to_string().c_str()); mgb_assert( - input_channel_idx < input_shape.ndim && - output_channel_idx < input_shape.ndim, + (input_channel_idx < input_shape.ndim && + output_channel_idx < input_shape.ndim) || + (is_channelwise && output_channel_idx == input_channel_idx), "invalid channel idx(in_channel:%zu, out_channel:%zu, shp:%s)", input_channel_idx, output_channel_idx, input_shape.to_string().c_str()); size_t in_channel_alignment = 0, out_channel_alignment = 0; @@ -506,6 +516,13 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( out_channel_alignment = output_shape[i].stride(); } } + /* \notes: FIXME this is a hack. Since the layout of weight in channelwise + * convolution does not have output channel dimension, so we mannually modify the + * out_channel_alignment to bypass the following assertion statements. */ + if (is_channelwise) { + mgb_assert(out_channel_alignment == 0); + out_channel_alignment = 1; + } mgb_assert( in_channel_alignment > 0 && out_channel_alignment > 0, "invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)", diff --git a/src/gopt/impl/global_layout_transform/subgraph_extractor.cpp b/src/gopt/impl/global_layout_transform/subgraph_extractor.cpp index c18837bfc..9384f5cff 100644 --- a/src/gopt/impl/global_layout_transform/subgraph_extractor.cpp +++ b/src/gopt/impl/global_layout_transform/subgraph_extractor.cpp @@ -263,20 +263,9 @@ std::vector SubGraphExtractor::extract( std::vector partitions; partitions.reserve(topo.size()); ThinHashMap roots; + /// backward pass for (const auto& opr : reverse_adaptor(topo)) { - if (m_opr_list.count(opr->dyn_typeinfo()) == 0) { - for (const auto& i : opr->input()) { - if (m_opr_list.count(i->owner_opr()->dyn_typeinfo())) { - auto root = union_find(i->owner_opr()); - GraphPartition* partition; - auto find = roots.find(root); - if (find != roots.end()) { - partition = find->second; - partition->output().insert(i); - } - } - } - } else { + if (m_opr_list.count(opr->dyn_typeinfo()) > 0) { auto root = union_find(opr); auto find = roots.find(root); GraphPartition* partition = nullptr; @@ -304,6 +293,23 @@ std::vector SubGraphExtractor::extract( partition->input().insert(i); } } + /// forward pass + for (auto&& opr : topo) { + if (m_opr_list.count(opr->dyn_typeinfo()) == 0) { + for (const auto& i : opr->input()) { + if (m_opr_list.count(i->owner_opr()->dyn_typeinfo())) { + auto root = union_find(i->owner_opr()); + GraphPartition* partition; + auto find = roots.find(root); + if (find != roots.end()) { + partition = find->second; + partition->output().insert(i); + } + } + } + } + } + for (auto&& partition : partitions) { auto& all_oprs = partition.all_oprs(); std::reverse(all_oprs.begin(), all_oprs.end()); diff --git a/src/gopt/impl/global_layout_transform/utils.h b/src/gopt/impl/global_layout_transform/utils.h index 5b108ec7e..336f4a9d8 100644 --- a/src/gopt/impl/global_layout_transform/utils.h +++ b/src/gopt/impl/global_layout_transform/utils.h @@ -29,6 +29,9 @@ static inline const char* opr_format_to_string( cb(NCHW32); cb(NCHW64); cb(CHWN4); + cb(NCHW44); + cb(NCHW88); + cb(NCHW44_DOT); default: mgb_assert( false, "Invalid opr format(got:%u)", @@ -53,6 +56,10 @@ static inline TensorFormats opr_format_to_tensor_formats( return TensorFormats::NCHWc64; case OprFormat::CHWN4: return TensorFormats::CHWNc4; + case OprFormat::NCHW88: + return TensorFormats::NCHWc8; + case OprFormat::NCHW44: + return TensorFormats::NCHWc4; default: mgb_throw( AssertionError, "format(%s) is not supported", -- GitLab