commands.h 4.0 KB
Newer Older
1 2 3
#pragma once

#include <string>
4
#include <unordered_set>
M
Megvii Engine Team 已提交
5
#include <variant>
6 7 8

#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/utils/to_string.h"
M
Megvii Engine Team 已提交
9
#include "megbrain/tensor.h"
10

11
#include "./stack_manager.h"
12 13
#include "./tensor_info.h"

14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
namespace mgb::imperative {

namespace interpreter::intl {

struct TensorInfo;
class InterpreterProfiler;

struct Put {
    TensorInfo* dest;
    HostTensorND value;
    bool no_cache = false;

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("dest", dest);
        functor("no_cache", no_cache);
M
Megvii Engine Team 已提交
30
        // functor("value", value);
31 32
    }

M
Megvii Engine Team 已提交
33
    const char* get_name() const { return "Put"; }
34 35 36
};

struct ApplyOp {
M
Megvii Engine Team 已提交
37
    uint64_t id;  // used by profiler to identify unique apply
38 39 40
    std::shared_ptr<OpDef> op;
    SmallVector<TensorInfo*> inputs;
    SmallVector<TensorInfo*> outputs;
41
    bool validated = false;
42 43 44 45 46 47 48 49

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("op", op);
        functor("inputs", inputs);
        functor("outputs", outputs);
    }

M
Megvii Engine Team 已提交
50
    const char* get_name() const { return "ApplyOp"; }
51 52 53 54 55 56 57 58 59 60
};

struct Del {
    TensorInfo* dest;

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("dest", dest);
    }

M
Megvii Engine Team 已提交
61
    const char* get_name() const { return "Del"; }
62 63 64 65 66 67 68 69 70 71
};

struct GetValue {
    TensorInfo* dest;

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("dest", dest);
    }

M
Megvii Engine Team 已提交
72
    const char* get_name() const { return "GetValue"; }
73 74 75 76 77 78 79 80 81 82
};

struct Drop {
    TensorInfo* dest;

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("dest", dest);
    }

M
Megvii Engine Team 已提交
83
    const char* get_name() const { return "Drop"; }
84 85 86 87
};

struct SetOption {
    std::string key;
88
    size_t value;
89 90 91 92 93 94 95

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("key", key);
        functor("value", value);
    }

M
Megvii Engine Team 已提交
96
    const char* get_name() const { return "SetOption"; }
97 98 99
};

struct StartProfile {
100
    std::unordered_set<TensorInfo*> capture_tensors;
101 102 103 104

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {}

M
Megvii Engine Team 已提交
105
    const char* get_name() const { return "StartProfile"; }
106 107 108
};

struct StopProfile {
109
    std::unordered_set<TensorInfo*> escape_tensors;
110 111

    template <typename TFunctor>
112
    void get_props(TFunctor&& functor) const {}
113

M
Megvii Engine Team 已提交
114
    const char* get_name() const { return "StopProfile"; }
115 116 117 118 119 120 121 122 123 124
};

struct PushScope {
    std::string scope_name;

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("scope_name", scope_name);
    }

M
Megvii Engine Team 已提交
125
    const char* get_name() const { return "PushScope"; }
126 127 128 129 130 131 132 133 134 135
};

struct PopScope {
    std::string scope_name;

    template <typename TFunctor>
    void get_props(TFunctor&& functor) const {
        functor("scope_name", scope_name);
    }

M
Megvii Engine Team 已提交
136
    const char* get_name() const { return "PopScope"; }
137 138
};

M
Megvii Engine Team 已提交
139
using CommandData = std::variant<
140 141
        Put, ApplyOp, Del, GetValue, Drop, SetOption, StartProfile, StopProfile,
        PushScope, PopScope>;
142

143 144 145 146 147 148
struct Command {
    uint64_t id;
    CommandData data;
    StackManager::Trace trace;
};
// using IdentifiedCommand = std::pair<uint64_t, Command>;
149

M
Megvii Engine Team 已提交
150
}  // namespace interpreter::intl
151 152

template <>
M
Megvii Engine Team 已提交
153
struct ToStringTrait<interpreter::intl::Command> {
154
    std::string operator()(const interpreter::intl::Command& cmd) const {
M
Megvii Engine Team 已提交
155 156 157 158 159 160 161 162 163 164 165 166 167 168
        std::string content = std::visit(
                [](const auto& cmd) {
                    std::string result = cmd.get_name();
                    result += "{";
                    cmd.get_props([&](const char* key, auto&& value) {
                        result += key;
                        result += ": ";
                        result += to_string(value);
                        result += ",";
                    });
                    result += "}";
                    return result;
                },
                cmd.data);
169
        return content;
170 171 172
    }
};

M
Megvii Engine Team 已提交
173
}  // namespace mgb::imperative