提交 f4ead788 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(mgb): static allocation with given padding

GitOrigin-RevId: fdf2de8ad6f767bf3d0c4f3ae9287bbd77b70c16
上级 575a6dca
......@@ -552,6 +552,10 @@ std::unique_ptr<CompNodeSeqRecorder> CompNode::ImplBase::create_seq_recorder(
return {};
}
size_t CompNode::ImplBase::get_mem_padding() {
return 0;
}
void CompNode::ImplBase::add_callback(megdnn::thin_function<void()>&&) {
mgb_throw(MegBrainError,
"Unsupported add callback to "
......
......@@ -160,7 +160,9 @@ bool SeqMemOptimizer::plan_chunk_allocation() {
if (chunk->owner_var == var) {
size_t& usage = cn2usage[var->comp_node()];
size_t offset = usage;
usage += chunk->size();
usage += get_aligned_power2(
chunk->size() + var->comp_node().get_mem_padding(),
var->comp_node().get_mem_addr_alignment());
chunk->mem_alloc_status.set_static_offset(offset);
}
}
......@@ -299,6 +301,7 @@ bool SeqMemOptimizer::run_static_mem_alloc_on_comp_node(
auto allocator = StaticMemAlloc::make(
StaticMemAlloc::AllocatorAlgo::PUSHDOWN);
allocator->alignment(comp_node.get_mem_addr_alignment());
allocator->padding(comp_node.get_mem_padding());
#if MGB_ENABLE_DEBUG_UTIL
allocator->dbg_key2varnode = [](StaticMemAlloc::UserKeyType key) {
return static_cast<const MemChunkLifeInterval*>(key)->chunk->owner_var;
......
......@@ -89,6 +89,15 @@ class StaticMemAlloc {
*/
virtual StaticMemAlloc& alignment(size_t alignment) = 0;
/*!
* \brief set interval padding at the end(except for overwritters)
*
* Must be called before calling add()
*
* \param padding interval padding
*/
virtual StaticMemAlloc& padding(size_t padding) = 0;
#if MGB_ENABLE_DEBUG_UTIL
//! set by the caller to convert key to VarNode* for debug logging
VarNode* (*dbg_key2varnode)(UserKeyType) = nullptr;
......
......@@ -84,7 +84,7 @@ size_t StaticMemAllocImplHelper::add(size_t begin, size_t end, size_t size,
mgb_assert(begin < end);
auto id = m_interval_storage.size();
m_interval_storage.push_back({begin, end, size, key, id});
m_interval_storage.push_back({begin, end, size + m_padding, key, id});
return id;
}
......
......@@ -45,6 +45,11 @@ class StaticMemAllocImplHelper: public StaticMemAlloc {
return *this;
}
StaticMemAlloc& padding(size_t padding) override final {
m_padding = padding;
return *this;
}
size_t tot_alloc_lower_bound() const override final {
return m_peak_lower_bound;
}
......@@ -69,7 +74,7 @@ class StaticMemAllocImplHelper: public StaticMemAlloc {
}
private:
size_t m_alignment = 1, m_peak_lower_bound = 0;
size_t m_alignment = 1, m_padding = 0, m_peak_lower_bound = 0;
//! original interval storage
std::vector<Interval> m_interval_storage;
......
......@@ -288,6 +288,17 @@ class CompNode {
return m_impl->get_mem_addr_alignment();
}
/*!
* \brief get the size of the paddings which must be reserved at the
* end of memory chunk; guaranteed to be power of 2
*/
size_t get_mem_padding() const {
size_t padding = m_impl->get_mem_padding();
mgb_assert(!(padding & (padding - 1)),
"mem padding should be power of 2");
return padding;
}
/*!
* \brief release consecutive free chunks on all devices to defragment;
* see DevMemAlloc::try_coalesce_free
......@@ -510,6 +521,7 @@ class CompNode {
const void *src, size_t size) = 0;
virtual size_t get_mem_addr_alignment() = 0;
virtual size_t get_mem_padding();
virtual std::unique_ptr<Event> create_event(size_t flags) = 0;
......
......@@ -34,10 +34,11 @@ struct TestParam {
using Algo = StaticMemAlloc::AllocatorAlgo;
Algo algo;
size_t align, nr_rand_opr, rng_seed;
size_t align, padding, nr_rand_opr, rng_seed;
static decltype(auto) make_values(
const std::vector<size_t> &aligns,
const std::vector<size_t> &paddings,
const std::vector<size_t> &nr_rand_opr) {
std::vector<TestParam> data;
std::mt19937_64 rng(next_rand_seed());
......@@ -46,9 +47,11 @@ struct TestParam {
for (auto nr: nr_rand_opr) {
size_t seed = rng();
for (auto align: aligns) {
#define itcb(algo) data.push_back({Algo::algo, align, nr, seed});
ITER_ALGO(itcb)
for (auto padding: paddings) {
#define itcb(algo) data.push_back({Algo::algo, align, padding, nr, seed});
ITER_ALGO(itcb)
#undef itcb
}
}
}
return ::testing::ValuesIn(data);
......@@ -65,7 +68,7 @@ std::ostream& operator << (std::ostream &ostr, const TestParam &p) {
ITER_ALGO(itcb);
#undef itcb
ostr << "algo=" << algo << " align=" << p.align;
ostr << "algo=" << algo << " align=" << p.align << " padding=" << p.padding;
if (p.nr_rand_opr != 1)
ostr << " nr_rand_opr=" << p.nr_rand_opr << " rng_seed=" << p.rng_seed;
return ostr;
......@@ -75,6 +78,10 @@ class BasicCorrectness: public ::testing::TestWithParam<TestParam> {
protected:
std::unique_ptr<cg::StaticMemAlloc> m_allocator;
size_t padding() const {
return GetParam().padding;
}
size_t align(size_t addr) const {
return get_aligned_power2(addr, GetParam().align);
}
......@@ -84,6 +91,7 @@ class BasicCorrectness: public ::testing::TestWithParam<TestParam> {
void SetUp() override {
m_allocator = StaticMemAlloc::make(GetParam().algo);
m_allocator->alignment(GetParam().align);
m_allocator->padding(GetParam().padding);
}
};
......@@ -102,8 +110,9 @@ TEST_P(BasicCorrectness, Alloc) {
allocator->add(0, 1, 1, makeuk(1));
allocator->add(1, 2, 2, makeuk(2));
allocator->solve();
ASSERT_EQ(std::max(align(2), 2 * align(1)), allocator->tot_alloc());
ASSERT_EQ(std::max(align(2), 2 * align(1)),
ASSERT_EQ(std::max(align(2 + padding()), 2 * align(1 + padding())),
allocator->tot_alloc());
ASSERT_EQ(std::max(align(2 + padding()), 2 * align(1 + padding())),
allocator->tot_alloc_lower_bound());
}
......@@ -116,8 +125,8 @@ TEST_P(BasicCorrectness, Overwrite) {
allocator->add_overwrite_spec(id2, id1, 0);
allocator->solve();
ASSERT_EQ(align(3), allocator->tot_alloc());
ASSERT_EQ(align(3), allocator->tot_alloc_lower_bound());
ASSERT_EQ(align(3 + padding()), allocator->tot_alloc());
ASSERT_EQ(align(3 + padding()), allocator->tot_alloc_lower_bound());
}
TEST_P(BasicCorrectness, OverwriteSameEnd) {
......@@ -127,12 +136,12 @@ TEST_P(BasicCorrectness, OverwriteSameEnd) {
allocator->add_overwrite_spec(id1, id0, 0);
allocator->solve();
ASSERT_EQ(align(1), allocator->tot_alloc());
ASSERT_EQ(align(1), allocator->tot_alloc_lower_bound());
ASSERT_EQ(align(1 + padding()), allocator->tot_alloc());
ASSERT_EQ(align(1 + padding()), allocator->tot_alloc_lower_bound());
}
INSTANTIATE_TEST_CASE_P(TestStaticMemAllocAlgo,
BasicCorrectness, TestParam::make_values({1, 2}, {1}));
BasicCorrectness, TestParam::make_values({1, 2}, {1, 2}, {1}));
#ifdef __OPTIMIZE__
......@@ -220,7 +229,7 @@ TEST_P(RandomOpr, Main) {
}
INSTANTIATE_TEST_CASE_P(TestStaticMemAllocAlgo,
RandomOpr, TestParam::make_values({1, 256}, {
RandomOpr, TestParam::make_values({1, 256}, {1, 32}, {
10, INTERVAL_MOVE_MAX_SIZE, 1000, 10000}));
TEST(TestStaticMemAllocAlgo, PushdownChain) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册