提交 123c79fa 编写于 作者: S Superjomn

refactor type system

上级 9d6a0c88
...@@ -13,126 +13,118 @@ ...@@ -13,126 +13,118 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/type_system.h" #include "paddle/fluid/lite/core/type_system.h"
#include "type_system.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
// ------------------------- GetType specification ---------------------------- // ------------------------- GetType specification ----------------------------
template <> // ------------------------- end GetType specification ------------------------
const Type*
Type::Get<false /*is_unsupported*/, false /*is_tensor*/, TargetType::kHost,
PrecisionType::kFloat, DataLayoutType::kNCHW>() {
static UnsupportedTy x;
return &x;
}
template <> size_t ParamTypeRegistry::KernelIdTy::hash() const {
const Type* std::hash<std::string> h;
Type::Get<false /*is_unsupported*/, true /*is_tensor*/, TargetType::kX86, size_t hash = h(kernel_type);
PrecisionType::kFloat, DataLayoutType::kNCHW>() { hash = hash_combine(hash, place.hash());
static TensorFp32NCHWTy x(TargetType::kX86); hash = hash_combine(hash, std::hash<int>()(static_cast<int>(io)));
return &x; hash = hash_combine(hash, std::hash<std::string>()(arg_name));
return hash;
} }
template <> std::ostream &operator<<(std::ostream &os, const Type &other) {
const Type* if (other.IsUnsupported()) {
Type::Get<false /*is_unsupported*/, true /*is_tensor*/, TargetType::kHost, os << "<Unsupported>";
PrecisionType::kFloat, DataLayoutType::kNCHW>() { return os;
static TensorFp32NCHWTy x(TargetType::kHost); }
return &x; if (other.IsVoid()) {
os << "<Void>";
return os;
}
if (other.IsTensor()) {
os << "<Tensor:";
} else {
os << "<";
}
os << TargetToStr(other.target()) << "/" << PrecisionToStr(other.precision())
<< "/" << DataLayoutToStr(other.layout()) << ">";
return os;
} }
template <> const Type *Type::GetTensorTy(TargetType target, PrecisionType precision,
const Type* Type::Get<UnsupportedTy>(TargetType target) { DataLayoutType layout, int device) {
return Get<false, false, TargetType::kHost, PrecisionType::kFloat, // NOTE quite naive implementation here, but not performance sensitive.
DataLayoutType::kNCHW>(); DataType::ID type_id = DataType::ID::Tensor;
}
template <TargetType Target> #define HASH_ONE(x) v = hash_combine(v, hasher(static_cast<int>(x)))
TensorListAnyTy* GetTensorListAnyTy() {
static TensorListAnyTy x(Target);
return &x;
}
template <TargetType Target>
TensorAnyTy* GetTensorAnyTy() {
static TensorAnyTy x(Target);
return &x;
}
template <> std::hash<int> hasher;
const Type* Type::Get<TensorListAnyTy>(TargetType target) { size_t v = hasher(static_cast<int>(type_id));
switch (target) { HASH_ONE(target);
case TargetType::kHost: HASH_ONE(precision);
return GetTensorListAnyTy<TARGET(kHost)>(); HASH_ONE(layout);
case TargetType::kCUDA: HASH_ONE(device);
return GetTensorListAnyTy<TARGET(kCUDA)>(); #undef HASH_ONE
case TargetType::kX86:
return GetTensorListAnyTy<TARGET(kX86)>();
default:
LOG(FATAL) << "unsupported type";
}
}
template <> std::stringstream name;
const Type* Type::Get<TensorAnyTy>(TargetType target) { name << "Tensor<";
switch (target) { name << TargetToStr(target) << ",";
case TargetType::kHost: name << PrecisionToStr(precision) << ",";
return GetTensorAnyTy<TARGET(kHost)>(); name << DataLayoutToStr(layout) << ",";
case TargetType::kCUDA: name << device;
return GetTensorAnyTy<TARGET(kCUDA)>(); name << ">";
case TargetType::kX86:
return GetTensorAnyTy<TARGET(kX86)>(); auto it = type_repo_.find(v);
default: if (it == type_repo_.end()) {
LOG(FATAL) << "unsupported type"; // The Types should alive across the process life, no need to delete.
type_repo_[v] =
new Type(type_id, name.str(), target, precision, layout, device);
} }
return type_repo_[v];
} }
template <TargetType Target> const Type *Type::GetTensorListTy(TargetType target, PrecisionType precision,
const Type* GetTensorFp32NCHWTy() { DataLayoutType layout, int device) {
static TensorFp32NCHWTy x(Target); DataType::ID type_id = DataType::ID::TensorList;
return &x;
}
template <> #define HASH_ONE(x) v = hash_combine(v, hasher(static_cast<int>(x)))
const Type* Type::Get<TensorFp32NCHWTy>(TargetType target) {
switch (target) {
case TARGET(kHost):
return GetTensorFp32NCHWTy<TARGET(kHost)>();
case TARGET(kCUDA):
return GetTensorFp32NCHWTy<TARGET(kCUDA)>();
case TARGET(kX86):
return GetTensorFp32NCHWTy<TARGET(kX86)>();
default:
LOG(FATAL) << "unsupported target Type " << TargetToStr(target);
}
return nullptr;
}
const Type* LookupType(DataTypeBase::ID type_id, bool is_unknown, std::hash<int> hasher;
bool is_tensor, Place place) { size_t v = hasher(static_cast<int>(type_id));
using id_t = DataTypeBase::ID; HASH_ONE(target);
switch (type_id) { HASH_ONE(precision);
case id_t::Tensor_Any: HASH_ONE(layout);
return Type::Get<TensorAnyTy>(place.target); HASH_ONE(device);
case id_t::Tensor_Fp32_NCHW: #undef HASH_ONE
return Type::Get<TensorFp32NCHWTy>(place.target);
case id_t::TensorList_Any: std::stringstream name;
return Type::Get<TensorListAnyTy>(place.target); name << "TensorList<";
default: name << TargetToStr(target) << ",";
LOG(FATAL) << "unsupported type"; name << PrecisionToStr(precision) << ",";
} name << DataLayoutToStr(layout) << ",";
return nullptr; name << device;
name << ">";
if (!type_repo_[v])
// The Types should alive across the process life, no need to delete.
type_repo_[v] =
new Type(type_id, name.str(), target, precision, layout, device);
return type_repo_[v];
} }
// ------------------------- end GetType specification ------------------------ const Type *Type::GetUnsupportedTy() {
std::hash<int> hasher;
size_t v = hasher(static_cast<int>(DataType::ID::Unsupported));
if (!type_repo_[v])
type_repo_[v] =
new Type(DataType::ID::Unsupported, "Unsupported", TARGET(kUnk),
PRECISION(kUnk), DATALAYOUT(kUnk), -1);
}
size_t ParamTypeRegistry::KernelIdTy::hash() const { const Type *Type::GetVoidTy() {
std::hash<std::string> h; std::hash<int> hasher;
size_t hash = h(kernel_type); size_t v = hasher(static_cast<int>(DataType::ID::Void));
hash = hash_combine(hash, place.hash()); if (!type_repo_[v])
hash = hash_combine(hash, std::hash<int>()(static_cast<int>(io))); type_repo_[v] = new Type(DataType::ID::Void, "Void", TARGET(kAny),
hash = hash_combine(hash, std::hash<std::string>()(arg_name)); PRECISION(kAny), DATALAYOUT(kAny), -1);
return hash;
} }
} // namespace lite } // namespace lite
......
...@@ -73,7 +73,7 @@ namespace lite { ...@@ -73,7 +73,7 @@ namespace lite {
// //
// TODO(Superjomn) Add operator/kernel-wise static checking to avoid unsupported // TODO(Superjomn) Add operator/kernel-wise static checking to avoid unsupported
// type mixed in the system. // type mixed in the system.
class DataTypeBase { class DataType {
public: public:
// The Void type can cast to any other type. // The Void type can cast to any other type.
// The Unsupported is the data type that developed include in the system, for // The Unsupported is the data type that developed include in the system, for
...@@ -83,43 +83,56 @@ class DataTypeBase { ...@@ -83,43 +83,56 @@ class DataTypeBase {
enum class ID : int { enum class ID : int {
Void = 0, // unknown type that can be cast to any data type. Void = 0, // unknown type that can be cast to any data type.
Unsupported, // Unsupported data type that will not be analyzed. Unsupported, // Unsupported data type that will not be analyzed.
Tensor_Fp32_NCHW,
Tensor_Int8_NCHW,
Tensor_Int64_NCHW,
// Tensor_Any represents a Tensor with any place, data, layout. It is used // Tensor_Any represents a Tensor with any place, data, layout. It is used
// in some IO kernels those doesn't care the data. // in some IO kernels those doesn't care the data.
Tensor_Any, Tensor,
// Used by feed or fetch op. // A tensor list, but all the elements should have the same type.
TensorList_Any, TensorList,
// ---------
NumTypes, // Must remains as last defined ID. NumTypes, // Must remains as last defined ID.
}; };
ID id() const { return id_; } ID id() const { return id_; }
// type check. // type check.
bool IsTensor() const { return is_tensor_; }
bool IsVoid() const { return id_ == ID::Void; } bool IsVoid() const { return id_ == ID::Void; }
bool IsUnsupported() const { return id_ == ID::Unsupported; } bool IsUnsupported() const { return id_ == ID::Unsupported; }
bool IsTensorFp32NCHW() const { return id_ == ID::Tensor_Fp32_NCHW; } bool IsTensor() const { return id_ == ID::Tensor; }
bool IsTensorInt8NCHW() const { return id_ == ID::Tensor_Int8_NCHW; } bool IsTensorList() const { return id_ == ID::TensorList; }
bool IsTensorInt64NCHW() const { return id_ == ID::Tensor_Int64_NCHW; } // Get number of types.
int num_types() const { return static_cast<int>(ID::NumTypes); } int num_types() const { return static_cast<int>(ID::NumTypes); }
protected: protected:
// Can only extended by subclass. // Can only extended by subclass.
DataTypeBase(ID id, bool is_tensor) : id_(id), is_tensor_(is_tensor) {} DataType(ID id) : id_(id) {}
ID id_{ID::Unsupported}; ID id_{ID::Unsupported};
bool is_tensor_{false};
}; };
/* /*
* Datatype with device info considered. * Datatype with device info considered.
* NOTE A Type with different device is treated as different DeviceDataType. * NOTE A Type with different device is treated as different DeviceDataType.
*/ */
class Type : public DataTypeBase { class Type : public DataType {
public: public:
// Can cast to another type. This is heavily used in MIR, by determine whether
// is is possible to add a statement to transform a type to another.
virtual bool TypeCastable(const Type& type) const { return id_ == type.id(); }
/// Get a Tensor type.
static const Type* GetTensorTy(TargetType target,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW),
int device = 0);
/// Get a TensorList type.
static const Type* GetTensorListTy(
TargetType target, PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW), int device = 0);
/// Get an Unsupported type.
static const Type* GetUnsupportedTy();
/// Get an Void type.
static const Type* GetVoidTy();
TargetType target() const { return place_.target; } TargetType target() const { return place_.target; }
PrecisionType precision() const { return place_.precision; } PrecisionType precision() const { return place_.precision; }
DataLayoutType layout() const { return place_.layout; } DataLayoutType layout() const { return place_.layout; }
...@@ -130,52 +143,23 @@ class Type : public DataTypeBase { ...@@ -130,52 +143,23 @@ class Type : public DataTypeBase {
bool operator==(const Type& other) { bool operator==(const Type& other) {
return id_ == other.id() && place_ == other.place(); return id_ == other.id() && place_ == other.place();
} }
friend std::ostream& operator<<(std::ostream& os, const Type& other) { friend std::ostream& operator<<(std::ostream& os, const Type& other);
if (other.IsUnsupported()) {
os << "<Unsupported>";
return os;
}
if (other.IsVoid()) {
os << "<Void>";
return os;
}
if (other.is_tensor_) {
os << "<Tensor:";
} else {
os << "<";
}
os << TargetToStr(other.target()) << "/"
<< PrecisionToStr(other.precision()) << "/"
<< DataLayoutToStr(other.layout()) << ">";
return os;
}
// Can cast to another type. This is heavily used in MIR, by determine whether
// is is possible to add a statement to transform a type to another.
virtual bool TypeCastable(const Type& type) const { return id_ == type.id(); }
template <bool is_unknown, bool is_tensor = true,
TargetType target = TargetType::kHost,
PrecisionType precision = PrecisionType::kFloat,
DataLayoutType layout = DataLayoutType::kNCHW>
// Get a type.
static const Type* Get();
template <typename TypeTy>
static const Type* Get(TargetType target = TargetType::kHost);
virtual ~Type() = default; virtual ~Type() = default;
protected: protected:
Type(ID id, const std::string& name, bool is_tensor, /// One should avoid using this construct.
TargetType target = TargetType::kHost, Type(ID id, const std::string& name, TargetType target = TargetType::kHost,
PrecisionType precision = PrecisionType::kFloat, PrecisionType precision = PrecisionType::kFloat,
DataLayoutType layout = DataLayoutType::kNCHW, short device = 0) DataLayoutType layout = DataLayoutType::kNCHW, short device = 0)
: DataTypeBase(id, is_tensor), : DataType(id), place_{target, precision, layout, device}, name_(name) {}
place_{target, precision, layout, device},
name_(name) {} // An map is used here to maintain a global repo for types. We don't use
// MACROs with static variables for that the TypeSystem should only used in
// compile time, that is not performance sensitive, and a map-based way is
// easier to implement and maintain.
static std::map<size_t, const Type*> type_repo_;
protected:
Place place_; Place place_;
const std::string name_; const std::string name_;
}; };
...@@ -224,46 +208,15 @@ static bool TypeCompatibleTo(const Type& a, const Type& b) { ...@@ -224,46 +208,15 @@ static bool TypeCompatibleTo(const Type& a, const Type& b) {
// is only one instance across the system. // is only one instance across the system.
class VoidTy : public Type { class VoidTy : public Type {
public: public:
VoidTy() : Type(ID::Void, "Void", false /*is_tensor*/) {} VoidTy() : Type(ID::Void, "Void") {}
}; };
class UnsupportedTy : public Type { class UnsupportedTy : public Type {
public: public:
UnsupportedTy() : Type(ID::Unsupported, "Unsupported", false /*is_tensor*/) {} UnsupportedTy() : Type(ID::Unsupported, "Unsupported", false /*is_tensor*/) {}
}; };
class TensorAnyTy : public Type {
public:
explicit TensorAnyTy(TargetType target)
: Type(ID::Tensor_Any, "TensorAny", true, target, PRECISION(kAny),
DATALAYOUT(kAny)) {}
};
// A list of tensor, and no assumption on the data layout or data type.
class TensorListAnyTy : public Type {
public:
explicit TensorListAnyTy(TargetType target)
: Type(ID::TensorList_Any, "TensorList_Any", false, target,
PRECISION(kAny), DATALAYOUT(kAny)) {}
};
class TensorFp32NCHWTy : public Type {
public:
explicit TensorFp32NCHWTy(TargetType target)
: Type(ID::Tensor_Fp32_NCHW, "TensorFp32NCHW", true /*is_tensor*/, target,
PrecisionType::kFloat, DataLayoutType::kNCHW) {}
};
class TensorInt8NCHWTy : public Type {
public:
explicit TensorInt8NCHWTy(TargetType target)
: Type(ID::Tensor_Int8_NCHW, "TensorInt8NCHW", true /*is_tensor*/, target,
PrecisionType::kInt8, DataLayoutType::kNCHW) {}
};
class TensorInt64NCHWTy : public Type {
public:
explicit TensorInt64NCHWTy(TargetType target)
: Type(ID::Tensor_Int64_NCHW, "TensorInt64NCHW", true /*is_tensor*/,
target, PrecisionType::kInt8, DataLayoutType::kNCHW) {}
};
const Type* LookupType(DataTypeBase::ID type_id, bool is_unknown, const Type* LookupType(DataType::ID type_id, bool is_unknown, bool is_tensor,
bool is_tensor, Place place); Place place);
// ------------------------- end predefined types --------------------------- // ------------------------- end predefined types ---------------------------
/* /*
......
...@@ -18,15 +18,13 @@ ...@@ -18,15 +18,13 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
TEST(TypeSystem, test) { TEST(TypeSystem, CheckDuplicateGet) {
ASSERT_TRUE(TypeSystem::Global().Contains<lite::TensorBase>()); auto* tensor_ty =
} Type::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
auto* tensor_ty1 =
Type::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
TEST(TypeSystem, register_new) { ASSERT_EQ(tensor_ty, tensor_ty1);
TypeSystem::Global().Register<int>("int32");
ASSERT_TRUE(TypeSystem::Global().Contains<int>());
ASSERT_TRUE(TypeSystem::Global().Contains(typeid(int).hash_code()));
ASSERT_TRUE(TypeSystem::Global().Contains("int32"));
} }
} // namespace lite } // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册