提交 694aa1bd 编写于 作者: M Megvii Engine Team

feat(dnn): add heuristic cache

GitOrigin-RevId: 35e942b5e39c60d9d4e0ffe41497103e8c0a8822
上级 ca4374b6
/**
* \file dnn/include/megdnn/heuristic_cache.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/basic_types.h"
#include "megdnn/oprs/base.h"
#include <mutex>
#include <string>
#include <unordered_map>
namespace megdnn {
class HeuristicCache {
private:
HeuristicCache() = default;
public:
static HeuristicCache& instance();
struct KeyStorage {
std::string category;
std::string input;
bool operator==(const KeyStorage& k) const {
return category == k.category && input == k.input;
}
};
class Key {
Handle* m_handle;
uint32_t m_opr_type;
const TensorLayout* m_inp_layouts_ptr;
size_t m_inp_layouts_size;
const void* m_param_ptr;
size_t m_param_size;
mutable std::string m_category;
mutable std::string m_input;
public:
Key(Handle* opr_handle, Algorithm::OprType opr_type, const TensorLayout* inp_layouts_ptr,
size_t inp_layouts_size, const void* param_ptr = nullptr, size_t param_size = 0)
: m_handle{opr_handle},
m_opr_type{static_cast<uint32_t>(opr_type)},
m_inp_layouts_ptr{inp_layouts_ptr},
m_inp_layouts_size{inp_layouts_size},
m_param_ptr{param_ptr},
m_param_size{param_size} {}
KeyStorage build_key_storage() const;
};
struct Result {
ExecutionPolicy policy;
size_t workspace;
};
void put(const Key& key, Result& result);
Result get(const Key& key);
void clear();
private:
struct Hash {
size_t operator()(const KeyStorage& k) const {
size_t h1 = std::hash<std::string>{}(k.category);
size_t h2 = std::hash<std::string>{}(k.input);
h1 ^= h2 + 0x9e3779b9 + (h1 << 6) + (h1 >> 2);
return h1;
}
};
std::unordered_map<KeyStorage, Result, Hash> m_heuristic_cache;
#if __DEPLOY_ON_XP_SP2__
size_t m_mtx;
#else
std::mutex m_mtx;
#endif
};
} // namespace megdnn
......@@ -42,6 +42,10 @@ public:
const TensorLayout& B,
const TensorLayout& C) = 0;
static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD;
}
protected:
void check_exec(const TensorLayout& A, const TensorLayout& B,
const TensorLayout& C, size_t workspace_in_bytes);
......@@ -76,6 +80,11 @@ public:
const TensorLayout& C) = 0;
static size_t pack_size (const Param::Format format);
static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::MATRIX_MUL_FORWARD;
}
protected:
void check_exec(const TensorLayout& A, const TensorLayout& B,
const TensorLayout& C, size_t workspace_in_bytes);
......
......@@ -275,6 +275,10 @@ public:
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) = 0;
static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::CONVOLUTION_FORWARD;
}
protected:
CanonizedFilterMeta check_exec(
const TensorLayout& src, const TensorLayout& filter,
......@@ -309,6 +313,10 @@ public:
void deduce_layout(const TensorLayout& filter, const TensorLayout& diff,
TensorLayout& grad);
static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::CONVOLUTION_BACKWARD_DATA;
}
protected:
CanonizedFilterMeta check_exec(const TensorLayout& filter,
const TensorLayout& diff,
......@@ -338,6 +346,10 @@ public:
const TensorLayout& diff,
const TensorLayout& grad) = 0;
static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::CONVOLUTION_BACKWARD_FILTER;
}
protected:
CanonizedFilterMeta check_exec(const TensorLayout& src,
const TensorLayout& diff,
......@@ -505,6 +517,10 @@ public:
const ConvBiasForward::BiasMode bias_mode,
const param::ConvBias::NonlineMode nonline_mode);
static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::CONVBIAS_FORWARD;
}
protected:
CanonizedFilterMeta check_exec(
const TensorLayout& src, const TensorLayout& filter,
......@@ -775,6 +791,10 @@ public:
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) = 0;
static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::POOLING_FORWARD;
}
protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
......@@ -801,6 +821,10 @@ public:
const TensorLayout& diff,
const TensorLayout& grad) = 0;
static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::POOLING_BACKWARD;
}
protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& diff, const TensorLayout& grad,
......@@ -1216,6 +1240,10 @@ public:
const TensorLayout& filter,
const TensorLayout& dst) = 0;
static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::CONVOLUTION3D_FORWARD;
}
protected:
CanonizedFilterMeta check_exec(const TensorLayout& src,
const TensorLayout& filter,
......@@ -1244,6 +1272,10 @@ public:
void deduce_layout(const TensorLayout& filter, const TensorLayout& diff,
TensorLayout& grad);
static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::CONVOLUTION3D_BACKWARD_DATA;
}
protected:
CanonizedFilterMeta check_exec(const TensorLayout& filter,
const TensorLayout& diff,
......@@ -1268,6 +1300,10 @@ public:
const TensorLayout& diff,
const TensorLayout& grad) = 0;
static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::CONVOLUTION3D_BACKWARD_FILTER;
}
protected:
CanonizedFilterMeta check_exec(const TensorLayout& src,
const TensorLayout& diff,
......@@ -1308,6 +1344,10 @@ public:
const TensorLayout& filter,
const TensorLayout& dst) = 0;
static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::LOCAL_SHARE_FORWARD;
}
protected:
void check_exec(const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, size_t workspace_in_bytes);
......@@ -1334,6 +1374,10 @@ public:
void deduce_layout(const TensorLayout& filter, const TensorLayout& diff,
TensorLayout& grad);
static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::LOCAL_SHARE_BACKWARD_DATA;
}
protected:
void check_exec(const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_in_bytes);
......@@ -1358,6 +1402,10 @@ public:
const TensorLayout& diff,
const TensorLayout& grad) = 0;
static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::LOCAL_SHARE_BACKWARD_FILTER;
}
protected:
void check_exec(const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_in_bytes);
......@@ -1479,6 +1527,10 @@ public:
const TensorLayout& mask,
const TensorLayout& dst) = 0;
static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::DEFORMABLE_CONV_FORWARD;
}
protected:
CanonizedFilterMeta check_exec(const TensorLayout& im,
const TensorLayout& filter,
......@@ -1520,6 +1572,10 @@ public:
const TensorLayout& mask, const TensorLayout& out_grad,
TensorLayout& filter_grad);
static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::DEFORMABLE_CONV_BACKWARD_FILTER;
}
protected:
CanonizedFilterMeta check_exec(const TensorLayout& im,
const TensorLayout& offset,
......@@ -1566,6 +1622,10 @@ public:
const TensorLayout& out_grad, TensorLayout& im_grad,
TensorLayout& offset_grad, TensorLayout& mask_grad);
static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::DEFORMABLE_CONV_BACKWARD_DATA;
}
protected:
CanonizedFilterMeta check_exec(
const TensorLayout& im, const TensorLayout& filter,
......@@ -1677,6 +1737,10 @@ public:
const TensorLayout& z,
const TensorLayout& dst) = 0;
static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::BATCH_CONV_FORWARD;
}
protected:
CanonizedFilterMeta check_exec(const TensorLayout& src,
const TensorLayout& filter,
......
......@@ -101,6 +101,15 @@ PoolingImpl::PoolingKernParam PoolingImpl::make_pooling_kern_param(
size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) {
TensorLayoutArray layouts{src, dst};
HeuristicCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(), &this->param(),
sizeof(this->param())};
auto rst = HeuristicCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
auto param = make_pooling_kern_szie_param(this, src, dst);
auto algo = get_algorithm(this, src, dst);
if (!is_fallback_algo(algo)) {
......
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
......@@ -17,10 +18,28 @@
#include <vector>
#include "megdnn/common.h"
#include "megdnn/heuristic_cache.h"
#include "utils.h"
namespace megdnn {
template <class Opr, typename... Args>
size_t get_dnn_workspace(Opr* opr, Args&&... args) {
TensorLayoutArray layouts{{args...}};
HeuristicCache::Key key{opr->handle(), opr->get_opr_type(),
layouts.data(), layouts.size(), &opr->param(),
sizeof(opr->param())};
auto rst = HeuristicCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
typename Opr::AlgoBase::SizeArgs size_args(opr,
std::forward<Args>(args)...);
return get_algorithm(opr, std::forward<Args>(args)...)
->get_workspace_in_bytes(size_args);
}
/*!
* \brief get user-configured algorithm, or heuristic algorithm
*/
......@@ -31,9 +50,20 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) {
if (set.valid()) {
ret = set;
} else {
ret = opr->get_algorithm_info_heuristic(
std::forward<Args>(args)..., std::numeric_limits<size_t>::max(),
AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT).desc;
TensorLayoutArray layouts{{args...}};
HeuristicCache::Key key{opr->handle(), opr->get_opr_type(),
layouts.data(), layouts.size(), &opr->param(),
sizeof(opr->param())};
auto rst = HeuristicCache::instance().get(key);
if (rst.policy.algo.valid()) {
ret = rst.policy.algo;
} else {
ret = opr->get_algorithm_info_heuristic(
std::forward<Args>(args)...,
std::numeric_limits<size_t>::max(),
AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT)
.desc;
}
}
return static_cast<typename Opr::AlgoBase*>(
opr->get_algorithm_from_desc(ret));
......
......@@ -250,13 +250,9 @@ CanonizedFilterMeta DeformableConvBackwardData::check_exec(
megdnn_assert_eq_dtype(im, mask_grad);
// check layout
megdnn_assert(im.shape == im_grad.shape, "invalid im_grad shape: %s",
megdnn_layout_msg(im_grad).c_str());
megdnn_assert(offset.shape == offset_grad.shape,
"invalid offset_grad shape: %s",
megdnn_layout_msg(offset_grad).c_str());
megdnn_assert(mask.shape == mask_grad.shape, "invalid mask_grad shape: %s",
megdnn_layout_msg(mask_grad).c_str());
megdnn_assert_eq_shape(im, im_grad);
megdnn_assert_eq_shape(offset, offset_grad);
megdnn_assert_eq_shape(mask, mask_grad);
auto ret = make_canonized_filter_meta(im.ndim, filter, offset);
auto required_workspace_in_bytes =
......
/**
* \file dnn/src/common/heuristic_cache.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/heuristic_cache.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
#if MEGDNN_WITH_CUDA
#include "src/cuda/utils.h"
#endif
#if MEGDNN_WITH_ROCM
#include "hcc_detail/hcc_defs_prologue.h"
#include "megcore_rocm.h"
#include "src/rocm/utils.h"
#endif
using namespace megdnn;
HeuristicCache& HeuristicCache::instance() {
static HeuristicCache ins;
return ins;
}
HeuristicCache::KeyStorage HeuristicCache::Key::build_key_storage() const {
auto&& ctg = m_category;
auto&& inp = m_input;
if (!m_category.empty() && !m_input.empty())
return {ctg, inp};
inp.reserve(sizeof(TensorLayout) * 3 * m_inp_layouts_size + m_param_size);
for (size_t i = 0; i < m_inp_layouts_size; i++) {
auto&& ly = m_inp_layouts_ptr[i];
for (size_t j = 0; j < ly.ndim; j++) {
if (j)
inp.push_back(',');
inp.append(std::to_string(ly.shape[j]));
}
inp.push_back(';');
for (size_t j = 0; j < ly.ndim; j++) {
if (j)
inp.push_back(',');
inp.append(std::to_string(ly.stride[j]));
}
inp.push_back(';');
inp.append(ly.dtype.name());
inp.push_back(';');
inp.append(ly.format.to_string().c_str());
inp.push_back('|');
}
if (m_param_size) {
inp.append(reinterpret_cast<const char*>(m_param_ptr), m_param_size);
}
ctg = "plat:";
ctg.append(std::to_string(static_cast<uint32_t>(m_handle->type())));
switch (m_handle->type()) {
#if MEGDNN_WITH_CUDA
case Handle::HandleType::CUDA: {
int cuda_rt = -1;
cuda_check(cudaRuntimeGetVersion(&cuda_rt));
cuda_rt /= 1000;
auto&& handle = static_cast<megdnn::cuda::HandleImpl*>(m_handle);
auto&& prop = handle->device_prop();
ctg.append(ssprintf(";dev=%s;cap=%d.%d;runtime=%d;",
prop.name, prop.major, prop.minor, cuda_rt));
break;
}
#endif
#if MEGDNN_WITH_ROCM
case Handle::HandleType::ROCM: {
auto&& handle = static_cast<megdnn::rocm::HandleImpl*>(m_handle);
auto&& prop = handle->device_prop();
int drv = -1, hip_rt = -1;
hip_check(hipDriverGetVersion(&drv));
hip_check(hipRuntimeGetVersion(&hip_rt));
ctg.append(ssprintf(";dev=%s;cap=%d.%d,drv=%d;runtime=%d;",
prop.name, prop.major, prop.minor, drv, hip_rt));
break;
}
#endif
case Handle::HandleType::FALLBACK:
#if MEGDNN_X86
case Handle::HandleType::X86:
#endif
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
case Handle::HandleType::ARM_COMMON:
#endif
#if MEGDNN_AARCH64
case Handle::HandleType::AARCH64:
#endif
#if MEGDNN_ARMV7
case Handle::HandleType::ARMV7:
#endif
{
size_t nr_threads =
static_cast<megdnn::naive::HandleImpl*>(m_handle)
->megcore_dispatcher()
->nr_threads();
ctg.append(";");
ctg.append(std::to_string(nr_threads));
ctg.append(";");
break;
}
default:
ctg.append(";");
}
ctg.append(std::to_string(m_opr_type));
return {ctg, inp};
}
void HeuristicCache::put(const Key& key, Result& result) {
MEGDNN_LOCK_GUARD(m_mtx);
if (result.policy.algo.valid())
m_heuristic_cache[key.build_key_storage()] = result;
}
HeuristicCache::Result HeuristicCache::get(const Key& key) {
MEGDNN_LOCK_GUARD(m_mtx);
KeyStorage ks = key.build_key_storage();
auto iter = m_heuristic_cache.find(ks);
if (iter == m_heuristic_cache.end()) {
return {};
} else {
return iter->second;
}
}
void HeuristicCache::clear() {
MEGDNN_LOCK_GUARD(m_mtx);
m_heuristic_cache.clear();
}
\ No newline at end of file
......@@ -56,9 +56,7 @@ size_t BatchConvBiasForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) {
AlgoBase::SizeArgs args(this, src, filter, bias, z, dst);
return get_algorithm(this, src, filter, bias, z, dst)
->get_workspace_in_bytes(args);
return get_dnn_workspace(this, src, filter, bias, z, dst);
}
void BatchConvBiasForwardImpl::exec(_megdnn_tensor_in src,
......@@ -66,10 +64,12 @@ void BatchConvBiasForwardImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_in bias, _megdnn_tensor_in z,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(src.layout, filter.layout, bias.layout, z.layout, dst.layout,
workspace.size);
AlgoBase::ExecArgs args(this, src, filter, bias, z, dst, workspace);
auto algo = get_algorithm(this, src.layout, filter.layout, bias.layout,
z.layout, dst.layout);
algo->check_workspace(args, workspace).exec(args);
algo->exec(args);
}
const char* BatchConvBiasForwardImpl::get_algorithm_set_name() const {
......
......@@ -33,13 +33,12 @@ void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
AlgoBase::ExecArgs args(this, A, B, C, workspace);
check_exec(A.layout, B.layout, C.layout, workspace.size);
auto&& algo = megdnn::get_algorithm(this, A.layout, B.layout, C.layout);
algo->check_workspace(args, workspace).exec(args);
algo->exec(args);
}
size_t BatchedMatrixMulForwardImpl::get_workspace_in_bytes(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) {
AlgoBase::SizeArgs args(this, A, B, C);
return megdnn::get_algorithm(this, A, B, C)->get_workspace_in_bytes(args);
return get_dnn_workspace(this, A, B, C);
}
std::vector<Algorithm*> BatchedMatrixMulForwardImpl::get_all_algorithms(
......
......@@ -36,7 +36,7 @@ void ConvBiasForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
preprocessed_filter);
auto algo = get_algorithm(this, src.layout, filter.layout, bias.layout,
z.layout, dst.layout);
algo->check_workspace(args, workspace).exec(args);
algo->exec(args);
};
std::vector<ConvBiasForward::Algorithm*>
......@@ -228,6 +228,15 @@ size_t ConvBiasForwardImpl::get_workspace_in_bytes(
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst,
const PreprocessedFilter* preprocessed_filter) {
TensorLayoutArray layouts{src, filter, bias, z, dst};
HeuristicCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(), &this->param(),
sizeof(this->param())};
auto rst = HeuristicCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
AlgoBase::SizeArgs args{
this, src, filter, bias, z, dst, preprocessed_filter};
return get_algorithm(this, src, filter, bias, z, dst)
......
......@@ -58,9 +58,7 @@ size_t ConvolutionForwardImpl::get_workspace_in_bytes(
const TensorLayout& dst,
const PreprocessedFilter* preprocessed_filter) {
MEGDNN_MARK_USED_VAR(preprocessed_filter);
AlgoBase::SizeArgs args{this, src, filter, dst};
return megdnn::get_algorithm(this, src, filter, dst)
->get_workspace_in_bytes(args);
return get_dnn_workspace(this, src, filter, dst);
}
void ConvolutionForwardImpl::exec(_megdnn_tensor_in src,
......@@ -72,7 +70,7 @@ void ConvolutionForwardImpl::exec(_megdnn_tensor_in src,
preprocessed_filter);
AlgoBase::ExecArgs args(this, src, filter, dst, workspace);
auto&& algo = get_algorithm(this, src.layout, filter.layout, dst.layout);
algo->check_workspace(args, workspace).exec(args);
algo->exec(args);
}
const char* ConvolutionForwardImpl::get_algorithm_set_name() const {
......@@ -85,9 +83,10 @@ void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
check_exec(filter.layout, diff.layout, grad.layout, workspace.size);
AlgoBase::ExecArgs args(this, filter, diff, grad, workspace);
auto algo = get_algorithm(this, filter.layout, diff.layout, grad.layout);
algo->check_workspace(args, workspace).exec(args);
algo->exec(args);
}
std::vector<ConvolutionBackwardDataImpl::Algorithm*>
......@@ -196,9 +195,7 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(
size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) {
AlgoBase::SizeArgs args(this, filter, diff, grad);
return get_algorithm(this, filter, diff, grad)
->get_workspace_in_bytes(args);
return get_dnn_workspace(this, filter, diff, grad);
}
const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const {
......@@ -211,9 +208,10 @@ void ConvolutionBackwardFilterImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
check_exec(src.layout, diff.layout, grad.layout, workspace.size);
AlgoBase::ExecArgs args(this, src, diff, grad, workspace);
auto algo = get_algorithm(this, src.layout, diff.layout, grad.layout);
algo->check_workspace(args, workspace).exec(args);
algo->exec(args);
}
std::vector<ConvolutionBackwardFilterImpl::Algorithm*>
......@@ -324,9 +322,7 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) {
AlgoBase::SizeArgs args(this, src, diff, grad);
return get_algorithm(this, src, diff, grad)
->get_workspace_in_bytes(args);
return get_dnn_workspace(this, src, diff, grad);
}
const char* ConvolutionBackwardFilterImpl::get_algorithm_set_name() const {
......
......@@ -111,18 +111,17 @@ Convolution3DForwardImpl::get_all_algorithms(const TensorLayout& src,
size_t Convolution3DForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) {
AlgoBase::SizeArgs args(this, src, filter, dst);
return get_algorithm(this, src, filter, dst)
->get_workspace_in_bytes(args);
return get_dnn_workspace(this, src, filter, dst);
}
void Convolution3DForwardImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_in filter,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(src.layout, filter.layout, dst.layout, workspace.size);
AlgoBase::ExecArgs args(this, src, filter, dst, workspace);
auto algo = get_algorithm(this, src.layout, filter.layout, dst.layout);
algo->check_workspace(args, workspace).exec(args);
algo->exec(args);
}
const char* Convolution3DForwardImpl::get_algorithm_set_name() const {
......@@ -133,9 +132,10 @@ void Convolution3DBackwardDataImpl::exec(_megdnn_tensor_in filter,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
check_exec(filter.layout, diff.layout, grad.layout, workspace.size);
AlgoBase::ExecArgs args(this, filter, diff, grad, workspace);
auto algo = get_algorithm(this, filter.layout, diff.layout, grad.layout);
algo->check_workspace(args, workspace).exec(args);
algo->exec(args);
}
std::vector<Convolution3DBackwardDataImpl::Algorithm*>
......@@ -200,9 +200,7 @@ Convolution3DBackwardDataImpl::get_algorithm_heuristic(
size_t Convolution3DBackwardDataImpl::get_workspace_in_bytes(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) {
AlgoBase::SizeArgs args(this, filter, diff, grad);
return get_algorithm(this, filter, diff, grad)
->get_workspace_in_bytes(args);
return get_dnn_workspace(this, filter, diff, grad);
}
const char* Convolution3DBackwardDataImpl::get_algorithm_set_name() const {
......@@ -213,10 +211,11 @@ void Convolution3DBackwardFilterImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
check_exec(src.layout, diff.layout, grad.layout, workspace.size);
AlgoBase::ExecArgs args(this, src, diff, grad, workspace);
auto algo =
get_algorithm(this, src.layout, diff.layout, grad.layout);
algo->check_workspace(args, workspace).exec(args);
algo->exec(args);
}
std::vector<Convolution3DBackwardFilterImpl::Algorithm*>
......@@ -281,9 +280,7 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic(
size_t Convolution3DBackwardFilterImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) {
AlgoBase::SizeArgs args(this, src, diff, grad);
return get_algorithm(this, src, diff, grad)
->get_workspace_in_bytes(args);
return get_dnn_workspace(this, src, diff , grad);
}
const char* Convolution3DBackwardFilterImpl::get_algorithm_set_name() const {
......
......@@ -36,8 +36,7 @@ size_t Fwd::get_workspace_in_bytes(const TensorLayout& im,
const TensorLayout& offset,
const TensorLayout& mask,
const TensorLayout& dst) {
auto algo = get_algorithm(this, im, filter, offset, mask, dst);
return algo->get_workspace_in_bytes({this, im, filter, offset, mask, dst});
return get_dnn_workspace(this, im, filter, offset, mask, dst);
}
std::vector<AlgoFwd*> Fwd::get_all_algorithms(const TensorLayout& /* im */,
......@@ -96,13 +95,13 @@ const char* Fwd::get_algorithm_set_name() const {
void Fwd::exec(_megdnn_tensor_in im, _megdnn_tensor_in filter,
_megdnn_tensor_in offset, _megdnn_tensor_in mask,
_megdnn_tensor_out out, _megdnn_workspace workspace) {
check_exec(im.layout, filter.layout, offset.layout, mask.layout, out.layout,
workspace.size);
auto algo = get_algorithm(this, im.layout, filter.layout, offset.layout,
mask.layout, out.layout);
AlgoBase::ExecArgs args(this, im, filter, offset, mask, out, workspace);
algo->check_workspace(args, workspace).exec(args);
return;
algo->exec(args);
}
/* ============== BwdFlt Implementation ============== */
......@@ -152,21 +151,23 @@ AlgoBwdFlt* BwdFlt::get_algorithm_heuristic(
size_t BwdFlt::get_workspace_in_bytes(
const TensorLayout& im, const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& out_grad, const TensorLayout& filter_grad) {
auto algo = get_algorithm(this, im, offset, mask, out_grad, filter_grad);
return algo->get_workspace_in_bytes({this, im, offset, mask, out_grad, filter_grad});
return get_dnn_workspace(this, im, offset, mask, out_grad, filter_grad);
}
const char* BwdFlt::get_algorithm_set_name() const {
return "DEFORMABLE_CONV_BWD_FILTER_CUDA";
};
void BwdFlt::exec(_megdnn_tensor_in im, _megdnn_tensor_in offset, _megdnn_tensor_in mask,
_megdnn_tensor_in out_grad, _megdnn_tensor_out filter_grad,
_megdnn_workspace workspace) {
AlgoBase::ExecArgs args(this, im, offset, mask, out_grad, filter_grad, workspace);
auto algo = get_algorithm(this, im.layout, offset.layout, mask.layout, out_grad.layout,
filter_grad.layout);
algo->check_workspace(args, workspace).exec(args);
void BwdFlt::exec(_megdnn_tensor_in im, _megdnn_tensor_in offset,
_megdnn_tensor_in mask, _megdnn_tensor_in out_grad,
_megdnn_tensor_out filter_grad, _megdnn_workspace workspace) {
check_exec(im.layout, offset.layout, mask.layout, out_grad.layout,
filter_grad.layout, workspace.size);
AlgoBase::ExecArgs args(this, im, offset, mask, out_grad, filter_grad,
workspace);
auto algo = get_algorithm(this, im.layout, offset.layout, mask.layout,
out_grad.layout, filter_grad.layout);
algo->exec(args);
}
/* ============== BwdData Implementation ============== */
......@@ -222,10 +223,8 @@ size_t BwdData::get_workspace_in_bytes(
const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& out_grad, const TensorLayout& im_grad,
const TensorLayout& offset_grad, const TensorLayout& mask_grad) {
auto algo = get_algorithm(this, im, filter, offset, mask, out_grad,
im_grad, offset_grad, mask_grad);
return algo->get_workspace_in_bytes({this, im, filter, offset, mask, out_grad,
im_grad, offset_grad, mask_grad});
return get_dnn_workspace(this, im, filter, offset, mask, out_grad, im_grad,
offset_grad, mask_grad);
}
const char* BwdData::get_algorithm_set_name() const {
......@@ -233,16 +232,19 @@ const char* BwdData::get_algorithm_set_name() const {
};
void BwdData::exec(_megdnn_tensor_in im, _megdnn_tensor_in filter,
_megdnn_tensor_in offset, _megdnn_tensor_in mask,
_megdnn_tensor_in out_grad, _megdnn_tensor_out im_grad,
_megdnn_tensor_out offset_grad, _megdnn_tensor_out mask_grad,
_megdnn_workspace workspace) {
_megdnn_tensor_in offset, _megdnn_tensor_in mask,
_megdnn_tensor_in out_grad, _megdnn_tensor_out im_grad,
_megdnn_tensor_out offset_grad, _megdnn_tensor_out mask_grad,
_megdnn_workspace workspace) {
check_exec(im.layout, filter.layout, offset.layout, mask.layout,
out_grad.layout, im_grad.layout, offset_grad.layout,
mask_grad.layout, workspace.size);
AlgoBase::ExecArgs args(this, im, filter, offset, mask, out_grad, im_grad,
offset_grad, mask_grad, workspace);
auto algo = get_algorithm(this, im.layout, filter.layout, offset.layout,
mask.layout, out_grad.layout, im_grad.layout,
offset_grad.layout, mask_grad.layout);
algo->check_workspace(args, workspace).exec(args);
algo->exec(args);
}
// vim: syntax=cpp.doxygen
......@@ -59,17 +59,17 @@ LocalShareForwardImpl::get_all_algorithms(const TensorLayout& src,
size_t LocalShareForwardImpl::get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst) {
AlgoBase::SizeArgs args(this, src, filter, dst);
return get_algorithm(this, src, filter, dst)->get_workspace_in_bytes(args);
return get_dnn_workspace(this, src, filter, dst);
}
void LocalShareForwardImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_in filter,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(src.layout, filter.layout, dst.layout, workspace.size);
AlgoBase::ExecArgs args(this, src, filter, dst, workspace);
auto algo = get_algorithm(this, src.layout, filter.layout, dst.layout);
algo->check_workspace(args, workspace).exec(args);
algo->exec(args);
}
const char* LocalShareForwardImpl::get_algorithm_set_name() const {
......@@ -112,8 +112,7 @@ LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout& filter,
size_t LocalShareBackwardDataImpl::get_workspace_in_bytes(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad) {
AlgoBase::SizeArgs args(this, filter, diff, grad);
return get_algorithm(this, filter, diff, grad)->get_workspace_in_bytes(args);
return get_dnn_workspace(this, filter, diff, grad);
}
void LocalShareBackwardDataImpl::exec(_megdnn_tensor_in filter,
......@@ -166,8 +165,7 @@ LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout& src,
size_t LocalShareBackwardFilterImpl::get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad) {
AlgoBase::SizeArgs args(this, src, diff, grad);
return get_algorithm(this, src, diff, grad)->get_workspace_in_bytes(args);
return get_dnn_workspace(this, src, diff, grad);
}
void LocalShareBackwardFilterImpl::exec(_megdnn_tensor_in src,
......
......@@ -59,8 +59,7 @@ MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic(
size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) {
AlgoBase::SizeArgs args{this, A, B, C};
return megdnn::get_algorithm(this, A, B, C)->get_workspace_in_bytes(args);
return get_dnn_workspace(this, A, B, C);
}
void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
......@@ -69,7 +68,7 @@ void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
check_exec(A.layout, B.layout, C.layout, workspace.size);
AlgoBase::ExecArgs args(this, A, B, C, workspace);
auto&& algo = get_algorithm(this, A.layout, B.layout, C.layout);
algo->check_workspace(args, workspace).exec(args);
algo->exec(args);
}
} // namespace cuda
......
......@@ -21,8 +21,7 @@ namespace cuda {
size_t PoolingForwardImpl::get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) {
AlgoBase::SizeArgs args(this, src, dst);
return get_algorithm(this, src, dst)->get_workspace_in_bytes(args);
return get_dnn_workspace(this, src, dst);
}
const char* PoolingForwardImpl::get_algorithm_set_name() const {
......@@ -117,9 +116,7 @@ size_t PoolingBackwardImpl::get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst,
const TensorLayout& diff,
const TensorLayout& grad) {
AlgoBase::SizeArgs args(this, src, dst, diff, grad);
return get_algorithm(this, src, dst, diff, grad)
->get_workspace_in_bytes(args);
return get_dnn_workspace(this, src, dst, diff, grad);
}
} // namespace cuda
......
......@@ -44,8 +44,7 @@ BatchedMatrixMulForwardImpl::get_algorithm_heuristic(
size_t BatchedMatrixMulForwardImpl::get_workspace_in_bytes(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) {
AlgoBase::SizeArgs args{this, A, B, C};
return megdnn::get_algorithm(this, A, B, C)->get_workspace_in_bytes(args);
return get_dnn_workspace(this, A, B, C);
}
void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
......@@ -54,7 +53,7 @@ void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
check_exec(A.layout, B.layout, C.layout, workspace.size);
AlgoBase::ExecArgs args(this, A, B, C, workspace);
auto&& algo = get_algorithm(this, A.layout, B.layout, C.layout);
algo->check_workspace(args, workspace).exec(args);
algo->exec(args);
}
// vim: syntax=cpp.doxygen
......@@ -224,6 +224,15 @@ size_t ConvBiasImpl::get_workspace_in_bytes(
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst,
const PreprocessedFilter* preprocessed_filter) {
TensorLayoutArray layouts{src, filter, bias, z, dst};
HeuristicCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(), &this->param(),
sizeof(this->param())};
auto rst = HeuristicCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
auto fparam = make_ncb_kern_size_param(src, filter, bias, dst,
preprocessed_filter);
auto&& algo = get_algorithm(fparam);
......
......@@ -146,6 +146,15 @@ size_t ConvolutionImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst,
const PreprocessedFilter* preprocessed_filter) {
TensorLayoutArray layouts{src, filter, dst};
HeuristicCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(), &this->param(),
sizeof(this->param())};
auto rst = HeuristicCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
auto fparam =
make_ncb_kern_size_param(src, filter, dst, preprocessed_filter);
auto&& algo = get_algorithm(fparam);
......@@ -494,6 +503,15 @@ void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter,
size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) {
TensorLayoutArray layouts{filter, diff, grad};
HeuristicCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(), &this->param(),
sizeof(this->param())};
auto rst = HeuristicCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
if (param().format == param::Convolution::Format::NHWCD4 ||
param().format == param::Convolution::Format::NCHW4 ||
(param().format == param::Convolution::Format::NCHW &&
......
......@@ -219,6 +219,15 @@ MatrixMulImpl::KernParam MatrixMulImpl::make_kern_param(
size_t MatrixMulImpl::get_workspace_in_bytes(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) {
TensorLayoutArray layouts{A, B, C};
HeuristicCache::Key key{this->handle(),this->get_opr_type(),
layouts.data(), layouts.size(), &this->param(),
sizeof(this->param())};
auto rst = HeuristicCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
if (auto algo = get_algorithm_heuristic(
A, B, C, std::numeric_limits<size_t>::max(),
AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT)) {
......
......@@ -15,6 +15,7 @@
#include "src/naive/convolution/helper.h"
#include <cstring>
#include "megdnn/heuristic_cache.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
......@@ -56,6 +57,14 @@ size_t BatchConvBiasForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& flt,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) {
TensorLayoutArray layouts{src, flt, bias, z, dst};
HeuristicCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(), &this->param(),
sizeof(this->param())};
auto rst = HeuristicCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
return get_workspace_bundle(nullptr, src, flt, bias, z, dst)
.total_size_in_bytes();
}
......
......@@ -13,6 +13,7 @@
#include "src/naive/convolution/helper.h"
#include <cstring>
#include "megdnn/heuristic_cache.h"
#include "megdnn/dtype.h"
#include "src/common/conv_bias.h"
#include "src/common/opr_delegate.h"
......@@ -201,6 +202,15 @@ size_t ConvBiasForwardImpl::get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& z,
const TensorLayout& dst,
const PreprocessedFilter*) {
TensorLayoutArray layouts{src, flt, bias, z, dst};
HeuristicCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(), &this->param(),
sizeof(this->param())};
auto rst = HeuristicCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
size_t float_workspace_size = 0;
if (z.ndim > 0 && z.dtype.category() != DTypeCategory::FLOAT) {
......
......@@ -11,7 +11,7 @@
#include "./opr_impl.h"
#include "./helper.h"
#include "src/naive/handle.h"
#include "megdnn/heuristic_cache.h"
#include "src/naive/handle.h"
#include "src/common/utils.h"
#include "megdnn/dtype.h"
......@@ -78,6 +78,15 @@ void ConvolutionForwardImpl::exec(_megdnn_tensor_in src,
size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad) {
TensorLayoutArray layouts{filter, diff, grad};
HeuristicCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(), &this->param(),
sizeof(this->param())};
auto rst = HeuristicCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
size_t workspace_size = 0;
auto flt_dt = filter.dtype.enumv();
auto grad_dt = grad.dtype.enumv();
......@@ -191,6 +200,15 @@ size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes(
const TensorLayout& grad) {
size_t workspace_size = 0;
#if !MEGDNN_DISABLE_FLOAT16
TensorLayoutArray layouts{src, diff, grad};
HeuristicCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(), &this->param(),
sizeof(this->param())};
auto rst = HeuristicCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
auto src_dt = src.dtype.enumv();
auto grad_dt = grad.dtype.enumv();
auto diff_dt = diff.dtype.enumv();
......
......@@ -12,6 +12,7 @@
#include "src/naive/pooling/opr_impl.h"
#include <cstring>
#include "megdnn/heuristic_cache.h"
#include "megdnn/dtype.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
......@@ -402,6 +403,14 @@ WorkspaceBundle PoolingForwardImpl::get_workspace_bundle(
size_t PoolingForwardImpl::get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) {
TensorLayoutArray layouts{src, dst};
HeuristicCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(), &this->param(),
sizeof(this->param())};
auto rst = HeuristicCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
return get_workspace_bundle(nullptr, src, dst).total_size_in_bytes();
}
namespace {
......@@ -652,6 +661,14 @@ WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle(
size_t PoolingBackwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& diff, const TensorLayout& grad) {
TensorLayoutArray layouts{src, dst, diff, grad};
HeuristicCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(), &this->param(),
sizeof(this->param())};
auto rst = HeuristicCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
return get_workspace_bundle(nullptr, src, dst, diff, grad)
.total_size_in_bytes();
}
......
......@@ -47,8 +47,7 @@ BatchedMatrixMulForwardImpl::get_algorithm_heuristic(
size_t BatchedMatrixMulForwardImpl::get_workspace_in_bytes(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) {
AlgoBase::SizeArgs args{this, A, B, C};
return megdnn::get_algorithm(this, A, B, C)->get_workspace_in_bytes(args);
return get_dnn_workspace(this, A, B, C);
}
void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
......@@ -57,7 +56,7 @@ void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
check_exec(A.layout, B.layout, C.layout, workspace.size);
AlgoBase::ExecArgs args(this, A, B, C, workspace);
auto&& algo = get_algorithm(this, A.layout, B.layout, C.layout);
algo->check_workspace(args, workspace).exec(args);
algo->exec(args);
}
// vim: syntax=cpp.doxygen
......@@ -112,19 +112,30 @@ ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src,
size_t ConvolutionForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, const PreprocessedFilter*) {
TensorLayoutArray layouts{src, filter, dst};
HeuristicCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(), &this->param(),
sizeof(this->param())};
auto rst = HeuristicCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
AlgoBase::SizeArgs args(this, src, filter, dst);
return get_algorithm(this, src, args.filter_meta, dst)
return get_algorithm(this, src, filter, dst)
->get_workspace_in_bytes(args);
}
void ConvolutionForwardImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_in filter,
_megdnn_tensor_out dst,
const PreprocessedFilter*,
const PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace) {
check_exec(src.layout, filter.layout, dst.layout, workspace.size,
preprocessed_filter);
AlgoBase::ExecArgs args(this, src, filter, dst, workspace);
auto algo = get_algorithm(this, src.layout, args.filter_meta, dst.layout);
algo->check_workspace(args, workspace).exec(args);
auto algo = get_algorithm(this, src.layout, filter.layout, dst.layout);
algo->exec(args);
}
const char* ConvolutionForwardImpl::get_algorithm_set_name() const {
......@@ -137,9 +148,10 @@ void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
check_exec(filter.layout, diff.layout, grad.layout, workspace.size);
AlgoBase::ExecArgs args(this, filter, diff, grad, workspace);
auto algo = get_algorithm(this, args.filter_meta, diff.layout, grad.layout);
algo->check_workspace(args, workspace).exec(args);
auto algo = get_algorithm(this, filter.layout, diff.layout, grad.layout);
algo->exec(args);
}
std::vector<ConvolutionBackwardDataImpl::Algorithm*>
......@@ -192,8 +204,17 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(
size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) {
TensorLayoutArray layouts{filter, diff, grad};
HeuristicCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(), &this->param(),
sizeof(this->param())};
auto rst = HeuristicCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
AlgoBase::SizeArgs args(this, filter, diff, grad);
return get_algorithm(this, args.filter_meta, diff, grad)
return get_algorithm(this, filter, diff, grad)
->get_workspace_in_bytes(args);
}
......@@ -207,10 +228,11 @@ void ConvolutionBackwardFilterImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
check_exec(src.layout, diff.layout, grad.layout, workspace.size);
AlgoBase::ExecArgs args(this, src, diff, grad, workspace);
auto algo =
get_algorithm(this, src.layout, diff.layout, args.grad_filter_meta);
algo->check_workspace(args, workspace).exec(args);
get_algorithm(this, src.layout, diff.layout, grad.layout);
algo->exec(args);
}
std::vector<ConvolutionBackwardFilterImpl::Algorithm*>
......@@ -264,8 +286,17 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) {
TensorLayoutArray layouts{src, diff, grad};
HeuristicCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(), &this->param(),
sizeof(this->param())};
auto rst = HeuristicCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
AlgoBase::SizeArgs args(this, src, diff, grad);
return get_algorithm(this, src, diff, args.grad_filter_meta)
return get_algorithm(this, src, diff, grad)
->get_workspace_in_bytes(args);
}
......
......@@ -24,7 +24,7 @@ public:
const PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace) override;
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& src, const CanonizedFilterMeta& filter,
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
......@@ -95,7 +95,7 @@ public:
void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) override;
AlgorithmInfo get_algorithm_info_heuristic(
const CanonizedFilterMeta& filter, const TensorLayout& diff,
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
......@@ -145,7 +145,7 @@ public:
_megdnn_tensor_out grad, _megdnn_workspace workspace) override;
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& src, const TensorLayout& diff,
const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
return get_algorithm_heuristic(src, diff, grad,
......
......@@ -44,8 +44,7 @@ MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic(
size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) {
AlgoBase::SizeArgs args{this, A, B, C};
return megdnn::get_algorithm(this, A, B, C)->get_workspace_in_bytes(args);
return get_dnn_workspace(this, A, B, C);
}
void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
......@@ -54,7 +53,7 @@ void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
check_exec(A.layout, B.layout, C.layout, workspace.size);
AlgoBase::ExecArgs args(this, A, B, C, workspace);
auto&& algo = get_algorithm(this, A.layout, B.layout, C.layout);
algo->check_workspace(args, workspace).exec(args);
algo->exec(args);
}
// vim: syntax=cpp.doxygen
......@@ -19,8 +19,7 @@ namespace rocm {
size_t PoolingForwardImpl::get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) {
AlgoBase::SizeArgs args(this, src, dst);
return get_algorithm(this, src, dst)->get_workspace_in_bytes(args);
return get_dnn_workspace(this, src, dst);
}
const char* PoolingForwardImpl::get_algorithm_set_name() const {
......@@ -69,9 +68,7 @@ size_t PoolingBackwardImpl::get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst,
const TensorLayout& diff,
const TensorLayout& grad) {
AlgoBase::SizeArgs args(this, src, dst, diff, grad);
return get_algorithm(this, src, dst, diff, grad)
->get_workspace_in_bytes(args);
return get_dnn_workspace(this, src, dst, diff, grad);
};
const char* PoolingBackwardImpl::get_algorithm_set_name() const {
......
......@@ -46,6 +46,15 @@ WorkspaceBundle megdnn::x86::get_bundle(const TensorLayout& src,
size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) {
TensorLayoutArray layouts{src, dst};
HeuristicCache::Key key{this->handle(), this->get_opr_type(),
layouts.data(), layouts.size(), &this->param(),
sizeof(this->param())};
auto rst = HeuristicCache::instance().get(key);
if (rst.policy.algo.valid()) {
return rst.workspace;
}
auto algo = get_algorithm(this, src, dst);
if (!is_fallback_algo(algo)) {
if (is_supported(SIMDType::SSE) && src.dtype == dtype::Float32() &&
......
......@@ -29,6 +29,7 @@
#include "megbrain/plugin/profiler.h"
#include "megbrain/test/helper.h"
#include "megdnn/heuristic_cache.h"
#include "megdnn/oprs/base.h"
#include <atomic>
......@@ -2075,10 +2076,12 @@ void test_free_memory_in_weight_preprocess(int record_level, CompNode cn) {
TEST(TestGraph, FreeMemoryInWeightPreprocess) {
test_free_memory_in_weight_preprocess(0, CompNode::load("xpu0"));
megdnn::HeuristicCache::instance().clear();
}
TEST(TestGraph, RecordFreeMemoryInWeightPreprocess) {
test_free_memory_in_weight_preprocess(1, CompNode::load("cpu0"));
megdnn::HeuristicCache::instance().clear();
}
namespace {
......@@ -2157,6 +2160,7 @@ TEST(TestGraph, FreeMemoryInWeightPreprocessWithValueInfer) {
->cast_final_safe<opr::SharedDeviceTensor>()
.get_dev_tensor()
.empty());
megdnn::HeuristicCache::instance().clear();
}
TEST(TestGraph, FreeMemoryInWeightPreprocessWithMultiReader) {
......@@ -2200,6 +2204,7 @@ TEST(TestGraph, FreeMemoryInWeightPreprocessWithMultiReader) {
->cast_final_safe<opr::SharedDeviceTensor>()
.get_dev_tensor()
.empty());
megdnn::HeuristicCache::instance().clear();
}
TEST(TestGraph, FreeBias) {
......
......@@ -24,6 +24,7 @@
//! TODO: here has to be know some megdnn::opr when there is produced midout.h
//! fix it if there is another graceful way.
#include "megdnn/heuristic_cache.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs.h"
#include "megdnn/oprs/base.h"
......@@ -1156,6 +1157,15 @@ template <typename Opr>
size_t AlgoChooser<Opr>::setup_algo(const FixedTensorLayouts& layouts,
Opr* megdnn_opr, const MGBOpr* mgb_opr,
bool allow_weight_preprocess) {
HeuristicCache::Key cache_key(
megdnn_opr->handle(), megdnn_opr->get_opr_type(), layouts.data(),
layouts.size(), &megdnn_opr->param(), sizeof(megdnn_opr->param()));
auto rst = HeuristicCache::instance().get(cache_key);
if (rst.policy.algo.valid()) {
megdnn_opr->execution_policy() = rst.policy;
return rst.workspace;
}
if (WorkspaceLimitGetter::is_prealloc_run(mgb_opr->owner_graph())) {
return 0;
}
......@@ -1192,6 +1202,11 @@ size_t AlgoChooser<Opr>::setup_algo(const FixedTensorLayouts& layouts,
mgb_log_debug("%s", ret.c_str());
megdnn_opr->execution_policy() = policy;
if (mgb_opr->execution_policy().strategy & ExecutionStrategy::HEURISTIC) {
HeuristicCache::Result cache_result{policy, workspace};
HeuristicCache::instance().put(cache_key, cache_result);
}
return workspace;
}
......
......@@ -22,6 +22,7 @@
#include "megbrain/opr/tensor_manip.h"
#include "megdnn/oprs/base.h"
#include "megdnn/dtype.h"
#include "megdnn/heuristic_cache.h"
#include <cmath>
#include <random>
......@@ -337,6 +338,7 @@ void test_no_profiling_on_shape_change(const TensorShapeArray& inps0,
TEST(TestOprDNN, FastrunNoProfilingOnShapeChange) {
REQUIRE_GPU(1);
megdnn::HeuristicCache::instance().clear();
test_no_profiling_on_shape_change<opr::Convolution>(
{{12, 3, 36, 36}, {4, 3, 3, 3}}, {{32, 3, 28, 28}, {4, 3, 3, 3}});
......
......@@ -21,6 +21,7 @@
#include "megbrain/gopt/inference.h"
#include "megbrain/opr/tensor_manip.h"
#include "megdnn/dtype.h"
#include "megdnn/heuristic_cache.h"
#include "megdnn/oprs/base.h"
#include <gmock/gmock.h>
......@@ -396,6 +397,7 @@ TEST(TestOprDNN, ConvBiasExePolicy) {
#endif
run(strategy);
}
megdnn::HeuristicCache::instance().clear();
ASSERT_THROW(run(S::OPTIMIZED | S::PROFILE), MegBrainError);
PersistentCache::set_impl(orig_impl);
}
......@@ -460,6 +462,7 @@ TEST(TestOprDNN, ConvolutionExePolicy) {
for (auto strategy :
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) {
#endif
megdnn::HeuristicCache::instance().clear();
using Checker = AutoOprChecker<2, 1>;
auto make_graph = [&](const Checker::SymInpArray& inputs)
......@@ -489,6 +492,7 @@ TEST(TestOprDNN, ConvolutionExePolicy) {
} else {
ASSERT_LT(0, nr_get);
}
megdnn::HeuristicCache::instance().clear();
}
}
......@@ -544,6 +548,7 @@ TEST(TestOprDNN, ConvolutionBackwardDataBfloat16ExePolicy) {
#else
for (auto strategy: {S:HEURISTIC, S(S::PROFILE | S::HEURISTIC)}) {
#endif
megdnn::HeuristicCache::instance().clear();
using Checker = AutoOprChecker<2, 1>;
auto make_graph = [&](const Checker::SymInpArray& inputs)
......@@ -1835,6 +1840,7 @@ TEST(TestOprDNN, LocalShareForwardExecPolicy) {
auto run_with_param = [&](size_t fh = 3, size_t fw = 3, size_t sh = 1,
size_t sw = 1, size_t sgh = 3,
size_t sgw = 3) {
megdnn::HeuristicCache::instance().clear();
size_t ph = fh / 2, pw = fw / 2;
param.pad_h = ph, param.pad_w = pw;
param.stride_h = sh, param.stride_w = sw,
......@@ -2289,6 +2295,7 @@ TEST(TestOprDNN, HeuristicReproducible) {
}
algo_name0 = palgo->name();
}
megdnn::HeuristicCache::instance().clear();
{
Checker checker(make_graph, fwd);
checker.run(inp_tensor(2, 3, 4, 9, 8, 3, 3), opt)
......@@ -2306,6 +2313,7 @@ TEST(TestOprDNN, HeuristicReproducible) {
algo_name1 = palgo->name();
}
EXPECT_TRUE(algo_name0 == algo_name1);
megdnn::HeuristicCache::instance().clear();
}
#undef inp_tensor
#undef get_shp
......@@ -2585,6 +2593,7 @@ TEST_F(TestWeightPreprocess, NoPreprocessNeeded) {
}
TEST_F(TestWeightPreprocess, PreprocessCalledOnlyOnce) {
megdnn::HeuristicCache::instance().clear();
using ::testing::_;
using ::testing::Return;
using ::testing::Field;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册