comp_node_env.h 9.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
/**
 * \file src/core/include/megbrain/comp_node_env.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

#include "megbrain/common.h"
#include "megbrain/comp_node.h"
#include "megbrain/utils/metahelper.h"
#include "megbrain/utils/thread.h"
#include "megbrain_build_config.h"

#include "megdnn/handle.h"


#if MGB_CUDA
#include <cuda_runtime.h>

#if MGB_ENABLE_LOGGING
#define MGB_CUDA_CHECK(expr)                                          \
    do {                                                              \
        cudaError_t __cuda_check_code = (expr);                       \
        if (!mgb_likely(__cuda_check_code == cudaSuccess)) {          \
            ::mgb::_on_cuda_error(#expr, __cuda_check_code, __FILE__, \
                                  __func__, __LINE__);                \
        }                                                             \
    } while (0)
#else
#define MGB_CUDA_CHECK(expr)                                            \
    do {                                                                \
        cudaError_t __cuda_check_code = (expr);                         \
        if (!mgb_likely(__cuda_check_code == cudaSuccess)) {            \
            ::mgb::_on_cuda_error(#expr, __cuda_check_code, "", "", 1); \
        }                                                               \
    } while (0)

44
#endif  // MGB_ENABLE_LOGGING
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330

#endif

//! whether to enable asynchronous initialization for CompNode and CompNodeEnv
#define MGB_ENABLE_COMP_NODE_ASYNC_INIT (MGB_CUDA)

//! whether AsyncErrorInfo is needed
#define MGB_NEED_MEGDNN_ASYNC_ERROR (MGB_CUDA)

#if MGB_ENABLE_COMP_NODE_ASYNC_INIT
#include <atomic>
#include <future>
#endif

#include <memory>
#include <type_traits>
#include "megbrain/utils/thin/function.h"

namespace mgb {

#if MGB_CUDA
[[noreturn]] void _on_cuda_error(const char* expr, cudaError_t err,
                                 const char* file, const char* func, int line);
#endif


class CPUDispatcher : public MegcoreCPUDispatcher {
public:
    using AffinityCallBack = thin_function<void(size_t)>;
    //! get number of tasks already dispatched
    virtual size_t get_nr_dispatched_tasks() const = 0;
    //! set the cpu affinity callback, the callback is
    //! thin_function<void(size_t)>
    virtual void set_affinity(AffinityCallBack&& /*affinity_cb*/) {
        mgb_assert(0, "The CompNode set_affinity is not implement");
    }
};


/*!
 * \brief CompNode environment
 *
 * CompNodeEnv contains necessary information to launch a kernel on a comp node,
 * or calling other libraries on a comp node. It has common fields for all comp
 * nodes and also specific fields for a given comp node type.
 *
 * Each CompNode is associated with a CompNodeEnv that could be retrieved by
 * CompNodeEnv::from_comp_node.
 *
 * Note: CUDA CompNodeEnv is initialized asynchronously. The env and property is
 * set synchronously, but m_lib_handle_manager would be initialized in the
 * future.
 */
class CompNodeEnv final : public NonCopyableObj {
public:
    using DeviceType = CompNode::DeviceType;
    using MemEventHandler =
            thin_function<void(size_t alloc_size, bool is_host, void* ptr)>;

    //! extra properties for a CompNodeEnv
    struct Property {
        //! type of the underlying device
        DeviceType type;

        //! alignment requirement in bytes, for memory allocating
        size_t mem_alignment = 0;
    };

    //! get user data by calling UserDataContainer::get_user_data_or_create;
    //! this method is thread-safe
    template <typename T, typename Maker>
    T& get_user_data(Maker&& maker) const {
        ensure_async_init_finished();
        MGB_LOCK_GUARD(m_user_data_container_mtx);
        return *m_user_data_container->get_user_data_or_create<T>(
                std::forward<Maker>(maker));
    }

    template <typename T>
    T& get_user_data() const {
        ensure_async_init_finished();
        MGB_LOCK_GUARD(m_user_data_container_mtx);
        return *m_user_data_container->get_user_data_or_create<T>(
                std::make_shared<T>);
    }

    //! check whether a user data object has been registered
    template <typename T>
    bool has_user_data() const {
        ensure_async_init_finished();
        MGB_LOCK_GUARD(m_user_data_container_mtx);
        return m_user_data_container->get_user_data<T>().second;
    }

    //! get property
    const Property& property() const { return m_property; }

    //! get the comp node to which this env belongs
    CompNode comp_node() const { return m_comp_node; }

    /*!
     * \brief create CompNodeEnv from comp_node
     */
    static inline const CompNodeEnv& from_comp_node(const CompNode& node);

    /*!
     * \brief activate this env for current thread
     *
     * Currently only calls cuda_env().activate() if type is cuda
     */
    void activate() const {
#if MGB_CUDA
        if (m_property.type == DeviceType::CUDA) {
            m_cuda_env.activate();
        }
#endif
    }

    /*!
     * \brief set a callback to be invoked on alloc/free events
     * \param[in,out] handler the new handler to be set; the previous handler
     *      would be returned
     */
    void mem_event_handler(MemEventHandler& handler) {
        m_mem_event_handler.swap(handler);
    }

    //! invoke mem event handler on a mem event; only be called from CompNode
    void on_mem_event(size_t alloc_size, bool is_host, void* ptr) {
        if (m_mem_event_handler) {
            m_mem_event_handler(alloc_size, is_host, ptr);
        }
    }

        // following are impls for various envs

#if MGB_CUDA
    struct CudaEnv {
        int device = -1;
        cudaStream_t stream = 0;
        cudaDeviceProp device_prop;

        void activate() const { MGB_CUDA_CHECK(cudaSetDevice(device)); }
    };

    const CudaEnv& cuda_env() const {
        if (mgb_unlikely(m_property.type != DeviceType::CUDA))
            on_bad_device_type(DeviceType::CUDA);
        ensure_async_init_finished();
        return m_cuda_env;
    }

    //! init this as a cuda env asynchronously
    void init_cuda_async(int dev, CompNode comp_node,
                         const ContinuationCtx<cudaStream_t>& cont);
#endif


    struct CpuEnv {
        using Task = CPUDispatcher::Task;
        using MultiThreadingTask = CPUDispatcher::MultiThreadingTask;
        using AffinityCallBack = thin_function<void(size_t)>;

        std::shared_ptr<CPUDispatcher> dispatcher;

        void dispatch(Task&& task) const {
            dispatcher->dispatch(std::move(task));
        }

        void dispatch(MultiThreadingTask&& task, size_t parallelism) const {
            dispatcher->dispatch(std::move(task), parallelism);
        }

        void set_affinity(AffinityCallBack&& cb) const {
            dispatcher->set_affinity(std::move(cb));
        }
    };

    const CpuEnv& cpu_env() const {
        if (mgb_unlikely(m_property.type != DeviceType::CPU))
            on_bad_device_type(DeviceType::CPU);
        return m_cpu_env;
    }

    //! init this as a cpu env
    void init_cpu(const CpuEnv& env, CompNode comp_node);

    void fini();

private:
    CompNode m_comp_node;
    Property m_property;
    MemEventHandler m_mem_event_handler;

#if MGB_CUDA
    CudaEnv m_cuda_env;
#endif
    CpuEnv m_cpu_env;

    std::unique_ptr<UserDataContainer> m_user_data_container;
    mutable RecursiveSpinlock m_user_data_container_mtx;

    [[noreturn]] void on_bad_device_type(DeviceType expected) const;

#if MGB_ENABLE_COMP_NODE_ASYNC_INIT
    //! whether async init is in future; set by init*_async methods
    std::atomic_bool m_async_init_need_wait{false};
    std::mutex m_async_init_mtx;
    std::future<void> m_async_init_future;
    std::thread::id m_async_init_tid;

    void ensure_async_init_finished() const {
        if (m_async_init_need_wait.load()) {
            const_cast<CompNodeEnv*>(this)->wait_async_init();
        }
    }

    void wait_async_init();
#else
    void ensure_async_init_finished() const {}
#endif
};

//! megdnn handle stored in a CompNodeEnv
class MegDNNHandle final : public UserDataContainer::UserData,
                           public std::enable_shared_from_this<MegDNNHandle> {
    MGB_TYPEINFO_OBJ_DECL;

    static int sm_default_dbg_level;
    megcoreDeviceHandle_t m_dev_hdl = nullptr;
    megcoreComputingHandle_t m_comp_hdl = nullptr;
    std::unique_ptr<megdnn::Handle> m_megdnn_handle;

#if MGB_NEED_MEGDNN_ASYNC_ERROR
    std::shared_ptr<megcore::AsyncErrorInfo> m_async_error_info_devptr;
    megcore::AsyncErrorInfo* make_async_error_info(const CompNodeEnv& env);
#endif

public:
    MegDNNHandle(const CompNodeEnv& env);
    ~MegDNNHandle() noexcept;

    static MegDNNHandle& get(const CompNodeEnv& env);

    megdnn::Handle* operator->() const { return handle(); }

    megdnn::Handle* handle() const { return m_megdnn_handle.get(); }

    //! set the default debug level; return original setting
    static int exchange_default_dbg_level(int level) {
        auto ret = sm_default_dbg_level;
        sm_default_dbg_level = level;
        return ret;
    }

#if MGB_NEED_MEGDNN_ASYNC_ERROR
    /*!
     * \brief get pointer to underlying AsyncErrorInfo
     *
     * return nullptr if the device does not need async error report.
     */
    megcore::AsyncErrorInfo* async_error_info_devptr() const {
        return m_async_error_info_devptr.get();
    }
#endif
};

class CompNode::Impl : public CompNode::ImplBase {
protected:
    CompNodeEnv m_env;

    using ImplBase::ImplBase;
    ~Impl() = default;

public:
    CompNodeEnv& env() { return m_env; }
};

const CompNodeEnv& CompNodeEnv::from_comp_node(const CompNode& node) {
    mgb_assert(node.valid());
    return static_cast<CompNode::Impl*>(node.m_impl)->env();
}

}  // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}