/** * \file imperative/src/impl/transformations/trace.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/imperative/transformations/eval.h" #include "megbrain/imperative/transformations/grad.h" namespace mgb { namespace imperative { std::vector InterpreterTransformation::apply_transformation( const Operator& op, Span inputs) { if (auto* op_val = op.as()) { if (op_val->op().same_type()) { return {inputs[0]}; } SmallVector input_handles; SmallVector output_handles; CleanupGuard _{[&] { for (auto handle : output_handles) { if (handle) { m_channel->del(handle); } } }}; for (auto input : inputs) { input_handles.push_back(*input.cast().handle()); } output_handles = m_channel->apply_op(op_val->op().shared_from_this(), input_handles); std::vector outputs; for (auto& handle : output_handles) { outputs.push_back(InterpreterValue::make(share_handle(handle))); handle = nullptr; } return outputs; } else if (auto* get_attr = op.as()) { Handle handle = *inputs[0].cast().handle(); ValueRef output; switch (get_attr->attr()) { case GetAttr::DType: output = DTypeValue::make(m_channel->get_dtype(handle)); break; case GetAttr::Shape: output = ShapeValue::make( ValueShape::from(m_channel->get_shape(handle))); break; case GetAttr::Device: output = CompNodeValue::make(m_channel->get_device(handle)); break; case GetAttr::Value: output = HostValue::make(m_channel->get_value(handle)); break; case GetAttr::Data: output = DeviceValue::make(m_channel->get_dev_tensor(handle)); break; default: mgb_throw( MegBrainError, "Interpreter: malformed GetAttr: %s", op.to_string().c_str()); } return {output}; } else if (auto* create_tensor = op.as()) { auto args = create_tensor->parse(inputs); if (!args.device) { // implies H2D mgb_assert(args.host, "neither host and device value is valid"); return {InterpreterValue::make(share_handle( m_channel->put(*args.host, args.kind == CreateTensor::Unique)))}; } else { return {InterpreterValue::make(share_handle(m_channel->put( *args.device, args.host ? *args.host : HostTensorND())))}; } } else if (auto* dtr_command = op.as()) { auto handle = *inputs[0].cast().handle(); switch (dtr_command->kind()) { case DTRCommand::Drop: m_channel->drop(handle); break; default: mgb_throw(AssertionError, "unknown DTRCommand %d", dtr_command->kind()); } return {}; } else if (auto* rename_value = op.as()) { auto& input = inputs[0].cast(); return {InterpreterValue::make(input.handle(), rename_value->name())}; } else if (op.is()) { auto name = inputs[0].cast().name(); if (!name.empty()) { return {StringValue::make(name)}; } else { return {ValueRef()}; } } else { return imperative::apply(op, inputs); } } } // namespace imperative } // namespace mgb