提交 9d5c5c07 编写于 作者: M Megvii Engine Team

feat(dnn/naive): workspacebundle support 2D

GitOrigin-RevId: 4408bb9e1d2ced2f9e16cc6784882f89be4a3cf2
上级 f268e0f8
...@@ -156,15 +156,42 @@ void megdnn::infer_conv_shape2d(size_t ih, size_t iw, size_t fh, size_t fw, ...@@ -156,15 +156,42 @@ void megdnn::infer_conv_shape2d(size_t ih, size_t iw, size_t fh, size_t fw,
WorkspaceBundle::WorkspaceBundle(void* ptr, SmallVector<size_t> sizes_in_bytes, WorkspaceBundle::WorkspaceBundle(void* ptr, SmallVector<size_t> sizes_in_bytes,
size_t align_in_bytes) size_t align_in_bytes)
: m_ptr(ptr), : m_ptr(ptr),
m_sizes(std::move(sizes_in_bytes)),
m_align_in_bytes(align_in_bytes) { m_align_in_bytes(align_in_bytes) {
m_aligned_sizes.reserve(m_sizes.size()); m_aligned_sizes.reserve(m_sizes.size());
for (auto size : m_sizes) { m_sizes.push_back(sizes_in_bytes);
size_t reduce_size = 0_z;
m_reduce_num.push_back(0_z);
for (auto size : m_sizes[0]) {
auto aligned_size = size; auto aligned_size = size;
if (size % m_align_in_bytes != 0) { if (size % m_align_in_bytes != 0) {
aligned_size += m_align_in_bytes - size % m_align_in_bytes; aligned_size += m_align_in_bytes - size % m_align_in_bytes;
} }
m_aligned_sizes.push_back(aligned_size); m_aligned_sizes.push_back(aligned_size);
m_reduce_sizes.push_back(reduce_size);
reduce_size += aligned_size;
}
}
WorkspaceBundle::WorkspaceBundle(
SmallVector<SmallVector<size_t>> vector_sizes_in_bytes, void* ptr,
size_t align_in_bytes)
: m_ptr(ptr),
m_sizes(vector_sizes_in_bytes),
m_align_in_bytes(align_in_bytes) {
size_t nr_workspace = 0_z;
size_t reduce_size = 0_z;
for (auto sizes_in_bytes: vector_sizes_in_bytes) {
m_reduce_num.push_back(nr_workspace);
for (auto size : sizes_in_bytes) {
auto aligned_size = size;
if (size % m_align_in_bytes != 0) {
aligned_size += m_align_in_bytes - size % m_align_in_bytes;
}
m_aligned_sizes.push_back(aligned_size);
m_reduce_sizes.push_back(reduce_size);
reduce_size += aligned_size;
nr_workspace++;
}
} }
} }
...@@ -172,22 +199,39 @@ void* WorkspaceBundle::ptr() const { ...@@ -172,22 +199,39 @@ void* WorkspaceBundle::ptr() const {
return m_ptr; return m_ptr;
} }
void* WorkspaceBundle::get(size_t i) const { void* WorkspaceBundle::get(size_t dim1, size_t dim0) const {
megdnn_assert(dim1 < m_sizes.size(), "dim1 is out of range");
megdnn_assert(dim0 < m_sizes[dim1].size(), "dim0 is out of range");
auto addr = reinterpret_cast<uintptr_t>(m_ptr); auto addr = reinterpret_cast<uintptr_t>(m_ptr);
if (addr % m_align_in_bytes != 0) if (addr % m_align_in_bytes != 0)
addr += m_align_in_bytes - addr % m_align_in_bytes; addr += m_align_in_bytes - addr % m_align_in_bytes;
for (size_t j = 0; j < i; ++j) { size_t index = m_reduce_num[dim1] + dim0;
addr += m_aligned_sizes[j]; addr += m_reduce_sizes[index];
} return reinterpret_cast<void*>(addr);
}
void* WorkspaceBundle::get(size_t dim0) const {
megdnn_assert(dim0 < m_aligned_sizes.size(), "dim0 is out of range");
auto addr = reinterpret_cast<uintptr_t>(m_ptr);
if (addr % m_align_in_bytes != 0)
addr += m_align_in_bytes - addr % m_align_in_bytes;
addr += m_reduce_sizes[dim0];
return reinterpret_cast<void*>(addr); return reinterpret_cast<void*>(addr);
} }
size_t WorkspaceBundle::nr_workspace() const { size_t WorkspaceBundle::nr_workspace() const {
return m_sizes.size(); return m_aligned_sizes.size();
}
size_t WorkspaceBundle::get_size(size_t dim1, size_t dim0) const {
megdnn_assert(dim1 < m_sizes.size(), "dim1 is out of range");
megdnn_assert(dim0 < m_sizes[dim1].size(), "dim0 is out of range");
return m_sizes[dim1][dim0];
} }
size_t WorkspaceBundle::get_size(size_t i) const { size_t WorkspaceBundle::get_size(size_t dim0) const {
return m_sizes[i]; megdnn_assert(dim0 < m_aligned_sizes.size(), "dim0 is out of range");
return m_sizes[0][dim0];
} }
void WorkspaceBundle::set(void* ptr) { void WorkspaceBundle::set(void* ptr) {
......
...@@ -194,8 +194,15 @@ std::unique_ptr<T> make_unique(Args&&... args) { ...@@ -194,8 +194,15 @@ std::unique_ptr<T> make_unique(Args&&... args) {
*/ */
class WorkspaceBundle { class WorkspaceBundle {
public: public:
WorkspaceBundle(void* ptr, SmallVector<size_t> sizes_in_bytes, WorkspaceBundle(void* ptr = nullptr,
SmallVector<size_t> sizes_in_bytes = {},
size_t align_in_bytes = 512); size_t align_in_bytes = 512);
/**
* construct 2D workspace buldle
*/
WorkspaceBundle(SmallVector<SmallVector<size_t>> vector_sizes_in_bytes,
void* ptr, size_t align_in_bytes = 512);
/** /**
* \returns raw workspace ptr. * \returns raw workspace ptr.
* *
...@@ -204,26 +211,45 @@ public: ...@@ -204,26 +211,45 @@ public:
*/ */
void* ptr() const; void* ptr() const;
/** /**
* \returns the i-th workspace ptr (aligned) * \returns the 2D [dim1, dim0] workspace ptr (aligned)
*/ */
void* get(size_t i) const; void* get(size_t dim1, size_t dim0) const;
/**
* \returns the 1D [dim0] workspace ptr (aligned)
*/
void* get(size_t dim0) const;
/** /**
* \returns total size taking into account paddings to solve alignment * \returns total size taking into account paddings to solve alignment
* issue. * issue.
*/ */
size_t total_size_in_bytes() const; size_t total_size_in_bytes() const;
size_t get_size(size_t i) const; /**
* \return the 2D [dim1, dim0] workspace size
*/
size_t get_size(size_t dim1, size_t dim0) const;
/**
* \return the 1D [dim0] workspace size
*/
size_t get_size(size_t dim0) const;
size_t nr_workspace() const; size_t nr_workspace() const;
void set(void* ptr); void set(void* ptr);
Workspace get_workspace(size_t i) const { Workspace get_workspace(size_t dim1, size_t dim0) const {
return {static_cast<dt_byte*>(get(i)), get_size(i)}; return {static_cast<dt_byte*>(get(dim1, dim0)), get_size(dim1, dim0)};
}
Workspace get_workspace(size_t dim0) const {
return {static_cast<dt_byte*>(get(dim0)), get_size(dim0)};
} }
private: private:
void* m_ptr; void* m_ptr;
SmallVector<size_t> m_sizes; SmallVector<SmallVector<size_t>> m_sizes;
SmallVector<size_t> m_aligned_sizes; SmallVector<size_t> m_aligned_sizes;
//! all workspace size prefix sum
SmallVector<size_t> m_reduce_sizes;
//! dim1 workspace number prefix sum
SmallVector<size_t> m_reduce_num;
size_t m_align_in_bytes; size_t m_align_in_bytes;
}; };
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "megdnn/basic_types.h" #include "megdnn/basic_types.h"
#include "megdnn/tensor_format.h" #include "megdnn/tensor_format.h"
#include "src/common/utils.h"
// clang-format off // clang-format off
#include "test/common/utils.h" #include "test/common/utils.h"
...@@ -278,4 +279,21 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_W) { ...@@ -278,4 +279,21 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_W) {
} }
} }
TEST(MISC, WORKSPACE_BUNDLE) {
WorkspaceBundle bundle{
{{100, 200}, {435, 234, 143}, {422, 1325, 728}}, nullptr, 64};
bundle.set(reinterpret_cast<void*>(82l));
ASSERT_EQ(bundle.get(0), reinterpret_cast<void*>(128l));
void* dst = reinterpret_cast<void*>(128 + round_up(100, 64));
ASSERT_EQ(bundle.get(0, 1), dst);
dst = reinterpret_cast<void*>(128 + round_up(100, 64) + round_up(200, 64) +
round_up(435, 64));
ASSERT_EQ(bundle.get(1, 1), dst);
dst = reinterpret_cast<void*>(128l + round_up(100, 64) + round_up(200, 64) +
round_up(435, 64) + round_up(234, 64) +
round_up(143, 64) + round_up(422, 64) +
round_up(1325, 64));
ASSERT_EQ(bundle.get(2, 2), dst);
}
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册