physical_tensor.h 4.3 KB
Newer Older
1
/**
M
Megvii Engine Team 已提交
2 3
 * \file imperative/src/include/megbrain/imperative/physical_tensor.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
4
 *
M
Megvii Engine Team 已提交
5
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
6
 *
M
Megvii Engine Team 已提交
7 8 9
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
 */

#pragma once

#include <mutex>
#include <memory>

#include "megbrain/tensor.h"

namespace mgb {
namespace imperative {

/************************** Tensor *****************************/
class Blob;
using BlobPtr = std::shared_ptr<Blob>;

class BlobManagerImpl;

class Blob : public NonCopyableObj {
public:
    Blob(const DeviceTensorStorage& s);
    Blob(CompNode cn, size_t sz);
    ~Blob();

    template<typename ...Args>
    static BlobPtr make(Args&& ...args) {
        return std::make_shared<Blob>(std::forward<Args>(args)...);
    }

    using RawStorage = DeviceTensorStorage::RawStorage;
    const RawStorage& storage();

    const CompNode& comp_node() const {
        return m_comp_node;
    }

    size_t size() const {
        return m_size;
    }
49 50 51 52

    size_t id() const {
        return m_id;
    }
53 54 55 56 57
private:
    friend class BlobManagerImpl;
    CompNode m_comp_node;
    mutable RawStorage m_storage;
    size_t m_size = 0;
58
    size_t m_id;
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
};

struct EventDeleter {
    void operator()(CompNode::Event*);
};
using EventPtr = std::unique_ptr<CompNode::Event, EventDeleter>;

class Tensor;
using TensorPtr = std::shared_ptr<Tensor>;
class Tensor : public NonCopyableObj {
public:
    Tensor() = default;
    Tensor(BlobPtr blob, const TensorLayout& layout, size_t offset = 0, const HostTensorND& hv = {});
    Tensor(BlobPtr blob, const TensorLayout& layout, const HostTensorND& hv = {})
        : Tensor(std::move(blob), layout, 0, hv) {};
    Tensor(const HostTensorND &hv);
    Tensor(const DeviceTensorND &dv, const HostTensorND& hv = {});
    Tensor(const TensorLayout& layout, const CompNode& cn);
    Tensor(const BlobPtr blob, const size_t offset, const TensorLayout& layout);

    static TensorPtr make(const HostTensorND& hv);

    template<typename T, typename = std::enable_if_t<std::is_same_v<std::decay_t<T>, HostTensorND>>>
    static TensorPtr make(T&& hv) {
        TensorPtr (*f)(const HostTensorND&) = &make;
        return f(std::forward<T>(hv));
    };

    template<typename ...Args>
    static TensorPtr make(Args&& ...args) {
        return std::make_shared<Tensor>(std::forward<Args>(args)...);
    }

    CompNode comp_node() const {
        mgb_assert(m_blob, "uninitialized tensor.");
        return m_blob->comp_node();
    }

97 98 99 100
    DType dtype() const {
        return m_layout.dtype;
    }

101 102 103 104
    TensorLayout layout() const {
        return m_layout;
    }

105 106 107 108
    const TensorShape& shape() const {
        return m_layout;
    }

109 110 111 112
    size_t offset() const {
        return m_offset;
    }

113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
    DeviceTensorND dev_tensor();

    static TensorPtr make_scalar(DTypeScalar value, CompNode cn);

    TensorPtr make_scalar(DTypeScalar value) const {
        mgb_assert(m_blob, "uninitialized tensor.");
        return make_scalar(value, m_blob->comp_node());
    }

    BlobPtr& blob() {
        return m_blob;
    }

    void fetch_value();
    bool value_fetched();
    TensorPtr sub(size_t offset, TensorShape shape);

    // m_value is set once readonly afterwards
    // so the return value is thread safe
    const HostTensorND& get_value();
    // return a pointer instead of a reference to ensure thread safety
    const HostTensorND* try_get_value();

    void add_release_callback(CompNode cn);
    CompNode::Event* get_or_create_event();
138 139 140 141

    // Make sure all static objects required to destruct a tensor has completed
    // construction. All static storage duration object that holds tensors must
    // call this method before their constructors completes.
142
    static void static_initialize();
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
private:

    TensorLayout m_layout;
    BlobPtr m_blob;
    size_t m_offset;
    std::mutex m_mtx;
    HostTensorND m_value;
    EventPtr m_value_ready = nullptr;
};

struct LogicalTensorDesc {
    TensorLayout layout;
    CompNode comp_node;
    DeviceTensorND value; // cpu:default
};

} // namespace imperative
} // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}