提交 123c79fa 编写于 作者: S Superjomn

refactor type system

上级 9d6a0c88
......@@ -13,126 +13,118 @@
// limitations under the License.
#include "paddle/fluid/lite/core/type_system.h"
#include "type_system.h"
namespace paddle {
namespace lite {
// ------------------------- GetType specification ----------------------------
template <>
const Type*
Type::Get<false /*is_unsupported*/, false /*is_tensor*/, TargetType::kHost,
PrecisionType::kFloat, DataLayoutType::kNCHW>() {
static UnsupportedTy x;
return &x;
}
// ------------------------- end GetType specification ------------------------
template <>
const Type*
Type::Get<false /*is_unsupported*/, true /*is_tensor*/, TargetType::kX86,
PrecisionType::kFloat, DataLayoutType::kNCHW>() {
static TensorFp32NCHWTy x(TargetType::kX86);
return &x;
size_t ParamTypeRegistry::KernelIdTy::hash() const {
std::hash<std::string> h;
size_t hash = h(kernel_type);
hash = hash_combine(hash, place.hash());
hash = hash_combine(hash, std::hash<int>()(static_cast<int>(io)));
hash = hash_combine(hash, std::hash<std::string>()(arg_name));
return hash;
}
template <>
const Type*
Type::Get<false /*is_unsupported*/, true /*is_tensor*/, TargetType::kHost,
PrecisionType::kFloat, DataLayoutType::kNCHW>() {
static TensorFp32NCHWTy x(TargetType::kHost);
return &x;
std::ostream &operator<<(std::ostream &os, const Type &other) {
if (other.IsUnsupported()) {
os << "<Unsupported>";
return os;
}
if (other.IsVoid()) {
os << "<Void>";
return os;
}
if (other.IsTensor()) {
os << "<Tensor:";
} else {
os << "<";
}
os << TargetToStr(other.target()) << "/" << PrecisionToStr(other.precision())
<< "/" << DataLayoutToStr(other.layout()) << ">";
return os;
}
template <>
const Type* Type::Get<UnsupportedTy>(TargetType target) {
return Get<false, false, TargetType::kHost, PrecisionType::kFloat,
DataLayoutType::kNCHW>();
}
const Type *Type::GetTensorTy(TargetType target, PrecisionType precision,
DataLayoutType layout, int device) {
// NOTE quite naive implementation here, but not performance sensitive.
DataType::ID type_id = DataType::ID::Tensor;
template <TargetType Target>
TensorListAnyTy* GetTensorListAnyTy() {
static TensorListAnyTy x(Target);
return &x;
}
template <TargetType Target>
TensorAnyTy* GetTensorAnyTy() {
static TensorAnyTy x(Target);
return &x;
}
#define HASH_ONE(x) v = hash_combine(v, hasher(static_cast<int>(x)))
template <>
const Type* Type::Get<TensorListAnyTy>(TargetType target) {
switch (target) {
case TargetType::kHost:
return GetTensorListAnyTy<TARGET(kHost)>();
case TargetType::kCUDA:
return GetTensorListAnyTy<TARGET(kCUDA)>();
case TargetType::kX86:
return GetTensorListAnyTy<TARGET(kX86)>();
default:
LOG(FATAL) << "unsupported type";
}
}
std::hash<int> hasher;
size_t v = hasher(static_cast<int>(type_id));
HASH_ONE(target);
HASH_ONE(precision);
HASH_ONE(layout);
HASH_ONE(device);
#undef HASH_ONE
template <>
const Type* Type::Get<TensorAnyTy>(TargetType target) {
switch (target) {
case TargetType::kHost:
return GetTensorAnyTy<TARGET(kHost)>();
case TargetType::kCUDA:
return GetTensorAnyTy<TARGET(kCUDA)>();
case TargetType::kX86:
return GetTensorAnyTy<TARGET(kX86)>();
default:
LOG(FATAL) << "unsupported type";
std::stringstream name;
name << "Tensor<";
name << TargetToStr(target) << ",";
name << PrecisionToStr(precision) << ",";
name << DataLayoutToStr(layout) << ",";
name << device;
name << ">";
auto it = type_repo_.find(v);
if (it == type_repo_.end()) {
// The Types should alive across the process life, no need to delete.
type_repo_[v] =
new Type(type_id, name.str(), target, precision, layout, device);
}
return type_repo_[v];
}
template <TargetType Target>
const Type* GetTensorFp32NCHWTy() {
static TensorFp32NCHWTy x(Target);
return &x;
}
const Type *Type::GetTensorListTy(TargetType target, PrecisionType precision,
DataLayoutType layout, int device) {
DataType::ID type_id = DataType::ID::TensorList;
template <>
const Type* Type::Get<TensorFp32NCHWTy>(TargetType target) {
switch (target) {
case TARGET(kHost):
return GetTensorFp32NCHWTy<TARGET(kHost)>();
case TARGET(kCUDA):
return GetTensorFp32NCHWTy<TARGET(kCUDA)>();
case TARGET(kX86):
return GetTensorFp32NCHWTy<TARGET(kX86)>();
default:
LOG(FATAL) << "unsupported target Type " << TargetToStr(target);
}
return nullptr;
}
#define HASH_ONE(x) v = hash_combine(v, hasher(static_cast<int>(x)))
const Type* LookupType(DataTypeBase::ID type_id, bool is_unknown,
bool is_tensor, Place place) {
using id_t = DataTypeBase::ID;
switch (type_id) {
case id_t::Tensor_Any:
return Type::Get<TensorAnyTy>(place.target);
case id_t::Tensor_Fp32_NCHW:
return Type::Get<TensorFp32NCHWTy>(place.target);
case id_t::TensorList_Any:
return Type::Get<TensorListAnyTy>(place.target);
default:
LOG(FATAL) << "unsupported type";
}
return nullptr;
std::hash<int> hasher;
size_t v = hasher(static_cast<int>(type_id));
HASH_ONE(target);
HASH_ONE(precision);
HASH_ONE(layout);
HASH_ONE(device);
#undef HASH_ONE
std::stringstream name;
name << "TensorList<";
name << TargetToStr(target) << ",";
name << PrecisionToStr(precision) << ",";
name << DataLayoutToStr(layout) << ",";
name << device;
name << ">";
if (!type_repo_[v])
// The Types should alive across the process life, no need to delete.
type_repo_[v] =
new Type(type_id, name.str(), target, precision, layout, device);
return type_repo_[v];
}
// ------------------------- end GetType specification ------------------------
const Type *Type::GetUnsupportedTy() {
std::hash<int> hasher;
size_t v = hasher(static_cast<int>(DataType::ID::Unsupported));
if (!type_repo_[v])
type_repo_[v] =
new Type(DataType::ID::Unsupported, "Unsupported", TARGET(kUnk),
PRECISION(kUnk), DATALAYOUT(kUnk), -1);
}
size_t ParamTypeRegistry::KernelIdTy::hash() const {
std::hash<std::string> h;
size_t hash = h(kernel_type);
hash = hash_combine(hash, place.hash());
hash = hash_combine(hash, std::hash<int>()(static_cast<int>(io)));
hash = hash_combine(hash, std::hash<std::string>()(arg_name));
return hash;
const Type *Type::GetVoidTy() {
std::hash<int> hasher;
size_t v = hasher(static_cast<int>(DataType::ID::Void));
if (!type_repo_[v])
type_repo_[v] = new Type(DataType::ID::Void, "Void", TARGET(kAny),
PRECISION(kAny), DATALAYOUT(kAny), -1);
}
} // namespace lite
......
......@@ -73,7 +73,7 @@ namespace lite {
//
// TODO(Superjomn) Add operator/kernel-wise static checking to avoid unsupported
// type mixed in the system.
class DataTypeBase {
class DataType {
public:
// The Void type can cast to any other type.
// The Unsupported is the data type that developed include in the system, for
......@@ -83,43 +83,56 @@ class DataTypeBase {
enum class ID : int {
Void = 0, // unknown type that can be cast to any data type.
Unsupported, // Unsupported data type that will not be analyzed.
Tensor_Fp32_NCHW,
Tensor_Int8_NCHW,
Tensor_Int64_NCHW,
// Tensor_Any represents a Tensor with any place, data, layout. It is used
// in some IO kernels those doesn't care the data.
Tensor_Any,
// Used by feed or fetch op.
TensorList_Any,
Tensor,
// A tensor list, but all the elements should have the same type.
TensorList,
// ---------
NumTypes, // Must remains as last defined ID.
};
ID id() const { return id_; }
// type check.
bool IsTensor() const { return is_tensor_; }
bool IsVoid() const { return id_ == ID::Void; }
bool IsUnsupported() const { return id_ == ID::Unsupported; }
bool IsTensorFp32NCHW() const { return id_ == ID::Tensor_Fp32_NCHW; }
bool IsTensorInt8NCHW() const { return id_ == ID::Tensor_Int8_NCHW; }
bool IsTensorInt64NCHW() const { return id_ == ID::Tensor_Int64_NCHW; }
bool IsTensor() const { return id_ == ID::Tensor; }
bool IsTensorList() const { return id_ == ID::TensorList; }
// Get number of types.
int num_types() const { return static_cast<int>(ID::NumTypes); }
protected:
// Can only extended by subclass.
DataTypeBase(ID id, bool is_tensor) : id_(id), is_tensor_(is_tensor) {}
DataType(ID id) : id_(id) {}
ID id_{ID::Unsupported};
bool is_tensor_{false};
};
/*
* Datatype with device info considered.
* NOTE A Type with different device is treated as different DeviceDataType.
*/
class Type : public DataTypeBase {
class Type : public DataType {
public:
// Can cast to another type. This is heavily used in MIR, by determine whether
// is is possible to add a statement to transform a type to another.
virtual bool TypeCastable(const Type& type) const { return id_ == type.id(); }
/// Get a Tensor type.
static const Type* GetTensorTy(TargetType target,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW),
int device = 0);
/// Get a TensorList type.
static const Type* GetTensorListTy(
TargetType target, PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW), int device = 0);
/// Get an Unsupported type.
static const Type* GetUnsupportedTy();
/// Get an Void type.
static const Type* GetVoidTy();
TargetType target() const { return place_.target; }
PrecisionType precision() const { return place_.precision; }
DataLayoutType layout() const { return place_.layout; }
......@@ -130,52 +143,23 @@ class Type : public DataTypeBase {
bool operator==(const Type& other) {
return id_ == other.id() && place_ == other.place();
}
friend std::ostream& operator<<(std::ostream& os, const Type& other) {
if (other.IsUnsupported()) {
os << "<Unsupported>";
return os;
}
if (other.IsVoid()) {
os << "<Void>";
return os;
}
if (other.is_tensor_) {
os << "<Tensor:";
} else {
os << "<";
}
os << TargetToStr(other.target()) << "/"
<< PrecisionToStr(other.precision()) << "/"
<< DataLayoutToStr(other.layout()) << ">";
return os;
}
// Can cast to another type. This is heavily used in MIR, by determine whether
// is is possible to add a statement to transform a type to another.
virtual bool TypeCastable(const Type& type) const { return id_ == type.id(); }
template <bool is_unknown, bool is_tensor = true,
TargetType target = TargetType::kHost,
PrecisionType precision = PrecisionType::kFloat,
DataLayoutType layout = DataLayoutType::kNCHW>
// Get a type.
static const Type* Get();
template <typename TypeTy>
static const Type* Get(TargetType target = TargetType::kHost);
friend std::ostream& operator<<(std::ostream& os, const Type& other);
virtual ~Type() = default;
protected:
Type(ID id, const std::string& name, bool is_tensor,
TargetType target = TargetType::kHost,
/// One should avoid using this construct.
Type(ID id, const std::string& name, TargetType target = TargetType::kHost,
PrecisionType precision = PrecisionType::kFloat,
DataLayoutType layout = DataLayoutType::kNCHW, short device = 0)
: DataTypeBase(id, is_tensor),
place_{target, precision, layout, device},
name_(name) {}
: DataType(id), place_{target, precision, layout, device}, name_(name) {}
// An map is used here to maintain a global repo for types. We don't use
// MACROs with static variables for that the TypeSystem should only used in
// compile time, that is not performance sensitive, and a map-based way is
// easier to implement and maintain.
static std::map<size_t, const Type*> type_repo_;
protected:
Place place_;
const std::string name_;
};
......@@ -224,46 +208,15 @@ static bool TypeCompatibleTo(const Type& a, const Type& b) {
// is only one instance across the system.
class VoidTy : public Type {
public:
VoidTy() : Type(ID::Void, "Void", false /*is_tensor*/) {}
VoidTy() : Type(ID::Void, "Void") {}
};
class UnsupportedTy : public Type {
public:
UnsupportedTy() : Type(ID::Unsupported, "Unsupported", false /*is_tensor*/) {}
};
class TensorAnyTy : public Type {
public:
explicit TensorAnyTy(TargetType target)
: Type(ID::Tensor_Any, "TensorAny", true, target, PRECISION(kAny),
DATALAYOUT(kAny)) {}
};
// A list of tensor, and no assumption on the data layout or data type.
class TensorListAnyTy : public Type {
public:
explicit TensorListAnyTy(TargetType target)
: Type(ID::TensorList_Any, "TensorList_Any", false, target,
PRECISION(kAny), DATALAYOUT(kAny)) {}
};
class TensorFp32NCHWTy : public Type {
public:
explicit TensorFp32NCHWTy(TargetType target)
: Type(ID::Tensor_Fp32_NCHW, "TensorFp32NCHW", true /*is_tensor*/, target,
PrecisionType::kFloat, DataLayoutType::kNCHW) {}
};
class TensorInt8NCHWTy : public Type {
public:
explicit TensorInt8NCHWTy(TargetType target)
: Type(ID::Tensor_Int8_NCHW, "TensorInt8NCHW", true /*is_tensor*/, target,
PrecisionType::kInt8, DataLayoutType::kNCHW) {}
};
class TensorInt64NCHWTy : public Type {
public:
explicit TensorInt64NCHWTy(TargetType target)
: Type(ID::Tensor_Int64_NCHW, "TensorInt64NCHW", true /*is_tensor*/,
target, PrecisionType::kInt8, DataLayoutType::kNCHW) {}
};
const Type* LookupType(DataTypeBase::ID type_id, bool is_unknown,
bool is_tensor, Place place);
const Type* LookupType(DataType::ID type_id, bool is_unknown, bool is_tensor,
Place place);
// ------------------------- end predefined types ---------------------------
/*
......
......@@ -18,15 +18,13 @@
namespace paddle {
namespace lite {
TEST(TypeSystem, test) {
ASSERT_TRUE(TypeSystem::Global().Contains<lite::TensorBase>());
}
TEST(TypeSystem, CheckDuplicateGet) {
auto* tensor_ty =
Type::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
auto* tensor_ty1 =
Type::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
TEST(TypeSystem, register_new) {
TypeSystem::Global().Register<int>("int32");
ASSERT_TRUE(TypeSystem::Global().Contains<int>());
ASSERT_TRUE(TypeSystem::Global().Contains(typeid(int).hash_code()));
ASSERT_TRUE(TypeSystem::Global().Contains("int32"));
ASSERT_EQ(tensor_ty, tensor_ty1);
}
} // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册