handle.cpp 4.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
/**
 * \file dnn/src/common/handle.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megdnn/basic_types.h"

#include "src/common/handle_impl.h"
#include "src/common/utils.h"
#include "src/fallback/handle.h"
#include "src/naive/handle.h"

#include "midout.h"

#if MEGDNN_X86
#include "src/x86/handle.h"
#endif


#if MEGDNN_WITH_CUDA
#include "src/cuda/handle.h"
#endif


using namespace megdnn;

MIDOUT_DECL(HandlePlatform);
MIDOUT_DECL(HandleOpr);

Handle::Handle(megcoreComputingHandle_t computing_handle, HandleType type)
        : m_computing_handle(computing_handle), m_handle_type(type) {}

std::unique_ptr<Handle> Handle::make(megcoreComputingHandle_t computing_handle,
                                     int debug_level) {
    (void)debug_level;
    megcoreDeviceHandle_t device_handle;
    megcorePlatform_t platform;
    megcoreGetDeviceHandle(computing_handle, &device_handle);

    megcoreGetPlatform(device_handle, &platform);
    if (platform == megcorePlatformCPU) {
        // only enable midout for CPU, becuase CPU might be unused when some
        // other platforms are used
        MIDOUT_BEGIN(HandlePlatform, midout_iv(megcorePlatformCPU)) {
        // CPU
#if MEGDNN_NAIVE
            return make_unique<naive::HandleImpl>(computing_handle);
#else
            if (debug_level == 0) {
#if MEGDNN_X86
                // Because of ICC bug, we cannot use make_unique here. It will
                // trigger an internal compiler error.
                return std::unique_ptr<x86::HandleImpl>(
                        new x86::HandleImpl(computing_handle));
                // return make_unique<x86::HandleImpl>(computing_handle);
#else
                return make_unique<fallback::HandleImpl>(computing_handle);
#endif
            } else if (debug_level == 1) {
                return make_unique<fallback::HandleImpl>(computing_handle);
            } else if (debug_level == 2) {
                return make_unique<naive::HandleImpl>(computing_handle);
            } else {
                megdnn_throw(megdnn_mangle("Debug level must be 0/1/2."));
            }
        }
        MIDOUT_END();
#endif
        }
        else {
            // CUDA
            megdnn_assert_internal(platform == megcorePlatformCUDA);
#if MEGDNN_WITH_CUDA
            return make_unique<cuda::HandleImpl>(computing_handle);
#else
            return nullptr;
#endif
        }
    }


    void Handle::set_destructor(const thin_function<void()>& d) {
        megdnn_assert(!m_destructor, "destructor can be set only once");
        m_destructor = d;
    }

    Handle::~Handle() {
        if (m_destructor)
            m_destructor();
        m_alive_magic = 0;
    }

    size_t Handle::alignment_requirement() const {
        // default to 32
        return 32;
    }

    size_t Handle::image2d_pitch_alignment() const {
        megdnn_throw("image2d tensor format not supported on this handle");
    }

    bool Handle::check_cross_dev_copy_constraint(const TensorLayout& src) {
        return src.is_contiguous();
    }

    void Handle::on_opr_destructed(OperatorBase * opr) {
        if (m_alive_magic != ALIVE_MAGIC) {
            megdnn_log_error(
                    "Handle is destructed before opr gets destructed. "
                    "Please fix the destruction order as this would cause "
                    "undefined memory access. "
                    "Abort now to avoid further problems.");
            abort();
        }
        if (m_on_opr_destructed) {
            m_on_opr_destructed(opr);
        }
    }

    OperatorBase::~OperatorBase() { m_handle->on_opr_destructed(this); }

    template <typename Opr>
    std::unique_ptr<Opr> Handle::create_operator() {
#define CASE(etype, nm)                                                        \
    case HandleType::etype: {                                                  \
        MIDOUT_BEGIN(HandleOpr, Opr, midout_iv(HandleType::etype)) {           \
            return static_cast<nm::HandleImpl*>(this)->create_operator<Opr>(); \
        }                                                                      \
        MIDOUT_END();                                                          \
    }

        switch (m_handle_type) {
            CASE(NAIVE, naive);
#if !MEGDNN_NAIVE
            CASE(FALLBACK, fallback);
#if MEGDNN_X86
            CASE(X86, x86);
#endif
#endif  // !MEGDNN_NAIVE
#if MEGDNN_WITH_CUDA
            CASE(CUDA,cuda);
#endif
            default:
                megdnn_throw(megdnn_mangle("bad handle type"));
        }
#undef CASE
    }

#define INST(opr) template std::unique_ptr<opr> Handle::create_operator();
        MEGDNN_FOREACH_OPR_CLASS(INST)
#undef INST
// vim: syntax=cpp.doxygen