提交 6c9b3a58 编写于 作者: M Megvii Engine Team

refactor(dnn): remove algorithm cache queries

GitOrigin-RevId: b7a1dc62d8c4fd5aae778f9a2f48314890f5beba
上级 8563f514
...@@ -51,15 +51,6 @@ PoolingImpl::AlgoPack PoolingImpl::sm_algo_pack; ...@@ -51,15 +51,6 @@ PoolingImpl::AlgoPack PoolingImpl::sm_algo_pack;
size_t PoolingImpl::get_workspace_in_bytes( size_t PoolingImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) { const TensorLayout& src, const TensorLayout& dst) {
TensorLayoutArray layouts{src, dst};
AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(),
&this->param(), sizeof(this->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
auto param = make_pooling_kern_szie_param(this, src, dst); auto param = make_pooling_kern_szie_param(this, src, dst);
auto algo = get_algorithm(this, src, dst); auto algo = get_algorithm(this, src, dst);
if (!is_fallback_algo(algo)) { if (!is_fallback_algo(algo)) {
......
...@@ -13,14 +13,6 @@ namespace megdnn { ...@@ -13,14 +13,6 @@ namespace megdnn {
template <class Opr, typename... Args> template <class Opr, typename... Args>
size_t get_dnn_workspace(Opr* opr, Args&&... args) { size_t get_dnn_workspace(Opr* opr, Args&&... args) {
TensorLayoutArray layouts{{args...}};
AlgorithmCache::Key key{opr->handle(), opr->get_opr_type(), layouts.data(),
layouts.size(), &opr->param(), sizeof(opr->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
typename Opr::AlgoBase::SizeArgs size_args(opr, std::forward<Args>(args)...); typename Opr::AlgoBase::SizeArgs size_args(opr, std::forward<Args>(args)...);
return get_algorithm(opr, std::forward<Args>(args)...) return get_algorithm(opr, std::forward<Args>(args)...)
->get_workspace_in_bytes(size_args); ->get_workspace_in_bytes(size_args);
...@@ -32,6 +24,7 @@ size_t get_dnn_workspace(Opr* opr, Args&&... args) { ...@@ -32,6 +24,7 @@ size_t get_dnn_workspace(Opr* opr, Args&&... args) {
template <class Opr, typename... Args> template <class Opr, typename... Args>
typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) {
typename Opr::AlgorithmDesc ret; typename Opr::AlgorithmDesc ret;
// first check self configured algorithm
auto set = opr->execution_policy().algo; auto set = opr->execution_policy().algo;
if (set.valid()) { if (set.valid()) {
ret = set; ret = set;
...@@ -40,10 +33,12 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { ...@@ -40,10 +33,12 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) {
AlgorithmCache::Key key{opr->handle(), opr->get_opr_type(), AlgorithmCache::Key key{opr->handle(), opr->get_opr_type(),
layouts.data(), layouts.size(), layouts.data(), layouts.size(),
&opr->param(), sizeof(opr->param())}; &opr->param(), sizeof(opr->param())};
// then get from global algorithm cache
auto rst = AlgorithmCache::instance().get(key); auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) { if (rst.policy.algo.valid()) {
ret = rst.policy.algo; ret = rst.policy.algo;
} else { } else {
// finally get pre-defined heuristic algorithm
ret = opr->get_algorithm_info_heuristic( ret = opr->get_algorithm_info_heuristic(
std::forward<Args>(args)..., std::forward<Args>(args)...,
std::numeric_limits<size_t>::max(), AlgoAttribute::DEFAULT, std::numeric_limits<size_t>::max(), AlgoAttribute::DEFAULT,
......
...@@ -44,14 +44,6 @@ WorkspaceBundle BatchConvBiasForwardImpl::get_workspace_bundle( ...@@ -44,14 +44,6 @@ WorkspaceBundle BatchConvBiasForwardImpl::get_workspace_bundle(
size_t BatchConvBiasForwardImpl::get_workspace_in_bytes( size_t BatchConvBiasForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& flt, const TensorLayout& bias, const TensorLayout& src, const TensorLayout& flt, const TensorLayout& bias,
const TensorLayout& z, const TensorLayout& dst) { const TensorLayout& z, const TensorLayout& dst) {
TensorLayoutArray layouts{src, flt, bias, z, dst};
AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(),
&this->param(), sizeof(this->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
return get_workspace_bundle(nullptr, src, flt, bias, z, dst).total_size_in_bytes(); return get_workspace_bundle(nullptr, src, flt, bias, z, dst).total_size_in_bytes();
} }
......
...@@ -187,15 +187,6 @@ void forward_bias<dt_quint4, dt_qint4, dt_qint32, dt_qint32>( ...@@ -187,15 +187,6 @@ void forward_bias<dt_quint4, dt_qint4, dt_qint32, dt_qint32>(
size_t ConvBiasForwardImpl::get_workspace_in_bytes( size_t ConvBiasForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& flt, const TensorLayout& bias, const TensorLayout& src, const TensorLayout& flt, const TensorLayout& bias,
const TensorLayout& z, const TensorLayout& dst, const PreprocessedFilter*) { const TensorLayout& z, const TensorLayout& dst, const PreprocessedFilter*) {
TensorLayoutArray layouts{src, flt, bias, z, dst};
AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(),
&this->param(), sizeof(this->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
size_t float_workspace_size = 0; size_t float_workspace_size = 0;
if (z.ndim > 0 && z.dtype.category() != DTypeCategory::FLOAT) { if (z.ndim > 0 && z.dtype.category() != DTypeCategory::FLOAT) {
......
...@@ -66,15 +66,6 @@ void ConvolutionForwardImpl::exec( ...@@ -66,15 +66,6 @@ void ConvolutionForwardImpl::exec(
size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes( size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) { const TensorLayout& grad) {
TensorLayoutArray layouts{filter, diff, grad};
AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(),
&this->param(), sizeof(this->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
size_t workspace_size = 0; size_t workspace_size = 0;
auto flt_dt = filter.dtype.enumv(); auto flt_dt = filter.dtype.enumv();
auto grad_dt = grad.dtype.enumv(); auto grad_dt = grad.dtype.enumv();
...@@ -178,15 +169,6 @@ size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes( ...@@ -178,15 +169,6 @@ size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) { const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) {
size_t workspace_size = 0; size_t workspace_size = 0;
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
TensorLayoutArray layouts{src, diff, grad};
AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(),
&this->param(), sizeof(this->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
auto src_dt = src.dtype.enumv(); auto src_dt = src.dtype.enumv();
auto grad_dt = grad.dtype.enumv(); auto grad_dt = grad.dtype.enumv();
auto diff_dt = diff.dtype.enumv(); auto diff_dt = diff.dtype.enumv();
......
...@@ -397,14 +397,6 @@ WorkspaceBundle PoolingForwardImpl::get_workspace_bundle( ...@@ -397,14 +397,6 @@ WorkspaceBundle PoolingForwardImpl::get_workspace_bundle(
size_t PoolingForwardImpl::get_workspace_in_bytes( size_t PoolingForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) { const TensorLayout& src, const TensorLayout& dst) {
TensorLayoutArray layouts{src, dst};
AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(),
&this->param(), sizeof(this->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
return get_workspace_bundle(nullptr, src, dst).total_size_in_bytes(); return get_workspace_bundle(nullptr, src, dst).total_size_in_bytes();
} }
namespace { namespace {
...@@ -649,14 +641,6 @@ WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle( ...@@ -649,14 +641,6 @@ WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle(
size_t PoolingBackwardImpl::get_workspace_in_bytes( size_t PoolingBackwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
const TensorLayout& grad) { const TensorLayout& grad) {
TensorLayoutArray layouts{src, dst, diff, grad};
AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(),
&this->param(), sizeof(this->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
return get_workspace_bundle(nullptr, src, dst, diff, grad).total_size_in_bytes(); return get_workspace_bundle(nullptr, src, dst, diff, grad).total_size_in_bytes();
} }
......
...@@ -104,15 +104,6 @@ std::vector<ConvolutionForwardImpl::Algorithm*> ConvolutionForwardImpl:: ...@@ -104,15 +104,6 @@ std::vector<ConvolutionForwardImpl::Algorithm*> ConvolutionForwardImpl::
size_t ConvolutionForwardImpl::get_workspace_in_bytes( size_t ConvolutionForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
const PreprocessedFilter*) { const PreprocessedFilter*) {
TensorLayoutArray layouts{src, filter, dst};
AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(),
&this->param(), sizeof(this->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
AlgoBase::SizeArgs args(this, src, filter, dst); AlgoBase::SizeArgs args(this, src, filter, dst);
return get_algorithm(this, src, filter, dst)->get_workspace_in_bytes(args); return get_algorithm(this, src, filter, dst)->get_workspace_in_bytes(args);
} }
...@@ -198,15 +189,6 @@ ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl:: ...@@ -198,15 +189,6 @@ ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::
size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes( size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) { const TensorLayout& grad) {
TensorLayoutArray layouts{filter, diff, grad};
AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(),
&this->param(), sizeof(this->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
AlgoBase::SizeArgs args(this, filter, diff, grad); AlgoBase::SizeArgs args(this, filter, diff, grad);
return get_algorithm(this, filter, diff, grad)->get_workspace_in_bytes(args); return get_algorithm(this, filter, diff, grad)->get_workspace_in_bytes(args);
} }
...@@ -282,15 +264,6 @@ ConvolutionBackwardFilterImpl::Algorithm* ConvolutionBackwardFilterImpl:: ...@@ -282,15 +264,6 @@ ConvolutionBackwardFilterImpl::Algorithm* ConvolutionBackwardFilterImpl::
size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes( size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) { const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) {
TensorLayoutArray layouts{src, diff, grad};
AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(),
&this->param(), sizeof(this->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
AlgoBase::SizeArgs args(this, src, diff, grad); AlgoBase::SizeArgs args(this, src, diff, grad);
return get_algorithm(this, src, diff, grad)->get_workspace_in_bytes(args); return get_algorithm(this, src, diff, grad)->get_workspace_in_bytes(args);
} }
......
...@@ -35,15 +35,6 @@ WorkspaceBundle megdnn::x86::get_bundle( ...@@ -35,15 +35,6 @@ WorkspaceBundle megdnn::x86::get_bundle(
size_t PoolingImpl::get_workspace_in_bytes( size_t PoolingImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) { const TensorLayout& src, const TensorLayout& dst) {
TensorLayoutArray layouts{src, dst};
AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(),
&this->param(), sizeof(this->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
auto algo = get_algorithm(this, src, dst); auto algo = get_algorithm(this, src, dst);
if (!is_fallback_algo(algo)) { if (!is_fallback_algo(algo)) {
if (is_supported(SIMDType::SSE) && src.dtype == dtype::Float32() && if (is_supported(SIMDType::SSE) && src.dtype == dtype::Float32() &&
......
...@@ -351,7 +351,7 @@ class TimedFuncInvokerImpl final : public TimedFuncInvoker { ...@@ -351,7 +351,7 @@ class TimedFuncInvokerImpl final : public TimedFuncInvoker {
} else { } else {
CHECK_SYS_ERR(cur_recv); CHECK_SYS_ERR(cur_recv);
} }
mgb_assert(cur_recv > 0); mgb_assert(cur_recv >= 0);
dest += cur_recv; dest += cur_recv;
size -= cur_recv; size -= cur_recv;
} }
......
...@@ -950,10 +950,12 @@ void AlgoChooser<Opr>::AlgoChooserHelper::profile( ...@@ -950,10 +950,12 @@ void AlgoChooser<Opr>::AlgoChooserHelper::profile(
algo.desc.name.c_str(), layouts_str.c_str()); algo.desc.name.c_str(), layouts_str.c_str());
timer.reset(); timer.reset();
MGB_TRY { cur_rst = profile_single_algo(policy, cur_timeout); } MGB_TRY { cur_rst = profile_single_algo(policy, cur_timeout); }
// megbrain catched exception
MGB_CATCH(std::exception & exc, { MGB_CATCH(std::exception & exc, {
mgb_log_warn("caught exception during %s: %s", msg.c_str(), exc.what()); mgb_log_debug("caught exception during %s: %s", msg.c_str(), exc.what());
continue; continue;
}) })
// megbrain uncatched exception
MGB_CATCH(..., { MGB_CATCH(..., {
mgb_log_warn("caught exception during %s", msg.c_str()); mgb_log_warn("caught exception during %s", msg.c_str());
continue; continue;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册