提交 2be6ceda 编写于 作者: M Megvii Engine Team

feat(imperative/utils): add serval utils

GitOrigin-RevId: f401663ae3641d8a6467cf4d10cba17a1d3f4553
上级 e7c2ed11
......@@ -160,7 +160,7 @@ private:
template <typename TItem>
void register_converter() {
m_table[typeid(TItem)] = [](const any_t& input) {
return variant_t(*input.as<TItem>());
return variant_t(input.cast<TItem>());
};
}
......
......@@ -11,7 +11,6 @@
#pragma once
#include <any>
#include <bitset>
#include <chrono>
#include <deque>
......@@ -28,6 +27,7 @@
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/physical_tensor.h"
#include "megbrain/imperative/utils/any.h"
namespace mgb {
namespace imperative {
......@@ -51,48 +51,6 @@ public:
static std::shared_ptr<CompNode::Event> record_device(CompNode device);
};
class AnyPtr {
public:
struct Deleter {
void* object;
void (*method)(void*, void*);
void operator()(void* ptr) { method(object, ptr); }
};
private:
using holder_t = std::unique_ptr<void, Deleter>;
const std::type_info* m_type = nullptr;
holder_t m_holder = nullptr;
public:
AnyPtr() = default;
template <
typename T,
typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, AnyPtr>>>
explicit AnyPtr(T* value, Deleter deleter) {
m_type = &typeid(T);
m_holder = {value, deleter};
}
template <typename T>
T* as() {
mgb_assert(is_exactly<T>(), "type mismatch");
return reinterpret_cast<T*>(m_holder.get());
}
template <typename T>
const T* as() const {
mgb_assert(is_exactly<T>(), "type mismatch");
return reinterpret_cast<const T*>(m_holder.get());
}
template <typename T>
bool is_exactly() const {
return std::type_index{typeid(T)} == std::type_index{*m_type};
}
const std::type_info& type() const { return *m_type; }
bool operator==(std::nullptr_t nptr) const { return m_holder == nullptr; }
operator bool() const { return m_holder != nullptr; }
};
class Profiler {
public:
struct Record {
......@@ -128,7 +86,6 @@ private:
std::thread::id m_thread_id;
std::vector<Record> m_records;
std::atomic<Status> m_status = Running;
std::unordered_map<std::type_index, AnyPtr> m_mem_pools;
static std::vector<entry_t> sm_records;
static options_t sm_profile_options;
......@@ -161,42 +118,21 @@ public:
return *tm_profiler;
}
template <typename T>
static MemPool<T>& get_mem_pool() {
thread_local MemPool<T>* t_pool = nullptr;
if (t_pool == nullptr) {
auto& pool = get_instance().m_mem_pools[typeid(MemPool<T>)];
if (pool == nullptr) {
pool =
AnyPtr(new MemPool<T>(),
{nullptr, [](void*, void* ptr) {
delete reinterpret_cast<MemPool<T>*>(ptr);
}});
}
t_pool = pool.as<MemPool<T>>();
}
return *t_pool;
}
static uint64_t next_id() { return sm_last_id++; }
template <typename T, typename... TArgs>
static uint64_t record(TArgs&&... args) {
auto& profiler = get_instance();
auto& mem_pool = get_mem_pool<T>();
// auto& mem_pool = get_mem_pool<T>();
if constexpr (sm_debug) {
Status expected = Running;
mgb_assert(profiler.m_status.compare_exchange_strong(expected, Recording));
}
uint64_t id = next_id();
profiler::Time time = sm_timer.record_host();
auto deleter = [](void* obj, void* ptr) {
reinterpret_cast<MemPool<T>*>(obj)->free(reinterpret_cast<T*>(ptr));
};
profiler.m_records.emplace_back(
id, profiler.m_thread_id, time,
AnyPtr{mem_pool.alloc(T{std::forward<TArgs>(args)...}),
{&mem_pool, deleter}});
AnyPtr::make<T>(T{std::forward<TArgs&&>(args)...}));
if constexpr (sm_debug) {
Status expected = Recording;
mgb_assert(profiler.m_status.compare_exchange_strong(expected, Running));
......@@ -241,7 +177,7 @@ public:
bundle.options = get_options();
bundle.start_at = sm_start_at;
bundle.thread_dict = get_thread_dict();
return std::move(bundle);
return bundle;
}
static option_t get_option(std::string key, option_t default_val) {
......
/**
* \file imperative/src/include/megbrain/imperative/utils/allocator.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <typeindex>
#include "megbrain/utils/mempool.h"
#include "megbrain/utils/metahelper.h"
namespace mgb::imperative {
template <typename T>
class Allocator {
public:
using pointer = T*;
using const_pointer = const T*;
using void_pointer = void*;
using const_void_pointer = const void*;
using value_type = T;
using size_type = std::size_t;
using diffenence_type = std::ptrdiff_t;
using pool_type = MemPoolStorage;
private:
pool_type* m_pool = nullptr;
public:
Allocator(pool_type* pool) : m_pool(pool) {}
T* allocate(size_type n) {
mgb_assert(n == 1);
return m_pool->alloc(sizeof(T));
}
void deallocate(pointer* p, size_type n) {
mgb_assert(n == 1);
m_pool->free(p);
}
bool operator==(const Allocator& rhs) const { return m_pool == rhs.m_pool; }
bool operator!=(const Allocator& rhs) const { return m_pool != rhs.m_pool; }
};
template <typename T>
class ThreadLocalAllocatorAdapter {
public:
using value_type = T;
using size_type = std::size_t;
using pointer = T*;
public:
T* allocate(size_type n) { mgb_assert(false); }
void deallocate(pointer* p, size_type n) { mgb_assert(false); }
bool operator==(const ThreadLocalAllocatorAdapter& rhs) const { return true; }
bool operator!=(const ThreadLocalAllocatorAdapter& rhs) const { return false; }
};
} // namespace mgb::imperative
\ No newline at end of file
/**
* \file imperative/src/include/megbrain/imperative/utils/any.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <typeindex>
#include "megbrain/imperative/utils/local_ptr.h"
namespace mgb::imperative {
class AnyMixinBase {
private:
const std::type_info* m_type = nullptr;
public:
AnyMixinBase() = default;
const std::type_info& type() const { return *m_type; }
friend class AnyPtr;
};
template <typename T>
class AnyMixin : public AnyMixinBase, public T {
public:
AnyMixin(T&& val) : T(std::move(val)) {}
};
class AnyPtr {
public:
using storage_t = LocalPtr<AnyMixinBase>;
private:
storage_t m_storage;
public:
const std::type_info& type() const { return m_storage->type(); }
template <typename T>
const T& cast() const {
mgb_assert(is_exactly<T>(), "type mismatch");
return *static_cast<const AnyMixin<T>*>(m_storage.get());
}
template <typename T>
bool is_exactly() const {
return std::type_index{typeid(T)} == std::type_index{type()};
}
bool operator==(std::nullptr_t nptr) const { return m_storage == nullptr; }
bool operator!=(std::nullptr_t nptr) const { return m_storage != nullptr; }
operator bool() const { return m_storage != nullptr; }
template <typename T, typename... TArgs>
static AnyPtr make(TArgs&&... args) {
AnyPtr ret;
ret.m_storage = LocalPtr<AnyMixinBase>::make<AnyMixin<T>>(
std::forward<TArgs&&>(args)...);
ret.m_storage->m_type = &typeid(T);
return ret;
}
};
} // namespace mgb::imperative
/**
* \file imperative/src/include/megbrain/imperative/utils/visit.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <chrono>
#include <future>
#include <vector>
#include "megbrain/utils/metahelper.h"
#include "megbrain/utils/small_vector.h"
namespace mgb::imperative {
class BoxBase : public NonCopyableObj {
public:
virtual void reset() = 0;
virtual void set_exception(std::exception_ptr exc) = 0;
virtual bool try_set_exception(std::exception_ptr exc) = 0;
};
/**
* \brief An reusable promise
*
* \tparam T type of value
*/
template <typename T>
class Box final : public BoxBase {
private:
std::promise<T> m_promise;
std::shared_future<T> m_future;
std::mutex m_mutex;
bool m_value_set;
bool m_exception_set;
public:
Box() { reset(); }
const T& get_value() { return m_future.get(); }
T take_value() {
T value = m_future.get();
reset();
return value;
}
void set_value(T value) {
MGB_LOCK_GUARD(m_mutex);
m_promise.set_value(std::move(value));
m_value_set = true;
}
bool try_set_value(T value) {
MGB_LOCK_GUARD(m_mutex);
if (m_exception_set) {
return false;
}
m_promise.set_value(std::move(value));
m_value_set = true;
return true;
}
void set_exception(std::exception_ptr exc) override {
MGB_LOCK_GUARD(m_mutex);
m_promise.set_exception(exc);
m_exception_set = true;
}
bool try_set_exception(std::exception_ptr exc) override {
MGB_LOCK_GUARD(m_mutex);
if (m_value_set) {
return false;
}
m_promise.set_exception(exc);
m_exception_set = true;
return true;
}
void reset() override {
MGB_LOCK_GUARD(m_mutex);
m_promise = {};
m_future = m_promise.get_future();
m_value_set = false;
m_exception_set = false;
}
/**
* \brief make an empty box
*
* \return std::shared_ptr<Box>
*/
static std::shared_ptr<Box> make() { return std::make_shared<Box>(); }
};
} // namespace mgb::imperative
/**
* \file imperative/src/include/megbrain/imperative/utils/span.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <iomanip>
#include <memory>
#include <sstream>
namespace mgb {
namespace imperative {
template <typename T>
class CleanupGuard {
private:
T m_callback;
public:
explicit CleanupGuard(T cb) : m_callback{std::move(cb)} {}
~CleanupGuard() { m_callback(); }
};
inline std::string quoted(std::string str) {
std::stringstream ss;
ss << std::quoted(str);
return ss.str();
}
} // namespace imperative
} // namespace mgb
\ No newline at end of file
/**
* \file imperative/src/include/megbrain/imperative/utils/intrusive_list.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/utils/metahelper.h"
namespace mgb::imperative::utils::intrusive_list {
// copy policy
struct after_t {};
struct before_t {};
struct disable_t {};
template <typename T>
struct Tail;
// invariant: next->prev == this
template <typename T>
struct Head {
Tail<T>* next;
Head(Tail<T>* node = nullptr) : next(node) {}
Head(const Head<T>&) = delete;
Head<T>& operator=(const Head<T>&) = delete;
Head(Head<T>&& rhs) : next(rhs.next) {
rhs.next = nullptr;
if (next) {
next->prev = this;
}
}
Head<T>& operator=(Head<T>&& rhs) {
mgb_assert(!next);
next = rhs.next;
rhs.next = nullptr;
if (next) {
next->prev = this;
}
return *this;
}
~Head() {
if (next) {
next->prev = nullptr;
}
}
};
// invariant: prev->next == this
template <typename T>
struct Tail {
Head<T>* prev;
Tail(Head<T>* node = nullptr) : prev(node) {}
Tail(const Tail<T>&) = delete;
Tail<T>& operator=(const Tail<T>&) = delete;
Tail(Tail<T>&& rhs) : prev(rhs.prev) {
rhs.prev = nullptr;
if (prev) {
prev->next = this;
}
}
Tail<T>& operator=(Tail<T>&& rhs) {
mgb_assert(!prev);
prev = rhs.prev;
rhs.prev = nullptr;
if (prev) {
prev->next = this;
}
return *this;
}
~Tail() {
if (prev) {
prev->next = nullptr;
}
}
};
template <typename T, typename policy>
struct Node;
template <typename T>
class Iterator {
T* ptr;
void inc() { ptr = static_cast<T*>(ptr->Head<T>::next); }
void dec() { ptr = static_cast<T*>(ptr->Head<T>::prev); }
public:
Iterator(Head<T>& head) : ptr(static_cast<T*>(head.next)) {}
Iterator(Tail<T>& tail) : ptr(static_cast<T*>(tail.prev)) {}
template <typename policy>
Iterator(Node<T, policy>& node) : ptr(static_cast<T*>(&node)) {}
T& operator*() { return *static_cast<T*>(ptr); }
T* operator->() { return static_cast<T*>(ptr); }
operator bool() { return ptr; }
bool operator==(const Iterator<T>& rhs) { return ptr == rhs.ptr; }
Iterator& operator++() {
inc();
return *this;
}
Iterator& operator--() {
dec();
return *this;
}
Iterator operator++(int) {
auto ret = *this;
inc();
return ret;
}
Iterator operator--(int) {
auto ret = *this;
dec();
return ret;
}
};
// Node in a doubly linked list. Unlike std::list, nodes are not owned by a container.
// Instead, nodes may join or leave a list freely.
// NOTE: Derived classes have to explicitly declare copy / assignment as default,
// otherwise the compiler generated version would use the const T& signature,
// which is deleted.
template <typename T = void, typename policy = disable_t>
struct Node : Tail<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>>,
Head<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>> {
private:
using this_t = Node<T, policy>;
using U = std::conditional_t<std::is_same_v<T, void>, this_t, T>;
public:
using head_t = Head<U>;
using tail_t = Tail<U>;
using head_t::next;
using tail_t::prev;
Node() = default;
Node(const this_t&) = delete;
this_t& operator=(const this_t&) = delete;
//! constructed node is inserted after the input node
Node(after_t, head_t& node) : tail_t(&node), head_t(node.next) {
node.next = this;
if (next) {
next->prev = this;
}
}
//! constructed node is inserted before the input node
Node(before_t, tail_t& node) : head_t(&node), tail_t(node.prev) {
node.prev = this;
if (prev) {
prev->next = this;
}
}
Node(this_t&& rhs) : tail_t(rhs.prev), head_t(rhs.next) {
rhs.prev = nullptr;
rhs.next = nullptr;
if (prev) {
prev->next = this;
}
if (next) {
next->prev = this;
}
}
Node& operator=(this_t&& rhs) {
unlink();
prev = rhs.prev;
next = rhs.next;
rhs.prev = nullptr;
rhs.next = nullptr;
if (prev) {
prev->next = this;
}
if (next) {
next->prev = this;
}
return *this;
}
template <
typename p = policy,
typename = std::enable_if_t<
std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>>
Node(this_t& rhs) : Node(policy{}, rhs) {}
template <
typename p = policy,
typename = std::enable_if_t<
std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>>
this_t& operator=(this_t& rhs) {
insert(policy{}, rhs);
return *this;
}
void unlink() {
if (prev) {
prev->next = next;
}
if (next) {
next->prev = prev;
}
prev = nullptr;
next = nullptr;
}
//! this node is unlinked from its list and inserted after the input node
void insert(after_t, head_t& node) {
unlink();
prev = &node;
next = node.next;
node.next = this;
if (next) {
next->prev = this;
}
}
//! this node is unlinked from its list and inserted before the input node
void insert(before_t, tail_t& node) {
unlink();
next = &node;
prev = node.prev;
node.prev = this;
if (prev) {
prev->next = this;
}
}
void insert_before(tail_t& node) { insert(before_t{}, node); }
void insert_after(head_t& node) { insert(after_t{}, node); }
~Node() { unlink(); }
};
} // namespace mgb::imperative::utils::intrusive_list
/**
* \file imperative/src/include/megbrain/imperative/utils/local_ptr.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <optional>
#include "megbrain/imperative/utils/mempool.h"
#include "megbrain/utils/metahelper.h"
namespace mgb::imperative {
template <typename T>
class LocalPtrStorage : public NonCopyableObj {
private:
size_t m_ref_count = 0;
size_t m_weak_count = 0;
T* m_pointer = nullptr;
void (*reset)(LocalPtrStorage*) = nullptr;
void (*free)(LocalPtrStorage*) = nullptr;
void inc_ref() { m_ref_count++; }
void dec_ref() {
m_ref_count--;
if (m_ref_count == 0) {
reset(this);
m_pointer = nullptr;
reset = nullptr;
if (m_weak_count == 0) {
free(this);
// dead
}
}
}
void inc_weak_ref() { m_weak_count++; }
void dec_weak_ref() {
m_weak_count--;
if ((m_weak_count + m_ref_count) == 0) {
free(this);
// dead
}
}
template <typename U>
friend class LocalPtr;
template <typename U>
friend class LocalWeakPtr;
public:
};
template <typename T, typename TDerived>
class LocalPtrStorgeImpl : public LocalPtrStorage<T> {
private:
std::optional<TDerived> m_value;
void* m_pool = nullptr;
template <typename U>
friend class LocalPtr;
template <typename U>
friend class LocalWeakPtr;
};
template <typename T>
class LocalWeakPtr;
/**
* \brief thread-unsafe smart pointer
*
* \tparam T type of value
*/
template <typename T>
class LocalPtr {
public:
using storage_t = LocalPtrStorage<T>;
using pool_t = MemPool<storage_t>;
using weak_type = LocalWeakPtr<T>;
private:
storage_t* m_storage = nullptr;
void emplace(storage_t* ptr) {
if (ptr) {
ptr->inc_ref();
m_storage = ptr;
}
}
LocalPtr(storage_t* ptr) { emplace(ptr); }
public:
LocalPtr() = default;
LocalPtr(const LocalPtr& rhs) { (*this) = rhs; }
LocalPtr(LocalPtr&& rhs) { (*this) = std::move(rhs); }
LocalPtr& operator=(const LocalPtr& rhs) {
if (this == &rhs) {
return *this;
}
auto storage = rhs.m_storage;
if (storage) {
storage->inc_ref();
}
if (m_storage) {
m_storage->dec_ref();
// rhs.m_storage may be invalid here
}
m_storage = storage;
return *this;
}
LocalPtr& operator=(LocalPtr&& rhs) {
if (this == &rhs) {
return *this;
}
std::swap(m_storage, rhs.m_storage);
rhs.reset();
return *this;
}
bool operator==(const LocalPtr& rhs) const { return m_storage == rhs.m_storage; }
bool operator!=(const LocalPtr& rhs) const { return m_storage != rhs.m_storage; }
size_t hash() const { return reinterpret_cast<uintptr_t>(m_storage); }
~LocalPtr() { reset(); }
/**
* \brief Construct an instance of TDerived and return an LocalPtr
*
* There is an memory pool for each (T, TDerived) pair
*
* \tparam TDerived type of concrete instance, should be subclass of T
* \tparam TArgs
* \param args constructor arguments
* \return LocalPtr points to the instance
*/
template <typename TDerived = T, typename... TArgs>
static LocalPtr make(TArgs&&... args) {
static_assert(std::is_base_of_v<T, TDerived>);
using storage_impl_t = LocalPtrStorgeImpl<T, TDerived>;
constexpr auto normalize_size = [](size_t size) {
size_t normalized_size = 64;
while (normalized_size < size) {
normalized_size *= 2;
}
return normalized_size;
};
using raw_storage_t =
std::aligned_storage_t<normalize_size(sizeof(storage_impl_t))>;
static_assert(alignof(raw_storage_t) % alignof(storage_impl_t) == 0);
static_assert(sizeof(raw_storage_t) >= sizeof(storage_impl_t));
using pool_t = MemPool<raw_storage_t>;
pool_t& pool = MemPoolUtils<raw_storage_t>::get_thread_local();
auto* raw_storage = pool.alloc_raw();
auto* storage = reinterpret_cast<storage_impl_t*>(raw_storage);
new (storage) storage_impl_t();
storage->m_value.emplace(std::forward<TArgs&&>(args)...);
storage->m_pointer = &*storage->m_value;
storage->reset = [](storage_t* storage) {
auto* storage_impl = static_cast<storage_impl_t*>(storage);
storage_impl->m_value.reset();
storage_impl->m_pointer = nullptr;
};
storage->free = [](storage_t* storage_base) {
auto* storage = static_cast<storage_impl_t*>(storage_base);
auto* pool = reinterpret_cast<pool_t*>(storage->m_pool);
storage->m_pool = nullptr;
storage->~storage_impl_t();
auto* raw_storage = reinterpret_cast<raw_storage_t*>(storage);
pool->free_raw(raw_storage);
};
storage->m_pool = &pool;
return {(storage_t*)storage};
}
T& operator*() const { return *get(); }
T* get() const {
if ((!m_storage) || !m_storage->m_pointer) {
return nullptr;
}
return m_storage->m_pointer;
}
T* operator->() const { return get(); }
size_t ref_count() const { return m_storage->m_ref_count; }
bool unique() const { return ref_count() == 1; }
void reset() {
if (m_storage) {
m_storage->dec_ref();
m_storage = nullptr;
}
}
operator bool() const { return bool(m_storage); }
bool operator==(std::nullptr_t nptr) const { return m_storage == nullptr; }
bool operator!=(std::nullptr_t nptr) const { return m_storage != nullptr; }
template <typename U>
friend class LocalWeakPtr;
};
template <typename T>
class LocalWeakPtr {
public:
using storage_t = LocalPtrStorage<T>;
private:
storage_t* m_storage = nullptr;
void emplace(storage_t* ptr) {
if (ptr) {
ptr->inc_weak_ref();
m_storage = ptr;
}
}
public:
LocalWeakPtr() = default;
LocalWeakPtr(const LocalPtr<T>& rhs) { emplace(rhs.m_storage); }
LocalWeakPtr(const LocalWeakPtr& rhs) { (*this) = rhs; }
LocalWeakPtr(LocalWeakPtr&& rhs) { (*this) = std::move(rhs); }
LocalWeakPtr& operator=(const LocalWeakPtr& rhs) {
if (this == &rhs) {
return *this;
}
reset();
emplace(rhs.m_storage);
return *this;
}
LocalWeakPtr& operator=(LocalWeakPtr&& rhs) {
if (this == &rhs) {
return *this;
}
std::swap(m_storage, rhs.m_storage);
rhs.reset();
return *this;
}
~LocalWeakPtr() { reset(); }
void reset() {
if (m_storage) {
m_storage->dec_weak_ref();
m_storage = nullptr;
}
}
LocalPtr<T> lock() const {
if (m_storage && m_storage->m_ref_count) {
return {m_storage};
}
return {};
}
bool operator==(const LocalWeakPtr& rhs) const {
return m_storage == rhs.m_storage;
}
bool operator!=(const LocalWeakPtr& rhs) const {
return m_storage != rhs.m_storage;
}
size_t hash() const { return reinterpret_cast<uintptr_t>(m_storage); }
};
template <typename T, typename TDerived, typename... TArgs>
LocalPtr<T> make_local(TArgs&&... args) {
return LocalPtr<T>::template make<TDerived>(std::forward<TArgs&&>(args)...);
}
} // namespace mgb::imperative
/**
* \file imperative/src/include/megbrain/imperative/utils/map.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <optional>
#include "megbrain/utils/metahelper.h"
namespace mgb::imperative {
/**
* \brief an hash map optimized for weak pointer as key
*
* Keys were scanned automatically, so values referenced by invalid keys whould be
* released soon
*
* \tparam TKey key type, requires(bool(key.lock()))
* \tparam TValue value type
*/
template <typename TKey, typename TValue>
class WeakKeyMap : public NonCopyableObj {
public:
using storage_t = std::unordered_map<TKey, TValue>;
private:
storage_t m_storage;
typename storage_t::iterator m_cursor = m_storage.begin();
/**
* \brief select a key and verify that whether it is invalid. If yes, erase it
*
*/
void _step() {
if (m_cursor == m_storage.end()) {
m_cursor = m_storage.begin();
return;
}
auto key = m_cursor->first;
if (!key.lock()) {
m_cursor = m_storage.erase(m_cursor);
} else {
++m_cursor;
}
}
public:
size_t count(TKey key) {
_step();
_step();
return m_storage.count(key);
}
TValue& at(TKey key) const { return m_storage.at(key); }
TValue& at(TKey key) {
_step();
_step();
return m_storage.at(key);
}
TValue& operator[](TKey key) {
_step();
_step();
if (m_storage.count(key)) {
return m_storage.at(key);
} else {
size_t bucket_count = m_storage.bucket_count();
TValue& result = m_storage[key];
if (bucket_count != m_storage.bucket_count()) {
m_cursor = m_storage.begin();
}
return result;
}
}
std::optional<TValue> try_get(TKey key) const {
auto iter = m_storage.find(key);
if (iter == m_storage.end()) {
return {};
}
return {iter->second};
}
std::optional<TValue> try_get(TKey key) {
_step();
_step();
return ((const WeakKeyMap*)this)->try_get(std::move(key));
}
};
template <typename TKey, typename TValue>
class WeakValueMap : public NonCopyableObj {
public:
using storage_t = std::unordered_map<TKey, TValue>;
private:
storage_t m_storage;
typename storage_t::iterator m_cursor = m_storage.begin();
/**
* \brief select a key and verify that whether it is invalid. If yes, erase it
*
*/
void _step() {
if (m_cursor == m_storage.end()) {
m_cursor = m_storage.begin();
return;
}
auto value = m_cursor->second;
if (!value.lock()) {
m_cursor = m_storage.erase(m_cursor);
} else {
++m_cursor;
}
}
public:
size_t count(TKey key) {
_step();
_step();
return m_storage.count(key);
}
TValue& at(TKey key) const { return m_storage.at(key); }
TValue& at(TKey key) {
_step();
_step();
return m_storage.at(key);
}
TValue& operator[](TKey key) {
_step();
_step();
if (m_storage.count(key)) {
return m_storage.at(key);
} else {
size_t bucket_count = m_storage.bucket_count();
TValue& result = m_storage[key];
if (bucket_count != m_storage.bucket_count()) {
m_cursor = m_storage.begin();
}
return result;
}
}
};
} // namespace mgb::imperative
\ No newline at end of file
/**
* \file imperative/src/include/megbrain/imperative/utils/mempool.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <mutex>
#include <thread>
#include <unordered_map>
#include "megbrain/utils/mempool.h"
#include "megbrain/utils/metahelper.h"
namespace mgb::imperative {
template <typename T>
class MemPoolUtils {
private:
static std::mutex sm_mutex;
static std::unordered_map<std::thread::id, std::unique_ptr<MemPool<T>>>
sm_instances;
static thread_local MemPool<T>* tm_instance;
static MemPool<T>* sm_instance;
public:
static MemPool<T>& get_thread_local() {
if (!tm_instance) {
MGB_LOCK_GUARD(sm_mutex);
auto& instance = sm_instances[std::this_thread::get_id()];
if (!instance) { // thread id may be duplicated
instance = std::make_unique<MemPool<T>>();
}
tm_instance = instance.get();
}
return *tm_instance;
}
static MemPool<T>& get_static() {
if (!sm_instance) {
MGB_LOCK_GUARD(sm_mutex);
auto& instance = sm_instances[{}];
if (!instance) { // double check
instance = std::make_unique<MemPool<T>>();
sm_instance = instance.get();
}
mgb_assert(sm_instance);
}
}
};
template <typename T>
std::mutex MemPoolUtils<T>::sm_mutex;
template <typename T>
std::unordered_map<std::thread::id, std::unique_ptr<MemPool<T>>>
MemPoolUtils<T>::sm_instances;
template <typename T>
thread_local MemPool<T>* MemPoolUtils<T>::tm_instance;
template <typename T>
MemPool<T>* MemPoolUtils<T>::sm_instance;
} // namespace mgb::imperative
\ No newline at end of file
/**
* \file imperative/src/include/megbrain/imperative/utils/span.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <array>
#include <vector>
#include "megbrain/utils/small_vector.h"
namespace mgb::imperative {
/**
* \brief wrapper for c-style array
*
* \tparam T value type
*/
template <typename T>
class Span {
private:
const T* m_begin = nullptr;
const T* m_end = nullptr;
public:
Span() {}
Span(const T* begin, const T* end) : m_begin{begin}, m_end{end} {}
Span(const T* begin, size_t size) : Span(begin, begin + size) {}
template <typename TContainer>
Span(TContainer& container) : Span(container.data(), container.size()) {}
const T* begin() const { return m_begin; }
const T* end() const { return m_end; }
const T* data() const { return m_begin; }
size_t size() const { return m_end - m_begin; }
template <typename TContainer>
TContainer copy_into() {
return TContainer(m_begin, m_end);
}
const T& operator[](size_t idx) const { return m_begin[idx]; }
const T& at(size_t idx) const { return m_begin[idx]; }
const T& item() const {
mgb_assert(
m_end - m_begin == 1, "size mismatch: %zu vs %zu", (m_end - m_begin),
(size_t)1);
return m_begin[0];
}
template <size_t N>
const std::array<T, N>& as_array() {
mgb_assert(
m_end - m_begin == N, "size mismatch: %zu vs %zu", (m_end - m_begin),
N);
return *reinterpret_cast<const std::array<T, N>*>(m_begin);
}
Span sub(size_t begin, size_t length) {
mgb_assert(begin + length <= m_end - m_begin);
return {m_begin + begin, length};
}
};
} // namespace mgb::imperative
......@@ -16,6 +16,7 @@
#include <tuple>
#include <type_traits>
#include "megbrain/imperative/utils/span.h"
#include "megbrain/tensor.h"
#include "megbrain/utils/small_vector.h"
......@@ -59,6 +60,22 @@ struct ToStringTrait<SmallVector<T, N>> {
}
};
template <typename T>
struct ToStringTrait<std::vector<T>> {
std::string operator()(const std::vector<T>& v) const {
if (v.empty()) {
return "[]";
}
std::string result = "[";
result += to_string(v[0]);
for (size_t i = 1; i < v.size(); ++i) {
result += ", ";
result += to_string(v[i]);
}
return result + "]";
}
};
template <typename T>
struct ToStringTrait<std::shared_ptr<T>> {
std::string operator()(const std::shared_ptr<T>& sp) const {
......@@ -115,4 +132,36 @@ struct ToStringTrait<CompNode> {
std::string operator()(CompNode device) const { return device.to_string(); }
};
inline std::string string_join(Span<std::string> span, char delimiter = ',') {
std::string buffer = "[";
for (size_t i = 1; i < span.size(); ++i) {
if (i) {
buffer.push_back(delimiter);
}
buffer.append(span[0]);
}
return buffer + "]";
}
template <typename T>
struct ToStringTrait<Span<T>> {
std::string operator()(Span<T> span) const {
if (span.size() == 0) {
return "[]";
}
std::string result = "[";
result += to_string(span[0]);
for (size_t i = 1; i < span.size(); ++i) {
result += ", ";
result += to_string(span[i]);
}
return result + "]";
}
};
template <>
struct ToStringTrait<std::type_info> {
std::string operator()(const std::type_info& info) const { return info.name(); }
};
} // namespace mgb::imperative
/**
* \file imperative/src/include/megbrain/imperative/utils/visit.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <vector>
#include "megbrain/imperative/utils/span.h"
#include "megbrain/tensor.h"
namespace mgb::imperative {
/**
* \brief like TensorShape, but allow real scalar shape.
*
*/
struct ValueShape {
size_t shape[TensorShape::MAX_NDIM];
int ndim = 0;
ValueShape() = default;
ValueShape(std::initializer_list<size_t> dims) {
for (auto&& dim : dims) {
shape[ndim++] = dim;
}
}
ValueShape(Span<size_t> dims) {
for (auto&& dim : dims) {
shape[ndim++] = dim;
}
}
size_t& operator[](int axis) { return shape[axis]; }
size_t operator[](int axis) const { return shape[axis]; }
size_t at(int axis) const {
mgb_assert(axis < ndim);
return shape[axis];
}
size_t total_nr_elems() const {
size_t prod = 1;
for (int i = 0; i < ndim; ++i) {
prod *= shape[i];
}
return prod;
}
bool is_scalar() const { return ndim == 0; }
std::string to_string() const {
std::string buffer = "{";
for (size_t i = 0; i < ndim; ++i) {
if (i) {
buffer.append(",");
}
buffer.append(std::to_string(shape[i]));
}
buffer.append("}");
return buffer;
}
static ValueShape from(TensorShape tensor_shape) {
mgb_assert(tensor_shape.ndim);
return Span<size_t>{tensor_shape.shape, tensor_shape.ndim};
}
TensorShape as_tensor_shape() const {
mgb_assert(ndim != 0);
TensorShape ret;
for (size_t i = 0; i < ndim; ++i) {
ret.shape[i] = shape[i];
}
ret.ndim = ndim;
return ret;
}
bool operator==(const ValueShape& rhs) const {
if (ndim != rhs.ndim) {
return false;
}
for (size_t i = 0; i < ndim; ++i) {
if (shape[i] != rhs.shape[i]) {
return false;
}
}
return true;
}
};
static_assert(sizeof(size_t) >= sizeof(int));
static_assert(TensorShape::MAX_NDIM == 7);
static_assert(sizeof(ValueShape) <= sizeof(size_t) * 8);
} // namespace mgb::imperative
\ No newline at end of file
/**
* \file imperative/src/include/megbrain/imperative/utils/visit.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <vector>
#include "megbrain/utils/small_vector.h"
namespace mgb::imperative {
template <typename... TVisitors>
class Visitor : public TVisitors... {
public:
using TVisitors::operator()...;
};
} // namespace mgb::imperative
......@@ -28,10 +28,10 @@ TEST(TestProfiler, ImperativeLogProfile) {
auto results = imperative::Profiler::collect();
imperative::Profiler::stop_profile();
mgb_assert(results.entries.size() == 2);
auto* event_start = results.entries[0].data.as<profiler::CustomEvent>();
auto* event_finish = results.entries[1].data.as<profiler::CustomFinishEvent>();
mgb_assert(event_start && event_start->title == "XXX");
mgb_assert(event_finish && event_finish->title == "XXX");
auto& event_start = results.entries[0].data.cast<profiler::CustomEvent>();
auto& event_finish = results.entries[1].data.cast<profiler::CustomFinishEvent>();
mgb_assert(event_start.title == "XXX");
mgb_assert(event_finish.title == "XXX");
mgb_assert(results.entries[0].time < results.entries[1].time);
mgb_assert(results.entries[0].id < results.entries[1].id);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册