From 240a685fc358bd6ea756b6591bb4c7fef77ce560 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 11 Mar 2022 18:38:13 +0800 Subject: [PATCH] feat(opencl): opt lite-OpenCL api: opencl_clear_global_data and enable_opencl_deploy lite api GitOrigin-RevId: 9d932ff27e16011448f5d1c0fa63d0ed737d842c --- lite/src/global.cpp | 1 + lite/src/misc.h | 23 +++++++++++++++++++++++ lite/src/network_impl_base.h | 25 ++++++++++++++++++++++++- lite/test/test_network.cpp | 13 +++++++++++++ 4 files changed, 61 insertions(+), 1 deletion(-) diff --git a/lite/src/global.cpp b/lite/src/global.cpp index f436bd708..e77a020e0 100644 --- a/lite/src/global.cpp +++ b/lite/src/global.cpp @@ -16,6 +16,7 @@ #include "decryption/rc4_cryption.h" #include "lite/global.h" #include "misc.h" +#include "network_impl_base.h" #include "parse_info/default_parse.h" #include "parse_info/parse_info_base.h" diff --git a/lite/src/misc.h b/lite/src/misc.h index 2ff3c79a8..8d04c38cf 100644 --- a/lite/src/misc.h +++ b/lite/src/misc.h @@ -51,6 +51,29 @@ LITE_API std::string ssprintf(const char* fmt = 0, ...) */ LITE_API void print_log(LiteLogLevel level, const char* format = 0, ...) __attribute__((format(printf, 2, 3))); + +/*! + * \brief NonCopyableObj base. + */ +class NonCopyableObj { +public: + NonCopyableObj() {} + +private: + NonCopyableObj(const NonCopyableObj&); + NonCopyableObj& operator=(const NonCopyableObj&); +}; + +template +class Singleton : public NonCopyableObj { +public: + Singleton() {} + static T& Instance() { + static T _; + return _; + } +}; + } // namespace lite #if LITE_ENABLE_LOGGING diff --git a/lite/src/network_impl_base.h b/lite/src/network_impl_base.h index e6a05bb01..0d1e42430 100644 --- a/lite/src/network_impl_base.h +++ b/lite/src/network_impl_base.h @@ -16,10 +16,32 @@ #include "tensor_impl_base.h" #include "type_info.h" +#include #include namespace lite { +/*! + * \brief network reference count + */ +class NetworkRefCount : public Singleton { +public: + NetworkRefCount() : count(0) {} + + NetworkRefCount& operator++(int) { + ++count; + return *this; + } + NetworkRefCount& operator--(int) { + --count; + return *this; + } + int refcount() { return count; } + +private: + std::atomic count; +}; + /*! * \brief the Inner IO data struct, add some inner data from IO */ @@ -54,7 +76,8 @@ struct NetworkIOInner { */ class Network::NetworkImplBase : public DynTypeObj { public: - virtual ~NetworkImplBase() = default; + virtual ~NetworkImplBase() { NetworkRefCount::Instance()--; }; + NetworkImplBase() { NetworkRefCount::Instance()++; }; //! set the config of the network, include: //! the inference device diff --git a/lite/test/test_network.cpp b/lite/test/test_network.cpp index 78f139d5f..f971033fc 100644 --- a/lite/test/test_network.cpp +++ b/lite/test/test_network.cpp @@ -70,6 +70,19 @@ TEST(TestNetWork, Basic) { compare_lite_tensor(result_lite, result_mgb); } +TEST(TestNetWork, RefCount) { + Config config; + ASSERT_EQ(NetworkRefCount::Instance().refcount(), 0); + std::shared_ptr network = std::make_shared(config); + ASSERT_EQ(NetworkRefCount::Instance().refcount(), 1); + std::shared_ptr network_s = std::make_shared(config); + ASSERT_EQ(NetworkRefCount::Instance().refcount(), 2); + network.reset(); + ASSERT_EQ(NetworkRefCount::Instance().refcount(), 1); + network_s.reset(); + ASSERT_EQ(NetworkRefCount::Instance().refcount(), 0); +} + TEST(TestNetWork, SetDeviceId) { Config config; auto lite_tensor = get_input_data("./input_data.npy"); -- GitLab