提交 6c9b3a58 编写于 作者: M Megvii Engine Team

refactor(dnn): remove algorithm cache queries

GitOrigin-RevId: b7a1dc62d8c4fd5aae778f9a2f48314890f5beba
上级 8563f514
......@@ -51,15 +51,6 @@ PoolingImpl::AlgoPack PoolingImpl::sm_algo_pack;
size_t PoolingImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) {
TensorLayoutArray layouts{src, dst};
AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(),
&this->param(), sizeof(this->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
auto param = make_pooling_kern_szie_param(this, src, dst);
auto algo = get_algorithm(this, src, dst);
if (!is_fallback_algo(algo)) {
......
......@@ -13,14 +13,6 @@ namespace megdnn {
template <class Opr, typename... Args>
size_t get_dnn_workspace(Opr* opr, Args&&... args) {
TensorLayoutArray layouts{{args...}};
AlgorithmCache::Key key{opr->handle(), opr->get_opr_type(), layouts.data(),
layouts.size(), &opr->param(), sizeof(opr->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
typename Opr::AlgoBase::SizeArgs size_args(opr, std::forward<Args>(args)...);
return get_algorithm(opr, std::forward<Args>(args)...)
->get_workspace_in_bytes(size_args);
......@@ -32,6 +24,7 @@ size_t get_dnn_workspace(Opr* opr, Args&&... args) {
template <class Opr, typename... Args>
typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) {
typename Opr::AlgorithmDesc ret;
// first check self configured algorithm
auto set = opr->execution_policy().algo;
if (set.valid()) {
ret = set;
......@@ -40,10 +33,12 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) {
AlgorithmCache::Key key{opr->handle(), opr->get_opr_type(),
layouts.data(), layouts.size(),
&opr->param(), sizeof(opr->param())};
// then get from global algorithm cache
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
ret = rst.policy.algo;
} else {
// finally get pre-defined heuristic algorithm
ret = opr->get_algorithm_info_heuristic(
std::forward<Args>(args)...,
std::numeric_limits<size_t>::max(), AlgoAttribute::DEFAULT,
......
......@@ -44,14 +44,6 @@ WorkspaceBundle BatchConvBiasForwardImpl::get_workspace_bundle(
size_t BatchConvBiasForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& flt, const TensorLayout& bias,
const TensorLayout& z, const TensorLayout& dst) {
TensorLayoutArray layouts{src, flt, bias, z, dst};
AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(),
&this->param(), sizeof(this->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
return get_workspace_bundle(nullptr, src, flt, bias, z, dst).total_size_in_bytes();
}
......
......@@ -187,15 +187,6 @@ void forward_bias<dt_quint4, dt_qint4, dt_qint32, dt_qint32>(
size_t ConvBiasForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& flt, const TensorLayout& bias,
const TensorLayout& z, const TensorLayout& dst, const PreprocessedFilter*) {
TensorLayoutArray layouts{src, flt, bias, z, dst};
AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(),
&this->param(), sizeof(this->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
size_t float_workspace_size = 0;
if (z.ndim > 0 && z.dtype.category() != DTypeCategory::FLOAT) {
......
......@@ -66,15 +66,6 @@ void ConvolutionForwardImpl::exec(
size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) {
TensorLayoutArray layouts{filter, diff, grad};
AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(),
&this->param(), sizeof(this->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
size_t workspace_size = 0;
auto flt_dt = filter.dtype.enumv();
auto grad_dt = grad.dtype.enumv();
......@@ -178,15 +169,6 @@ size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) {
size_t workspace_size = 0;
#if !MEGDNN_DISABLE_FLOAT16
TensorLayoutArray layouts{src, diff, grad};
AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(),
&this->param(), sizeof(this->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
auto src_dt = src.dtype.enumv();
auto grad_dt = grad.dtype.enumv();
auto diff_dt = diff.dtype.enumv();
......
......@@ -397,14 +397,6 @@ WorkspaceBundle PoolingForwardImpl::get_workspace_bundle(
size_t PoolingForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) {
TensorLayoutArray layouts{src, dst};
AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(),
&this->param(), sizeof(this->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
return get_workspace_bundle(nullptr, src, dst).total_size_in_bytes();
}
namespace {
......@@ -649,14 +641,6 @@ WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle(
size_t PoolingBackwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
const TensorLayout& grad) {
TensorLayoutArray layouts{src, dst, diff, grad};
AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(),
&this->param(), sizeof(this->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
return get_workspace_bundle(nullptr, src, dst, diff, grad).total_size_in_bytes();
}
......
......@@ -104,15 +104,6 @@ std::vector<ConvolutionForwardImpl::Algorithm*> ConvolutionForwardImpl::
size_t ConvolutionForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
const PreprocessedFilter*) {
TensorLayoutArray layouts{src, filter, dst};
AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(),
&this->param(), sizeof(this->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
AlgoBase::SizeArgs args(this, src, filter, dst);
return get_algorithm(this, src, filter, dst)->get_workspace_in_bytes(args);
}
......@@ -198,15 +189,6 @@ ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::
size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) {
TensorLayoutArray layouts{filter, diff, grad};
AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(),
&this->param(), sizeof(this->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
AlgoBase::SizeArgs args(this, filter, diff, grad);
return get_algorithm(this, filter, diff, grad)->get_workspace_in_bytes(args);
}
......@@ -282,15 +264,6 @@ ConvolutionBackwardFilterImpl::Algorithm* ConvolutionBackwardFilterImpl::
size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) {
TensorLayoutArray layouts{src, diff, grad};
AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(),
&this->param(), sizeof(this->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
AlgoBase::SizeArgs args(this, src, diff, grad);
return get_algorithm(this, src, diff, grad)->get_workspace_in_bytes(args);
}
......
......@@ -35,15 +35,6 @@ WorkspaceBundle megdnn::x86::get_bundle(
size_t PoolingImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) {
TensorLayoutArray layouts{src, dst};
AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(),
&this->param(), sizeof(this->param())};
auto rst = AlgorithmCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
auto algo = get_algorithm(this, src, dst);
if (!is_fallback_algo(algo)) {
if (is_supported(SIMDType::SSE) && src.dtype == dtype::Float32() &&
......
......@@ -351,7 +351,7 @@ class TimedFuncInvokerImpl final : public TimedFuncInvoker {
} else {
CHECK_SYS_ERR(cur_recv);
}
mgb_assert(cur_recv > 0);
mgb_assert(cur_recv >= 0);
dest += cur_recv;
size -= cur_recv;
}
......
......@@ -950,10 +950,12 @@ void AlgoChooser<Opr>::AlgoChooserHelper::profile(
algo.desc.name.c_str(), layouts_str.c_str());
timer.reset();
MGB_TRY { cur_rst = profile_single_algo(policy, cur_timeout); }
// megbrain catched exception
MGB_CATCH(std::exception & exc, {
mgb_log_warn("caught exception during %s: %s", msg.c_str(), exc.what());
mgb_log_debug("caught exception during %s: %s", msg.c_str(), exc.what());
continue;
})
// megbrain uncatched exception
MGB_CATCH(..., {
mgb_log_warn("caught exception during %s", msg.c_str());
continue;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册