eval.cpp 5.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
/**
 * \file imperative/src/impl/transformations/trace.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/imperative/transformations/eval.h"
#include "megbrain/imperative/transformations/grad.h"
14
#include "megbrain/imperative/utils/stats.h"
15 16 17 18

namespace mgb {
namespace imperative {

19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
DTypeValue::ref_t InterpreterInfo::dtype() const {
    if (!m_dtype) {
        m_dtype = DTypeValue::make(handle()->channel()->get_dtype(handle()->handle()));
    }
    return m_dtype;
}

CompNodeValue::ref_t InterpreterInfo::comp_node() const {
    if (!m_comp_node) {
        m_comp_node = CompNodeValue::make(
                handle()->channel()->get_device(handle()->handle()));
    }
    return m_comp_node;
}

ShapeValue::ref_t InterpreterInfo::shape() const {
    if (!m_shape) {
        m_shape = ShapeValue::make(
                ValueShape::from(handle()->channel()->get_shape(handle()->handle())));
    }
    return m_shape;
}

ValueRefList InterpreterTransformation::apply_op(
        const ApplyOp& apply_op, Span<ValueRef> inputs) {
    SmallVector<Handle> input_handles;
    SmallVector<Handle> output_handles;
    CleanupGuard _{[&] {
        for (auto handle : output_handles) {
            if (handle) {
                m_channel->del(handle);
50 51
            }
        }
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
    }};
    for (auto input : inputs) {
        input_handles.push_back(input.cast<InterpreterValue>().handle()->handle());
    }
    output_handles =
            m_channel->apply_op(apply_op.op().shared_from_this(), input_handles);
    ValueRefList outputs(output_handles.size());
    for (size_t i = 0; i < output_handles.size(); ++i) {
        outputs[i] = InterpreterValue::make(share_handle(output_handles[i]));
        output_handles[i] = nullptr;
    }
    return outputs;
}

ValueRefList InterpreterTransformation::apply_get_attr(
        const GetAttr& get_attr, Span<ValueRef> inputs) {
    auto& input = inputs.item().cast<InterpreterValue>();
    ValueRef output;
    switch (get_attr.attr()) {
        case GetAttr::DType:
            output = input.dtype();
            break;
        case GetAttr::Shape:
            output = input.shape();
            break;
        case GetAttr::Device:
            output = input.comp_node();
            break;
        case GetAttr::Value:
            output = HostValue::make(m_channel->get_value(input.handle()->handle()));
            break;
        case GetAttr::Data:
            output = DeviceValue::make(
                    m_channel->get_dev_tensor(input.handle()->handle()));
            break;
        default:
            mgb_throw(
                    MegBrainError, "Interpreter: malformed GetAttr: %s",
                    get_attr.to_string().c_str());
    }
    return {output};
}

ValueRefList InterpreterTransformation::apply_create_tensor(
        const CreateTensor& create_tensor, Span<ValueRef> inputs) {
    auto args = create_tensor.parse(inputs);
    if (!args.device) {
        // implies H2D
        mgb_assert(args.host, "neither host and device value is valid");
        return {InterpreterValue::make(share_handle(
                m_channel->put(*args.host, args.kind == CreateTensor::Unique)))};
    } else {
        return {InterpreterValue::make(share_handle(m_channel->put(
                *args.device, args.host ? *args.host : HostTensorND())))};
    }
}

ValueRefList InterpreterTransformation::apply_transformation(
        const Operator& op, Span<ValueRef> inputs) {
    if (auto* op_val = op.as<ApplyOp>()) {
112 113 114 115 116
        if (op_val->op().same_type<FastpathCopy>()) {
            return inputs[0];
        } else {
            return apply_op(*op_val, inputs);
        }
117
    } else if (auto* get_attr = op.as<GetAttr>()) {
118
        return apply_get_attr(*get_attr, inputs);
119
    } else if (auto* create_tensor = op.as<CreateTensor>()) {
120
        return apply_create_tensor(*create_tensor, inputs);
121
    } else if (auto* dtr_command = op.as<DTRCommand>()) {
122
        auto handle = inputs[0].cast<InterpreterValue>().handle()->handle();
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
        switch (dtr_command->kind()) {
            case DTRCommand::Drop:
                m_channel->drop(handle);
                break;
            default:
                mgb_throw(AssertionError, "unknown DTRCommand %d", dtr_command->kind());
        }
        return {};
    } else if (auto* rename_value = op.as<RenameValue>()) {
        auto& input = inputs[0].cast<InterpreterValue>();
        return {InterpreterValue::make(input.handle(), rename_value->name())};
    } else if (op.is<GetName>()) {
        auto name = inputs[0].cast<InterpreterValue>().name();
        if (!name.empty()) {
            return {StringValue::make(name)};
        } else {
            return {ValueRef()};
        }
    } else {
        return imperative::apply(op, inputs);
    }
}

}  // namespace imperative
}  // namespace mgb