提交 8585aa61 编写于 作者: M Megvii Engine Team

fix(mgb): fix fast run crash when profile heuristic strategy

GitOrigin-RevId: 6046a2db0c532b33c78f20c4aa1aa7f7df1af0e4
上级 ef9aa800
...@@ -25,79 +25,9 @@ ...@@ -25,79 +25,9 @@
using namespace mgb; using namespace mgb;
namespace { // ================= PersistentCache ======================
class InMemoryPersistentCache final: public PersistentCache {
struct BlobStorage: public Blob {
std::unique_ptr<uint8_t[]> data_refhold;
size_t hash = 0;
BlobStorage& init_data_ref(const Blob &b) {
data_refhold = std::make_unique<uint8_t[]>(b.size + 1);
memcpy(data_refhold.get(), b.ptr, b.size);
data_refhold.get()[b.size] = 0; // for C-string safety
ptr = data_refhold.get();
size = b.size;
return *this;
}
BlobStorage& init_hash() {
hash = XXHash{}.update(ptr, size).digest();
return *this;
}
bool operator == (const BlobStorage &rhs) const {
return size == rhs.size && !memcmp(ptr, rhs.ptr, size);
}
struct Hash {
size_t operator() (const BlobStorage &b) const {
return b.hash;
}
};
};
std::unordered_map<std::string,
std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>>
m_cache;
std::mutex m_mtx;
Maybe<Blob> get(const std::string& category, const Blob& key) override {
decltype(m_cache.begin()) iter0;
{
MGB_LOCK_GUARD(m_mtx);
iter0 = m_cache.find(category);
if (iter0 == m_cache.end())
return None;
}
BlobStorage key_storage;
key_storage.Blob::operator=(key);
key_storage.init_hash();
MGB_LOCK_GUARD(m_mtx);
auto iter1 = iter0->second.find(key_storage);
if (iter1 == iter0->second.end())
return None;
return iter1->second;
}
void put(const std::string& category, const Blob& key,
const Blob& value) override {
BlobStorage key_storage;
key_storage.init_data_ref(key).init_hash();
MGB_LOCK_GUARD(m_mtx);
auto size0 = m_cache.size();
m_cache[category][std::move(key_storage)].init_data_ref(value);
if (m_cache.size() > size0) {
mgb_log_debug("new cache category: %s", category.c_str());
}
}
};
}
std::shared_ptr<PersistentCache> PersistentCache::sm_impl = std::shared_ptr<PersistentCache> PersistentCache::sm_impl =
std::make_shared<InMemoryPersistentCache>(); std::make_shared<InMemoryPersistentCache>();
std::shared_ptr<PersistentCache> PersistentCache::set_impl( std::shared_ptr<PersistentCache> PersistentCache::set_impl(
std::shared_ptr<PersistentCache> impl) { std::shared_ptr<PersistentCache> impl) {
...@@ -141,6 +71,65 @@ std::string PersistentCache::make_category_from_comp_node(CompNode comp_node) { ...@@ -141,6 +71,65 @@ std::string PersistentCache::make_category_from_comp_node(CompNode comp_node) {
} }
} }
// ================= InMemoryPersistentCache ==================
using Blob = PersistentCache::Blob;
InMemoryPersistentCache::BlobStorage&
InMemoryPersistentCache::BlobStorage::init_data_ref(const Blob& b) {
data_refhold = std::make_unique<uint8_t[]>(b.size + 1);
memcpy(data_refhold.get(), b.ptr, b.size);
data_refhold.get()[b.size] = 0; // for C-string safety
ptr = data_refhold.get();
size = b.size;
return *this;
}
InMemoryPersistentCache::BlobStorage&
InMemoryPersistentCache::BlobStorage::init_hash() {
hash = XXHash{}.update(ptr, size).digest();
return *this;
}
bool InMemoryPersistentCache::BlobStorage::operator==(
const BlobStorage& rhs) const {
return size == rhs.size && !memcmp(ptr, rhs.ptr, size);
}
Maybe<Blob> InMemoryPersistentCache::get(const std::string& category,
const Blob& key) {
decltype(m_cache.begin()) iter0;
{
MGB_LOCK_GUARD(m_mtx);
iter0 = m_cache.find(category);
if (iter0 == m_cache.end())
return None;
}
BlobStorage key_storage;
key_storage.Blob::operator=(key);
key_storage.init_hash();
MGB_LOCK_GUARD(m_mtx);
auto iter1 = iter0->second.find(key_storage);
if (iter1 == iter0->second.end())
return None;
return iter1->second;
}
void InMemoryPersistentCache::put(const std::string& category, const Blob& key,
const Blob& value) {
BlobStorage key_storage;
key_storage.init_data_ref(key).init_hash();
MGB_LOCK_GUARD(m_mtx);
auto size0 = m_cache.size();
m_cache[category][std::move(key_storage)].init_data_ref(value);
if (m_cache.size() > size0) {
mgb_log_debug("new cache category: %s", category.c_str());
}
}
// ================= AlgoChooserProfileCache ==================
AlgoChooserProfileCache::AlgoChooserProfileCache( AlgoChooserProfileCache::AlgoChooserProfileCache(
CompNode cn, const char *opr_type) { CompNode cn, const char *opr_type) {
m_category = "profile:"; m_category = "profile:";
......
...@@ -55,6 +55,37 @@ namespace mgb { ...@@ -55,6 +55,37 @@ namespace mgb {
static std::string make_category_from_comp_node(CompNode comp_node); static std::string make_category_from_comp_node(CompNode comp_node);
}; };
/*!
* \brief persistent cache that keep in memory
* The implementation is thread safe.
*/
class InMemoryPersistentCache final : public PersistentCache {
struct BlobStorage : public PersistentCache::Blob {
std::unique_ptr<uint8_t[]> data_refhold;
size_t hash = 0;
BlobStorage& init_data_ref(const Blob& b);
BlobStorage& init_hash();
bool operator==(const BlobStorage& rhs) const;
struct Hash {
size_t operator()(const BlobStorage& b) const { return b.hash; }
};
};
Maybe<Blob> get(const std::string& category, const Blob& key) override;
void put(const std::string& category, const Blob& key,
const Blob& value) override;
std::unordered_map<
std::string,
std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>>
m_cache;
std::mutex m_mtx;
};
/*! /*!
* \brief proxy PersistentCache to be better suited for managing profiling * \brief proxy PersistentCache to be better suited for managing profiling
* results of operator impl algorithms * results of operator impl algorithms
......
...@@ -68,7 +68,6 @@ std::string format_fixlayouts( ...@@ -68,7 +68,6 @@ std::string format_fixlayouts(
ret.append(", "); ret.append(", ");
} }
ret.append(layouts[i].to_string() + " "); ret.append(layouts[i].to_string() + " ");
ret.append(layouts[i].dtype.name());
} }
ret.append(") -> ("); ret.append(") -> (");
for (size_t i = 0; i < arity_out; ++i) { for (size_t i = 0; i < arity_out; ++i) {
...@@ -76,7 +75,6 @@ std::string format_fixlayouts( ...@@ -76,7 +75,6 @@ std::string format_fixlayouts(
ret.append(", "); ret.append(", ");
} }
ret.append(layouts[i + arity_in].to_string() + " "); ret.append(layouts[i + arity_in].to_string() + " ");
ret.append(layouts[i + arity_in].dtype.name());
} }
return ret; return ret;
} }
...@@ -420,6 +418,7 @@ AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, ...@@ -420,6 +418,7 @@ AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx,
AlgoChooser<_Opr>::profile(sub_ctx, selected_strategy); AlgoChooser<_Opr>::profile(sub_ctx, selected_strategy);
}); });
} }
typename AlgoChooser<Opr>::ImplExecutionPolicy policy; typename AlgoChooser<Opr>::ImplExecutionPolicy policy;
ctx.construct_execution_policy(selected_strategy, policy); ctx.construct_execution_policy(selected_strategy, policy);
return policy; return policy;
...@@ -660,8 +659,28 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( ...@@ -660,8 +659,28 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy(
bool retrive_from_cache) const { bool retrive_from_cache) const {
if (!policy.algo.valid()) { if (!policy.algo.valid()) {
if (retrive_from_cache) { if (retrive_from_cache) {
policy.algo = policy.algo = get_profile_result_from_cache(selected_strategy).desc;
get_profile_result_from_cache(selected_strategy).desc; if (!policy.algo.valid()) {
auto target_attr =
extract_algo_attribute_from_execution_strategy(
selected_strategy);
std::string layouts_str =
format_fixlayouts<Opr>(m_layouts, arity_in, arity_out);
std::string msg = ssprintf(
"(mbg_opr : %s, layouts %s, with attribute(%s) and "
"without attribute(%s)",
m_base_mgb_opr->dyn_typeinfo()->name,
layouts_str.c_str(),
Algorithm::attribute_str(target_attr.first).c_str(),
Algorithm::attribute_str(target_attr.second).c_str());
mgb_log_warn(
"No algo get from cache for %s. This may caused by "
"mismatch with model and cache file. ex. profiling "
"with version1, but inferencing on version2 or "
"profiling modelA but inferencing modelB",
msg.c_str());
return;
}
} else { } else {
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
owner_graph(), m_cn, m_execution_policy.workspace_limit); owner_graph(), m_cn, m_execution_policy.workspace_limit);
...@@ -673,10 +692,12 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( ...@@ -673,10 +692,12 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy(
attr.second), attr.second),
m_layouts) m_layouts)
.desc; .desc;
mgb_assert(policy.algo.valid(),
"No algo found from heuristic with strategy %u and "
"workspace limit %zu",
static_cast<uint32_t>(selected_strategy),
workspace_limit);
} }
mgb_assert(policy.algo.valid(),
"No algo found from cache or heuristic, maybe some error "
"occured");
} }
Algorithm* algo = m_megdnn_opr->get_algorithm_from_desc(policy.algo); Algorithm* algo = m_megdnn_opr->get_algorithm_from_desc(policy.algo);
...@@ -697,9 +718,13 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( ...@@ -697,9 +718,13 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy(
sub_ctx.construct_execution_policy(selected_strategy, sub_ctx.construct_execution_policy(selected_strategy,
policy.sub_policy.back(), policy.sub_policy.back(),
retrive_from_cache); retrive_from_cache);
if (!policy.sub_policy.back().algo.valid()) {
// means sub_ctx.construct_execution_policy fails. clean up
// policy.algo and return
policy = {};
return;
}
}); });
return;
} }
template <typename Opr> template <typename Opr>
......
...@@ -140,9 +140,10 @@ public: ...@@ -140,9 +140,10 @@ public:
* \brief construct execution policy from cache or heuristic. * \brief construct execution policy from cache or heuristic.
* *
* \param selected_strategy select algo which matched this strategy * \param selected_strategy select algo which matched this strategy
* \param policy execution policy * \param [out] policy execution policy
* \param retrive_from_cache retrive algo from cache if set True, get * \param retrive_from_cache retrive algo from cache if set True, get
* from heuristic otherwise. * from heuristic otherwise.
* \note When contruction fail, the policy will be cleaned.
*/ */
void construct_execution_policy(ExecutionStrategy selected_strategy, void construct_execution_policy(ExecutionStrategy selected_strategy,
ImplExecutionPolicy& policy, ImplExecutionPolicy& policy,
...@@ -152,14 +153,13 @@ public: ...@@ -152,14 +153,13 @@ public:
Maybe<PreprocessFilter<Opr>> construct_fake_preprocess_filter() const; Maybe<PreprocessFilter<Opr>> construct_fake_preprocess_filter() const;
}; };
template<typename U> template <typename U>
friend class AlgoChooser; friend class AlgoChooser;
private: private:
//! entrance for getting algorithm according to execution strategy //! entrance for getting algorithm according to execution strategy
static ImplExecutionPolicy get_policy(ExeContext& ctx); static ImplExecutionPolicy get_policy(ExeContext& ctx);
//! profile and save to cache //! profile and save to cache
static void profile(ExeContext& ctx, ExecutionStrategy selected_strategy); static void profile(ExeContext& ctx, ExecutionStrategy selected_strategy);
......
...@@ -30,7 +30,6 @@ ...@@ -30,7 +30,6 @@
#include <random> #include <random>
using namespace mgb; using namespace mgb;
namespace { namespace {
using Param = opr::Convolution::Param; using Param = opr::Convolution::Param;
...@@ -354,21 +353,26 @@ TEST(TestOprDNN, ConvBiasExePolicy) { ...@@ -354,21 +353,26 @@ TEST(TestOprDNN, ConvBiasExePolicy) {
auto cn = CompNode::load("cpux"); auto cn = CompNode::load("cpux");
auto orig_impl = PersistentCache::set_impl(
std::make_shared<InMemoryPersistentCache>());
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
for (auto strategy : for (auto strategy :
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE,
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTIMIZED}) { S::PROFILE | S::HEURISTIC}) {
#else #else
for (auto strategy : for (auto strategy :
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) {
#endif #endif
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
auto mkvar = [&](const char* name, const TensorShape& shp, auto mkvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) { const DType& dtype) {
return opr::TypeCvt::make( return opr::TypeCvt::make(
opr::Host2DeviceCopy::make(*graph, gen(shp), cn).rename(name), opr::Host2DeviceCopy::make(*graph, gen(shp), cn)
.rename(name),
dtype); dtype);
}; };
...@@ -388,7 +392,11 @@ TEST(TestOprDNN, ConvBiasExePolicy) { ...@@ -388,7 +392,11 @@ TEST(TestOprDNN, ConvBiasExePolicy) {
HostTensorND host_y; HostTensorND host_y;
auto func = graph->compile({make_callback_copy(conv_bias, host_y)}); auto func = graph->compile({make_callback_copy(conv_bias, host_y)});
func->execute(); func->execute();
//! set a new cache
PersistentCache::set_impl(std::make_shared<InMemoryPersistentCache>());
} }
PersistentCache::set_impl(orig_impl);
} }
TEST(TestOprDNN, ConvBiasExePolicy_Quantized8Asym) { TEST(TestOprDNN, ConvBiasExePolicy_Quantized8Asym) {
...@@ -401,19 +409,21 @@ TEST(TestOprDNN, ConvBiasExePolicy_Quantized8Asym) { ...@@ -401,19 +409,21 @@ TEST(TestOprDNN, ConvBiasExePolicy_Quantized8Asym) {
for (auto strategy : for (auto strategy :
SmallVector<S>{S::PROFILE, S::PROFILE | S::REPRODUCIBLE}) { SmallVector<S>{S::PROFILE, S::PROFILE | S::REPRODUCIBLE}) {
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
auto mkvar = [&](const char* name, const TensorShape& shp, auto mkvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) { const DType& dtype) {
return opr::TypeCvt::make( return opr::TypeCvt::make(
opr::Host2DeviceCopy::make(*graph, gen(shp), cn).rename(name), opr::Host2DeviceCopy::make(*graph, gen(shp), cn)
.rename(name),
dtype); dtype);
}; };
auto x = mkvar("x", {20, 50, 50, 16}, dtype::Quantized8Asymm(2.5f, static_cast<uint8_t>(0))); auto x = mkvar("x", {20, 50, 50, 16},
auto w = mkvar("w", {24, 3, 3, 16}, dtype::Quantized8Asymm(2.5f, static_cast<uint8_t>(0))); dtype::Quantized8Asymm(2.5f, static_cast<uint8_t>(0)));
auto w = mkvar("w", {24, 3, 3, 16},
dtype::Quantized8Asymm(2.5f, static_cast<uint8_t>(0)));
auto bias = mkvar("bias", {1, 1, 1, 24}, dtype::QuantizedS32(6.25f)); auto bias = mkvar("bias", {1, 1, 1, 24}, dtype::QuantizedS32(6.25f));
param.nonlineMode = Param::NonlineMode::RELU; param.nonlineMode = Param::NonlineMode::RELU;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册