impl.h 7.5 KB
Newer Older
1 2 3 4
/**
 * \file src/core/impl/comp_node/mem_alloc/impl.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

#include "megbrain/comp_node/alloc.h"

#include <set>
#include <map>
#include <unordered_map>
#include <atomic>
#include <vector>

namespace mgb {
namespace mem_alloc {

class DevMemAllocImpl;

class MemAllocImplHelper: virtual public MemAllocBase {
    friend class DevMemAllocImpl;

    protected:
        struct MemAddr {
            //! whether it is head of a chunk from raw allocator; if true, it
            //! could not be merged with chunks with lower address
            bool is_head = false;
            size_t addr = -1;

            void* addr_ptr() const {
                return reinterpret_cast<void*>(addr);
            }

            bool operator < (const MemAddr &rhs) const {
                return addr < rhs.addr;
            }

            MemAddr operator + (size_t delta) const {
                return {false, addr + delta};
            }
        };

        struct FreeBlock {
            MemAddr addr;
            size_t size = -1;

            size_t end() const {
                return addr.addr + size;
            }
        };

        struct FreeCmpBySize{
            bool operator() (const FreeBlock &a, const FreeBlock &b) const {
                // prefer more recent (hotter) block
                return a.size < b.size || (a.size == b.size && a.addr < b.addr);
            }
        };

        struct BlkByAddrIter;
        struct FreeBlockAddrInfo;

        //! free blocks sorted by size, and map to corresponding iterator in
        //! m_free_blk_addr
        std::map<FreeBlock, BlkByAddrIter, FreeCmpBySize> m_free_blk_size;

        //! map from address to size and size iter
        std::map<size_t, FreeBlockAddrInfo> m_free_blk_addr;

        std::mutex m_mutex;

        struct BlkByAddrIter {
            decltype(m_free_blk_addr.begin()) aiter;
        };

        struct FreeBlockAddrInfo {
            bool is_head;   //! always equals to siter->first.addr.is_head
            size_t size;
            decltype(m_free_blk_size.begin()) siter;
        };

        /*!
         * \brief merge a block into free list, without locking
         */
        void merge_free_unsafe(FreeBlock block);

        /*!
         * \brief directly insert a free block into m_free_blk_size and
         *      m_free_blk_addr, without merging
         */
97
        virtual void insert_free_unsafe(const FreeBlock &block);
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117

        /*!
         * \brief allocate from parent allocator; this method must either return
         *      a valid address or throw an exception
         *
         * m_free_blk_addr and m_free_blk_size must be maintained if necessary
         */
        virtual MemAddr alloc_from_parent(size_t size) = 0;

        /*!
         * \brief get name of this allocator
         */
        virtual std::string get_name() const = 0;

        MemAddr do_alloc(size_t size, bool allow_from_parent,
                bool log_stat_on_error = false);

        //! get free mem for this allocator, without locking
        FreeMemStat get_free_memory_self_unsafe();

118 119 120 121
#if !MGB_BUILD_SLIM_SERVING
        std::pair<size_t, size_t> get_free_left_and_right(size_t begin_ptr, size_t end_ptr) override;
#endif

122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
    public:
        void print_memory_state() override;

        FreeMemStat get_free_memory() override final;
};


class StreamMemAllocImpl final: public StreamMemAlloc,
                                public MemAllocImplHelper {
    struct AllocatedBlock {
        bool is_head;
        size_t size;
    };

    DevMemAllocImpl *m_dev_alloc;
    int m_stream_id;

    //! map from address to block info
    std::unordered_map<void*, AllocatedBlock> m_allocated_blocks;

    void* alloc(size_t size) override;

    void free(void *addr) override;

    void get_mem_info(size_t& free, size_t& tot) override;

    std::string get_name() const override;

    MemAddr alloc_from_parent(size_t size) override;
    size_t get_used_memory() override;
    FreeMemStat get_free_memory_dev() override;

    public:
        StreamMemAllocImpl(DevMemAllocImpl *dev_alloc, int stream_id):
            m_dev_alloc(dev_alloc), m_stream_id(stream_id)
        {}
};

160 161 162 163 164 165
/*!
 * \Note: DevMemAlloc has two-level structure, but when only one stream was
 * registered into the DevMemAlloc, the DevMemAlloc would behave like a
 * single-level allocator(i.e. only the FreeBlock pool in its child stream
 * allocator will be used) for better performance
 */
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
class DevMemAllocImpl final: public DevMemAlloc,
                             public MemAllocImplHelper {
    friend class StreamMemAllocImpl;
    int m_device;
    std::shared_ptr<RawAllocator> m_raw_allocator;
    std::shared_ptr<DeviceRuntimePolicy> m_runtime_policy;
    ThinHashMap<StreamKey, std::unique_ptr<StreamMemAllocImpl>> m_stream_alloc;

    //!< blocks allocated from raw alloc, addr to size
    std::unordered_map<void*, size_t> m_alloc_from_raw;

    size_t m_tot_allocated_from_raw = 0;
    std::atomic_size_t m_used_size{0};

    /*!
     * \brief gather all free blocks from child streams, and release full chunks
     *      back to parent allocator
     * \return number of bytes released
     */
    size_t gather_stream_free_blk_and_release_full() override;

    StreamMemAlloc* add_stream(StreamKey stream) override;

    MemAddr alloc_from_parent(size_t size) override;

    std::string get_name() const override {
        return ssprintf("dev allocator %d", m_device);
    }

    const std::shared_ptr<RawAllocator>& raw_allocator() const override {
        return m_raw_allocator;
    }

    const std::shared_ptr<DeviceRuntimePolicy>& device_runtime_policy()
            const override {
        return m_runtime_policy;
    }

    size_t get_used_memory() override { return m_used_size.load(); }

206 207 208 209 210 211 212 213
    void insert_free_unsafe(const FreeBlock &block) override;

    /*!
     * \brief return stream allocator if DevMemAlloc has single child,
     * otherwise return nullptr
     */
    StreamMemAllocImpl* get_single_child_stream_unsafe();

214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
public:
    DevMemAllocImpl(
            int device, size_t reserve_size,
            const std::shared_ptr<mem_alloc::RawAllocator>& raw_allocator,
            const std::shared_ptr<mem_alloc::DeviceRuntimePolicy>&
                    runtime_policy);

    ~DevMemAllocImpl();

    int device() const { return m_device; }

    MemAddr alloc(size_t size);

    void print_memory_state() override;

    FreeMemStat get_free_memory_dev() override;
};

232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
class SimpleCachingAllocImpl : public SimpleCachingAlloc,
                               public MemAllocImplHelper {
    struct AllocatedBlock {
        bool is_head;
        size_t size;
    };

    std::unique_ptr<RawAllocator> m_raw_alloc;
    std::unordered_map<void*, size_t> m_alloc_from_raw;
    std::unordered_map<void*, AllocatedBlock> m_allocated_blocks;
    size_t m_used_size = 0;

public:
    SimpleCachingAllocImpl(std::unique_ptr<RawAllocator> m_raw_alloc);
    ~SimpleCachingAllocImpl();

    void* alloc(size_t size) override;
    void free(void* ptr) override;
    size_t get_used_memory() override;
    FreeMemStat get_free_memory_dev() override;

protected:
    MemAddr alloc_from_parent(size_t size) override;
    std::string get_name() const override;
};

258 259 260
}
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}