static_mem_alloc.cpp 7.6 KB
Newer Older
M
Megvii Engine Team 已提交
1
#include "../impl/graph/var_node_mem_mgr/static_mem_alloc.h"
2 3 4 5 6 7 8 9 10 11
#include "megbrain/test/helper.h"
#include "megbrain/utils/arith_helper.h"
#include "megbrain/utils/timer.h"

#include <random>

using namespace mgb;
using namespace cg;

#ifdef WIN32
M
Megvii Engine Team 已提交
12 13
#pragma message \
        "static_mem_alloc disabled because it causes the program to crash at startup"
14 15
#else

M
Megvii Engine Team 已提交
16
#define ITER_ALGO(cb) cb(INTERVAL_MOVE) cb(BEST_FIT) cb(PUSHDOWN)
17 18 19 20 21 22 23

namespace {

struct TestParam {
    using Algo = StaticMemAlloc::AllocatorAlgo;

    Algo algo;
24
    size_t align, padding, nr_rand_opr, rng_seed;
25 26

    static decltype(auto) make_values(
M
Megvii Engine Team 已提交
27 28
            const std::vector<size_t>& aligns, const std::vector<size_t>& paddings,
            const std::vector<size_t>& nr_rand_opr) {
29 30
        std::vector<TestParam> data;
        std::mt19937_64 rng(next_rand_seed());
M
Megvii Engine Team 已提交
31
        // std::mt19937_64 rng(0);
32

M
Megvii Engine Team 已提交
33
        for (auto nr : nr_rand_opr) {
34
            size_t seed = rng();
M
Megvii Engine Team 已提交
35 36 37
            for (auto align : aligns) {
                for (auto padding : paddings) {
#define itcb(algo)    data.push_back({Algo::algo, align, padding, nr, seed});
38
                    ITER_ALGO(itcb)
39
#undef itcb
40
                }
41 42 43 44 45 46
            }
        }
        return ::testing::ValuesIn(data);
    }
};

M
Megvii Engine Team 已提交
47
std::ostream& operator<<(std::ostream& ostr, const TestParam& p) {
48
    std::string algo;
M
Megvii Engine Team 已提交
49 50 51 52 53
#define itcb(a)                                         \
    do {                                                \
        if (p.algo == StaticMemAlloc::AllocatorAlgo::a) \
            algo = #a;                                  \
    } while (0);
54 55 56
    ITER_ALGO(itcb);
#undef itcb

57
    ostr << "algo=" << algo << " align=" << p.align << " padding=" << p.padding;
58 59 60 61 62
    if (p.nr_rand_opr != 1)
        ostr << " nr_rand_opr=" << p.nr_rand_opr << " rng_seed=" << p.rng_seed;
    return ostr;
}

M
Megvii Engine Team 已提交
63 64 65
class BasicCorrectness : public ::testing::TestWithParam<TestParam> {
protected:
    std::unique_ptr<cg::StaticMemAlloc> m_allocator;
66

M
Megvii Engine Team 已提交
67
    size_t padding() const { return GetParam().padding; }
68

M
Megvii Engine Team 已提交
69 70 71
    size_t align(size_t addr) const {
        return get_aligned_power2(addr, GetParam().align);
    }
72

M
Megvii Engine Team 已提交
73 74 75 76 77 78
public:
    void SetUp() override {
        m_allocator = StaticMemAlloc::make(GetParam().algo);
        m_allocator->alignment(GetParam().align);
        m_allocator->padding(GetParam().padding);
    }
79 80
};

M
Megvii Engine Team 已提交
81
class RandomOpr : public BasicCorrectness {};
82 83 84 85 86

decltype(auto) makeuk(int v) {
    return reinterpret_cast<cg::StaticMemAlloc::UserKeyType>(v);
}

M
Megvii Engine Team 已提交
87
}  // anonymous namespace
88 89

TEST_P(BasicCorrectness, Alloc) {
M
Megvii Engine Team 已提交
90
    cg::StaticMemAlloc* allocator = this->m_allocator.get();
91 92 93 94
    allocator->add(0, 1, 1, makeuk(0));
    allocator->add(0, 1, 1, makeuk(1));
    allocator->add(1, 2, 2, makeuk(2));
    allocator->solve();
M
Megvii Engine Team 已提交
95 96
    ASSERT_EQ(
            std::max(align(2 + padding()), 2 * align(1 + padding())),
97
            allocator->tot_alloc());
M
Megvii Engine Team 已提交
98 99
    ASSERT_EQ(
            std::max(align(2 + padding()), 2 * align(1 + padding())),
100 101 102 103
            allocator->tot_alloc_lower_bound());
}

TEST_P(BasicCorrectness, Overwrite) {
M
Megvii Engine Team 已提交
104
    cg::StaticMemAlloc* allocator = this->m_allocator.get();
105 106 107 108 109 110 111
    auto id0 = allocator->add(0, 2, 3, makeuk(0));
    auto id1 = allocator->add(1, 3, 1, makeuk(1));
    auto id2 = allocator->add(2, 4, 1, makeuk(2));
    allocator->add_overwrite_spec(id1, id0, 1);
    allocator->add_overwrite_spec(id2, id1, 0);
    allocator->solve();

112 113
    ASSERT_EQ(align(3 + padding()), allocator->tot_alloc());
    ASSERT_EQ(align(3 + padding()), allocator->tot_alloc_lower_bound());
114 115 116
}

TEST_P(BasicCorrectness, OverwriteSameEnd) {
M
Megvii Engine Team 已提交
117
    cg::StaticMemAlloc* allocator = this->m_allocator.get();
118 119 120 121 122
    auto id1 = allocator->add(1, 2, 1, makeuk(1));
    auto id0 = allocator->add(0, 2, 1, makeuk(0));
    allocator->add_overwrite_spec(id1, id0, 0);
    allocator->solve();

123 124
    ASSERT_EQ(align(1 + padding()), allocator->tot_alloc());
    ASSERT_EQ(align(1 + padding()), allocator->tot_alloc_lower_bound());
125 126
}

M
Megvii Engine Team 已提交
127
INSTANTIATE_TEST_CASE_P(
M
Megvii Engine Team 已提交
128 129
        TestStaticMemAllocAlgo, BasicCorrectness,
        TestParam::make_values({1, 2}, {1, 2}, {1}));
130

M
Megvii Engine Team 已提交
131
#ifdef __OPTIMIZE__
132 133 134 135 136 137
constexpr size_t INTERVAL_MOVE_MAX_SIZE = 600;
#else
constexpr size_t INTERVAL_MOVE_MAX_SIZE = 400;
#endif

TEST_P(RandomOpr, Main) {
M
Megvii Engine Team 已提交
138 139
    cg::StaticMemAlloc* allocator = this->m_allocator.get();
    auto&& param = this->GetParam();
140 141 142
    std::mt19937_64 rng(param.rng_seed);

    if (param.algo == TestParam::Algo::INTERVAL_MOVE &&
M
Megvii Engine Team 已提交
143
        param.nr_rand_opr > INTERVAL_MOVE_MAX_SIZE)
144 145 146 147 148
        return;

    constexpr size_t MAX_SIZE = 4096;

    // [0, 1)
M
Megvii Engine Team 已提交
149
    auto uniform = [&]() { return rng() / (std::mt19937_64::max() + 1.0); };
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166

    // int [lo, hi)
    auto uniform_i = [&](size_t lo, size_t hi = 0) -> size_t {
        if (!hi) {
            hi = lo;
            lo = 0;
        }
        mgb_assert(lo <= hi);
        return (hi - lo) * uniform() + lo;
    };

    // begin, end, size, id
    std::vector<std::tuple<size_t, size_t, size_t, size_t>> reqs;

    // indices in reqs that overwrite others
    std::vector<size_t> overwrite_src_idx;

M
Megvii Engine Team 已提交
167
    for (size_t i = 0; i < param.nr_rand_opr; ++i) {
168 169 170 171 172 173 174 175 176 177
        bool overwrite = false;
        size_t begin, ov_dest, ov_offset, size;
        if (!reqs.empty() && uniform() <= 0.2) {
            size_t idx;
            if (!overwrite_src_idx.empty() && uniform() <= 0.5)
                idx = overwrite_src_idx[uniform_i(overwrite_src_idx.size())];
            else
                idx = uniform_i(0, reqs.size());
            begin = std::get<1>(reqs[idx]);
            if (begin) {
M
Megvii Engine Team 已提交
178
                --begin;
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
                auto tot_sz = std::get<2>(reqs[idx]);
                if (tot_sz >= 2) {
                    ov_dest = std::get<3>(reqs[idx]);
                    ov_offset = uniform_i(tot_sz);
                    size = uniform_i(1, tot_sz - ov_offset);
                    overwrite = true;
                }
            }
        }
        if (!overwrite) {
            begin = uniform_i(param.nr_rand_opr);
            size = uniform_i(1, MAX_SIZE);
        }
        auto end = begin + uniform_i(1, param.nr_rand_opr),
             id = allocator->add(begin, end, size, makeuk(i));
        reqs.emplace_back(begin, end, size, id);
        if (overwrite) {
            allocator->add_overwrite_spec(id, ov_dest, ov_offset);
            overwrite_src_idx.push_back(reqs.size() - 1);
        }
    }

    RealTimer timer;
    allocator->solve();
    std::ostringstream ostr;
    ostr << param;
M
Megvii Engine Team 已提交
205
    auto sz_tot = allocator->tot_alloc(), sz_lower = allocator->tot_alloc_lower_bound();
206 207 208 209
    mgb_log("%s: time=%.3f size=%zu/%zu cost=%.3f", ostr.str().c_str(),
            timer.get_secs(), sz_tot, sz_lower, double(sz_tot) / sz_lower - 1);
}

M
Megvii Engine Team 已提交
210
INSTANTIATE_TEST_CASE_P(
M
Megvii Engine Team 已提交
211 212 213
        TestStaticMemAllocAlgo, RandomOpr,
        TestParam::make_values(
                {1, 256}, {1, 32}, {10, INTERVAL_MOVE_MAX_SIZE, 1000, 10000}));
214 215

TEST(TestStaticMemAllocAlgo, PushdownChain) {
M
Megvii Engine Team 已提交
216
    auto allocator = StaticMemAlloc::make(StaticMemAlloc::AllocatorAlgo::PUSHDOWN);
217
    constexpr size_t NR = 5;
M
Megvii Engine Team 已提交
218
    for (size_t i = 0; i < NR; ++i)
219 220 221 222 223 224
        allocator->add(i, i + 2, i + 1, makeuk(i));
    allocator->solve();
    ASSERT_EQ(NR + NR - 1, allocator->tot_alloc_lower_bound());
    ASSERT_EQ(NR + NR - 1, allocator->tot_alloc());
}

225 226 227 228 229 230 231 232 233 234 235 236
TEST(TestStaticMemAllocAlgo, PushdownConsistence) {
    for (size_t run_nr = 0; run_nr < 500; ++run_nr) {
        auto allocator = StaticMemAlloc::make(StaticMemAlloc::AllocatorAlgo::PUSHDOWN);
        constexpr size_t NR = 100;
        for (size_t i = 0; i < NR; ++i)
            allocator->add(i, i + 2, i + 1, makeuk(i));

        allocator->solve();
        ASSERT_EQ(NR + NR - 1, allocator->tot_alloc_lower_bound());
        ASSERT_EQ(NR + NR - 1, allocator->tot_alloc());
    }
}
M
Megvii Engine Team 已提交
237
#endif  // WIN32
238 239

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}