未验证 提交 aa10b11e 编写于 作者: 石晓伟 提交者: GitHub

refactor any.h, test=develop (#3736)

上级 9e361a4d
...@@ -76,8 +76,6 @@ class Context<TargetType::kHost> { ...@@ -76,8 +76,6 @@ class Context<TargetType::kHost> {
template <> template <>
class Context<TargetType::kNPU> { class Context<TargetType::kNPU> {
public: public:
Context() {}
explicit Context(const NPUContext& ctx);
// NOTE: InitOnce should only be used by ContextScheduler // NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() {} void InitOnce() {}
void CopySharedTo(NPUContext* ctx) {} void CopySharedTo(NPUContext* ctx) {}
...@@ -101,8 +99,6 @@ class Context<TargetType::kNPU> { ...@@ -101,8 +99,6 @@ class Context<TargetType::kNPU> {
template <> template <>
class Context<TargetType::kAPU> { class Context<TargetType::kAPU> {
public: public:
Context() {}
explicit Context(const APUContext& ctx);
// NOTE: InitOnce should only be used by ContextScheduler // NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() {} void InitOnce() {}
void CopySharedTo(APUContext* ctx) {} void CopySharedTo(APUContext* ctx) {}
...@@ -116,8 +112,6 @@ class Context<TargetType::kAPU> { ...@@ -116,8 +112,6 @@ class Context<TargetType::kAPU> {
template <> template <>
class Context<TargetType::kBM> { class Context<TargetType::kBM> {
public: public:
Context() {}
explicit Context(const BMContext& ctx);
// NOTE: InitOnce should only be used by ContextScheduler // NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() { TargetWrapperBM::SetDevice(TargetWrapperBM::GetDevice()); } void InitOnce() { TargetWrapperBM::SetDevice(TargetWrapperBM::GetDevice()); }
void CopySharedTo(BMContext* ctx) {} void CopySharedTo(BMContext* ctx) {}
...@@ -131,8 +125,6 @@ class Context<TargetType::kBM> { ...@@ -131,8 +125,6 @@ class Context<TargetType::kBM> {
template <> template <>
class Context<TargetType::kRKNPU> { class Context<TargetType::kRKNPU> {
public: public:
Context() {}
explicit Context(const RKNPUContext& ctx);
// NOTE: InitOnce should only be used by ContextScheduler // NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() {} void InitOnce() {}
void CopySharedTo(RKNPUContext* ctx) {} void CopySharedTo(RKNPUContext* ctx) {}
...@@ -146,9 +138,6 @@ class Context<TargetType::kRKNPU> { ...@@ -146,9 +138,6 @@ class Context<TargetType::kRKNPU> {
template <> template <>
class Context<TargetType::kXPU> { class Context<TargetType::kXPU> {
public: public:
Context() {}
explicit Context(const XPUContext& ctx);
// NOTE: InitOnce should only be used by ContextScheduler // NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() {} void InitOnce() {}
...@@ -200,11 +189,6 @@ class Context<TargetType::kXPU> { ...@@ -200,11 +189,6 @@ class Context<TargetType::kXPU> {
template <> template <>
class Context<TargetType::kARM> { class Context<TargetType::kARM> {
public: public:
Context() {}
explicit Context(const ARMContext& ctx);
ARMContext& operator=(const ARMContext& ctx) {}
// NOTE: InitOnce should only be used by ContextScheduler // NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() { DeviceInfo::Init(); } void InitOnce() { DeviceInfo::Init(); }
...@@ -246,7 +230,6 @@ class Context<TargetType::kARM> { ...@@ -246,7 +230,6 @@ class Context<TargetType::kARM> {
template <> template <>
class Context<TargetType::kFPGA> { class Context<TargetType::kFPGA> {
public: public:
Context() {}
void InitOnce() {} void InitOnce() {}
FPGAContext& operator=(const FPGAContext& ctx) {} FPGAContext& operator=(const FPGAContext& ctx) {}
...@@ -340,8 +323,6 @@ class Context<TargetType::kMLU> { ...@@ -340,8 +323,6 @@ class Context<TargetType::kMLU> {
template <> template <>
class Context<TargetType::kX86> { class Context<TargetType::kX86> {
public: public:
Context() {}
// NOTE: InitOnce should only be used by ContextScheduler // NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() {} void InitOnce() {}
......
...@@ -115,7 +115,7 @@ class KernelBase { ...@@ -115,7 +115,7 @@ class KernelBase {
} }
template <typename T> template <typename T>
void SetParam(T param) { void SetParam(T param) {
param_.set<T>(param); param_.set(param);
} }
template <typename P> template <typename P>
P& Param() const { P& Param() const {
......
...@@ -24,7 +24,7 @@ namespace cpp { ...@@ -24,7 +24,7 @@ namespace cpp {
template <> \ template <> \
void OpDesc::SetAttr<T>(const std::string& name, const T& v) { \ void OpDesc::SetAttr<T>(const std::string& name, const T& v) { \
attr_types_[name] = AttrType::repr__; \ attr_types_[name] = AttrType::repr__; \
attrs_[name].set<T>(v); \ attrs_[name].set(v); \
} }
SET_ATTR_IMPL(int32_t, INT); SET_ATTR_IMPL(int32_t, INT);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -13,11 +13,3 @@ ...@@ -13,11 +13,3 @@
// limitations under the License. // limitations under the License.
#include "lite/utils/any.h" #include "lite/utils/any.h"
namespace paddle {
namespace lite {
size_t Any::kInvalidType{typeid(void).hash_code()};
} // namespace lite
} // namespace paddle
...@@ -13,8 +13,12 @@ ...@@ -13,8 +13,12 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <functional> #include <algorithm>
#include <set> #include <cstring>
#include <type_traits>
#include <typeinfo>
#include <utility>
#include "lite/utils/cp_logging.h" #include "lite/utils/cp_logging.h"
namespace paddle { namespace paddle {
...@@ -22,67 +26,273 @@ namespace lite { ...@@ -22,67 +26,273 @@ namespace lite {
class Any { class Any {
public: public:
Any() = default; inline Any() = default;
explicit Any(const Any& other) { inline explicit Any(Any&& other);
type_ = other.type_; inline explicit Any(const Any& other);
data_ = other.clone_data_(other.data_);
deleter_ = other.deleter_;
clone_data_ = other.clone_data_;
}
template <typename T> template <typename T>
void set(const T& v) { void set();
set<T>();
*get_mutable<T>() = v;
}
template <typename T> template <typename T>
void set() { void set(T&& other);
if (type_ != kInvalidType) {
CHECK(type_ == typeid(T).hash_code()); template <typename T>
} else { const T& get() const;
type_ = typeid(T).hash_code();
deleter_ = [&](void** data) { template <typename T>
delete static_cast<T*>(*data); T* get_mutable();
*data = nullptr;
template <typename T>
inline explicit Any(T&& other);
inline ~Any();
inline Any& operator=(Any&& other);
inline Any& operator=(const Any& other);
template <typename T>
inline Any& operator=(T&& other);
inline bool empty() const;
inline bool valid() const;
inline void clear();
inline void swap(Any& other);
inline const std::type_info& type() const;
template <typename T, typename... Args>
inline void construct(Args&&... args);
private:
template <typename T>
class TypeOnHeap;
template <typename T>
class TypeOnStack;
template <typename T>
class TypeInfo;
static const size_t kStack = sizeof(void*) * 3;
static const size_t kAlign = sizeof(void*);
union Data {
std::aligned_storage<kStack, kAlign>::type stack;
void* pheap;
}; };
clone_data_ = [&](void* data) {
T* res = new T; struct Type {
CHECK(data) << "data pointer is nullptr"; void (*destroy)(Data* data);
*res = *static_cast<T*>(data); void (*create_from_data)(Data* dst, const Data& src);
return res; const std::type_info* ptype_info;
}; };
template <typename T>
struct data_on_stack {
static const bool value = ((alignof(T) <= kAlign) && (sizeof(T) <= kStack));
};
inline void construct(Any&& other);
inline void construct(const Any& other);
template <typename T>
inline void check_type() const;
template <typename T>
inline void check_type_by_name() const;
const Type* type_{nullptr};
Data data_;
};
template <typename T>
inline Any::Any(T&& other) {
typedef typename std::decay<T>::type DT;
if (std::is_same<DT, Any>::value) {
this->construct(std::forward<T>(other));
} else {
static_assert(std::is_copy_constructible<DT>::value,
"Any can only hold value that is copy constructable");
type_ = TypeInfo<DT>::get_type();
if (data_on_stack<DT>::value) {
#pragma GCC diagnostic push
#if 6 <= __GNUC__
#pragma GCC diagnostic ignored "-Wplacement-new"
#endif
new (&(data_.stack)) DT(std::forward<T>(other));
#pragma GCC diagnostic pop
} else {
data_.pheap = new DT(std::forward<T>(other));
} }
data_ = new T;
} }
}
template <typename T> inline Any::Any(Any&& other) { this->construct(std::move(other)); }
const T& get() const {
CHECK(data_); inline Any::Any(const Any& other) { this->construct(other); }
CHECK(type_ == typeid(T).hash_code());
return *static_cast<T*>(data_); inline void Any::construct(Any&& other) {
type_ = other.type_;
data_ = other.data_;
other.type_ = nullptr;
}
inline void Any::construct(const Any& other) {
type_ = other.type_;
if (type_ != nullptr) {
type_->create_from_data(&data_, other.data_);
} }
template <typename T> }
T* get_mutable() {
CHECK(data_); template <typename T, typename... Args>
CHECK(type_ == typeid(T).hash_code()); inline void Any::construct(Args&&... args) {
return static_cast<T*>(data_); clear();
typedef typename std::decay<T>::type DT;
type_ = TypeInfo<DT>::get_type();
if (data_on_stack<DT>::value) {
#pragma GCC diagnostic push
#if 6 <= __GNUC__
#pragma GCC diagnostic ignored "-Wplacement-new"
#endif
new (&(data_.stack)) DT(std::forward<Args>(args)...);
#pragma GCC diagnostic pop
} else {
data_.pheap = new DT(std::forward<Args>(args)...);
}
}
template <typename T>
void Any::set() {
this->construct<T>();
}
template <typename T>
void Any::set(T&& other) {
this->construct<T>(std::forward<T>(other));
}
inline Any::~Any() { this->clear(); }
inline Any& Any::operator=(Any&& other) {
Any(std::move(other)).swap(*this);
return *this;
}
inline Any& Any::operator=(const Any& other) {
Any(other).swap(*this);
return *this;
}
template <typename T>
inline Any& Any::operator=(T&& other) {
Any(std::forward<T>(other)).swap(*this);
return *this;
}
inline void Any::swap(Any& other) {
std::swap(type_, other.type_);
std::swap(data_, other.data_);
}
inline void Any::clear() {
if (type_ != nullptr) {
if (type_->destroy != nullptr) {
type_->destroy(&data_);
}
type_ = nullptr;
}
}
inline bool Any::empty() const { return type_ == nullptr; }
inline bool Any::valid() const { return empty() == false; }
inline const std::type_info& Any::type() const {
if (type_ != nullptr) {
return *(type_->ptype_info);
} else {
return typeid(void);
} }
}
template <typename T>
inline void Any::check_type() const {
CHECK_EQ((type_ == nullptr), false);
CHECK_EQ((*(type_->ptype_info) == typeid(T)), true);
}
template <typename T>
inline void Any::check_type_by_name() const {
CHECK_EQ((type_ == nullptr), false);
CHECK_EQ(strcmp(type_->ptype_info->name(), typeid(T).name()), 0);
}
bool valid() const { return (data_ != nullptr); } template <typename T>
inline const T& Any::get() const {
this->check_type<T>();
return *Any::TypeInfo<T>::get_ptr(&(this->data_));
}
~Any() { template <typename T>
if (valid()) { T* Any::get_mutable() {
deleter_(&data_); return Any::TypeInfo<T>::get_ptr(&(this->data_));
}
template <typename T>
class Any::TypeOnHeap {
public:
inline static T* get_ptr(Any::Data* data) {
return static_cast<T*>(data->pheap);
}
inline static const T* get_ptr(const Any::Data* data) {
return static_cast<const T*>(data->pheap);
} }
inline static void create_from_data(Any::Data* dst, const Any::Data& data) {
dst->pheap = new T(*get_ptr(&data));
}
inline static void destroy(Data* data) {
delete static_cast<T*>(data->pheap);
}
};
template <typename T>
class Any::TypeOnStack {
public:
inline static T* get_ptr(Any::Data* data) {
return reinterpret_cast<T*>(&(data->stack));
}
inline static const T* get_ptr(const Any::Data* data) {
return reinterpret_cast<const T*>(&(data->stack));
}
inline static void create_from_data(Any::Data* dst, const Any::Data& data) {
new (&(dst->stack)) T(*get_ptr(&data));
}
inline static void destroy(Data* data) {
T* dptr = reinterpret_cast<T*>(&(data->stack));
dptr->~T();
}
};
template <typename T>
class Any::TypeInfo : public std::conditional<Any::data_on_stack<T>::value,
Any::TypeOnStack<T>,
Any::TypeOnHeap<T>>::type {
public:
inline static const Type* get_type() {
static TypeInfo<T> tp;
return &(tp.type_);
} }
private: private:
static size_t kInvalidType; Type type_;
size_t type_{kInvalidType}; TypeInfo() {
void* data_{nullptr}; if (std::is_pod<T>::value && data_on_stack<T>::value) {
std::function<void(void**)> deleter_; type_.destroy = nullptr;
std::function<void*(void*)> clone_data_; } else {
type_.destroy = TypeInfo<T>::destroy;
}
type_.create_from_data = TypeInfo<T>::create_from_data;
type_.ptype_info = &typeid(T);
}
}; };
} // namespace lite } // namespace lite
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <functional>
#include <iostream> #include <iostream>
#include <list> #include <list>
#include <map> #include <map>
......
...@@ -103,7 +103,7 @@ static int gettimeofday(struct timeval* tp, void* tzp) { ...@@ -103,7 +103,7 @@ static int gettimeofday(struct timeval* tp, void* tzp) {
#define _CHECK_BINARY(x, cmp, y) CHECK(x cmp y) #define _CHECK_BINARY(x, cmp, y) CHECK(x cmp y)
#else #else
#define CHECK(x) if (!(x)) paddle::lite::LogMessageFatal(__FILE__, __FUNCTION__, __LINE__).stream() << "Check failed: " #x << ": " // NOLINT(*) #define CHECK(x) if (!(x)) paddle::lite::LogMessageFatal(__FILE__, __FUNCTION__, __LINE__).stream() << "Check failed: " #x << ": " // NOLINT(*)
#define _CHECK_BINARY(x, cmp, y) CHECK(x cmp y) << x << "!" #cmp << y << " " #define _CHECK_BINARY(x, cmp, y) CHECK((x cmp y)) << (x) << "!" #cmp << (y) << " " // NOLINT(*)
#endif #endif
// clang-format on // clang-format on
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册