提交 2e42bc08 编写于 作者: M Megvii Engine Team

feat(dispatch): implement new op dispatch system

GitOrigin-RevId: 355da6b81499b4a519e5a9156c3aab6ee7263988
上级 2be6ceda
#include "megbrain/imperative/basic_operators.h"
#include "megbrain/imperative/basic_values.h"
namespace mgb {
namespace imperative {
std::string ApplyOp::to_string() const {
return m_op.to_string();
}
std::string GetAttr::to_string() const {
std::string buffer;
const char* attr_name = ([&] {
switch (m_attr) {
case None:
return "None";
case DType:
return "DType";
case Device:
return "Device";
case Shape:
return "Shape";
case Value:
return "Value";
case Data:
return "Data";
default:
buffer = std::to_string(m_attr);
return buffer.c_str();
}
})();
return ssprintf("GetAttr{attr=%s}", attr_name);
}
CreateTensor::CreateTensor(Kind kind, CompNode device, DType dtype, ValueShape shape)
: m_kind(kind), m_device(device), m_dtype(dtype), m_shape(shape) {}
CreateTensor::CreateTensor(Kind kind, CompNode device, TensorLayout layout)
: m_kind(kind),
m_device(device),
m_dtype(layout.dtype),
m_shape(ValueShape::from(layout)) {
mgb_assert(
layout.is_contiguous() || layout.is_empty(), "layout should be contiguous");
}
auto CreateTensor::parse(Span<ValueRef> inputs) -> Args {
Args result;
for (auto&& input : inputs) {
if (auto host_storage = input.as_ref<HostStorage>()) {
mgb_assert(!result.host, "duplicated host value");
result.host.emplace();
result.host->reset(*host_storage, {shape().as_tensor_shape(), dtype()});
mgb_assert(result.host->layout().ndim, "invalid shape");
} else if (auto device_storage = input.as_ref<DeviceStorage>()) {
mgb_assert(!result.device, "duplicated device value");
result.device.emplace(device(), shape().as_tensor_shape(), dtype());
result.device->reset(*device_storage, {shape().as_tensor_shape(), dtype()});
mgb_assert(result.device->layout().ndim, "invalid shape");
} else {
mgb_throw(
MegBrainError,
"unknown input type, expects HostStorage or DeviceStorage, got "
"%s",
input.name()->c_str());
}
}
mgb_assert(
result.host || result.device, "require at least one of host/device value");
result.kind = kind();
return result;
}
std::string CreateTensor::to_string() const {
return ssprintf(
"CreateTensor{kind=%d, device=%s, dtype=%s, shape=%s}", (int)m_kind,
m_device.to_string().c_str(), m_dtype.name(), m_shape.to_string().c_str());
}
std::string DTRCommand::to_string() const {
return ssprintf("DTRCommandValue{kind=%d}", (int)m_kind);
}
std::string GetName::to_string() const {
return "GetName{}";
}
std::string RenameValue::to_string() const {
return ssprintf("RenameValue{name=%s}", imperative::quoted(m_name).c_str());
}
std::string IsScalar::to_string() const {
return "IsScalar";
}
} // namespace imperative
} // namespace mgb
#include "megbrain/imperative/basic_values.h"
namespace mgb {
namespace imperative {
std::string ShapeValue::to_string() const {
return ssprintf("ValueShape%s", ValueShape::to_string().c_str());
}
std::string CompNodeValue::to_string() const {
return CompNode::to_string();
}
std::string BoolValue::to_string() const {
return (*m_value) ? "true" : "false";
}
std::string HostStorage::to_string() const {
return ssprintf("HostStorage{device=%s}", comp_node().to_string().c_str());
}
std::string DeviceStorage::to_string() const {
return ssprintf("DeviceStorage{device=%s}", comp_node().to_string().c_str());
}
std::string HostValue::to_string() const {
return ssprintf(
"HostValue{device=%s, dtype=%s, shape=%s}", device().to_string().c_str(),
m_dtype.name(), m_shape.to_string().c_str());
}
HostTensorND HostValue::as_nd(bool allow_scalar) const {
HostTensorND nd;
TensorShape tensor_shape;
if (m_shape.is_scalar()) {
mgb_assert(allow_scalar);
tensor_shape = TensorShape{1};
} else {
tensor_shape = m_shape.as_tensor_shape();
}
nd.reset(m_storage, {tensor_shape, dtype()});
return nd;
}
std::string DeviceValue::to_string() const {
return ssprintf(
"DeviceValue{device=%s, dtype=%s, shape=%s}", device().to_string().c_str(),
m_dtype.name(), m_shape.to_string().c_str());
}
DeviceTensorND DeviceValue::as_nd(bool allow_scalar) const {
DeviceTensorND nd;
TensorShape tensor_shape;
if (m_shape.is_scalar()) {
mgb_assert(allow_scalar);
tensor_shape = TensorShape{1};
} else {
tensor_shape = m_shape.as_tensor_shape();
}
nd.reset(m_storage, {tensor_shape, dtype()});
return nd;
}
std::string FunctionValue::to_string() const {
return ssprintf("FunctionValue{type=%s}", target_type().name());
}
std::string DTypeValue::to_string() const {
return DType::name();
}
std::string StringValue::to_string() const {
return imperative::quoted((std::string&)*this);
}
std::string ErrorValue::to_string() const {
return ssprintf("ErrorValue{message=%s}", message().c_str());
}
} // namespace imperative
} // namespace mgb
/**
* \file imperative/src/impl/dispatch.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/utils/debug.h"
#include "megbrain/imperative/utils/helper.h"
#include "megbrain/imperative/utils/map.h"
namespace mgb {
namespace imperative {
std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs) {
static bool log_dispatch = MGB_GETENV("MGE_LOG_OP_DISPATCH");
bool enable_watch = ValueRef::any_watching();
auto& context = Transformation::get_context();
size_t& depth = context.next_transformation;
static const char tabs_storage[] = "\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t";
const char* tabs = tabs_storage + sizeof(tabs_storage) / sizeof(char) - depth - 1;
bool log_current_dispatch = log_dispatch;
if (enable_watch) {
for (size_t i = 0; i < inputs.size(); ++i) {
auto& input = inputs[i];
if (input.watching()) {
log_current_dispatch = true;
mgb_log_debug("%sinput[%zu] is %s", tabs, i, input.to_string().c_str());
debug::notify_event("apply");
}
}
}
// entrance
std::vector<ValueRef> outputs;
if (depth >= context.transformations.size()) {
// fallback
if (log_current_dispatch) {
mgb_log_debug(
"%sfallback apply %s in %s", tabs, op.to_string().c_str(),
imperative::to_string(inputs).c_str());
}
outputs = op.fallback(inputs);
} else {
// dispatch to stack top
auto& transformation = *context.transformations[depth];
++depth;
context.frames.push_back({op, inputs});
CleanupGuard _{[&] {
context.frames.pop_back();
--depth;
}};
if (log_current_dispatch) {
mgb_log_debug(
"%s%s apply %s in %s", tabs, transformation.name().c_str(),
op.to_string().c_str(), imperative::to_string(inputs).c_str());
}
outputs = transformation.apply_transformation(op, inputs);
}
if (log_current_dispatch) {
mgb_log_debug("%sreturn %s", tabs, imperative::to_string(outputs).c_str());
}
return outputs;
}
std::vector<ValueRef> apply(const OpDef& def, Span<ValueRef> inputs) {
return imperative::apply(ApplyOp{def}, inputs);
}
std::vector<ValueRef> apply(Subgraph graph, Span<ValueRef> inputs) {
SmallVector<ValueRef> inputs_storage;
for (size_t i = 0; i < inputs.size(); ++i) {
inputs_storage.push_back(inputs[i]);
}
auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<ValueRef> inputs,
size_t) {
auto outputs = imperative::apply(ApplyOp(*op), inputs);
return SmallVector<ValueRef>(outputs.begin(), outputs.end());
};
auto make_const = [](TensorPtr constant) -> ValueRef {
auto host_value = constant->get_value();
auto device_value = constant->dev_tensor();
mgb_assert(
host_value.layout().is_contiguous() &&
device_value.layout().is_contiguous());
ValueShape shape;
// FIXME: assume Tensor with shape {1} is scalar
if (!constant->shape().is_scalar()) {
shape = ValueShape::from(constant->shape());
}
return imperative::apply(
CreateTensor(
CreateTensor::Const, constant->comp_node(), constant->dtype(),
shape),
HostStorage::make(host_value.storage()),
DeviceStorage::make(device_value.storage()))[0];
};
auto outputs = graph.apply(inputs_storage, apply_functor, make_const);
return {outputs.begin(), outputs.end()};
}
} // namespace imperative
} // namespace mgb
#include "megbrain/imperative/operator.h"
namespace mgb {
namespace imperative {
std::vector<ValueRef> Operator::fallback(Span<ValueRef> inputs) const {
mgb_throw(MegBrainError, "no fallback implementation for %s", to_string().c_str());
}
size_t Operator::register_type(std::type_index type) {
auto& types = const_cast<std::vector<std::type_index>&>(registered_types());
types.push_back(type);
return types.size() - 1;
}
const std::vector<std::type_index>& Operator::registered_types() {
static std::vector<std::type_index> sm_registered_types;
return sm_registered_types;
}
} // namespace imperative
} // namespace mgb
#include "megbrain/imperative/transformation.h"
namespace mgb {
namespace imperative {
TransformationContext& Transformation::get_context() {
thread_local TransformationContext tl_context;
return tl_context;
}
} // namespace imperative
} // namespace mgb
/**
* \file imperative/src/impl/utils/debug.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include <typeindex>
#include "megbrain/imperative/transformation.h"
#include "megbrain/imperative/utils/debug.h"
#include "megbrain/imperative/value.h"
namespace mgb::imperative::debug {
const char* get_type_name(const std::type_info& type) {
return type.name();
}
const char* get_type_name(const std::type_index& type) {
return type.name();
}
void notify_event(const char* event) {}
void watch_value(ValueRef value) {
value.watch();
}
} // namespace mgb::imperative::debug
\ No newline at end of file
#include "megbrain/imperative/value.h"
#include "megbrain/imperative/basic_operators.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/utils/map.h"
namespace mgb {
namespace imperative {
namespace {
static thread_local size_t nr_watched_values = 0;
static thread_local uint64_t nr_values = 0;
static thread_local bool recording_values = false;
static thread_local std::vector<ValueWeakRef> recorded_values;
static WeakValueMap<uint64_t, ValueWeakRef> registered_values;
} // namespace
ValueRef::storage_t& ValueRef::storage() const {
if (!m_storage) {
return m_storage;
}
if (auto& storage = m_storage->m_successor.m_storage) {
while (storage->m_successor.m_storage) {
storage = storage->m_successor.m_storage;
}
return storage;
} else {
return m_storage;
}
}
TypedValueRef<DeviceValue> ValueRef::dev_tensor() const {
return imperative::apply(GetAttr(GetAttr::Data), *this)[0].as_ref<DeviceValue>();
}
TypedValueRef<HostValue> ValueRef::numpy() const {
return imperative::apply(GetAttr(GetAttr::Value), *this)[0].as_ref<HostValue>();
}
TypedValueRef<CompNodeValue> ValueRef::device() const {
return imperative::apply(GetAttr(GetAttr::Device), *this)[0]
.as_ref<CompNodeValue>();
}
TypedValueRef<ShapeValue> ValueRef::shape() const {
return imperative::apply(GetAttr(GetAttr::Shape), *this)[0].as_ref<ShapeValue>();
}
TypedValueRef<DTypeValue> ValueRef::dtype() const {
return imperative::apply(GetAttr(GetAttr::DType), *this)[0].as_ref<DTypeValue>();
}
TypedValueRef<StringValue> ValueRef::name() const {
return imperative::apply(GetName(), *this)[0].as_ref<StringValue>();
}
bool ValueRef::is_scalar() const {
return imperative::apply(IsScalar(), *this)[0].cast<BoolValue>();
}
void ValueRef::watch() const {
mgb_assert(m_storage);
storage()->m_watching++;
nr_watched_values++;
storage()->on_watch();
// TODO:
// imperative::apply(Watch(), this);
}
void ValueRef::unwatch() const {
mgb_assert(m_storage);
storage()->m_watching--;
nr_watched_values--;
storage()->on_unwatch();
}
ValueRef ValueRef::unwrap() const {
ValueRef value = *this;
auto& context = Transformation::get_context();
for (size_t i = 0; i < context.next_transformation; ++i) {
value = context.transformations[i]->unwrap(value);
}
mgb_assert(value);
return value;
}
std::string ValueRef::to_string() const {
if (!m_storage) {
return "<empty value>";
}
return ssprintf(
"(%zu:%zu) %s", id(), storage()->m_id, storage()->to_string().c_str());
}
std::string ValueRef::raw_type() const {
if (!m_storage) {
return "null";
}
auto& types = Value::registered_types();
mgb_assert(types.size() > m_storage->m_typecode);
return types[m_storage->m_typecode].name();
}
uint64_t ValueRef::id() const {
return m_storage ? m_storage->m_id : std::numeric_limits<uint64_t>::max();
}
bool ValueRef::watching() const {
auto storage = this->storage();
return storage && storage->m_watching;
}
ValueRef ValueRef::make(ValueRef::storage_t storage) {
if (recording_values) {
recorded_values.push_back({storage});
}
return {storage};
}
bool ValueRef::any_watching() {
return nr_watched_values != 0;
}
ValueRef ValueWeakRef::lock() {
auto strong_storage = m_storage.lock();
if ((!strong_storage) || strong_storage->m_successor) {
return {};
}
return {strong_storage};
}
Value::Value(size_t typecode) : m_typecode{typecode} {
m_id = nr_values++;
}
Value::~Value() {
if (m_watching) {
debug::notify_event("dtor");
}
}
size_t Value::register_type(std::type_index type) {
auto& types = const_cast<std::vector<std::type_index>&>(registered_types());
types.push_back(type);
return types.size() - 1;
}
const std::vector<std::type_index>& Value::registered_types() {
static std::vector<std::type_index> sm_registered_types;
return sm_registered_types;
}
void Value::register_value(ValueRef value) {
registered_values[value.id()] = ValueWeakRef(value);
}
ValueRef Value::get_value_by_id(uint64_t id) {
auto& weak_value = registered_values[id];
if (auto value = weak_value.lock()) {
return value;
}
return {};
}
void Value::begin_record_values() {
mgb_assert(!recording_values);
recording_values = true;
recorded_values.clear();
}
std::vector<ValueRef> Value::end_record_values() {
recording_values = false;
std::vector<ValueRef> recorded_strong_values;
for (auto&& weak_value : recorded_values) {
if (auto value = weak_value.lock()) {
recorded_strong_values.push_back(value);
}
}
return recorded_strong_values;
}
void Value::try_rethrow() {
if (m_typecode == ErrorValue::TYPE_CODE) {
auto message = static_cast<ErrorValue*>(this)->message();
mgb_throw(MegBrainError, "invalid value: %s", message.c_str());
}
}
} // namespace imperative
} // namespace mgb
/**
* \file imperative/src/include/megbrain/imperative/basic_operators.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <future>
#include <iomanip>
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/operator.h"
#include "megbrain/imperative/utils/helper.h"
#include "megbrain/imperative/utils/value_shape.h"
namespace mgb {
namespace imperative {
class GradKey;
using GenericFunction = std::function<std::vector<ValueRef>(Span<ValueRef>)>;
/**
* \brief apply an OpDef to values
*
*/
class ApplyOp final : public OperatorImpl<ApplyOp> {
private:
const OpDef& m_op;
public:
ApplyOp(const OpDef& op) : m_op(op) {}
const OpDef& op() { return m_op; }
std::string to_string() const override;
};
/**
* \brief get an basic attribute from Value
*
*/
class GetAttr final : public OperatorImpl<GetAttr, Operator::GetAttrLike> {
public:
enum Attr {
None,
DType,
Device,
Shape,
Value,
Data,
};
private:
Attr m_attr = None;
public:
GetAttr(Attr attr) : m_attr(attr) {
mgb_assert(attr != None, "invalid attr value: None");
}
Attr attr() const { return m_attr; }
std::string to_string() const;
};
/**
* \brief create a tensor value from host value or device value
*
*/
class CreateTensor final : public OperatorImpl<CreateTensor> {
public:
enum Kind {
Common, // common mode, h2d can be cached to speed up
Unique, // require output value to be unqiue (donnot share memory with other
// values)
Const, // put as constant (guaranteed to be same each time)
NoTrace, // won't be trace in any case, would be used in make_backward_graph
// (looking for a better name)
};
struct Args {
std::optional<HostTensorND> host;
std::optional<DeviceTensorND> device;
Kind kind;
};
private:
Kind m_kind;
CompNode m_device;
DType m_dtype;
ValueShape m_shape;
public:
CreateTensor(Kind kind, CompNode device, DType dtype, ValueShape shape);
CreateTensor(Kind kind, CompNode device, TensorLayout layout);
/**
* \brief utility function to unpack args of CreateTensor
*
* \param inputs contains host_storage and device_storage
* \return Args unpacked args
*/
Args parse(Span<ValueRef> inputs);
Kind kind() const { return m_kind; }
CompNode device() const { return m_device; }
DType dtype() const { return m_dtype; }
ValueShape shape() const { return m_shape; }
std::string to_string() const override;
};
class DTRCommand final : public OperatorImpl<DTRCommand, Operator::GetAttrLike> {
public:
enum Kind {
None,
Drop,
};
private:
Kind m_kind = None;
public:
DTRCommand(Kind kind) : m_kind(kind) {}
Kind kind() { return m_kind; }
std::string to_string() const override;
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override { return {}; }
};
// deprecated
class GetName final : public OperatorImpl<GetName, Operator::GetAttrLike> {
public:
std::string to_string() const override;
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override {
return {ValueRef()};
}
};
/**
* \brief return a value with new name
*
*/
class RenameValue : public OperatorImpl<RenameValue, Operator::IdentityLike> {
private:
std::string m_name;
public:
RenameValue(std::string name) : m_name(name) {}
std::string name() const { return m_name; }
std::string to_string() const override;
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override {
return {inputs.as_array<1>()[0]};
}
};
class IsScalar final : public OperatorImpl<IsScalar, Operator::GetAttrLike> {
private:
public:
std::string to_string() const override;
};
} // namespace imperative
} // namespace mgb
/**
* \file imperative/src/include/megbrain/imperative/basic_values.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <future>
#include <iomanip>
#include "megbrain/imperative/utils/helper.h"
#include "megbrain/imperative/utils/value_shape.h"
#include "megbrain/imperative/value.h"
namespace mgb {
namespace imperative {
class GradKey;
using GenericFunction = std::function<std::vector<ValueRef>(Span<ValueRef>)>;
class ShapeValue final : public MixinValueImpl<ShapeValue, ValueShape> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override;
};
class CompNodeValue final : public MixinValueImpl<CompNodeValue, CompNode> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override;
};
// TODO: override factory method
class BoolValue final : public ValueImpl<BoolValue> {
private:
std::optional<bool> m_value;
public:
BoolValue(bool value) : m_value{value} {}
operator bool() const { return *m_value; }
std::string to_string() const override;
void clear() override { m_value.reset(); }
};
class HostStorage final : public MixinValueImpl<HostStorage, HostTensorStorage> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override;
};
class DeviceStorage final : public MixinValueImpl<DeviceStorage, DeviceTensorStorage> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override;
};
/**
* \brief like HostTensorND mixin, but allow scalar value
*
*/
class HostValue final : public ValueImpl<HostValue> {
private:
DType m_dtype;
ValueShape m_shape;
HostTensorStorage m_storage;
public:
HostValue(DType dtype, ValueShape shape, HostTensorStorage storage)
: m_dtype(dtype), m_shape(shape), m_storage(storage) {}
HostValue(HostTensorND value)
: HostValue(
value.dtype(), ValueShape::from(value.shape()), value.storage()) {
}
std::string to_string() const override;
void clear() override {
m_dtype = {};
m_shape = {};
m_storage = {};
}
DType dtype() const { return m_dtype; }
ValueShape shape() const { return m_shape; }
CompNode device() const { return m_storage.comp_node(); }
HostTensorStorage storage() const { return m_storage; }
HostTensorND as_nd(bool allow_scalar = false) const;
};
/**
* \brief like DeviceTensorND mixin, but allow scalar value
*
*/
class DeviceValue final : public ValueImpl<DeviceValue> {
private:
DType m_dtype;
ValueShape m_shape;
DeviceTensorStorage m_storage;
public:
DeviceValue(DType dtype, ValueShape shape, DeviceTensorStorage storage)
: m_dtype(dtype), m_shape(shape), m_storage(storage) {}
DeviceValue(DeviceTensorND value)
: DeviceValue(
value.dtype(), ValueShape::from(value.shape()), value.storage()) {
}
std::string to_string() const override;
void clear() override {
m_dtype = {};
m_shape = {};
m_storage = {};
}
DType dtype() const { return m_dtype; }
ValueShape shape() const { return m_shape; }
CompNode device() const { return m_storage.comp_node(); }
DeviceTensorStorage storage() const { return m_storage; }
DeviceTensorND as_nd(bool allow_scalar = false) const;
};
class FunctionValue final : public MixinValueImpl<FunctionValue, GenericFunction> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override;
};
class DTypeValue final : public MixinValueImpl<DTypeValue, DType> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override;
};
class StringValue final : public MixinValueImpl<StringValue, std::string> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override;
};
class Error {
protected:
std::string m_message;
public:
Error() = default;
Error(std::string message) : m_message(message) {}
std::string message() const { return m_message; }
};
class ErrorValue final : public MixinValueImpl<ErrorValue, Error> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override;
};
} // namespace imperative
} // namespace mgb
/**
* \file imperative/src/include/megbrain/imperative/dispatch.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <list>
#include <map>
#include <memory>
#include <typeinfo>
#include <vector>
#include "megbrain/common.h"
#include "megbrain/imperative/basic_operators.h"
#include "megbrain/imperative/basic_values.h"
#include "megbrain/imperative/operator.h"
#include "megbrain/imperative/subgraph.h"
#include "megbrain/imperative/transformation.h"
#include "megbrain/imperative/utils/local_ptr.h"
#include "megbrain/imperative/utils/span.h"
#include "megbrain/imperative/value.h"
namespace mgb {
namespace imperative {
/**
* \brief dispatch entrance, requests would be forwarded to current top transformation
* (or fallback)
*
* \param op
* \param inputs
* \return std::vector<ValueRef>
*/
std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs);
std::vector<ValueRef> apply(const OpDef& def, Span<ValueRef> inputs);
std::vector<ValueRef> apply(Subgraph graph, Span<ValueRef> inputs);
template <typename... TArgs>
constexpr bool is_all_value_ref_v =
(... && (std::is_base_of_v<ValueRef, std::decay_t<TArgs>> ||
std::is_same_v<ValueRef, std::decay_t<TArgs>>));
template <typename T, typename... TArgs>
static auto apply(T&& op, TArgs&&... args)
-> std::enable_if_t<is_all_value_ref_v<TArgs...>, std::vector<ValueRef>> {
ValueRef args_arr[sizeof...(TArgs)] = {std::forward<TArgs&&>(args)...};
return imperative::apply(
std::forward<T&&>(op),
Span<ValueRef>(std::begin(args_arr), std::end(args_arr)));
}
template <typename T, typename TContainer>
static auto apply(T&& op, TContainer&& container) -> std::enable_if_t<
std::is_same_v<
std::remove_const_t<std::remove_pointer_t<decltype(container.data())>>,
ValueRef> &&
std::is_same_v<decltype(container.size()), size_t> &&
!std::is_same_v<std::decay_t<TContainer>, Span<ValueRef>>,
std::vector<ValueRef>> {
return imperative::apply(
std::forward<T&&>(op), Span<ValueRef>(container.data(), container.size()));
}
} // namespace imperative
} // namespace mgb
/**
* \file imperative/src/include/megbrain/imperative/operator.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <list>
#include <map>
#include <memory>
#include <typeindex>
#include <typeinfo>
#include <vector>
#include "megbrain/common.h"
#include "megbrain/imperative/utils/span.h"
#include "megbrain/imperative/value.h"
namespace mgb {
namespace imperative {
/**
* \brief base class for all operators
*
*/
class Operator {
public:
enum Kind {
IdentityLike, // one input, one output, output is like input
GetAttrLike, // no tensor output
Other,
};
private:
size_t m_typecode;
Kind m_kind;
protected:
Operator(size_t typecode, Kind kind) : m_typecode{typecode}, m_kind{kind} {}
public:
size_t typecode() const { return m_typecode; }
Kind kind() const { return m_kind; }
template <typename U>
U* as() const {
if (m_typecode != U::TYPE_CODE) {
return nullptr;
}
return static_cast<U*>(const_cast<Operator*>(this));
}
template <typename U>
bool is() const {
return as<U>() != nullptr;
}
template <Kind kKind>
bool is() const {
return kind() == kKind;
}
template <typename U>
U& cast() const {
U* ptr = as<U>();
mgb_assert(ptr);
return *ptr;
}
virtual std::string to_string() const = 0;
/**
* \brief fallback implementation of this. Not all operators has fallback
* implementation.
*
* \param inputs
* \return std::vector<ValueRef>
*/
virtual std::vector<ValueRef> fallback(Span<ValueRef> inputs) const;
std::type_index type() const { return registered_types()[m_typecode]; }
static size_t register_type(std::type_index type);
static const std::vector<std::type_index>& registered_types();
};
template <typename T, Operator::Kind kKind = Operator::Other>
class OperatorImpl : public Operator {
protected:
OperatorImpl() : Operator(TYPE_CODE, kKind) {}
public:
static inline size_t TYPE_CODE = [] { return register_type(typeid(T)); }();
std::string to_string() const override = 0;
};
} // namespace imperative
} // namespace mgb
/**
* \file imperative/src/include/megbrain/imperative/transformation.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <list>
#include <map>
#include <memory>
#include <vector>
#include "megbrain/common.h"
#include "megbrain/imperative/subgraph.h"
#include "megbrain/imperative/utils/local_ptr.h"
#include "megbrain/imperative/utils/span.h"
namespace mgb {
namespace imperative {
class ValueRef;
class Operator;
class Transformation;
/**
* \brief args of dispatch action
*
*/
struct TransformationFrame {
const Operator& op;
const Span<ValueRef>& inputs;
};
struct TransformationContext {
std::vector<std::shared_ptr<Transformation>> transformations;
std::vector<std::string> scopes;
// TODO: deprecate TransformationGuard, let next_transformation == frames.size()
size_t next_transformation = 0;
std::vector<TransformationFrame> frames;
};
/**
* \brief Transformation handles operation requests.
*
* There is an transformation stack in each context. When user send an operation
* request, it is firstly passed to the top transformation. When a transformation in the
* stack receiving a request, it should handle it and give a response. Transformations
* are allowed to send requests when handling other requests, those requests would be
* sent to downstairs. A transformation can only be added to one stack.
*/
class Transformation : public std::enable_shared_from_this<Transformation> {
public:
using pos_t =
decltype(std::declval<TransformationContext>().transformations)::iterator;
class TransformationGuard {
private:
size_t m_priority;
public:
TransformationGuard(size_t priority) : m_priority{priority} {
auto& context = get_context();
std::swap(m_priority, context.next_transformation);
mgb_assert(
context.next_transformation <= context.transformations.size(),
"invalid priority: %zu vs %zu", context.next_transformation,
context.transformations.size());
}
~TransformationGuard() {
std::swap(m_priority, get_context().next_transformation);
}
};
private:
size_t m_priority = std::numeric_limits<size_t>::max();
public:
/**
* \brief handle a dispatch request
*
* \param op
* \param inputs
* \return std::vector<ValueRef>
*/
virtual std::vector<ValueRef> apply_transformation(
const Operator& op, Span<ValueRef> inputs) = 0;
virtual ValueRef unwrap(ValueRef value) = 0;
virtual std::string name() const = 0;
/**
* \brief called when added to a stack.
*/
virtual void on_register(){};
/**
* \brief called when remove from a stack.
*
* Some transformations, like GradTransformation and TraceTransformation, produce
* special values when handling requests. Thus they should recover these values on
* unregistering because other transformations cann't recognize them.
*/
virtual void on_unregister() noexcept {};
public:
static auto top() { return get_context().transformations.begin(); }
static auto bottom() { return get_context().transformations.end(); }
static void push_scope(std::string scope) { get_context().scopes.push_back(scope); }
static void pop_scope(std::string scope) {
auto& context = get_context();
auto top = context.scopes.back();
context.scopes.pop_back();
mgb_assert(top == scope);
}
static std::vector<std::string> scopes() { return get_context().scopes; }
/**
* \brief position at transformation stack
*
* \return auto position
*/
auto pos() const {
mgb_assert(
m_priority != std::numeric_limits<size_t>::max(), "not yet registered");
return top() + m_priority;
}
/**
* \brief register this at given position
*
* \param pos position
*/
void register_at(pos_t pos) {
auto& context = get_context();
mgb_assert(
m_priority == std::numeric_limits<size_t>::max(), "already registered");
size_t priority = pos - context.transformations.begin();
for (auto iter = pos; iter != context.transformations.end(); ++iter) {
iter->get()->m_priority++;
}
m_priority = priority;
context.transformations.insert(pos, shared_from_this());
{
TransformationGuard _{m_priority + 1};
on_register();
}
// assert priority
}
/**
* \brief unregister this from transformation stack
*/
void unregister() noexcept {
auto& context = get_context();
mgb_assert(
m_priority != std::numeric_limits<size_t>::max(), "not yet registered");
{
TransformationGuard _{m_priority + 1};
on_unregister();
}
size_t priority = m_priority;
auto pos = top() + priority;
for (auto iter = pos; iter != context.transformations.end(); ++iter) {
iter->get()->m_priority--;
}
m_priority = std::numeric_limits<size_t>::max();
context.transformations.erase(pos);
// TODO: assert priority
}
// FIXME: deprecated
[[nodiscard]] TransformationGuard current_level_guard() { return m_priority; }
/**
* \brief swap current context with target
*
* \param context target context
*/
static void swap_context(TransformationContext& context) {
auto& current_context = get_context();
std::swap(context.transformations, current_context.transformations);
std::swap(context.scopes, current_context.scopes);
std::swap(context.next_transformation, current_context.next_transformation);
}
static TransformationContext& get_context();
friend std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs);
friend class ValueRef;
};
} // namespace imperative
} // namespace mgb
/**
* \file imperative/src/include/megbrain/imperative/utils/debug.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <memory>
namespace mgb::imperative::debug {
void notify_event(const char* event);
}
/**
* \file imperative/src/include/megbrain/imperative/value.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <list>
#include <map>
#include <memory>
#include <typeindex>
#include <vector>
#include "megbrain/common.h"
#include "megbrain/imperative/subgraph.h"
#include "megbrain/imperative/utils/allocator.h"
#include "megbrain/imperative/utils/debug.h"
#include "megbrain/imperative/utils/local_ptr.h"
#include "megbrain/imperative/utils/span.h"
namespace mgb {
namespace imperative {
class Value;
class ValueRef;
template <typename T>
class TypedValueRef;
template <typename T>
class TypedValueWeakRef;
class Transformation;
class HostValue;
class DeviceValue;
class ShapeValue;
class DTypeValue;
class CompNodeValue;
class StringValue;
class Operator;
/**
* \brief an smart reference of value
*
* An ValueRef is either empty or refers to a value. Values are organized as linked lists
* and only the tail node is valid. ValueRef stores a value node, and it may be
* an invalid internal node. When you dereference it, it will check its successor,
* automatically find the tail node and return. This list would be modified to reduce
* list length by change value's successor, but a ValueRef always has steady m_storage
* when not explicitly modified.
* So we use m_storage to identify a ValueRef ( hash / equility / id ).
*/
class ValueRef {
public:
using storage_t = LocalPtr<Value>;
protected:
mutable storage_t m_storage;
ValueRef(storage_t storage) { m_storage = storage; }
private:
/**
* \brief recursive get dest value storage and shorten path
*
* \return storage_t dest storage
*/
storage_t& storage() const;
public:
ValueRef() = default;
/**
* \brief whether value is instance of target type or not
*
* \tparam TValue target type
* \return true if type of value is TValue
* \return false if empty or type of value is not TValue
*/
template <typename TValue>
bool is() const;
/**
* \brief try cast value as target type
*
* \tparam TValue target type
* \return TValue* raw pointer if success, otherwise nullptr
*/
template <typename TValue>
const TValue* as() const;
/**
* \brief cast value to target type
*
* \tparam TValue target type
* \return TValue& reference of value
*/
template <typename TValue>
const TValue& cast() const;
/**
* \brief like as(), but returns TypedValueRef instead
*
* \tparam TValue target type
* \return TypedValueRef<TValue> reference if success, otherwise empty reference
*/
template <typename TValue>
inline TypedValueRef<TValue> as_ref() const;
operator bool() const { return bool(m_storage); }
TypedValueRef<DeviceValue> dev_tensor() const;
TypedValueRef<HostValue> numpy() const;
TypedValueRef<CompNodeValue> device() const;
TypedValueRef<ShapeValue> shape() const;
TypedValueRef<DTypeValue> dtype() const;
TypedValueRef<StringValue> name() const;
bool is_scalar() const;
void watch() const;
void unwatch() const;
bool watching() const;
ValueRef unwrap() const;
std::string to_string() const;
std::string raw_type() const;
uint64_t id() const;
size_t hash() const { return id(); }
static ValueRef make(storage_t storage);
static bool any_watching();
friend class ValueWeakRef;
template <typename T>
friend class TypedValueRef;
template <typename T>
friend class ValueImpl;
friend std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs);
};
template <>
struct ToStringTrait<ValueRef> {
public:
std::string operator()(const ValueRef& value) const { return value.to_string(); }
};
class ValueWeakRef {
public:
using storage_t = ValueRef::storage_t::weak_type;
protected:
uint64_t m_id = std::numeric_limits<uint64_t>::max();
mutable storage_t m_storage;
public:
ValueWeakRef() = default;
ValueWeakRef(ValueRef value) : m_id(value.id()), m_storage(value.m_storage) {}
/**
* \brief try promote to ValueRef
*
* \return ValueRef strong ref if value exists, otherwise empty ref
*/
ValueRef lock();
size_t hash() const { return m_id; }
bool operator==(const ValueWeakRef& rhs) const {
return m_storage == rhs.m_storage;
}
bool operator!=(const ValueWeakRef& rhs) const { return !(*this == rhs); }
};
/**
* \brief base class for all generic value involved in dispatch system
*
*/
class Value : public NonCopyableObj {
private:
uint64_t m_id = std::numeric_limits<uint64_t>::max();
size_t m_typecode = 0;
ValueRef m_successor;
size_t m_watching = 0;
protected:
Value(size_t typecode);
public:
size_t typecode() const { return m_typecode; }
const std::type_index type() const { return registered_types()[m_typecode]; }
static size_t register_type(std::type_index type);
static const std::vector<std::type_index>& registered_types();
static void register_value(ValueRef value);
static ValueRef get_value_by_id(uint64_t id);
static void begin_record_values();
static std::vector<ValueRef> end_record_values();
virtual std::string to_string() const = 0;
/**
* \brief clear all states of this value
*
*/
virtual void clear() = 0;
virtual void on_watch() {}
virtual void on_unwatch() {}
friend class ValueRef;
friend class ValueWeakRef;
template <typename T>
friend class ValueImpl;
template <typename T>
friend class TypedValueRef;
~Value();
private:
void try_rethrow();
};
/**
* \brief base class of values, with typecode and factory method support
*
* \tparam T type of value
*/
template <typename T>
class ValueImpl : public Value {
protected:
ValueImpl() : Value(TYPE_CODE) {}
public:
using ref_t = TypedValueRef<T>;
using weak_ref_t = TypedValueWeakRef<T>;
static inline size_t TYPE_CODE = [] { return register_type(typeid(T)); }();
/**
* \brief helper function for construct a value
*
* \tparam TArgs types of arguments
* \param args arguments
* \return TypedValueRef<T> reference of value
*/
template <typename... TArgs>
static TypedValueRef<T> make(TArgs&&... args) {
static_assert(std::is_final_v<T>);
return ValueRef::make(LocalPtr<Value>::make<T>(std::forward<TArgs&&>(args)...));
}
};
/**
* \brief base class of values, with mixin support
*
* \tparam T type of value
* \tparam TMixin type of mixin class
*/
template <typename T, typename TMixin>
class MixinValueImpl : public ValueImpl<T>, public TMixin {
public:
using TMixin::TMixin;
MixinValueImpl(TMixin mixin) : TMixin(std::move(mixin)) {}
public:
void clear() override final { ((TMixin&)*this) = {}; }
bool eq(const TMixin& value) const { return ((const TMixin&)*this) == value; }
};
template <typename TValue>
const TValue* ValueRef::as() const {
static_assert(std::is_base_of_v<ValueImpl<TValue>, TValue>);
auto storage = this->storage();
if (!storage) {
return nullptr;
}
if (storage->m_typecode != TValue::TYPE_CODE) {
return nullptr;
}
return static_cast<TValue*>(storage.get());
}
template <typename TValue>
const TValue& ValueRef::cast() const {
auto* ptr = as<TValue>();
if (!ptr) {
// if this is ErrorValue, rethrow directly
storage()->try_rethrow();
mgb_assert(
ptr, "expect type %s, got %s", typeid(TValue).name(),
to_string().c_str());
}
return *ptr;
}
template <typename TValue>
bool ValueRef::is() const {
auto* ptr = as<TValue>();
return ptr != nullptr;
}
template <typename TValue>
TypedValueRef<TValue> ValueRef::as_ref() const {
if (!is<TValue>()) {
return {};
}
return TypedValueRef<TValue>(*this);
}
/**
* \brief ValueRef with concrete type, convenient for dereference
*
* \tparam T type of value
*/
template <typename T>
class TypedValueRef : public ValueRef {
private:
TypedValueRef(ValueRef value) : ValueRef(value) {}
public:
TypedValueRef() = default;
const T& operator*() const { return this->template cast<T>(); }
const T* operator->() const { return this->template as<T>(); }
/**
* \brief reset underlying value to another value
*
* \param successor new value
*/
inline void reset(ValueRef successor) {
mgb_assert(m_storage);
mgb_assert(!m_storage->m_successor);
if (m_storage->m_watching) {
debug::notify_event("reset");
}
m_storage->clear();
m_storage->m_successor = ValueRef(successor.storage());
}
friend class ValueRef;
template <typename U>
friend class ValueImpl;
};
template <typename T>
class TypedValueWeakRef : public ValueWeakRef {
private:
public:
TypedValueWeakRef(ValueRef value) : ValueWeakRef(value) {}
TypedValueWeakRef(ValueWeakRef value) : ValueWeakRef(value) {}
TypedValueRef<T> lock() { return ValueWeakRef::lock().template as_ref<T>(); }
};
// TODO: add proxy value type, which is meant to be reset in the end
} // namespace imperative
} // namespace mgb
namespace std {
template <>
struct hash<mgb::imperative::ValueWeakRef> {
std::size_t operator()(const mgb::imperative::ValueWeakRef& weak_ref) const {
return weak_ref.hash();
}
};
template <>
struct hash<mgb::imperative::ValueRef> {
std::size_t operator()(const mgb::imperative::ValueRef& ref) const {
return ref.hash();
}
};
} // namespace std
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册