proxy_graph_detail.cpp 4.2 KB
Newer Older
1
/**
M
Megvii Engine Team 已提交
2 3
 * \file imperative/src/impl/proxy_graph_detail.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
4
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6
 *
M
Megvii Engine Team 已提交
7 8 9
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 11 12
 */

#include "./proxy_graph.h"
13
#include "megbrain/imperative/proxy_graph_detail.h"
14
#include "megbrain/imperative/ops/autogen.h"
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35

namespace mgb {
namespace imperative {
namespace proxy_graph_detail {

namespace {
SmallVector<Tensor*> to_raw_ptr_array(
        const SmallVector<TensorPtr>& inputs,
        bool ensure_storage=true) {
    SmallVector<Tensor*> ret;
    for (auto&& i : inputs) {
        mgb_assert(i);
        ret.push_back(i.get());
        if (ensure_storage) {
            // apply lazy allocation
            i->blob()->storage();
        }
    }
    return ret;
}

36 37 38 39 40 41 42 43
SmallVector<LogicalTensorDesc>
infer_output_attrs(const OpDef& def,
        const SmallVector<TensorPtr>& inputs) {
    auto&& graph = ProxyGraph::get_default_graph();
    return graph->infer_output_attrs(def, to_raw_ptr_array(inputs));
}
} // anonymous namespace

44
void exec(const OpDef& def,
45
        const SmallVector<TensorPtr>& inputs,
46 47
        const SmallVector<TensorPtr>& outputs,
        const SmallVector<TensorPtr>& workspaces) {
48
    auto&& graph = ProxyGraph::get_default_graph();
49
    auto raw_inputs = to_raw_ptr_array(inputs),
50 51
         raw_outputs = to_raw_ptr_array(outputs),
         raw_workspaces = to_raw_ptr_array(workspaces);
52
    CompNode::UnorderedSet used_cns;
53
    for (auto&& out: raw_outputs) {
54 55 56 57 58 59 60 61 62 63
        auto cn = out->comp_node();
        if (used_cns.insert(cn).second) {
            for (auto&& in: inputs) {
                if (in->comp_node() != cn) {
                    auto&& e = in->get_or_create_event();
                    e->device_wait_by(cn);
                }
            }
        }
    }
64
    graph->invoke_op(def, raw_inputs, raw_outputs, raw_workspaces);
65 66 67 68 69 70 71 72 73
    for (auto&& cn: used_cns) {
        for (auto&& in: inputs) {
            if (in->comp_node() != cn) {
                in->add_release_callback(cn);
            }
        }
    }
}

74 75
SmallVector<TensorPtr>
apply_on_physical_tensor(const OpDef& def,
76 77 78 79
        SmallVector<TensorPtr> inputs) {
    auto output_descs = infer_output_attrs(def, inputs);
    SmallVector<TensorPtr> outputs(output_descs.size(), {});
    for (size_t i = 0; i < outputs.size(); i++) {
80
        outputs[i] = Tensor::make(output_descs[i].layout, output_descs[i].comp_node);
81
    }
82
    exec(def, inputs, outputs, {});
83 84 85 86
    auto async_error = ProxyGraph::get_async_error();
    if (async_error) {
        throw *async_error;
    }
87 88
    return outputs;
}
89

90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
    const OpDef& def,
    const SmallVector<TensorPtr>& inputs_tensors,
    const SmallVector<MemoryDesc>& inputs_mems) {
    auto&& graph = ProxyGraph::get_default_graph();
    return graph->infer_output_mem_desc(def, to_raw_ptr_array(inputs_tensors), inputs_mems);
}

void execute(const OpDef& def,
        SmallVector<TensorPtr> inputs,
        SmallVector<TensorPtr> outputs,
        SmallVector<TensorPtr> workspace) {
    exec(def, inputs, outputs, workspace);
    auto async_error = ProxyGraph::get_async_error();
    if (async_error) {
        throw *async_error;
    }
    return;
}

110 111 112 113 114
// std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def,
//         const SmallVector<LogicalTensorDesc>& inputs) {
//     auto&& graph = ProxyGraph::get_default_graph();
//     return graph->infer_output_attrs_fallible(def, inputs);
// }
115

M
Megvii Engine Team 已提交
116
EncodedSubgraph
117 118 119 120
make_backward_graph(const OpDef& def,
        const SmallVector<LogicalTensorDesc>& inputs,
        const SmallVector<bool>& input_requires_grad,
        const SmallVector<bool>& output_has_grad) {
121
    return ProxyGraph::get_default_graph()->make_backward_graph(def, inputs, input_requires_grad, output_has_grad);
122 123 124 125 126 127
}

} // namespace proxy_graph_detail
} // namespace imperative
} // namespace mgb

128
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}