diff --git a/src/core/impl/comp_node/cpu/comp_node.cpp b/src/core/impl/comp_node/cpu/comp_node.cpp index 06867df5a5754548bd361a20554a06e14586d678..92633e155ae70d2f9d2ae818cfed932abc679ea6 100644 --- a/src/core/impl/comp_node/cpu/comp_node.cpp +++ b/src/core/impl/comp_node/cpu/comp_node.cpp @@ -604,12 +604,8 @@ class CpuCompNode::CompNodeRecorderImpl final : public CompNodeBaseImpl { using EventImpl::EventImpl; }; -//! TODO: because the x-code bug, see -//! https://github.com/tensorflow/tensorflow/issues/18356 -//! thread local is no support on IOS, -//! When update x-xode, this code should be deleted -#if !defined(IOS) && MGB_HAVE_THREAD - static thread_local SeqRecorderImpl* sm_cur_recorder; +#if MGB_HAVE_THREAD + static MGB_THREAD_LOCAL_PTR(SeqRecorderImpl) sm_cur_recorder; #else SeqRecorderImpl* sm_cur_recorder = nullptr; #endif @@ -822,9 +818,9 @@ public: } }; MGB_DYN_TYPE_OBJ_FINAL_IMPL(CompNodeRecorderImpl); -#if !defined(IOS) && MGB_HAVE_THREAD -thread_local CpuCompNode::SeqRecorderImpl* - CompNodeRecorderImpl::sm_cur_recorder = nullptr; +#if MGB_HAVE_THREAD +MGB_THREAD_LOCAL_PTR(CpuCompNode::SeqRecorderImpl) +CompNodeRecorderImpl::sm_cur_recorder = nullptr; #endif /* ======================== CpuCompNode ======================== */ diff --git a/src/core/include/megbrain/utils/thread.h b/src/core/include/megbrain/utils/thread.h index f94fd70bb5dd8edceeac6caa80b30c42a8ec9882..0e6dece5792efa4050b09e98cac234f4d3f9c066 100644 --- a/src/core/include/megbrain/utils/thread.h +++ b/src/core/include/megbrain/utils/thread.h @@ -14,6 +14,7 @@ #include "megbrain_build_config.h" #if MGB_HAVE_THREAD #include "./thread_impl_1.h" +#include "./thread_local.h" #else #include "./thread_impl_0.h" #endif diff --git a/src/core/include/megbrain/utils/thread_local.h b/src/core/include/megbrain/utils/thread_local.h new file mode 100644 index 0000000000000000000000000000000000000000..66c64713293c4b40337d10f1333a51145d9616eb --- /dev/null +++ b/src/core/include/megbrain/utils/thread_local.h @@ -0,0 +1,123 @@ +/** + * \file src/core/include/megbrain/utils/thread_local.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once +#include +#include +#include + +#if !defined(__APPLE__) +#define USE_STL_THREAD_LOCAL 1 +#else +#define USE_STL_THREAD_LOCAL 0 +#endif + +// clang-format off +#if defined(__APPLE__) +# if (__ENVIRONMENT_MAC_OS_X_VERSION_MIN_REQUIRED__ + 0) >= 101000 +# define USE_STL_THREAD_LOCAL 1 +# else +# define USE_STL_THREAD_LOCAL 0 +# endif +#endif + +#if USE_STL_THREAD_LOCAL +#define MGB_THREAD_LOCAL_PTR(T) thread_local T* +#else +#define MGB_THREAD_LOCAL_PTR(T) ThreadLocalPtr +#endif +// clang-format on + +#if !USE_STL_THREAD_LOCAL +#include + +namespace mgb { + +template +class ThreadLocalPtr { + struct ThreadData { + const ThreadLocalPtr* self = nullptr; + T** data = nullptr; + }; + pthread_key_t m_key; + std::function m_constructor = nullptr; + std::function m_destructor = nullptr; + + void move_to(T** data) { + if(void* d = pthread_getspecific(m_key)){ + *data = *static_cast(d)->data; + } + } + + T** get() const { + if (auto d = pthread_getspecific(m_key)) { + return static_cast(d)->data; + } + ThreadData* t_data = new ThreadData(); + t_data->data = m_constructor(); + t_data->self = this; + pthread_setspecific(m_key, t_data); + return t_data->data; + } + + static void exit(void* d) { + ThreadData* td = static_cast(d); + if (td && td->self->m_destructor) + td->self->m_destructor(td->data); + delete td; + } + +public: + ThreadLocalPtr( + std::function constructor, + std::function destructor = std::default_delete()) + : m_constructor(constructor), m_destructor(destructor) { + pthread_key_create(&m_key, exit); + } + + ThreadLocalPtr() + : ThreadLocalPtr(std::function([] { return new T*(); })) {} + + ThreadLocalPtr(std::nullptr_t) + : ThreadLocalPtr([] { return new T*(nullptr); }) {} + + ThreadLocalPtr(ThreadLocalPtr&& other) : ThreadLocalPtr() { + other.move_to(get()); + } + + ThreadLocalPtr& operator=(ThreadLocalPtr&& other) { + other.move_to(get()); + return *this; + } + ThreadLocalPtr& operator=(T* v) { + *get() = v; + return *this; + } + ~ThreadLocalPtr() { pthread_key_delete(m_key); } + + //!& operator like std thread_local + T** operator&() const { return get(); } + + //! use in if() + operator bool() const { return *get(); } + + //! directly access its member + T* operator->() const { return *get(); } + + //! convert to T* + operator T*() const { return *get(); } +}; + +} // namespace mgb + +#endif + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}