eval.cpp 4.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
/**
 * \file imperative/src/impl/transformations/trace.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/imperative/transformations/eval.h"
#include "megbrain/imperative/transformations/grad.h"

namespace mgb {
namespace imperative {

std::vector<ValueRef> InterpreterTransformation::apply_transformation(
        const Operator& op, Span<ValueRef> inputs) {
    if (auto* op_val = op.as<ApplyOp>()) {
        if (op_val->op().same_type<FastpathCopy>()) {
            return {inputs[0]};
        }
        SmallVector<Handle> input_handles;
        SmallVector<Handle> output_handles;
        CleanupGuard _{[&] {
            for (auto handle : output_handles) {
                if (handle) {
                    m_channel->del(handle);
                }
            }
        }};
        for (auto input : inputs) {
            input_handles.push_back(*input.cast<InterpreterValue>().handle());
        }
        output_handles =
                m_channel->apply_op(op_val->op().shared_from_this(), input_handles);
        std::vector<ValueRef> outputs;
        for (auto& handle : output_handles) {
            outputs.push_back(InterpreterValue::make(share_handle(handle)));
            handle = nullptr;
        }
        return outputs;
    } else if (auto* get_attr = op.as<GetAttr>()) {
        Handle handle = *inputs[0].cast<InterpreterValue>().handle();
        ValueRef output;
        switch (get_attr->attr()) {
            case GetAttr::DType:
                output = DTypeValue::make(m_channel->get_dtype(handle));
                break;
            case GetAttr::Shape:
                output = ShapeValue::make(
                        ValueShape::from(m_channel->get_shape(handle)));
                break;
            case GetAttr::Device:
                output = CompNodeValue::make(m_channel->get_device(handle));
                break;
            case GetAttr::Value:
                output = HostValue::make(m_channel->get_value(handle));
                break;
            case GetAttr::Data:
                output = DeviceValue::make(m_channel->get_dev_tensor(handle));
                break;
            default:
                mgb_throw(
                        MegBrainError, "Interpreter: malformed GetAttr: %s",
                        op.to_string().c_str());
        }
        return {output};
    } else if (auto* create_tensor = op.as<CreateTensor>()) {
        auto args = create_tensor->parse(inputs);
        if (!args.device) {
            // implies H2D
            mgb_assert(args.host, "neither host and device value is valid");
            return {InterpreterValue::make(share_handle(
                    m_channel->put(*args.host, args.kind == CreateTensor::Unique)))};
        } else {
            return {InterpreterValue::make(share_handle(m_channel->put(
                    *args.device, args.host ? *args.host : HostTensorND())))};
        }
    } else if (auto* dtr_command = op.as<DTRCommand>()) {
        auto handle = *inputs[0].cast<InterpreterValue>().handle();
        switch (dtr_command->kind()) {
            case DTRCommand::Drop:
                m_channel->drop(handle);
                break;
            default:
                mgb_throw(AssertionError, "unknown DTRCommand %d", dtr_command->kind());
        }
        return {};
    } else if (auto* rename_value = op.as<RenameValue>()) {
        auto& input = inputs[0].cast<InterpreterValue>();
        return {InterpreterValue::make(input.handle(), rename_value->name())};
    } else if (op.is<GetName>()) {
        auto name = inputs[0].cast<InterpreterValue>().name();
        if (!name.empty()) {
            return {StringValue::make(name)};
        } else {
            return {ValueRef()};
        }
    } else {
        return imperative::apply(op, inputs);
    }
}

}  // namespace imperative
}  // namespace mgb