#include "megbrain/imperative/value.h" #include "megbrain/imperative/basic_operators.h" #include "megbrain/imperative/dispatch.h" #include "megbrain/imperative/utils/map.h" namespace mgb { namespace imperative { namespace { static /*thread_local*/ size_t nr_watched_values = 0; static /*thread_local*/ uint64_t nr_values = 0; static /*thread_local*/ bool recording_values = false; static /*thread_local*/ std::vector recorded_values; static WeakValueMap registered_values; } // namespace ValueRef::storage_t& ValueRef::storage() const { if (mgb_likely(!m_storage->m_successor.m_storage)) { return m_storage; } while (m_storage->m_successor.m_storage) { m_storage = m_storage->m_successor.m_storage; } return m_storage; } const Value* ValueRef::as(size_t typecode) const { auto&& storage = this->storage(); if (storage->m_typecode != typecode) { return nullptr; } return static_cast(storage.get()); } bool ValueRef::is(size_t typecode) const { return this->storage()->m_typecode == typecode; } TypedValueRef ValueRef::dev_tensor() const { return imperative::apply(GetAttr(GetAttr::Data), *this)[0].cast_ref(); } TypedValueRef ValueRef::numpy() const { return imperative::apply(GetAttr(GetAttr::Value), *this)[0].cast_ref(); } TypedValueRef ValueRef::device() const { return imperative::apply(GetAttr(GetAttr::Device), *this)[0] .cast_ref(); } TypedValueRef ValueRef::shape() const { return imperative::apply(GetAttr(GetAttr::Shape), *this)[0].cast_ref(); } TypedValueRef ValueRef::dtype() const { return imperative::apply(GetAttr(GetAttr::DType), *this)[0].cast_ref(); } TypedValueRef ValueRef::name() const { return imperative::apply(GetName(), *this)[0].cast_ref(); } bool ValueRef::is_scalar() const { return imperative::apply(IsScalar(), *this)[0].cast(); } void ValueRef::watch() const { mgb_assert(m_storage); storage()->m_watching++; nr_watched_values++; storage()->on_watch(); // TODO: // imperative::apply(Watch(), this); } void ValueRef::unwatch() const { mgb_assert(m_storage); storage()->m_watching--; nr_watched_values--; storage()->on_unwatch(); } ValueRef ValueRef::unwrap() const { auto& context = Transformation::get_context(); if (mgb_unlikely(context.next_transformation)) { ValueRef value = *this; for (size_t i = 0; i < context.next_transformation; ++i) { value = context.transformations[i]->unwrap(value); } return value; } return *this; } std::string ValueRef::to_string() const { if (!m_storage) { return ""; } return ssprintf( "(%zu:%zu) %s", id(), storage()->m_id, storage()->to_string().c_str()); } std::string ValueRef::raw_type() const { if (!m_storage) { return "null"; } auto& types = Value::registered_types(); mgb_assert(types.size() > m_storage->m_typecode); return types[m_storage->m_typecode].name(); } bool ValueRef::watching() const { if (!m_storage) { return false; } return this->storage()->m_watching; } ValueRef ValueRef::make(ValueRef::storage_t storage) { if (recording_values) { recorded_values.push_back({storage}); } return {storage}; } bool ValueRef::any_watching() { return nr_watched_values != 0; } ValueRef ValueWeakRef::lock() { auto strong_storage = m_storage.lock(); if ((!strong_storage) || strong_storage->m_successor) { return {}; } return {strong_storage}; } Value::Value(size_t typecode) : m_typecode{typecode} { m_id = nr_values++; } Value::~Value() { if (m_watching) { debug::notify_event("dtor"); } } size_t Value::register_type(std::type_index type) { auto& types = const_cast&>(registered_types()); types.push_back(type); return types.size() - 1; } const std::vector& Value::registered_types() { static std::vector sm_registered_types; return sm_registered_types; } void Value::register_value(ValueRef value) { registered_values[value.id()] = ValueWeakRef(value); } ValueRef Value::get_value_by_id(uint64_t id) { auto& weak_value = registered_values[id]; if (auto value = weak_value.lock()) { return value; } return {}; } void Value::begin_record_values() { mgb_assert(!recording_values); recording_values = true; recorded_values.clear(); } std::vector Value::end_record_values() { recording_values = false; std::vector recorded_strong_values; for (auto&& weak_value : recorded_values) { if (auto value = weak_value.lock()) { recorded_strong_values.push_back(value); } } return recorded_strong_values; } void Value::try_rethrow() { if (m_typecode == ErrorValue::TYPE_CODE) { auto message = static_cast(this)->message(); mgb_throw(MegBrainError, "invalid value: %s", message.c_str()); } } inline void ValueRefList::init(size_t nr_elems) { m_size = nr_elems; if (m_size > 0) { if (m_size == 1) { m_data = inline_storage(); } else { auto& context = Transformation::get_context(); m_data = context.allocator.allocate(m_size); } for (size_t i = 0; i < m_size; ++i) { new (m_data + i) ValueRef(); } } else { m_data = nullptr; } } ValueRefList::ValueRefList(size_t nr_elems) { init(nr_elems); } ValueRefList::ValueRefList(std::initializer_list values) : ValueRefList(values.begin(), values.end()) {} ValueRefList::ValueRefList(const ValueRefList& rhs) : ValueRefList(rhs.cbegin(), rhs.cend()) {} ValueRefList::ValueRefList(ValueRefList&& rhs) : ValueRefList() { m_size = rhs.m_size; if (rhs.m_data == rhs.inline_storage()) { m_data = inline_storage(); new (m_data) ValueRef(); m_data[0] = std::move(rhs.m_data[0]); } else { m_data = rhs.m_data; rhs.m_data = nullptr; rhs.m_size = 0; } } ValueRefList& ValueRefList::operator=(const ValueRefList& rhs) { if (this == &rhs) { return *this; } clear(); init(rhs.m_size); for (size_t i = 0; i < m_size; ++i) { m_data[i] = rhs.m_data[i]; } return *this; } ValueRefList& ValueRefList::operator=(ValueRefList&& rhs) { if (this == &rhs) { return *this; } clear(); if (rhs.m_data == rhs.inline_storage()) { m_data = inline_storage(); new (m_data) ValueRef(); m_data[0] = rhs.m_data[0]; m_size = 1; rhs.clear(); } else { m_data = rhs.m_data; m_size = rhs.m_size; rhs.m_data = nullptr; rhs.m_size = 0; } return *this; } ValueRefList::~ValueRefList() { clear(); } void ValueRefList::clear() { for (size_t i = 0; i < m_size; ++i) { m_data[i].~ValueRef(); } if (m_data) { if (m_size != 1) { Transformation::get_context().allocator.deallocate(m_data, m_size); } else { mgb_assert(m_data == inline_storage()); } } m_data = nullptr; m_size = 0; } } // namespace imperative } // namespace mgb