diff --git a/dnn/include/megdnn/common.h b/dnn/include/megdnn/common.h index 142b8f4c52991f004da564df8a772b4b8f74c211..f0073c59b84a9e6fcc6b450c56a2053596ab39c0 100644 --- a/dnn/include/megdnn/common.h +++ b/dnn/include/megdnn/common.h @@ -32,13 +32,8 @@ namespace megdnn { */ template bool has_available_algo(Opr* opr, Args&&... args) { - const typename Opr::AlgoBase::SizeArgs size_args(opr, std::forward(args)...); - for (auto i : Opr::algo_pack().all_algos) { - if (i->is_available(size_args)) { - return true; - } - } - return false; + auto&& all_algos = opr->get_all_algorithms_info(std::forward(args)...); + return !all_algos.empty(); } } // namespace megdnn diff --git a/src/gopt/impl/global_layout_transform/opr_format_modifier.cpp b/src/gopt/impl/global_layout_transform/opr_format_modifier.cpp index 1684cb67a65d55d8018dc89b202dfb6435ef4a66..89fe19e58f8d6218dbf2cfb81abc76bd9532ba85 100644 --- a/src/gopt/impl/global_layout_transform/opr_format_modifier.cpp +++ b/src/gopt/impl/global_layout_transform/opr_format_modifier.cpp @@ -157,7 +157,6 @@ struct ConvMaker MakeConvCaller4, megdnn::param::BatchConvBias> {}; -#if 0 #include "../../opr/impl/internal/invoke.h" template struct MultiAlgoOprTrait; @@ -202,7 +201,6 @@ INST(ConvolutionBackwardData) INST(PoolingForward) #undef APPLY #undef INST -#endif } // namespace namespace mgb { @@ -291,9 +289,7 @@ VarNode* modify_opr_format( #undef cb } -#if 0 -bool has_available_algo(const VarNodeArray& i, - const cg::OperatorNodeBase* opr) { +bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr) { #define cb(_Opr) \ if (opr->dyn_typeinfo() == _Opr::typeinfo()) { \ MGB_MARK_USED_VAR(MultiAlgoOprTrait<_Opr>::has_algo); \ @@ -301,13 +297,12 @@ bool has_available_algo(const VarNodeArray& i, _.emplace_back(opr->output(0)); \ return MultiAlgoOprTrait<_Opr>::has_available_algo(_, opr); \ } else - cb(Convolution) cb(ConvBiasForward) cb(ConvolutionBackwardData) - cb(PoolingForward) { - mgb_throw(InternalError, "invalid multi-algo operator(got:%s)", - opr->dyn_typeinfo()->name); + cb(Convolution) cb(ConvBiasForward) cb(ConvolutionBackwardData) cb(PoolingForward) { + mgb_throw( + InternalError, "invalid multi-algo operator(got:%s)", + opr->dyn_typeinfo()->name); } } -#endif } // namespace intl } // namespace gopt diff --git a/src/gopt/impl/global_layout_transform/opr_format_modifier.h b/src/gopt/impl/global_layout_transform/opr_format_modifier.h index 9ef3700caca4d95fd4677182f292ae1a8b837431..c1d2002859de0593e60e48982420b049bd8b0cfc 100644 --- a/src/gopt/impl/global_layout_transform/opr_format_modifier.h +++ b/src/gopt/impl/global_layout_transform/opr_format_modifier.h @@ -21,9 +21,7 @@ namespace intl { #define FOREACH_FORMAT_AWARE_OPR(cb) \ cb(Convolution) cb(ConvBiasForward) cb(ConvolutionBackwardData) cb(PoolingForward) \ cb(WarpPerspective) cb(Resize) -#if 0 bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr); -#endif VarNode* modify_opr_format( opr::ConvBias::Param::Format opr_format, const VarNodeArray& i, diff --git a/src/gopt/impl/global_layout_transform/reformat_manager.cpp b/src/gopt/impl/global_layout_transform/reformat_manager.cpp index 7e5c6ef70cd67a94c1fb27d2702066d9bc3c89ef..82b9a8629adab6d22ef8a15ba4242608e5ad309d 100644 --- a/src/gopt/impl/global_layout_transform/reformat_manager.cpp +++ b/src/gopt/impl/global_layout_transform/reformat_manager.cpp @@ -43,7 +43,8 @@ static inline size_t extra_alignment( size_t dtype_bits = dt.is_low_bit() ? dt.low_bit() : dt.size(1) * 8; size_t extra_alignment = alignment_in_bits >= dtype_bits ? alignment_in_bits / dtype_bits : 1; - if (target_formats == TensorFormats::NHWC) + if (target_formats == TensorFormats::NHWC || + target_formats == TensorFormats::KRSC) channel_alignment = extra_alignment * channel_alignment / gcd(channel_alignment, extra_alignment); return channel_alignment; @@ -60,10 +61,12 @@ static inline std::tuple extra_alignment( size_t dtype_bits = dt.is_low_bit() ? dt.low_bit() : dt.size(1) * 8; size_t extra_alignment = alignment_in_bits >= dtype_bits ? alignment_in_bits / dtype_bits : 1; - if (key.input_format == TensorFormats::NHWC) + if (key.input_format == TensorFormats::NHWC || + key.input_format == TensorFormats::KRSC) input_channel_alignment = input_channel_alignment * extra_alignment / gcd(input_channel_alignment, extra_alignment); - if (key.output_format == TensorFormats::NHWC) + if (key.output_format == TensorFormats::NHWC || + key.output_format == TensorFormats::KRSC) output_channel_alignment = output_channel_alignment * extra_alignment / gcd(output_channel_alignment, extra_alignment); return std::make_tuple(input_channel_alignment, output_channel_alignment); diff --git a/src/gopt/include/megbrain/gopt/reformat_manager.h b/src/gopt/include/megbrain/gopt/reformat_manager.h index 0dea76b5a3d4852af154001829229f641e605c18..7426a568da3f6f82cc5287157c689adc5964109b 100644 --- a/src/gopt/include/megbrain/gopt/reformat_manager.h +++ b/src/gopt/include/megbrain/gopt/reformat_manager.h @@ -62,6 +62,16 @@ enum class TensorFormats : uint32_t { KCRS = 24, ///< [K, C, R, S] GKCRS = 25, ///< [G, K, C, R, S] C11RS = 26, ///< [C, 1, 1, R, S] + + // NHWC + KRSC = 27, /// < [K, R, S, C] + + // NCHW32 + KCRSc32 = 28, ///<[K, C/32, R, S, C%32] + // NCHW64 + KCRSc64 = 29, ///<[K, C/64, R, S, C%64] + // CHWN4 + CRSKc4 = 30, ///< [C/4, R, S, K, C%4] }; class ReformatManager : public NonCopyableObj {