From 9d5c5c078831e4f21b1184a7c90e422eca9be349 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 19 Jun 2020 21:51:22 +0800 Subject: [PATCH] feat(dnn/naive): workspacebundle support 2D GitOrigin-RevId: 4408bb9e1d2ced2f9e16cc6784882f89be4a3cf2 --- dnn/src/common/utils.cpp | 62 ++++++++++++++++++++++++---- dnn/src/common/utils.h | 40 ++++++++++++++---- dnn/test/common/test_basic_types.cpp | 18 ++++++++ 3 files changed, 104 insertions(+), 16 deletions(-) diff --git a/dnn/src/common/utils.cpp b/dnn/src/common/utils.cpp index 2dd2ca703..a11f340ae 100644 --- a/dnn/src/common/utils.cpp +++ b/dnn/src/common/utils.cpp @@ -156,15 +156,42 @@ void megdnn::infer_conv_shape2d(size_t ih, size_t iw, size_t fh, size_t fw, WorkspaceBundle::WorkspaceBundle(void* ptr, SmallVector sizes_in_bytes, size_t align_in_bytes) : m_ptr(ptr), - m_sizes(std::move(sizes_in_bytes)), m_align_in_bytes(align_in_bytes) { m_aligned_sizes.reserve(m_sizes.size()); - for (auto size : m_sizes) { + m_sizes.push_back(sizes_in_bytes); + size_t reduce_size = 0_z; + m_reduce_num.push_back(0_z); + for (auto size : m_sizes[0]) { auto aligned_size = size; if (size % m_align_in_bytes != 0) { aligned_size += m_align_in_bytes - size % m_align_in_bytes; } m_aligned_sizes.push_back(aligned_size); + m_reduce_sizes.push_back(reduce_size); + reduce_size += aligned_size; + } +} + +WorkspaceBundle::WorkspaceBundle( + SmallVector> vector_sizes_in_bytes, void* ptr, + size_t align_in_bytes) + : m_ptr(ptr), + m_sizes(vector_sizes_in_bytes), + m_align_in_bytes(align_in_bytes) { + size_t nr_workspace = 0_z; + size_t reduce_size = 0_z; + for (auto sizes_in_bytes: vector_sizes_in_bytes) { + m_reduce_num.push_back(nr_workspace); + for (auto size : sizes_in_bytes) { + auto aligned_size = size; + if (size % m_align_in_bytes != 0) { + aligned_size += m_align_in_bytes - size % m_align_in_bytes; + } + m_aligned_sizes.push_back(aligned_size); + m_reduce_sizes.push_back(reduce_size); + reduce_size += aligned_size; + nr_workspace++; + } } } @@ -172,22 +199,39 @@ void* WorkspaceBundle::ptr() const { return m_ptr; } -void* WorkspaceBundle::get(size_t i) const { +void* WorkspaceBundle::get(size_t dim1, size_t dim0) const { + megdnn_assert(dim1 < m_sizes.size(), "dim1 is out of range"); + megdnn_assert(dim0 < m_sizes[dim1].size(), "dim0 is out of range"); auto addr = reinterpret_cast(m_ptr); if (addr % m_align_in_bytes != 0) addr += m_align_in_bytes - addr % m_align_in_bytes; - for (size_t j = 0; j < i; ++j) { - addr += m_aligned_sizes[j]; - } + size_t index = m_reduce_num[dim1] + dim0; + addr += m_reduce_sizes[index]; + return reinterpret_cast(addr); +} + +void* WorkspaceBundle::get(size_t dim0) const { + megdnn_assert(dim0 < m_aligned_sizes.size(), "dim0 is out of range"); + auto addr = reinterpret_cast(m_ptr); + if (addr % m_align_in_bytes != 0) + addr += m_align_in_bytes - addr % m_align_in_bytes; + addr += m_reduce_sizes[dim0]; return reinterpret_cast(addr); } size_t WorkspaceBundle::nr_workspace() const { - return m_sizes.size(); + return m_aligned_sizes.size(); +} + +size_t WorkspaceBundle::get_size(size_t dim1, size_t dim0) const { + megdnn_assert(dim1 < m_sizes.size(), "dim1 is out of range"); + megdnn_assert(dim0 < m_sizes[dim1].size(), "dim0 is out of range"); + return m_sizes[dim1][dim0]; } -size_t WorkspaceBundle::get_size(size_t i) const { - return m_sizes[i]; +size_t WorkspaceBundle::get_size(size_t dim0) const { + megdnn_assert(dim0 < m_aligned_sizes.size(), "dim0 is out of range"); + return m_sizes[0][dim0]; } void WorkspaceBundle::set(void* ptr) { diff --git a/dnn/src/common/utils.h b/dnn/src/common/utils.h index 449c9b04e..d73a01e72 100644 --- a/dnn/src/common/utils.h +++ b/dnn/src/common/utils.h @@ -194,8 +194,15 @@ std::unique_ptr make_unique(Args&&... args) { */ class WorkspaceBundle { public: - WorkspaceBundle(void* ptr, SmallVector sizes_in_bytes, + WorkspaceBundle(void* ptr = nullptr, + SmallVector sizes_in_bytes = {}, size_t align_in_bytes = 512); + + /** + * construct 2D workspace buldle + */ + WorkspaceBundle(SmallVector> vector_sizes_in_bytes, + void* ptr, size_t align_in_bytes = 512); /** * \returns raw workspace ptr. * @@ -204,26 +211,45 @@ public: */ void* ptr() const; /** - * \returns the i-th workspace ptr (aligned) + * \returns the 2D [dim1, dim0] workspace ptr (aligned) */ - void* get(size_t i) const; + void* get(size_t dim1, size_t dim0) const; + /** + * \returns the 1D [dim0] workspace ptr (aligned) + */ + void* get(size_t dim0) const; /** * \returns total size taking into account paddings to solve alignment * issue. */ size_t total_size_in_bytes() const; - size_t get_size(size_t i) const; + /** + * \return the 2D [dim1, dim0] workspace size + */ + size_t get_size(size_t dim1, size_t dim0) const; + + /** + * \return the 1D [dim0] workspace size + */ + size_t get_size(size_t dim0) const; size_t nr_workspace() const; void set(void* ptr); - Workspace get_workspace(size_t i) const { - return {static_cast(get(i)), get_size(i)}; + Workspace get_workspace(size_t dim1, size_t dim0) const { + return {static_cast(get(dim1, dim0)), get_size(dim1, dim0)}; + } + Workspace get_workspace(size_t dim0) const { + return {static_cast(get(dim0)), get_size(dim0)}; } private: void* m_ptr; - SmallVector m_sizes; + SmallVector> m_sizes; SmallVector m_aligned_sizes; + //! all workspace size prefix sum + SmallVector m_reduce_sizes; + //! dim1 workspace number prefix sum + SmallVector m_reduce_num; size_t m_align_in_bytes; }; diff --git a/dnn/test/common/test_basic_types.cpp b/dnn/test/common/test_basic_types.cpp index 31a6c2e28..899dc7d2e 100644 --- a/dnn/test/common/test_basic_types.cpp +++ b/dnn/test/common/test_basic_types.cpp @@ -11,6 +11,7 @@ #include "megdnn/basic_types.h" #include "megdnn/tensor_format.h" +#include "src/common/utils.h" // clang-format off #include "test/common/utils.h" @@ -278,4 +279,21 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_W) { } } +TEST(MISC, WORKSPACE_BUNDLE) { + WorkspaceBundle bundle{ + {{100, 200}, {435, 234, 143}, {422, 1325, 728}}, nullptr, 64}; + bundle.set(reinterpret_cast(82l)); + ASSERT_EQ(bundle.get(0), reinterpret_cast(128l)); + void* dst = reinterpret_cast(128 + round_up(100, 64)); + ASSERT_EQ(bundle.get(0, 1), dst); + dst = reinterpret_cast(128 + round_up(100, 64) + round_up(200, 64) + + round_up(435, 64)); + ASSERT_EQ(bundle.get(1, 1), dst); + dst = reinterpret_cast(128l + round_up(100, 64) + round_up(200, 64) + + round_up(435, 64) + round_up(234, 64) + + round_up(143, 64) + round_up(422, 64) + + round_up(1325, 64)); + ASSERT_EQ(bundle.get(2, 2), dst); +} + // vim: syntax=cpp.doxygen -- GitLab