提交 3d45d352 编写于 作者: M Megvii Engine Team

feat(mgb/gopt): profiler support checking algo availability

GitOrigin-RevId: 39cad612ccb0484e5ccc38200293387a131a96d0
上级 5f15f759
......@@ -32,13 +32,8 @@ namespace megdnn {
*/
template <class Opr, typename... Args>
bool has_available_algo(Opr* opr, Args&&... args) {
const typename Opr::AlgoBase::SizeArgs size_args(opr, std::forward<Args>(args)...);
for (auto i : Opr::algo_pack().all_algos) {
if (i->is_available(size_args)) {
return true;
}
}
return false;
auto&& all_algos = opr->get_all_algorithms_info(std::forward<Args>(args)...);
return !all_algos.empty();
}
} // namespace megdnn
......
......@@ -157,7 +157,6 @@ struct ConvMaker<opr::BatchConvBiasForward>
MakeConvCaller4<megdnn::BatchConvBiasForward>,
megdnn::param::BatchConvBias> {};
#if 0
#include "../../opr/impl/internal/invoke.h"
template <typename Opr>
struct MultiAlgoOprTrait;
......@@ -202,7 +201,6 @@ INST(ConvolutionBackwardData)
INST(PoolingForward)
#undef APPLY
#undef INST
#endif
} // namespace
namespace mgb {
......@@ -291,9 +289,7 @@ VarNode* modify_opr_format(
#undef cb
}
#if 0
bool has_available_algo(const VarNodeArray& i,
const cg::OperatorNodeBase* opr) {
bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr) {
#define cb(_Opr) \
if (opr->dyn_typeinfo() == _Opr::typeinfo()) { \
MGB_MARK_USED_VAR(MultiAlgoOprTrait<_Opr>::has_algo); \
......@@ -301,13 +297,12 @@ bool has_available_algo(const VarNodeArray& i,
_.emplace_back(opr->output(0)); \
return MultiAlgoOprTrait<_Opr>::has_available_algo(_, opr); \
} else
cb(Convolution) cb(ConvBiasForward) cb(ConvolutionBackwardData)
cb(PoolingForward) {
mgb_throw(InternalError, "invalid multi-algo operator(got:%s)",
opr->dyn_typeinfo()->name);
cb(Convolution) cb(ConvBiasForward) cb(ConvolutionBackwardData) cb(PoolingForward) {
mgb_throw(
InternalError, "invalid multi-algo operator(got:%s)",
opr->dyn_typeinfo()->name);
}
}
#endif
} // namespace intl
} // namespace gopt
......
......@@ -21,9 +21,7 @@ namespace intl {
#define FOREACH_FORMAT_AWARE_OPR(cb) \
cb(Convolution) cb(ConvBiasForward) cb(ConvolutionBackwardData) cb(PoolingForward) \
cb(WarpPerspective) cb(Resize)
#if 0
bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr);
#endif
VarNode* modify_opr_format(
opr::ConvBias::Param::Format opr_format, const VarNodeArray& i,
......
......@@ -43,7 +43,8 @@ static inline size_t extra_alignment(
size_t dtype_bits = dt.is_low_bit() ? dt.low_bit() : dt.size(1) * 8;
size_t extra_alignment =
alignment_in_bits >= dtype_bits ? alignment_in_bits / dtype_bits : 1;
if (target_formats == TensorFormats::NHWC)
if (target_formats == TensorFormats::NHWC ||
target_formats == TensorFormats::KRSC)
channel_alignment = extra_alignment * channel_alignment /
gcd(channel_alignment, extra_alignment);
return channel_alignment;
......@@ -60,10 +61,12 @@ static inline std::tuple<size_t, size_t> extra_alignment(
size_t dtype_bits = dt.is_low_bit() ? dt.low_bit() : dt.size(1) * 8;
size_t extra_alignment =
alignment_in_bits >= dtype_bits ? alignment_in_bits / dtype_bits : 1;
if (key.input_format == TensorFormats::NHWC)
if (key.input_format == TensorFormats::NHWC ||
key.input_format == TensorFormats::KRSC)
input_channel_alignment = input_channel_alignment * extra_alignment /
gcd(input_channel_alignment, extra_alignment);
if (key.output_format == TensorFormats::NHWC)
if (key.output_format == TensorFormats::NHWC ||
key.output_format == TensorFormats::KRSC)
output_channel_alignment = output_channel_alignment * extra_alignment /
gcd(output_channel_alignment, extra_alignment);
return std::make_tuple(input_channel_alignment, output_channel_alignment);
......
......@@ -62,6 +62,16 @@ enum class TensorFormats : uint32_t {
KCRS = 24, ///< [K, C, R, S]
GKCRS = 25, ///< [G, K, C, R, S]
C11RS = 26, ///< [C, 1, 1, R, S]
// NHWC
KRSC = 27, /// < [K, R, S, C]
// NCHW32
KCRSc32 = 28, ///<[K, C/32, R, S, C%32]
// NCHW64
KCRSc64 = 29, ///<[K, C/64, R, S, C%64]
// CHWN4
CRSKc4 = 30, ///< [C/4, R, S, K, C%4]
};
class ReformatManager : public NonCopyableObj {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册