提交 bd42cb06 编写于 作者: M Megvii Engine Team

refactor(mgb/lite): refactor lite InfilePersistentCache with core impl

GitOrigin-RevId: 64b7825c34e5e8f271452444098cc2b35b9a17e3
上级 676b205b
......@@ -24,7 +24,7 @@
#include "megbrain/comp_node.h"
#include "megbrain/serialization/extern_c_opr.h"
#include "megbrain/version.h"
#include "mge/algo_cache/file_cache.h"
#include "megbrain/utils/infile_persistent_cache.h"
#include "mge/common.h"
#if MGB_ENABLE_TENSOR_RT
#include "megbrain/tensorrt/tensorrt_engine_cache.h"
......@@ -170,8 +170,8 @@ void lite::set_persistent_cache(const std::string& cache_path, bool always_sync)
"it now may cause unknow error!!");
}
cache_control.config_algo_times++;
mgb::PersistentCache::set_impl(
std::make_shared<InFilePersistentCache>(cache_path.c_str(), always_sync));
mgb::PersistentCache::set_impl(std::make_shared<mgb::InFilePersistentCache>(
cache_path.c_str(), always_sync));
}
void lite::dump_persistent_cache(const std::string& cache_path) {
......@@ -179,7 +179,7 @@ void lite::dump_persistent_cache(const std::string& cache_path) {
LITE_ASSERT(
cache_control.cache_type == "file",
"now cache type not correct, it can't be dumped.");
static_cast<InFilePersistentCache&>(mgb::PersistentCache::inst())
static_cast<mgb::InFilePersistentCache&>(mgb::PersistentCache::inst())
.dump_cache(cache_path.c_str());
}
......
/**
* \file lite/src/mge/algo_cache/file_cache.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "lite_build_config.h"
#if LITE_BUILD_WITH_MGE
#include "../common.h"
#include "file_cache.h"
using namespace lite;
//////////////////////// InFilePersistentCache::InputMemory ///////////////
class InFilePersistentCache::InputMemory {
const uint8_t* m_ptr;
size_t m_offset = 0;
size_t m_size;
public:
InputMemory(const uint8_t* bin, size_t size) : m_ptr{bin}, m_size{size} {}
template <typename T>
void read(T& val) {
static_assert(
std::is_trivially_copyable<T>::value,
"only support trivially copyable type");
LITE_ASSERT(m_offset + sizeof(T) <= m_size);
memcpy(&val, m_ptr, sizeof(T));
m_offset += sizeof(T);
m_ptr += sizeof(T);
}
template <typename T>
void read(T* buf, size_t size) {
static_assert(
std::is_trivially_copyable<T>::value && sizeof(T) == 1,
"only support read bytes");
LITE_ASSERT(m_offset + size <= m_size);
memcpy(buf, m_ptr, size);
m_offset += size;
m_ptr += size;
}
};
//////////////////////// InFilePersistentCache::InputFile ///////////////
class InFilePersistentCache::InputFile {
FILE* m_fp;
public:
InputFile(const char* path) : m_fp{fopen(path, "rb")} {
LITE_ASSERT(m_fp, "failed to open %s: %s", path, strerror(errno));
}
~InputFile() {
if (m_fp) {
fclose(m_fp);
}
}
template <typename T>
void read(T& val) {
static_assert(
std::is_trivially_copyable<T>::value,
"only support trivially copyable type");
auto ret = fread(&val, sizeof(T), 1, m_fp);
LITE_ASSERT(ret == 1);
}
template <typename T>
void read(T* buf, size_t size) {
static_assert(
std::is_trivially_copyable<T>::value && sizeof(T) == 1,
"only support read bytes");
auto ret = fread(buf, size, 1, m_fp);
LITE_ASSERT(ret == 1);
}
};
//////////////////////// InFilePersistentCache::OutputFile ///////////////
class InFilePersistentCache::OutputFile {
FILE* m_fp;
public:
OutputFile(const char* path) : m_fp{fopen(path, "wb")} {
LITE_ASSERT(m_fp, "failed to open %s: %s", path, strerror(errno));
}
~OutputFile() {
if (m_fp) {
fclose(m_fp);
}
}
template <typename T>
void write(T val) {
auto ret = fwrite(&val, sizeof(T), 1, m_fp);
LITE_ASSERT(ret == 1);
}
template <typename T>
void write(const T* buf, size_t size) {
static_assert(sizeof(T) == 1, "only support write bytes");
auto ret = fwrite(buf, size, 1, m_fp);
LITE_ASSERT(ret == 1);
}
void flush() { fflush(m_fp); }
void set_head() { fseek(m_fp, 0, SEEK_SET); }
};
//////////////////////// InFilePersistentCache::BlobStorage ///////////////
template <typename Input>
InFilePersistentCache::BlobStorage& InFilePersistentCache::BlobStorage::init_from_input(
Input& inp) {
uint32_t data_size;
inp.read(data_size);
size = data_size;
data_refhold = std::make_unique<uint8_t[]>(size);
inp.read(data_refhold.get(), size);
ptr = data_refhold.get();
return *this;
}
void InFilePersistentCache::BlobStorage::write_to_file(OutputFile& out_file) const {
uint32_t u_size = size;
out_file.write(u_size);
out_file.write(data_refhold.get(), u_size);
}
InFilePersistentCache::BlobStorage& InFilePersistentCache::BlobStorage::init_data_ref(
const Blob& b) {
data_refhold = std::make_unique<uint8_t[]>(b.size + 1);
memcpy(data_refhold.get(), b.ptr, b.size);
data_refhold.get()[b.size] = 0; // for C-string safety
ptr = data_refhold.get();
size = b.size;
return *this;
}
//////////////////////// InFilePersistentCache //////////////////////
template <typename Input>
void InFilePersistentCache::read_cache(Input& inp) {
uint32_t nr_category;
inp.read(nr_category);
char category_buf[256];
for (uint32_t i = 0; i < nr_category; i++) {
uint32_t category_size;
inp.read(category_size);
inp.read(category_buf, category_size);
category_buf[category_size] = '\0';
std::string category(category_buf);
mgb_log_debug("load new category: %s", category_buf);
// read bobs
uint32_t nr_bobs;
inp.read(nr_bobs);
for (uint32_t j = 0; j < nr_bobs; j++) {
BlobStorage key_storage;
key_storage.init_from_input(inp).init_hash();
mgb_log_debug("read key: %zu", key_storage.hash);
m_cache[category][std::move(key_storage)].init_from_input(inp);
}
}
}
InFilePersistentCache::InFilePersistentCache(const char* path, bool always_open) {
if (!access(path, F_OK)) {
mgb_log_debug("use fastrun cache: %s", path);
InputFile inp(path);
read_cache<InputFile>(inp);
}
if (always_open) {
m_always_open_file = std::make_shared<OutputFile>(path);
}
}
InFilePersistentCache::InFilePersistentCache(const uint8_t* bin, size_t size) {
LITE_ASSERT(bin);
InputMemory inp(bin, size);
read_cache<InputMemory>(inp);
}
void InFilePersistentCache::dump_cache(const char* path) {
OutputFile out_file(path);
dump_cache(&out_file);
}
void InFilePersistentCache::dump_cache(OutputFile* out_file) {
uint32_t nr_category = m_cache.size();
out_file->write(nr_category);
for (const auto& cached_category : m_cache) {
uint32_t category_size = cached_category.first.size();
out_file->write(category_size);
out_file->write(cached_category.first.data(), category_size);
mgb_log_debug("write new category: %s", cached_category.first.c_str());
uint32_t nr_bobs = cached_category.second.size();
out_file->write(nr_bobs);
for (const auto& item : cached_category.second) {
mgb_log_debug("dump key: %zu", item.first.hash);
item.first.write_to_file(*out_file);
item.second.write_to_file(*out_file);
}
}
}
mgb::Maybe<InFilePersistentCache::Blob> InFilePersistentCache::get(
const std::string& category, const Blob& key) {
decltype(m_cache.begin()) iter0;
{
MGB_LOCK_GUARD(m_mtx);
iter0 = m_cache.find(category);
if (iter0 == m_cache.end())
return mgb::None;
}
BlobStorage key_storage;
key_storage.Blob::operator=(key);
key_storage.init_hash();
MGB_LOCK_GUARD(m_mtx);
auto iter1 = iter0->second.find(key_storage);
if (iter1 == iter0->second.end())
return mgb::None;
return iter1->second;
}
void InFilePersistentCache::put(
const std::string& category, const Blob& key, const Blob& value) {
BlobStorage key_storage;
key_storage.init_data_ref(key).init_hash();
MGB_LOCK_GUARD(m_mtx);
auto size0 = m_cache.size();
m_cache[category][std::move(key_storage)].init_data_ref(value);
if (m_cache.size() > size0) {
mgb_log_debug("new cache category: %s", category.c_str());
}
if (m_always_open_file) {
m_always_open_file->set_head();
dump_cache(m_always_open_file.get());
m_always_open_file->flush();
}
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file lite/src/mge/algo_cache/file_cache.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "lite_build_config.h"
#if LITE_BUILD_WITH_MGE
#include "megbrain/utils/persistent_cache.h"
namespace lite {
/**
* dump format:
*
* all integers in local endian (effectively little endian as I can see)
*
* dump format:
* <nr_category|uint32_t><category_size|uint32_t><category|uint8_t*>
* <nr_bob|uint32_t>[<key_size|uint32_t><key|uint8_t*><data_size|
* uint32_t><data|uint8_t*>]*
*/
//! TODO: fix one thread set cache when other threads is using old cache
class InFilePersistentCache final : public mgb::PersistentCache {
class InputFile;
class InputMemory;
class OutputFile;
struct BlobStorage : public Blob {
std::unique_ptr<uint8_t[]> data_refhold;
size_t hash = 0;
template <typename Input>
BlobStorage& init_from_input(Input& inp);
void write_to_file(OutputFile& out_file) const;
BlobStorage& init_data_ref(const Blob& b);
BlobStorage& init_hash() {
hash = mgb::XXHash{}.update(ptr, size).digest();
return *this;
}
bool operator==(const BlobStorage& rhs) const {
return size == rhs.size && !memcmp(ptr, rhs.ptr, size);
}
struct Hash {
size_t operator()(const BlobStorage& b) const { return b.hash; }
};
};
std::unordered_map<
std::string,
std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>>
m_cache;
LITE_MUTEX m_mtx;
std::shared_ptr<OutputFile> m_always_open_file;
template <typename Input>
void read_cache(Input& inp);
public:
InFilePersistentCache() = default;
InFilePersistentCache(const char* path, bool always_open = false);
InFilePersistentCache(const uint8_t* bin, size_t size);
/**
* \warning You should invoke \c dump_cache mannually to save the cache
* file.
*/
void dump_cache(const char* path);
void dump_cache(OutputFile* out_file);
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override;
void put(const std::string& category, const Blob& key, const Blob& value) override;
};
} // namespace lite
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -112,6 +112,9 @@ public:
auto ret = fwrite(buf, size, 1, m_fp);
mgb_assert(ret == 1);
}
void flush() { fflush(m_fp); }
void set_head() { fseek(m_fp, 0, SEEK_SET); }
};
//////////////////////// InFilePersistentCache::BlobStorage ///////////////
......@@ -172,12 +175,15 @@ void InFilePersistentCache::read_cache(Input& inp) {
}
}
InFilePersistentCache::InFilePersistentCache(const char* path) {
InFilePersistentCache::InFilePersistentCache(const char* path, bool always_open) {
if (!access(path, F_OK)) {
mgb_log_debug("use fastrun cache: %s", path);
InputFile inp(path);
read_cache<InputFile>(inp);
}
if (always_open) {
m_always_open_file = std::make_shared<OutputFile>(path);
}
}
InFilePersistentCache::InFilePersistentCache(const uint8_t* bin, size_t size) {
......@@ -188,25 +194,28 @@ InFilePersistentCache::InFilePersistentCache(const uint8_t* bin, size_t size) {
void InFilePersistentCache::dump_cache(const char* path) {
OutputFile out_file(path);
dump_cache(&out_file);
}
void InFilePersistentCache::dump_cache(OutputFile* out_file) {
uint32_t nr_category = m_cache.size();
out_file.write(nr_category);
out_file->write(nr_category);
for (const auto& cached_category : m_cache) {
uint32_t category_size = cached_category.first.size();
out_file.write(category_size);
out_file.write(cached_category.first.data(), category_size);
out_file->write(category_size);
out_file->write(cached_category.first.data(), category_size);
mgb_log_debug("write new category: %s", cached_category.first.c_str());
uint32_t nr_bobs = cached_category.second.size();
out_file.write(nr_bobs);
out_file->write(nr_bobs);
for (const auto& item : cached_category.second) {
mgb_log_debug("dump key: %zu", item.first.hash);
item.first.write_to_file(out_file);
item.second.write_to_file(out_file);
item.first.write_to_file(*out_file);
item.second.write_to_file(*out_file);
}
}
}
Maybe<InFilePersistentCache::Blob> InFilePersistentCache::get(
const std::string& category, const Blob& key) {
decltype(m_cache.begin()) iter0;
......@@ -240,6 +249,11 @@ void InFilePersistentCache::put(
if (m_cache.size() > size0) {
mgb_log_debug("new cache category: %s", category.c_str());
}
if (m_always_open_file) {
m_always_open_file->set_head();
dump_cache(m_always_open_file.get());
m_always_open_file->flush();
}
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -54,13 +54,14 @@ class InFilePersistentCache final : public PersistentCache {
std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>>
m_cache;
MGB_MUTEX m_mtx;
std::shared_ptr<OutputFile> m_always_open_file;
template <typename Input>
void read_cache(Input& inp);
public:
InFilePersistentCache() = default;
InFilePersistentCache(const char* path);
InFilePersistentCache(const char* path, bool always_open = false);
InFilePersistentCache(const uint8_t* bin, size_t size);
/**
......@@ -68,6 +69,7 @@ public:
* file.
*/
void dump_cache(const char* path);
void dump_cache(OutputFile* out_file);
Maybe<Blob> get(const std::string& category, const Blob& key) override;
void put(const std::string& category, const Blob& key, const Blob& value) override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册