persistent_cache.cpp 8.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/**
 * \file imperative/src/impl/persistent_cache.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include <fstream>
#include <string>
#include <vector>

16 17
#include "cpp_redis/cpp_redis"

18 19 20 21 22 23
#include "megbrain/imperative/persistent_cache.h"
#include "megbrain/imperative/utils/base64.h"
#include "megbrain/utils/infile_persistent_cache.h"

namespace mgb::imperative::persistent_cache {

24 25 26 27 28 29
class RedisCache final : public ExtendedPersistentCache {
public:
    RedisCache(std::string prefix, uint64_t timeout) : m_prefix(prefix) {
        m_local = std::make_shared<mgb::InMemoryPersistentCache>();
    }

30 31 32 33
    void connect(std::string ip, size_t port, std::optional<std::string> password) {
        if (password) {
            m_client.auth(*password);
        }
34 35 36 37 38 39 40 41 42 43 44
        m_client.connect(
                ip, port,
                [](const std::string& host, std::size_t port,
                   cpp_redis::connect_state status) {
                    if (status == cpp_redis::connect_state::dropped) {
                        mgb_log("client disconnected from %s.", host.c_str());
                        mgb_log("Redis server connect to %s :%zu failed.", host.c_str(),
                                port);
                    }
                },
                std::uint32_t(200));
45
        mgb_assert(m_client.is_connected(), "connect failed");
46 47
        auto flag = m_client.get("mgb-cache-flag");
        sync();
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
        auto is_valid = [](const cpp_redis::reply& reply) {
            switch (reply.get_type()) {
                case cpp_redis::reply::type::error:
                case cpp_redis::reply::type::null:
                    return false;
                case cpp_redis::reply::type::integer:
                    return reply.as_integer() != 0;
                case cpp_redis::reply::type::simple_string:
                case cpp_redis::reply::type::bulk_string:
                    return !reply.as_string().empty();
                case cpp_redis::reply::type::array:
                    return !reply.as_array().empty();
                default:
                    mgb_assert(false, "unknown reply type %d", (int)reply.get_type());
            }
        };
        mgb_assert(is_valid(flag.get()), "invalid mgb-cache-flag");
65 66 67 68
    }

    bool valid() const override { return m_client.is_connected(); }

69 70
    void flush() override {}

71 72 73
    mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
        MGB_LOCK_GUARD(m_mtx);
        auto mem_result = m_local->get(category, key);
74
        if (mem_result.valid()) {
75
            return mem_result;
76
        }
77 78 79
        std::string key_str(static_cast<const char*>(key.ptr), key.size);
        std::string redis_key_str;
        encode(category + '@' + key_str, redis_key_str, 24);
80
        auto result = m_client.get(m_prefix + redis_key_str);
81 82
        sync();
        auto content = result.get();
83 84 85
        if (content.is_null()) {
            return None;
        }
86 87 88 89 90 91 92 93 94 95
        std::string decode_content;
        decode(content.as_string(), decode_content);
        m_local->put(category, key, {decode_content.data(), decode_content.length()});
        return m_local->get(category, key);
    }

    void put(const std::string& category, const Blob& key, const Blob& value) override {
        MGB_LOCK_GUARD(m_mtx);
        std::string key_str(static_cast<const char*>(key.ptr), key.size);
        std::string redis_key_str;
96
        encode(category + '@' + key_str, redis_key_str, 24);
97 98 99
        std::string value_str(static_cast<const char*>(value.ptr), value.size);
        std::string redis_value_str;
        encode(value_str, redis_value_str);
100
        auto result = m_client.set(m_prefix + redis_key_str, redis_value_str);
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
        m_local->put(category, key, value);
        sync();
    }

    std::optional<size_t> clear() override {
        size_t cursor = 0, nr_deleted = 0;
        std::string pattern = m_prefix + "@*";
        do {
            auto reply = m_client.scan(cursor, pattern).share();
            sync();
            auto keys = reply.get().as_array();
            std::vector<std::string> string_keys;
            for (auto&& key : keys) {
                string_keys.push_back(key.as_string());
            }
            m_client.del(string_keys);
            nr_deleted += string_keys.size();
            cursor = reply.get().as_array()[0].as_integer();
        } while (cursor != 0);
        return nr_deleted;
    }

private:
    std::shared_ptr<mgb::PersistentCache> m_local;
    std::mutex m_mtx;
    cpp_redis::client m_client;
    std::string m_prefix;
    uint64_t m_timeout;

    void sync() {
        m_client.sync_commit<double, std::milli>(std::chrono::milliseconds(m_timeout));
        mgb_assert(valid());
    }
};

136 137
class ExtendedInFilePersistentCache final : public ExtendedPersistentCache {
private:
138
    std::optional<std::string> m_path;
139 140 141 142 143
    std::unique_ptr<mgb::InFilePersistentCache> m_impl;

public:
    ExtendedInFilePersistentCache() = default;

144
    void open(std::string path) {
145 146
        std::fstream file;
        file.open(path, std::ios::in | std::ios::binary);
147
        mgb_assert(file.is_open(), "can't open file in %s", path.c_str());
148 149 150 151 152 153 154 155 156 157 158
        std::vector<char> bytes = {
                std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>()};
        if (bytes.size()) {
            m_impl = std::make_unique<mgb::InFilePersistentCache>(
                    (const uint8_t*)bytes.data(), bytes.size());
        } else {
            m_impl = std::make_unique<mgb::InFilePersistentCache>();
        }
        m_path = path;
    }

159 160 161
    void open() { m_impl = std::make_unique<mgb::InFilePersistentCache>(); }

    ~ExtendedInFilePersistentCache() { flush(); }
162 163 164 165 166 167 168 169 170 171

    mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
        return m_impl->get(category, key);
    }

    void put(const std::string& category, const Blob& key, const Blob& value) override {
        return m_impl->put(category, key, value);
    }

    std::optional<size_t> clear() override {
172 173 174 175 176 177
        if (m_impl) {
            m_impl = std::make_unique<mgb::InFilePersistentCache>();
            if (m_path) {
                m_impl->dump_cache(m_path->c_str());
            }
        }
178 179 180 181 182
        return {};
    }

    bool valid() const override { return m_impl != nullptr; }

183 184 185 186
    void flush() override {
        if (m_impl && m_path) {
            m_impl->dump_cache(m_path->c_str());
        }
187
    }
188
};
189

190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
std::shared_ptr<ExtendedPersistentCache> ExtendedPersistentCache::make_from_config(
        std::string type, std::unordered_map<std::string, std::string> args,
        std::string& err_msg) {
    try {
        if (type == "redis") {
            std::string prefix = args.at("prefix");
            std::optional<std::string> password = args.count("password")
                                                        ? args.at("password")
                                                        : std::optional<std::string>();
            auto cache = std::make_shared<RedisCache>(prefix, 100);
            if (args.count("unixsocket")) {
                std::string unixsocket = args.at("unixsocket");
                cache->connect(unixsocket, 0, password);
            } else {
                std::string ip = args.at("hostname");
                int port = atoi(args.at("port").c_str());
                std::optional<std::string> password =
                        args.count("password") ? args.at("password")
                                               : std::optional<std::string>();
                cache->connect(ip, port, password);
            }
            return cache;
        } else if (type == "in-file") {
            std::string path = args.at("path");
            auto cache = std::make_shared<ExtendedInFilePersistentCache>();
            cache->open(path);
            return cache;
        } else if (type == "in-memory") {
            auto cache = std::make_shared<ExtendedInFilePersistentCache>();
            cache->open();
            return cache;
        } else {
            mgb_assert(false, "persistent cache type %s unsupported", type.c_str());
        }
    } catch (const std::exception& exc) {
        err_msg = exc.what();
    } catch (...) {
        err_msg = "unknown exception";
228
    }
229
    return nullptr;
230 231 232 233 234
}

}  // namespace mgb::imperative::persistent_cache

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}