eval.cpp 5.0 KB
Newer Older
1 2 3 4 5
#include "megbrain/imperative/transformations/eval.h"

namespace mgb {
namespace imperative {

6
DTypeValue::ref_t InterpreterValue::dtype() const {
7 8 9 10 11 12
    if (!m_dtype) {
        m_dtype = DTypeValue::make(handle()->channel()->get_dtype(handle()->handle()));
    }
    return m_dtype;
}

13
CompNodeValue::ref_t InterpreterValue::comp_node() const {
14 15 16 17 18 19 20
    if (!m_comp_node) {
        m_comp_node = CompNodeValue::make(
                handle()->channel()->get_device(handle()->handle()));
    }
    return m_comp_node;
}

21
ShapeValue::ref_t InterpreterValue::shape() const {
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
    if (!m_shape) {
        m_shape = ShapeValue::make(
                ValueShape::from(handle()->channel()->get_shape(handle()->handle())));
    }
    return m_shape;
}

ValueRefList InterpreterTransformation::apply_op(
        const ApplyOp& apply_op, Span<ValueRef> inputs) {
    SmallVector<Handle> input_handles;
    SmallVector<Handle> output_handles;
    CleanupGuard _{[&] {
        for (auto handle : output_handles) {
            if (handle) {
                m_channel->del(handle);
37 38
            }
        }
39 40
    }};
    for (auto input : inputs) {
41
        input_handles.push_back(input.cast(m_value_type).handle()->handle());
42
    }
43
    m_channel->set_backtrace(Transformation::get_context().bt);
44 45 46 47
    output_handles =
            m_channel->apply_op(apply_op.op().shared_from_this(), input_handles);
    ValueRefList outputs(output_handles.size());
    for (size_t i = 0; i < output_handles.size(); ++i) {
48
        outputs[i] = m_value_type.make(share_handle(output_handles[i]));
49 50
        output_handles[i] = nullptr;
    }
51
    output_handles.clear();
52
    m_channel->clear_backtrace();
53 54 55 56 57
    return outputs;
}

ValueRefList InterpreterTransformation::apply_get_attr(
        const GetAttr& get_attr, Span<ValueRef> inputs) {
58
    auto& input = inputs.item().cast(m_value_type);
59
    ValueRef output;
60
    m_channel->set_backtrace(Transformation::get_context().bt);
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
    switch (get_attr.attr()) {
        case GetAttr::DType:
            output = input.dtype();
            break;
        case GetAttr::Shape:
            output = input.shape();
            break;
        case GetAttr::Device:
            output = input.comp_node();
            break;
        case GetAttr::Value:
            output = HostValue::make(m_channel->get_value(input.handle()->handle()));
            break;
        case GetAttr::Data:
            output = DeviceValue::make(
                    m_channel->get_dev_tensor(input.handle()->handle()));
            break;
        default:
            mgb_throw(
                    MegBrainError, "Interpreter: malformed GetAttr: %s",
                    get_attr.to_string().c_str());
    }
83
    m_channel->clear_backtrace();
84 85 86 87 88 89 90 91 92
    return {output};
}

ValueRefList InterpreterTransformation::apply_create_tensor(
        const CreateTensor& create_tensor, Span<ValueRef> inputs) {
    auto args = create_tensor.parse(inputs);
    if (!args.device) {
        // implies H2D
        mgb_assert(args.host, "neither host and device value is valid");
93
        return {m_value_type.make(share_handle(
94 95
                m_channel->put(*args.host, args.kind == CreateTensor::Unique)))};
    } else {
96
        return {m_value_type.make(share_handle(m_channel->put(
97 98 99 100 101 102 103
                *args.device, args.host ? *args.host : HostTensorND())))};
    }
}

ValueRefList InterpreterTransformation::apply_transformation(
        const Operator& op, Span<ValueRef> inputs) {
    if (auto* op_val = op.as<ApplyOp>()) {
104 105 106 107 108
        if (op_val->op().same_type<FastpathCopy>()) {
            return inputs[0];
        } else {
            return apply_op(*op_val, inputs);
        }
109
    } else if (auto* get_attr = op.as<GetAttr>()) {
110
        return apply_get_attr(*get_attr, inputs);
111
    } else if (auto* create_tensor = op.as<CreateTensor>()) {
112
        return apply_create_tensor(*create_tensor, inputs);
113
    } else if (auto* dtr_command = op.as<DTRCommand>()) {
114
        auto handle = inputs[0].cast(m_value_type).handle()->handle();
115 116 117 118 119 120 121 122 123
        switch (dtr_command->kind()) {
            case DTRCommand::Drop:
                m_channel->drop(handle);
                break;
            default:
                mgb_throw(AssertionError, "unknown DTRCommand %d", dtr_command->kind());
        }
        return {};
    } else if (auto* rename_value = op.as<RenameValue>()) {
124 125
        auto& input = inputs[0].cast(m_value_type);
        return {m_value_type.make(input.handle(), rename_value->name())};
126
    } else if (op.is<GetName>()) {
127
        auto name = inputs[0].cast(m_value_type).name();
128 129 130 131 132
        if (!name.empty()) {
            return {StringValue::make(name)};
        } else {
            return {ValueRef()};
        }
133 134 135 136 137
    } else if (op.is<DupTensor>()) {
        auto& input = inputs[0].cast(m_value_type);
        DeviceTensorND dev_tensor;
        dev_tensor.copy_from(m_channel->get_dev_tensor(input.handle()->handle()));
        return m_value_type.make(share_handle(m_channel->put(dev_tensor, {})));
138
    } else {
139
        return op.fallback(inputs);
140 141 142 143 144
    }
}

}  // namespace imperative
}  // namespace mgb