commands.h 4.7 KB
Newer Older
1 2 3
#pragma once

#include <string>
4
#include <unordered_set>
M
Megvii Engine Team 已提交
5
#include <variant>
6

7
#include "megbrain/imperative/backtrace.h"
8 9
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/utils/to_string.h"
M
Megvii Engine Team 已提交
10
#include "megbrain/tensor.h"
11

12
#include "./stack_manager.h"
13 14
#include "./tensor_info.h"

15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
namespace mgb::imperative {

namespace interpreter::intl {

struct TensorInfo;
class InterpreterProfiler;

struct Put {
    TensorInfo* dest;
    HostTensorND value;
    bool no_cache = false;

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("dest", dest);
        functor("no_cache", no_cache);
M
Megvii Engine Team 已提交
31
        // functor("value", value);
32 33
    }

M
Megvii Engine Team 已提交
34
    const char* get_name() const { return "Put"; }
35 36 37
};

struct ApplyOp {
M
Megvii Engine Team 已提交
38
    uint64_t id;  // used by profiler to identify unique apply
39 40 41
    std::shared_ptr<OpDef> op;
    SmallVector<TensorInfo*> inputs;
    SmallVector<TensorInfo*> outputs;
42
    bool validated = false;
43
    BackTraceInfoPtr bt = nullptr;
44 45 46 47 48 49 50 51

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("op", op);
        functor("inputs", inputs);
        functor("outputs", outputs);
    }

M
Megvii Engine Team 已提交
52
    const char* get_name() const { return "ApplyOp"; }
53 54 55 56 57 58 59 60 61 62
};

struct Del {
    TensorInfo* dest;

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("dest", dest);
    }

M
Megvii Engine Team 已提交
63
    const char* get_name() const { return "Del"; }
64 65 66 67 68 69 70 71 72 73
};

struct GetValue {
    TensorInfo* dest;

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("dest", dest);
    }

M
Megvii Engine Team 已提交
74
    const char* get_name() const { return "GetValue"; }
75 76 77 78 79 80 81 82 83 84
};

struct Drop {
    TensorInfo* dest;

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("dest", dest);
    }

M
Megvii Engine Team 已提交
85
    const char* get_name() const { return "Drop"; }
86 87 88 89
};

struct SetOption {
    std::string key;
90
    size_t value;
91 92 93 94 95 96 97

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("key", key);
        functor("value", value);
    }

M
Megvii Engine Team 已提交
98
    const char* get_name() const { return "SetOption"; }
99 100 101
};

struct StartProfile {
102
    std::unordered_set<TensorInfo*> capture_tensors;
103 104 105 106

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {}

M
Megvii Engine Team 已提交
107
    const char* get_name() const { return "StartProfile"; }
108 109 110
};

struct StopProfile {
111
    std::unordered_set<TensorInfo*> escape_tensors;
112 113

    template <typename TFunctor>
114
    void get_props(TFunctor&& functor) const {}
115

M
Megvii Engine Team 已提交
116
    const char* get_name() const { return "StopProfile"; }
117 118
};

119 120 121 122 123 124 125
struct StopStep {
    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {}

    const char* get_name() const { return "StopStep"; }
};

126 127 128 129 130 131 132 133
struct PushScope {
    std::string scope_name;

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("scope_name", scope_name);
    }

M
Megvii Engine Team 已提交
134
    const char* get_name() const { return "PushScope"; }
135 136 137 138 139 140 141 142 143 144
};

struct PopScope {
    std::string scope_name;

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("scope_name", scope_name);
    }

M
Megvii Engine Team 已提交
145
    const char* get_name() const { return "PopScope"; }
146 147
};

148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
struct StartRegen {
    TensorInfo* dest;

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("dest", dest);
    }

    const char* get_name() const { return "StartRegen"; }
};

struct StopRegen {
    TensorInfo* dest;

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("dest", dest);
    }

    const char* get_name() const { return "StopRegen"; }
};

M
Megvii Engine Team 已提交
170
using CommandData = std::variant<
171
        Put, ApplyOp, Del, GetValue, Drop, SetOption, StartProfile, StopProfile,
172
        StopStep, PushScope, PopScope, StartRegen, StopRegen>;
173

174 175 176 177 178 179
struct Command {
    uint64_t id;
    CommandData data;
    StackManager::Trace trace;
};
// using IdentifiedCommand = std::pair<uint64_t, Command>;
180

M
Megvii Engine Team 已提交
181
}  // namespace interpreter::intl
182 183

template <>
M
Megvii Engine Team 已提交
184
struct ToStringTrait<interpreter::intl::Command> {
185
    std::string operator()(const interpreter::intl::Command& cmd) const {
M
Megvii Engine Team 已提交
186 187 188 189 190 191 192 193 194 195 196 197 198 199
        std::string content = std::visit(
                [](const auto& cmd) {
                    std::string result = cmd.get_name();
                    result += "{";
                    cmd.get_props([&](const char* key, auto&& value) {
                        result += key;
                        result += ": ";
                        result += to_string(value);
                        result += ",";
                    });
                    result += "}";
                    return result;
                },
                cmd.data);
200
        return content;
201 202 203
    }
};

M
Megvii Engine Team 已提交
204
}  // namespace mgb::imperative