提交 d195fdec 编写于 作者: M Megvii Engine Team

refactor(mgb): refactor has-usable-algo function for global optimizer

GitOrigin-RevId: 66105166505ba81d2ef984bbc8085af5333a1c9b
上级 d6e50b2c
...@@ -24,5 +24,24 @@ ...@@ -24,5 +24,24 @@
#define setenv(name,value,overwrite) _putenv_s(name,value) #define setenv(name,value,overwrite) _putenv_s(name,value)
#endif #endif
namespace megdnn {
/*!
* \brief whether there is an algorithm from algo_pack() that is available for
* current size
*/
template <class Opr, typename... Args>
bool has_available_algo(Opr* opr, Args&&... args) {
const typename Opr::AlgoBase::SizeArgs size_args(
opr, std::forward<Args>(args)...);
for (auto i : Opr::algo_pack().all_algos) {
if (i->is_available(size_args)) {
return true;
}
}
return false;
}
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "megdnn/common.h"
#include "utils.h" #include "utils.h"
namespace megdnn { namespace megdnn {
...@@ -74,21 +75,6 @@ std::vector<typename Opr::Algorithm*> get_all_algorithms( ...@@ -74,21 +75,6 @@ std::vector<typename Opr::Algorithm*> get_all_algorithms(
return ret; return ret;
} }
/*!
* \brief whether there is an algorithm from algo_pack() that is available for
* current size
*/
template <class Opr>
bool has_available_algo(
const typename Opr::AlgoBase::SizeArgs& args) {
for (auto i : Opr::algo_pack().all_algos) {
if (i->is_available(args)) {
return true;
}
}
return false;
}
/*! /*!
* \brief a helper function to get an algorithm match attribute. If require a * \brief a helper function to get an algorithm match attribute. If require a
* algorithm with specified attribute, and the given algorithm match that * algorithm with specified attribute, and the given algorithm match that
......
...@@ -134,13 +134,6 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available( ...@@ -134,13 +134,6 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available(
auto config = prepare_sub_opr(args); auto config = prepare_sub_opr(args);
AlgoBase::SizeArgs sub_args{
static_cast<ConvBiasForwardImpl*>(config.second.get()),
config.first[0],
config.first[1],
config.first[2],
config.first[3],
config.first[4]};
bool is_relayout_ok = true; bool is_relayout_ok = true;
if (args.dst_layout->dtype.enumv() != DTypeEnum::Float32) { if (args.dst_layout->dtype.enumv() != DTypeEnum::Float32) {
is_relayout_ok = relayout_format::RelayoutFormatFast::usable( is_relayout_ok = relayout_format::RelayoutFormatFast::usable(
...@@ -148,7 +141,11 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available( ...@@ -148,7 +141,11 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available(
RelayoutFormat::Param::Mode::NCHW4_NCHW); RelayoutFormat::Param::Mode::NCHW4_NCHW);
} }
return is_relayout_ok && has_available_algo<ConvBiasForwardImpl>(sub_args); return is_relayout_ok &&
has_available_algo<ConvBiasForwardImpl>(
static_cast<ConvBiasForwardImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2],
config.first[3], config.first[4]);
} }
WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_workspace_bundle( WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_workspace_bundle(
......
...@@ -107,16 +107,12 @@ bool ConvBiasForwardImpl::AlgoGroupConvGeneral::is_available( ...@@ -107,16 +107,12 @@ bool ConvBiasForwardImpl::AlgoGroupConvGeneral::is_available(
auto conv_args = args; auto conv_args = args;
conv_args.dst_layout = &dst_layout; conv_args.dst_layout = &dst_layout;
auto config = prepare_sub_opr(conv_args); auto config = prepare_sub_opr(conv_args);
AlgoBase::SizeArgs sub_args{
bool ret = has_available_algo<ConvBiasForwardImpl>(
static_cast<ConvBiasForwardImpl*>(config.second.get()), static_cast<ConvBiasForwardImpl*>(config.second.get()),
config.first[0], config.first[0], config.first[1], config.first[2], config.first[3],
config.first[1], config.first[4]);
config.first[2], return ret;
config.first[3],
config.first[4]};
bool ret = has_available_algo<ConvBiasForwardImpl>(sub_args);
return ret;
} }
WorkspaceBundle ConvBiasForwardImpl::AlgoGroupConvGeneral::get_workspace_bundle( WorkspaceBundle ConvBiasForwardImpl::AlgoGroupConvGeneral::get_workspace_bundle(
......
...@@ -82,11 +82,10 @@ bool ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::is_available( ...@@ -82,11 +82,10 @@ bool ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::is_available(
} }
auto config = prepare_sub_opr(args); auto config = prepare_sub_opr(args);
AlgoBase::SizeArgs sub_args{
static_cast<ConvolutionBackwardDataImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2]};
return has_available_algo<ConvolutionBackwardDataImpl>(sub_args); return has_available_algo<ConvolutionBackwardDataImpl>(
static_cast<ConvolutionBackwardDataImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2]);
} }
WorkspaceBundle WorkspaceBundle
......
...@@ -78,11 +78,10 @@ bool ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::is_available( ...@@ -78,11 +78,10 @@ bool ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::is_available(
} }
auto config = prepare_sub_opr(args); auto config = prepare_sub_opr(args);
AlgoBase::SizeArgs sub_args{
static_cast<ConvolutionBackwardFilterImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2]};
return has_available_algo<ConvolutionBackwardFilterImpl>(sub_args); return has_available_algo<ConvolutionBackwardFilterImpl>(
static_cast<ConvolutionBackwardFilterImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2]);
} }
WorkspaceBundle WorkspaceBundle
......
...@@ -73,11 +73,10 @@ bool Convolution3DBackwardDataImpl::AlgoGroupConvGeneral::is_available( ...@@ -73,11 +73,10 @@ bool Convolution3DBackwardDataImpl::AlgoGroupConvGeneral::is_available(
} }
auto config = prepare_sub_opr(args); auto config = prepare_sub_opr(args);
AlgoBase::SizeArgs sub_args{
static_cast<Convolution3DBackwardDataImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2]};
return has_available_algo<Convolution3DBackwardDataImpl>(sub_args); return has_available_algo<Convolution3DBackwardDataImpl>(
static_cast<Convolution3DBackwardDataImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2]);
} }
WorkspaceBundle WorkspaceBundle
......
...@@ -77,11 +77,10 @@ bool Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral::is_available( ...@@ -77,11 +77,10 @@ bool Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral::is_available(
} }
auto config = prepare_sub_opr(args); auto config = prepare_sub_opr(args);
AlgoBase::SizeArgs sub_args{
static_cast<Convolution3DBackwardFilterImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2]};
return has_available_algo<Convolution3DBackwardFilterImpl>(sub_args); return has_available_algo<Convolution3DBackwardFilterImpl>(
static_cast<Convolution3DBackwardFilterImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2]);
} }
WorkspaceBundle WorkspaceBundle
......
...@@ -80,11 +80,10 @@ bool Convolution3DForwardImpl::AlgoGroupConvGeneral::is_available( ...@@ -80,11 +80,10 @@ bool Convolution3DForwardImpl::AlgoGroupConvGeneral::is_available(
} }
auto config = prepare_sub_opr(args); auto config = prepare_sub_opr(args);
AlgoBase::SizeArgs sub_args{
static_cast<Convolution3DForwardImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2]};
return has_available_algo<Convolution3DForwardImpl>(sub_args); return has_available_algo<Convolution3DForwardImpl>(
static_cast<Convolution3DForwardImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2]);
} }
WorkspaceBundle WorkspaceBundle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册