async_releaser.h 2.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/**
 * \file imperative/src/impl/async_releaser.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

#include "megbrain/comp_node.h"
#include "megbrain/imperative/blob_manager.h"
16
#include "megbrain/imperative/resource_manager.h"
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
#include "megbrain/system.h"

#include "./event_pool.h"

namespace mgb {
namespace imperative {

class AsyncReleaser : public CompNodeDepedentObject {
    struct WaiterParam {
        CompNode cn;
        CompNode::Event* event;
        BlobPtr blob;
        HostTensorStorage::RawStorage storage;
    };
    class Waiter final : public AsyncQueueSC<WaiterParam, Waiter> {
        AsyncReleaser* m_par_releaser;

    public:
        // disable busy wait by set max_spin=0 to save CPU cycle
        Waiter(AsyncReleaser* releaser)
M
Megvii Engine Team 已提交
37
                : AsyncQueueSC<WaiterParam, Waiter>(0), m_par_releaser(releaser) {}
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64

        void process_one_task(WaiterParam& param) {
            if (param.event->finished()) {
                param.blob.reset();
                param.storage.reset();
                EventPool::without_timer().free(param.event);
                return;
            }

            using namespace std::literals;
            std::this_thread::sleep_for(1us);
            add_task(std::move(param));
        }
        void on_async_queue_worker_thread_start() override {
            sys::set_thread_name("releaser");
        }
    };
    Waiter m_waiter{this};

protected:
    std::shared_ptr<void> on_comp_node_finalize() override {
        m_waiter.wait_task_queue_empty();
        return {};
    }

public:
    static AsyncReleaser* inst() {
65 66
        static auto* releaser = ResourceManager::create_global<AsyncReleaser>();
        return releaser;
67 68
    }

M
Megvii Engine Team 已提交
69
    ~AsyncReleaser() { m_waiter.wait_task_queue_empty(); }
70 71 72 73 74 75 76

    void add(BlobPtr blob, CompNode cn) { add(cn, std::move(blob), {}); }

    void add(const HostTensorND& hv) {
        add(hv.comp_node(), {}, hv.storage().raw_storage());
    }

M
Megvii Engine Team 已提交
77
    void add(CompNode cn, BlobPtr blob, HostTensorStorage::RawStorage storage = {}) {
78 79 80 81 82
        auto event = EventPool::without_timer().alloc(cn);
        event->record();
        m_waiter.add_task({cn, event, std::move(blob), std::move(storage)});
    }
};
M
Megvii Engine Team 已提交
83 84
}  // namespace imperative
}  // namespace mgb