From 2e42bc083d20a1ecc49b81c0f3c2cb6f769f31f3 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 14 Jan 2022 13:10:52 +0800 Subject: [PATCH] feat(dispatch): implement new op dispatch system GitOrigin-RevId: 355da6b81499b4a519e5a9156c3aab6ee7263988 --- imperative/src/impl/basic_operators.cpp | 98 +++++ imperative/src/impl/basic_values.cpp | 81 ++++ imperative/src/impl/dispatch.cpp | 108 +++++ imperative/src/impl/operator.cpp | 22 + imperative/src/impl/transformation.cpp | 12 + imperative/src/impl/utils/debug.cpp | 34 ++ imperative/src/impl/value.cpp | 190 +++++++++ .../megbrain/imperative/basic_operators.h | 176 ++++++++ .../megbrain/imperative/basic_values.h | 178 ++++++++ .../include/megbrain/imperative/dispatch.h | 72 ++++ .../include/megbrain/imperative/operator.h | 102 +++++ .../megbrain/imperative/transformation.h | 199 +++++++++ .../include/megbrain/imperative/utils/debug.h | 20 + .../src/include/megbrain/imperative/value.h | 388 ++++++++++++++++++ 14 files changed, 1680 insertions(+) create mode 100644 imperative/src/impl/basic_operators.cpp create mode 100644 imperative/src/impl/basic_values.cpp create mode 100644 imperative/src/impl/dispatch.cpp create mode 100644 imperative/src/impl/operator.cpp create mode 100644 imperative/src/impl/transformation.cpp create mode 100644 imperative/src/impl/utils/debug.cpp create mode 100644 imperative/src/impl/value.cpp create mode 100644 imperative/src/include/megbrain/imperative/basic_operators.h create mode 100644 imperative/src/include/megbrain/imperative/basic_values.h create mode 100644 imperative/src/include/megbrain/imperative/dispatch.h create mode 100644 imperative/src/include/megbrain/imperative/operator.h create mode 100644 imperative/src/include/megbrain/imperative/transformation.h create mode 100644 imperative/src/include/megbrain/imperative/utils/debug.h create mode 100644 imperative/src/include/megbrain/imperative/value.h diff --git a/imperative/src/impl/basic_operators.cpp b/imperative/src/impl/basic_operators.cpp new file mode 100644 index 000000000..c0835b249 --- /dev/null +++ b/imperative/src/impl/basic_operators.cpp @@ -0,0 +1,98 @@ +#include "megbrain/imperative/basic_operators.h" + +#include "megbrain/imperative/basic_values.h" + +namespace mgb { +namespace imperative { + +std::string ApplyOp::to_string() const { + return m_op.to_string(); +} + +std::string GetAttr::to_string() const { + std::string buffer; + const char* attr_name = ([&] { + switch (m_attr) { + case None: + return "None"; + case DType: + return "DType"; + case Device: + return "Device"; + case Shape: + return "Shape"; + case Value: + return "Value"; + case Data: + return "Data"; + default: + buffer = std::to_string(m_attr); + return buffer.c_str(); + } + })(); + return ssprintf("GetAttr{attr=%s}", attr_name); +} + +CreateTensor::CreateTensor(Kind kind, CompNode device, DType dtype, ValueShape shape) + : m_kind(kind), m_device(device), m_dtype(dtype), m_shape(shape) {} + +CreateTensor::CreateTensor(Kind kind, CompNode device, TensorLayout layout) + : m_kind(kind), + m_device(device), + m_dtype(layout.dtype), + m_shape(ValueShape::from(layout)) { + mgb_assert( + layout.is_contiguous() || layout.is_empty(), "layout should be contiguous"); +} + +auto CreateTensor::parse(Span inputs) -> Args { + Args result; + for (auto&& input : inputs) { + if (auto host_storage = input.as_ref()) { + mgb_assert(!result.host, "duplicated host value"); + result.host.emplace(); + result.host->reset(*host_storage, {shape().as_tensor_shape(), dtype()}); + mgb_assert(result.host->layout().ndim, "invalid shape"); + } else if (auto device_storage = input.as_ref()) { + mgb_assert(!result.device, "duplicated device value"); + result.device.emplace(device(), shape().as_tensor_shape(), dtype()); + result.device->reset(*device_storage, {shape().as_tensor_shape(), dtype()}); + mgb_assert(result.device->layout().ndim, "invalid shape"); + } else { + mgb_throw( + MegBrainError, + "unknown input type, expects HostStorage or DeviceStorage, got " + "%s", + input.name()->c_str()); + } + } + mgb_assert( + result.host || result.device, "require at least one of host/device value"); + result.kind = kind(); + return result; +} + +std::string CreateTensor::to_string() const { + return ssprintf( + "CreateTensor{kind=%d, device=%s, dtype=%s, shape=%s}", (int)m_kind, + m_device.to_string().c_str(), m_dtype.name(), m_shape.to_string().c_str()); +} + +std::string DTRCommand::to_string() const { + return ssprintf("DTRCommandValue{kind=%d}", (int)m_kind); +} + +std::string GetName::to_string() const { + return "GetName{}"; +} + +std::string RenameValue::to_string() const { + return ssprintf("RenameValue{name=%s}", imperative::quoted(m_name).c_str()); +} + +std::string IsScalar::to_string() const { + return "IsScalar"; +} + +} // namespace imperative +} // namespace mgb diff --git a/imperative/src/impl/basic_values.cpp b/imperative/src/impl/basic_values.cpp new file mode 100644 index 000000000..19caa3859 --- /dev/null +++ b/imperative/src/impl/basic_values.cpp @@ -0,0 +1,81 @@ +#include "megbrain/imperative/basic_values.h" + +namespace mgb { +namespace imperative { + +std::string ShapeValue::to_string() const { + return ssprintf("ValueShape%s", ValueShape::to_string().c_str()); +} + +std::string CompNodeValue::to_string() const { + return CompNode::to_string(); +} + +std::string BoolValue::to_string() const { + return (*m_value) ? "true" : "false"; +} + +std::string HostStorage::to_string() const { + return ssprintf("HostStorage{device=%s}", comp_node().to_string().c_str()); +} + +std::string DeviceStorage::to_string() const { + return ssprintf("DeviceStorage{device=%s}", comp_node().to_string().c_str()); +} + +std::string HostValue::to_string() const { + return ssprintf( + "HostValue{device=%s, dtype=%s, shape=%s}", device().to_string().c_str(), + m_dtype.name(), m_shape.to_string().c_str()); +} + +HostTensorND HostValue::as_nd(bool allow_scalar) const { + HostTensorND nd; + TensorShape tensor_shape; + if (m_shape.is_scalar()) { + mgb_assert(allow_scalar); + tensor_shape = TensorShape{1}; + } else { + tensor_shape = m_shape.as_tensor_shape(); + } + nd.reset(m_storage, {tensor_shape, dtype()}); + return nd; +} + +std::string DeviceValue::to_string() const { + return ssprintf( + "DeviceValue{device=%s, dtype=%s, shape=%s}", device().to_string().c_str(), + m_dtype.name(), m_shape.to_string().c_str()); +} + +DeviceTensorND DeviceValue::as_nd(bool allow_scalar) const { + DeviceTensorND nd; + TensorShape tensor_shape; + if (m_shape.is_scalar()) { + mgb_assert(allow_scalar); + tensor_shape = TensorShape{1}; + } else { + tensor_shape = m_shape.as_tensor_shape(); + } + nd.reset(m_storage, {tensor_shape, dtype()}); + return nd; +} + +std::string FunctionValue::to_string() const { + return ssprintf("FunctionValue{type=%s}", target_type().name()); +} + +std::string DTypeValue::to_string() const { + return DType::name(); +} + +std::string StringValue::to_string() const { + return imperative::quoted((std::string&)*this); +} + +std::string ErrorValue::to_string() const { + return ssprintf("ErrorValue{message=%s}", message().c_str()); +} + +} // namespace imperative +} // namespace mgb diff --git a/imperative/src/impl/dispatch.cpp b/imperative/src/impl/dispatch.cpp new file mode 100644 index 000000000..2fc49f139 --- /dev/null +++ b/imperative/src/impl/dispatch.cpp @@ -0,0 +1,108 @@ +/** + * \file imperative/src/impl/dispatch.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/imperative/dispatch.h" + +#include "megbrain/imperative/utils/debug.h" +#include "megbrain/imperative/utils/helper.h" +#include "megbrain/imperative/utils/map.h" + +namespace mgb { +namespace imperative { + +std::vector apply(const Operator& op, Span inputs) { + static bool log_dispatch = MGB_GETENV("MGE_LOG_OP_DISPATCH"); + bool enable_watch = ValueRef::any_watching(); + auto& context = Transformation::get_context(); + size_t& depth = context.next_transformation; + static const char tabs_storage[] = "\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t"; + const char* tabs = tabs_storage + sizeof(tabs_storage) / sizeof(char) - depth - 1; + bool log_current_dispatch = log_dispatch; + if (enable_watch) { + for (size_t i = 0; i < inputs.size(); ++i) { + auto& input = inputs[i]; + if (input.watching()) { + log_current_dispatch = true; + mgb_log_debug("%sinput[%zu] is %s", tabs, i, input.to_string().c_str()); + debug::notify_event("apply"); + } + } + } + // entrance + std::vector outputs; + if (depth >= context.transformations.size()) { + // fallback + if (log_current_dispatch) { + mgb_log_debug( + "%sfallback apply %s in %s", tabs, op.to_string().c_str(), + imperative::to_string(inputs).c_str()); + } + outputs = op.fallback(inputs); + } else { + // dispatch to stack top + auto& transformation = *context.transformations[depth]; + ++depth; + context.frames.push_back({op, inputs}); + CleanupGuard _{[&] { + context.frames.pop_back(); + --depth; + }}; + if (log_current_dispatch) { + mgb_log_debug( + "%s%s apply %s in %s", tabs, transformation.name().c_str(), + op.to_string().c_str(), imperative::to_string(inputs).c_str()); + } + outputs = transformation.apply_transformation(op, inputs); + } + if (log_current_dispatch) { + mgb_log_debug("%sreturn %s", tabs, imperative::to_string(outputs).c_str()); + } + return outputs; +} + +std::vector apply(const OpDef& def, Span inputs) { + return imperative::apply(ApplyOp{def}, inputs); +} + +std::vector apply(Subgraph graph, Span inputs) { + SmallVector inputs_storage; + for (size_t i = 0; i < inputs.size(); ++i) { + inputs_storage.push_back(inputs[i]); + } + auto apply_functor = [](std::shared_ptr op, SmallVector inputs, + size_t) { + auto outputs = imperative::apply(ApplyOp(*op), inputs); + return SmallVector(outputs.begin(), outputs.end()); + }; + auto make_const = [](TensorPtr constant) -> ValueRef { + auto host_value = constant->get_value(); + auto device_value = constant->dev_tensor(); + mgb_assert( + host_value.layout().is_contiguous() && + device_value.layout().is_contiguous()); + ValueShape shape; + // FIXME: assume Tensor with shape {1} is scalar + if (!constant->shape().is_scalar()) { + shape = ValueShape::from(constant->shape()); + } + return imperative::apply( + CreateTensor( + CreateTensor::Const, constant->comp_node(), constant->dtype(), + shape), + HostStorage::make(host_value.storage()), + DeviceStorage::make(device_value.storage()))[0]; + }; + auto outputs = graph.apply(inputs_storage, apply_functor, make_const); + return {outputs.begin(), outputs.end()}; +} + +} // namespace imperative +} // namespace mgb diff --git a/imperative/src/impl/operator.cpp b/imperative/src/impl/operator.cpp new file mode 100644 index 000000000..4337adf91 --- /dev/null +++ b/imperative/src/impl/operator.cpp @@ -0,0 +1,22 @@ +#include "megbrain/imperative/operator.h" + +namespace mgb { +namespace imperative { + +std::vector Operator::fallback(Span inputs) const { + mgb_throw(MegBrainError, "no fallback implementation for %s", to_string().c_str()); +} + +size_t Operator::register_type(std::type_index type) { + auto& types = const_cast&>(registered_types()); + types.push_back(type); + return types.size() - 1; +} + +const std::vector& Operator::registered_types() { + static std::vector sm_registered_types; + return sm_registered_types; +} + +} // namespace imperative +} // namespace mgb diff --git a/imperative/src/impl/transformation.cpp b/imperative/src/impl/transformation.cpp new file mode 100644 index 000000000..2b3b326f9 --- /dev/null +++ b/imperative/src/impl/transformation.cpp @@ -0,0 +1,12 @@ +#include "megbrain/imperative/transformation.h" + +namespace mgb { +namespace imperative { + +TransformationContext& Transformation::get_context() { + thread_local TransformationContext tl_context; + return tl_context; +} + +} // namespace imperative +} // namespace mgb diff --git a/imperative/src/impl/utils/debug.cpp b/imperative/src/impl/utils/debug.cpp new file mode 100644 index 000000000..510b4bdb7 --- /dev/null +++ b/imperative/src/impl/utils/debug.cpp @@ -0,0 +1,34 @@ +/** + * \file imperative/src/impl/utils/debug.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include + +#include "megbrain/imperative/transformation.h" +#include "megbrain/imperative/utils/debug.h" +#include "megbrain/imperative/value.h" + +namespace mgb::imperative::debug { + +const char* get_type_name(const std::type_info& type) { + return type.name(); +} + +const char* get_type_name(const std::type_index& type) { + return type.name(); +} + +void notify_event(const char* event) {} + +void watch_value(ValueRef value) { + value.watch(); +} + +} // namespace mgb::imperative::debug \ No newline at end of file diff --git a/imperative/src/impl/value.cpp b/imperative/src/impl/value.cpp new file mode 100644 index 000000000..5bbb8038d --- /dev/null +++ b/imperative/src/impl/value.cpp @@ -0,0 +1,190 @@ +#include "megbrain/imperative/value.h" + +#include "megbrain/imperative/basic_operators.h" +#include "megbrain/imperative/dispatch.h" +#include "megbrain/imperative/utils/map.h" + +namespace mgb { +namespace imperative { + +namespace { +static thread_local size_t nr_watched_values = 0; +static thread_local uint64_t nr_values = 0; +static thread_local bool recording_values = false; +static thread_local std::vector recorded_values; +static WeakValueMap registered_values; +} // namespace + +ValueRef::storage_t& ValueRef::storage() const { + if (!m_storage) { + return m_storage; + } + if (auto& storage = m_storage->m_successor.m_storage) { + while (storage->m_successor.m_storage) { + storage = storage->m_successor.m_storage; + } + return storage; + } else { + return m_storage; + } +} + +TypedValueRef ValueRef::dev_tensor() const { + return imperative::apply(GetAttr(GetAttr::Data), *this)[0].as_ref(); +} + +TypedValueRef ValueRef::numpy() const { + return imperative::apply(GetAttr(GetAttr::Value), *this)[0].as_ref(); +} + +TypedValueRef ValueRef::device() const { + return imperative::apply(GetAttr(GetAttr::Device), *this)[0] + .as_ref(); +} + +TypedValueRef ValueRef::shape() const { + return imperative::apply(GetAttr(GetAttr::Shape), *this)[0].as_ref(); +} + +TypedValueRef ValueRef::dtype() const { + return imperative::apply(GetAttr(GetAttr::DType), *this)[0].as_ref(); +} + +TypedValueRef ValueRef::name() const { + return imperative::apply(GetName(), *this)[0].as_ref(); +} + +bool ValueRef::is_scalar() const { + return imperative::apply(IsScalar(), *this)[0].cast(); +} + +void ValueRef::watch() const { + mgb_assert(m_storage); + storage()->m_watching++; + nr_watched_values++; + storage()->on_watch(); + // TODO: + // imperative::apply(Watch(), this); +} + +void ValueRef::unwatch() const { + mgb_assert(m_storage); + storage()->m_watching--; + nr_watched_values--; + storage()->on_unwatch(); +} + +ValueRef ValueRef::unwrap() const { + ValueRef value = *this; + auto& context = Transformation::get_context(); + for (size_t i = 0; i < context.next_transformation; ++i) { + value = context.transformations[i]->unwrap(value); + } + mgb_assert(value); + return value; +} + +std::string ValueRef::to_string() const { + if (!m_storage) { + return ""; + } + return ssprintf( + "(%zu:%zu) %s", id(), storage()->m_id, storage()->to_string().c_str()); +} + +std::string ValueRef::raw_type() const { + if (!m_storage) { + return "null"; + } + auto& types = Value::registered_types(); + mgb_assert(types.size() > m_storage->m_typecode); + return types[m_storage->m_typecode].name(); +} + +uint64_t ValueRef::id() const { + return m_storage ? m_storage->m_id : std::numeric_limits::max(); +} + +bool ValueRef::watching() const { + auto storage = this->storage(); + return storage && storage->m_watching; +} + +ValueRef ValueRef::make(ValueRef::storage_t storage) { + if (recording_values) { + recorded_values.push_back({storage}); + } + return {storage}; +} + +bool ValueRef::any_watching() { + return nr_watched_values != 0; +} + +ValueRef ValueWeakRef::lock() { + auto strong_storage = m_storage.lock(); + if ((!strong_storage) || strong_storage->m_successor) { + return {}; + } + return {strong_storage}; +} + +Value::Value(size_t typecode) : m_typecode{typecode} { + m_id = nr_values++; +} + +Value::~Value() { + if (m_watching) { + debug::notify_event("dtor"); + } +} + +size_t Value::register_type(std::type_index type) { + auto& types = const_cast&>(registered_types()); + types.push_back(type); + return types.size() - 1; +} + +const std::vector& Value::registered_types() { + static std::vector sm_registered_types; + return sm_registered_types; +} + +void Value::register_value(ValueRef value) { + registered_values[value.id()] = ValueWeakRef(value); +} + +ValueRef Value::get_value_by_id(uint64_t id) { + auto& weak_value = registered_values[id]; + if (auto value = weak_value.lock()) { + return value; + } + return {}; +} + +void Value::begin_record_values() { + mgb_assert(!recording_values); + recording_values = true; + recorded_values.clear(); +} + +std::vector Value::end_record_values() { + recording_values = false; + std::vector recorded_strong_values; + for (auto&& weak_value : recorded_values) { + if (auto value = weak_value.lock()) { + recorded_strong_values.push_back(value); + } + } + return recorded_strong_values; +} + +void Value::try_rethrow() { + if (m_typecode == ErrorValue::TYPE_CODE) { + auto message = static_cast(this)->message(); + mgb_throw(MegBrainError, "invalid value: %s", message.c_str()); + } +} + +} // namespace imperative +} // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/basic_operators.h b/imperative/src/include/megbrain/imperative/basic_operators.h new file mode 100644 index 000000000..ee821557f --- /dev/null +++ b/imperative/src/include/megbrain/imperative/basic_operators.h @@ -0,0 +1,176 @@ +/** + * \file imperative/src/include/megbrain/imperative/basic_operators.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include + +#include "megbrain/imperative/op_def.h" +#include "megbrain/imperative/operator.h" +#include "megbrain/imperative/utils/helper.h" +#include "megbrain/imperative/utils/value_shape.h" + +namespace mgb { +namespace imperative { + +class GradKey; + +using GenericFunction = std::function(Span)>; + +/** + * \brief apply an OpDef to values + * + */ +class ApplyOp final : public OperatorImpl { +private: + const OpDef& m_op; + +public: + ApplyOp(const OpDef& op) : m_op(op) {} + + const OpDef& op() { return m_op; } + + std::string to_string() const override; +}; + +/** + * \brief get an basic attribute from Value + * + */ +class GetAttr final : public OperatorImpl { +public: + enum Attr { + None, + DType, + Device, + Shape, + Value, + Data, + }; + +private: + Attr m_attr = None; + +public: + GetAttr(Attr attr) : m_attr(attr) { + mgb_assert(attr != None, "invalid attr value: None"); + } + + Attr attr() const { return m_attr; } + + std::string to_string() const; +}; + +/** + * \brief create a tensor value from host value or device value + * + */ +class CreateTensor final : public OperatorImpl { +public: + enum Kind { + Common, // common mode, h2d can be cached to speed up + Unique, // require output value to be unqiue (donnot share memory with other + // values) + Const, // put as constant (guaranteed to be same each time) + NoTrace, // won't be trace in any case, would be used in make_backward_graph + // (looking for a better name) + }; + struct Args { + std::optional host; + std::optional device; + Kind kind; + }; + +private: + Kind m_kind; + CompNode m_device; + DType m_dtype; + ValueShape m_shape; + +public: + CreateTensor(Kind kind, CompNode device, DType dtype, ValueShape shape); + CreateTensor(Kind kind, CompNode device, TensorLayout layout); + + /** + * \brief utility function to unpack args of CreateTensor + * + * \param inputs contains host_storage and device_storage + * \return Args unpacked args + */ + Args parse(Span inputs); + + Kind kind() const { return m_kind; } + CompNode device() const { return m_device; } + DType dtype() const { return m_dtype; } + ValueShape shape() const { return m_shape; } + + std::string to_string() const override; +}; + +class DTRCommand final : public OperatorImpl { +public: + enum Kind { + None, + Drop, + }; + +private: + Kind m_kind = None; + +public: + DTRCommand(Kind kind) : m_kind(kind) {} + + Kind kind() { return m_kind; } + + std::string to_string() const override; + + std::vector fallback(Span inputs) const override { return {}; } +}; + +// deprecated +class GetName final : public OperatorImpl { +public: + std::string to_string() const override; + + std::vector fallback(Span inputs) const override { + return {ValueRef()}; + } +}; + +/** + * \brief return a value with new name + * + */ +class RenameValue : public OperatorImpl { +private: + std::string m_name; + +public: + RenameValue(std::string name) : m_name(name) {} + + std::string name() const { return m_name; } + + std::string to_string() const override; + + std::vector fallback(Span inputs) const override { + return {inputs.as_array<1>()[0]}; + } +}; + +class IsScalar final : public OperatorImpl { +private: +public: + std::string to_string() const override; +}; + +} // namespace imperative +} // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/basic_values.h b/imperative/src/include/megbrain/imperative/basic_values.h new file mode 100644 index 000000000..487ed002b --- /dev/null +++ b/imperative/src/include/megbrain/imperative/basic_values.h @@ -0,0 +1,178 @@ +/** + * \file imperative/src/include/megbrain/imperative/basic_values.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include + +#include "megbrain/imperative/utils/helper.h" +#include "megbrain/imperative/utils/value_shape.h" +#include "megbrain/imperative/value.h" + +namespace mgb { +namespace imperative { + +class GradKey; + +using GenericFunction = std::function(Span)>; + +class ShapeValue final : public MixinValueImpl { +public: + using MixinValueImpl::MixinValueImpl; + + std::string to_string() const override; +}; + +class CompNodeValue final : public MixinValueImpl { +public: + using MixinValueImpl::MixinValueImpl; + + std::string to_string() const override; +}; + +// TODO: override factory method +class BoolValue final : public ValueImpl { +private: + std::optional m_value; + +public: + BoolValue(bool value) : m_value{value} {} + operator bool() const { return *m_value; } + + std::string to_string() const override; + + void clear() override { m_value.reset(); } +}; + +class HostStorage final : public MixinValueImpl { +public: + using MixinValueImpl::MixinValueImpl; + + std::string to_string() const override; +}; + +class DeviceStorage final : public MixinValueImpl { +public: + using MixinValueImpl::MixinValueImpl; + + std::string to_string() const override; +}; + +/** + * \brief like HostTensorND mixin, but allow scalar value + * + */ +class HostValue final : public ValueImpl { +private: + DType m_dtype; + ValueShape m_shape; + HostTensorStorage m_storage; + +public: + HostValue(DType dtype, ValueShape shape, HostTensorStorage storage) + : m_dtype(dtype), m_shape(shape), m_storage(storage) {} + HostValue(HostTensorND value) + : HostValue( + value.dtype(), ValueShape::from(value.shape()), value.storage()) { + } + + std::string to_string() const override; + + void clear() override { + m_dtype = {}; + m_shape = {}; + m_storage = {}; + } + + DType dtype() const { return m_dtype; } + ValueShape shape() const { return m_shape; } + CompNode device() const { return m_storage.comp_node(); } + HostTensorStorage storage() const { return m_storage; } + + HostTensorND as_nd(bool allow_scalar = false) const; +}; + +/** + * \brief like DeviceTensorND mixin, but allow scalar value + * + */ +class DeviceValue final : public ValueImpl { +private: + DType m_dtype; + ValueShape m_shape; + DeviceTensorStorage m_storage; + +public: + DeviceValue(DType dtype, ValueShape shape, DeviceTensorStorage storage) + : m_dtype(dtype), m_shape(shape), m_storage(storage) {} + DeviceValue(DeviceTensorND value) + : DeviceValue( + value.dtype(), ValueShape::from(value.shape()), value.storage()) { + } + + std::string to_string() const override; + + void clear() override { + m_dtype = {}; + m_shape = {}; + m_storage = {}; + } + + DType dtype() const { return m_dtype; } + ValueShape shape() const { return m_shape; } + CompNode device() const { return m_storage.comp_node(); } + DeviceTensorStorage storage() const { return m_storage; } + + DeviceTensorND as_nd(bool allow_scalar = false) const; +}; + +class FunctionValue final : public MixinValueImpl { +public: + using MixinValueImpl::MixinValueImpl; + + std::string to_string() const override; +}; + +class DTypeValue final : public MixinValueImpl { +public: + using MixinValueImpl::MixinValueImpl; + + std::string to_string() const override; +}; + +class StringValue final : public MixinValueImpl { +public: + using MixinValueImpl::MixinValueImpl; + + std::string to_string() const override; +}; + +class Error { +protected: + std::string m_message; + +public: + Error() = default; + Error(std::string message) : m_message(message) {} + + std::string message() const { return m_message; } +}; + +class ErrorValue final : public MixinValueImpl { +public: + using MixinValueImpl::MixinValueImpl; + + std::string to_string() const override; +}; + +} // namespace imperative +} // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/dispatch.h b/imperative/src/include/megbrain/imperative/dispatch.h new file mode 100644 index 000000000..3a1033ce5 --- /dev/null +++ b/imperative/src/include/megbrain/imperative/dispatch.h @@ -0,0 +1,72 @@ +/** + * \file imperative/src/include/megbrain/imperative/dispatch.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include "megbrain/common.h" +#include "megbrain/imperative/basic_operators.h" +#include "megbrain/imperative/basic_values.h" +#include "megbrain/imperative/operator.h" +#include "megbrain/imperative/subgraph.h" +#include "megbrain/imperative/transformation.h" +#include "megbrain/imperative/utils/local_ptr.h" +#include "megbrain/imperative/utils/span.h" +#include "megbrain/imperative/value.h" + +namespace mgb { +namespace imperative { + +/** + * \brief dispatch entrance, requests would be forwarded to current top transformation + * (or fallback) + * + * \param op + * \param inputs + * \return std::vector + */ +std::vector apply(const Operator& op, Span inputs); +std::vector apply(const OpDef& def, Span inputs); +std::vector apply(Subgraph graph, Span inputs); + +template +constexpr bool is_all_value_ref_v = + (... && (std::is_base_of_v> || + std::is_same_v>)); + +template +static auto apply(T&& op, TArgs&&... args) + -> std::enable_if_t, std::vector> { + ValueRef args_arr[sizeof...(TArgs)] = {std::forward(args)...}; + return imperative::apply( + std::forward(op), + Span(std::begin(args_arr), std::end(args_arr))); +} + +template +static auto apply(T&& op, TContainer&& container) -> std::enable_if_t< + std::is_same_v< + std::remove_const_t>, + ValueRef> && + std::is_same_v && + !std::is_same_v, Span>, + std::vector> { + return imperative::apply( + std::forward(op), Span(container.data(), container.size())); +} + +} // namespace imperative +} // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/operator.h b/imperative/src/include/megbrain/imperative/operator.h new file mode 100644 index 000000000..587c72972 --- /dev/null +++ b/imperative/src/include/megbrain/imperative/operator.h @@ -0,0 +1,102 @@ +/** + * \file imperative/src/include/megbrain/imperative/operator.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "megbrain/common.h" +#include "megbrain/imperative/utils/span.h" +#include "megbrain/imperative/value.h" + +namespace mgb { +namespace imperative { + +/** + * \brief base class for all operators + * + */ +class Operator { +public: + enum Kind { + IdentityLike, // one input, one output, output is like input + GetAttrLike, // no tensor output + Other, + }; + +private: + size_t m_typecode; + Kind m_kind; + +protected: + Operator(size_t typecode, Kind kind) : m_typecode{typecode}, m_kind{kind} {} + +public: + size_t typecode() const { return m_typecode; } + Kind kind() const { return m_kind; } + + template + U* as() const { + if (m_typecode != U::TYPE_CODE) { + return nullptr; + } + return static_cast(const_cast(this)); + } + template + bool is() const { + return as() != nullptr; + } + template + bool is() const { + return kind() == kKind; + } + template + U& cast() const { + U* ptr = as(); + mgb_assert(ptr); + return *ptr; + } + + virtual std::string to_string() const = 0; + + /** + * \brief fallback implementation of this. Not all operators has fallback + * implementation. + * + * \param inputs + * \return std::vector + */ + virtual std::vector fallback(Span inputs) const; + + std::type_index type() const { return registered_types()[m_typecode]; } + + static size_t register_type(std::type_index type); + static const std::vector& registered_types(); +}; + +template +class OperatorImpl : public Operator { +protected: + OperatorImpl() : Operator(TYPE_CODE, kKind) {} + +public: + static inline size_t TYPE_CODE = [] { return register_type(typeid(T)); }(); + + std::string to_string() const override = 0; +}; + +} // namespace imperative +} // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/transformation.h b/imperative/src/include/megbrain/imperative/transformation.h new file mode 100644 index 000000000..322f3cc5f --- /dev/null +++ b/imperative/src/include/megbrain/imperative/transformation.h @@ -0,0 +1,199 @@ +/** + * \file imperative/src/include/megbrain/imperative/transformation.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include +#include +#include + +#include "megbrain/common.h" +#include "megbrain/imperative/subgraph.h" +#include "megbrain/imperative/utils/local_ptr.h" +#include "megbrain/imperative/utils/span.h" + +namespace mgb { +namespace imperative { + +class ValueRef; +class Operator; +class Transformation; + +/** + * \brief args of dispatch action + * + */ +struct TransformationFrame { + const Operator& op; + const Span& inputs; +}; + +struct TransformationContext { + std::vector> transformations; + std::vector scopes; + // TODO: deprecate TransformationGuard, let next_transformation == frames.size() + size_t next_transformation = 0; + std::vector frames; +}; + +/** + * \brief Transformation handles operation requests. + * + * There is an transformation stack in each context. When user send an operation + * request, it is firstly passed to the top transformation. When a transformation in the + * stack receiving a request, it should handle it and give a response. Transformations + * are allowed to send requests when handling other requests, those requests would be + * sent to downstairs. A transformation can only be added to one stack. + */ +class Transformation : public std::enable_shared_from_this { +public: + using pos_t = + decltype(std::declval().transformations)::iterator; + + class TransformationGuard { + private: + size_t m_priority; + + public: + TransformationGuard(size_t priority) : m_priority{priority} { + auto& context = get_context(); + std::swap(m_priority, context.next_transformation); + mgb_assert( + context.next_transformation <= context.transformations.size(), + "invalid priority: %zu vs %zu", context.next_transformation, + context.transformations.size()); + } + ~TransformationGuard() { + std::swap(m_priority, get_context().next_transformation); + } + }; + +private: + size_t m_priority = std::numeric_limits::max(); + +public: + /** + * \brief handle a dispatch request + * + * \param op + * \param inputs + * \return std::vector + */ + virtual std::vector apply_transformation( + const Operator& op, Span inputs) = 0; + + virtual ValueRef unwrap(ValueRef value) = 0; + + virtual std::string name() const = 0; + + /** + * \brief called when added to a stack. + */ + virtual void on_register(){}; + + /** + * \brief called when remove from a stack. + * + * Some transformations, like GradTransformation and TraceTransformation, produce + * special values when handling requests. Thus they should recover these values on + * unregistering because other transformations cann't recognize them. + */ + virtual void on_unregister() noexcept {}; + +public: + static auto top() { return get_context().transformations.begin(); } + static auto bottom() { return get_context().transformations.end(); } + static void push_scope(std::string scope) { get_context().scopes.push_back(scope); } + static void pop_scope(std::string scope) { + auto& context = get_context(); + auto top = context.scopes.back(); + context.scopes.pop_back(); + mgb_assert(top == scope); + } + static std::vector scopes() { return get_context().scopes; } + + /** + * \brief position at transformation stack + * + * \return auto position + */ + auto pos() const { + mgb_assert( + m_priority != std::numeric_limits::max(), "not yet registered"); + return top() + m_priority; + } + + /** + * \brief register this at given position + * + * \param pos position + */ + void register_at(pos_t pos) { + auto& context = get_context(); + mgb_assert( + m_priority == std::numeric_limits::max(), "already registered"); + size_t priority = pos - context.transformations.begin(); + for (auto iter = pos; iter != context.transformations.end(); ++iter) { + iter->get()->m_priority++; + } + m_priority = priority; + context.transformations.insert(pos, shared_from_this()); + { + TransformationGuard _{m_priority + 1}; + on_register(); + } + // assert priority + } + + /** + * \brief unregister this from transformation stack + */ + void unregister() noexcept { + auto& context = get_context(); + mgb_assert( + m_priority != std::numeric_limits::max(), "not yet registered"); + { + TransformationGuard _{m_priority + 1}; + on_unregister(); + } + size_t priority = m_priority; + auto pos = top() + priority; + for (auto iter = pos; iter != context.transformations.end(); ++iter) { + iter->get()->m_priority--; + } + m_priority = std::numeric_limits::max(); + context.transformations.erase(pos); + // TODO: assert priority + } + // FIXME: deprecated + [[nodiscard]] TransformationGuard current_level_guard() { return m_priority; } + + /** + * \brief swap current context with target + * + * \param context target context + */ + static void swap_context(TransformationContext& context) { + auto& current_context = get_context(); + std::swap(context.transformations, current_context.transformations); + std::swap(context.scopes, current_context.scopes); + std::swap(context.next_transformation, current_context.next_transformation); + } + + static TransformationContext& get_context(); + + friend std::vector apply(const Operator& op, Span inputs); + friend class ValueRef; +}; + +} // namespace imperative +} // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/utils/debug.h b/imperative/src/include/megbrain/imperative/utils/debug.h new file mode 100644 index 000000000..09596ea36 --- /dev/null +++ b/imperative/src/include/megbrain/imperative/utils/debug.h @@ -0,0 +1,20 @@ +/** + * \file imperative/src/include/megbrain/imperative/utils/debug.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include + +namespace mgb::imperative::debug { + +void notify_event(const char* event); + +} diff --git a/imperative/src/include/megbrain/imperative/value.h b/imperative/src/include/megbrain/imperative/value.h new file mode 100644 index 000000000..0fa7e978a --- /dev/null +++ b/imperative/src/include/megbrain/imperative/value.h @@ -0,0 +1,388 @@ +/** + * \file imperative/src/include/megbrain/imperative/value.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include "megbrain/common.h" +#include "megbrain/imperative/subgraph.h" +#include "megbrain/imperative/utils/allocator.h" +#include "megbrain/imperative/utils/debug.h" +#include "megbrain/imperative/utils/local_ptr.h" +#include "megbrain/imperative/utils/span.h" + +namespace mgb { +namespace imperative { + +class Value; +class ValueRef; + +template +class TypedValueRef; + +template +class TypedValueWeakRef; + +class Transformation; + +class HostValue; +class DeviceValue; +class ShapeValue; +class DTypeValue; +class CompNodeValue; +class StringValue; + +class Operator; + +/** + * \brief an smart reference of value + * + * An ValueRef is either empty or refers to a value. Values are organized as linked lists + * and only the tail node is valid. ValueRef stores a value node, and it may be + * an invalid internal node. When you dereference it, it will check its successor, + * automatically find the tail node and return. This list would be modified to reduce + * list length by change value's successor, but a ValueRef always has steady m_storage + * when not explicitly modified. + * So we use m_storage to identify a ValueRef ( hash / equility / id ). + */ +class ValueRef { +public: + using storage_t = LocalPtr; + +protected: + mutable storage_t m_storage; + + ValueRef(storage_t storage) { m_storage = storage; } + +private: + /** + * \brief recursive get dest value storage and shorten path + * + * \return storage_t dest storage + */ + storage_t& storage() const; + +public: + ValueRef() = default; + + /** + * \brief whether value is instance of target type or not + * + * \tparam TValue target type + * \return true if type of value is TValue + * \return false if empty or type of value is not TValue + */ + template + bool is() const; + + /** + * \brief try cast value as target type + * + * \tparam TValue target type + * \return TValue* raw pointer if success, otherwise nullptr + */ + template + const TValue* as() const; + + /** + * \brief cast value to target type + * + * \tparam TValue target type + * \return TValue& reference of value + */ + template + const TValue& cast() const; + + /** + * \brief like as(), but returns TypedValueRef instead + * + * \tparam TValue target type + * \return TypedValueRef reference if success, otherwise empty reference + */ + template + inline TypedValueRef as_ref() const; + + operator bool() const { return bool(m_storage); } + + TypedValueRef dev_tensor() const; + TypedValueRef numpy() const; + TypedValueRef device() const; + TypedValueRef shape() const; + TypedValueRef dtype() const; + TypedValueRef name() const; + bool is_scalar() const; + + void watch() const; + void unwatch() const; + bool watching() const; + + ValueRef unwrap() const; + std::string to_string() const; + std::string raw_type() const; + uint64_t id() const; + size_t hash() const { return id(); } + + static ValueRef make(storage_t storage); + + static bool any_watching(); + + friend class ValueWeakRef; + template + friend class TypedValueRef; + template + friend class ValueImpl; + friend std::vector apply(const Operator& op, Span inputs); +}; + +template <> +struct ToStringTrait { +public: + std::string operator()(const ValueRef& value) const { return value.to_string(); } +}; + +class ValueWeakRef { +public: + using storage_t = ValueRef::storage_t::weak_type; + +protected: + uint64_t m_id = std::numeric_limits::max(); + mutable storage_t m_storage; + +public: + ValueWeakRef() = default; + ValueWeakRef(ValueRef value) : m_id(value.id()), m_storage(value.m_storage) {} + + /** + * \brief try promote to ValueRef + * + * \return ValueRef strong ref if value exists, otherwise empty ref + */ + ValueRef lock(); + size_t hash() const { return m_id; } + + bool operator==(const ValueWeakRef& rhs) const { + return m_storage == rhs.m_storage; + } + bool operator!=(const ValueWeakRef& rhs) const { return !(*this == rhs); } +}; + +/** + * \brief base class for all generic value involved in dispatch system + * + */ +class Value : public NonCopyableObj { +private: + uint64_t m_id = std::numeric_limits::max(); + size_t m_typecode = 0; + ValueRef m_successor; + size_t m_watching = 0; + +protected: + Value(size_t typecode); + +public: + size_t typecode() const { return m_typecode; } + const std::type_index type() const { return registered_types()[m_typecode]; } + + static size_t register_type(std::type_index type); + static const std::vector& registered_types(); + + static void register_value(ValueRef value); + static ValueRef get_value_by_id(uint64_t id); + static void begin_record_values(); + static std::vector end_record_values(); + + virtual std::string to_string() const = 0; + + /** + * \brief clear all states of this value + * + */ + virtual void clear() = 0; + + virtual void on_watch() {} + virtual void on_unwatch() {} + + friend class ValueRef; + friend class ValueWeakRef; + + template + friend class ValueImpl; + template + friend class TypedValueRef; + + ~Value(); + +private: + void try_rethrow(); +}; + +/** + * \brief base class of values, with typecode and factory method support + * + * \tparam T type of value + */ +template +class ValueImpl : public Value { +protected: + ValueImpl() : Value(TYPE_CODE) {} + +public: + using ref_t = TypedValueRef; + using weak_ref_t = TypedValueWeakRef; + + static inline size_t TYPE_CODE = [] { return register_type(typeid(T)); }(); + + /** + * \brief helper function for construct a value + * + * \tparam TArgs types of arguments + * \param args arguments + * \return TypedValueRef reference of value + */ + template + static TypedValueRef make(TArgs&&... args) { + static_assert(std::is_final_v); + return ValueRef::make(LocalPtr::make(std::forward(args)...)); + } +}; + +/** + * \brief base class of values, with mixin support + * + * \tparam T type of value + * \tparam TMixin type of mixin class + */ +template +class MixinValueImpl : public ValueImpl, public TMixin { +public: + using TMixin::TMixin; + + MixinValueImpl(TMixin mixin) : TMixin(std::move(mixin)) {} + +public: + void clear() override final { ((TMixin&)*this) = {}; } + + bool eq(const TMixin& value) const { return ((const TMixin&)*this) == value; } +}; + +template +const TValue* ValueRef::as() const { + static_assert(std::is_base_of_v, TValue>); + auto storage = this->storage(); + if (!storage) { + return nullptr; + } + if (storage->m_typecode != TValue::TYPE_CODE) { + return nullptr; + } + return static_cast(storage.get()); +} + +template +const TValue& ValueRef::cast() const { + auto* ptr = as(); + if (!ptr) { + // if this is ErrorValue, rethrow directly + storage()->try_rethrow(); + mgb_assert( + ptr, "expect type %s, got %s", typeid(TValue).name(), + to_string().c_str()); + } + return *ptr; +} + +template +bool ValueRef::is() const { + auto* ptr = as(); + return ptr != nullptr; +} + +template +TypedValueRef ValueRef::as_ref() const { + if (!is()) { + return {}; + } + return TypedValueRef(*this); +} + +/** + * \brief ValueRef with concrete type, convenient for dereference + * + * \tparam T type of value + */ +template +class TypedValueRef : public ValueRef { +private: + TypedValueRef(ValueRef value) : ValueRef(value) {} + +public: + TypedValueRef() = default; + const T& operator*() const { return this->template cast(); } + const T* operator->() const { return this->template as(); } + + /** + * \brief reset underlying value to another value + * + * \param successor new value + */ + inline void reset(ValueRef successor) { + mgb_assert(m_storage); + mgb_assert(!m_storage->m_successor); + if (m_storage->m_watching) { + debug::notify_event("reset"); + } + m_storage->clear(); + m_storage->m_successor = ValueRef(successor.storage()); + } + + friend class ValueRef; + + template + friend class ValueImpl; +}; + +template +class TypedValueWeakRef : public ValueWeakRef { +private: +public: + TypedValueWeakRef(ValueRef value) : ValueWeakRef(value) {} + TypedValueWeakRef(ValueWeakRef value) : ValueWeakRef(value) {} + TypedValueRef lock() { return ValueWeakRef::lock().template as_ref(); } +}; + +// TODO: add proxy value type, which is meant to be reset in the end + +} // namespace imperative +} // namespace mgb + +namespace std { + +template <> +struct hash { + std::size_t operator()(const mgb::imperative::ValueWeakRef& weak_ref) const { + return weak_ref.hash(); + } +}; + +template <> +struct hash { + std::size_t operator()(const mgb::imperative::ValueRef& ref) const { + return ref.hash(); + } +}; + +} // namespace std -- GitLab