proxy_graph_detail.cpp 5.4 KB
Newer Older
1
/**
M
Megvii Engine Team 已提交
2 3
 * \file imperative/src/impl/proxy_graph_detail.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
4
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6
 *
M
Megvii Engine Team 已提交
7 8 9
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 11 12
 */

#include "./proxy_graph.h"
13
#include "megbrain/imperative/proxy_graph_detail.h"
14
#include "megbrain/imperative/ops/autogen.h"
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35

namespace mgb {
namespace imperative {
namespace proxy_graph_detail {

namespace {
SmallVector<Tensor*> to_raw_ptr_array(
        const SmallVector<TensorPtr>& inputs,
        bool ensure_storage=true) {
    SmallVector<Tensor*> ret;
    for (auto&& i : inputs) {
        mgb_assert(i);
        ret.push_back(i.get());
        if (ensure_storage) {
            // apply lazy allocation
            i->blob()->storage();
        }
    }
    return ret;
}

36 37 38 39 40 41 42 43
SmallVector<LogicalTensorDesc>
infer_output_attrs(const OpDef& def,
        const SmallVector<TensorPtr>& inputs) {
    auto&& graph = ProxyGraph::get_default_graph();
    return graph->infer_output_attrs(def, to_raw_ptr_array(inputs));
}
} // anonymous namespace

44
void exec(const OpDef& def,
45
        const SmallVector<TensorPtr>& inputs,
46 47
        const SmallVector<TensorPtr>& outputs,
        const SmallVector<TensorPtr>& workspaces) {
48
    auto&& graph = ProxyGraph::get_default_graph();
49
    auto raw_inputs = to_raw_ptr_array(inputs),
50 51
         raw_outputs = to_raw_ptr_array(outputs),
         raw_workspaces = to_raw_ptr_array(workspaces);
52
    CompNode::UnorderedSet used_cns;
53
    for (auto&& out: raw_outputs) {
54 55 56 57 58 59 60 61 62 63
        auto cn = out->comp_node();
        if (used_cns.insert(cn).second) {
            for (auto&& in: inputs) {
                if (in->comp_node() != cn) {
                    auto&& e = in->get_or_create_event();
                    e->device_wait_by(cn);
                }
            }
        }
    }
64
    graph->invoke_op(def, raw_inputs, raw_outputs, raw_workspaces);
65 66 67 68 69 70 71 72 73
    for (auto&& cn: used_cns) {
        for (auto&& in: inputs) {
            if (in->comp_node() != cn) {
                in->add_release_callback(cn);
            }
        }
    }
}

74 75
SmallVector<TensorPtr>
apply_on_physical_tensor(const OpDef& def,
76 77 78 79
        SmallVector<TensorPtr> inputs) {
    auto output_descs = infer_output_attrs(def, inputs);
    SmallVector<TensorPtr> outputs(output_descs.size(), {});
    for (size_t i = 0; i < outputs.size(); i++) {
80
        outputs[i] = Tensor::make(output_descs[i].layout, output_descs[i].comp_node);
81
    }
82
    exec(def, inputs, outputs, {});
83 84 85 86
    auto async_error = ProxyGraph::get_async_error();
    if (async_error) {
        throw *async_error;
    }
87 88
    return outputs;
}
89

90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
    const OpDef& def,
    const SmallVector<TensorPtr>& inputs_tensors,
    const SmallVector<MemoryDesc>& inputs_mems) {
    auto&& graph = ProxyGraph::get_default_graph();
    return graph->infer_output_mem_desc(def, to_raw_ptr_array(inputs_tensors), inputs_mems);
}

void execute(const OpDef& def,
        SmallVector<TensorPtr> inputs,
        SmallVector<TensorPtr> outputs,
        SmallVector<TensorPtr> workspace) {
    exec(def, inputs, outputs, workspace);
    auto async_error = ProxyGraph::get_async_error();
    if (async_error) {
        throw *async_error;
    }
    return;
}

110 111 112 113 114
// std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def,
//         const SmallVector<LogicalTensorDesc>& inputs) {
//     auto&& graph = ProxyGraph::get_default_graph();
//     return graph->infer_output_attrs_fallible(def, inputs);
// }
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135

namespace {

size_t get_backward_graph_hash_key(const OpDef& def,
        const SmallVector<LogicalTensorDesc>& inputs,
        const SmallVector<bool>& input_requires_grad,
        const SmallVector<bool>& output_has_grad) {
    XXHash state;
    size_t length = 0, data[3 + 2 * inputs.size()];
    data[length ++] = def.hash();
    for (auto &&i : inputs) {
        data[length ++] = mgb::hash(i.layout.dtype.handle());
        data[length ++] = mgb::hash(i.comp_node);
    }
    data[length ++] = mgb::hash(input_requires_grad);
    data[length ++] = mgb::hash(output_has_grad);
    mgb_assert(length == 3 + 2 * inputs.size());
    state.update(data, length * sizeof(size_t));
    return state.digest();
}

136
struct BackwardGraphCache : std::unordered_map<size_t, EncodedSubraph>, CompNodeDepedentObject {
137 138 139 140 141 142 143 144
    std::shared_ptr<void> on_comp_node_finalize() override {
        clear();
        return {};
    }
} backward_graph_cache;

} // anonymous namespace

145
EncodedSubraph
146 147 148 149 150 151 152 153 154
make_backward_graph(const OpDef& def,
        const SmallVector<LogicalTensorDesc>& inputs,
        const SmallVector<bool>& input_requires_grad,
        const SmallVector<bool>& output_has_grad) {
    auto hash_key = get_backward_graph_hash_key(def, inputs, input_requires_grad, output_has_grad);
    auto&& iter = backward_graph_cache.find(hash_key);
    if (iter != backward_graph_cache.end()) {
        return iter->second;
    }
155
    auto&& graph = ProxyGraph::get_default_graph();
156 157 158 159 160 161 162 163 164
    auto res = graph->make_backward_graph(def, inputs, input_requires_grad, output_has_grad);
    backward_graph_cache.emplace(hash_key, res);
    return res;
}

} // namespace proxy_graph_detail
} // namespace imperative
} // namespace mgb

165
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}