proxy_graph.h 3.1 KB
Newer Older
1
/**
M
Megvii Engine Team 已提交
2 3
 * \file imperative/src/impl/proxy_graph.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
4
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6
 *
M
Megvii Engine Team 已提交
7 8 9
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
 */

#pragma once

#include "megbrain/imperative.h"
#include "megbrain/graph/cg.h"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/comp_node.h"

#include "megbrain/imperative/ops/backward_graph.h"

namespace mgb {
namespace imperative {

class ProxyGraph : public NonCopyableObj {
public:
    static ProxyGraph* get_default_graph();

    /********************** Physical Tensor API **********************/

    SmallVector<LogicalTensorDesc> infer_output_attrs(
            const OpDef& opdef,
            const SmallVector<Tensor*>& inputs);

    void invoke_op(
            const OpDef& opdef,
            const SmallVector<Tensor*>& inputs,
            const SmallVector<Tensor*>& outputs);

    BackwardGraphResult make_backward_graph(
            const OpDef& opdef,
            const SmallVector<LogicalTensorDesc>& input_descs,
            const SmallVector<bool>& input_requires_grad,
            const SmallVector<bool>& output_has_grad);

    /********************** Logical Tensor API **********************/

    size_t get_opr_output_size(
            const OpDef& opdef,
            const SmallVector<LogicalTensorDesc>& inputs);

51
    std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
            const OpDef& opdef,
            const SmallVector<LogicalTensorDesc>& inputs);

private:
    ProxyGraph();

    class ProxyGraphImpl;
    class ExecEnv;
    class StaticInferManager;
    class SeqCompNodeOptimizer;
    class InputPlaceholder;
    struct ProxyGraphInst;
    struct GradGraph;
    struct CurOprGuard;

    void reset();

    /********************** Physical Tensor Helper **********************/

    void cleanup();

    void init_output_tensor(
            const SmallVector<Tensor*>& outputs);

    cg::OperatorNodeBase* get_proxy_opr(
            const OpDef& opdef,
            const SmallVector<Tensor*>& inputs);

    /********************** Logical Tensor Helper **********************/

    cg::OperatorNodeBase* get_proxy_opr(
            const OpDef& opdef,
            const SmallVector<LogicalTensorDesc>& inputs);

    cg::VarNodeArray make_input_place_holders(
            const SmallVector<LogicalTensorDesc>& inputs);

    /********************** Common Helper **********************/

91
    bool do_shape_infer(bool sync_value);
92 93 94 95 96

    TensorPtr as_tensor(cg::OperatorNodeBase* opr, bool share=true);

    cg::OperatorNodeBase* m_cur_opr = nullptr;
    std::unique_ptr<ProxyGraphImpl> m_graph;
97
    size_t m_max_op_cnt = 100;
98 99 100 101 102 103 104 105 106
    std::unique_ptr<ExecEnv> m_env;
    std::unique_ptr<StaticInferManager> m_static_infer_manager;
    std::unique_ptr<SeqCompNodeOptimizer> m_seq_comp_node_optimizer;
};

} // namespace imperative
} // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}