interpreter_impl.h 2.9 KB
Newer Older
M
Megvii Engine Team 已提交
1 2 3 4 5 6 7 8 9 10 11
/**
 * \file imperative/src/impl/interpreter_impl.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
#include <variant>
#include <future>

#include "megbrain/utils/mempool.h"
#include "megbrain/imperative/interpreter.h"


namespace mgb::imperative::interpreter::intl {

using Handle = Interpreter::Handle;

struct InterpreterImpl : Interpreter {
    std::unique_ptr<Channel> create_channel() override;
};

struct TensorInfo {
    TensorPtr ptr;
    LogicalTensorDesc desc;
    bool value_fetched = false;
};

struct Put {
    TensorInfo* dest;
    HostTensorND value;
};
struct ApplyOp {
    std::shared_ptr<OpDef> op;
    SmallVector<TensorInfo*> inputs;
    SmallVector<TensorInfo*> outputs;
};
struct Del {
    TensorInfo* dest;
};
struct GetValue {
    TensorInfo* dest;
};
using Command = std::variant<Put,
                             ApplyOp,
                             Del,
                             GetValue>;

struct ChannelImpl : Interpreter::Channel {
    ChannelImpl() : m_worker(this) {}
    ~ChannelImpl() override;

    Handle put(const HostTensorND& value) override;
M
Megvii Engine Team 已提交
58
    Handle put(const DeviceTensorND& value) override;
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76

    void del(Handle) override;

    SmallVector<Handle> apply_op(
            std::shared_ptr<OpDef> op,
            const SmallVector<Handle>& inputs) override;

    HostTensorND get_value(Handle) override;
    TensorShape get_shape(Handle) override;
    DType get_dtype(Handle) override;
    CompNode get_device(Handle) override;

    DeviceTensorND get_dev_tensor(Handle) override;

    void sync() override;
    void close() override;

    void config_async_level(int level) override;
77
    int get_async_level() override;
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104

private:
    TensorInfo* alloc();
    void free(TensorInfo*);

    void process_one_task(Command&);

    void check_worker_exc_unsafe();

    void produce_tensor(TensorInfo* dest, TensorPtr ptr);

    std::mutex m_mutex;
    std::condition_variable m_cv;
    MemPool<TensorInfo> m_pool;
    std::unordered_set<Handle> m_valid_handle;
    TensorInfo* m_waitee = nullptr;
    std::exception_ptr m_worker_exc;

    struct WorkQueue : AsyncQueueSC<Command, WorkQueue> {
        WorkQueue(ChannelImpl* owner) : m_owner(owner) {}
        void process_one_task(Command& cmd) {
            m_owner->process_one_task(cmd);
        }
    private:
        ChannelImpl* m_owner;
    } m_worker;

105 106 107 108 109
    //! config whether raise error exactly when invoking op.
    //! level 2: both device and user side errors are async;
    //! level 1: user side errors are sync;
    //! level 0: both sync.
    int m_async_level = 1;
110 111 112
};

} // namespace mgb::imperative::interpreter::intl