diff --git a/imperative/src/impl/transformations/eval.cpp b/imperative/src/impl/transformations/eval.cpp new file mode 100644 index 0000000000000000000000000000000000000000..289975a38e374d2b92c9100af29104d6ed763c19 --- /dev/null +++ b/imperative/src/impl/transformations/eval.cpp @@ -0,0 +1,107 @@ +/** + * \file imperative/src/impl/transformations/trace.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/imperative/transformations/eval.h" +#include "megbrain/imperative/transformations/grad.h" + +namespace mgb { +namespace imperative { + +std::vector InterpreterTransformation::apply_transformation( + const Operator& op, Span inputs) { + if (auto* op_val = op.as()) { + if (op_val->op().same_type()) { + return {inputs[0]}; + } + SmallVector input_handles; + SmallVector output_handles; + CleanupGuard _{[&] { + for (auto handle : output_handles) { + if (handle) { + m_channel->del(handle); + } + } + }}; + for (auto input : inputs) { + input_handles.push_back(*input.cast().handle()); + } + output_handles = + m_channel->apply_op(op_val->op().shared_from_this(), input_handles); + std::vector outputs; + for (auto& handle : output_handles) { + outputs.push_back(InterpreterValue::make(share_handle(handle))); + handle = nullptr; + } + return outputs; + } else if (auto* get_attr = op.as()) { + Handle handle = *inputs[0].cast().handle(); + ValueRef output; + switch (get_attr->attr()) { + case GetAttr::DType: + output = DTypeValue::make(m_channel->get_dtype(handle)); + break; + case GetAttr::Shape: + output = ShapeValue::make( + ValueShape::from(m_channel->get_shape(handle))); + break; + case GetAttr::Device: + output = CompNodeValue::make(m_channel->get_device(handle)); + break; + case GetAttr::Value: + output = HostValue::make(m_channel->get_value(handle)); + break; + case GetAttr::Data: + output = DeviceValue::make(m_channel->get_dev_tensor(handle)); + break; + default: + mgb_throw( + MegBrainError, "Interpreter: malformed GetAttr: %s", + op.to_string().c_str()); + } + return {output}; + } else if (auto* create_tensor = op.as()) { + auto args = create_tensor->parse(inputs); + if (!args.device) { + // implies H2D + mgb_assert(args.host, "neither host and device value is valid"); + return {InterpreterValue::make(share_handle( + m_channel->put(*args.host, args.kind == CreateTensor::Unique)))}; + } else { + return {InterpreterValue::make(share_handle(m_channel->put( + *args.device, args.host ? *args.host : HostTensorND())))}; + } + } else if (auto* dtr_command = op.as()) { + auto handle = *inputs[0].cast().handle(); + switch (dtr_command->kind()) { + case DTRCommand::Drop: + m_channel->drop(handle); + break; + default: + mgb_throw(AssertionError, "unknown DTRCommand %d", dtr_command->kind()); + } + return {}; + } else if (auto* rename_value = op.as()) { + auto& input = inputs[0].cast(); + return {InterpreterValue::make(input.handle(), rename_value->name())}; + } else if (op.is()) { + auto name = inputs[0].cast().name(); + if (!name.empty()) { + return {StringValue::make(name)}; + } else { + return {ValueRef()}; + } + } else { + return imperative::apply(op, inputs); + } +} + +} // namespace imperative +} // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/transformations/eval.h b/imperative/src/include/megbrain/imperative/transformations/eval.h new file mode 100644 index 0000000000000000000000000000000000000000..a35c74de6a66a4485839b8d1cd2250e3927e81e4 --- /dev/null +++ b/imperative/src/include/megbrain/imperative/transformations/eval.h @@ -0,0 +1,95 @@ +/** + * \file imperative/src/include/megbrain/imperative/eval.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megbrain/imperative/dispatch.h" +#include "megbrain/imperative/interpreter.h" +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/imperative/utils/helper.h" + +namespace mgb::imperative { + +struct InterpreterInfo { +public: + using Handle = interpreter::Interpreter::Handle; + using Channel = interpreter::Interpreter::Channel; + +private: + std::shared_ptr m_handle = nullptr; + std::string m_name; + +public: + InterpreterInfo() = default; + InterpreterInfo(std::shared_ptr handle, std::string name = {}) + : m_handle(handle), m_name(name) {} + + std::shared_ptr handle() const { return m_handle; } + + std::string name() const { return m_name; } +}; + +class InterpreterValue final + : public MixinValueImpl { +public: + using MixinValueImpl::MixinValueImpl; + + std::string to_string() const override { + return ssprintf( + "Handle{ptr=%p, name=%s}", handle().get(), + imperative::quoted(name()).c_str()); + } +}; + +/** + * \brief interpret operations with interpreter + * + * This is the most basic and simplest transformation. It read operation requests and + * forwards them to interpreter. Not all tensor requests would be handled by it, + * some were resolved by CompiledTransformation or LazyEvalTransformation. + */ +class InterpreterTransformation final : public Transformation { +public: + using Interpreter = interpreter::Interpreter; + using Handle = Interpreter::Handle; + using Channel = Interpreter::Channel; + +private: + std::unique_ptr m_channel; + +public: + explicit InterpreterTransformation(std::unique_ptr channel) + : m_channel{std::move(channel)} {} + + Channel* channel() { return m_channel.get(); } + + std::vector apply_transformation( + const Operator& op, Span inputs) override; + + ValueRef unwrap(ValueRef value) override { + mgb_assert(!value.is()); + return value; + } + + std::string name() const override { return "InterpreterTransformation"; } + + std::shared_ptr share_handle(Handle handle) { + return std::shared_ptr( + new Handle(handle), [channel = m_channel.get()](Handle* ptr) { + if (ptr) { + channel->del(*ptr); + delete ptr; + } + }); + } +}; + +} // namespace mgb::imperative