提交 a33c3b73 编写于 作者: M Megvii Engine Team 提交者: huangxinda

refactor(mgb/dnn): arm pooling rebase algochooser

GitOrigin-RevId: 21d17e647afdc349929ebc668639406e088e3c68
上级 8dea6b3c
...@@ -28,6 +28,7 @@ public: ...@@ -28,6 +28,7 @@ public:
const char* name() const override { return "ARM_POOLING_STRIDE1"; } const char* name() const override { return "ARM_POOLING_STRIDE1"; }
bool usable(const PoolingKernSizeParam& param) const override; bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override; void exec(const PoolingKernParam& param) const override;
MEGDNN_DECL_ALGO_TYPE(ARM_FilterxModexStride1)
}; };
class PoolingImpl::AlgoFilter2ModexStride2 final : public AlgoBase { class PoolingImpl::AlgoFilter2ModexStride2 final : public AlgoBase {
...@@ -38,6 +39,7 @@ public: ...@@ -38,6 +39,7 @@ public:
const char* name() const override { return "ARM_POOLING_STRIDE2"; } const char* name() const override { return "ARM_POOLING_STRIDE2"; }
bool usable(const PoolingKernSizeParam& param) const override; bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override; void exec(const PoolingKernParam& param) const override;
MEGDNN_DECL_ALGO_TYPE(ARM_Filter2ModexStride2)
}; };
class PoolingImpl::AlgoFilter3MaxStride2 final : public AlgoBase { class PoolingImpl::AlgoFilter3MaxStride2 final : public AlgoBase {
public: public:
...@@ -47,6 +49,7 @@ public: ...@@ -47,6 +49,7 @@ public:
const char* name() const override { return "ARM_POOLING_FILTER3_MAX"; } const char* name() const override { return "ARM_POOLING_FILTER3_MAX"; }
bool usable(const PoolingKernSizeParam& param) const override; bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override; void exec(const PoolingKernParam& param) const override;
MEGDNN_DECL_ALGO_TYPE(ARM_Filter3MaxStride2)
}; };
class PoolingImpl::AlgoFilter3AverageStride2 final : public AlgoBase { class PoolingImpl::AlgoFilter3AverageStride2 final : public AlgoBase {
...@@ -57,6 +60,7 @@ public: ...@@ -57,6 +60,7 @@ public:
const char* name() const override { return "ARM_POOLING_FILTER3_AVERAGE"; } const char* name() const override { return "ARM_POOLING_FILTER3_AVERAGE"; }
bool usable(const PoolingKernSizeParam& param) const override; bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override; void exec(const PoolingKernParam& param) const override;
MEGDNN_DECL_ALGO_TYPE(ARM_Filter3AverageStride2)
}; };
class PoolingImpl::AlgoFilter4MaxStride2 final : public AlgoBase { class PoolingImpl::AlgoFilter4MaxStride2 final : public AlgoBase {
...@@ -67,6 +71,7 @@ public: ...@@ -67,6 +71,7 @@ public:
const char* name() const override { return "ARM_POOLING_FILTER4_MAX"; } const char* name() const override { return "ARM_POOLING_FILTER4_MAX"; }
bool usable(const PoolingKernSizeParam& param) const override; bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override; void exec(const PoolingKernParam& param) const override;
MEGDNN_DECL_ALGO_TYPE(ARM_Filter4MaxStride2)
}; };
class PoolingImpl::AlgoFilter5MaxStride2 final : public AlgoBase { class PoolingImpl::AlgoFilter5MaxStride2 final : public AlgoBase {
...@@ -77,6 +82,7 @@ public: ...@@ -77,6 +82,7 @@ public:
const char* name() const override { return "ARM_POOLING_FILTER5_MAX"; } const char* name() const override { return "ARM_POOLING_FILTER5_MAX"; }
bool usable(const PoolingKernSizeParam& param) const override; bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override; void exec(const PoolingKernParam& param) const override;
MEGDNN_DECL_ALGO_TYPE(ARM_Filter5MaxStride2)
}; };
class PoolingImpl::AlgoInt8Filter2MaxStride2 final : public AlgoBase { class PoolingImpl::AlgoInt8Filter2MaxStride2 final : public AlgoBase {
...@@ -87,6 +93,7 @@ public: ...@@ -87,6 +93,7 @@ public:
const char* name() const override { return "ARM_POOLING_INT8_FILTER2X2"; } const char* name() const override { return "ARM_POOLING_INT8_FILTER2X2"; }
bool usable(const PoolingKernSizeParam& param) const override; bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override; void exec(const PoolingKernParam& param) const override;
MEGDNN_DECL_ALGO_TYPE(ARM_Int8Filter2MaxStride2)
}; };
class PoolingImpl::AlgoInt8Filter3MaxStride2 final : public AlgoBase { class PoolingImpl::AlgoInt8Filter3MaxStride2 final : public AlgoBase {
...@@ -97,6 +104,7 @@ public: ...@@ -97,6 +104,7 @@ public:
const char* name() const override { return "ARM_POOLING_INT8_FILTER3X3"; } const char* name() const override { return "ARM_POOLING_INT8_FILTER3X3"; }
bool usable(const PoolingKernSizeParam& param) const override; bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override; void exec(const PoolingKernParam& param) const override;
MEGDNN_DECL_ALGO_TYPE(ARM_Int8Filter3MaxStride2)
}; };
class PoolingImpl::AlgoFilter3ModexStridexNCHW44 final : public AlgoBase { class PoolingImpl::AlgoFilter3ModexStridexNCHW44 final : public AlgoBase {
...@@ -107,6 +115,7 @@ public: ...@@ -107,6 +115,7 @@ public:
const char* name() const override { return "ARM_POOLING_FILTER3_MODEX_STRIDEX_NCHW44"; } const char* name() const override { return "ARM_POOLING_FILTER3_MODEX_STRIDEX_NCHW44"; }
bool usable(const PoolingKernSizeParam& param) const override; bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override; void exec(const PoolingKernParam& param) const override;
MEGDNN_DECL_ALGO_TYPE(ARM_Filter3ModexStridexNCHW44)
}; };
class PoolingImpl::AlgoFilter2ModexStridexNCHW44 final : public AlgoBase { class PoolingImpl::AlgoFilter2ModexStridexNCHW44 final : public AlgoBase {
...@@ -117,6 +126,7 @@ public: ...@@ -117,6 +126,7 @@ public:
const char* name() const override { return "ARM_POOLING_FILTER2_MODEX_STRIDEX_NCHW44"; } const char* name() const override { return "ARM_POOLING_FILTER2_MODEX_STRIDEX_NCHW44"; }
bool usable(const PoolingKernSizeParam& param) const override; bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override; void exec(const PoolingKernParam& param) const override;
MEGDNN_DECL_ALGO_TYPE(ARM_Filter2ModexStridexNCHW44)
}; };
class PoolingImpl::AlgoFilter4ModexStridexNCHW44 final : public AlgoBase { class PoolingImpl::AlgoFilter4ModexStridexNCHW44 final : public AlgoBase {
...@@ -127,6 +137,7 @@ public: ...@@ -127,6 +137,7 @@ public:
const char* name() const override { return "ARM_POOLING_FILTER4_MODEX_STRIDEX_NCHW44"; } const char* name() const override { return "ARM_POOLING_FILTER4_MODEX_STRIDEX_NCHW44"; }
bool usable(const PoolingKernSizeParam& param) const override; bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override; void exec(const PoolingKernParam& param) const override;
MEGDNN_DECL_ALGO_TYPE(ARM_Filter4ModexStridexNCHW44)
}; };
class PoolingImpl::AlgoFilter5ModexStridexNCHW44 final : public AlgoBase { class PoolingImpl::AlgoFilter5ModexStridexNCHW44 final : public AlgoBase {
...@@ -137,6 +148,7 @@ public: ...@@ -137,6 +148,7 @@ public:
const char* name() const override { return "ARM_POOLING_FILTER5_MODEX_STRIDEX_NCHW44"; } const char* name() const override { return "ARM_POOLING_FILTER5_MODEX_STRIDEX_NCHW44"; }
bool usable(const PoolingKernSizeParam& param) const override; bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override; void exec(const PoolingKernParam& param) const override;
MEGDNN_DECL_ALGO_TYPE(ARM_Filter5ModexStridexNCHW44)
}; };
class PoolingImpl::AlgoFp32ModexStridexNCHW44 final : public AlgoBase { class PoolingImpl::AlgoFp32ModexStridexNCHW44 final : public AlgoBase {
public: public:
...@@ -146,6 +158,17 @@ public: ...@@ -146,6 +158,17 @@ public:
const char* name() const override { return "ARM_POOLING_FP32_MODEX_STRIDEX_NCHW44"; } const char* name() const override { return "ARM_POOLING_FP32_MODEX_STRIDEX_NCHW44"; }
bool usable(const PoolingKernSizeParam& param) const override; bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override; void exec(const PoolingKernParam& param) const override;
MEGDNN_DECL_ALGO_TYPE(ARM_Fp32ModexStridexNCHW44)
};
class PoolingImpl::AlgoFallback final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
};
const char* name() const override { return "FALLBACK_POOLING"; }
bool usable(const PoolingKernSizeParam&) const override { return true; }
void exec(const PoolingKernParam&) const override {}
MEGDNN_DECL_ALGO_TYPE(ARM_Fallback)
}; };
WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param); WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param);
......
...@@ -12,11 +12,14 @@ ...@@ -12,11 +12,14 @@
#include "src/arm_common/pooling/opr_impl.h" #include "src/arm_common/pooling/opr_impl.h"
#include "src/arm_common/pooling/algo.h" #include "src/arm_common/pooling/algo.h"
#include "src/common/metahelper.h" #include "src/common/metahelper.h"
#include "src/common/algo_chooser.h"
using namespace megdnn; using namespace megdnn;
using namespace arm_common; using namespace arm_common;
class PoolingImpl::AlgoPack : NonCopyableObj { class PoolingImpl::AlgoPack : NonCopyableObj {
private:
AlgoBase::Mapper m_all_algos_map;
AlgoFilterxModexStride1 algo_filterx_modex_stride1; AlgoFilterxModexStride1 algo_filterx_modex_stride1;
AlgoFilter2ModexStride2 algo_filter2_modex_stride2; AlgoFilter2ModexStride2 algo_filter2_modex_stride2;
AlgoFilter3MaxStride2 algo_filter3_max_stride2; AlgoFilter3MaxStride2 algo_filter3_max_stride2;
...@@ -30,6 +33,7 @@ class PoolingImpl::AlgoPack : NonCopyableObj { ...@@ -30,6 +33,7 @@ class PoolingImpl::AlgoPack : NonCopyableObj {
AlgoFilter4ModexStridexNCHW44 algo_filter4_modex_stridex_nchw4; AlgoFilter4ModexStridexNCHW44 algo_filter4_modex_stridex_nchw4;
AlgoFilter5ModexStridexNCHW44 algo_filter5_modex_stridex_nchw4; AlgoFilter5ModexStridexNCHW44 algo_filter5_modex_stridex_nchw4;
AlgoFp32ModexStridexNCHW44 algo_fp32_modex_stridex_nchw44; AlgoFp32ModexStridexNCHW44 algo_fp32_modex_stridex_nchw44;
AlgoFallback algo_fallback;
public: public:
AlgoPack() { AlgoPack() {
...@@ -46,10 +50,18 @@ public: ...@@ -46,10 +50,18 @@ public:
all_algos.emplace_back(&algo_filter4_modex_stridex_nchw4); all_algos.emplace_back(&algo_filter4_modex_stridex_nchw4);
all_algos.emplace_back(&algo_filter5_modex_stridex_nchw4); all_algos.emplace_back(&algo_filter5_modex_stridex_nchw4);
all_algos.emplace_back(&algo_fp32_modex_stridex_nchw44); all_algos.emplace_back(&algo_fp32_modex_stridex_nchw44);
all_algos.emplace_back(&algo_fallback);
for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
} }
SmallVector<AlgoBase*> all_algos; SmallVector<AlgoBase*> all_algos;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
}; };
PoolingImpl::AlgoPack PoolingImpl::sm_algo_pack;
PoolingImpl::PoolingKernSizeParam PoolingImpl::make_pooling_kern_szie_param( PoolingImpl::PoolingKernSizeParam PoolingImpl::make_pooling_kern_szie_param(
fallback::PoolingImpl* opr, const TensorLayout& src, fallback::PoolingImpl* opr, const TensorLayout& src,
const TensorLayout& dst) { const TensorLayout& dst) {
...@@ -89,44 +101,36 @@ PoolingImpl::PoolingKernParam PoolingImpl::make_pooling_kern_param( ...@@ -89,44 +101,36 @@ PoolingImpl::PoolingKernParam PoolingImpl::make_pooling_kern_param(
size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src, size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) { const TensorLayout& dst) {
bool find_algo = false;
static AlgoPack m_algo_pack;
auto param = make_pooling_kern_szie_param(this, src, dst); auto param = make_pooling_kern_szie_param(this, src, dst);
for (auto& m_algo : m_algo_pack.all_algos) { auto algo = get_algorithm(this, src, dst);
if (m_algo->usable(param)) { if (!is_fallback_algo(algo)) {
find_algo = true; size_t arm_common_workspace = 0;
break;
}
}
size_t arm_common_workspace = 0;
//! When multi-thread, every thread has its own workspace
size_t nr_threads = static_cast<naive::HandleImpl*>(handle())
->megcore_dispatcher()
->nr_threads();
if ((param.src_type.category() == DTypeCategory::FLOAT ||
param.src_type == dtype::Int8{} ||
param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Quantized8Asymm) &&
param.filter[0] == param.filter[1] &&
(param.filter[0] == 3 || param.filter[0] == 5) &&
param.format == Param::Format::NCHW &&
(param.mode == Mode::MAX ||
(param.mode == Mode::AVERAGE && param.filter[0] == 3)) &&
param.stride[0] == 2 && param.stride[1] == 2 && param.isz[0] >= 2 &&
param.isz[1] >= 2) {
WorkspaceBundle ws = get_bundle(param);
arm_common_workspace = ws.total_size_in_bytes() * nr_threads;
}
if ((param.src_type.enumv() == DTypeEnum::QuantizedS8 || //! When multi-thread, every thread has its own workspace
param.src_type.enumv() == DTypeEnum::Int8) && size_t nr_threads = static_cast<naive::HandleImpl*>(handle())
(param.format == param::Pooling::Format::NCHW44)) { ->megcore_dispatcher()
WorkspaceBundle ws = get_bundle_nchw44(param); ->nr_threads();
arm_common_workspace = ws.total_size_in_bytes() * nr_threads; if ((param.src_type.category() == DTypeCategory::FLOAT ||
} param.src_type == dtype::Int8{} ||
param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Quantized8Asymm) &&
param.filter[0] == param.filter[1] &&
(param.filter[0] == 3 || param.filter[0] == 5) &&
param.format == Param::Format::NCHW &&
(param.mode == Mode::MAX ||
(param.mode == Mode::AVERAGE && param.filter[0] == 3)) &&
param.stride[0] == 2 && param.stride[1] == 2 && param.isz[0] >= 2 &&
param.isz[1] >= 2) {
WorkspaceBundle ws = get_bundle(param);
arm_common_workspace = ws.total_size_in_bytes() * nr_threads;
}
if (find_algo) { if ((param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Int8) &&
(param.format == param::Pooling::Format::NCHW44)) {
WorkspaceBundle ws = get_bundle_nchw44(param);
arm_common_workspace = ws.total_size_in_bytes() * nr_threads;
}
return arm_common_workspace; return arm_common_workspace;
} else { } else {
auto fallback_worksapce = auto fallback_worksapce =
...@@ -139,14 +143,48 @@ void PoolingImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, ...@@ -139,14 +143,48 @@ void PoolingImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) { _megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size); check_exec(src.layout, dst.layout, workspace.size);
auto param = make_pooling_kern_param(this, src, dst, workspace); auto param = make_pooling_kern_param(this, src, dst, workspace);
static AlgoPack m_algo_pack; auto algo = get_algorithm(this, src.layout, dst.layout);
for (auto& m_algo : m_algo_pack.all_algos) { if (!is_fallback_algo(algo)) {
if (m_algo->usable(param)) { algo->exec(param);
m_algo->exec(param); } else {
return; fallback::PoolingImpl::exec(src, dst, workspace);
}
}
MEGDNN_DEF_GET_ALGO_FROM_DESC(PoolingImpl);
std::vector<Algorithm*> PoolingImpl::get_all_algorithms(
const TensorLayout& src, const TensorLayout& dst) {
auto param = make_pooling_kern_szie_param(this, src, dst);
std::vector<Algorithm*> ret;
ret.reserve(algo_pack().all_algos.size());
for (auto i : algo_pack().all_algos) {
if (i->usable(param)) {
ret.push_back(i);
}
}
megdnn_assert(!ret.empty(), "no usable pooling fwd algorithm");
return ret;
}
Algorithm* PoolingImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& dst,
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes);
auto param = make_pooling_kern_szie_param(this, src, dst);
for (auto&& iter : sm_algo_pack.all_algos) {
if (iter->is_available_attribute(param, positive_attr, negative_attr)) {
return iter;
} }
} }
fallback::PoolingImpl::exec(src, dst, workspace); megdnn_throw(
ssprintf("require algorithm with attribute(%s) and without "
"attribute(%s), but can't get suitable algo.\n",
Algorithm::attribute_str(positive_attr).c_str(),
Algorithm::attribute_str(negative_attr).c_str()));
return nullptr;
} }
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -12,11 +12,30 @@ ...@@ -12,11 +12,30 @@
#pragma once #pragma once
#include "megdnn/oprs/base.h" #include "megdnn/oprs/base.h"
#include "src/fallback/pooling/opr_impl.h" #include "src/fallback/pooling/opr_impl.h"
#include <unordered_map>
namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {
class PoolingImpl final : public fallback::PoolingImpl { class PoolingImpl final : public fallback::PoolingImpl {
private:
class AlgoFilterxModexStride1;
class AlgoFilter2ModexStride2;
class AlgoFilter3MaxStride2;
class AlgoFilter3AverageStride2;
class AlgoFilter4MaxStride2;
class AlgoFilter5MaxStride2;
class AlgoInt8Filter2MaxStride2;
class AlgoInt8Filter3MaxStride2;
class AlgoFilter2ModexStridexNCHW44;
class AlgoFilter3ModexStridexNCHW44;
class AlgoFilter4ModexStridexNCHW44;
class AlgoFilter5ModexStridexNCHW44;
class AlgoFp32ModexStridexNCHW44;
class AlgoFallback;
class AlgoPack;
static AlgoPack sm_algo_pack;
public: public:
using fallback::PoolingImpl::PoolingImpl; using fallback::PoolingImpl::PoolingImpl;
void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
...@@ -70,28 +89,68 @@ public: ...@@ -70,28 +89,68 @@ public:
_megdnn_workspace workspace); _megdnn_workspace workspace);
class AlgoBase : public detail::Algorithm { class AlgoBase : public detail::Algorithm {
public: public:
enum class AlgoType : uint32_t {
ARM_FilterxModexStride1,
ARM_Filter2ModexStride2,
ARM_Filter3MaxStride2,
ARM_Filter3AverageStride2,
ARM_Filter4MaxStride2,
ARM_Filter5MaxStride2,
ARM_Int8Filter2MaxStride2,
ARM_Int8Filter3MaxStride2,
ARM_Filter2ModexStridexNCHW44,
ARM_Filter3ModexStridexNCHW44,
ARM_Filter4ModexStridexNCHW44,
ARM_Filter5ModexStridexNCHW44,
ARM_Fp32ModexStridexNCHW44,
ARM_Fallback
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ARM_COMMON; }
virtual ~AlgoBase() = default; virtual ~AlgoBase() = default;
virtual bool usable(const PoolingKernSizeParam& param) const = 0; virtual bool usable(const PoolingKernSizeParam& param) const = 0;
virtual void exec(const PoolingKernParam& param) const = 0; virtual void exec(const PoolingKernParam& param) const = 0;
uint32_t type() const override { return INVALID_ALGO_TYPE; }; uint32_t type() const override { return INVALID_ALGO_TYPE; };
bool is_available_attribute(
const PoolingKernSizeParam& param,
const AlgoAttribute& positive_attr =
AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) && usable(param);
}
}; };
private: const char* get_algorithm_set_name() const override {
class AlgoFilterxModexStride1; return "ARM_POOLING_FORWARD";
class AlgoFilter2ModexStride2; }
class AlgoFilter3MaxStride2;
class AlgoFilter3AverageStride2; Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
class AlgoFilter4MaxStride2;
class AlgoFilter5MaxStride2; std::vector<Algorithm*> get_all_algorithms(
class AlgoInt8Filter2MaxStride2; const TensorLayout& src, const TensorLayout& dst) override;
class AlgoInt8Filter3MaxStride2;
class AlgoFilter2ModexStridexNCHW44; Algorithm* get_algorithm_heuristic(
class AlgoFilter3ModexStridexNCHW44; const TensorLayout& src, const TensorLayout& dst,
class AlgoFilter4ModexStridexNCHW44; size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
class AlgoFilter5ModexStridexNCHW44; const AlgoAttribute& negative_attr) override;
class AlgoFp32ModexStridexNCHW44;
class AlgoPack; AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& src, const TensorLayout& dst,
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
return get_algorithm_heuristic(src, dst, workspace_limit_in_bytes,
positive_attr, negative_attr)
->info();
}
static const AlgoPack& algo_pack() { return sm_algo_pack; }
bool is_fallback_algo(Algorithm* algo) {
return strcmp(algo->name(), "FALLBACK_POOLING") == 0;
}
}; };
} // namespace arm_common } // namespace arm_common
} // namespace megdnn } // namespace megdnn
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册