bases.h 5.8 KB
Newer Older
1 2 3 4
/**
 * \file src/core/include/megbrain/graph/bases.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

#include "megbrain/utils/json.h"
#include "megbrain/utils/metahelper.h"
#include "megbrain/exception.h"
#include "megbrain/comp_node.h"

#include <string>

21 22 23 24
#ifndef MGB_ENABLE_DTR
#define MGB_ENABLE_DTR ((!MGB_BUILD_SLIM_SERVING) && (!!MGB_HAVE_THREAD))
#endif  //  MGB_ENABLE_DTR

25 26 27 28
#ifndef MGB_ENABLE_SUBLINEAR
#define MGB_ENABLE_SUBLINEAR ((!MGB_BUILD_SLIM_SERVING) && (!!MGB_HAVE_THREAD))
#endif  //  MGB_ENABLE_SUBLINEAR

29 30
// FIXME: reopen when rewriting memory swap or existing tests are passed
#define MGB_ENABLE_MEMORY_SWAP 0
31 32 33
#ifndef MGB_ENABLE_MEMORY_SWAP
#define MGB_ENABLE_MEMORY_SWAP \
    ((!MGB_BUILD_SLIM_SERVING) && (!!MGB_HAVE_THREAD) && (MGB_CUDA))
34
#endif  //  MGB_ENABLE_MEMORY_SWAP
35 36 37 38 39 40 41 42 43 44 45 46 47 48

#ifndef MGB_ENABLE_PARTIAL_EXECUTION
#define MGB_ENABLE_PARTIAL_EXECUTION (!MGB_BUILD_SLIM_SERVING)
#endif  //  MGB_ENABLE_PARTIAL_EXECUTION

#ifndef MGB_ENABLE_COND_EXEC
#define MGB_ENABLE_COND_EXEC !MGB_BUILD_SLIM_SERVING
#endif
#if MGB_ENABLE_COND_EXEC
#define MGB_IF_COND_EXEC(x...) x
#else
#define MGB_IF_COND_EXEC(x...)
#endif

49 50 51 52 53 54
#if MGB_CUDA && MGB_ENABLE_EXCEPTION
#define MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER 1
#else
#define MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER 0
#endif // whether enable memory defragment

55 56
namespace mgb {

57 58 59 60 61 62 63 64 65
class GraphError : public MegBrainError {
public:
    using MegBrainError::MegBrainError;
};

}  // namespace mgb

namespace mgb {

66 67 68 69 70 71 72
//! computing graph
namespace cg {

namespace static_infer {
    struct DepElement;
};

73
using GraphError = mgb::GraphError;
74
class VarNode;
75 76
class OperatorNodeBase;
class ComputingGraph;
77
using VarNodeArray = mgb::SmallVector<VarNode*>;
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
/*!
 * \brief Base class for a node in the graph.
 *
 * Each node must have a name for debugging and graph dump, and each node is
 * uniquely identified by its memory address. Every node in a computing graph
 * has its unique numerical ID.
 */
class GraphNodeBase: public json::Serializable, public NonCopyableObj {
    ComputingGraph* const m_owner_graph;
    size_t m_id;

    protected:
        ~GraphNodeBase() = default;

    public:
        GraphNodeBase(ComputingGraph *owner_graph);

        ComputingGraph* owner_graph() const {
            return m_owner_graph;
        }

        //! get node ID as string
        std::string id_str() const {
            return std::to_string(m_id);
        }

        //! get node ID as number
        size_t id() const {
            return m_id;
        }
};

110 111 112 113 114 115 116 117 118 119 120
class OutputVarsUserData final : public mgb::UserDataContainer::UserData {
    MGB_TYPEINFO_OBJ_DECL;

private:
    VarNodeArray m_output_vars;

public:
    void set_output_vars(VarNodeArray vars) { m_output_vars = std::move(vars); }
    const VarNodeArray& get_output_vars() const { return m_output_vars; }
};

121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
/*!
 * \brief an object that executes asynchronously
 */
class AsyncExecutable : public json::Serializable,
                        public CompNodeDepedentObject {
    UserDataContainer m_user_data;

    public:
        virtual ~AsyncExecutable() noexcept;

        virtual AsyncExecutable& execute() = 0;

        /*!
         * \brief wait for current task to finish
         */
        virtual AsyncExecutable& wait() = 0;

        /*!
         * \brief previous execution time in seconds
         */
        virtual double get_prev_exec_time() const = 0;

        /*!
         * \brief iterate over operator sequence
         * \param cb callback function, return false to stop iterating
         */
        virtual AsyncExecutable& iter_opr_seq(
                thin_function<bool(OperatorNodeBase*)> cb) = 0;

        /*!
         * \brief get RT_STATIC deps needed for static infer in this func
         */
        virtual const SmallVector<static_infer::DepElement>&
            get_rt_static_source_deps() = 0;

        /*!
         * \brief number of calls to execute()
         */
        virtual size_t get_run_id() const = 0;

        /*!
         * \brief update static memory allocation plan and allocation size
         *
         * Note: as a side effect, static shape inference would be executed and
         * var shapes are updated.
         *
         * \return static allocation size for each comp node
         */
        virtual const CompNode::UnorderedMap<size_t>&
        update_static_alloc_plan_and_get_size() = 0;

        /*!
         * \brief clear device memory; memory would be allocated in the next run
         */
        virtual void clear_device_memory() = 0;

        //! get the graph that owns this executable; nullptr if no owner graph
        virtual ComputingGraph* owner_graph() const = 0;

        //! user data associated with a compiled executable
        UserDataContainer& user_data() {
            return m_user_data;
        }
184 185 186 187 188 189 190 191 192 193 194 195 196

        void set_output_vars(const VarNodeArray& vars) {
            std::shared_ptr<OutputVarsUserData> ud =
                    std::make_shared<OutputVarsUserData>();
            ud->set_output_vars(vars);
            m_user_data.add_user_data(ud);
        }

        const VarNodeArray& get_output_vars() const {
            auto output_vars_pair =
                    m_user_data.get_user_data<OutputVarsUserData>();
            return (*(output_vars_pair.first))->get_output_vars();
        }
197
#ifndef __IN_TEE_ENV__
198 199
        virtual void get_static_memory_alloc_info(
                const std::string& svg_name) const {
200 201 202
            mgb_assert(svg_name.length() < 0,
                       "can't call this function directly\n");
        }
203
#endif
204 205 206 207 208 209 210 211
};


} // namespace cg
} // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}