diff --git a/lite/src/global.cpp b/lite/src/global.cpp index f436bd70889f55b7445fc4632e867217b4526766..e77a020e040a5bf1de8f662adb7bb14447b9589c 100644 --- a/lite/src/global.cpp +++ b/lite/src/global.cpp @@ -16,6 +16,7 @@ #include "decryption/rc4_cryption.h" #include "lite/global.h" #include "misc.h" +#include "network_impl_base.h" #include "parse_info/default_parse.h" #include "parse_info/parse_info_base.h" diff --git a/lite/src/misc.h b/lite/src/misc.h index 2ff3c79a8e8936718133d8cbe13f9c482a3f76cb..8d04c38cf0713d8465e78426bb839352a456cb12 100644 --- a/lite/src/misc.h +++ b/lite/src/misc.h @@ -51,6 +51,29 @@ LITE_API std::string ssprintf(const char* fmt = 0, ...) */ LITE_API void print_log(LiteLogLevel level, const char* format = 0, ...) __attribute__((format(printf, 2, 3))); + +/*! + * \brief NonCopyableObj base. + */ +class NonCopyableObj { +public: + NonCopyableObj() {} + +private: + NonCopyableObj(const NonCopyableObj&); + NonCopyableObj& operator=(const NonCopyableObj&); +}; + +template +class Singleton : public NonCopyableObj { +public: + Singleton() {} + static T& Instance() { + static T _; + return _; + } +}; + } // namespace lite #if LITE_ENABLE_LOGGING diff --git a/lite/src/network_impl_base.h b/lite/src/network_impl_base.h index e6a05bb01a80f071a5a3f2e5d4ce0ea9d96918fd..0d1e42430c0b8008284add8c36c3c6de438f3340 100644 --- a/lite/src/network_impl_base.h +++ b/lite/src/network_impl_base.h @@ -16,10 +16,32 @@ #include "tensor_impl_base.h" #include "type_info.h" +#include #include namespace lite { +/*! + * \brief network reference count + */ +class NetworkRefCount : public Singleton { +public: + NetworkRefCount() : count(0) {} + + NetworkRefCount& operator++(int) { + ++count; + return *this; + } + NetworkRefCount& operator--(int) { + --count; + return *this; + } + int refcount() { return count; } + +private: + std::atomic count; +}; + /*! * \brief the Inner IO data struct, add some inner data from IO */ @@ -54,7 +76,8 @@ struct NetworkIOInner { */ class Network::NetworkImplBase : public DynTypeObj { public: - virtual ~NetworkImplBase() = default; + virtual ~NetworkImplBase() { NetworkRefCount::Instance()--; }; + NetworkImplBase() { NetworkRefCount::Instance()++; }; //! set the config of the network, include: //! the inference device diff --git a/lite/test/test_network.cpp b/lite/test/test_network.cpp index 78f139d5fb669eeb218164b13fce4b7c145df766..f971033fc3fe25ac7c6190350e3e85148901fc47 100644 --- a/lite/test/test_network.cpp +++ b/lite/test/test_network.cpp @@ -70,6 +70,19 @@ TEST(TestNetWork, Basic) { compare_lite_tensor(result_lite, result_mgb); } +TEST(TestNetWork, RefCount) { + Config config; + ASSERT_EQ(NetworkRefCount::Instance().refcount(), 0); + std::shared_ptr network = std::make_shared(config); + ASSERT_EQ(NetworkRefCount::Instance().refcount(), 1); + std::shared_ptr network_s = std::make_shared(config); + ASSERT_EQ(NetworkRefCount::Instance().refcount(), 2); + network.reset(); + ASSERT_EQ(NetworkRefCount::Instance().refcount(), 1); + network_s.reset(); + ASSERT_EQ(NetworkRefCount::Instance().refcount(), 0); +} + TEST(TestNetWork, SetDeviceId) { Config config; auto lite_tensor = get_input_data("./input_data.npy");