提交 cfed86f9 编写于 作者: M Megvii Engine Team

feat(persistentcache): change file persistent cache with append model

GitOrigin-RevId: 7a427bdab4198a48190f0ecd7a45f4c0c889be8c
上级 d19fc2c1
...@@ -120,22 +120,13 @@ InFilePersistentCache::BlobStorage& InFilePersistentCache::BlobStorage::init_fro ...@@ -120,22 +120,13 @@ InFilePersistentCache::BlobStorage& InFilePersistentCache::BlobStorage::init_fro
return *this; return *this;
} }
template <typename OutputFile>
void InFilePersistentCache::BlobStorage::write_to_file(OutputFile& out_file) const { void InFilePersistentCache::BlobStorage::write_to_file(OutputFile& out_file) const {
uint32_t u_size = size; uint32_t u_size = size;
out_file.write(u_size); out_file.write(u_size);
out_file.write(data_refhold.get(), u_size); out_file.write(data_refhold.get(), u_size);
} }
InFilePersistentCache::BlobStorage& InFilePersistentCache::BlobStorage::init_data_ref(
const Blob& b) {
data_refhold = std::make_unique<uint8_t[]>(b.size + 1);
memcpy(data_refhold.get(), b.ptr, b.size);
data_refhold.get()[b.size] = 0; // for C-string safety
ptr = data_refhold.get();
size = b.size;
return *this;
}
//////////////////////// InFilePersistentCache ////////////////////// //////////////////////// InFilePersistentCache //////////////////////
template <typename Input> template <typename Input>
......
...@@ -21,10 +21,62 @@ std::shared_ptr<PersistentCache> PersistentCache::sm_impl = ...@@ -21,10 +21,62 @@ std::shared_ptr<PersistentCache> PersistentCache::sm_impl =
std::shared_ptr<PersistentCache> PersistentCache::set_impl( std::shared_ptr<PersistentCache> PersistentCache::set_impl(
std::shared_ptr<PersistentCache> impl) { std::shared_ptr<PersistentCache> impl) {
mgb_assert(impl); mgb_assert(impl);
merge_old_cache(impl);
sm_impl.swap(impl); sm_impl.swap(impl);
return impl; return impl;
} }
void PersistentCache::merge_old_cache(std::shared_ptr<PersistentCache> impl) {
MGB_LOCK_GUARD(PersistentCache::inst().m_mtx);
if (sm_impl) {
auto& old_cache = sm_impl->m_cache;
if (old_cache.size() > 0) {
mgb_log_debug("find old persistent cache, now append to it!!");
auto& new_cache = impl->m_cache;
CacheMap tmp_cache;
//! CacheMap do not imp deepcopy and = operator, so we insert manually
auto insert = [](CacheMap& dst, CacheMap& in) {
for (auto& x : in) {
auto category = x.first;
for (auto& y : x.second) {
auto& key = y.first;
auto& value = y.second;
BlobStorage key_storage;
key_storage.init_data_ref(key).init_hash();
dst[category][std::move(key_storage)].init_data_ref(value);
}
}
};
insert(tmp_cache, old_cache);
insert(tmp_cache, new_cache);
impl->m_cache = std::move(tmp_cache);
} else {
mgb_log_debug("do not find any old persistent cache");
}
}
}
PersistentCache::BlobStorage& PersistentCache::BlobStorage::init_data_ref(
const Blob& b) {
data_refhold = std::make_unique<uint8_t[]>(b.size + 1);
memcpy(data_refhold.get(), b.ptr, b.size);
data_refhold.get()[b.size] = 0; // for C-string safety
ptr = data_refhold.get();
size = b.size;
return *this;
}
PersistentCache::BlobStorage& PersistentCache::BlobStorage::init_hash() {
hash = XXHash{}.update(ptr, size).digest();
return *this;
}
bool PersistentCache::BlobStorage::operator==(const BlobStorage& rhs) const {
return size == rhs.size && !memcmp(ptr, rhs.ptr, size);
}
std::string PersistentCache::make_category_from_comp_node(CompNode comp_node) { std::string PersistentCache::make_category_from_comp_node(CompNode comp_node) {
auto&& env = CompNodeEnv::from_comp_node(comp_node); auto&& env = CompNodeEnv::from_comp_node(comp_node);
switch (env.property().type) { switch (env.property().type) {
...@@ -65,26 +117,6 @@ std::string PersistentCache::make_category_from_comp_node(CompNode comp_node) { ...@@ -65,26 +117,6 @@ std::string PersistentCache::make_category_from_comp_node(CompNode comp_node) {
// ================= InMemoryPersistentCache ================== // ================= InMemoryPersistentCache ==================
using Blob = PersistentCache::Blob; using Blob = PersistentCache::Blob;
InMemoryPersistentCache::BlobStorage& InMemoryPersistentCache::BlobStorage::
init_data_ref(const Blob& b) {
data_refhold = std::make_unique<uint8_t[]>(b.size + 1);
memcpy(data_refhold.get(), b.ptr, b.size);
data_refhold.get()[b.size] = 0; // for C-string safety
ptr = data_refhold.get();
size = b.size;
return *this;
}
InMemoryPersistentCache::BlobStorage& InMemoryPersistentCache::BlobStorage::
init_hash() {
hash = XXHash{}.update(ptr, size).digest();
return *this;
}
bool InMemoryPersistentCache::BlobStorage::operator==(const BlobStorage& rhs) const {
return size == rhs.size && !memcmp(ptr, rhs.ptr, size);
}
Maybe<Blob> InMemoryPersistentCache::get(const std::string& category, const Blob& key) { Maybe<Blob> InMemoryPersistentCache::get(const std::string& category, const Blob& key) {
decltype(m_cache.begin()) iter0; decltype(m_cache.begin()) iter0;
{ {
......
...@@ -17,33 +17,6 @@ class InFilePersistentCache final : public PersistentCache { ...@@ -17,33 +17,6 @@ class InFilePersistentCache final : public PersistentCache {
class InputFile; class InputFile;
class InputMemory; class InputMemory;
class OutputFile; class OutputFile;
struct BlobStorage : public Blob {
std::unique_ptr<uint8_t[]> data_refhold;
size_t hash = 0;
template <typename Input>
BlobStorage& init_from_input(Input& inp);
void write_to_file(OutputFile& out_file) const;
BlobStorage& init_data_ref(const Blob& b);
BlobStorage& init_hash() {
hash = XXHash{}.update(ptr, size).digest();
return *this;
}
bool operator==(const BlobStorage& rhs) const {
return size == rhs.size && !memcmp(ptr, rhs.ptr, size);
}
struct Hash {
size_t operator()(const BlobStorage& b) const { return b.hash; }
};
};
std::unordered_map<
std::string,
std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>>
m_cache;
MGB_MUTEX m_mtx;
std::shared_ptr<OutputFile> m_always_open_file; std::shared_ptr<OutputFile> m_always_open_file;
template <typename Input> template <typename Input>
...@@ -68,13 +41,6 @@ public: ...@@ -68,13 +41,6 @@ public:
MGE_WIN_DECLSPEC_FUC void put( MGE_WIN_DECLSPEC_FUC void put(
const std::string& category, const Blob& key, const Blob& value) override; const std::string& category, const Blob& key, const Blob& value) override;
bool support_dump_cache() override { return true; } bool support_dump_cache() override { return true; }
std::unordered_map<
std::string,
std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>>
get_cache() {
return std::move(m_cache);
}
}; };
} // namespace mgb } // namespace mgb
......
...@@ -23,6 +23,41 @@ public: ...@@ -23,6 +23,41 @@ public:
size_t size; size_t size;
}; };
struct BlobStorage : public Blob {
std::unique_ptr<uint8_t[]> data_refhold;
size_t hash = 0;
BlobStorage& init_data_ref(const Blob& b);
BlobStorage& init_hash();
bool operator==(const BlobStorage& rhs) const;
struct Hash {
size_t operator()(const BlobStorage& b) const { return b.hash; }
};
template <typename Input>
BlobStorage& init_from_input(Input& inp);
template <typename OutputFile>
void write_to_file(OutputFile& out_file) const;
};
typedef std::unordered_map<
std::string,
std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>>
CacheMap;
CacheMap m_cache;
MGB_MUTEX m_mtx;
//! will make m_cache empty
CacheMap get_cache() { return std::move(m_cache); }
//! clear cache
MGE_WIN_DECLSPEC_FUC void clear_cache() { m_cache.clear(); }
virtual Maybe<Blob> get(const std::string& category, const Blob& key) = 0; virtual Maybe<Blob> get(const std::string& category, const Blob& key) = 0;
virtual void put( virtual void put(
...@@ -34,6 +69,9 @@ public: ...@@ -34,6 +69,9 @@ public:
MGE_WIN_DECLSPEC_FUC static std::shared_ptr<PersistentCache> set_impl( MGE_WIN_DECLSPEC_FUC static std::shared_ptr<PersistentCache> set_impl(
std::shared_ptr<PersistentCache> impl); std::shared_ptr<PersistentCache> impl);
//! merge sm_impl m_cache, use to append insert cache
MGE_WIN_DECLSPEC_FUC static void merge_old_cache(std::shared_ptr<PersistentCache>);
//! get the instance; the default implementation just caches in //! get the instance; the default implementation just caches in
//! memory //! memory
static PersistentCache& inst() { return *sm_impl; } static PersistentCache& inst() { return *sm_impl; }
...@@ -48,32 +86,11 @@ public: ...@@ -48,32 +86,11 @@ public:
* The implementation is thread safe. * The implementation is thread safe.
*/ */
class InMemoryPersistentCache final : public PersistentCache { class InMemoryPersistentCache final : public PersistentCache {
struct BlobStorage : public PersistentCache::Blob {
std::unique_ptr<uint8_t[]> data_refhold;
size_t hash = 0;
BlobStorage& init_data_ref(const Blob& b);
BlobStorage& init_hash();
bool operator==(const BlobStorage& rhs) const;
struct Hash {
size_t operator()(const BlobStorage& b) const { return b.hash; }
};
};
MGE_WIN_DECLSPEC_FUC Maybe<Blob> get( MGE_WIN_DECLSPEC_FUC Maybe<Blob> get(
const std::string& category, const Blob& key) override; const std::string& category, const Blob& key) override;
MGE_WIN_DECLSPEC_FUC void put( MGE_WIN_DECLSPEC_FUC void put(
const std::string& category, const Blob& key, const Blob& value) override; const std::string& category, const Blob& key, const Blob& value) override;
std::unordered_map<
std::string,
std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>>
m_cache;
MGB_MUTEX m_mtx;
public: public:
MGE_WIN_DECLSPEC_FUC InMemoryPersistentCache() = default; MGE_WIN_DECLSPEC_FUC InMemoryPersistentCache() = default;
}; };
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "megbrain/test/autocheck.h" #include "megbrain/test/autocheck.h"
#include "megbrain/test/helper.h" #include "megbrain/test/helper.h"
#include "megbrain/test/megdnn_helper.h" #include "megbrain/test/megdnn_helper.h"
#include "megbrain/utils/infile_persistent_cache.h"
#include "megdnn/algorithm_cache.h" #include "megdnn/algorithm_cache.h"
#include "megdnn/dtype.h" #include "megdnn/dtype.h"
#include "megdnn/oprs/base.h" #include "megdnn/oprs/base.h"
...@@ -354,6 +355,10 @@ TEST(TestOprDNN, ConvBiasExePolicy) { ...@@ -354,6 +355,10 @@ TEST(TestOprDNN, ConvBiasExePolicy) {
HostTensorND host_y; HostTensorND host_y;
auto func = graph->compile({make_callback_copy(conv_bias, host_y)}); auto func = graph->compile({make_callback_copy(conv_bias, host_y)});
func->execute(); func->execute();
//! force clear all PersistentCache by get_cache
PersistentCache::inst().clear_cache();
size_t old_size = PersistentCache::inst().get_cache().size();
ASSERT_EQ(old_size, 0);
//! set a new cache //! set a new cache
PersistentCache::set_impl(std::make_shared<InMemoryPersistentCache>()); PersistentCache::set_impl(std::make_shared<InMemoryPersistentCache>());
}; };
...@@ -372,6 +377,64 @@ TEST(TestOprDNN, ConvBiasExePolicy) { ...@@ -372,6 +377,64 @@ TEST(TestOprDNN, ConvBiasExePolicy) {
PersistentCache::set_impl(orig_impl); PersistentCache::set_impl(orig_impl);
} }
TEST(TestOprDNN, PersistentCacheAppend) {
PersistentCache::inst().clear_cache();
auto orig_impl =
PersistentCache::set_impl(std::make_shared<InMemoryPersistentCache>());
auto orig_impl_size = orig_impl->get_cache().size();
auto category_a = "test_category_a";
std::vector<int8_t> blob_key{1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int8_t> blob_value{-1, -2, -3, -4, -5, -6, -7, -8};
PersistentCache::Blob key = {.ptr = blob_key.data(), .size = blob_key.size()};
PersistentCache::Blob value = {.ptr = blob_value.data(), .size = blob_value.size()};
//! trigger call InMemoryPersistentCache put
PersistentCache::inst().put(category_a, key, value);
auto now_size = PersistentCache::inst().get_cache().size();
//! assert new key not in InMemoryPersistentCache imp
ASSERT_EQ(orig_impl_size + 1, now_size);
//! trigger append call InFilePersistentCache init
PersistentCache::set_impl(std::make_shared<InFilePersistentCache>());
auto size_after_restore = PersistentCache::inst().get_cache().size();
//! assert key not in InFilePersistentCache imp
//! as memory instance do cache do not sync cache to file
ASSERT_EQ(size_after_restore, orig_impl_size);
auto t_file_imp = std::make_shared<InFilePersistentCache>();
auto category_b = "test_category_b";
//! trigger call InFilePersistentCache put
t_file_imp->put(category_b, key, value);
//! set new file imp
PersistentCache::set_impl(t_file_imp);
//! trigger InFilePersistentCache append init
auto old_cache =
PersistentCache::set_impl(std::make_shared<InFilePersistentCache>());
//! assert set_impl return old cache exactly
ASSERT_EQ(old_cache->m_cache.size(), now_size);
//! test key get
auto get_value = PersistentCache::inst().get(category_b, key);
ASSERT_TRUE(
!memcmp(get_value.val().ptr, blob_value.data(),
blob_value.size() * sizeof(int8_t)));
size_after_restore = PersistentCache::inst().get_cache().size();
//! assert key still in orig_impl imp
ASSERT_EQ(size_after_restore, now_size);
//! restore old impl, may memory or file, trigger may memory append init
PersistentCache::set_impl(orig_impl);
size_after_restore = PersistentCache::inst().get_cache().size();
//! assert key not in orig_impl imp, caused by get_cache will clear m_cache
ASSERT_EQ(size_after_restore + 1, now_size);
}
TEST(TestOprDNN, ConvBiasExePolicy_Quantized8Asym) { TEST(TestOprDNN, ConvBiasExePolicy_Quantized8Asym) {
using Param = opr::ConvBias::Param; using Param = opr::ConvBias::Param;
Param param; Param param;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册