提交 9d5c5c07 编写于 作者: M Megvii Engine Team

feat(dnn/naive): workspacebundle support 2D

GitOrigin-RevId: 4408bb9e1d2ced2f9e16cc6784882f89be4a3cf2
上级 f268e0f8
......@@ -156,15 +156,42 @@ void megdnn::infer_conv_shape2d(size_t ih, size_t iw, size_t fh, size_t fw,
WorkspaceBundle::WorkspaceBundle(void* ptr, SmallVector<size_t> sizes_in_bytes,
size_t align_in_bytes)
: m_ptr(ptr),
m_sizes(std::move(sizes_in_bytes)),
m_align_in_bytes(align_in_bytes) {
m_aligned_sizes.reserve(m_sizes.size());
for (auto size : m_sizes) {
m_sizes.push_back(sizes_in_bytes);
size_t reduce_size = 0_z;
m_reduce_num.push_back(0_z);
for (auto size : m_sizes[0]) {
auto aligned_size = size;
if (size % m_align_in_bytes != 0) {
aligned_size += m_align_in_bytes - size % m_align_in_bytes;
}
m_aligned_sizes.push_back(aligned_size);
m_reduce_sizes.push_back(reduce_size);
reduce_size += aligned_size;
}
}
WorkspaceBundle::WorkspaceBundle(
SmallVector<SmallVector<size_t>> vector_sizes_in_bytes, void* ptr,
size_t align_in_bytes)
: m_ptr(ptr),
m_sizes(vector_sizes_in_bytes),
m_align_in_bytes(align_in_bytes) {
size_t nr_workspace = 0_z;
size_t reduce_size = 0_z;
for (auto sizes_in_bytes: vector_sizes_in_bytes) {
m_reduce_num.push_back(nr_workspace);
for (auto size : sizes_in_bytes) {
auto aligned_size = size;
if (size % m_align_in_bytes != 0) {
aligned_size += m_align_in_bytes - size % m_align_in_bytes;
}
m_aligned_sizes.push_back(aligned_size);
m_reduce_sizes.push_back(reduce_size);
reduce_size += aligned_size;
nr_workspace++;
}
}
}
......@@ -172,22 +199,39 @@ void* WorkspaceBundle::ptr() const {
return m_ptr;
}
void* WorkspaceBundle::get(size_t i) const {
void* WorkspaceBundle::get(size_t dim1, size_t dim0) const {
megdnn_assert(dim1 < m_sizes.size(), "dim1 is out of range");
megdnn_assert(dim0 < m_sizes[dim1].size(), "dim0 is out of range");
auto addr = reinterpret_cast<uintptr_t>(m_ptr);
if (addr % m_align_in_bytes != 0)
addr += m_align_in_bytes - addr % m_align_in_bytes;
for (size_t j = 0; j < i; ++j) {
addr += m_aligned_sizes[j];
}
size_t index = m_reduce_num[dim1] + dim0;
addr += m_reduce_sizes[index];
return reinterpret_cast<void*>(addr);
}
void* WorkspaceBundle::get(size_t dim0) const {
megdnn_assert(dim0 < m_aligned_sizes.size(), "dim0 is out of range");
auto addr = reinterpret_cast<uintptr_t>(m_ptr);
if (addr % m_align_in_bytes != 0)
addr += m_align_in_bytes - addr % m_align_in_bytes;
addr += m_reduce_sizes[dim0];
return reinterpret_cast<void*>(addr);
}
size_t WorkspaceBundle::nr_workspace() const {
return m_sizes.size();
return m_aligned_sizes.size();
}
size_t WorkspaceBundle::get_size(size_t dim1, size_t dim0) const {
megdnn_assert(dim1 < m_sizes.size(), "dim1 is out of range");
megdnn_assert(dim0 < m_sizes[dim1].size(), "dim0 is out of range");
return m_sizes[dim1][dim0];
}
size_t WorkspaceBundle::get_size(size_t i) const {
return m_sizes[i];
size_t WorkspaceBundle::get_size(size_t dim0) const {
megdnn_assert(dim0 < m_aligned_sizes.size(), "dim0 is out of range");
return m_sizes[0][dim0];
}
void WorkspaceBundle::set(void* ptr) {
......
......@@ -194,8 +194,15 @@ std::unique_ptr<T> make_unique(Args&&... args) {
*/
class WorkspaceBundle {
public:
WorkspaceBundle(void* ptr, SmallVector<size_t> sizes_in_bytes,
WorkspaceBundle(void* ptr = nullptr,
SmallVector<size_t> sizes_in_bytes = {},
size_t align_in_bytes = 512);
/**
* construct 2D workspace buldle
*/
WorkspaceBundle(SmallVector<SmallVector<size_t>> vector_sizes_in_bytes,
void* ptr, size_t align_in_bytes = 512);
/**
* \returns raw workspace ptr.
*
......@@ -204,26 +211,45 @@ public:
*/
void* ptr() const;
/**
* \returns the i-th workspace ptr (aligned)
* \returns the 2D [dim1, dim0] workspace ptr (aligned)
*/
void* get(size_t i) const;
void* get(size_t dim1, size_t dim0) const;
/**
* \returns the 1D [dim0] workspace ptr (aligned)
*/
void* get(size_t dim0) const;
/**
* \returns total size taking into account paddings to solve alignment
* issue.
*/
size_t total_size_in_bytes() const;
size_t get_size(size_t i) const;
/**
* \return the 2D [dim1, dim0] workspace size
*/
size_t get_size(size_t dim1, size_t dim0) const;
/**
* \return the 1D [dim0] workspace size
*/
size_t get_size(size_t dim0) const;
size_t nr_workspace() const;
void set(void* ptr);
Workspace get_workspace(size_t i) const {
return {static_cast<dt_byte*>(get(i)), get_size(i)};
Workspace get_workspace(size_t dim1, size_t dim0) const {
return {static_cast<dt_byte*>(get(dim1, dim0)), get_size(dim1, dim0)};
}
Workspace get_workspace(size_t dim0) const {
return {static_cast<dt_byte*>(get(dim0)), get_size(dim0)};
}
private:
void* m_ptr;
SmallVector<size_t> m_sizes;
SmallVector<SmallVector<size_t>> m_sizes;
SmallVector<size_t> m_aligned_sizes;
//! all workspace size prefix sum
SmallVector<size_t> m_reduce_sizes;
//! dim1 workspace number prefix sum
SmallVector<size_t> m_reduce_num;
size_t m_align_in_bytes;
};
......
......@@ -11,6 +11,7 @@
#include "megdnn/basic_types.h"
#include "megdnn/tensor_format.h"
#include "src/common/utils.h"
// clang-format off
#include "test/common/utils.h"
......@@ -278,4 +279,21 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_W) {
}
}
TEST(MISC, WORKSPACE_BUNDLE) {
WorkspaceBundle bundle{
{{100, 200}, {435, 234, 143}, {422, 1325, 728}}, nullptr, 64};
bundle.set(reinterpret_cast<void*>(82l));
ASSERT_EQ(bundle.get(0), reinterpret_cast<void*>(128l));
void* dst = reinterpret_cast<void*>(128 + round_up(100, 64));
ASSERT_EQ(bundle.get(0, 1), dst);
dst = reinterpret_cast<void*>(128 + round_up(100, 64) + round_up(200, 64) +
round_up(435, 64));
ASSERT_EQ(bundle.get(1, 1), dst);
dst = reinterpret_cast<void*>(128l + round_up(100, 64) + round_up(200, 64) +
round_up(435, 64) + round_up(234, 64) +
round_up(143, 64) + round_up(422, 64) +
round_up(1325, 64));
ASSERT_EQ(bundle.get(2, 2), dst);
}
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册