opr_delegate.h 2.1 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/common/opr_delegate.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

#include "megdnn/handle.h"
#include "megdnn/oprs/base.h"

#include "src/common/utils.h"

namespace megdnn {

/*!
 * \brief get a handle that dispatches to caller cpu thread
 *
 * Usually used for calling other opr impls from some opr impl. You probably
 * want to use CpuOprDelegationStorage instead.
 */
27 28
MGE_WIN_DECLSPEC_FUC const std::shared_ptr<Handle>& inplace_cpu_handle(
        int debug_level = 0);
29 30 31 32 33 34 35 36 37 38 39 40 41

/*!
 * \brief storage for oprs on inplace CPU handle
 *
 * This class takes care of thread safety and destruction order. Usage example:
 *
 *      MatrixMul* get_matmul() {
 *          static CpuOprDelegationStorage<> storage;
 *          return storage.get<MatrixMul>();
 *      }
 */
template <int nr_opr = 1>
class CpuOprDelegationStorage {
42
    DNN_MUTEX m_mtx;
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
    std::shared_ptr<Handle> m_handle;
    std::unique_ptr<OperatorBase> m_oprs[nr_opr];

public:
    ~CpuOprDelegationStorage();

    template <typename Opr, int idx = 0>
    Opr* get(const typename Opr::Param& param = {});
};

template <int nr_opr>
CpuOprDelegationStorage<nr_opr>::~CpuOprDelegationStorage() = default;

template <int nr_opr>
template <typename Opr, int idx>
Opr* CpuOprDelegationStorage<nr_opr>::get(const typename Opr::Param& param) {
    static_assert(idx < nr_opr, "invalid idx");
    if (!m_oprs[idx]) {
        MEGDNN_LOCK_GUARD(m_mtx);
        if (!m_oprs[idx]) {
            if (!m_handle) {
                m_handle = inplace_cpu_handle();
            }
            auto opr = m_handle->create_operator<Opr>();
            megdnn_assert(opr->is_thread_safe());
            opr->param() = param;
            m_oprs[idx] = std::move(opr);
        }
    }
    return static_cast<Opr*>(m_oprs[idx].get());
}

}  // namespace megdnn

// vim: syntax=cpp.doxygen