eval.cpp 4.8 KB
Newer Older
1 2 3 4 5
#include "megbrain/imperative/transformations/eval.h"

namespace mgb {
namespace imperative {

6
DTypeValue::ref_t InterpreterValue::dtype() const {
7 8 9 10 11 12
    if (!m_dtype) {
        m_dtype = DTypeValue::make(handle()->channel()->get_dtype(handle()->handle()));
    }
    return m_dtype;
}

13
CompNodeValue::ref_t InterpreterValue::comp_node() const {
14 15 16 17 18 19 20
    if (!m_comp_node) {
        m_comp_node = CompNodeValue::make(
                handle()->channel()->get_device(handle()->handle()));
    }
    return m_comp_node;
}

21
ShapeValue::ref_t InterpreterValue::shape() const {
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
    if (!m_shape) {
        m_shape = ShapeValue::make(
                ValueShape::from(handle()->channel()->get_shape(handle()->handle())));
    }
    return m_shape;
}

ValueRefList InterpreterTransformation::apply_op(
        const ApplyOp& apply_op, Span<ValueRef> inputs) {
    SmallVector<Handle> input_handles;
    SmallVector<Handle> output_handles;
    CleanupGuard _{[&] {
        for (auto handle : output_handles) {
            if (handle) {
                m_channel->del(handle);
37 38
            }
        }
39 40
    }};
    for (auto input : inputs) {
41
        input_handles.push_back(input.cast(m_value_type).handle()->handle());
42 43 44 45 46
    }
    output_handles =
            m_channel->apply_op(apply_op.op().shared_from_this(), input_handles);
    ValueRefList outputs(output_handles.size());
    for (size_t i = 0; i < output_handles.size(); ++i) {
47
        outputs[i] = m_value_type.make(share_handle(output_handles[i]));
48 49
        output_handles[i] = nullptr;
    }
50
    output_handles.clear();
51 52 53 54 55
    return outputs;
}

ValueRefList InterpreterTransformation::apply_get_attr(
        const GetAttr& get_attr, Span<ValueRef> inputs) {
56
    auto& input = inputs.item().cast(m_value_type);
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
    ValueRef output;
    switch (get_attr.attr()) {
        case GetAttr::DType:
            output = input.dtype();
            break;
        case GetAttr::Shape:
            output = input.shape();
            break;
        case GetAttr::Device:
            output = input.comp_node();
            break;
        case GetAttr::Value:
            output = HostValue::make(m_channel->get_value(input.handle()->handle()));
            break;
        case GetAttr::Data:
            output = DeviceValue::make(
                    m_channel->get_dev_tensor(input.handle()->handle()));
            break;
        default:
            mgb_throw(
                    MegBrainError, "Interpreter: malformed GetAttr: %s",
                    get_attr.to_string().c_str());
    }
    return {output};
}

ValueRefList InterpreterTransformation::apply_create_tensor(
        const CreateTensor& create_tensor, Span<ValueRef> inputs) {
    auto args = create_tensor.parse(inputs);
    if (!args.device) {
        // implies H2D
        mgb_assert(args.host, "neither host and device value is valid");
89
        return {m_value_type.make(share_handle(
90 91
                m_channel->put(*args.host, args.kind == CreateTensor::Unique)))};
    } else {
92
        return {m_value_type.make(share_handle(m_channel->put(
93 94 95 96 97 98 99
                *args.device, args.host ? *args.host : HostTensorND())))};
    }
}

ValueRefList InterpreterTransformation::apply_transformation(
        const Operator& op, Span<ValueRef> inputs) {
    if (auto* op_val = op.as<ApplyOp>()) {
100 101 102 103 104
        if (op_val->op().same_type<FastpathCopy>()) {
            return inputs[0];
        } else {
            return apply_op(*op_val, inputs);
        }
105
    } else if (auto* get_attr = op.as<GetAttr>()) {
106
        return apply_get_attr(*get_attr, inputs);
107
    } else if (auto* create_tensor = op.as<CreateTensor>()) {
108
        return apply_create_tensor(*create_tensor, inputs);
109
    } else if (auto* dtr_command = op.as<DTRCommand>()) {
110
        auto handle = inputs[0].cast(m_value_type).handle()->handle();
111 112 113 114 115 116 117 118 119
        switch (dtr_command->kind()) {
            case DTRCommand::Drop:
                m_channel->drop(handle);
                break;
            default:
                mgb_throw(AssertionError, "unknown DTRCommand %d", dtr_command->kind());
        }
        return {};
    } else if (auto* rename_value = op.as<RenameValue>()) {
120 121
        auto& input = inputs[0].cast(m_value_type);
        return {m_value_type.make(input.handle(), rename_value->name())};
122
    } else if (op.is<GetName>()) {
123
        auto name = inputs[0].cast(m_value_type).name();
124 125 126 127 128
        if (!name.empty()) {
            return {StringValue::make(name)};
        } else {
            return {ValueRef()};
        }
129 130 131 132 133
    } else if (op.is<DupTensor>()) {
        auto& input = inputs[0].cast(m_value_type);
        DeviceTensorND dev_tensor;
        dev_tensor.copy_from(m_channel->get_dev_tensor(input.handle()->handle()));
        return m_value_type.make(share_handle(m_channel->put(dev_tensor, {})));
134
    } else {
135
        return op.fallback(inputs);
136 137 138 139 140
    }
}

}  // namespace imperative
}  // namespace mgb