From aa10b11e425c18a05aa3373f1436dcbf986dea8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9F=B3=E6=99=93=E4=BC=9F?= <39303645+Shixiaowei02@users.noreply.github.com> Date: Thu, 4 Jun 2020 10:10:53 +0800 Subject: [PATCH] refactor any.h, test=develop (#3736) --- lite/core/context.h | 19 -- lite/core/kernel.h | 2 +- lite/model_parser/cpp/op_desc.cc | 2 +- lite/utils/any.cc | 10 +- lite/utils/any.h | 304 ++++++++++++++++++++++++++----- lite/utils/factory.h | 1 + lite/utils/logging.h | 2 +- 7 files changed, 262 insertions(+), 78 deletions(-) diff --git a/lite/core/context.h b/lite/core/context.h index bf2635da38..f606eeffaf 100644 --- a/lite/core/context.h +++ b/lite/core/context.h @@ -76,8 +76,6 @@ class Context { template <> class Context { public: - Context() {} - explicit Context(const NPUContext& ctx); // NOTE: InitOnce should only be used by ContextScheduler void InitOnce() {} void CopySharedTo(NPUContext* ctx) {} @@ -101,8 +99,6 @@ class Context { template <> class Context { public: - Context() {} - explicit Context(const APUContext& ctx); // NOTE: InitOnce should only be used by ContextScheduler void InitOnce() {} void CopySharedTo(APUContext* ctx) {} @@ -116,8 +112,6 @@ class Context { template <> class Context { public: - Context() {} - explicit Context(const BMContext& ctx); // NOTE: InitOnce should only be used by ContextScheduler void InitOnce() { TargetWrapperBM::SetDevice(TargetWrapperBM::GetDevice()); } void CopySharedTo(BMContext* ctx) {} @@ -131,8 +125,6 @@ class Context { template <> class Context { public: - Context() {} - explicit Context(const RKNPUContext& ctx); // NOTE: InitOnce should only be used by ContextScheduler void InitOnce() {} void CopySharedTo(RKNPUContext* ctx) {} @@ -146,9 +138,6 @@ class Context { template <> class Context { public: - Context() {} - explicit Context(const XPUContext& ctx); - // NOTE: InitOnce should only be used by ContextScheduler void InitOnce() {} @@ -200,11 +189,6 @@ class Context { template <> class Context { public: - Context() {} - explicit Context(const ARMContext& ctx); - - ARMContext& operator=(const ARMContext& ctx) {} - // NOTE: InitOnce should only be used by ContextScheduler void InitOnce() { DeviceInfo::Init(); } @@ -246,7 +230,6 @@ class Context { template <> class Context { public: - Context() {} void InitOnce() {} FPGAContext& operator=(const FPGAContext& ctx) {} @@ -340,8 +323,6 @@ class Context { template <> class Context { public: - Context() {} - // NOTE: InitOnce should only be used by ContextScheduler void InitOnce() {} diff --git a/lite/core/kernel.h b/lite/core/kernel.h index cbd9e8afff..5f98eb0d3e 100644 --- a/lite/core/kernel.h +++ b/lite/core/kernel.h @@ -115,7 +115,7 @@ class KernelBase { } template void SetParam(T param) { - param_.set(param); + param_.set(param); } template P& Param() const { diff --git a/lite/model_parser/cpp/op_desc.cc b/lite/model_parser/cpp/op_desc.cc index f4be0106fc..712ef64fa9 100644 --- a/lite/model_parser/cpp/op_desc.cc +++ b/lite/model_parser/cpp/op_desc.cc @@ -24,7 +24,7 @@ namespace cpp { template <> \ void OpDesc::SetAttr(const std::string& name, const T& v) { \ attr_types_[name] = AttrType::repr__; \ - attrs_[name].set(v); \ + attrs_[name].set(v); \ } SET_ATTR_IMPL(int32_t, INT); diff --git a/lite/utils/any.cc b/lite/utils/any.cc index fde832aae0..c58bcf5716 100644 --- a/lite/utils/any.cc +++ b/lite/utils/any.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,11 +13,3 @@ // limitations under the License. #include "lite/utils/any.h" - -namespace paddle { -namespace lite { - -size_t Any::kInvalidType{typeid(void).hash_code()}; - -} // namespace lite -} // namespace paddle diff --git a/lite/utils/any.h b/lite/utils/any.h index 3f7029e98c..f658e4e658 100644 --- a/lite/utils/any.h +++ b/lite/utils/any.h @@ -13,8 +13,12 @@ // limitations under the License. #pragma once -#include -#include +#include +#include +#include +#include +#include + #include "lite/utils/cp_logging.h" namespace paddle { @@ -22,67 +26,273 @@ namespace lite { class Any { public: - Any() = default; - explicit Any(const Any& other) { - type_ = other.type_; - data_ = other.clone_data_(other.data_); - deleter_ = other.deleter_; - clone_data_ = other.clone_data_; - } + inline Any() = default; + inline explicit Any(Any&& other); + inline explicit Any(const Any& other); template - void set(const T& v) { - set(); - *get_mutable() = v; - } + void set(); + + template + void set(T&& other); + + template + const T& get() const; + + template + T* get_mutable(); + + template + inline explicit Any(T&& other); + + inline ~Any(); + + inline Any& operator=(Any&& other); + inline Any& operator=(const Any& other); + + template + inline Any& operator=(T&& other); + + inline bool empty() const; + inline bool valid() const; + inline void clear(); + inline void swap(Any& other); + inline const std::type_info& type() const; + + template + inline void construct(Args&&... args); + + private: + template + class TypeOnHeap; + + template + class TypeOnStack; + + template + class TypeInfo; + + static const size_t kStack = sizeof(void*) * 3; + static const size_t kAlign = sizeof(void*); + + union Data { + std::aligned_storage::type stack; + void* pheap; + }; + + struct Type { + void (*destroy)(Data* data); + void (*create_from_data)(Data* dst, const Data& src); + const std::type_info* ptype_info; + }; template - void set() { - if (type_ != kInvalidType) { - CHECK(type_ == typeid(T).hash_code()); + struct data_on_stack { + static const bool value = ((alignof(T) <= kAlign) && (sizeof(T) <= kStack)); + }; + + inline void construct(Any&& other); + inline void construct(const Any& other); + + template + inline void check_type() const; + + template + inline void check_type_by_name() const; + + const Type* type_{nullptr}; + Data data_; +}; + +template +inline Any::Any(T&& other) { + typedef typename std::decay::type DT; + if (std::is_same::value) { + this->construct(std::forward(other)); + } else { + static_assert(std::is_copy_constructible
::value, + "Any can only hold value that is copy constructable"); + type_ = TypeInfo
::get_type(); + if (data_on_stack
::value) { +#pragma GCC diagnostic push +#if 6 <= __GNUC__ +#pragma GCC diagnostic ignored "-Wplacement-new" +#endif + new (&(data_.stack)) DT(std::forward(other)); +#pragma GCC diagnostic pop } else { - type_ = typeid(T).hash_code(); - deleter_ = [&](void** data) { - delete static_cast(*data); - *data = nullptr; - }; - clone_data_ = [&](void* data) { - T* res = new T; - CHECK(data) << "data pointer is nullptr"; - *res = *static_cast(data); - return res; - }; + data_.pheap = new DT(std::forward(other)); } - data_ = new T; } +} - template - const T& get() const { - CHECK(data_); - CHECK(type_ == typeid(T).hash_code()); - return *static_cast(data_); +inline Any::Any(Any&& other) { this->construct(std::move(other)); } + +inline Any::Any(const Any& other) { this->construct(other); } + +inline void Any::construct(Any&& other) { + type_ = other.type_; + data_ = other.data_; + other.type_ = nullptr; +} + +inline void Any::construct(const Any& other) { + type_ = other.type_; + if (type_ != nullptr) { + type_->create_from_data(&data_, other.data_); } - template - T* get_mutable() { - CHECK(data_); - CHECK(type_ == typeid(T).hash_code()); - return static_cast(data_); +} + +template +inline void Any::construct(Args&&... args) { + clear(); + typedef typename std::decay::type DT; + type_ = TypeInfo
::get_type(); + if (data_on_stack
::value) { +#pragma GCC diagnostic push +#if 6 <= __GNUC__ +#pragma GCC diagnostic ignored "-Wplacement-new" +#endif + new (&(data_.stack)) DT(std::forward(args)...); +#pragma GCC diagnostic pop + } else { + data_.pheap = new DT(std::forward(args)...); } +} + +template +void Any::set() { + this->construct(); +} + +template +void Any::set(T&& other) { + this->construct(std::forward(other)); +} - bool valid() const { return (data_ != nullptr); } +inline Any::~Any() { this->clear(); } - ~Any() { - if (valid()) { - deleter_(&data_); +inline Any& Any::operator=(Any&& other) { + Any(std::move(other)).swap(*this); + return *this; +} + +inline Any& Any::operator=(const Any& other) { + Any(other).swap(*this); + return *this; +} + +template +inline Any& Any::operator=(T&& other) { + Any(std::forward(other)).swap(*this); + return *this; +} + +inline void Any::swap(Any& other) { + std::swap(type_, other.type_); + std::swap(data_, other.data_); +} + +inline void Any::clear() { + if (type_ != nullptr) { + if (type_->destroy != nullptr) { + type_->destroy(&data_); } + type_ = nullptr; + } +} + +inline bool Any::empty() const { return type_ == nullptr; } + +inline bool Any::valid() const { return empty() == false; } + +inline const std::type_info& Any::type() const { + if (type_ != nullptr) { + return *(type_->ptype_info); + } else { + return typeid(void); + } +} + +template +inline void Any::check_type() const { + CHECK_EQ((type_ == nullptr), false); + CHECK_EQ((*(type_->ptype_info) == typeid(T)), true); +} + +template +inline void Any::check_type_by_name() const { + CHECK_EQ((type_ == nullptr), false); + CHECK_EQ(strcmp(type_->ptype_info->name(), typeid(T).name()), 0); +} + +template +inline const T& Any::get() const { + this->check_type(); + return *Any::TypeInfo::get_ptr(&(this->data_)); +} + +template +T* Any::get_mutable() { + return Any::TypeInfo::get_ptr(&(this->data_)); +} + +template +class Any::TypeOnHeap { + public: + inline static T* get_ptr(Any::Data* data) { + return static_cast(data->pheap); + } + inline static const T* get_ptr(const Any::Data* data) { + return static_cast(data->pheap); + } + inline static void create_from_data(Any::Data* dst, const Any::Data& data) { + dst->pheap = new T(*get_ptr(&data)); + } + inline static void destroy(Data* data) { + delete static_cast(data->pheap); + } +}; + +template +class Any::TypeOnStack { + public: + inline static T* get_ptr(Any::Data* data) { + return reinterpret_cast(&(data->stack)); + } + inline static const T* get_ptr(const Any::Data* data) { + return reinterpret_cast(&(data->stack)); + } + inline static void create_from_data(Any::Data* dst, const Any::Data& data) { + new (&(dst->stack)) T(*get_ptr(&data)); + } + inline static void destroy(Data* data) { + T* dptr = reinterpret_cast(&(data->stack)); + dptr->~T(); + } +}; + +template +class Any::TypeInfo : public std::conditional::value, + Any::TypeOnStack, + Any::TypeOnHeap>::type { + public: + inline static const Type* get_type() { + static TypeInfo tp; + return &(tp.type_); } private: - static size_t kInvalidType; - size_t type_{kInvalidType}; - void* data_{nullptr}; - std::function deleter_; - std::function clone_data_; + Type type_; + TypeInfo() { + if (std::is_pod::value && data_on_stack::value) { + type_.destroy = nullptr; + } else { + type_.destroy = TypeInfo::destroy; + } + type_.create_from_data = TypeInfo::create_from_data; + type_.ptype_info = &typeid(T); + } }; } // namespace lite diff --git a/lite/utils/factory.h b/lite/utils/factory.h index 9685c16544..d286ceb42c 100644 --- a/lite/utils/factory.h +++ b/lite/utils/factory.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include #include #include #include diff --git a/lite/utils/logging.h b/lite/utils/logging.h index e30fe08b22..f292f220c0 100644 --- a/lite/utils/logging.h +++ b/lite/utils/logging.h @@ -103,7 +103,7 @@ static int gettimeofday(struct timeval* tp, void* tzp) { #define _CHECK_BINARY(x, cmp, y) CHECK(x cmp y) #else #define CHECK(x) if (!(x)) paddle::lite::LogMessageFatal(__FILE__, __FUNCTION__, __LINE__).stream() << "Check failed: " #x << ": " // NOLINT(*) -#define _CHECK_BINARY(x, cmp, y) CHECK(x cmp y) << x << "!" #cmp << y << " " +#define _CHECK_BINARY(x, cmp, y) CHECK((x cmp y)) << (x) << "!" #cmp << (y) << " " // NOLINT(*) #endif // clang-format on -- GitLab