persistent_cache.cpp 8.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/**
 * \file imperative/src/impl/persistent_cache.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include <fstream>
#include <string>
#include <vector>

16 17
#include "cpp_redis/cpp_redis"

18 19 20 21 22 23
#include "megbrain/imperative/persistent_cache.h"
#include "megbrain/imperative/utils/base64.h"
#include "megbrain/utils/infile_persistent_cache.h"

namespace mgb::imperative::persistent_cache {

24 25 26 27 28 29
class RedisCache final : public ExtendedPersistentCache {
public:
    RedisCache(std::string prefix, uint64_t timeout) : m_prefix(prefix) {
        m_local = std::make_shared<mgb::InMemoryPersistentCache>();
    }

30 31 32 33
    void connect(std::string ip, size_t port, std::optional<std::string> password) {
        if (password) {
            m_client.auth(*password);
        }
34 35 36 37 38 39 40 41 42 43 44
        m_client.connect(
                ip, port,
                [](const std::string& host, std::size_t port,
                   cpp_redis::connect_state status) {
                    if (status == cpp_redis::connect_state::dropped) {
                        mgb_log("client disconnected from %s.", host.c_str());
                        mgb_log("Redis server connect to %s :%zu failed.", host.c_str(),
                                port);
                    }
                },
                std::uint32_t(200));
45
        mgb_assert(m_client.is_connected(), "connect failed");
46 47
        auto flag = m_client.get("mgb-cache-flag");
        sync();
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
        auto is_valid = [](const cpp_redis::reply& reply) {
            switch (reply.get_type()) {
                case cpp_redis::reply::type::error:
                case cpp_redis::reply::type::null:
                    return false;
                case cpp_redis::reply::type::integer:
                    return reply.as_integer() != 0;
                case cpp_redis::reply::type::simple_string:
                case cpp_redis::reply::type::bulk_string:
                    return !reply.as_string().empty();
                case cpp_redis::reply::type::array:
                    return !reply.as_array().empty();
                default:
                    mgb_assert(false, "unknown reply type %d", (int)reply.get_type());
            }
        };
        mgb_assert(is_valid(flag.get()), "invalid mgb-cache-flag");
65 66 67 68
    }

    bool valid() const override { return m_client.is_connected(); }

69 70
    void flush() override {}

71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
    mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
        MGB_LOCK_GUARD(m_mtx);
        auto mem_result = m_local->get(category, key);
        if (mem_result.valid())
            return mem_result;

        std::string key_str(static_cast<const char*>(key.ptr), key.size);
        std::string redis_key_str;
        encode(category + '@' + key_str, redis_key_str, 24);
        auto result = m_client.get(redis_key_str);
        sync();
        auto content = result.get();
        if (content.is_null())
            return mgb::None;
        std::string decode_content;
        decode(content.as_string(), decode_content);
        m_local->put(category, key, {decode_content.data(), decode_content.length()});

        return m_local->get(category, key);
    }

    void put(const std::string& category, const Blob& key, const Blob& value) override {
        MGB_LOCK_GUARD(m_mtx);
        std::string key_str(static_cast<const char*>(key.ptr), key.size);
        std::string redis_key_str;
96
        encode(category + '@' + key_str, redis_key_str, 24);
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
        std::string value_str(static_cast<const char*>(value.ptr), value.size);
        std::string redis_value_str;
        encode(value_str, redis_value_str);

        auto result = m_client.set(redis_key_str, redis_value_str);
        m_local->put(category, key, value);
        sync();
    }

    std::optional<size_t> clear() override {
        size_t cursor = 0, nr_deleted = 0;
        std::string pattern = m_prefix + "@*";
        do {
            auto reply = m_client.scan(cursor, pattern).share();
            sync();
            auto keys = reply.get().as_array();
            std::vector<std::string> string_keys;
            for (auto&& key : keys) {
                string_keys.push_back(key.as_string());
            }
            m_client.del(string_keys);
            nr_deleted += string_keys.size();
            cursor = reply.get().as_array()[0].as_integer();
        } while (cursor != 0);
        return nr_deleted;
    }

private:
    std::shared_ptr<mgb::PersistentCache> m_local;
    std::mutex m_mtx;
    cpp_redis::client m_client;
    std::string m_prefix;
    uint64_t m_timeout;

    void sync() {
        m_client.sync_commit<double, std::milli>(std::chrono::milliseconds(m_timeout));
        mgb_assert(valid());
    }
};

137 138
class ExtendedInFilePersistentCache final : public ExtendedPersistentCache {
private:
139
    std::optional<std::string> m_path;
140 141 142 143 144
    std::unique_ptr<mgb::InFilePersistentCache> m_impl;

public:
    ExtendedInFilePersistentCache() = default;

145
    void open(std::string path) {
146 147
        std::fstream file;
        file.open(path, std::ios::in | std::ios::binary);
148
        mgb_assert(file.is_open(), "can't open file in %s", path.c_str());
149 150 151 152 153 154 155 156 157 158 159
        std::vector<char> bytes = {
                std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>()};
        if (bytes.size()) {
            m_impl = std::make_unique<mgb::InFilePersistentCache>(
                    (const uint8_t*)bytes.data(), bytes.size());
        } else {
            m_impl = std::make_unique<mgb::InFilePersistentCache>();
        }
        m_path = path;
    }

160 161 162
    void open() { m_impl = std::make_unique<mgb::InFilePersistentCache>(); }

    ~ExtendedInFilePersistentCache() { flush(); }
163 164 165 166 167 168 169 170 171 172

    mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
        return m_impl->get(category, key);
    }

    void put(const std::string& category, const Blob& key, const Blob& value) override {
        return m_impl->put(category, key, value);
    }

    std::optional<size_t> clear() override {
173 174 175 176 177 178
        if (m_impl) {
            m_impl = std::make_unique<mgb::InFilePersistentCache>();
            if (m_path) {
                m_impl->dump_cache(m_path->c_str());
            }
        }
179 180 181 182 183
        return {};
    }

    bool valid() const override { return m_impl != nullptr; }

184 185 186 187
    void flush() override {
        if (m_impl && m_path) {
            m_impl->dump_cache(m_path->c_str());
        }
188
    }
189
};
190

191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
std::shared_ptr<ExtendedPersistentCache> ExtendedPersistentCache::make_from_config(
        std::string type, std::unordered_map<std::string, std::string> args,
        std::string& err_msg) {
    try {
        if (type == "redis") {
            std::string prefix = args.at("prefix");
            std::optional<std::string> password = args.count("password")
                                                        ? args.at("password")
                                                        : std::optional<std::string>();
            auto cache = std::make_shared<RedisCache>(prefix, 100);
            if (args.count("unixsocket")) {
                std::string unixsocket = args.at("unixsocket");
                cache->connect(unixsocket, 0, password);
            } else {
                std::string ip = args.at("hostname");
                int port = atoi(args.at("port").c_str());
                std::optional<std::string> password =
                        args.count("password") ? args.at("password")
                                               : std::optional<std::string>();
                cache->connect(ip, port, password);
            }
            return cache;
        } else if (type == "in-file") {
            std::string path = args.at("path");
            auto cache = std::make_shared<ExtendedInFilePersistentCache>();
            cache->open(path);
            return cache;
        } else if (type == "in-memory") {
            auto cache = std::make_shared<ExtendedInFilePersistentCache>();
            cache->open();
            return cache;
        } else {
            mgb_assert(false, "persistent cache type %s unsupported", type.c_str());
        }
    } catch (const std::exception& exc) {
        err_msg = exc.what();
    } catch (...) {
        err_msg = "unknown exception";
229
    }
230
    return nullptr;
231 232 233 234 235
}

}  // namespace mgb::imperative::persistent_cache

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}