diff --git a/paddle/fluid/lite/core/type_system.cc b/paddle/fluid/lite/core/type_system.cc index 233e04d7bc5f0d8f2191beb77061d4e1dd1853c8..4c1ea9d729e0438f33c20db2a2ca675c757cc80f 100644 --- a/paddle/fluid/lite/core/type_system.cc +++ b/paddle/fluid/lite/core/type_system.cc @@ -13,126 +13,118 @@ // limitations under the License. #include "paddle/fluid/lite/core/type_system.h" +#include "type_system.h" namespace paddle { namespace lite { // ------------------------- GetType specification ---------------------------- -template <> -const Type* -Type::Get() { - static UnsupportedTy x; - return &x; -} +// ------------------------- end GetType specification ------------------------ -template <> -const Type* -Type::Get() { - static TensorFp32NCHWTy x(TargetType::kX86); - return &x; +size_t ParamTypeRegistry::KernelIdTy::hash() const { + std::hash h; + size_t hash = h(kernel_type); + hash = hash_combine(hash, place.hash()); + hash = hash_combine(hash, std::hash()(static_cast(io))); + hash = hash_combine(hash, std::hash()(arg_name)); + return hash; } -template <> -const Type* -Type::Get() { - static TensorFp32NCHWTy x(TargetType::kHost); - return &x; +std::ostream &operator<<(std::ostream &os, const Type &other) { + if (other.IsUnsupported()) { + os << ""; + return os; + } + if (other.IsVoid()) { + os << ""; + return os; + } + if (other.IsTensor()) { + os << ""; + return os; } -template <> -const Type* Type::Get(TargetType target) { - return Get(); -} +const Type *Type::GetTensorTy(TargetType target, PrecisionType precision, + DataLayoutType layout, int device) { + // NOTE quite naive implementation here, but not performance sensitive. + DataType::ID type_id = DataType::ID::Tensor; -template -TensorListAnyTy* GetTensorListAnyTy() { - static TensorListAnyTy x(Target); - return &x; -} -template -TensorAnyTy* GetTensorAnyTy() { - static TensorAnyTy x(Target); - return &x; -} +#define HASH_ONE(x) v = hash_combine(v, hasher(static_cast(x))) -template <> -const Type* Type::Get(TargetType target) { - switch (target) { - case TargetType::kHost: - return GetTensorListAnyTy(); - case TargetType::kCUDA: - return GetTensorListAnyTy(); - case TargetType::kX86: - return GetTensorListAnyTy(); - default: - LOG(FATAL) << "unsupported type"; - } -} + std::hash hasher; + size_t v = hasher(static_cast(type_id)); + HASH_ONE(target); + HASH_ONE(precision); + HASH_ONE(layout); + HASH_ONE(device); +#undef HASH_ONE -template <> -const Type* Type::Get(TargetType target) { - switch (target) { - case TargetType::kHost: - return GetTensorAnyTy(); - case TargetType::kCUDA: - return GetTensorAnyTy(); - case TargetType::kX86: - return GetTensorAnyTy(); - default: - LOG(FATAL) << "unsupported type"; + std::stringstream name; + name << "Tensor<"; + name << TargetToStr(target) << ","; + name << PrecisionToStr(precision) << ","; + name << DataLayoutToStr(layout) << ","; + name << device; + name << ">"; + + auto it = type_repo_.find(v); + if (it == type_repo_.end()) { + // The Types should alive across the process life, no need to delete. + type_repo_[v] = + new Type(type_id, name.str(), target, precision, layout, device); } + return type_repo_[v]; } -template -const Type* GetTensorFp32NCHWTy() { - static TensorFp32NCHWTy x(Target); - return &x; -} +const Type *Type::GetTensorListTy(TargetType target, PrecisionType precision, + DataLayoutType layout, int device) { + DataType::ID type_id = DataType::ID::TensorList; -template <> -const Type* Type::Get(TargetType target) { - switch (target) { - case TARGET(kHost): - return GetTensorFp32NCHWTy(); - case TARGET(kCUDA): - return GetTensorFp32NCHWTy(); - case TARGET(kX86): - return GetTensorFp32NCHWTy(); - default: - LOG(FATAL) << "unsupported target Type " << TargetToStr(target); - } - return nullptr; -} +#define HASH_ONE(x) v = hash_combine(v, hasher(static_cast(x))) -const Type* LookupType(DataTypeBase::ID type_id, bool is_unknown, - bool is_tensor, Place place) { - using id_t = DataTypeBase::ID; - switch (type_id) { - case id_t::Tensor_Any: - return Type::Get(place.target); - case id_t::Tensor_Fp32_NCHW: - return Type::Get(place.target); - case id_t::TensorList_Any: - return Type::Get(place.target); - default: - LOG(FATAL) << "unsupported type"; - } - return nullptr; + std::hash hasher; + size_t v = hasher(static_cast(type_id)); + HASH_ONE(target); + HASH_ONE(precision); + HASH_ONE(layout); + HASH_ONE(device); +#undef HASH_ONE + + std::stringstream name; + name << "TensorList<"; + name << TargetToStr(target) << ","; + name << PrecisionToStr(precision) << ","; + name << DataLayoutToStr(layout) << ","; + name << device; + name << ">"; + + if (!type_repo_[v]) + // The Types should alive across the process life, no need to delete. + type_repo_[v] = + new Type(type_id, name.str(), target, precision, layout, device); + return type_repo_[v]; } -// ------------------------- end GetType specification ------------------------ +const Type *Type::GetUnsupportedTy() { + std::hash hasher; + size_t v = hasher(static_cast(DataType::ID::Unsupported)); + if (!type_repo_[v]) + type_repo_[v] = + new Type(DataType::ID::Unsupported, "Unsupported", TARGET(kUnk), + PRECISION(kUnk), DATALAYOUT(kUnk), -1); +} -size_t ParamTypeRegistry::KernelIdTy::hash() const { - std::hash h; - size_t hash = h(kernel_type); - hash = hash_combine(hash, place.hash()); - hash = hash_combine(hash, std::hash()(static_cast(io))); - hash = hash_combine(hash, std::hash()(arg_name)); - return hash; +const Type *Type::GetVoidTy() { + std::hash hasher; + size_t v = hasher(static_cast(DataType::ID::Void)); + if (!type_repo_[v]) + type_repo_[v] = new Type(DataType::ID::Void, "Void", TARGET(kAny), + PRECISION(kAny), DATALAYOUT(kAny), -1); } } // namespace lite diff --git a/paddle/fluid/lite/core/type_system.h b/paddle/fluid/lite/core/type_system.h index 94f336d0b3154a73b1ec754a18d0c9985838a0e0..4ebcfbc2acfc8afbfecc7a96405615d93b599597 100644 --- a/paddle/fluid/lite/core/type_system.h +++ b/paddle/fluid/lite/core/type_system.h @@ -73,7 +73,7 @@ namespace lite { // // TODO(Superjomn) Add operator/kernel-wise static checking to avoid unsupported // type mixed in the system. -class DataTypeBase { +class DataType { public: // The Void type can cast to any other type. // The Unsupported is the data type that developed include in the system, for @@ -83,43 +83,56 @@ class DataTypeBase { enum class ID : int { Void = 0, // unknown type that can be cast to any data type. Unsupported, // Unsupported data type that will not be analyzed. - Tensor_Fp32_NCHW, - Tensor_Int8_NCHW, - Tensor_Int64_NCHW, // Tensor_Any represents a Tensor with any place, data, layout. It is used // in some IO kernels those doesn't care the data. - Tensor_Any, - // Used by feed or fetch op. - TensorList_Any, + Tensor, + // A tensor list, but all the elements should have the same type. + TensorList, + // --------- NumTypes, // Must remains as last defined ID. }; ID id() const { return id_; } // type check. - bool IsTensor() const { return is_tensor_; } bool IsVoid() const { return id_ == ID::Void; } bool IsUnsupported() const { return id_ == ID::Unsupported; } - bool IsTensorFp32NCHW() const { return id_ == ID::Tensor_Fp32_NCHW; } - bool IsTensorInt8NCHW() const { return id_ == ID::Tensor_Int8_NCHW; } - bool IsTensorInt64NCHW() const { return id_ == ID::Tensor_Int64_NCHW; } - + bool IsTensor() const { return id_ == ID::Tensor; } + bool IsTensorList() const { return id_ == ID::TensorList; } + // Get number of types. int num_types() const { return static_cast(ID::NumTypes); } protected: // Can only extended by subclass. - DataTypeBase(ID id, bool is_tensor) : id_(id), is_tensor_(is_tensor) {} + DataType(ID id) : id_(id) {} ID id_{ID::Unsupported}; - bool is_tensor_{false}; }; /* * Datatype with device info considered. * NOTE A Type with different device is treated as different DeviceDataType. */ -class Type : public DataTypeBase { +class Type : public DataType { public: + // Can cast to another type. This is heavily used in MIR, by determine whether + // is is possible to add a statement to transform a type to another. + virtual bool TypeCastable(const Type& type) const { return id_ == type.id(); } + + /// Get a Tensor type. + static const Type* GetTensorTy(TargetType target, + PrecisionType precision = PRECISION(kFloat), + DataLayoutType layout = DATALAYOUT(kNCHW), + int device = 0); + /// Get a TensorList type. + static const Type* GetTensorListTy( + TargetType target, PrecisionType precision = PRECISION(kFloat), + DataLayoutType layout = DATALAYOUT(kNCHW), int device = 0); + /// Get an Unsupported type. + static const Type* GetUnsupportedTy(); + /// Get an Void type. + static const Type* GetVoidTy(); + TargetType target() const { return place_.target; } PrecisionType precision() const { return place_.precision; } DataLayoutType layout() const { return place_.layout; } @@ -130,52 +143,23 @@ class Type : public DataTypeBase { bool operator==(const Type& other) { return id_ == other.id() && place_ == other.place(); } - friend std::ostream& operator<<(std::ostream& os, const Type& other) { - if (other.IsUnsupported()) { - os << ""; - return os; - } - if (other.IsVoid()) { - os << ""; - return os; - } - if (other.is_tensor_) { - os << ""; - return os; - } - - // Can cast to another type. This is heavily used in MIR, by determine whether - // is is possible to add a statement to transform a type to another. - virtual bool TypeCastable(const Type& type) const { return id_ == type.id(); } - - template - // Get a type. - static const Type* Get(); - - template - static const Type* Get(TargetType target = TargetType::kHost); + friend std::ostream& operator<<(std::ostream& os, const Type& other); virtual ~Type() = default; protected: - Type(ID id, const std::string& name, bool is_tensor, - TargetType target = TargetType::kHost, + /// One should avoid using this construct. + Type(ID id, const std::string& name, TargetType target = TargetType::kHost, PrecisionType precision = PrecisionType::kFloat, DataLayoutType layout = DataLayoutType::kNCHW, short device = 0) - : DataTypeBase(id, is_tensor), - place_{target, precision, layout, device}, - name_(name) {} + : DataType(id), place_{target, precision, layout, device}, name_(name) {} + + // An map is used here to maintain a global repo for types. We don't use + // MACROs with static variables for that the TypeSystem should only used in + // compile time, that is not performance sensitive, and a map-based way is + // easier to implement and maintain. + static std::map type_repo_; - protected: Place place_; const std::string name_; }; @@ -224,46 +208,15 @@ static bool TypeCompatibleTo(const Type& a, const Type& b) { // is only one instance across the system. class VoidTy : public Type { public: - VoidTy() : Type(ID::Void, "Void", false /*is_tensor*/) {} + VoidTy() : Type(ID::Void, "Void") {} }; class UnsupportedTy : public Type { public: UnsupportedTy() : Type(ID::Unsupported, "Unsupported", false /*is_tensor*/) {} }; -class TensorAnyTy : public Type { - public: - explicit TensorAnyTy(TargetType target) - : Type(ID::Tensor_Any, "TensorAny", true, target, PRECISION(kAny), - DATALAYOUT(kAny)) {} -}; -// A list of tensor, and no assumption on the data layout or data type. -class TensorListAnyTy : public Type { - public: - explicit TensorListAnyTy(TargetType target) - : Type(ID::TensorList_Any, "TensorList_Any", false, target, - PRECISION(kAny), DATALAYOUT(kAny)) {} -}; -class TensorFp32NCHWTy : public Type { - public: - explicit TensorFp32NCHWTy(TargetType target) - : Type(ID::Tensor_Fp32_NCHW, "TensorFp32NCHW", true /*is_tensor*/, target, - PrecisionType::kFloat, DataLayoutType::kNCHW) {} -}; -class TensorInt8NCHWTy : public Type { - public: - explicit TensorInt8NCHWTy(TargetType target) - : Type(ID::Tensor_Int8_NCHW, "TensorInt8NCHW", true /*is_tensor*/, target, - PrecisionType::kInt8, DataLayoutType::kNCHW) {} -}; -class TensorInt64NCHWTy : public Type { - public: - explicit TensorInt64NCHWTy(TargetType target) - : Type(ID::Tensor_Int64_NCHW, "TensorInt64NCHW", true /*is_tensor*/, - target, PrecisionType::kInt8, DataLayoutType::kNCHW) {} -}; -const Type* LookupType(DataTypeBase::ID type_id, bool is_unknown, - bool is_tensor, Place place); +const Type* LookupType(DataType::ID type_id, bool is_unknown, bool is_tensor, + Place place); // ------------------------- end predefined types --------------------------- /* diff --git a/paddle/fluid/lite/core/type_system_test.cc b/paddle/fluid/lite/core/type_system_test.cc index b26234b7c8bf423237fdd0208ec022958097de2a..b01aa0852ffa96721f63e45f0023e49f62868484 100644 --- a/paddle/fluid/lite/core/type_system_test.cc +++ b/paddle/fluid/lite/core/type_system_test.cc @@ -18,15 +18,13 @@ namespace paddle { namespace lite { -TEST(TypeSystem, test) { - ASSERT_TRUE(TypeSystem::Global().Contains()); -} +TEST(TypeSystem, CheckDuplicateGet) { + auto* tensor_ty = + Type::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); + auto* tensor_ty1 = + Type::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); -TEST(TypeSystem, register_new) { - TypeSystem::Global().Register("int32"); - ASSERT_TRUE(TypeSystem::Global().Contains()); - ASSERT_TRUE(TypeSystem::Global().Contains(typeid(int).hash_code())); - ASSERT_TRUE(TypeSystem::Global().Contains("int32")); + ASSERT_EQ(tensor_ty, tensor_ty1); } } // namespace lite