diff --git a/imperative/src/impl/profiler/states.h b/imperative/src/impl/profiler/states.h index 0af1ac4208cceb44c836efa61517aa8c5d8fa189..1f452cf3996d0207b940a5488bf2afd9dd62f89d 100644 --- a/imperative/src/impl/profiler/states.h +++ b/imperative/src/impl/profiler/states.h @@ -160,7 +160,7 @@ private: template void register_converter() { m_table[typeid(TItem)] = [](const any_t& input) { - return variant_t(*input.as()); + return variant_t(input.cast()); }; } diff --git a/imperative/src/include/megbrain/imperative/profiler.h b/imperative/src/include/megbrain/imperative/profiler.h index b94da5260b7d4f15ad5757e9cb467420ad5ce6ab..1084e73c432ae76b49f0f9bf54676dbc9238be85 100644 --- a/imperative/src/include/megbrain/imperative/profiler.h +++ b/imperative/src/include/megbrain/imperative/profiler.h @@ -11,7 +11,6 @@ #pragma once -#include #include #include #include @@ -28,6 +27,7 @@ #include "megbrain/imperative/op_def.h" #include "megbrain/imperative/physical_tensor.h" +#include "megbrain/imperative/utils/any.h" namespace mgb { namespace imperative { @@ -51,48 +51,6 @@ public: static std::shared_ptr record_device(CompNode device); }; -class AnyPtr { -public: - struct Deleter { - void* object; - void (*method)(void*, void*); - void operator()(void* ptr) { method(object, ptr); } - }; - -private: - using holder_t = std::unique_ptr; - - const std::type_info* m_type = nullptr; - holder_t m_holder = nullptr; - -public: - AnyPtr() = default; - template < - typename T, - typename = std::enable_if_t, AnyPtr>>> - explicit AnyPtr(T* value, Deleter deleter) { - m_type = &typeid(T); - m_holder = {value, deleter}; - } - template - T* as() { - mgb_assert(is_exactly(), "type mismatch"); - return reinterpret_cast(m_holder.get()); - } - template - const T* as() const { - mgb_assert(is_exactly(), "type mismatch"); - return reinterpret_cast(m_holder.get()); - } - template - bool is_exactly() const { - return std::type_index{typeid(T)} == std::type_index{*m_type}; - } - const std::type_info& type() const { return *m_type; } - bool operator==(std::nullptr_t nptr) const { return m_holder == nullptr; } - operator bool() const { return m_holder != nullptr; } -}; - class Profiler { public: struct Record { @@ -128,7 +86,6 @@ private: std::thread::id m_thread_id; std::vector m_records; std::atomic m_status = Running; - std::unordered_map m_mem_pools; static std::vector sm_records; static options_t sm_profile_options; @@ -161,42 +118,21 @@ public: return *tm_profiler; } - template - static MemPool& get_mem_pool() { - thread_local MemPool* t_pool = nullptr; - if (t_pool == nullptr) { - auto& pool = get_instance().m_mem_pools[typeid(MemPool)]; - if (pool == nullptr) { - pool = - AnyPtr(new MemPool(), - {nullptr, [](void*, void* ptr) { - delete reinterpret_cast*>(ptr); - }}); - } - t_pool = pool.as>(); - } - return *t_pool; - } - static uint64_t next_id() { return sm_last_id++; } template static uint64_t record(TArgs&&... args) { auto& profiler = get_instance(); - auto& mem_pool = get_mem_pool(); + // auto& mem_pool = get_mem_pool(); if constexpr (sm_debug) { Status expected = Running; mgb_assert(profiler.m_status.compare_exchange_strong(expected, Recording)); } uint64_t id = next_id(); profiler::Time time = sm_timer.record_host(); - auto deleter = [](void* obj, void* ptr) { - reinterpret_cast*>(obj)->free(reinterpret_cast(ptr)); - }; profiler.m_records.emplace_back( id, profiler.m_thread_id, time, - AnyPtr{mem_pool.alloc(T{std::forward(args)...}), - {&mem_pool, deleter}}); + AnyPtr::make(T{std::forward(args)...})); if constexpr (sm_debug) { Status expected = Recording; mgb_assert(profiler.m_status.compare_exchange_strong(expected, Running)); @@ -241,7 +177,7 @@ public: bundle.options = get_options(); bundle.start_at = sm_start_at; bundle.thread_dict = get_thread_dict(); - return std::move(bundle); + return bundle; } static option_t get_option(std::string key, option_t default_val) { diff --git a/imperative/src/include/megbrain/imperative/utils/allocator.h b/imperative/src/include/megbrain/imperative/utils/allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..4bc9cca9b9708fd1249d1a3f119c7723cc409dba --- /dev/null +++ b/imperative/src/include/megbrain/imperative/utils/allocator.h @@ -0,0 +1,71 @@ +/** + * \file imperative/src/include/megbrain/imperative/utils/allocator.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include + +#include "megbrain/utils/mempool.h" +#include "megbrain/utils/metahelper.h" + +namespace mgb::imperative { + +template +class Allocator { +public: + using pointer = T*; + using const_pointer = const T*; + using void_pointer = void*; + using const_void_pointer = const void*; + using value_type = T; + using size_type = std::size_t; + using diffenence_type = std::ptrdiff_t; + using pool_type = MemPoolStorage; + +private: + pool_type* m_pool = nullptr; + +public: + Allocator(pool_type* pool) : m_pool(pool) {} + + T* allocate(size_type n) { + mgb_assert(n == 1); + return m_pool->alloc(sizeof(T)); + } + + void deallocate(pointer* p, size_type n) { + mgb_assert(n == 1); + m_pool->free(p); + } + + bool operator==(const Allocator& rhs) const { return m_pool == rhs.m_pool; } + + bool operator!=(const Allocator& rhs) const { return m_pool != rhs.m_pool; } +}; + +template +class ThreadLocalAllocatorAdapter { +public: + using value_type = T; + using size_type = std::size_t; + using pointer = T*; + +public: + T* allocate(size_type n) { mgb_assert(false); } + + void deallocate(pointer* p, size_type n) { mgb_assert(false); } + + bool operator==(const ThreadLocalAllocatorAdapter& rhs) const { return true; } + + bool operator!=(const ThreadLocalAllocatorAdapter& rhs) const { return false; } +}; + +} // namespace mgb::imperative \ No newline at end of file diff --git a/imperative/src/include/megbrain/imperative/utils/any.h b/imperative/src/include/megbrain/imperative/utils/any.h new file mode 100644 index 0000000000000000000000000000000000000000..a7223743ecc2d2859fd187f01249def5167e1957 --- /dev/null +++ b/imperative/src/include/megbrain/imperative/utils/any.h @@ -0,0 +1,70 @@ +/** + * \file imperative/src/include/megbrain/imperative/utils/any.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include + +#include "megbrain/imperative/utils/local_ptr.h" + +namespace mgb::imperative { + +class AnyMixinBase { +private: + const std::type_info* m_type = nullptr; + +public: + AnyMixinBase() = default; + + const std::type_info& type() const { return *m_type; } + + friend class AnyPtr; +}; + +template +class AnyMixin : public AnyMixinBase, public T { +public: + AnyMixin(T&& val) : T(std::move(val)) {} +}; + +class AnyPtr { +public: + using storage_t = LocalPtr; + +private: + storage_t m_storage; + +public: + const std::type_info& type() const { return m_storage->type(); } + template + const T& cast() const { + mgb_assert(is_exactly(), "type mismatch"); + return *static_cast*>(m_storage.get()); + } + template + bool is_exactly() const { + return std::type_index{typeid(T)} == std::type_index{type()}; + } + bool operator==(std::nullptr_t nptr) const { return m_storage == nullptr; } + bool operator!=(std::nullptr_t nptr) const { return m_storage != nullptr; } + operator bool() const { return m_storage != nullptr; } + + template + static AnyPtr make(TArgs&&... args) { + AnyPtr ret; + ret.m_storage = LocalPtr::make>( + std::forward(args)...); + ret.m_storage->m_type = &typeid(T); + return ret; + } +}; + +} // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/utils/box.h b/imperative/src/include/megbrain/imperative/utils/box.h new file mode 100644 index 0000000000000000000000000000000000000000..e3fa6625d67b441ba2626a7ab3f575d0204aa52d --- /dev/null +++ b/imperative/src/include/megbrain/imperative/utils/box.h @@ -0,0 +1,96 @@ +/** + * \file imperative/src/include/megbrain/imperative/utils/visit.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include +#include + +#include "megbrain/utils/metahelper.h" +#include "megbrain/utils/small_vector.h" + +namespace mgb::imperative { + +class BoxBase : public NonCopyableObj { +public: + virtual void reset() = 0; + virtual void set_exception(std::exception_ptr exc) = 0; + virtual bool try_set_exception(std::exception_ptr exc) = 0; +}; + +/** + * \brief An reusable promise + * + * \tparam T type of value + */ +template +class Box final : public BoxBase { +private: + std::promise m_promise; + std::shared_future m_future; + std::mutex m_mutex; + bool m_value_set; + bool m_exception_set; + +public: + Box() { reset(); } + const T& get_value() { return m_future.get(); } + T take_value() { + T value = m_future.get(); + reset(); + return value; + } + void set_value(T value) { + MGB_LOCK_GUARD(m_mutex); + m_promise.set_value(std::move(value)); + m_value_set = true; + } + bool try_set_value(T value) { + MGB_LOCK_GUARD(m_mutex); + if (m_exception_set) { + return false; + } + m_promise.set_value(std::move(value)); + m_value_set = true; + return true; + } + void set_exception(std::exception_ptr exc) override { + MGB_LOCK_GUARD(m_mutex); + m_promise.set_exception(exc); + m_exception_set = true; + } + bool try_set_exception(std::exception_ptr exc) override { + MGB_LOCK_GUARD(m_mutex); + if (m_value_set) { + return false; + } + m_promise.set_exception(exc); + m_exception_set = true; + return true; + } + void reset() override { + MGB_LOCK_GUARD(m_mutex); + m_promise = {}; + m_future = m_promise.get_future(); + m_value_set = false; + m_exception_set = false; + } + + /** + * \brief make an empty box + * + * \return std::shared_ptr + */ + static std::shared_ptr make() { return std::make_shared(); } +}; + +} // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/utils/helper.h b/imperative/src/include/megbrain/imperative/utils/helper.h new file mode 100644 index 0000000000000000000000000000000000000000..36c325e7f0396e52db0bd85152ca920ca860ee63 --- /dev/null +++ b/imperative/src/include/megbrain/imperative/utils/helper.h @@ -0,0 +1,40 @@ +/** + * \file imperative/src/include/megbrain/imperative/utils/span.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include +#include + +namespace mgb { + +namespace imperative { + +template +class CleanupGuard { +private: + T m_callback; + +public: + explicit CleanupGuard(T cb) : m_callback{std::move(cb)} {} + ~CleanupGuard() { m_callback(); } +}; + +inline std::string quoted(std::string str) { + std::stringstream ss; + ss << std::quoted(str); + return ss.str(); +} + +} // namespace imperative + +} // namespace mgb \ No newline at end of file diff --git a/imperative/src/include/megbrain/imperative/utils/intrusive_list.h b/imperative/src/include/megbrain/imperative/utils/intrusive_list.h new file mode 100644 index 0000000000000000000000000000000000000000..1b87d92b9b49287eb0c18b218c7194510d80435f --- /dev/null +++ b/imperative/src/include/megbrain/imperative/utils/intrusive_list.h @@ -0,0 +1,245 @@ +/** + * \file imperative/src/include/megbrain/imperative/utils/intrusive_list.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/utils/metahelper.h" + +namespace mgb::imperative::utils::intrusive_list { + +// copy policy +struct after_t {}; +struct before_t {}; +struct disable_t {}; + +template +struct Tail; + +// invariant: next->prev == this +template +struct Head { + Tail* next; + + Head(Tail* node = nullptr) : next(node) {} + Head(const Head&) = delete; + Head& operator=(const Head&) = delete; + Head(Head&& rhs) : next(rhs.next) { + rhs.next = nullptr; + if (next) { + next->prev = this; + } + } + Head& operator=(Head&& rhs) { + mgb_assert(!next); + next = rhs.next; + rhs.next = nullptr; + if (next) { + next->prev = this; + } + return *this; + } + ~Head() { + if (next) { + next->prev = nullptr; + } + } +}; + +// invariant: prev->next == this +template +struct Tail { + Head* prev; + + Tail(Head* node = nullptr) : prev(node) {} + Tail(const Tail&) = delete; + Tail& operator=(const Tail&) = delete; + Tail(Tail&& rhs) : prev(rhs.prev) { + rhs.prev = nullptr; + if (prev) { + prev->next = this; + } + } + Tail& operator=(Tail&& rhs) { + mgb_assert(!prev); + prev = rhs.prev; + rhs.prev = nullptr; + if (prev) { + prev->next = this; + } + return *this; + } + ~Tail() { + if (prev) { + prev->next = nullptr; + } + } +}; + +template +struct Node; + +template +class Iterator { + T* ptr; + + void inc() { ptr = static_cast(ptr->Head::next); } + void dec() { ptr = static_cast(ptr->Head::prev); } + +public: + Iterator(Head& head) : ptr(static_cast(head.next)) {} + Iterator(Tail& tail) : ptr(static_cast(tail.prev)) {} + + template + Iterator(Node& node) : ptr(static_cast(&node)) {} + + T& operator*() { return *static_cast(ptr); } + T* operator->() { return static_cast(ptr); } + + operator bool() { return ptr; } + bool operator==(const Iterator& rhs) { return ptr == rhs.ptr; } + + Iterator& operator++() { + inc(); + return *this; + } + Iterator& operator--() { + dec(); + return *this; + } + Iterator operator++(int) { + auto ret = *this; + inc(); + return ret; + } + Iterator operator--(int) { + auto ret = *this; + dec(); + return ret; + } +}; + +// Node in a doubly linked list. Unlike std::list, nodes are not owned by a container. +// Instead, nodes may join or leave a list freely. +// NOTE: Derived classes have to explicitly declare copy / assignment as default, +// otherwise the compiler generated version would use the const T& signature, +// which is deleted. +template +struct Node : Tail, Node, T>>, + Head, Node, T>> { +private: + using this_t = Node; + using U = std::conditional_t, this_t, T>; + +public: + using head_t = Head; + using tail_t = Tail; + using head_t::next; + using tail_t::prev; + + Node() = default; + Node(const this_t&) = delete; + this_t& operator=(const this_t&) = delete; + + //! constructed node is inserted after the input node + Node(after_t, head_t& node) : tail_t(&node), head_t(node.next) { + node.next = this; + if (next) { + next->prev = this; + } + } + + //! constructed node is inserted before the input node + Node(before_t, tail_t& node) : head_t(&node), tail_t(node.prev) { + node.prev = this; + if (prev) { + prev->next = this; + } + } + + Node(this_t&& rhs) : tail_t(rhs.prev), head_t(rhs.next) { + rhs.prev = nullptr; + rhs.next = nullptr; + if (prev) { + prev->next = this; + } + if (next) { + next->prev = this; + } + } + + Node& operator=(this_t&& rhs) { + unlink(); + prev = rhs.prev; + next = rhs.next; + rhs.prev = nullptr; + rhs.next = nullptr; + if (prev) { + prev->next = this; + } + if (next) { + next->prev = this; + } + return *this; + } + + template < + typename p = policy, + typename = std::enable_if_t< + std::is_same_v || std::is_same_v, void>> + Node(this_t& rhs) : Node(policy{}, rhs) {} + + template < + typename p = policy, + typename = std::enable_if_t< + std::is_same_v || std::is_same_v, void>> + this_t& operator=(this_t& rhs) { + insert(policy{}, rhs); + return *this; + } + + void unlink() { + if (prev) { + prev->next = next; + } + if (next) { + next->prev = prev; + } + prev = nullptr; + next = nullptr; + } + + //! this node is unlinked from its list and inserted after the input node + void insert(after_t, head_t& node) { + unlink(); + prev = &node; + next = node.next; + node.next = this; + if (next) { + next->prev = this; + } + } + + //! this node is unlinked from its list and inserted before the input node + void insert(before_t, tail_t& node) { + unlink(); + next = &node; + prev = node.prev; + node.prev = this; + if (prev) { + prev->next = this; + } + } + + void insert_before(tail_t& node) { insert(before_t{}, node); } + void insert_after(head_t& node) { insert(after_t{}, node); } + + ~Node() { unlink(); } +}; + +} // namespace mgb::imperative::utils::intrusive_list diff --git a/imperative/src/include/megbrain/imperative/utils/local_ptr.h b/imperative/src/include/megbrain/imperative/utils/local_ptr.h new file mode 100644 index 0000000000000000000000000000000000000000..cc87fd29136a506eb79c10044044c3ab01c6b136 --- /dev/null +++ b/imperative/src/include/megbrain/imperative/utils/local_ptr.h @@ -0,0 +1,285 @@ +/** + * \file imperative/src/include/megbrain/imperative/utils/local_ptr.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include + +#include "megbrain/imperative/utils/mempool.h" +#include "megbrain/utils/metahelper.h" + +namespace mgb::imperative { + +template +class LocalPtrStorage : public NonCopyableObj { +private: + size_t m_ref_count = 0; + size_t m_weak_count = 0; + T* m_pointer = nullptr; + void (*reset)(LocalPtrStorage*) = nullptr; + void (*free)(LocalPtrStorage*) = nullptr; + + void inc_ref() { m_ref_count++; } + + void dec_ref() { + m_ref_count--; + if (m_ref_count == 0) { + reset(this); + m_pointer = nullptr; + reset = nullptr; + if (m_weak_count == 0) { + free(this); + // dead + } + } + } + + void inc_weak_ref() { m_weak_count++; } + + void dec_weak_ref() { + m_weak_count--; + if ((m_weak_count + m_ref_count) == 0) { + free(this); + // dead + } + } + + template + friend class LocalPtr; + + template + friend class LocalWeakPtr; + +public: +}; + +template +class LocalPtrStorgeImpl : public LocalPtrStorage { +private: + std::optional m_value; + void* m_pool = nullptr; + + template + friend class LocalPtr; + + template + friend class LocalWeakPtr; +}; + +template +class LocalWeakPtr; + +/** + * \brief thread-unsafe smart pointer + * + * \tparam T type of value + */ +template +class LocalPtr { +public: + using storage_t = LocalPtrStorage; + using pool_t = MemPool; + using weak_type = LocalWeakPtr; + +private: + storage_t* m_storage = nullptr; + + void emplace(storage_t* ptr) { + if (ptr) { + ptr->inc_ref(); + m_storage = ptr; + } + } + + LocalPtr(storage_t* ptr) { emplace(ptr); } + +public: + LocalPtr() = default; + LocalPtr(const LocalPtr& rhs) { (*this) = rhs; } + LocalPtr(LocalPtr&& rhs) { (*this) = std::move(rhs); } + LocalPtr& operator=(const LocalPtr& rhs) { + if (this == &rhs) { + return *this; + } + auto storage = rhs.m_storage; + if (storage) { + storage->inc_ref(); + } + if (m_storage) { + m_storage->dec_ref(); + // rhs.m_storage may be invalid here + } + m_storage = storage; + return *this; + } + LocalPtr& operator=(LocalPtr&& rhs) { + if (this == &rhs) { + return *this; + } + std::swap(m_storage, rhs.m_storage); + rhs.reset(); + return *this; + } + bool operator==(const LocalPtr& rhs) const { return m_storage == rhs.m_storage; } + bool operator!=(const LocalPtr& rhs) const { return m_storage != rhs.m_storage; } + size_t hash() const { return reinterpret_cast(m_storage); } + + ~LocalPtr() { reset(); } + + /** + * \brief Construct an instance of TDerived and return an LocalPtr + * + * There is an memory pool for each (T, TDerived) pair + * + * \tparam TDerived type of concrete instance, should be subclass of T + * \tparam TArgs + * \param args constructor arguments + * \return LocalPtr points to the instance + */ + template + static LocalPtr make(TArgs&&... args) { + static_assert(std::is_base_of_v); + using storage_impl_t = LocalPtrStorgeImpl; + constexpr auto normalize_size = [](size_t size) { + size_t normalized_size = 64; + while (normalized_size < size) { + normalized_size *= 2; + } + return normalized_size; + }; + using raw_storage_t = + std::aligned_storage_t; + static_assert(alignof(raw_storage_t) % alignof(storage_impl_t) == 0); + static_assert(sizeof(raw_storage_t) >= sizeof(storage_impl_t)); + using pool_t = MemPool; + pool_t& pool = MemPoolUtils::get_thread_local(); + auto* raw_storage = pool.alloc_raw(); + auto* storage = reinterpret_cast(raw_storage); + new (storage) storage_impl_t(); + storage->m_value.emplace(std::forward(args)...); + storage->m_pointer = &*storage->m_value; + storage->reset = [](storage_t* storage) { + auto* storage_impl = static_cast(storage); + storage_impl->m_value.reset(); + storage_impl->m_pointer = nullptr; + }; + storage->free = [](storage_t* storage_base) { + auto* storage = static_cast(storage_base); + auto* pool = reinterpret_cast(storage->m_pool); + storage->m_pool = nullptr; + storage->~storage_impl_t(); + auto* raw_storage = reinterpret_cast(storage); + pool->free_raw(raw_storage); + }; + storage->m_pool = &pool; + return {(storage_t*)storage}; + } + + T& operator*() const { return *get(); } + + T* get() const { + if ((!m_storage) || !m_storage->m_pointer) { + return nullptr; + } + return m_storage->m_pointer; + } + + T* operator->() const { return get(); } + + size_t ref_count() const { return m_storage->m_ref_count; } + + bool unique() const { return ref_count() == 1; } + + void reset() { + if (m_storage) { + m_storage->dec_ref(); + m_storage = nullptr; + } + } + + operator bool() const { return bool(m_storage); } + bool operator==(std::nullptr_t nptr) const { return m_storage == nullptr; } + bool operator!=(std::nullptr_t nptr) const { return m_storage != nullptr; } + + template + friend class LocalWeakPtr; +}; + +template +class LocalWeakPtr { +public: + using storage_t = LocalPtrStorage; + +private: + storage_t* m_storage = nullptr; + + void emplace(storage_t* ptr) { + if (ptr) { + ptr->inc_weak_ref(); + m_storage = ptr; + } + } + +public: + LocalWeakPtr() = default; + LocalWeakPtr(const LocalPtr& rhs) { emplace(rhs.m_storage); } + LocalWeakPtr(const LocalWeakPtr& rhs) { (*this) = rhs; } + LocalWeakPtr(LocalWeakPtr&& rhs) { (*this) = std::move(rhs); } + LocalWeakPtr& operator=(const LocalWeakPtr& rhs) { + if (this == &rhs) { + return *this; + } + reset(); + emplace(rhs.m_storage); + return *this; + } + LocalWeakPtr& operator=(LocalWeakPtr&& rhs) { + if (this == &rhs) { + return *this; + } + std::swap(m_storage, rhs.m_storage); + rhs.reset(); + return *this; + } + + ~LocalWeakPtr() { reset(); } + + void reset() { + if (m_storage) { + m_storage->dec_weak_ref(); + m_storage = nullptr; + } + } + + LocalPtr lock() const { + if (m_storage && m_storage->m_ref_count) { + return {m_storage}; + } + return {}; + } + + bool operator==(const LocalWeakPtr& rhs) const { + return m_storage == rhs.m_storage; + } + + bool operator!=(const LocalWeakPtr& rhs) const { + return m_storage != rhs.m_storage; + } + + size_t hash() const { return reinterpret_cast(m_storage); } +}; + +template +LocalPtr make_local(TArgs&&... args) { + return LocalPtr::template make(std::forward(args)...); +} + +} // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/utils/map.h b/imperative/src/include/megbrain/imperative/utils/map.h new file mode 100644 index 0000000000000000000000000000000000000000..f096db5fea29db57b8e27af95134b60691e60408 --- /dev/null +++ b/imperative/src/include/megbrain/imperative/utils/map.h @@ -0,0 +1,157 @@ +/** + * \file imperative/src/include/megbrain/imperative/utils/map.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include + +#include "megbrain/utils/metahelper.h" + +namespace mgb::imperative { + +/** + * \brief an hash map optimized for weak pointer as key + * + * Keys were scanned automatically, so values referenced by invalid keys whould be + * released soon + * + * \tparam TKey key type, requires(bool(key.lock())) + * \tparam TValue value type + */ +template +class WeakKeyMap : public NonCopyableObj { +public: + using storage_t = std::unordered_map; + +private: + storage_t m_storage; + typename storage_t::iterator m_cursor = m_storage.begin(); + + /** + * \brief select a key and verify that whether it is invalid. If yes, erase it + * + */ + void _step() { + if (m_cursor == m_storage.end()) { + m_cursor = m_storage.begin(); + return; + } + auto key = m_cursor->first; + if (!key.lock()) { + m_cursor = m_storage.erase(m_cursor); + } else { + ++m_cursor; + } + } + +public: + size_t count(TKey key) { + _step(); + _step(); + return m_storage.count(key); + } + + TValue& at(TKey key) const { return m_storage.at(key); } + + TValue& at(TKey key) { + _step(); + _step(); + return m_storage.at(key); + } + + TValue& operator[](TKey key) { + _step(); + _step(); + if (m_storage.count(key)) { + return m_storage.at(key); + } else { + size_t bucket_count = m_storage.bucket_count(); + TValue& result = m_storage[key]; + if (bucket_count != m_storage.bucket_count()) { + m_cursor = m_storage.begin(); + } + return result; + } + } + + std::optional try_get(TKey key) const { + auto iter = m_storage.find(key); + if (iter == m_storage.end()) { + return {}; + } + return {iter->second}; + } + + std::optional try_get(TKey key) { + _step(); + _step(); + return ((const WeakKeyMap*)this)->try_get(std::move(key)); + } +}; + +template +class WeakValueMap : public NonCopyableObj { +public: + using storage_t = std::unordered_map; + +private: + storage_t m_storage; + typename storage_t::iterator m_cursor = m_storage.begin(); + + /** + * \brief select a key and verify that whether it is invalid. If yes, erase it + * + */ + void _step() { + if (m_cursor == m_storage.end()) { + m_cursor = m_storage.begin(); + return; + } + auto value = m_cursor->second; + if (!value.lock()) { + m_cursor = m_storage.erase(m_cursor); + } else { + ++m_cursor; + } + } + +public: + size_t count(TKey key) { + _step(); + _step(); + return m_storage.count(key); + } + + TValue& at(TKey key) const { return m_storage.at(key); } + + TValue& at(TKey key) { + _step(); + _step(); + return m_storage.at(key); + } + + TValue& operator[](TKey key) { + _step(); + _step(); + if (m_storage.count(key)) { + return m_storage.at(key); + } else { + size_t bucket_count = m_storage.bucket_count(); + TValue& result = m_storage[key]; + if (bucket_count != m_storage.bucket_count()) { + m_cursor = m_storage.begin(); + } + return result; + } + } +}; + +} // namespace mgb::imperative \ No newline at end of file diff --git a/imperative/src/include/megbrain/imperative/utils/mempool.h b/imperative/src/include/megbrain/imperative/utils/mempool.h new file mode 100644 index 0000000000000000000000000000000000000000..ca3b4778248eceefa1b3bec26468f7c1881f2480 --- /dev/null +++ b/imperative/src/include/megbrain/imperative/utils/mempool.h @@ -0,0 +1,70 @@ +/** + * \file imperative/src/include/megbrain/imperative/utils/mempool.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include +#include + +#include "megbrain/utils/mempool.h" +#include "megbrain/utils/metahelper.h" + +namespace mgb::imperative { + +template +class MemPoolUtils { +private: + static std::mutex sm_mutex; + static std::unordered_map>> + sm_instances; + static thread_local MemPool* tm_instance; + static MemPool* sm_instance; + +public: + static MemPool& get_thread_local() { + if (!tm_instance) { + MGB_LOCK_GUARD(sm_mutex); + auto& instance = sm_instances[std::this_thread::get_id()]; + if (!instance) { // thread id may be duplicated + instance = std::make_unique>(); + } + tm_instance = instance.get(); + } + return *tm_instance; + } + static MemPool& get_static() { + if (!sm_instance) { + MGB_LOCK_GUARD(sm_mutex); + auto& instance = sm_instances[{}]; + if (!instance) { // double check + instance = std::make_unique>(); + sm_instance = instance.get(); + } + mgb_assert(sm_instance); + } + } +}; + +template +std::mutex MemPoolUtils::sm_mutex; + +template +std::unordered_map>> + MemPoolUtils::sm_instances; + +template +thread_local MemPool* MemPoolUtils::tm_instance; + +template +MemPool* MemPoolUtils::sm_instance; + +} // namespace mgb::imperative \ No newline at end of file diff --git a/imperative/src/include/megbrain/imperative/utils/span.h b/imperative/src/include/megbrain/imperative/utils/span.h new file mode 100644 index 0000000000000000000000000000000000000000..bb5efe52ac3b82f0cafc6b9be1ee4ae5abbb8f2f --- /dev/null +++ b/imperative/src/include/megbrain/imperative/utils/span.h @@ -0,0 +1,69 @@ +/** + * \file imperative/src/include/megbrain/imperative/utils/span.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include + +#include "megbrain/utils/small_vector.h" + +namespace mgb::imperative { + +/** + * \brief wrapper for c-style array + * + * \tparam T value type + */ +template +class Span { +private: + const T* m_begin = nullptr; + const T* m_end = nullptr; + +public: + Span() {} + Span(const T* begin, const T* end) : m_begin{begin}, m_end{end} {} + Span(const T* begin, size_t size) : Span(begin, begin + size) {} + template + Span(TContainer& container) : Span(container.data(), container.size()) {} + const T* begin() const { return m_begin; } + const T* end() const { return m_end; } + const T* data() const { return m_begin; } + size_t size() const { return m_end - m_begin; } + template + TContainer copy_into() { + return TContainer(m_begin, m_end); + } + const T& operator[](size_t idx) const { return m_begin[idx]; } + const T& at(size_t idx) const { return m_begin[idx]; } + const T& item() const { + mgb_assert( + m_end - m_begin == 1, "size mismatch: %zu vs %zu", (m_end - m_begin), + (size_t)1); + return m_begin[0]; + } + + template + const std::array& as_array() { + mgb_assert( + m_end - m_begin == N, "size mismatch: %zu vs %zu", (m_end - m_begin), + N); + return *reinterpret_cast*>(m_begin); + } + + Span sub(size_t begin, size_t length) { + mgb_assert(begin + length <= m_end - m_begin); + return {m_begin + begin, length}; + } +}; + +} // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/utils/to_string.h b/imperative/src/include/megbrain/imperative/utils/to_string.h index d0ca998d75b8a098dff1f3c5506f098a0ddbbfde..2dd47dfbd14f11f8bc1dffa26dcb277f521a0b2a 100644 --- a/imperative/src/include/megbrain/imperative/utils/to_string.h +++ b/imperative/src/include/megbrain/imperative/utils/to_string.h @@ -16,6 +16,7 @@ #include #include +#include "megbrain/imperative/utils/span.h" #include "megbrain/tensor.h" #include "megbrain/utils/small_vector.h" @@ -59,6 +60,22 @@ struct ToStringTrait> { } }; +template +struct ToStringTrait> { + std::string operator()(const std::vector& v) const { + if (v.empty()) { + return "[]"; + } + std::string result = "["; + result += to_string(v[0]); + for (size_t i = 1; i < v.size(); ++i) { + result += ", "; + result += to_string(v[i]); + } + return result + "]"; + } +}; + template struct ToStringTrait> { std::string operator()(const std::shared_ptr& sp) const { @@ -115,4 +132,36 @@ struct ToStringTrait { std::string operator()(CompNode device) const { return device.to_string(); } }; +inline std::string string_join(Span span, char delimiter = ',') { + std::string buffer = "["; + for (size_t i = 1; i < span.size(); ++i) { + if (i) { + buffer.push_back(delimiter); + } + buffer.append(span[0]); + } + return buffer + "]"; +} + +template +struct ToStringTrait> { + std::string operator()(Span span) const { + if (span.size() == 0) { + return "[]"; + } + std::string result = "["; + result += to_string(span[0]); + for (size_t i = 1; i < span.size(); ++i) { + result += ", "; + result += to_string(span[i]); + } + return result + "]"; + } +}; + +template <> +struct ToStringTrait { + std::string operator()(const std::type_info& info) const { return info.name(); } +}; + } // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/utils/value_shape.h b/imperative/src/include/megbrain/imperative/utils/value_shape.h new file mode 100644 index 0000000000000000000000000000000000000000..ecf911480c1b73dd382620c7cc5e78e0858ae34c --- /dev/null +++ b/imperative/src/include/megbrain/imperative/utils/value_shape.h @@ -0,0 +1,104 @@ +/** + * \file imperative/src/include/megbrain/imperative/utils/visit.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include + +#include "megbrain/imperative/utils/span.h" +#include "megbrain/tensor.h" + +namespace mgb::imperative { + +/** + * \brief like TensorShape, but allow real scalar shape. + * + */ +struct ValueShape { + size_t shape[TensorShape::MAX_NDIM]; + int ndim = 0; + + ValueShape() = default; + ValueShape(std::initializer_list dims) { + for (auto&& dim : dims) { + shape[ndim++] = dim; + } + } + ValueShape(Span dims) { + for (auto&& dim : dims) { + shape[ndim++] = dim; + } + } + + size_t& operator[](int axis) { return shape[axis]; } + + size_t operator[](int axis) const { return shape[axis]; } + + size_t at(int axis) const { + mgb_assert(axis < ndim); + return shape[axis]; + } + + size_t total_nr_elems() const { + size_t prod = 1; + for (int i = 0; i < ndim; ++i) { + prod *= shape[i]; + } + return prod; + } + + bool is_scalar() const { return ndim == 0; } + + std::string to_string() const { + std::string buffer = "{"; + for (size_t i = 0; i < ndim; ++i) { + if (i) { + buffer.append(","); + } + buffer.append(std::to_string(shape[i])); + } + buffer.append("}"); + return buffer; + } + + static ValueShape from(TensorShape tensor_shape) { + mgb_assert(tensor_shape.ndim); + return Span{tensor_shape.shape, tensor_shape.ndim}; + } + + TensorShape as_tensor_shape() const { + mgb_assert(ndim != 0); + TensorShape ret; + for (size_t i = 0; i < ndim; ++i) { + ret.shape[i] = shape[i]; + } + ret.ndim = ndim; + return ret; + } + + bool operator==(const ValueShape& rhs) const { + if (ndim != rhs.ndim) { + return false; + } + for (size_t i = 0; i < ndim; ++i) { + if (shape[i] != rhs.shape[i]) { + return false; + } + } + return true; + } +}; + +static_assert(sizeof(size_t) >= sizeof(int)); +static_assert(TensorShape::MAX_NDIM == 7); +static_assert(sizeof(ValueShape) <= sizeof(size_t) * 8); + +} // namespace mgb::imperative \ No newline at end of file diff --git a/imperative/src/include/megbrain/imperative/utils/visit.h b/imperative/src/include/megbrain/imperative/utils/visit.h new file mode 100644 index 0000000000000000000000000000000000000000..a2189475a04f7e71e70db9fa6a90baa45f68af5a --- /dev/null +++ b/imperative/src/include/megbrain/imperative/utils/visit.h @@ -0,0 +1,26 @@ +/** + * \file imperative/src/include/megbrain/imperative/utils/visit.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include + +#include "megbrain/utils/small_vector.h" + +namespace mgb::imperative { + +template +class Visitor : public TVisitors... { +public: + using TVisitors::operator()...; +}; + +} // namespace mgb::imperative diff --git a/imperative/src/test/profiler.cpp b/imperative/src/test/profiler.cpp index 9d54ab073c771fdb234ce119ca0c9f4e50700614..7f0dd76a0680a6f277b58b10df4a37b8224d1b89 100644 --- a/imperative/src/test/profiler.cpp +++ b/imperative/src/test/profiler.cpp @@ -28,10 +28,10 @@ TEST(TestProfiler, ImperativeLogProfile) { auto results = imperative::Profiler::collect(); imperative::Profiler::stop_profile(); mgb_assert(results.entries.size() == 2); - auto* event_start = results.entries[0].data.as(); - auto* event_finish = results.entries[1].data.as(); - mgb_assert(event_start && event_start->title == "XXX"); - mgb_assert(event_finish && event_finish->title == "XXX"); + auto& event_start = results.entries[0].data.cast(); + auto& event_finish = results.entries[1].data.cast(); + mgb_assert(event_start.title == "XXX"); + mgb_assert(event_finish.title == "XXX"); mgb_assert(results.entries[0].time < results.entries[1].time); mgb_assert(results.entries[0].id < results.entries[1].id); }