提交 e6c68a2d 编写于 作者: K Kavya Srinet

Fixed data_type.h

上级 e18bc024
...@@ -20,35 +20,35 @@ limitations under the License. */ ...@@ -20,35 +20,35 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
inline proto::DataType ToDataType(std::type_index type) { inline proto::VarType::Type ToDataType(std::type_index type) {
using namespace paddle::framework::proto; using namespace paddle::framework::proto;
if (typeid(float).hash_code() == type.hash_code()) { if (typeid(float).hash_code() == type.hash_code()) {
return DataType::FP32; return VarType::Type::FP32;
} else if (typeid(double).hash_code() == type.hash_code()) { } else if (typeid(double).hash_code() == type.hash_code()) {
return DataType::FP64; return VarType::Type::FP64;
} else if (typeid(int).hash_code() == type.hash_code()) { } else if (typeid(int).hash_code() == type.hash_code()) {
return DataType::INT32; return VarType::Type::INT32;
} else if (typeid(int64_t).hash_code() == type.hash_code()) { } else if (typeid(int64_t).hash_code() == type.hash_code()) {
return DataType::INT64; return VarType::Type::INT64;
} else if (typeid(bool).hash_code() == type.hash_code()) { } else if (typeid(bool).hash_code() == type.hash_code()) {
return DataType::BOOL; return VarType::Type::BOOL;
} else { } else {
PADDLE_THROW("Not supported"); PADDLE_THROW("Not supported");
} }
} }
inline std::type_index ToTypeIndex(proto::DataType type) { inline std::type_index ToTypeIndex(proto::VarType::Type type) {
using namespace paddle::framework::proto; using namespace paddle::framework::proto;
switch (type) { switch (type) {
case DataType::FP32: case VarType::Type::FP32:
return typeid(float); return typeid(float);
case DataType::FP64: case VarType::Type::FP64:
return typeid(double); return typeid(double);
case DataType::INT32: case VarType::Type::INT32:
return typeid(int); return typeid(int);
case DataType::INT64: case VarType::Type::INT64:
return typeid(int64_t); return typeid(int64_t);
case DataType::BOOL: case VarType::Type::BOOL:
return typeid(bool); return typeid(bool);
default: default:
PADDLE_THROW("Not support type %d", type); PADDLE_THROW("Not support type %d", type);
...@@ -56,22 +56,22 @@ inline std::type_index ToTypeIndex(proto::DataType type) { ...@@ -56,22 +56,22 @@ inline std::type_index ToTypeIndex(proto::DataType type) {
} }
template <typename Visitor> template <typename Visitor>
inline void VisitDataType(proto::DataType type, Visitor visitor) { inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
using namespace paddle::framework::proto; using namespace paddle::framework::proto;
switch (type) { switch (type) {
case DataType::FP32: case VarType::Type::FP32:
visitor.template operator()<float>(); visitor.template operator()<float>();
break; break;
case DataType::FP64: case VarType::Type::FP64:
visitor.template operator()<double>(); visitor.template operator()<double>();
break; break;
case DataType::INT32: case VarType::Type::INT32:
visitor.template operator()<int>(); visitor.template operator()<int>();
break; break;
case DataType::INT64: case VarType::Type::INT64:
visitor.template operator()<int64_t>(); visitor.template operator()<int64_t>();
break; break;
case DataType::BOOL: case VarType::Type::BOOL:
visitor.template operator()<bool>(); visitor.template operator()<bool>();
break; break;
default: default:
...@@ -79,22 +79,22 @@ inline void VisitDataType(proto::DataType type, Visitor visitor) { ...@@ -79,22 +79,22 @@ inline void VisitDataType(proto::DataType type, Visitor visitor) {
} }
} }
inline std::string DataTypeToString(const proto::DataType type) { inline std::string DataTypeToString(const proto::VarType::Type type) {
using namespace paddle::framework::proto; using namespace paddle::framework::proto;
switch (type) { switch (type) {
case DataType::FP16: case VarType::Type::FP16:
return "float16"; return "float16";
case DataType::FP32: case VarType::Type::FP32:
return "float32"; return "float32";
case DataType::FP64: case VarType::Type::FP64:
return "float64"; return "float64";
case DataType::INT16: case VarType::Type::INT16:
return "int16"; return "int16";
case DataType::INT32: case VarType::Type::INT32:
return "int32"; return "int32";
case DataType::INT64: case VarType::Type::INT64:
return "int64"; return "int64";
case DataType::BOOL: case VarType::Type::BOOL:
return "bool"; return "bool";
default: default:
PADDLE_THROW("Not support type %d", type); PADDLE_THROW("Not support type %d", type);
...@@ -102,7 +102,7 @@ inline std::string DataTypeToString(const proto::DataType type) { ...@@ -102,7 +102,7 @@ inline std::string DataTypeToString(const proto::DataType type) {
} }
inline std::ostream& operator<<(std::ostream& out, inline std::ostream& operator<<(std::ostream& out,
const proto::DataType& type) { const proto::VarType::Type& type) {
out << DataTypeToString(type); out << DataTypeToString(type);
return out; return out;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册