提交 f4ead788 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(mgb): static allocation with given padding

GitOrigin-RevId: fdf2de8ad6f767bf3d0c4f3ae9287bbd77b70c16
上级 575a6dca
...@@ -552,6 +552,10 @@ std::unique_ptr<CompNodeSeqRecorder> CompNode::ImplBase::create_seq_recorder( ...@@ -552,6 +552,10 @@ std::unique_ptr<CompNodeSeqRecorder> CompNode::ImplBase::create_seq_recorder(
return {}; return {};
} }
size_t CompNode::ImplBase::get_mem_padding() {
return 0;
}
void CompNode::ImplBase::add_callback(megdnn::thin_function<void()>&&) { void CompNode::ImplBase::add_callback(megdnn::thin_function<void()>&&) {
mgb_throw(MegBrainError, mgb_throw(MegBrainError,
"Unsupported add callback to " "Unsupported add callback to "
......
...@@ -160,7 +160,9 @@ bool SeqMemOptimizer::plan_chunk_allocation() { ...@@ -160,7 +160,9 @@ bool SeqMemOptimizer::plan_chunk_allocation() {
if (chunk->owner_var == var) { if (chunk->owner_var == var) {
size_t& usage = cn2usage[var->comp_node()]; size_t& usage = cn2usage[var->comp_node()];
size_t offset = usage; size_t offset = usage;
usage += chunk->size(); usage += get_aligned_power2(
chunk->size() + var->comp_node().get_mem_padding(),
var->comp_node().get_mem_addr_alignment());
chunk->mem_alloc_status.set_static_offset(offset); chunk->mem_alloc_status.set_static_offset(offset);
} }
} }
...@@ -299,6 +301,7 @@ bool SeqMemOptimizer::run_static_mem_alloc_on_comp_node( ...@@ -299,6 +301,7 @@ bool SeqMemOptimizer::run_static_mem_alloc_on_comp_node(
auto allocator = StaticMemAlloc::make( auto allocator = StaticMemAlloc::make(
StaticMemAlloc::AllocatorAlgo::PUSHDOWN); StaticMemAlloc::AllocatorAlgo::PUSHDOWN);
allocator->alignment(comp_node.get_mem_addr_alignment()); allocator->alignment(comp_node.get_mem_addr_alignment());
allocator->padding(comp_node.get_mem_padding());
#if MGB_ENABLE_DEBUG_UTIL #if MGB_ENABLE_DEBUG_UTIL
allocator->dbg_key2varnode = [](StaticMemAlloc::UserKeyType key) { allocator->dbg_key2varnode = [](StaticMemAlloc::UserKeyType key) {
return static_cast<const MemChunkLifeInterval*>(key)->chunk->owner_var; return static_cast<const MemChunkLifeInterval*>(key)->chunk->owner_var;
......
...@@ -89,6 +89,15 @@ class StaticMemAlloc { ...@@ -89,6 +89,15 @@ class StaticMemAlloc {
*/ */
virtual StaticMemAlloc& alignment(size_t alignment) = 0; virtual StaticMemAlloc& alignment(size_t alignment) = 0;
/*!
* \brief set interval padding at the end(except for overwritters)
*
* Must be called before calling add()
*
* \param padding interval padding
*/
virtual StaticMemAlloc& padding(size_t padding) = 0;
#if MGB_ENABLE_DEBUG_UTIL #if MGB_ENABLE_DEBUG_UTIL
//! set by the caller to convert key to VarNode* for debug logging //! set by the caller to convert key to VarNode* for debug logging
VarNode* (*dbg_key2varnode)(UserKeyType) = nullptr; VarNode* (*dbg_key2varnode)(UserKeyType) = nullptr;
......
...@@ -84,7 +84,7 @@ size_t StaticMemAllocImplHelper::add(size_t begin, size_t end, size_t size, ...@@ -84,7 +84,7 @@ size_t StaticMemAllocImplHelper::add(size_t begin, size_t end, size_t size,
mgb_assert(begin < end); mgb_assert(begin < end);
auto id = m_interval_storage.size(); auto id = m_interval_storage.size();
m_interval_storage.push_back({begin, end, size, key, id}); m_interval_storage.push_back({begin, end, size + m_padding, key, id});
return id; return id;
} }
......
...@@ -45,6 +45,11 @@ class StaticMemAllocImplHelper: public StaticMemAlloc { ...@@ -45,6 +45,11 @@ class StaticMemAllocImplHelper: public StaticMemAlloc {
return *this; return *this;
} }
StaticMemAlloc& padding(size_t padding) override final {
m_padding = padding;
return *this;
}
size_t tot_alloc_lower_bound() const override final { size_t tot_alloc_lower_bound() const override final {
return m_peak_lower_bound; return m_peak_lower_bound;
} }
...@@ -69,7 +74,7 @@ class StaticMemAllocImplHelper: public StaticMemAlloc { ...@@ -69,7 +74,7 @@ class StaticMemAllocImplHelper: public StaticMemAlloc {
} }
private: private:
size_t m_alignment = 1, m_peak_lower_bound = 0; size_t m_alignment = 1, m_padding = 0, m_peak_lower_bound = 0;
//! original interval storage //! original interval storage
std::vector<Interval> m_interval_storage; std::vector<Interval> m_interval_storage;
......
...@@ -288,6 +288,17 @@ class CompNode { ...@@ -288,6 +288,17 @@ class CompNode {
return m_impl->get_mem_addr_alignment(); return m_impl->get_mem_addr_alignment();
} }
/*!
* \brief get the size of the paddings which must be reserved at the
* end of memory chunk; guaranteed to be power of 2
*/
size_t get_mem_padding() const {
size_t padding = m_impl->get_mem_padding();
mgb_assert(!(padding & (padding - 1)),
"mem padding should be power of 2");
return padding;
}
/*! /*!
* \brief release consecutive free chunks on all devices to defragment; * \brief release consecutive free chunks on all devices to defragment;
* see DevMemAlloc::try_coalesce_free * see DevMemAlloc::try_coalesce_free
...@@ -510,6 +521,7 @@ class CompNode { ...@@ -510,6 +521,7 @@ class CompNode {
const void *src, size_t size) = 0; const void *src, size_t size) = 0;
virtual size_t get_mem_addr_alignment() = 0; virtual size_t get_mem_addr_alignment() = 0;
virtual size_t get_mem_padding();
virtual std::unique_ptr<Event> create_event(size_t flags) = 0; virtual std::unique_ptr<Event> create_event(size_t flags) = 0;
......
...@@ -34,10 +34,11 @@ struct TestParam { ...@@ -34,10 +34,11 @@ struct TestParam {
using Algo = StaticMemAlloc::AllocatorAlgo; using Algo = StaticMemAlloc::AllocatorAlgo;
Algo algo; Algo algo;
size_t align, nr_rand_opr, rng_seed; size_t align, padding, nr_rand_opr, rng_seed;
static decltype(auto) make_values( static decltype(auto) make_values(
const std::vector<size_t> &aligns, const std::vector<size_t> &aligns,
const std::vector<size_t> &paddings,
const std::vector<size_t> &nr_rand_opr) { const std::vector<size_t> &nr_rand_opr) {
std::vector<TestParam> data; std::vector<TestParam> data;
std::mt19937_64 rng(next_rand_seed()); std::mt19937_64 rng(next_rand_seed());
...@@ -46,11 +47,13 @@ struct TestParam { ...@@ -46,11 +47,13 @@ struct TestParam {
for (auto nr: nr_rand_opr) { for (auto nr: nr_rand_opr) {
size_t seed = rng(); size_t seed = rng();
for (auto align: aligns) { for (auto align: aligns) {
#define itcb(algo) data.push_back({Algo::algo, align, nr, seed}); for (auto padding: paddings) {
#define itcb(algo) data.push_back({Algo::algo, align, padding, nr, seed});
ITER_ALGO(itcb) ITER_ALGO(itcb)
#undef itcb #undef itcb
} }
} }
}
return ::testing::ValuesIn(data); return ::testing::ValuesIn(data);
} }
}; };
...@@ -65,7 +68,7 @@ std::ostream& operator << (std::ostream &ostr, const TestParam &p) { ...@@ -65,7 +68,7 @@ std::ostream& operator << (std::ostream &ostr, const TestParam &p) {
ITER_ALGO(itcb); ITER_ALGO(itcb);
#undef itcb #undef itcb
ostr << "algo=" << algo << " align=" << p.align; ostr << "algo=" << algo << " align=" << p.align << " padding=" << p.padding;
if (p.nr_rand_opr != 1) if (p.nr_rand_opr != 1)
ostr << " nr_rand_opr=" << p.nr_rand_opr << " rng_seed=" << p.rng_seed; ostr << " nr_rand_opr=" << p.nr_rand_opr << " rng_seed=" << p.rng_seed;
return ostr; return ostr;
...@@ -75,6 +78,10 @@ class BasicCorrectness: public ::testing::TestWithParam<TestParam> { ...@@ -75,6 +78,10 @@ class BasicCorrectness: public ::testing::TestWithParam<TestParam> {
protected: protected:
std::unique_ptr<cg::StaticMemAlloc> m_allocator; std::unique_ptr<cg::StaticMemAlloc> m_allocator;
size_t padding() const {
return GetParam().padding;
}
size_t align(size_t addr) const { size_t align(size_t addr) const {
return get_aligned_power2(addr, GetParam().align); return get_aligned_power2(addr, GetParam().align);
} }
...@@ -84,6 +91,7 @@ class BasicCorrectness: public ::testing::TestWithParam<TestParam> { ...@@ -84,6 +91,7 @@ class BasicCorrectness: public ::testing::TestWithParam<TestParam> {
void SetUp() override { void SetUp() override {
m_allocator = StaticMemAlloc::make(GetParam().algo); m_allocator = StaticMemAlloc::make(GetParam().algo);
m_allocator->alignment(GetParam().align); m_allocator->alignment(GetParam().align);
m_allocator->padding(GetParam().padding);
} }
}; };
...@@ -102,8 +110,9 @@ TEST_P(BasicCorrectness, Alloc) { ...@@ -102,8 +110,9 @@ TEST_P(BasicCorrectness, Alloc) {
allocator->add(0, 1, 1, makeuk(1)); allocator->add(0, 1, 1, makeuk(1));
allocator->add(1, 2, 2, makeuk(2)); allocator->add(1, 2, 2, makeuk(2));
allocator->solve(); allocator->solve();
ASSERT_EQ(std::max(align(2), 2 * align(1)), allocator->tot_alloc()); ASSERT_EQ(std::max(align(2 + padding()), 2 * align(1 + padding())),
ASSERT_EQ(std::max(align(2), 2 * align(1)), allocator->tot_alloc());
ASSERT_EQ(std::max(align(2 + padding()), 2 * align(1 + padding())),
allocator->tot_alloc_lower_bound()); allocator->tot_alloc_lower_bound());
} }
...@@ -116,8 +125,8 @@ TEST_P(BasicCorrectness, Overwrite) { ...@@ -116,8 +125,8 @@ TEST_P(BasicCorrectness, Overwrite) {
allocator->add_overwrite_spec(id2, id1, 0); allocator->add_overwrite_spec(id2, id1, 0);
allocator->solve(); allocator->solve();
ASSERT_EQ(align(3), allocator->tot_alloc()); ASSERT_EQ(align(3 + padding()), allocator->tot_alloc());
ASSERT_EQ(align(3), allocator->tot_alloc_lower_bound()); ASSERT_EQ(align(3 + padding()), allocator->tot_alloc_lower_bound());
} }
TEST_P(BasicCorrectness, OverwriteSameEnd) { TEST_P(BasicCorrectness, OverwriteSameEnd) {
...@@ -127,12 +136,12 @@ TEST_P(BasicCorrectness, OverwriteSameEnd) { ...@@ -127,12 +136,12 @@ TEST_P(BasicCorrectness, OverwriteSameEnd) {
allocator->add_overwrite_spec(id1, id0, 0); allocator->add_overwrite_spec(id1, id0, 0);
allocator->solve(); allocator->solve();
ASSERT_EQ(align(1), allocator->tot_alloc()); ASSERT_EQ(align(1 + padding()), allocator->tot_alloc());
ASSERT_EQ(align(1), allocator->tot_alloc_lower_bound()); ASSERT_EQ(align(1 + padding()), allocator->tot_alloc_lower_bound());
} }
INSTANTIATE_TEST_CASE_P(TestStaticMemAllocAlgo, INSTANTIATE_TEST_CASE_P(TestStaticMemAllocAlgo,
BasicCorrectness, TestParam::make_values({1, 2}, {1})); BasicCorrectness, TestParam::make_values({1, 2}, {1, 2}, {1}));
#ifdef __OPTIMIZE__ #ifdef __OPTIMIZE__
...@@ -220,7 +229,7 @@ TEST_P(RandomOpr, Main) { ...@@ -220,7 +229,7 @@ TEST_P(RandomOpr, Main) {
} }
INSTANTIATE_TEST_CASE_P(TestStaticMemAllocAlgo, INSTANTIATE_TEST_CASE_P(TestStaticMemAllocAlgo,
RandomOpr, TestParam::make_values({1, 256}, { RandomOpr, TestParam::make_values({1, 256}, {1, 32}, {
10, INTERVAL_MOVE_MAX_SIZE, 1000, 10000})); 10, INTERVAL_MOVE_MAX_SIZE, 1000, 10000}));
TEST(TestStaticMemAllocAlgo, PushdownChain) { TEST(TestStaticMemAllocAlgo, PushdownChain) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册