提交 a57317f4 编写于 作者: M Megvii Engine Team

refactor(mgb): move profile cache out of mgb opr and update CACHE_KEY_VERSION

GitOrigin-RevId: db433164d297c9d6438fc45203e57e72df3f9b0b
上级 278b2baa
...@@ -36,8 +36,6 @@ using namespace opr; ...@@ -36,8 +36,6 @@ using namespace opr;
using namespace cg::static_infer; using namespace cg::static_infer;
using intl::WorkspaceLimitGetter; using intl::WorkspaceLimitGetter;
#define CACHE_KEY_VERSION "v2"
/* ==================== misc impl ==================== */ /* ==================== misc impl ==================== */
mixin::Convolution::~Convolution() = default; mixin::Convolution::~Convolution() = default;
...@@ -103,26 +101,12 @@ void mixin::Convolution::init_output_static_infer_desc_for_bwd_data( ...@@ -103,26 +101,12 @@ void mixin::Convolution::init_output_static_infer_desc_for_bwd_data(
{SourceType::DEP, inp_deps, infer_wk}); {SourceType::DEP, inp_deps, infer_wk});
} }
#define IMPL_CONV(_cls, _prof_name) \ #define IMPL_CONV(_cls) \
void _cls::init_profile_cache() { \ std::pair<const void*, size_t> _cls::param_blob() const { \
std::string name(_prof_name CACHE_KEY_VERSION); \ return {&param(), sizeof(Param)}; \
name.append(megdnn_opr()->get_algorithm_set_name()); \ } \
m_profile_cache = std::make_unique<AlgoChooserProfileCache>( \
comp_node(), name.c_str()); \
} \
std::pair<const void*, size_t> _cls::param_blob() const { \
return {&param(), sizeof(Param)}; \
} \
MGB_DYN_TYPE_OBJ_FINAL_IMPL(_cls) MGB_DYN_TYPE_OBJ_FINAL_IMPL(_cls)
AlgoChooserProfileCache& mixin::Convolution::profile_cache() const {
if (!m_profile_cache) {
const_cast<Convolution*>(this)->init_profile_cache();
mgb_assert(m_profile_cache);
}
return *m_profile_cache;
}
class mixin::WeightPreprocessExecutor::PreprocessedFilterExecDep final class mixin::WeightPreprocessExecutor::PreprocessedFilterExecDep final
: public cg::GraphExecutable::ExecDependency { : public cg::GraphExecutable::ExecDependency {
std::unique_ptr<PreprocessedFilter> m_pf; std::unique_ptr<PreprocessedFilter> m_pf;
...@@ -209,7 +193,7 @@ bool mixin::WeightPreprocessExecutor::mixin_allow_weight_preprocess( ...@@ -209,7 +193,7 @@ bool mixin::WeightPreprocessExecutor::mixin_allow_weight_preprocess(
/* ==================== ConvolutionForward ==================== */ /* ==================== ConvolutionForward ==================== */
IMPL_CONV(ConvolutionForward, "conv_fwd"); IMPL_CONV(ConvolutionForward);
ConvolutionForward::ConvolutionForward(VarNode* src, VarNode* filter, ConvolutionForward::ConvolutionForward(VarNode* src, VarNode* filter,
const Param& param, const Param& param,
...@@ -335,7 +319,7 @@ void ConvolutionForward::scn_do_execute_preprocess() { ...@@ -335,7 +319,7 @@ void ConvolutionForward::scn_do_execute_preprocess() {
} }
/* ==================== ConvolutionBackwardData ==================== */ /* ==================== ConvolutionBackwardData ==================== */
IMPL_CONV(ConvolutionBackwardData, "conv_bwd_data"); IMPL_CONV(ConvolutionBackwardData);
ConvolutionBackwardData::ConvolutionBackwardData( ConvolutionBackwardData::ConvolutionBackwardData(
VarNode* filter, VarNode* diff, VarNode* src_for_shp, VarNode* filter, VarNode* diff, VarNode* src_for_shp,
...@@ -426,7 +410,7 @@ MGB_IMPL_OPR_GRAD(ConvolutionBackwardData) { ...@@ -426,7 +410,7 @@ MGB_IMPL_OPR_GRAD(ConvolutionBackwardData) {
#endif #endif
/* ==================== ConvolutionBackwardFilter ==================== */ /* ==================== ConvolutionBackwardFilter ==================== */
IMPL_CONV(ConvolutionBackwardFilter, "conv_bwd_filter"); IMPL_CONV(ConvolutionBackwardFilter);
ConvolutionBackwardFilter::ConvolutionBackwardFilter( ConvolutionBackwardFilter::ConvolutionBackwardFilter(
VarNode* src, VarNode* diff, VarNode* filter, const Param& param, VarNode* src, VarNode* diff, VarNode* filter, const Param& param,
...@@ -480,7 +464,7 @@ MGB_IMPL_OPR_GRAD(ConvolutionBackwardFilter) { ...@@ -480,7 +464,7 @@ MGB_IMPL_OPR_GRAD(ConvolutionBackwardFilter) {
#endif #endif
/* ==================== Convolution3DForward ==================== */ /* ==================== Convolution3DForward ==================== */
IMPL_CONV(Convolution3DForward, "conv3d_fwd"); IMPL_CONV(Convolution3DForward);
Convolution3DForward::Convolution3DForward(VarNode* src, VarNode* filter, Convolution3DForward::Convolution3DForward(VarNode* src, VarNode* filter,
const Param& param, const Param& param,
...@@ -553,7 +537,7 @@ size_t Convolution3DForward::get_workspace_size_bytes( ...@@ -553,7 +537,7 @@ size_t Convolution3DForward::get_workspace_size_bytes(
} }
/* ==================== Convolution3DBackwardData ==================== */ /* ==================== Convolution3DBackwardData ==================== */
IMPL_CONV(Convolution3DBackwardData, "conv3d_bwd_data"); IMPL_CONV(Convolution3DBackwardData);
Convolution3DBackwardData::Convolution3DBackwardData( Convolution3DBackwardData::Convolution3DBackwardData(
VarNode* filter, VarNode* diff, VarNode* src_for_shp, VarNode* filter, VarNode* diff, VarNode* src_for_shp,
...@@ -631,7 +615,7 @@ MGB_IMPL_OPR_GRAD(Convolution3DBackwardData) { ...@@ -631,7 +615,7 @@ MGB_IMPL_OPR_GRAD(Convolution3DBackwardData) {
#endif #endif
/* ==================== Convolution3DBackwardFilter ==================== */ /* ==================== Convolution3DBackwardFilter ==================== */
IMPL_CONV(Convolution3DBackwardFilter, "conv3d_bwd_filter"); IMPL_CONV(Convolution3DBackwardFilter);
Convolution3DBackwardFilter::Convolution3DBackwardFilter( Convolution3DBackwardFilter::Convolution3DBackwardFilter(
VarNode* src, VarNode* diff, VarNode* filter, const Param& param, VarNode* src, VarNode* diff, VarNode* filter, const Param& param,
...@@ -719,7 +703,7 @@ SymbolVar MaskPropagate::make(SymbolVar src, const Param& param, ...@@ -719,7 +703,7 @@ SymbolVar MaskPropagate::make(SymbolVar src, const Param& param,
} }
/* ==================== ConvBiasForward ==================== */ /* ==================== ConvBiasForward ==================== */
IMPL_CONV(ConvBiasForward, "conv_bias_fwd"); IMPL_CONV(ConvBiasForward);
ConvBiasForward::ConvBiasForward(VarNode* src, VarNode* filter, ConvBiasForward::ConvBiasForward(VarNode* src, VarNode* filter,
const Param& param, const Param& param,
...@@ -1005,7 +989,7 @@ void ConvBiasForward::scn_do_execute_preprocess() { ...@@ -1005,7 +989,7 @@ void ConvBiasForward::scn_do_execute_preprocess() {
/* ===================== LocalShareForward ==================== */ /* ===================== LocalShareForward ==================== */
IMPL_CONV(LocalShareForward, "local_share"); IMPL_CONV(LocalShareForward);
LocalShareForward::LocalShareForward(VarNode* src, VarNode* filter, LocalShareForward::LocalShareForward(VarNode* src, VarNode* filter,
const Param& param, const Param& param,
...@@ -1073,7 +1057,7 @@ MGB_IMPL_OPR_GRAD(LocalShareForward) { ...@@ -1073,7 +1057,7 @@ MGB_IMPL_OPR_GRAD(LocalShareForward) {
/* ===================== LocalShareBackwardData ==================== */ /* ===================== LocalShareBackwardData ==================== */
IMPL_CONV(LocalShareBackwardData, "local_share_bwd_data"); IMPL_CONV(LocalShareBackwardData);
LocalShareBackwardData::LocalShareBackwardData(VarNode* filter, VarNode* diff, LocalShareBackwardData::LocalShareBackwardData(VarNode* filter, VarNode* diff,
VarNode* src_for_shp, VarNode* src_for_shp,
...@@ -1153,7 +1137,7 @@ MGB_IMPL_OPR_GRAD(LocalShareBackwardData) { ...@@ -1153,7 +1137,7 @@ MGB_IMPL_OPR_GRAD(LocalShareBackwardData) {
/* ==================== LocalShareBackwardFilter ==================== */ /* ==================== LocalShareBackwardFilter ==================== */
IMPL_CONV(LocalShareBackwardFilter, "local_share_bwd_filter"); IMPL_CONV(LocalShareBackwardFilter);
LocalShareBackwardFilter::LocalShareBackwardFilter( LocalShareBackwardFilter::LocalShareBackwardFilter(
VarNode* src, VarNode* diff, VarNode* filter, const Param& param, VarNode* src, VarNode* diff, VarNode* filter, const Param& param,
...@@ -1208,7 +1192,7 @@ MGB_IMPL_OPR_GRAD(LocalShareBackwardFilter) { ...@@ -1208,7 +1192,7 @@ MGB_IMPL_OPR_GRAD(LocalShareBackwardFilter) {
/* ===================== DeformableConvForward ==================== */ /* ===================== DeformableConvForward ==================== */
IMPL_CONV(DeformableConvForward, "deformable_conv"); IMPL_CONV(DeformableConvForward);
DeformableConvForward::DeformableConvForward(VarNode* src, VarNode* filter, DeformableConvForward::DeformableConvForward(VarNode* src, VarNode* filter,
VarNode* offset, VarNode* mask, VarNode* offset, VarNode* mask,
...@@ -1293,7 +1277,7 @@ MGB_IMPL_OPR_GRAD(DeformableConvForward) { ...@@ -1293,7 +1277,7 @@ MGB_IMPL_OPR_GRAD(DeformableConvForward) {
/* ==================== DeformableConvBackwardData ==================== */ /* ==================== DeformableConvBackwardData ==================== */
IMPL_CONV(DeformableConvBackwardData, "deformalbe_conv_backward_data"); IMPL_CONV(DeformableConvBackwardData);
DeformableConvBackwardData::DeformableConvBackwardData( DeformableConvBackwardData::DeformableConvBackwardData(
VarNode* src, VarNode* filter, VarNode* offset, VarNode* mask, VarNode* src, VarNode* filter, VarNode* offset, VarNode* mask,
...@@ -1425,7 +1409,7 @@ void DeformableConvBackwardData::init_output_static_infer_desc() { ...@@ -1425,7 +1409,7 @@ void DeformableConvBackwardData::init_output_static_infer_desc() {
/* ==================== DeformableConvBackwardFilter ==================== */ /* ==================== DeformableConvBackwardFilter ==================== */
IMPL_CONV(DeformableConvBackwardFilter, "deformalbe_conv_backward_filter"); IMPL_CONV(DeformableConvBackwardFilter);
DeformableConvBackwardFilter::DeformableConvBackwardFilter( DeformableConvBackwardFilter::DeformableConvBackwardFilter(
VarNode* src, VarNode* filter, VarNode* offset, VarNode* mask, VarNode* src, VarNode* filter, VarNode* offset, VarNode* mask,
...@@ -1484,7 +1468,7 @@ size_t DeformableConvBackwardFilter::get_workspace_size_bytes( ...@@ -1484,7 +1468,7 @@ size_t DeformableConvBackwardFilter::get_workspace_size_bytes(
} }
/* ==================== BatchConvBiasForward ==================== */ /* ==================== BatchConvBiasForward ==================== */
IMPL_CONV(BatchConvBiasForward, "batch_conv_bias_fwd"); IMPL_CONV(BatchConvBiasForward);
BatchConvBiasForward::BatchConvBiasForward(VarNode* src, VarNode* filter, BatchConvBiasForward::BatchConvBiasForward(VarNode* src, VarNode* filter,
const Param& param, const Param& param,
......
...@@ -36,15 +36,29 @@ using mgb::opr::intl::WorkspaceLimitGetter; ...@@ -36,15 +36,29 @@ using mgb::opr::intl::WorkspaceLimitGetter;
// timeout delta to be added with fastest known algorithm for new algos // timeout delta to be added with fastest known algorithm for new algos
constexpr double TIMEOUT_TOLERANCE = 2; constexpr double TIMEOUT_TOLERANCE = 2;
#define CACHE_KEY_VERSION "v3"
namespace {
template <typename Opr>
std::string profile_name(Opr* opr) {
std::string ret =
std::string(MegDNNOpr2MGBOpr<Opr>::MGBOpr::typeinfo()->name) +
CACHE_KEY_VERSION;
ret.append(opr->get_algorithm_set_name());
return ret;
}
}
namespace mgb { namespace mgb {
namespace opr { namespace opr {
template <typename Opr> template <typename Opr>
AlgoChooserProfileCache::Result AlgoChooser<Opr>::get_profile_result( AlgoChooserProfileCache::Result AlgoChooser<Opr>::get_profile_result(
ExeContext& ctx, bool enable_update) { ExeContext& ctx, bool enable_update) {
AlgoChooserProfileCache& cache = ctx.mgb_opr()->profile_cache(); AlgoChooserProfileCache cache(ctx.mgb_opr()->comp_node(),
profile_name(ctx.megdnn_opr()).c_str());
ConvTensorLayouts origin_layouts = ctx.layouts(); TensorLayoutArray origin_layouts = ctx.layouts();
typename Opr::Param origin_param = ctx.mgb_opr()->param(); typename Opr::Param origin_param = ctx.mgb_opr()->param();
AlgoChooserProfileCache::Key cache_key{origin_layouts.data(), AlgoChooserProfileCache::Key cache_key{origin_layouts.data(),
origin_layouts.size(), &origin_param, origin_layouts.size(), &origin_param,
...@@ -131,12 +145,12 @@ typename AlgoChooser<Opr>::ImplAlgo AlgoChooser<Opr>::choose_by_profile( ...@@ -131,12 +145,12 @@ typename AlgoChooser<Opr>::ImplAlgo AlgoChooser<Opr>::choose_by_profile(
"profiling result but not in algo_map; please " "profiling result but not in algo_map; please "
"report this " "report this "
"bug; opr: %s{%s}, shapes: %s %s %s", "bug; opr: %s{%s}, shapes: %s %s %s",
i.algo.c_str(),
ctx.mgb_opr()->cname(), ctx.mgb_opr()->cname(),
ctx.mgb_opr()->dyn_typeinfo()->name, ctx.mgb_opr()->dyn_typeinfo()->name,
ctx.layouts()[0].TensorShape::to_string().c_str(), ctx.layouts()[0].TensorShape::to_string().c_str(),
ctx.layouts()[1].TensorShape::to_string().c_str(), ctx.layouts()[1].TensorShape::to_string().c_str(),
ctx.layouts()[2].TensorShape::to_string().c_str(), ctx.layouts()[2].TensorShape::to_string().c_str());
i.algo.c_str());
return iter->second; return iter->second;
} }
} }
...@@ -153,7 +167,7 @@ typename AlgoChooser<Opr>::ImplAlgo AlgoChooser<Opr>::choose_by_profile( ...@@ -153,7 +167,7 @@ typename AlgoChooser<Opr>::ImplAlgo AlgoChooser<Opr>::choose_by_profile(
} }
template <typename Opr> template <typename Opr>
size_t AlgoChooser<Opr>::setup_algo(const ConvTensorLayouts& layouts, size_t AlgoChooser<Opr>::setup_algo(const TensorLayoutArray& layouts,
Opr* megdnn_opr, const MGBOpr* mgb_opr, Opr* megdnn_opr, const MGBOpr* mgb_opr,
bool allow_weight_preprocess) { bool allow_weight_preprocess) {
if (WorkspaceLimitGetter::is_prealloc_run(mgb_opr->owner_graph())) { if (WorkspaceLimitGetter::is_prealloc_run(mgb_opr->owner_graph())) {
...@@ -220,7 +234,7 @@ typename AlgoChooser<Opr>::ImplAlgo AlgoChooser<Opr>::get_algo( ...@@ -220,7 +234,7 @@ typename AlgoChooser<Opr>::ImplAlgo AlgoChooser<Opr>::get_algo(
AlgoChooser<megdnn::Opr>::choose_by_profile( \ AlgoChooser<megdnn::Opr>::choose_by_profile( \
ExeContext& ctx, bool require_reproducible, bool enable_update); \ ExeContext& ctx, bool require_reproducible, bool enable_update); \
template size_t AlgoChooser<megdnn::Opr>::setup_algo( \ template size_t AlgoChooser<megdnn::Opr>::setup_algo( \
const ConvTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \ const TensorLayoutArray& layouts, megdnn::Opr* megdnn_opr, \
const MGBOpr* mgb_opr, bool allow_weight_preprocess); const MGBOpr* mgb_opr, bool allow_weight_preprocess);
MGB_FOREACH_FASTRUN_OPR(INST) MGB_FOREACH_FASTRUN_OPR(INST)
......
...@@ -74,12 +74,8 @@ class Convolution { ...@@ -74,12 +74,8 @@ class Convolution {
mutable bool m_policy_accessed = false; mutable bool m_policy_accessed = false;
ExecutionPolicy m_policy; ExecutionPolicy m_policy;
std::unique_ptr<AlgoChooserProfileCache> m_profile_cache;
AlgoChooserHook m_algo_chooser; AlgoChooserHook m_algo_chooser;
virtual void init_profile_cache() = 0;
//! init output desc for conv backward data oprs; it handles both grad //! init output desc for conv backward data oprs; it handles both grad
//! usage and deconv usage //! usage and deconv usage
template <class MgbOpr, class MegDNNOpr> template <class MgbOpr, class MegDNNOpr>
...@@ -159,7 +155,6 @@ class ConvolutionTestingPeer; ...@@ -159,7 +155,6 @@ class ConvolutionTestingPeer;
MGB_DEFINE_OPR_CLASS(ConvolutionForward, MGB_DEFINE_OPR_CLASS(ConvolutionForward,
intl::ConvolutionForwardBase, public mixin::Convolution) // { intl::ConvolutionForwardBase, public mixin::Convolution) // {
void init_profile_cache() override;
void init_output_dtype() override; void init_output_dtype() override;
size_t get_workspace_size_bytes( size_t get_workspace_size_bytes(
const TensorShapeArray &input_shapes, const TensorShapeArray &input_shapes,
...@@ -245,7 +240,6 @@ public: ...@@ -245,7 +240,6 @@ public:
const ExecutionPolicy& policy = {}, const ExecutionPolicy& policy = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
void init_profile_cache() override;
std::pair<const void*, size_t> param_blob() const override; std::pair<const void*, size_t> param_blob() const override;
static void check_winograd_param_valid( static void check_winograd_param_valid(
...@@ -268,7 +262,6 @@ MGB_DEFINE_OPR_CLASS(ConvolutionBackwardData, ...@@ -268,7 +262,6 @@ MGB_DEFINE_OPR_CLASS(ConvolutionBackwardData,
void init_output_format() override; void init_output_format() override;
void add_input_layout_constraint() override; void add_input_layout_constraint() override;
void init_profile_cache() override;
void scn_do_execute() override; void scn_do_execute() override;
NodeProp *do_make_node_prop() const override; NodeProp *do_make_node_prop() const override;
...@@ -310,7 +303,6 @@ MGB_DEFINE_OPR_CLASS(ConvolutionBackwardFilter, ...@@ -310,7 +303,6 @@ MGB_DEFINE_OPR_CLASS(ConvolutionBackwardFilter,
intl::MegDNNOprWrapperBwd<megdnn::ConvolutionBackwardFilter>, intl::MegDNNOprWrapperBwd<megdnn::ConvolutionBackwardFilter>,
public mixin::Convolution ) // { public mixin::Convolution ) // {
void init_profile_cache() override final;
size_t get_workspace_size_bytes( size_t get_workspace_size_bytes(
const TensorShapeArray &input_shapes, const TensorShapeArray &input_shapes,
...@@ -360,7 +352,6 @@ MGB_DEFINE_OPR_CLASS(Convolution3DForward, ...@@ -360,7 +352,6 @@ MGB_DEFINE_OPR_CLASS(Convolution3DForward,
intl::MegDNNOprWrapperFwd<megdnn::Convolution3DForward>, intl::MegDNNOprWrapperFwd<megdnn::Convolution3DForward>,
public mixin::Convolution) // { public mixin::Convolution) // {
void init_profile_cache() override;
void init_output_dtype() override; void init_output_dtype() override;
size_t get_workspace_size_bytes( size_t get_workspace_size_bytes(
const TensorShapeArray &input_shapes, const TensorShapeArray &input_shapes,
...@@ -391,7 +382,6 @@ MGB_DEFINE_OPR_CLASS(Convolution3DBackwardData, ...@@ -391,7 +382,6 @@ MGB_DEFINE_OPR_CLASS(Convolution3DBackwardData,
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;
void add_input_layout_constraint() override; void add_input_layout_constraint() override;
void init_profile_cache() override;
void scn_do_execute() override; void scn_do_execute() override;
NodeProp *do_make_node_prop() const override; NodeProp *do_make_node_prop() const override;
...@@ -433,8 +423,6 @@ MGB_DEFINE_OPR_CLASS(Convolution3DBackwardFilter, ...@@ -433,8 +423,6 @@ MGB_DEFINE_OPR_CLASS(Convolution3DBackwardFilter,
intl::MegDNNOprWrapperBwd<megdnn::Convolution3DBackwardFilter>, intl::MegDNNOprWrapperBwd<megdnn::Convolution3DBackwardFilter>,
public mixin::Convolution) // { public mixin::Convolution) // {
void init_profile_cache() override final;
size_t get_workspace_size_bytes( size_t get_workspace_size_bytes(
const TensorShapeArray &input_shapes, const TensorShapeArray &input_shapes,
const TensorShapeArray &output_shapes) const override final; const TensorShapeArray &output_shapes) const override final;
...@@ -455,7 +443,6 @@ MGB_DEFINE_OPR_CLASS(Convolution3DBackwardFilter, ...@@ -455,7 +443,6 @@ MGB_DEFINE_OPR_CLASS(Convolution3DBackwardFilter,
MGB_DEFINE_OPR_CLASS(LocalShareForward, MGB_DEFINE_OPR_CLASS(LocalShareForward,
intl::MegDNNOprWrapperFwd<megdnn::LocalShareForward>, intl::MegDNNOprWrapperFwd<megdnn::LocalShareForward>,
public mixin::Convolution) // { public mixin::Convolution) // {
void init_profile_cache() override final;
void init_output_dtype() override; void init_output_dtype() override;
void init_output_format() override; void init_output_format() override;
...@@ -483,7 +470,6 @@ MGB_DEFINE_OPR_CLASS( ...@@ -483,7 +470,6 @@ MGB_DEFINE_OPR_CLASS(
void init_output_dtype() override; void init_output_dtype() override;
void add_input_layout_constraint() override; void add_input_layout_constraint() override;
void init_profile_cache() override;
void scn_do_execute() override; void scn_do_execute() override;
NodeProp* do_make_node_prop() const override; NodeProp* do_make_node_prop() const override;
...@@ -506,7 +492,6 @@ MGB_DEFINE_OPR_CLASS( ...@@ -506,7 +492,6 @@ MGB_DEFINE_OPR_CLASS(
LocalShareBackwardFilter, LocalShareBackwardFilter,
intl::MegDNNOprWrapperBwd<megdnn::LocalShareBackwardFilter>, intl::MegDNNOprWrapperBwd<megdnn::LocalShareBackwardFilter>,
public mixin::Convolution) // { public mixin::Convolution) // {
void init_profile_cache() override final;
size_t get_workspace_size_bytes( size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes, const TensorShapeArray& input_shapes,
...@@ -542,7 +527,6 @@ MGB_DEFINE_OPR_CLASS(DeformableConvForward, ...@@ -542,7 +527,6 @@ MGB_DEFINE_OPR_CLASS(DeformableConvForward,
std::pair<const void*, size_t> param_blob() const override; std::pair<const void*, size_t> param_blob() const override;
private: private:
void init_profile_cache() override;
void init_output_dtype() override; void init_output_dtype() override;
void init_output_format() override; void init_output_format() override;
size_t get_workspace_size_bytes( size_t get_workspace_size_bytes(
...@@ -589,7 +573,6 @@ private: ...@@ -589,7 +573,6 @@ private:
void add_input_layout_constraint() override { void add_input_layout_constraint() override {
mixin::megdnn_utils::add_input_layout_constraint_contig(*this); mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
} }
void init_profile_cache() override;
}; };
MGB_DEFINE_OPR_CLASS( MGB_DEFINE_OPR_CLASS(
...@@ -612,7 +595,6 @@ public: ...@@ -612,7 +595,6 @@ public:
std::pair<const void*, size_t> param_blob() const override; std::pair<const void*, size_t> param_blob() const override;
private: private:
void init_profile_cache() override;
size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const TensorShapeArray& output_shapes)
const override final; const override final;
...@@ -668,7 +650,6 @@ public: ...@@ -668,7 +650,6 @@ public:
const ExecutionPolicy& policy = {}, const ExecutionPolicy& policy = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
void init_profile_cache() override;
std::pair<const void*, size_t> param_blob() const override; std::pair<const void*, size_t> param_blob() const override;
}; };
using BatchConvBias = BatchConvBiasForward; using BatchConvBias = BatchConvBiasForward;
......
...@@ -48,16 +48,16 @@ class AlgoChooser { ...@@ -48,16 +48,16 @@ class AlgoChooser {
using ImplAlgo = typename Opr::AlgorithmInfo; using ImplAlgo = typename Opr::AlgorithmInfo;
using MGBOpr = typename MegDNNOpr2MGBOpr<Opr>::MGBOpr; using MGBOpr = typename MegDNNOpr2MGBOpr<Opr>::MGBOpr;
using ConvTensorLayouts = std::array<TensorLayout, arity>; using TensorLayoutArray = std::array<TensorLayout, arity>;
class ExeContext { class ExeContext {
const ConvTensorLayouts& m_layouts; const TensorLayoutArray& m_layouts;
Opr* m_megdnn_opr; Opr* m_megdnn_opr;
const MGBOpr* m_mgb_opr; const MGBOpr* m_mgb_opr;
bool m_allow_weight_preprocess; bool m_allow_weight_preprocess;
public: public:
ExeContext(const ConvTensorLayouts& layouts, Opr* megdnn_opr, ExeContext(const TensorLayoutArray& layouts, Opr* megdnn_opr,
const MGBOpr* mgb_opr, bool allow_weight_preprocess) const MGBOpr* mgb_opr, bool allow_weight_preprocess)
: m_layouts{layouts}, : m_layouts{layouts},
m_megdnn_opr{megdnn_opr}, m_megdnn_opr{megdnn_opr},
...@@ -65,9 +65,9 @@ class AlgoChooser { ...@@ -65,9 +65,9 @@ class AlgoChooser {
m_allow_weight_preprocess{allow_weight_preprocess} { m_allow_weight_preprocess{allow_weight_preprocess} {
mgb_assert(m_layouts.size() == layouts.size()); mgb_assert(m_layouts.size() == layouts.size());
static_assert( static_assert(
std::tuple_size<ConvTensorLayouts>::value == 3 || std::tuple_size<TensorLayoutArray>::value == 3 ||
std::tuple_size<ConvTensorLayouts>::value == 5 || std::tuple_size<TensorLayoutArray>::value == 5 ||
std::tuple_size<ConvTensorLayouts>::value == 8, std::tuple_size<TensorLayoutArray>::value == 8,
"Convolution AlgoChooser assumes arity = 3 , 5 or 8 (for " "Convolution AlgoChooser assumes arity = 3 , 5 or 8 (for "
"deformable conv)"); "deformable conv)");
} }
...@@ -80,7 +80,7 @@ class AlgoChooser { ...@@ -80,7 +80,7 @@ class AlgoChooser {
return m_layouts[idx]; return m_layouts[idx];
} }
const ConvTensorLayouts& layouts() const { return m_layouts; } const TensorLayoutArray& layouts() const { return m_layouts; }
ImplAlgo choose_by_heuristic(bool reproducible = false) const; ImplAlgo choose_by_heuristic(bool reproducible = false) const;
...@@ -125,7 +125,7 @@ public: ...@@ -125,7 +125,7 @@ public:
/*! /*!
* \brief setup algorithm and return workspace size * \brief setup algorithm and return workspace size
*/ */
static size_t setup_algo(const ConvTensorLayouts& layouts, Opr* megdnn_opr, static size_t setup_algo(const TensorLayoutArray& layouts, Opr* megdnn_opr,
const MGBOpr* mgb_opr, const MGBOpr* mgb_opr,
bool allow_weight_preprocess = false); bool allow_weight_preprocess = false);
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册