trace.h 13.4 KB
Newer Older
1 2 3 4
#pragma once

#include <chrono>
#include <future>
5
#include <set>
6 7 8 9 10 11 12 13 14 15 16 17 18 19
#include <variant>
#include "megbrain/gopt/inference.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/interpreter.h"
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/utils/box.h"
#include "megbrain/imperative/utils/helper.h"
#include "megbrain/opr/io.h"
#include "megbrain/serialization/serializer.h"

namespace mgb::imperative {

struct TraceResult {
    struct SeqItem {
20 21 22 23 24 25 26
        enum OpKind {
            Unknown,
            TraceMarkVar,
            Rename,
            IOMarkVar,
            CreateTensor,
        };
27 28 29
        std::shared_ptr<OpDef> op;
        SmallVector<size_t> inputs;
        SmallVector<size_t> outputs;
30
        OpKind kind = OpKind::Unknown;
31 32
    };

33 34
    using OpKind = SeqItem::OpKind;

35 36 37 38 39 40 41 42 43 44 45
    struct VarInfo {
        enum Kind {
            External,  // End point of traced graph, its value is received from
                       // environment
            Constant,  // Also end point, but its value is constant in all executions,
                       // so we don't need to get from env every time, just capture it
            Internal,  // Not end point, produced by some op (or just forwarded) from
                       // op_seq
        };

        size_t id;
46 47
        DTypeValue::ref_t dtype;
        CompNodeValue::ref_t device;
48

49 50
        // if exists, for input: assert equal
        // for output: get_data/shape/value
51 52 53
        ValueRef bound_data;
        std::string mark;
        std::string name;
54
        int handle_id;
55 56 57 58 59

        Kind kind;
        bool value_required = false;
        bool data_required = false;
        bool shape_required = false;
60 61
        std::set<size_t> inp_marker;
        std::set<size_t> out_marker;
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
        TensorShape shape;
    };

    using VarKind = VarInfo::Kind;

    std::vector<SeqItem> seq;
    std::vector<VarInfo> vars;

    /**
     * \brief dump to mgb computing graph
     *
     * \param graph mgb computing graph
     * \param inputs (input_id, input_name, input_shape)
     * \param outputs (output_id, outupt_name)
     * \param prefer_input_names
     * \return VarNodeArray output nodes
     */
    VarNodeArray dump(
            ComputingGraph& graph,
            std::vector<std::tuple<size_t, std::string, TensorShape>> inputs,
            std::vector<std::pair<size_t, std::string>> outputs,
            bool prefer_input_names);
};

/**
 * \brief mark an var as arg/kwarg/output
 *
 */
class TraceMarkVar : public OperatorImpl<TraceMarkVar, Operator::IdentityLike> {
private:
    std::string m_mark;

public:
    TraceMarkVar(std::string mark) : m_mark(mark) {}

    std::string mark() const { return m_mark; }

    std::string to_string() const override {
        return ssprintf("TraceMarkVar{mark=%s}", imperative::quoted(m_mark).c_str());
    }
102 103

    std::string raw_type() const { return "TraceMarkVar"; }
104 105
};

106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
class IOMarkVar : public OperatorImpl<IOMarkVar, Operator::IdentityLike> {
public:
    enum Kind {
        Input,
        Output,
    };

private:
    size_t m_mark;
    Kind m_kind;

public:
    IOMarkVar(size_t mark, Kind kind) : m_mark(mark), m_kind(kind) {}

    size_t mark() const { return m_mark; }
    Kind kind() const { return m_kind; }

    std::string to_string() const override { return ssprintf("IOMarkVar"); }
    std::string raw_type() const override { return "IOMarkVar"; }
};

127
class TracingValue final : public ObjectValue<TracingValue> {
128 129 130 131 132
private:
    ValueRef m_value = {};
    size_t m_id = 0;

public:
133
    TracingValue(ValueRef value, size_t id) : m_value(value), m_id(id) {}
134 135 136 137 138 139 140 141 142 143 144 145
    ValueRef value() const { return m_value; }
    size_t id() const { return m_id; }

    std::string to_string() const override {
        return ssprintf(
                "TracingValue{\"id\"=%zu, \"value\"=%s}", id(),
                value().to_string().c_str());
    }

    void on_watch() override { value().watch(); }

    void on_unwatch() override { value().unwatch(); }
146 147

    void clear() override { m_value = {}; }
148 149 150 151 152 153 154 155 156 157 158 159 160
};

/**
 * \brief trace operation sequence to TraceResult
 *
 * TracingTransformation records and forwards all operations to next layer,
 * as if it's transparent. When execution ends, it exports an operation sequence,
 * which is usually used to build CompiledTransformation.
 */
class TracingTransformation final : public Transformation {
public:
    using VarInfo = TraceResult::VarInfo;
    using VarKind = VarInfo::Kind;
161
    using OpKind = TraceResult::SeqItem::OpKind;
162 163 164 165 166

private:
    std::vector<TraceResult::SeqItem> m_seq;
    std::vector<TraceResult::VarInfo> m_vars;
    std::vector<TracingValue::weak_ref_t> m_weak_vars;
167
    std::unordered_map<size_t, size_t> extern_var_to_id;
168 169
    bool m_capture_as_const = false;
    bool m_record_input_shapes = false;
170
    bool m_record_all_shapes = false;
171
    ObjectType<TracingValue> m_value_type{"TracingValue"};
172

173 174 175 176
public:
    std::unordered_map<size_t, size_t> inpmark_to_id;
    std::unordered_map<size_t, size_t> outmark_to_id;

177 178 179 180 181 182 183 184 185 186 187 188 189 190
public:
    TracingTransformation(bool capture_as_const, bool record_input_shapes)
            : m_capture_as_const(capture_as_const),
              m_record_input_shapes(record_input_shapes) {}

    /**
     * \brief record values for trace
     *
     * \param value value to be traced
     * \param capture whether capture value or not
     * \param kind External, Constant or Internal
     * \return TypedValueRef<TracingValue> traced value
     */
    TypedValueRef<TracingValue> record_var(ValueRef value, bool capture, VarKind kind) {
191 192 193 194
        if (kind == VarKind::External &&
            extern_var_to_id.find(value.id()) != extern_var_to_id.end()) {
            return m_value_type.make(value, extern_var_to_id[value.id()]);
        }
195
        size_t id = m_vars.size();
196 197 198
        if (kind == VarKind::External) {
            extern_var_to_id[value.id()] = id;
        }
199
        auto wrapped_value = m_value_type.make(value, id);
200
        m_vars.push_back({id, value.dtype(), value.device()});
201 202 203 204 205
        auto& var = m_vars.back();
        if (capture) {
            var.bound_data = value;
        }
        var.kind = kind;
206 207
        if ((m_record_input_shapes && kind != VarKind::Internal) ||
            m_record_all_shapes) {
208 209
            var.shape = value.shape()->as_tensor_shape();
        }
210 211
        if (m_record_all_shapes)
            var.handle_id = value.handle_id();
212 213 214 215 216 217 218
        if (auto name = value.name()) {
            var.name = *name;
        }
        m_weak_vars.push_back(wrapped_value);
        return wrapped_value;
    }
    ValueRef unwrap_var(ValueRef value) {
219
        if (auto* tracing_value = value.as(m_value_type)) {
220 221 222 223 224
            return tracing_value->value();
        }
        return value;
    }

225
    ValueRefList apply_transformation(
226 227 228
            const Operator& op, Span<ValueRef> inputs) override;

    ValueRef unwrap(ValueRef value) override {
229
        if (auto* tracing_value = value.as(m_value_type)) {
230 231 232 233 234 235 236 237
            return tracing_value->value();
        }
        return value;
    }

    std::string name() const override { return "TracingTransformation"; }

    void on_unregister() noexcept override;
238
    void postprocess_trace_result();
239
    TraceResult get_result() { return {m_seq, m_vars}; }
240
    void enable_record_all_shapes() { m_record_all_shapes = true; }
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
};

class TraceError : public std::exception {
private:
    std::string m_message;

public:
    TraceError(std::string reason) {
        m_message = ssprintf("trace error because %s", reason.c_str());
    }
    const char* what() const noexcept override { return m_message.c_str(); }
};

/**
 * \brief boost with traced result from TracingTransformation
 *
 * CompiledTransformation is built with an operation sequence. It compiles a megbrain
 * graph with the sequence and handle operation requests with this graph. Besides that,
 * it also checks that if current operation is same as previous one in seq.
 */
class CompiledTransformation final : public Transformation {
public:
    using VarInfo = TraceResult::VarInfo;
    using VarKind = VarInfo::Kind;
265
    using OpKind = TraceResult::SeqItem::OpKind;
266 267 268 269 270 271 272

    struct VarAccessor {
        VarNode* node;
        std::function<TensorShape()> shape_getter;
        std::function<DeviceTensorND()> data_getter;
        std::function<HostTensorND()> value_getter;
        std::function<void(DeviceTensorND)> data_setter;
273 274 275
        std::function<void(std::exception_ptr)> exc_setter;
    };

276
    class TracedValue final : public ObjectValue<TracedValue> {
277 278 279 280 281 282 283 284 285
    private:
        size_t m_id = 0;
        VarInfo* m_var = nullptr;
        VarAccessor* m_accessor = nullptr;
        mutable ShapeValue::ref_t m_shape;
        mutable DTypeValue::ref_t m_dtype;
        mutable CompNodeValue::ref_t m_comp_node;

    public:
286
        TracedValue(size_t id, VarInfo* var, VarAccessor* accessor)
287 288 289 290 291 292 293 294 295 296 297 298 299 300
                : m_id(id), m_var(var), m_accessor(accessor) {}
        size_t id() const { return m_id; }
        ShapeValue::ref_t shape() const;
        DTypeValue::ref_t dtype() const;
        CompNodeValue::ref_t comp_node() const;
        const VarAccessor& accessor() const;

        void set_exception(std::exception_ptr exc) const {
            m_accessor->exc_setter(exc);
        }

        std::string to_string() const override {
            return ssprintf("TracedValue{\"id\"=%zu}", id());
        }
301 302

        void clear() override {}
303 304 305 306 307 308
    };

private:
    std::vector<TraceResult::SeqItem> m_seq;
    std::vector<TraceResult::VarInfo> m_vars;
    std::vector<VarAccessor> m_var_accessors;
309
    std::unordered_map<std::string, size_t> mark2id;
310 311 312 313 314 315 316 317
    size_t m_pc = 0;
    std::shared_ptr<ComputingGraph> m_graph;
    std::unique_ptr<cg::AsyncExecutable> m_executable;
    std::vector<TracedValue::weak_ref_t> m_weak_values;
    std::thread m_graph_executor;
    std::function<bool(ValueRef, ValueRef)> m_value_comparator;
    bool m_input_shape_static;
    std::mutex m_mutex;
318
    std::condition_variable m_cv;
319
    std::exception_ptr m_graph_exc;
320
    int m_graph_status = 0;  // 0 = stop, 1 = running, 2 = finalizing
321 322
    std::vector<std::shared_ptr<BoxBase>> m_boxes;
    ComputingGraph::OutputSpec m_output_spec;
323
    ObjectType<TracedValue> m_value_type{"TracedValue"};
324
    std::set<size_t> m_setted_extern;
325 326 327 328 329 330 331 332 333

public:
    CompiledTransformation(TraceResult result, bool input_shape_static)
            : m_seq(result.seq),
              m_vars(result.vars),
              m_input_shape_static(input_shape_static) {
        m_graph = ComputingGraph::make();
        options().no_force_inplace = true;
        options().async_exec_level = 0b100;
334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354
        m_graph_executor = std::thread([&] {
            while (true) {
                std::unique_lock lock{m_mutex};
                m_cv.wait(lock, [&] { return m_graph_status != 0; });
                lock.unlock();
                if (m_graph_status == 2) {
                    break;
                }
                try {
                    m_executable->execute();
                    m_executable->wait();
                } catch (...) {
                    auto exc = std::current_exception();
                    set_exception(exc);
                }
                lock.lock();
                m_graph_status = 0;
                lock.unlock();
                m_cv.notify_all();
            }
        });
355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
    }

    ComputingGraph& graph() { return *m_graph; }

    ComputingGraph::Options& options() { return m_graph->options(); }

    /**
     * \brief Set the value comparator object (usually from python)
     *
     * \param comparator
     */
    void set_value_comparator(std::function<bool(ValueRef, ValueRef)> comparator) {
        m_value_comparator = comparator;
    }

    void compile();

    void recompile();

    void assert_tensor_equal(ValueRef lhs, ValueRef rhs);

    /**
     * \brief handle input for trace
     *
     * 1. For external, set input value to data_setter;
     * 2. For const, do nothing;
     * 3. For internal, assert var id;
     * *. Always assert data equals if there are data bound.
     *
     * \param id
     * \param value
     */
    void trace_input(size_t id, ValueRef value);

    /**
     * \brief make a placeholder for output.
     *
     * \param id trace_id
     * \return TracedValue::ref_t output placeholder, would be reset to real value when
     * trace exits
     */
    TracedValue::ref_t trace_output(size_t id);

    TraceResult::SeqItem& next_instruction();

400 401 402 403 404 405 406 407
    ValueRefList apply_op(const ApplyOp& apply_op, Span<ValueRef> inputs);

    ValueRefList apply_get_attr(const GetAttr& get_attr, Span<ValueRef> inputs);

    ValueRefList apply_create_tensor(
            const CreateTensor& create_tensor, Span<ValueRef> inputs);

    ValueRefList apply_transformation(
408 409 410 411 412
            const Operator& op, Span<ValueRef> inputs) override;

    void on_unregister() noexcept override;

    ValueRef unwrap(ValueRef value) override {
413
        mgb_assert(!value.is(m_value_type));
414 415 416
        return value;
    }

417
    VarAccessor& get_accessor_by_id(size_t id) { return m_var_accessors[id]; }
418

419 420
    std::string name() const override { return "CompiledTransformation"; }
    void set_pc_to_end() { m_pc = m_seq.size(); }
421 422 423 424 425 426 427 428 429 430 431 432
    void execute();

    void wait();

    std::exception_ptr set_exception(std::exception_ptr exc) noexcept;

    template <typename T>
    std::shared_ptr<Box<T>> make_box() {
        auto box = Box<T>::make();
        m_boxes.push_back(box);
        return box;
    }
433 434 435 436 437 438 439 440 441

    ~CompiledTransformation() {
        {
            MGB_LOCK_GUARD(m_mutex);
            m_graph_status = 2;
        }
        m_cv.notify_all();
        m_graph_executor.join();
    }
442 443 444
};

}  // namespace mgb::imperative