From fe93013a6e53adbb49f74e79428dfff5b6fd266c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 22 Sep 2021 18:23:52 +0800 Subject: [PATCH] feat(mgb/gopt): global layout transform support nchw_nchwxx hybrid mode GitOrigin-RevId: 6d5b55d7fc67b536b25c2fe49457f6a74f9c62b5 --- dnn/src/common/convolution.cpp | 17 +- .../dynamic_programming_solver.cpp | 58 +-- .../layout_transform_context.cpp | 74 ++-- .../layout_transform_pass.cpp | 31 +- .../opr_format_modifier.h | 2 +- .../opr_tensor_formats_config.cpp | 343 ++++++++++++++---- .../profiler_cache.cpp | 2 +- .../global_layout_transform/profiler_impl.cpp | 62 ++-- .../profiling_based_solver.cpp | 3 +- src/gopt/impl/global_layout_transform/utils.h | 43 ++- .../megbrain/gopt/layout_transform_context.h | 62 +++- src/gopt/include/megbrain/gopt/profiler.h | 13 +- src/gopt/include/megbrain/gopt/solver.h | 3 +- src/gopt/test/embed_cache.py | 3 +- src/gopt/test/layout_transform_pass.cpp | 295 +++++++++++---- src/gopt/test/network.cpp | 90 +++++ src/gopt/test/network.h | 16 +- src/gopt/test/profiler.cpp | 21 +- 18 files changed, 866 insertions(+), 272 deletions(-) diff --git a/dnn/src/common/convolution.cpp b/dnn/src/common/convolution.cpp index 0fc9afbbb..b9fa2c8c9 100644 --- a/dnn/src/common/convolution.cpp +++ b/dnn/src/common/convolution.cpp @@ -830,9 +830,9 @@ typename ConvolutionBase::CanonizedFilterMeta ConvolutionBase::CanonizedFilterMeta ConvolutionBase& var2fmts, const OperatorNodeBase* opr, - OprFormat opr_fmt, const Context& ctx); + OprFormatConfigID config_id, const Context& ctx); /*! * \brief compute the distace of two states of the given varnode * \param[in] from the source state @@ -140,28 +141,35 @@ private: TensorFormats DynamicProgrammingSolver::Impl::get_io_formats( ThinHashMap& var2fmts, const OperatorNodeBase* opr, - OprFormat opr_fmt, const Context& ctx) { + OprFormatConfigID config_id, const Context& ctx) { auto&& rst = ctx.rst; auto&& opr_configs = ctx.opr_configs; auto iter = opr_configs.find(opr->dyn_typeinfo()); Maybe fmtcfg = None; + Maybe opr_fmt = None; if (iter != opr_configs.end()) { - fmtcfg = (*iter->second.at(opr_fmt))(opr); + fmtcfg = (*iter->second.at(config_id))(opr); + } else { + opr_fmt = OprTensorFormatsConfiguration::safe_cast_to_opr_format(config_id); } TensorFormats out_fmt; if (fmtcfg.valid()) out_fmt = fmtcfg.val().output_tensor_formats[0]; - else - out_fmt = opr_format_to_tensor_formats(opr_fmt); + else { + mgb_assert(opr_fmt.valid()); + out_fmt = opr_format_to_tensor_formats(opr_fmt.val()); + } for (size_t i = 0; i < opr->input().size(); ++i) { auto&& var = opr->input(i); auto iter = rst.var_record.find(var); if (iter != rst.var_record.end()) { if (fmtcfg.valid()) var2fmts[var] = fmtcfg.val().input_tensor_formats[i]; - else - var2fmts[var] = opr_format_to_tensor_formats(opr_fmt); + else { + mgb_assert(opr_fmt.valid()); + var2fmts[var] = opr_format_to_tensor_formats(opr_fmt.val()); + } } } return out_fmt; @@ -342,13 +350,13 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve( cuts.emplace_back(Cut{}); auto& states = cuts.back().states; for (const auto& record : records) { - auto opr_fmt = record.first; + auto cfg_id = record.first; float opr_time = record.second; ThinHashMap ivar2fmts; - auto out_fmt = get_io_formats(ivar2fmts, opr, opr_fmt, ctx); + auto out_fmt = get_io_formats(ivar2fmts, opr, cfg_id, ctx); const auto& edge = edges[cur]; State state(edge.size(), 0); - Value value{opr, nullptr, opr_fmt, 0.f, cur}; + Value value{opr, nullptr, cfg_id, 0.f, cur}; float ovar_time = 0.f; for (size_t i = 0; i < edge.size(); ++i) { auto&& var = edge[i]; @@ -396,16 +404,16 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve( const auto& records = it->second.costs; StateTable states; for (const auto& record : records) { - auto opr_fmt = record.first; + auto cfg_id = record.first; float opr_time = record.second; ThinHashMap ivar2fmts; - auto out_fmt = get_io_formats(ivar2fmts, opr, opr_fmt, ctx); + auto out_fmt = get_io_formats(ivar2fmts, opr, cfg_id, ctx); for (const auto& kv : cuts.back().states) { auto&& prev_state = kv.first; float prev_time = kv.second.time; const auto& edge = edges[cur]; State state(edge.size(), 0); - Value value{opr, &prev_state, opr_fmt, 0.f, cur}; + Value value{opr, &prev_state, cfg_id, 0.f, cur}; float ovar_time = 0.f; for (size_t i = 0; i < edge.size(); ++i) { auto&& var = edge[i]; @@ -482,7 +490,7 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve( /// backward pass to generate the solution float min_time = std::numeric_limits::max(); OperatorNodeBase* cur_opr = nullptr; - OprFormat min_fmt = OprFormat::NCHW; + OprFormatConfigID min_cfg = OprFormatConfigID::NCHW; const State* pstate = nullptr; for (auto&& kv : cuts.back().states) { auto&& v = kv.second; @@ -490,7 +498,7 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve( cur_opr = v.opr; pstate = v.prev; min_time = v.time; - min_fmt = v.opr_fmt; + min_cfg = v.cfg_id; ///! just to check the tensor formats of the output varnode auto&& k = kv.first; size_t opr_idx = v.opr_idx; @@ -505,10 +513,10 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve( } mgb_assert(cur_opr != nullptr); mgb_log_debug( - "opr:%s;format:%s;time:%f", cur_opr->cname(), opr_format_to_string(min_fmt), + "opr:%s;config:%s;time:%f", cur_opr->cname(), config_id_to_string(min_cfg), min_time); - solution.insert({cur_opr, min_fmt}); + solution.insert({cur_opr, min_cfg}); cur = cuts.size() - 2; while (pstate) { auto val = cuts[cur].states[*pstate]; @@ -522,9 +530,9 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve( } } mgb_log_debug( - "opr:%s;format:%s;time:%f", val.opr->cname(), - opr_format_to_string(val.opr_fmt), val.time); - solution.insert({val.opr, val.opr_fmt}); + "opr:%s;cofig:%s;time:%f", val.opr->cname(), + config_id_to_string(val.cfg_id), val.time); + solution.insert({val.opr, val.cfg_id}); pstate = val.prev; cur--; } diff --git a/src/gopt/impl/global_layout_transform/layout_transform_context.cpp b/src/gopt/impl/global_layout_transform/layout_transform_context.cpp index 20bc12f3b..ec5d68798 100644 --- a/src/gopt/impl/global_layout_transform/layout_transform_context.cpp +++ b/src/gopt/impl/global_layout_transform/layout_transform_context.cpp @@ -22,6 +22,7 @@ using namespace gopt; namespace { using OprFormat = LayoutTransformContext::OprFormat; +using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; using OprList = LayoutTransformContext::OprList; using Attribute = LayoutTransformContext::Attribute; using Target = LayoutTransformContext::Target; @@ -43,7 +44,7 @@ const char* target_to_string(Target target) { } std::unique_ptr make_cuda_ctx( - OprFormat base_opr_format, TensorFormats base_tensor_format) { + OprFormatConfigID base_config_id, TensorFormats base_tensor_format) { OprList opr_list = { opr::ConvBiasForward::typeinfo(), opr::ConvolutionForward::typeinfo(), @@ -58,34 +59,38 @@ std::unique_ptr make_cuda_ctx( SmallVector available_tensor_formats = { TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4, TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4}; - Attribute attribute = { - base_opr_format, base_tensor_format, Target::CUDA, + base_config_id, base_tensor_format, Target::CUDA, LayoutTransformContext::ReformatAttribute::AUTO_PADDING_NHWC}; auto ctx = std::make_unique( std::move(opr_list), std::move(available_tensor_formats), attribute); ctx->add_opr_config( opr::ConvBiasForward::typeinfo(), - {OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW32, - OprFormat::NCHW64, OprFormat::CHWN4}) + {OprFormatConfigID::NCHW, OprFormatConfigID::NHWC, + OprFormatConfigID::NCHW4_NCHW32, OprFormatConfigID::NCHW32_NCHW4, + OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, + OprFormatConfigID::NCHW64, OprFormatConfigID::CHWN4}) .add_opr_config( opr::ConvolutionForward::typeinfo(), - {OprFormat::NCHW, OprFormat::NCHW4}) + {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW4}) .add_opr_config( opr::ConvolutionBackwardData::typeinfo(), - {OprFormat::NCHW, OprFormat::NCHW4, OprFormat::NHWC}) + {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW4, + OprFormatConfigID::NHWC}) .add_opr_config( opr::PoolingForward::typeinfo(), - {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, - OprFormat::NCHW64, OprFormat::CHWN4}) + {OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, + OprFormatConfigID::NHWC, OprFormatConfigID::NCHW64, + OprFormatConfigID::CHWN4}) .add_opr_config( opr::WarpPerspectiveForward::typeinfo(), - {OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64}); + {OprFormatConfigID::NHWC, OprFormatConfigID::NCHW4, + OprFormatConfigID::NCHW64}); return ctx; } std::unique_ptr make_arm_ctx( - OprFormat base_opr_format, TensorFormats base_tensor_format) { + OprFormatConfigID base_config_id, TensorFormats base_tensor_format) { OprList opr_list = { opr::ConvBiasForward::typeinfo(), opr::ConvolutionForward::typeinfo(), @@ -101,57 +106,64 @@ std::unique_ptr make_arm_ctx( SmallVector available_tensor_formats = { TensorFormats::NCHW, TensorFormats::NCHWc4, DNN_INC_FLOAT16(TensorFormats::NCHWc8)}; - Attribute attribute = {base_opr_format, base_tensor_format, Target::ARM}; + Attribute attribute = {base_config_id, base_tensor_format, Target::ARM}; auto ctx = std::make_unique( std::move(opr_list), std::move(available_tensor_formats), attribute); ctx->add_opr_config( opr::ConvBiasForward::typeinfo(), - {OprFormat::NCHW, OprFormat::NCHW44, DNN_INC_FLOAT16(OprFormat::NCHW88), - OprFormat::NCHW44_DOT}) + {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44, + OprFormatConfigID::NCHW44_HYBRID, + DNN_INC_FLOAT16(OprFormatConfigID::NCHW88), + DNN_INC_FLOAT16(OprFormatConfigID::NCHW88_HYBRID), + OprFormatConfigID::NCHW44_DOT, OprFormatConfigID::NCHW44_DOT_HYBRID}) .add_opr_config( opr::ConvolutionForward::typeinfo(), - {OprFormat::NCHW, OprFormat::NCHW44, - DNN_INC_FLOAT16(OprFormat::NCHW88), OprFormat::NCHW44_DOT}) + {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44, + OprFormatConfigID::NCHW44_HYBRID, + DNN_INC_FLOAT16(OprFormatConfigID::NCHW88), + DNN_INC_FLOAT16(OprFormatConfigID::NCHW88_HYBRID), + OprFormatConfigID::NCHW44_DOT, + OprFormatConfigID::NCHW44_DOT_HYBRID}) .add_opr_config( opr::PoolingForward::typeinfo(), - {OprFormat::NCHW, OprFormat::NCHW44, - DNN_INC_FLOAT16(OprFormat::NCHW88)}) + {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44, + DNN_INC_FLOAT16(OprFormatConfigID::NCHW88)}) .add_opr_config( opr::ResizeForward::typeinfo(), - {OprFormat::NCHW, OprFormat::NCHW44, - DNN_INC_FLOAT16(OprFormat::NCHW88)}); + {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44, + DNN_INC_FLOAT16(OprFormatConfigID::NCHW88)}); return ctx; } } // namespace /* ================= LayoutTransformContext ==================*/ LayoutTransformContext& LayoutTransformContext::add_opr_config( - Typeinfo* opr, OprFormat opr_format) { + Typeinfo* opr, OprFormatConfigID config_id) { auto& dispatchers = m_opr_configs[opr]; - dispatchers[opr_format] = + dispatchers[config_id] = OprTensorFormatsConfiguration::find_dispatcher_by_type_format( - opr, opr_format); + opr, config_id); return *this; } LayoutTransformContext& LayoutTransformContext::add_opr_config( - Typeinfo* opr, SmallVector opr_formats) { + Typeinfo* opr, SmallVector config_ids) { auto& dispatchers = m_opr_configs[opr]; - for (auto opr_fmt : opr_formats) { - dispatchers[opr_fmt] = - OprTensorFormatsConfiguration::find_dispatcher_by_type_format( - opr, opr_fmt); + for (auto cfg : config_ids) { + dispatchers[cfg] = + OprTensorFormatsConfiguration::find_dispatcher_by_type_format(opr, cfg); } return *this; } std::unique_ptr LayoutTransformContext::make( - Target target, OprFormat base_opr_format, TensorFormats base_tensor_format) { + Target target, OprFormatConfigID base_config_id, + TensorFormats base_tensor_format) { switch (target) { case Target::CUDA: - return make_cuda_ctx(base_opr_format, base_tensor_format); + return make_cuda_ctx(base_config_id, base_tensor_format); case Target::ARM: - return make_arm_ctx(base_opr_format, base_tensor_format); + return make_arm_ctx(base_config_id, base_tensor_format); default: mgb_assert(false, "unsupported target %s\n", target_to_string(target)); } diff --git a/src/gopt/impl/global_layout_transform/layout_transform_pass.cpp b/src/gopt/impl/global_layout_transform/layout_transform_pass.cpp index 206b51e4f..8a198a0c3 100644 --- a/src/gopt/impl/global_layout_transform/layout_transform_pass.cpp +++ b/src/gopt/impl/global_layout_transform/layout_transform_pass.cpp @@ -43,6 +43,7 @@ void LayoutTransformPass::apply(OptState& opt) const { auto partitions = extractor.extract(opt.graph().endpoint_vars()); using Solution = SolverBase::Solution; + using OprFormat = SolverBase::OprFormat; Solution solution; ThinHashSet endpoint_vars; for (auto&& partition : partitions) { @@ -60,7 +61,7 @@ void LayoutTransformPass::apply(OptState& opt) const { auto&& opr_configs = m_ctx->opr_configs(); auto&& base_fmt = m_ctx->attribute().base_tensor_formats; - auto&& base_opr_fmt = m_ctx->attribute().base_opr_format; + auto&& base_cfg_id = m_ctx->attribute().base_config_id; auto&& reformat_attribute = m_ctx->attribute().reformat_attribute; ThinHashMap var2fmts; static ThinHashSet format_aware_oprs = { @@ -69,18 +70,25 @@ void LayoutTransformPass::apply(OptState& opt) const { #undef cb }; auto rewriter = opt.graph().make_rewriter(); - auto on_opr = [&opr_configs, &base_fmt, &base_opr_fmt, &reformat_attribute, + auto on_opr = [&opr_configs, &base_fmt, &base_cfg_id, &reformat_attribute, &rewriter, &solution, &var2fmts, &endpoint_vars](OperatorNodeBase* opr) { auto it = solution.find(opr); if (it != solution.end()) { - auto opr_fmt = it->second; + auto cfg_id = it->second; auto find = opr_configs.find(opr->dyn_typeinfo()); Maybe fmtcfg = None; Maybe basecfg = None; + Maybe opr_fmt = None; if (find != opr_configs.end()) { - fmtcfg = (*find->second.at(opr_fmt))(opr); - basecfg = (*find->second.at(base_opr_fmt))(opr); + fmtcfg = (*find->second.at(cfg_id))(opr); + auto _ = OprTensorFormatsConfiguration::find_dispatcher_by_type_format( + opr->dyn_typeinfo(), base_cfg_id); + basecfg = (*_)(opr); + opr_fmt = fmtcfg.val().opr_format; + } else { + opr_fmt = + OprTensorFormatsConfiguration::safe_cast_to_opr_format(cfg_id); } VarNodeArray new_inp; size_t nr_inps = opr->input().size(); @@ -89,7 +97,7 @@ void LayoutTransformPass::apply(OptState& opt) const { nr_inps = std::min(fmtcfg.val().input_tensor_formats.size(), nr_inps); out_fmt = fmtcfg.val().output_tensor_formats[0]; } else { - out_fmt = opr_format_to_tensor_formats(opr_fmt); + out_fmt = opr_format_to_tensor_formats(opr_fmt.val()); } new_inp.resize(nr_inps); for (size_t i = 0; i < nr_inps; ++i) { @@ -103,7 +111,7 @@ void LayoutTransformPass::apply(OptState& opt) const { from = find->second; } auto to = fmtcfg.valid() ? fmtcfg.val().input_tensor_formats[i] - : opr_format_to_tensor_formats(opr_fmt); + : opr_format_to_tensor_formats(opr_fmt.val()); bool is_parameter = fmtcfg.valid() && fmtcfg.val().input_tensor_types[i] == TensorType::WEIGHT; @@ -119,7 +127,7 @@ void LayoutTransformPass::apply(OptState& opt) const { var->dtype().enumv()}; if (is_parameter) { auto aligned_desc = - ReformatManager::make_aligned_desc(base_fmt, out_fmt); + ReformatManager::make_aligned_desc(from, out_fmt); reformat = ReformatManager::instance() .auto_aligned_reformat_weight( var, key, aligned_desc); @@ -134,7 +142,7 @@ void LayoutTransformPass::apply(OptState& opt) const { } VarNode* new_out; if (format_aware_oprs.count(opr->dyn_typeinfo()) > 0) { - new_out = intl::modify_opr_format(opr_fmt, new_inp, opr); + new_out = intl::modify_opr_format(opr_fmt.val(), new_inp, opr); } else { new_out = serialization::copy_opr_shallow(*opr, new_inp, opr->config()) ->output(0); @@ -170,9 +178,8 @@ void LayoutTransformPass::apply(OptState& opt) const { ovar, new_ovar, mgb_cstr_log(ssprintf( "replace opr(%s) to new opr " - "format(%s)", - opr->cname(), - opr_format_to_string(opr_fmt)) + "format config(%s)", + opr->cname(), config_id_to_string(cfg_id)) .c_str())); } } else { diff --git a/src/gopt/impl/global_layout_transform/opr_format_modifier.h b/src/gopt/impl/global_layout_transform/opr_format_modifier.h index c1d200285..11e634f68 100644 --- a/src/gopt/impl/global_layout_transform/opr_format_modifier.h +++ b/src/gopt/impl/global_layout_transform/opr_format_modifier.h @@ -24,7 +24,7 @@ namespace intl { bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr); VarNode* modify_opr_format( - opr::ConvBias::Param::Format opr_format, const VarNodeArray& i, + opr::Convolution::Param::Format opr_format, const VarNodeArray& i, const cg::OperatorNodeBase* opr); } // namespace intl diff --git a/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp b/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp index f5f450f5e..0605a197c 100644 --- a/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp +++ b/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp @@ -25,7 +25,8 @@ MIDOUT_DECL(megbrain_opr_tensor_formats_config) using namespace mgb; using namespace cg; using namespace gopt; -using OprFormat = opr::ConvBias::Param::Format; +using OprFormat = OprTensorFormatsConfiguration::OprFormat; +using OprFormatConfigID = OprTensorFormatsConfiguration::OprFormatConfigID; namespace { template @@ -56,19 +57,22 @@ static bool is_channel_wise_conv(const OperatorNodeBase* opr) { if (format == Opr::Param::Format::NCHW) { ocpg = weight_shp[1], icpg = weight_shp[2]; return ocpg == 1 && icpg == 1; + } else { + mgb_assert(false, "invalid opr format(%s)", opr_format_to_string(format)); } return false; } -template +template struct OprSingleInOutTensorFormatsDispatcherImpl; template <> -struct OprSingleInOutTensorFormatsDispatcherImpl { +struct OprSingleInOutTensorFormatsDispatcherImpl { static Maybe dispatch(const OperatorNodeBase* opr) { OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW; + config.config_id = OprFormatConfigID::NCHW; config.input_dtypes = {opr->input(0)->dtype().enumv()}; config.input_tensor_types = {TensorType::FEATURE}; config.output_dtypes = {opr->output(0)->dtype().enumv()}; @@ -79,11 +83,12 @@ struct OprSingleInOutTensorFormatsDispatcherImpl { }; template <> -struct OprSingleInOutTensorFormatsDispatcherImpl { +struct OprSingleInOutTensorFormatsDispatcherImpl { static Maybe dispatch(const OperatorNodeBase* opr) { OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW44; + config.config_id = OprFormatConfigID::NCHW44; bool available = true; available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float32; config.input_dtypes = {opr->input(0)->dtype().enumv()}; @@ -99,11 +104,12 @@ struct OprSingleInOutTensorFormatsDispatcherImpl { #if !MEGDNN_DISABLE_FLOAT16 template <> -struct OprSingleInOutTensorFormatsDispatcherImpl { +struct OprSingleInOutTensorFormatsDispatcherImpl { static Maybe dispatch(const OperatorNodeBase* opr) { OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW88; + config.config_id = OprFormatConfigID::NCHW88; bool available = true; available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float16; config.input_dtypes = {opr->input(0)->dtype().enumv()}; @@ -119,11 +125,12 @@ struct OprSingleInOutTensorFormatsDispatcherImpl { #endif template <> -struct OprSingleInOutTensorFormatsDispatcherImpl { +struct OprSingleInOutTensorFormatsDispatcherImpl { static Maybe dispatch(const OperatorNodeBase* opr) { OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW4; + config.config_id = OprFormatConfigID::NCHW4; bool available = true; available &= opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8; config.input_dtypes = {opr->input(0)->dtype().enumv()}; @@ -139,11 +146,12 @@ struct OprSingleInOutTensorFormatsDispatcherImpl { }; template <> -struct OprSingleInOutTensorFormatsDispatcherImpl { +struct OprSingleInOutTensorFormatsDispatcherImpl { static Maybe dispatch(const OperatorNodeBase* opr) { OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::CHWN4; + config.config_id = OprFormatConfigID::CHWN4; bool available = true; available &= opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8; config.input_dtypes = {opr->input(0)->dtype().enumv()}; @@ -159,11 +167,12 @@ struct OprSingleInOutTensorFormatsDispatcherImpl { }; template <> -struct OprSingleInOutTensorFormatsDispatcherImpl { +struct OprSingleInOutTensorFormatsDispatcherImpl { static Maybe dispatch(const OperatorNodeBase* opr) { OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW32; + config.config_id = OprFormatConfigID::NCHW32; bool available = true; available &= opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8; config.input_dtypes = {opr->input(0)->dtype().enumv()}; @@ -179,11 +188,12 @@ struct OprSingleInOutTensorFormatsDispatcherImpl { }; template <> -struct OprSingleInOutTensorFormatsDispatcherImpl { +struct OprSingleInOutTensorFormatsDispatcherImpl { static Maybe dispatch(const OperatorNodeBase* opr) { OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NHWC; + config.config_id = OprFormatConfigID::NHWC; bool available = true; available &= opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm || opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4; @@ -200,11 +210,12 @@ struct OprSingleInOutTensorFormatsDispatcherImpl { }; template <> -struct OprSingleInOutTensorFormatsDispatcherImpl { +struct OprSingleInOutTensorFormatsDispatcherImpl { static Maybe dispatch(const OperatorNodeBase* opr) { OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW64; + config.config_id = OprFormatConfigID::NCHW64; bool available = true; available &= opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm || opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4; @@ -220,16 +231,17 @@ struct OprSingleInOutTensorFormatsDispatcherImpl { } }; -template +template struct ConvTensorFormatsDispatcherImpl; template -struct ConvTensorFormatsDispatcherImpl { +struct ConvTensorFormatsDispatcherImpl { static Maybe dispatch(const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW; + config.config_id = OprFormatConfigID::NCHW; // setup dtypes for (size_t i = 0; i < opr->input().size(); ++i) { config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); @@ -260,37 +272,35 @@ struct ConvTensorFormatsDispatcherImpl { }; template -struct ConvTensorFormatsDispatcherImpl { +struct ConvTensorFormatsDispatcherImpl { static Maybe dispatch(const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NHWC; + config.config_id = OprFormatConfigID::NHWC; + auto check_dtype = [](const DType& dt) { + bool i4_config = dt.enumv() == DTypeEnum::Quantized4Asymm || + dt.enumv() == DTypeEnum::QuantizedS4; + bool i8_config = dt.enumv() == DTypeEnum::QuantizedS8; + return i4_config || i8_config; + }; bool available = true; for (size_t i = 0; i < opr->input().size(); ++i) { if (i == 2) available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS32; else { - bool i4_config = - opr->input(i)->dtype().enumv() == DTypeEnum::Quantized4Asymm || - opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS4; - bool i8_config = - opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8; - available &= (i4_config || i8_config); + available &= check_dtype(opr->input(i)->dtype()); } config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; config.input_tensor_types.emplace_back(tensor_type); } - bool i4_config = - opr->output(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm || - opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS4; - bool i8_config = opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; - available &= (i4_config || i8_config); + available &= check_dtype(opr->output(0)->dtype()); config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); available &= conv.param().sparse == Opr::Param::Sparse::DENSE; config.input_tensor_formats = { - TensorFormats::NHWC, TensorFormats::NHWC, TensorFormats::NHWC, + TensorFormats::NHWC, TensorFormats::KRSC, TensorFormats::NHWC, TensorFormats::NHWC}; config.output_tensor_formats = {TensorFormats::NHWC}; if (available) @@ -300,12 +310,13 @@ struct ConvTensorFormatsDispatcherImpl { }; template -struct ConvTensorFormatsDispatcherImpl { +struct ConvTensorFormatsDispatcherImpl { static Maybe dispatch(const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW4; + config.config_id = OprFormatConfigID::NCHW4; bool available = true; // setup dtypes for (size_t i = 0; i < opr->input().size(); ++i) { @@ -322,7 +333,7 @@ struct ConvTensorFormatsDispatcherImpl { // setup tensor formats if (conv.param().sparse == Opr::Param::Sparse::DENSE) { config.input_tensor_formats = { - TensorFormats::NCHWc4, TensorFormats::NCHWc4, TensorFormats::NCHWc4, + TensorFormats::NCHWc4, TensorFormats::KCRSc4, TensorFormats::NCHWc4, TensorFormats::NCHWc4}; } else { mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); @@ -344,12 +355,75 @@ struct ConvTensorFormatsDispatcherImpl { }; template -struct ConvTensorFormatsDispatcherImpl { +struct ConvTensorFormatsDispatcherImpl { + static Maybe dispatch(const OperatorNodeBase* opr) { + const auto& conv = opr->cast_final_safe(); + OprTensorFormatsConfiguration config; + config.typeinfo = opr->dyn_typeinfo(); + config.opr_format = OprFormat::NCHW4_NCHW32; + config.config_id = OprFormatConfigID::NCHW4_NCHW32; + bool available = true; + for (size_t i = 0; i < opr->input().size(); ++i) { + if (i == 2) + available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS32; + else + available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8; + config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); + TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; + config.input_tensor_types.emplace_back(tensor_type); + } + available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; + config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); + available &= conv.param().sparse == Opr::Param::Sparse::DENSE; + config.input_tensor_formats = { + TensorFormats::NCHWc4, TensorFormats::KCRSc4, TensorFormats::NCHWc32, + TensorFormats::NCHWc32}; + config.output_tensor_formats = {TensorFormats::NCHWc32}; + if (available) + return config; + return None; + } +}; + +template +struct ConvTensorFormatsDispatcherImpl { + static Maybe dispatch(const OperatorNodeBase* opr) { + const auto& conv = opr->cast_final_safe(); + OprTensorFormatsConfiguration config; + config.typeinfo = opr->dyn_typeinfo(); + config.opr_format = OprFormat::NCHW4_NCHW; + config.config_id = OprFormatConfigID::NCHW4_NCHW; + bool available = true; + for (size_t i = 0; i < opr->input().size(); ++i) { + if (i >= 2) + available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32; + else + available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8; + config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); + TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; + config.input_tensor_types.emplace_back(tensor_type); + } + available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32; + config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); + available &= conv.param().sparse == Opr::Param::Sparse::DENSE; + config.input_tensor_formats = { + TensorFormats::NCHWc4, TensorFormats::KCRSc4, TensorFormats::NCHW, + TensorFormats::NCHW}; + config.output_tensor_formats = {TensorFormats::NCHW}; + if (available) + return config; + return None; + } +}; + +template +struct ConvTensorFormatsDispatcherImpl { static Maybe dispatch(const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW32; + config.config_id = OprFormatConfigID::NCHW32; bool available = true; for (size_t i = 0; i < opr->input().size(); ++i) { if (i == 2) @@ -364,7 +438,7 @@ struct ConvTensorFormatsDispatcherImpl { config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); available &= conv.param().sparse == Opr::Param::Sparse::DENSE; config.input_tensor_formats = { - TensorFormats::NCHWc32, TensorFormats::NCHWc32, TensorFormats::NCHWc32, + TensorFormats::NCHWc32, TensorFormats::KCRSc32, TensorFormats::NCHWc32, TensorFormats::NCHWc32}; config.output_tensor_formats = {TensorFormats::NCHWc32}; if (available) @@ -374,12 +448,44 @@ struct ConvTensorFormatsDispatcherImpl { }; template -struct ConvTensorFormatsDispatcherImpl { +struct ConvTensorFormatsDispatcherImpl { + static Maybe dispatch(const OperatorNodeBase* opr) { + const auto& conv = opr->cast_final_safe(); + OprTensorFormatsConfiguration config; + config.typeinfo = opr->dyn_typeinfo(); + config.opr_format = OprFormat::NCHW32_NCHW4; + config.config_id = OprFormatConfigID::NCHW32_NCHW4; + bool available = true; + for (size_t i = 0; i < opr->input().size(); ++i) { + if (i == 2) + available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS32; + else + available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8; + config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); + TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; + config.input_tensor_types.emplace_back(tensor_type); + } + available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; + config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); + available &= conv.param().sparse == Opr::Param::Sparse::DENSE; + config.input_tensor_formats = { + TensorFormats::NCHWc32, TensorFormats::KCRSc32, TensorFormats::NCHWc4, + TensorFormats::NCHWc4}; + config.output_tensor_formats = {TensorFormats::NCHWc4}; + if (available) + return config; + return None; + } +}; + +template +struct ConvTensorFormatsDispatcherImpl { static Maybe dispatch(const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW64; + config.config_id = OprFormatConfigID::NCHW64; bool available = true; for (size_t i = 0; i < opr->input().size(); ++i) { if (i == 2) @@ -397,7 +503,7 @@ struct ConvTensorFormatsDispatcherImpl { config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); available &= conv.param().sparse == Opr::Param::Sparse::DENSE; config.input_tensor_formats = { - TensorFormats::NCHWc64, TensorFormats::NCHWc64, TensorFormats::NCHWc64, + TensorFormats::NCHWc64, TensorFormats::KCRSc64, TensorFormats::NCHWc64, TensorFormats::NCHWc64}; config.output_tensor_formats = {TensorFormats::NCHWc64}; if (available) @@ -407,12 +513,13 @@ struct ConvTensorFormatsDispatcherImpl { }; template -struct ConvTensorFormatsDispatcherImpl { +struct ConvTensorFormatsDispatcherImpl { static Maybe dispatch(const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::CHWN4; + config.config_id = OprFormatConfigID::CHWN4; bool available = true; for (size_t i = 0; i < opr->input().size(); ++i) { if (i == 2) @@ -427,7 +534,7 @@ struct ConvTensorFormatsDispatcherImpl { config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); available &= conv.param().sparse == Opr::Param::Sparse::DENSE; config.input_tensor_formats = { - TensorFormats::CHWNc4, TensorFormats::CHWNc4, TensorFormats::CHWNc4, + TensorFormats::CHWNc4, TensorFormats::CRSKc4, TensorFormats::CHWNc4, TensorFormats::CHWNc4}; config.output_tensor_formats = {TensorFormats::CHWNc4}; if (available) @@ -437,12 +544,13 @@ struct ConvTensorFormatsDispatcherImpl { }; template -struct ConvTensorFormatsDispatcherImpl { +struct ConvTensorFormatsDispatcherImpl { static Maybe dispatch(const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW44; + config.config_id = OprFormatConfigID::NCHW44; bool available = true; // setup dtypes for (size_t i = 0; i < opr->input().size(); ++i) { @@ -477,14 +585,44 @@ struct ConvTensorFormatsDispatcherImpl { } }; +template +struct ConvTensorFormatsDispatcherImpl { + static Maybe dispatch(const OperatorNodeBase* opr) { + const auto& conv = opr->cast_final_safe(); + OprTensorFormatsConfiguration config; + config.typeinfo = opr->dyn_typeinfo(); + config.opr_format = OprFormat::NCHW44; + config.config_id = OprFormatConfigID::NCHW44_HYBRID; + bool available = true; + // setup dtypes + for (size_t i = 0; i < opr->input().size(); ++i) { + available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32; + config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); + TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; + config.input_tensor_types.emplace_back(tensor_type); + } + available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32; + config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); + available &= conv.param().sparse == Opr::Param::Sparse::DENSE; + config.input_tensor_formats = { + TensorFormats::NCHW, TensorFormats::KRSCk4, TensorFormats::NCHWc4, + TensorFormats::NCHWc4}; + config.output_tensor_formats = {TensorFormats::NCHWc4}; + if (!available) + return None; + return config; + } +}; + #if !MEGDNN_DISABLE_FLOAT16 template -struct ConvTensorFormatsDispatcherImpl { +struct ConvTensorFormatsDispatcherImpl { static Maybe dispatch(const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW88; + config.config_id = OprFormatConfigID::NCHW88; bool available = true; // setup dtypes for (size_t i = 0; i < opr->input().size(); ++i) { @@ -518,15 +656,46 @@ struct ConvTensorFormatsDispatcherImpl { return config; } }; + +template +struct ConvTensorFormatsDispatcherImpl { + static Maybe dispatch(const OperatorNodeBase* opr) { + const auto& conv = opr->cast_final_safe(); + OprTensorFormatsConfiguration config; + config.typeinfo = opr->dyn_typeinfo(); + config.opr_format = OprFormat::NCHW88; + config.config_id = OprFormatConfigID::NCHW88_HYBRID; + bool available = true; + // setup dtypes + for (size_t i = 0; i < opr->input().size(); ++i) { + available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float16; + config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); + TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; + config.input_tensor_types.emplace_back(tensor_type); + } + available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float16; + config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); + available &= conv.param().sparse == Opr::Param::Sparse::DENSE; + // setup tensor formats + config.input_tensor_formats = { + TensorFormats::NCHW, TensorFormats::KRSCk8, TensorFormats::NCHWc8, + TensorFormats::NCHWc8}; + config.output_tensor_formats = {TensorFormats::NCHWc8}; + if (!available) + return None; + return config; + } +}; #endif template -struct ConvTensorFormatsDispatcherImpl { +struct ConvTensorFormatsDispatcherImpl { static Maybe dispatch(const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW44_DOT; + config.config_id = OprFormatConfigID::NCHW44_DOT; bool available = true; // setup dtypes for (size_t i = 0; i < opr->input().size(); ++i) { @@ -566,14 +735,53 @@ struct ConvTensorFormatsDispatcherImpl { } }; +template +struct ConvTensorFormatsDispatcherImpl { + static Maybe dispatch(const OperatorNodeBase* opr) { + const auto& conv = opr->cast_final_safe(); + OprTensorFormatsConfiguration config; + config.typeinfo = opr->dyn_typeinfo(); + config.opr_format = OprFormat::NCHW44_DOT; + config.config_id = OprFormatConfigID::NCHW44_DOT_HYBRID; + bool available = true; + // setup dtypes + for (size_t i = 0; i < opr->input().size(); ++i) { + if (i == 2) { + available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS32; + } else { + available &= + opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8 || + opr->input(i)->dtype().enumv() == DTypeEnum::Quantized8Asymm; + } + config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); + TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; + config.input_tensor_types.emplace_back(tensor_type); + } + available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8 || + opr->output(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm; + config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); + available &= conv.param().sparse == Opr::Param::Sparse::DENSE; + // setup tensor formats + config.input_tensor_formats = { + TensorFormats::NCHW, TensorFormats::KRSCk4, TensorFormats::NCHWc4, + TensorFormats::NCHWc4}; + config.output_tensor_formats = {TensorFormats::NCHWc4}; + if (!available) + return None; + return config; + } +}; + template <> -struct ConvTensorFormatsDispatcherImpl { +struct ConvTensorFormatsDispatcherImpl< + opr::ConvolutionBackwardData, OprFormatConfigID::NCHW> { using Opr = opr::ConvolutionBackwardData; static Maybe dispatch(const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW; + config.config_id = OprFormatConfigID::NCHW; // setup dtypes for (size_t i = 0; i < opr->input().size(); ++i) { config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); @@ -584,7 +792,7 @@ struct ConvTensorFormatsDispatcherImpl -struct ConvTensorFormatsDispatcherImpl { +struct ConvTensorFormatsDispatcherImpl< + opr::ConvolutionBackwardData, OprFormatConfigID::NCHW4> { using Opr = opr::ConvolutionBackwardData; static Maybe dispatch(const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NCHW4; + config.config_id = OprFormatConfigID::NCHW4; bool available = true; for (size_t i = 0; i < opr->input().size(); ++i) { available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8; @@ -622,7 +832,7 @@ struct ConvTensorFormatsDispatcherImploutput(0)->dtype().enumv()); available &= conv.param().sparse == opr::ConvBias::Param::Sparse::DENSE; config.input_tensor_formats = { - TensorFormats::NCHWc4, TensorFormats::NCHWc4, TensorFormats::NCHWc4, + TensorFormats::KCRSc4, TensorFormats::NCHWc4, TensorFormats::NCHWc4, TensorFormats::NCHWc4}; config.output_tensor_formats = {TensorFormats::NCHWc4}; if (available) @@ -632,13 +842,15 @@ struct ConvTensorFormatsDispatcherImpl -struct ConvTensorFormatsDispatcherImpl { +struct ConvTensorFormatsDispatcherImpl< + opr::ConvolutionBackwardData, OprFormatConfigID::NHWC> { using Opr = opr::ConvolutionBackwardData; static Maybe dispatch(const OperatorNodeBase* opr) { const auto& conv = opr->cast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); config.opr_format = OprFormat::NHWC; + config.config_id = OprFormatConfigID::NHWC; bool available = true; for (size_t i = 0; i < opr->input().size(); ++i) { available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8; @@ -650,7 +862,7 @@ struct ConvTensorFormatsDispatcherImploutput(0)->dtype().enumv()); available &= conv.param().sparse == opr::ConvBias::Param::Sparse::DENSE; config.input_tensor_formats = { - TensorFormats::NHWC, TensorFormats::NHWC, TensorFormats::NHWC, + TensorFormats::KRSC, TensorFormats::NHWC, TensorFormats::NHWC, TensorFormats::NHWC}; config.output_tensor_formats = {TensorFormats::NHWC}; if (available) @@ -661,7 +873,7 @@ struct ConvTensorFormatsDispatcherImpl& val) const { + size_t operator()(const std::pair& val) const { size_t h1 = mgb::hash(val.first); size_t h2 = std::hash()(static_cast(val.second)); return mgb::hash_pair_combine(h1, h2); @@ -670,28 +882,29 @@ struct StaticData { using OprTensorFormatsDispatcher = OprTensorFormatsConfiguration::OprTensorFormatsDispatcher; std::unordered_map< - std::pair, OprTensorFormatsDispatcher, KeyHash> + std::pair, OprTensorFormatsDispatcher, + KeyHash> typefmt2dispatcher; StaticData(); }; StaticData::StaticData() { -#define OPR_TENSOR_FORMATS_CONFIG_REG(_Opr, _fmt) \ - typefmt2dispatcher[{opr::_Opr::typeinfo(), OprFormat::_fmt}] = \ - [](const OperatorNodeBase* opr) { \ - MIDOUT_B(opr::_Opr, midout_iv(OprFormat::_fmt)) \ - return ConvTensorFormatsDispatcherImpl< \ - opr::_Opr, OprFormat::_fmt>::dispatch(opr); \ - MIDOUT_E \ +#define OPR_TENSOR_FORMATS_CONFIG_REG(_Opr, _fmt) \ + typefmt2dispatcher[{opr::_Opr::typeinfo(), OprFormatConfigID::_fmt}] = \ + [](const OperatorNodeBase* opr) { \ + MIDOUT_B(opr::_Opr, midout_iv(OprFormatConfigID::_fmt)) \ + return ConvTensorFormatsDispatcherImpl< \ + opr::_Opr, OprFormatConfigID::_fmt>::dispatch(opr); \ + MIDOUT_E \ } -#define OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(_Opr, _fmt) \ - typefmt2dispatcher[{opr::_Opr::typeinfo(), OprFormat::_fmt}] = \ - [](const OperatorNodeBase* opr) { \ - MIDOUT_B(opr::_Opr, midout_iv(OprFormat::_fmt)) \ - return OprSingleInOutTensorFormatsDispatcherImpl< \ - OprFormat::_fmt>::dispatch(opr); \ - MIDOUT_E \ +#define OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(_Opr, _fmt) \ + typefmt2dispatcher[{opr::_Opr::typeinfo(), OprFormatConfigID::_fmt}] = \ + [](const OperatorNodeBase* opr) { \ + MIDOUT_B(opr::_Opr, midout_iv(OprFormatConfigID::_fmt)) \ + return OprSingleInOutTensorFormatsDispatcherImpl< \ + OprFormatConfigID::_fmt>::dispatch(opr); \ + MIDOUT_E \ } OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW); @@ -703,16 +916,22 @@ StaticData::StaticData() { OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44); #if !MEGDNN_DISABLE_FLOAT16 OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88); + OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88_HYBRID); #endif OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT); + OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_HYBRID); + OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT_HYBRID); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW4); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44); #if !MEGDNN_DISABLE_FLOAT16 OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW88); + OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW88_HYBRID); #endif OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT); + OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_HYBRID); + OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT_HYBRID); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWC); @@ -752,14 +971,14 @@ StaticData& static_data() { OprTensorFormatsConfiguration::OprTensorFormatsDispatcher* OprTensorFormatsConfiguration::find_dispatcher_by_type_format( - Typeinfo* type, OprFormat opr_format) { + Typeinfo* type, OprFormatConfigID config_id) { auto&& typefmt2dispatcher = static_data().typefmt2dispatcher; - auto iter = typefmt2dispatcher.find(std::make_pair(type, opr_format)); + auto iter = typefmt2dispatcher.find(std::make_pair(type, config_id)); mgb_assert( iter != typefmt2dispatcher.end(), "cannot find OprTensorFormatsDispatcher for opr type(%s) and " - "opr format(%s)", - type->name, opr_format_to_string(opr_format)); + "opr format configuration id(%s)", + type->name, config_id_to_string(config_id)); return &iter->second; } diff --git a/src/gopt/impl/global_layout_transform/profiler_cache.cpp b/src/gopt/impl/global_layout_transform/profiler_cache.cpp index e27dbdcd4..df3a0dd93 100644 --- a/src/gopt/impl/global_layout_transform/profiler_cache.cpp +++ b/src/gopt/impl/global_layout_transform/profiler_cache.cpp @@ -64,7 +64,7 @@ void ProfilerCache::Key::build_blob_from_opr() { // serialize opr_format m_blob_storage.append( - std::to_string(static_cast(m_key_impl.opr_key.opr_format))); + std::to_string(static_cast(m_key_impl.opr_key.config_id))); // serialize extra_attribute m_blob_storage.append( diff --git a/src/gopt/impl/global_layout_transform/profiler_impl.cpp b/src/gopt/impl/global_layout_transform/profiler_impl.cpp index 58c392829..6c7001953 100644 --- a/src/gopt/impl/global_layout_transform/profiler_impl.cpp +++ b/src/gopt/impl/global_layout_transform/profiler_impl.cpp @@ -29,30 +29,6 @@ using namespace gopt; using ReformatKey = ReformatManager::ReformatKey; namespace { -using OprFormat = Problem::OprFormat; -OprFormat tensor_formats_to_opr_format(TensorFormats tensor_format) { - switch (tensor_format) { - case TensorFormats::NCHW: - return OprFormat::NCHW; - case TensorFormats::NCHWc4: - return OprFormat::NCHW44; - case TensorFormats::NCHWc8: - return OprFormat::NCHW88; - case TensorFormats::NCHWc32: - return OprFormat::NCHW32; - case TensorFormats::NCHWc64: - return OprFormat::NCHW64; - case TensorFormats::NHWC: - return OprFormat::NHWC; - case TensorFormats::CHWNc4: - return OprFormat::CHWN4; - default: - mgb_throw( - MegBrainError, "tensor format(%u) is not supported", - static_cast(tensor_format)); - } -} - class GraphPartitionProfiler final : public PluginBase { using CompNodeEventPtr = std::unique_ptr; @@ -214,8 +190,8 @@ ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( record.opr = opr; auto& costs = record.costs; for (auto&& f : available_tensor_formats) { - auto opr_format = tensor_formats_to_opr_format(f); - costs[opr_format] = profile_operator(opr, base_format, f, extra_attribute); + auto config_id = tensor_formats_to_config_id(f); + costs[config_id] = profile_operator(opr, base_format, f, extra_attribute); } return record; } @@ -261,7 +237,7 @@ ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( record.opr = opr; auto& costs = record.costs; for (auto&& i : available_configs) { - costs[i.opr_format] = profile_operator(opr, base_config, i, extra_attribute); + costs[i.config_id] = profile_operator(opr, base_config, i, extra_attribute); } return record; } @@ -316,7 +292,6 @@ float ProfilerImpl::profile_operator( new_inps[i] = imm.node(); } VarNode* y = mgb::gopt::intl::modify_opr_format(config.opr_format, new_inps, opr); -#if 0 static const ThinHashSet multi_algo_oprs = { opr::Convolution::typeinfo(), opr::ConvBiasForward::typeinfo(), @@ -326,7 +301,6 @@ float ProfilerImpl::profile_operator( if (multi_algo_oprs.count(opr->dyn_typeinfo()) && !mgb::gopt::intl::has_available_algo(new_inps, y->owner_opr())) return PROFILE_TIME_OUT; -#endif if (!m_opr_filter(opr, y->owner_opr())) return PROFILE_TIME_OUT; auto mark = MarkInputContiguous::make(SymbolVar(y)); @@ -494,6 +468,30 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile(const Problem& problem) cons return profiling_result; } +ProfilerImpl::OprFormatConfigID ProfilerImpl::tensor_formats_to_config_id( + TensorFormats tensor_format) const { + switch (tensor_format) { + case TensorFormats::NCHW: + return OprFormatConfigID::NCHW; + case TensorFormats::NCHWc4: + return OprFormatConfigID::NCHW4; + case TensorFormats::NCHWc8: + return OprFormatConfigID::NCHW8; + case TensorFormats::NCHWc32: + return OprFormatConfigID::NCHW32; + case TensorFormats::NCHWc64: + return OprFormatConfigID::NCHW64; + case TensorFormats::NHWC: + return OprFormatConfigID::NHWC; + case TensorFormats::CHWNc4: + return OprFormatConfigID::CHWN4; + default: + mgb_throw( + MegBrainError, "tensor format(%u) is not supported", + static_cast(tensor_format)); + } +} + /* ================== ProfilerBase =================*/ std::string ProfilerBase::OperatorNodeRecord::to_string() const { auto str = ssprintf( @@ -508,7 +506,7 @@ std::string ProfilerBase::OperatorNodeRecord::to_string() const { opr->output(0)->shape().to_string().c_str()); for (auto&& cpair : costs) { str += ssprintf( - "\tformat: %s; cost:%f", opr_format_to_string(cpair.first), + "\tconfig: %s; cost:%f", config_id_to_string(cpair.first), cpair.second); } return str; @@ -557,7 +555,7 @@ float CachedProfiler::profile_operator( const OperatorNodeBase* opr, TensorFormats base_format, TensorFormats tensor_format, ReformatAttribute extra_attribute) const { ProfilerCache::Key key{ - opr, tensor_formats_to_opr_format(tensor_format), extra_attribute}; + opr, tensor_formats_to_config_id(tensor_format), extra_attribute}; auto ret = ProfilerCache::inst().get(key); if (ret.valid()) return ret.val(); @@ -571,7 +569,7 @@ float CachedProfiler::profile_operator( const OperatorNodeBase* opr, const OprTensorFormatsConfiguration& base_config, const OprTensorFormatsConfiguration& config, ReformatAttribute extra_attribute) const { - ProfilerCache::Key key{opr, config.opr_format, extra_attribute}; + ProfilerCache::Key key{opr, config.config_id, extra_attribute}; auto ret = ProfilerCache::inst().get(key); if (ret.valid()) return ret.val(); diff --git a/src/gopt/impl/global_layout_transform/profiling_based_solver.cpp b/src/gopt/impl/global_layout_transform/profiling_based_solver.cpp index e7039c999..43f39ce11 100644 --- a/src/gopt/impl/global_layout_transform/profiling_based_solver.cpp +++ b/src/gopt/impl/global_layout_transform/profiling_based_solver.cpp @@ -48,7 +48,8 @@ ProfilingBasedSolver::ProfilingBasedSolver(std::unique_ptr profile }; m_problem_filter = [](const Problem& problem) { - auto&& base_opr_format = problem.attribute().base_opr_format; + auto&& base_opr_format = OprTensorFormatsConfiguration::safe_cast_to_opr_format( + problem.attribute().base_config_id); bool has_format_aware_opr = false; for (auto&& opr : problem.graph_partition().all_oprs()) { auto iter = format_aware_opr_validators.find(opr->dyn_typeinfo()); diff --git a/src/gopt/impl/global_layout_transform/utils.h b/src/gopt/impl/global_layout_transform/utils.h index 336f4a9d8..9cc3b82d2 100644 --- a/src/gopt/impl/global_layout_transform/utils.h +++ b/src/gopt/impl/global_layout_transform/utils.h @@ -40,6 +40,37 @@ static inline const char* opr_format_to_string( #undef cb } +static inline const char* config_id_to_string( + OprTensorFormatsConfiguration::OprFormatConfigID config_id) { + using OprFormatConfigID = OprTensorFormatsConfiguration::OprFormatConfigID; +#define cb(_fmt) \ + case OprFormatConfigID::_fmt: \ + return #_fmt + switch (config_id) { + cb(NCHW); + cb(NHWC); + cb(NCHW4); + cb(NCHW8); + cb(NCHW4_NCHW32); + cb(NCHW4_NCHW); + cb(NCHW32); + cb(NCHW32_NCHW4); + cb(NCHW64); + cb(CHWN4); + cb(NCHW44); + cb(NCHW44_HYBRID); + cb(NCHW88); + cb(NCHW88_HYBRID); + cb(NCHW44_DOT); + cb(NCHW44_DOT_HYBRID); + default: + mgb_assert( + false, "Invalid config id(got:%u)", + static_cast(config_id)); + } +#undef cb +} + static inline TensorFormats opr_format_to_tensor_formats( OprTensorFormatsConfiguration::OprFormat opr_format) { using OprFormat = OprTensorFormatsConfiguration::OprFormat; @@ -60,6 +91,8 @@ static inline TensorFormats opr_format_to_tensor_formats( return TensorFormats::NCHWc8; case OprFormat::NCHW44: return TensorFormats::NCHWc4; + case OprFormat::NCHW8: + return TensorFormats::NCHWc8; default: mgb_throw( AssertionError, "format(%s) is not supported", @@ -124,9 +157,17 @@ static inline megdnn::NamedTensorShape tensor_formats_to_named_tensor_shape( return {{"G"}, {"K"}, {"C"}, {"R"}, {"S"}}; case TensorFormats::C11RS: return {{"C"}, {"C%1"}, {"C%1"}, {"R"}, {"S"}}; + case TensorFormats::KRSC: + return {{"K"}, {"R"}, {"S"}, {"C"}}; + case TensorFormats::KCRSc32: + return {{"K"}, {"C//32"}, {"R"}, {"S"}, {"C%32"}}; + case TensorFormats::KCRSc64: + return {{"K"}, {"C//64"}, {"R"}, {"S"}, {"C%64"}}; + case TensorFormats::CRSKc4: + return {{"C//4"}, {"R"}, {"S"}, {"K"}, {"C%4"}}; default: mgb_throw( - AssertionError, "invalid tensor formats(%u)", + MegBrainError, "invalid tensor formats(%u)", static_cast(format)); } } diff --git a/src/gopt/include/megbrain/gopt/layout_transform_context.h b/src/gopt/include/megbrain/gopt/layout_transform_context.h index d2e677e34..7dfc907fa 100644 --- a/src/gopt/include/megbrain/gopt/layout_transform_context.h +++ b/src/gopt/include/megbrain/gopt/layout_transform_context.h @@ -26,19 +26,48 @@ namespace gopt { * configuration of the opr format */ struct OprTensorFormatsConfiguration { - using OprFormat = opr::ConvBias::Param::Format; + using OprFormat = opr::Convolution::Param::Format; + static constexpr uint32_t FORMAT_NR_MEMBER = + opr::Convolution::Param::FORMAT_NR_MEMBER; + enum class OprFormatConfigID : uint32_t { +#define cb(fmt_) fmt_ = static_cast(OprFormat::fmt_) + cb(NCHW), + cb(NHWC), + cb(NHWCD4), + cb(NCHW4), + cb(NCHW8), + cb(NCHW32), + cb(NCHW88), + cb(NCHW44), + cb(NCHW44_DOT), + cb(NCHW4_NCHW32), + cb(NCHW32_NCHW4), + cb(NCHW4_NCHW), + cb(NCHW4_NHWC), + cb(CHWN4), + cb(NCHW64), + NCHW44_HYBRID = FORMAT_NR_MEMBER, + NCHW88_HYBRID = FORMAT_NR_MEMBER + 1, + NCHW44_DOT_HYBRID = FORMAT_NR_MEMBER + 2, + }; +#undef cb using OprTensorFormatsDispatcher = thin_function( const cg::OperatorNodeBase*)>; Typeinfo* typeinfo; OprFormat opr_format; + OprFormatConfigID config_id; SmallVector input_dtypes; SmallVector output_dtypes; SmallVector input_tensor_formats; SmallVector input_tensor_types; SmallVector output_tensor_formats; static OprTensorFormatsDispatcher* find_dispatcher_by_type_format( - Typeinfo* type, OprFormat opr_format); + Typeinfo* type, OprFormatConfigID config_id); + static OprFormat safe_cast_to_opr_format(OprFormatConfigID config_id) { + mgb_assert(static_cast(config_id) < FORMAT_NR_MEMBER); + return static_cast(static_cast(config_id)); + } }; /*! @@ -48,14 +77,15 @@ class LayoutTransformContext { public: using OprList = SubGraphExtractor::OprList; using OprFormat = OprTensorFormatsConfiguration::OprFormat; + using OprFormatConfigID = OprTensorFormatsConfiguration::OprFormatConfigID; using OprTensorFormatsDispatcher = OprTensorFormatsConfiguration::OprTensorFormatsDispatcher; - using OprConfigTrait = - ThinHashMap>; + using OprConfigTrait = ThinHashMap< + Typeinfo*, ThinHashMap>; using Target = GraphTuningOptions::Target; using ReformatAttribute = ReformatManager::ReformatKey::Attribute; struct Attribute { - OprFormat base_opr_format; /// the base opr format indicates that the + OprFormatConfigID base_config_id; /// the base opr format indicates that the /// network to be optimized is constructed /// in the base opr format, i.e. all the /// format aware operators (conv, conv_bias, @@ -97,21 +127,22 @@ public: /*! * \brief add an op format configuration for a particular operator type * \param opr runtime typeinfo of operator - * \param opr_format op format configuration which to be enabled in the - * layout transform problem + * \param config_id op format configuration id which is going to be enabled + * in the layout transform problem */ - LayoutTransformContext& add_opr_config(Typeinfo* opr, OprFormat opr_format); + LayoutTransformContext& add_opr_config(Typeinfo* opr, OprFormatConfigID config_id); /*! * \brief add a vector of op format configurations for a particular operator * type * \param opr runtime typeinfo of operator - * \param opr_format op format configuration which to be enabled in the - * layout transform problem + * \param config_ids ids of op format configurations which are enabled in + * the layout transform problem */ LayoutTransformContext& add_opr_config( - Typeinfo* opr, SmallVector opr_formats); + Typeinfo* opr, SmallVector config_ids); static std::unique_ptr make( - Target target = Target::UNSPEC, OprFormat base_opr_format = OprFormat::NCHW, + Target target = Target::UNSPEC, + OprFormatConfigID base_config_id = OprFormatConfigID::NCHW, TensorFormats base_tensor_format = TensorFormats::NCHW); private: @@ -130,6 +161,7 @@ private: class Problem { public: using OprFormat = OprTensorFormatsConfiguration::OprFormat; + using OprFormatConfigID = OprTensorFormatsConfiguration::OprFormatConfigID; using OprTensorFormatsDispatcher = OprTensorFormatsConfiguration::OprTensorFormatsDispatcher; using OprConfigTrait = LayoutTransformContext::OprConfigTrait; @@ -152,13 +184,15 @@ public: */ OprTensorFormatsConfiguration base_config(const cg::OperatorNodeBase* opr) const { auto _ = OprTensorFormatsConfiguration::find_dispatcher_by_type_format( - opr->dyn_typeinfo(), m_ctx.attribute().base_opr_format); + opr->dyn_typeinfo(), m_ctx.attribute().base_config_id); auto rst = (*_)(opr); if (rst.valid()) return rst.val(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); - config.opr_format = m_ctx.attribute().base_opr_format; + config.config_id = m_ctx.attribute().base_config_id; + config.opr_format = OprTensorFormatsConfiguration::safe_cast_to_opr_format( + config.config_id); for (const auto& i : opr->input()) { config.input_dtypes.emplace_back(i->dtype().enumv()); config.input_tensor_formats.emplace_back(base_format()); diff --git a/src/gopt/include/megbrain/gopt/profiler.h b/src/gopt/include/megbrain/gopt/profiler.h index 373471e88..0a05a011b 100644 --- a/src/gopt/include/megbrain/gopt/profiler.h +++ b/src/gopt/include/megbrain/gopt/profiler.h @@ -33,9 +33,10 @@ class CachedProfiler; class ProfilerBase { public: using OprFormat = Problem::OprFormat; + using OprFormatConfigID = Problem::OprFormatConfigID; struct OperatorNodeRecord { const cg::OperatorNodeBase* opr; ///< pointer to operator node - ThinHashMap + ThinHashMap costs; ///< costs of operator node, i.e. the elapsed device ///< time of the operator node on different opr format ///< (layout configuration). @@ -199,6 +200,8 @@ protected: virtual float profile_var_node( const VarNode* var, TensorFormats base_format, const ReformatKey& key) const; + OprFormatConfigID tensor_formats_to_config_id(TensorFormats tensor_format) const; + OprFootprint m_opr_footprint; float m_opr_threshold; /// a threshold, when the computation of the newly /// created operator that is built in some opr @@ -224,14 +227,14 @@ class ProfilerCache : public NonCopyableObj { public: using ReformatKey = ReformatManager::ReformatKey; using ReformatAttribute = ReformatKey::Attribute; - using OprFormat = ProfilerBase::OprFormat; + using OprFormatConfigID = ProfilerBase::OprFormatConfigID; class Key final : public NonCopyableObj { std::string m_blob_storage; std::string m_category; struct OprKey { const OperatorNodeBase* opr; - OprFormat opr_format; + OprFormatConfigID config_id; ReformatAttribute extra_attribute; }; @@ -254,9 +257,9 @@ public: void build_category(CompNode cn); public: - Key(const OperatorNodeBase* opr, OprFormat opr_format, + Key(const OperatorNodeBase* opr, OprFormatConfigID config_id, ReformatAttribute extra_attribute = ReformatAttribute::DEFAULT) { - m_key_impl.opr_key = {opr, opr_format, extra_attribute}; + m_key_impl.opr_key = {opr, config_id, extra_attribute}; build_blob_from_opr(); mgb_assert( opr->node_prop().contain( diff --git a/src/gopt/include/megbrain/gopt/solver.h b/src/gopt/include/megbrain/gopt/solver.h index 4fa369601..8e03d8def 100644 --- a/src/gopt/include/megbrain/gopt/solver.h +++ b/src/gopt/include/megbrain/gopt/solver.h @@ -28,7 +28,8 @@ class ProfilerBase; class SolverBase { public: using OprFormat = Problem::OprFormat; - using Solution = ThinHashMap; + using OprFormatConfigID = Problem::OprFormatConfigID; + using Solution = ThinHashMap; SolverBase() = default; virtual ~SolverBase() = default; /*! diff --git a/src/gopt/test/embed_cache.py b/src/gopt/test/embed_cache.py index baffd49c4..3da02450f 100644 --- a/src/gopt/test/embed_cache.py +++ b/src/gopt/test/embed_cache.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -95,7 +96,7 @@ static const std::vector {} = {{ if __name__ == '__main__': parser = argparse.ArgumentParser( - description='embed cache into cache header file', + description='embed cubin into cpp source file', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('-o', '--output', help='output source file', required=True) diff --git a/src/gopt/test/layout_transform_pass.cpp b/src/gopt/test/layout_transform_pass.cpp index eef926760..d4e334795 100644 --- a/src/gopt/test/layout_transform_pass.cpp +++ b/src/gopt/test/layout_transform_pass.cpp @@ -23,7 +23,7 @@ #include "megbrain/plugin/profiler.h" #include "megbrain/serialization/serializer.h" -#define MGB_WITH_CACHED_TEST 1 +#define MGB_WITH_CACHED_TEST 0 #if MGB_WITH_CACHED_TEST #include "./cache_data.h" @@ -60,30 +60,6 @@ size_t find_opr_num(SymbolVar endpoint) { return opr_num; } -using OprFormat = Problem::OprFormat; -OprFormat tensor_formats_to_opr_format(TensorFormats tensor_format) { - switch (tensor_format) { - case TensorFormats::NCHW: - return OprFormat::NCHW; - case TensorFormats::NCHWc4: - return OprFormat::NCHW4; - case TensorFormats::NCHWc8: - return OprFormat::NCHW8; - case TensorFormats::NCHWc32: - return OprFormat::NCHW32; - case TensorFormats::NCHWc64: - return OprFormat::NCHW64; - case TensorFormats::NHWC: - return OprFormat::NHWC; - case TensorFormats::CHWNc4: - return OprFormat::CHWN4; - default: - mgb_throw( - MegBrainError, "tensor format(%u) is not supported", - static_cast(tensor_format)); - } -} - class ProfilerMock : public ProfilerImpl { public: ProfilerMock(const uint8_t* bin, size_t size) { @@ -105,7 +81,7 @@ private: ReformatAttribute extra_attribute = ReformatAttribute::DEFAULT) const override { ProfilerCache::Key key{ - opr, tensor_formats_to_opr_format(tensor_format), extra_attribute}; + opr, tensor_formats_to_config_id(tensor_format), extra_attribute}; auto ret = ProfilerCache::inst().get(key); if (ret.valid()) return ret.val(); @@ -117,9 +93,7 @@ private: const OprTensorFormatsConfiguration& config, ReformatAttribute extra_attribute = ReformatAttribute::DEFAULT) const override { - ProfilerCache::Key key{opr, config.opr_format, extra_attribute}; - std::string tmp; - tmp.reserve(key.blob().size); + ProfilerCache::Key key{opr, config.config_id, extra_attribute}; auto ret = ProfilerCache::inst().get(key); if (ret.valid()) return ret.val(); @@ -161,7 +135,7 @@ TEST(TestLayoutTransform, Resnet18_QS8) { auto func1 = network.graph->compile({make_callback_copy(output, t1)}); func1->execute(); - using OprFormat = LayoutTransformContext::OprFormat; + using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; using OprList = LayoutTransformContext::OprList; using Target = LayoutTransformContext::Target; using ReformatAttribute = LayoutTransformContext::ReformatAttribute; @@ -175,17 +149,18 @@ TEST(TestLayoutTransform, Resnet18_QS8) { TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4, TensorFormats::NCHWc32, TensorFormats::CHWNc4}; Attribute attribute = { - OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC, + OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC, ReformatAttribute::AUTO_PADDING_NHWC}; auto ctx = std::make_unique( std::move(opr_list), std::move(available_tensor_formats), attribute); ctx->add_opr_config( opr::ConvBiasForward::typeinfo(), - {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::CHWN4, OprFormat::NHWC}) + {OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, + OprFormatConfigID::CHWN4, OprFormatConfigID::NHWC}) .add_opr_config( opr::PoolingForward::typeinfo(), - {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, - OprFormat::CHWN4}); + {OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, + OprFormatConfigID::NHWC, OprFormatConfigID::CHWN4}); #if MGB_WITH_CACHED_TEST auto profiler = std::make_unique( static_cast(TestLayoutTransform_Resnet18_QS8.data()), @@ -253,7 +228,7 @@ TEST(TestLayoutTransform, Resnet18_QS4) { auto func1 = network.graph->compile({make_callback_copy(output, t1)}); func1->execute(); - using OprFormat = LayoutTransformContext::OprFormat; + using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; using OprList = LayoutTransformContext::OprList; using Attribute = LayoutTransformContext::Attribute; using Target = LayoutTransformContext::Target; @@ -267,18 +242,20 @@ TEST(TestLayoutTransform, Resnet18_QS4) { TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4, TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4}; Attribute attribute = { - OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC, + OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC, ReformatAttribute::AUTO_PADDING_NHWC}; auto ctx = std::make_unique( std::move(opr_list), std::move(available_tensor_formats), attribute); ctx->add_opr_config( opr::ConvBiasForward::typeinfo(), - {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::CHWN4, OprFormat::NHWC, - OprFormat::NCHW64}) + {OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, + OprFormatConfigID::CHWN4, OprFormatConfigID::NHWC, + OprFormatConfigID::NCHW64}) .add_opr_config( opr::PoolingForward::typeinfo(), - {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NCHW64, - OprFormat::NHWC, OprFormat::CHWN4}); + {OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, + OprFormatConfigID::NCHW64, OprFormatConfigID::NHWC, + OprFormatConfigID::CHWN4}); #if MGB_WITH_CACHED_TEST auto profiler = std::make_unique( static_cast(TestLayoutTransform_Resnet18_QS4.data()), @@ -375,7 +352,7 @@ TEST(TestLayoutTransform, Detection_QS8) { S strategy = S::PROFILE; gopt::modify_opr_algo_strategy_inplace({outputs}, strategy); - using OprFormat = LayoutTransformContext::OprFormat; + using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; using OprList = LayoutTransformContext::OprList; using Attribute = LayoutTransformContext::Attribute; using Target = LayoutTransformContext::Target; @@ -389,18 +366,18 @@ TEST(TestLayoutTransform, Detection_QS8) { TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4, TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4}; Attribute attribute = { - OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC, + OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC, ReformatAttribute::AUTO_PADDING_NHWC}; auto ctx = std::make_unique( std::move(opr_list), std::move(available_tensor_formats), attribute); ctx->add_opr_config( opr::ConvBiasForward::typeinfo(), - {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::CHWN4, OprFormat::NHWC, - OprFormat::NCHW64}) + {OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, + OprFormatConfigID::CHWN4, OprFormatConfigID::NHWC, + OprFormatConfigID::NCHW64}) .add_opr_config( - opr::PoolingForward::typeinfo(), - {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NCHW64, - OprFormat::NHWC, OprFormat::CHWN4}); + opr::ConvolutionBackwardData::typeinfo(), + {OprFormatConfigID::NCHW4, OprFormatConfigID::NHWC}); #if MGB_WITH_CACHED_TEST auto profiler = std::make_unique( static_cast(TestLayoutTransform_Detection_QS8.data()), @@ -452,7 +429,7 @@ TEST(TestLayoutTransform, Detection_QS4) { S strategy = S::PROFILE; gopt::modify_opr_algo_strategy_inplace({outputs}, strategy); - using OprFormat = LayoutTransformContext::OprFormat; + using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; using OprList = LayoutTransformContext::OprList; using ReformatAttribute = LayoutTransformContext::ReformatAttribute; using Attribute = LayoutTransformContext::Attribute; @@ -466,18 +443,18 @@ TEST(TestLayoutTransform, Detection_QS4) { TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4, TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4}; Attribute attribute = { - OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC, + OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC, ReformatAttribute::AUTO_PADDING_NHWC}; auto ctx = std::make_unique( std::move(opr_list), std::move(available_tensor_formats), attribute); ctx->add_opr_config( opr::ConvBiasForward::typeinfo(), - {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::CHWN4, OprFormat::NHWC, - OprFormat::NCHW64}) + {OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, + OprFormatConfigID::CHWN4, OprFormatConfigID::NHWC, + OprFormatConfigID::NCHW64}) .add_opr_config( - opr::PoolingForward::typeinfo(), - {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NCHW64, - OprFormat::NHWC, OprFormat::CHWN4}); + opr::ConvolutionBackwardData::typeinfo(), + {OprFormatConfigID::NCHW4, OprFormatConfigID::NHWC}); #if MGB_WITH_CACHED_TEST auto profiler = std::make_unique( static_cast(TestLayoutTransform_Detection_QS4.data()), @@ -538,7 +515,7 @@ TEST(TestLayoutTransform, Wide) { S strategy = S::PROFILE; gopt::modify_opr_algo_strategy_inplace({y}, strategy); - using OprFormat = LayoutTransformContext::OprFormat; + using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; using OprList = LayoutTransformContext::OprList; using ReformatAttribute = LayoutTransformContext::ReformatAttribute; using Attribute = LayoutTransformContext::Attribute; @@ -550,12 +527,13 @@ TEST(TestLayoutTransform, Wide) { SmallVector available_tensor_formats = { TensorFormats::NCHW, TensorFormats::NHWC}; Attribute attribute = { - OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC, + OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC, ReformatAttribute::DEFAULT}; auto ctx = std::make_unique( std::move(opr_list), std::move(available_tensor_formats), attribute); ctx->add_opr_config( - opr::ConvBiasForward::typeinfo(), {OprFormat::NCHW, OprFormat::NHWC}); + opr::ConvBiasForward::typeinfo(), + {OprFormatConfigID::NCHW, OprFormatConfigID::NHWC}); #if MGB_WITH_CACHED_TEST auto profiler = std::make_unique( static_cast(TestLayoutTransform_Wide.data()), @@ -580,6 +558,8 @@ TEST(TestLayoutTransform, Wide) { auto func = network.graph->compile({{sym_o, {}}}); func->execute(); gprof.to_json_full(func.get())->writeto_fpath(output_file("wide.json")); + /// check global layout transform pass, no dimshuffle + /// disable the following check, to make ci stable. auto nr_dimshuffle = find_opr_num(sym_o); ASSERT_EQ(nr_dimshuffle, 0u); auto nr_param_merge = find_opr_num(sym_o); @@ -631,7 +611,7 @@ TEST(TestLayoutTransform, DetectionHead) { S strategy = S::PROFILE; gopt::modify_opr_algo_strategy_inplace({y}, strategy); - using OprFormat = LayoutTransformContext::OprFormat; + using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; using OprList = LayoutTransformContext::OprList; using Attribute = LayoutTransformContext::Attribute; using ReformatAttribute = LayoutTransformContext::ReformatAttribute; @@ -650,27 +630,30 @@ TEST(TestLayoutTransform, DetectionHead) { TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4, TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4}; Attribute attribute = { - OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC, + OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC, ReformatAttribute::AUTO_PADDING_NHWC}; auto ctx = std::make_unique( std::move(opr_list), std::move(available_tensor_formats), attribute); ctx->add_opr_config( opr::ConvBiasForward::typeinfo(), - {OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW32, - OprFormat::NCHW64, OprFormat::CHWN4}) + {OprFormatConfigID::NCHW, OprFormatConfigID::NHWC, + OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, + OprFormatConfigID::NCHW64, OprFormatConfigID::CHWN4}) .add_opr_config( opr::ConvolutionForward::typeinfo(), - {OprFormat::NCHW, OprFormat::NCHW4}) + {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW4}) .add_opr_config( opr::ConvolutionBackwardData::typeinfo(), - {OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4}) + {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW4}) .add_opr_config( opr::PoolingForward::typeinfo(), - {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, - OprFormat::NCHW64, OprFormat::CHWN4}) + {OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, + OprFormatConfigID::NHWC, OprFormatConfigID::NCHW64, + OprFormatConfigID::CHWN4}) .add_opr_config( opr::WarpPerspectiveForward::typeinfo(), - {OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64}); + {OprFormatConfigID::NHWC, OprFormatConfigID::NCHW4, + OprFormatConfigID::NCHW64}); #if MGB_WITH_CACHED_TEST auto profiler = std::make_unique( static_cast(TestLayoutTransform_DetectionHead.data()), @@ -765,4 +748,184 @@ TEST(TestLayoutTransform, CanonicalizeLayoutTransform) { MGB_ASSERT_TENSOR_EQ(t1, t2); } +TEST(TestLayoutTransform, Resnet18_F32) { + auto cn = CompNode::load("cpu0"); + + Network network(cn); + auto output = make_resnet18(network, 1); + + HostTensorND t1; + auto func1 = network.graph->compile({make_callback_copy(output, t1)}); + func1->execute(); + + using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; + using OprList = LayoutTransformContext::OprList; + using Target = LayoutTransformContext::Target; + using Attribute = LayoutTransformContext::Attribute; + OprList opr_list = { + opr::ConvBiasForward::typeinfo(), + opr::ConvolutionForward::typeinfo(), + opr::ElemwiseMultiType::typeinfo(), + opr::Elemwise::typeinfo(), + opr::TypeCvt::typeinfo(), + opr::Concat::typeinfo(), + opr::PoolingForward::typeinfo(), + opr::WarpPerspectiveForward::typeinfo(), + opr::Resize::typeinfo(), + }; + SmallVector available_tensor_formats = { + TensorFormats::NCHW, + TensorFormats::NCHWc4, + TensorFormats::NCHWc8, + }; + Attribute attribute = { + OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC}; + auto ctx = std::make_unique( + std::move(opr_list), std::move(available_tensor_formats), attribute); + ctx->add_opr_config( + opr::ConvBiasForward::typeinfo(), + { + OprFormatConfigID::NCHW44, + OprFormatConfigID::NCHW, + OprFormatConfigID::NCHW44_HYBRID, + }) + .add_opr_config( + opr::ConvolutionForward::typeinfo(), + { + OprFormatConfigID::NCHW44, + OprFormatConfigID::NCHW, + OprFormatConfigID::NCHW44_HYBRID, + }) + .add_opr_config( + opr::PoolingForward::typeinfo(), { + OprFormatConfigID::NCHW, + OprFormatConfigID::NCHW44, + }); +#if MGB_WITH_CACHED_TEST + auto profiler = std::make_unique( + static_cast(TestLayoutTransform_Resnet18_F32.data()), + TestLayoutTransform_Resnet18_F32.size()); +#else + auto profiler = ProfilerBase::make_cached_profiler( + "TestLayoutTransform.Resnet18_F32.cache"); +#endif + std::unique_ptr solver{ + new DynamicProgrammingSolver(std::move(profiler))}; + auto new_output = + gopt::GraphOptimizer{} + .add_pass() + .add_pass(std::move(ctx), std::move(solver)) + .add_pass() + .add_pass() + .add_pass() + .apply({{output}}) + .endpoint_vars(); + auto new_out_var = new_output[0]; + /// check global layout transform pass + auto nr_dimshuffle = find_opr_num(new_out_var); + ASSERT_EQ(nr_dimshuffle, 1u); + /// check first conv format + const auto& first_conv = find_opr(new_out_var); + const auto& cast = first_conv.cast_final_safe(); + ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NCHW44); + + GraphProfiler gprof{network.graph.get()}; + HostTensorND t2; + auto func2 = network.graph->compile({make_callback_copy(new_out_var, t2)}); + func2->execute(); + gprof.to_json_full(func2.get())->writeto_fpath(output_file("resnet18_f32.json")); + /// check correct + MGB_ASSERT_TENSOR_EQ(t1, t2); +} + +TEST(TestLayoutTransform, MobileNetV2) { + auto cn = CompNode::load("cpu0"); + + Network network(cn); + auto output = make_mobilenet_v2(network, 1); + + HostTensorND t1; + auto func1 = network.graph->compile({make_callback_copy(output, t1)}); + func1->execute(); + + using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; + using OprList = LayoutTransformContext::OprList; + using Target = LayoutTransformContext::Target; + using Attribute = LayoutTransformContext::Attribute; + OprList opr_list = { + opr::ConvBiasForward::typeinfo(), + opr::ConvolutionForward::typeinfo(), + opr::ElemwiseMultiType::typeinfo(), + opr::Elemwise::typeinfo(), + opr::TypeCvt::typeinfo(), + opr::Concat::typeinfo(), + opr::PoolingForward::typeinfo(), + opr::WarpPerspectiveForward::typeinfo(), + opr::Resize::typeinfo(), + }; + SmallVector available_tensor_formats = { + TensorFormats::NCHW, + TensorFormats::NCHWc4, + TensorFormats::NCHWc8, + }; + Attribute attribute = { + OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC}; + auto ctx = std::make_unique( + std::move(opr_list), std::move(available_tensor_formats), attribute); + ctx->add_opr_config( + opr::ConvBiasForward::typeinfo(), + { + OprFormatConfigID::NCHW44, + OprFormatConfigID::NCHW, + OprFormatConfigID::NCHW44_HYBRID, + }) + .add_opr_config( + opr::ConvolutionForward::typeinfo(), + { + OprFormatConfigID::NCHW44, + OprFormatConfigID::NCHW, + OprFormatConfigID::NCHW44_HYBRID, + }) + .add_opr_config( + opr::PoolingForward::typeinfo(), { + OprFormatConfigID::NCHW, + OprFormatConfigID::NCHW44, + }); +#if MGB_WITH_CACHED_TEST + auto profiler = std::make_unique( + static_cast(TestLayoutTransform_MobileNetV2_F32.data()), + TestLayoutTransform_MobileNetV2_F32.size()); +#else + auto profiler = ProfilerBase::make_cached_profiler( + "TestLayoutTransform.MobileNetV2_F32.cache"); +#endif + std::unique_ptr solver{ + new DynamicProgrammingSolver(std::move(profiler))}; + auto new_output = + gopt::GraphOptimizer{} + .add_pass() + .add_pass(std::move(ctx), std::move(solver)) + .add_pass() + .add_pass() + .add_pass() + .apply({{output}}) + .endpoint_vars(); + auto new_out_var = new_output[0]; + /// check global layout transform pass + auto nr_dimshuffle = find_opr_num(new_out_var); + ASSERT_EQ(nr_dimshuffle, 1u); + /// check first conv format + const auto& first_conv = find_opr(new_out_var); + const auto& cast = first_conv.cast_final_safe(); + ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NCHW44); + + GraphProfiler gprof{network.graph.get()}; + HostTensorND t2; + auto func2 = network.graph->compile({make_callback_copy(new_out_var, t2)}); + func2->execute(); + gprof.to_json_full(func2.get())->writeto_fpath(output_file("mobilenet_v2_f32.json")); + /// check correct + MGB_ASSERT_TENSOR_EQ(t1, t2); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/test/network.cpp b/src/gopt/test/network.cpp index 7c264ecc7..8647adf5f 100644 --- a/src/gopt/test/network.cpp +++ b/src/gopt/test/network.cpp @@ -45,6 +45,36 @@ SymbolVar Network::add_conv( return conv; } +SymbolVar Network::add_group_conv( + SymbolVar f, size_t output_channels, size_t groups, KernSize kern_size, + DType out_dtype, bool has_relu, Stride stride, Padding padding) { + static int weight_idx = 0; + static int bias_idx = 0; + + size_t input_channels = f.node()->shape()[1]; + auto weight = add_cvar( + ssprintf("w%d", weight_idx).c_str(), + {groups, output_channels / groups, input_channels / groups, kern_size[0], + kern_size[1]}); + auto bias = add_cvar(ssprintf("b%d", bias_idx).c_str(), {1, output_channels, 1, 1}); + mgb_assert(out_dtype.category() == DTypeCategory::FLOAT); + opr::ConvBias::Param param; + param.sparse = opr::ConvBias::Param::Sparse::GROUP; + param.stride_h = stride[0], param.stride_w = stride[1]; + param.pad_h = padding[0], param.pad_w = padding[1]; + if (has_relu) { + param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; + } else { + param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY; + } + + auto conv = opr::ConvBias::make( + f, weight, bias, param, {}, OperatorNodeConfig{out_dtype}); + weight_idx++; + bias_idx++; + return conv; +} + SymbolVar Network::add_deconv( SymbolVar f, size_t ratio, size_t output_channels, DType out_dtype) { static int weight_idx = 0; @@ -208,6 +238,7 @@ SymbolVarArray fusion_pyramids_feature( false, {1, 1}, {0, 0}); if (!touch) { x = f; + touch = true; } else { x = network.add_deconv(x, 2, 16, dtype::QuantizedS8{1.f}); x = network.add_elemwise( @@ -236,4 +267,63 @@ SymbolVarArray mgb::make_det(Network& network, size_t batch, DType out_dtype) { return outputs; } +SymbolVar mgb::bottleneck( + Network& network, SymbolVar f, size_t input_channels, size_t channels, size_t t, + size_t stride) { + size_t in_channels = f.node()->shape()[1]; + SymbolVar x = f; + if (t != 1) { + x = network.add_conv( + f, input_channels * t, {1, 1}, dtype::Float32(), true, {1, 1}, {0, 0}); + } + x = network.add_group_conv( + x, input_channels * t, input_channels * t, {3, 3}, dtype::Float32(), true, + {stride, stride}, {1, 1}); + x = network.add_conv(x, channels, {1, 1}, dtype::Float32(), false, {1, 1}, {0, 0}); + if (stride == 1 && in_channels == channels) + x = f + x; + return x; +} + +SymbolVar mgb::bottleneck_group( + Network& network, SymbolVar f, size_t input_channels, size_t channels, + size_t stages, size_t s, size_t t) { + SymbolVar x = f; + for (size_t i = 0; i < stages; ++i) { + size_t stride = i == 0 ? s : 1; + x = bottleneck(network, x, input_channels, channels, t, stride); + input_channels = channels; + } + return x; +} + +namespace { +size_t make_divisible(size_t v, size_t divisor) { + size_t min_value = divisor; + size_t new_v = std::max(min_value, (v + divisor / 2) / divisor * divisor); + if (new_v < 0.9 * v) + new_v += divisor; + return new_v; +} +} // namespace + +SymbolVar mgb::make_mobilenet_v2(Network& network, size_t batch) { + auto data = network.add_var("data", {batch, 3, 224, 224}); + constexpr size_t round_nearest = 8; + auto x = network.add_conv( + data, make_divisible(32, round_nearest), {3, 3}, dtype::Float32(), true, + {2, 2}, {1, 1}); + x = bottleneck(network, x, 32, make_divisible(16, round_nearest), 1, 1); + x = bottleneck_group(network, x, 16, make_divisible(24, round_nearest), 2, 2, 6); + x = bottleneck_group(network, x, 24, make_divisible(32, round_nearest), 3, 2, 6); + x = bottleneck_group(network, x, 32, make_divisible(64, round_nearest), 4, 2, 6); + x = bottleneck_group(network, x, 64, make_divisible(96, round_nearest), 3, 1, 6); + x = bottleneck_group(network, x, 96, make_divisible(160, round_nearest), 3, 2, 6); + x = bottleneck_group(network, x, 160, make_divisible(320, round_nearest), 1, 1, 6); + x = network.add_conv( + x, make_divisible(1280, round_nearest), {1, 1}, dtype::Float32(), true, + {1, 1}, {0, 0}); + return x; +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/test/network.h b/src/gopt/test/network.h index a0018e57a..194a02db1 100644 --- a/src/gopt/test/network.h +++ b/src/gopt/test/network.h @@ -28,7 +28,7 @@ namespace mgb { class Network { private: - HostTensorGenerator<> gen; + HostTensorGenerator gen{-0.01, 0.01}; CompNode cn; public: @@ -49,6 +49,10 @@ public: SymbolVar f, size_t output_channels, KernSize kern_size, DType out_dtype = dtype::Float32(), bool has_relu = true, Stride stride = {1, 1}, Padding padding = {0, 0}); + SymbolVar add_group_conv( + SymbolVar f, size_t output_channels, size_t groups, KernSize kern_size, + DType out_dtype = dtype::Float32(), bool has_relu = true, + Stride stride = {1, 1}, Padding padding = {0, 0}); SymbolVar add_deconv( SymbolVar f, size_t ratio, size_t output_channels, DType out_dtype); SymbolVar add_elemwise( @@ -73,6 +77,16 @@ SymbolVar make_resnet18( SymbolVarArray make_det( Network& network, size_t batch = 16, DType out_dtype = dtype::Float32()); +SymbolVar bottleneck( + Network& network, SymbolVar f, size_t input_channels, size_t channels, size_t t, + size_t stride); + +SymbolVar bottleneck_group( + Network& network, SymbolVar f, size_t input_channels, size_t channels, + size_t stages, size_t s, size_t t); + +SymbolVar make_mobilenet_v2(Network& network, size_t batch = 1); + } // namespace mgb // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/test/profiler.cpp b/src/gopt/test/profiler.cpp index b32f3d724..465a4a56a 100644 --- a/src/gopt/test/profiler.cpp +++ b/src/gopt/test/profiler.cpp @@ -26,7 +26,7 @@ using namespace serialization; #if MGB_CUDA namespace { std::unique_ptr make_ctx() { - using OprFormat = LayoutTransformContext::OprFormat; + using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; using OprList = LayoutTransformContext::OprList; using Attribute = LayoutTransformContext::Attribute; using Target = LayoutTransformContext::Target; @@ -44,26 +44,29 @@ std::unique_ptr make_ctx() { SmallVector available_tensor_formats = { TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4, TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4}; - Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, Target::CUDA}; + Attribute attribute = {OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::CUDA}; auto ctx = std::make_unique( std::move(opr_list), std::move(available_tensor_formats), attribute); ctx->add_opr_config( opr::ConvBiasForward::typeinfo(), - {OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW32, - OprFormat::NCHW64, OprFormat::CHWN4}) + {OprFormatConfigID::NCHW, OprFormatConfigID::NHWC, + OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, + OprFormatConfigID::NCHW64, OprFormatConfigID::CHWN4}) .add_opr_config( opr::ConvolutionForward::typeinfo(), - {OprFormat::NCHW, OprFormat::NCHW4}) + {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW4}) .add_opr_config( opr::ConvolutionBackwardData::typeinfo(), - {OprFormat::NCHW, OprFormat::NCHW4}) + {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW4}) .add_opr_config( opr::PoolingForward::typeinfo(), - {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, - OprFormat::NCHW64, OprFormat::CHWN4}) + {OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, + OprFormatConfigID::NHWC, OprFormatConfigID::NCHW64, + OprFormatConfigID::CHWN4}) .add_opr_config( opr::WarpPerspectiveForward::typeinfo(), - {OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64}); + {OprFormatConfigID::NHWC, OprFormatConfigID::NCHW4, + OprFormatConfigID::NCHW64}); return ctx; } } // namespace -- GitLab