persistent_cache.cpp 8.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#include "megbrain/utils/persistent_cache.h"
#include "megbrain/comp_node_env.h"

#include <cstdio>
#include <cstring>

#ifdef WIN32
#define snprintf _snprintf
#endif

#if MGB_CUDA
#include <cuda_runtime_api.h>
#endif

using namespace mgb;

17
// ================= PersistentCache ======================
18
std::shared_ptr<PersistentCache> PersistentCache::sm_impl =
19
        std::make_shared<InMemoryPersistentCache>();
20 21 22 23 24 25 26 27 28 29 30 31 32

std::shared_ptr<PersistentCache> PersistentCache::set_impl(
        std::shared_ptr<PersistentCache> impl) {
    mgb_assert(impl);
    sm_impl.swap(impl);
    return impl;
}

std::string PersistentCache::make_category_from_comp_node(CompNode comp_node) {
    auto&& env = CompNodeEnv::from_comp_node(comp_node);
    switch (env.property().type) {
#if MGB_CUDA
        case CompNode::DeviceType::CUDA: {
33
            int cuda_rt = -1;
34
            MGB_CUDA_CHECK(cudaRuntimeGetVersion(&cuda_rt));
35
            int cuda_rt_major = cuda_rt / 1000;
36 37 38
            auto&& prop = env.cuda_env().device_prop;
            // note: we do not contain library versions such as cudnn here. They
            // are handled by opr impls in MegDNN
M
Megvii Engine Team 已提交
39 40 41
            return ssprintf(
                    "plat=cuda;dev=%s;cap=%d.%d;runtime=%d", prop.name, prop.major,
                    prop.minor, cuda_rt_major);
42 43
            break;
        }
44 45 46 47 48 49 50
#endif
#if MGB_ROCM
        case CompNode::DeviceType::ROCM: {
            int drv = -1, hip_rt = -1;
            MGB_ROCM_CHECK(hipDriverGetVersion(&drv));
            MGB_ROCM_CHECK(hipRuntimeGetVersion(&hip_rt));
            auto&& prop = env.rocm_env().device_prop;
M
Megvii Engine Team 已提交
51 52 53
            return ssprintf(
                    "plat=rocm;dev=%s;cap=%d.%d,drv=%d;runtime=%d", prop.name,
                    prop.major, prop.minor, drv, hip_rt);
54 55
            break;
        }
56 57 58 59
#endif
        case CompNode::DeviceType::CPU:
            return "plat=cpu";
        default:
M
Megvii Engine Team 已提交
60 61 62
            mgb_throw(
                    MegBrainError,
                    "unsupported comp node for persistent cache category");
63 64 65
    }
}

66 67
// ================= InMemoryPersistentCache ==================
using Blob = PersistentCache::Blob;
M
Megvii Engine Team 已提交
68 69
InMemoryPersistentCache::BlobStorage& InMemoryPersistentCache::BlobStorage::
        init_data_ref(const Blob& b) {
70 71 72 73 74 75 76 77
    data_refhold = std::make_unique<uint8_t[]>(b.size + 1);
    memcpy(data_refhold.get(), b.ptr, b.size);
    data_refhold.get()[b.size] = 0;  // for C-string safety
    ptr = data_refhold.get();
    size = b.size;
    return *this;
}

M
Megvii Engine Team 已提交
78 79
InMemoryPersistentCache::BlobStorage& InMemoryPersistentCache::BlobStorage::
        init_hash() {
80 81 82 83
    hash = XXHash{}.update(ptr, size).digest();
    return *this;
}

M
Megvii Engine Team 已提交
84
bool InMemoryPersistentCache::BlobStorage::operator==(const BlobStorage& rhs) const {
85 86 87
    return size == rhs.size && !memcmp(ptr, rhs.ptr, size);
}

M
Megvii Engine Team 已提交
88
Maybe<Blob> InMemoryPersistentCache::get(const std::string& category, const Blob& key) {
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
    decltype(m_cache.begin()) iter0;
    {
        MGB_LOCK_GUARD(m_mtx);
        iter0 = m_cache.find(category);
        if (iter0 == m_cache.end())
            return None;
    }

    BlobStorage key_storage;
    key_storage.Blob::operator=(key);
    key_storage.init_hash();

    MGB_LOCK_GUARD(m_mtx);

    auto iter1 = iter0->second.find(key_storage);
    if (iter1 == iter0->second.end())
        return None;
    return iter1->second;
}

M
Megvii Engine Team 已提交
109 110
void InMemoryPersistentCache::put(
        const std::string& category, const Blob& key, const Blob& value) {
111 112 113 114 115 116 117 118 119 120 121 122
    BlobStorage key_storage;
    key_storage.init_data_ref(key).init_hash();

    MGB_LOCK_GUARD(m_mtx);
    auto size0 = m_cache.size();
    m_cache[category][std::move(key_storage)].init_data_ref(value);
    if (m_cache.size() > size0) {
        mgb_log_debug("new cache category: %s", category.c_str());
    }
}

// ================= AlgoChooserProfileCache ==================
M
Megvii Engine Team 已提交
123
AlgoChooserProfileCache::AlgoChooserProfileCache(CompNode cn, const char* opr_type) {
124 125 126 127 128 129 130 131
    m_category = "profile:";
    m_category.append(PersistentCache::make_category_from_comp_node(cn));
    m_category.append(":");
    m_category.append(opr_type);
}

#define ENTRY_FMT ":%d;%lg;%zu:"

M
Megvii Engine Team 已提交
132
Maybe<AlgoChooserProfileCache::Result> AlgoChooserProfileCache::get(const Key& key) {
133
    auto raw_buf = PersistentCache::inst().get(m_category, key.build_blob());
M
Megvii Engine Team 已提交
134
    if (!raw_buf.valid())
135
        return None;
M
Megvii Engine Team 已提交
136 137 138 139 140 141 142 143 144
    mgb_assert(
            raw_buf->size <= 1024 * 1024,
            "buf size too large, maybe corrupted data: %p %zu", raw_buf->ptr,
            raw_buf->size);
    auto buf = static_cast<const uint8_t*>(raw_buf->ptr), buf_end = buf + raw_buf->size;
    mgb_assert(
            buf && buf < buf_end,
            "PersistentCache returned invalid value: ptr=%p size=%zu", raw_buf->ptr,
            raw_buf->size);
145 146 147 148 149 150 151 152 153
    auto read_uint32 = [&]() {
        auto next = buf + sizeof(uint32_t);
        mgb_assert(next <= buf_end);
        auto ret = *reinterpret_cast<const uint32_t*>(buf);
        buf = next;
        return ret;
    };

    auto ret_size = read_uint32();
M
Megvii Engine Team 已提交
154 155 156
    mgb_assert(
            static_cast<ptrdiff_t>(ret_size) < buf_end - buf,
            "result size too large (%u), maybe corrupted data", ret_size);
157
    Result ret(ret_size);
M
Megvii Engine Team 已提交
158
    for (auto&& i : ret) {
159 160 161 162 163 164 165 166 167
        // read algo name
        auto size = read_uint32();
        i.algo.resize(size);
        mgb_assert(buf + size < buf_end);
        memcpy(&i.algo[0], buf, size);
        buf += size;

        auto entry_len = read_uint32();
        mgb_assert(buf + entry_len <= buf_end);
M
Megvii Engine Team 已提交
168 169 170
        auto nr =
                sscanf(reinterpret_cast<const char*>(buf), ENTRY_FMT, &i.attribute,
                       &i.time, &i.workspace);
171 172 173 174 175 176 177
        mgb_assert(nr == 3);
        buf += entry_len;
    }
    mgb_assert(buf == buf_end);
    return ret;
}

M
Megvii Engine Team 已提交
178
void AlgoChooserProfileCache::put(const Key& key, Result& result) {
179
    mgb_assert(!result.empty());
M
Megvii Engine Team 已提交
180 181
    auto result_cmp = [](const ResultEntry& a, const ResultEntry& b) {
        return a.time < b.time || (a.time == b.time && a.workspace < b.workspace);
182 183 184 185
    };
    small_sort(result.begin(), result.end(), result_cmp);

    // remove algos that run slower but use more workspace
M
Megvii Engine Team 已提交
186 187 188
    for (size_t i = 1; i < result.size();) {
        auto&& prev = result[i - 1];
        auto&& cur = result[i];
189

M
Megvii Engine Team 已提交
190
        if (prev.workspace <= cur.workspace && prev.attribute == cur.attribute) {
191 192
            result.erase(result.begin() + i);
        } else {
193
            ++i;
194 195 196 197 198 199 200 201 202 203
        }
    }

    std::string val;
    val.reserve((sizeof(ResultEntry) - sizeof(std::string)) * 2 * result.size());
    auto write_uint32 = [&](uint32_t v) {
        val.append(reinterpret_cast<const char*>(&v), sizeof(v));
    };
    write_uint32(result.size());
    constexpr int SPR_SIZE = 100;
M
Megvii Engine Team 已提交
204
    for (auto&& i : result) {
205 206 207 208 209 210 211 212 213 214
        // write algo
        write_uint32(i.algo.size());
        auto pos = val.size();
        val.resize(pos + i.algo.size());
        memcpy(&val[pos], i.algo.data(), i.algo.size());

        // write others
        write_uint32(0);
        pos = val.size();
        val.resize(pos + SPR_SIZE);
M
Megvii Engine Team 已提交
215 216
        uint32_t nr = snprintf(
                &val[pos], SPR_SIZE, ENTRY_FMT, i.attribute, i.time, i.workspace);
217 218
        //! for memory boundary failed, snprintf ret do not contain \0
        nr += 1;
219 220 221 222 223
        mgb_assert(nr < SPR_SIZE);
        memcpy(&val[pos - sizeof(uint32_t)], &nr, sizeof(nr));
        val.resize(pos + nr);
    }

M
Megvii Engine Team 已提交
224
    PersistentCache::inst().put(m_category, key.build_blob(), {val.data(), val.size()});
225 226 227
}

PersistentCache::Blob AlgoChooserProfileCache::Key::build_blob() const {
M
Megvii Engine Team 已提交
228
    auto&& ret = m_blob_storage;
229 230 231 232
    if (!m_blob_storage.empty())
        return {ret.data(), ret.size()};

    ret.reserve(sizeof(TensorLayout) * 3 * m_inp_layouts_size + m_param_size);
M
Megvii Engine Team 已提交
233 234 235
    for (size_t i = 0; i < m_inp_layouts_size; ++i) {
        auto&& ly = m_inp_layouts_ptr[i];
        for (size_t j = 0; j < ly.ndim; ++j) {
236 237 238 239 240 241
            if (j)
                ret.push_back(',');
            ret.append(std::to_string(ly.shape[j]));
        }
        if (!ly.is_contiguous()) {
            ret.push_back(';');
M
Megvii Engine Team 已提交
242
            for (size_t j = 0; j < ly.ndim; ++j) {
243 244 245 246 247 248 249 250
                if (j)
                    ret.push_back(',');
                ret.append(std::to_string(ly.stride[j]));
            }
        }
        ret.push_back(';');
        ret.append(ly.dtype.name());
        ret.push_back('|');
M
Megvii Engine Team 已提交
251 252 253 254
        mgb_assert(
                ly.format.is_default() ||
                        (ly.format.is_lowbit_aligned() && ly.dtype.is_low_bit()),
                "currently only default format is supported");
255 256 257 258 259 260 261 262 263 264 265 266 267 268
    }
    if (m_param_size) {
        ret.append(reinterpret_cast<const char*>(m_param), m_param_size);
    }
    return {ret.data(), ret.size()};
}

#undef ENGRY_FMT

#ifdef WIN32
#undef snprintf
#endif

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}