diff --git a/paddle/phi/kernels/funcs/data_layout_transform.h b/paddle/phi/kernels/funcs/data_layout_transform.h index 62cfb45b8c02eefd6868dafbc3c7a209543161ad..a2a50937752e4f2de6ecea3037ef054728751e33 100644 --- a/paddle/phi/kernels/funcs/data_layout_transform.h +++ b/paddle/phi/kernels/funcs/data_layout_transform.h @@ -48,14 +48,30 @@ inline OneDNNMemoryFormat ToOneDNNFormat(const DataLayout& layout) { } } -// Caution: proto::VarType::Type -> phi::DataType after transfer inline OneDNNDataType ToOneDNNDataType(DataType type) { - static std::unordered_map dict{ - {DataType::FLOAT32, OneDNNDataType::f32}, - {DataType::INT8, OneDNNDataType::s8}, - {DataType::UINT8, OneDNNDataType::u8}, - {DataType::INT32, OneDNNDataType::s32}, - {DataType::BFLOAT16, OneDNNDataType::bf16}}; +#if __GNUC__ > 5 + using DataTypeMapping = std::unordered_map; +#else + struct DataTypeHash { + std::size_t operator()(const DataType& f) const { + return std::hash{}(static_cast(f)); + } + }; + struct DataTypeEqual { + bool operator()(const DataType& lhs, const DataType& rhs) const { + return static_cast(lhs) == static_cast(rhs); + } + }; + using DataTypeMapping = + std::unordered_map; +#endif + + static DataTypeMapping dict{{DataType::FLOAT32, OneDNNDataType::f32}, + {DataType::INT8, OneDNNDataType::s8}, + {DataType::UINT8, OneDNNDataType::u8}, + {DataType::INT32, OneDNNDataType::s32}, + {DataType::BFLOAT16, OneDNNDataType::bf16}}; + auto iter = dict.find(type); if (iter != dict.end()) return iter->second; return OneDNNDataType::undef;