commands.h 4.7 KB
Newer Older
1 2 3
#pragma once

#include <string>
4
#include <unordered_set>
M
Megvii Engine Team 已提交
5
#include <variant>
6

7 8
#include "./stack_manager.h"
#include "./tensor_info.h"
9
#include "megbrain/imperative/backtrace.h"
10
#include "megbrain/imperative/op_def.h"
11
#include "megbrain/imperative/profiler.h"
12
#include "megbrain/imperative/utils/to_string.h"
M
Megvii Engine Team 已提交
13
#include "megbrain/tensor.h"
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30

namespace mgb::imperative {

namespace interpreter::intl {

struct TensorInfo;
class InterpreterProfiler;

struct Put {
    TensorInfo* dest;
    HostTensorND value;
    bool no_cache = false;

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("dest", dest);
        functor("no_cache", no_cache);
M
Megvii Engine Team 已提交
31
        // functor("value", value);
32 33
    }

M
Megvii Engine Team 已提交
34
    const char* get_name() const { return "Put"; }
35 36 37
};

struct ApplyOp {
M
Megvii Engine Team 已提交
38
    uint64_t id;  // used by profiler to identify unique apply
39 40 41
    std::shared_ptr<OpDef> op;
    SmallVector<TensorInfo*> inputs;
    SmallVector<TensorInfo*> outputs;
42
    bool validated = false;
43
    BackTraceInfoPtr bt = nullptr;
44 45 46 47 48 49 50 51

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("op", op);
        functor("inputs", inputs);
        functor("outputs", outputs);
    }

M
Megvii Engine Team 已提交
52
    const char* get_name() const { return "ApplyOp"; }
53 54 55 56 57 58 59 60 61 62
};

struct Del {
    TensorInfo* dest;

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("dest", dest);
    }

M
Megvii Engine Team 已提交
63
    const char* get_name() const { return "Del"; }
64 65 66 67 68 69 70 71 72 73
};

struct GetValue {
    TensorInfo* dest;

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("dest", dest);
    }

M
Megvii Engine Team 已提交
74
    const char* get_name() const { return "GetValue"; }
75 76 77 78 79 80 81 82 83 84
};

struct Drop {
    TensorInfo* dest;

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("dest", dest);
    }

M
Megvii Engine Team 已提交
85
    const char* get_name() const { return "Drop"; }
86 87 88 89
};

struct SetOption {
    std::string key;
90
    size_t value;
91 92 93 94 95 96 97

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("key", key);
        functor("value", value);
    }

M
Megvii Engine Team 已提交
98
    const char* get_name() const { return "SetOption"; }
99 100 101
};

struct StartProfile {
102
    std::unordered_set<TensorInfo*> capture_tensors;
103 104 105 106

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {}

M
Megvii Engine Team 已提交
107
    const char* get_name() const { return "StartProfile"; }
108 109 110
};

struct StopProfile {
111
    std::unordered_set<TensorInfo*> escape_tensors;
112 113

    template <typename TFunctor>
114
    void get_props(TFunctor&& functor) const {}
115

M
Megvii Engine Team 已提交
116
    const char* get_name() const { return "StopProfile"; }
117 118
};

119 120 121 122 123 124 125
struct StopStep {
    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {}

    const char* get_name() const { return "StopStep"; }
};

126 127
struct PushScope {
    std::string scope_name;
128
    ScopeType type;
129 130 131 132 133 134

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("scope_name", scope_name);
    }

M
Megvii Engine Team 已提交
135
    const char* get_name() const { return "PushScope"; }
136 137 138 139
};

struct PopScope {
    std::string scope_name;
140
    ScopeType type;
141 142 143 144 145
    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("scope_name", scope_name);
    }

M
Megvii Engine Team 已提交
146
    const char* get_name() const { return "PopScope"; }
147 148
};

149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
struct StartRegen {
    TensorInfo* dest;

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("dest", dest);
    }

    const char* get_name() const { return "StartRegen"; }
};

struct StopRegen {
    TensorInfo* dest;

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("dest", dest);
    }

    const char* get_name() const { return "StopRegen"; }
};

M
Megvii Engine Team 已提交
171
using CommandData = std::variant<
172
        Put, ApplyOp, Del, GetValue, Drop, SetOption, StartProfile, StopProfile,
173
        StopStep, PushScope, PopScope, StartRegen, StopRegen>;
174

175 176 177 178 179 180
struct Command {
    uint64_t id;
    CommandData data;
    StackManager::Trace trace;
};
// using IdentifiedCommand = std::pair<uint64_t, Command>;
181

M
Megvii Engine Team 已提交
182
}  // namespace interpreter::intl
183 184

template <>
M
Megvii Engine Team 已提交
185
struct ToStringTrait<interpreter::intl::Command> {
186
    std::string operator()(const interpreter::intl::Command& cmd) const {
M
Megvii Engine Team 已提交
187 188 189 190 191 192 193 194 195 196 197 198 199 200
        std::string content = std::visit(
                [](const auto& cmd) {
                    std::string result = cmd.get_name();
                    result += "{";
                    cmd.get_props([&](const char* key, auto&& value) {
                        result += key;
                        result += ": ";
                        result += to_string(value);
                        result += ",";
                    });
                    result += "}";
                    return result;
                },
                cmd.data);
201
        return content;
202 203 204
    }
};

M
Megvii Engine Team 已提交
205
}  // namespace mgb::imperative