提交 fd2b4b47 编写于 作者: Y yuyang18

Make tensor support uint8

上级 9707aa6b
...@@ -58,6 +58,7 @@ static DataTypeMap* InitDataTypeMap() { ...@@ -58,6 +58,7 @@ static DataTypeMap* InitDataTypeMap() {
RegType(bool, proto::VarType::BOOL); RegType(bool, proto::VarType::BOOL);
RegType(size_t, proto::VarType::SIZE_T); RegType(size_t, proto::VarType::SIZE_T);
RegType(int16_t, proto::VarType::INT16); RegType(int16_t, proto::VarType::INT16);
RegType(uint8_t, proto::VarType::UINT8);
#undef RegType #undef RegType
return retv; return retv;
......
...@@ -47,8 +47,14 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { ...@@ -47,8 +47,14 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
case proto::VarType::BOOL: case proto::VarType::BOOL:
visitor.template operator()<bool>(); visitor.template operator()<bool>();
break; break;
case proto::VarType::UINT8:
visitor.template operator()<uint8_t>();
break;
case proto::VarType::INT16:
visitor.template operator()<int16_t>();
break;
default: default:
PADDLE_THROW("Not supported"); PADDLE_THROW("Not supported %d", type);
} }
} }
......
...@@ -103,6 +103,7 @@ message VarType { ...@@ -103,6 +103,7 @@ message VarType {
FP64 = 6; FP64 = 6;
// Tensor<size_t> is used in C++. // Tensor<size_t> is used in C++.
SIZE_T = 19; SIZE_T = 19;
UINT8 = 20;
// Other types that may need additional descriptions // Other types that may need additional descriptions
LOD_TENSOR = 7; LOD_TENSOR = 7;
......
...@@ -228,11 +228,12 @@ TEST(LoD, CheckAbsLoD) { ...@@ -228,11 +228,12 @@ TEST(LoD, CheckAbsLoD) {
ASSERT_FALSE(CheckAbsLoD(abs_lod0)); ASSERT_FALSE(CheckAbsLoD(abs_lod0));
} }
TEST(LoDTensor, RecordIO) { template <typename T>
static void TestRecordIO() {
LoDTensor tensor; LoDTensor tensor;
int* tmp = tensor.mutable_data<int>(make_ddim({4, 5}), platform::CPUPlace()); T* tmp = tensor.mutable_data<T>(make_ddim({4, 5}), platform::CPUPlace());
for (int i = 0; i < 20; ++i) { for (int i = 0; i < 20; ++i) {
tmp[i] = i; tmp[i] = static_cast<T>(i);
} }
std::stringstream* stream = new std::stringstream(); std::stringstream* stream = new std::stringstream();
...@@ -247,7 +248,7 @@ TEST(LoDTensor, RecordIO) { ...@@ -247,7 +248,7 @@ TEST(LoDTensor, RecordIO) {
auto assert_tensor_ok = [](const LoDTensor& tensor) { auto assert_tensor_ok = [](const LoDTensor& tensor) {
for (int i = 0; i < 20; ++i) { for (int i = 0; i < 20; ++i) {
ASSERT_EQ(tensor.data<int>()[i], i); ASSERT_EQ(tensor.data<T>()[i], static_cast<T>(i));
} }
}; };
...@@ -265,5 +266,13 @@ TEST(LoDTensor, RecordIO) { ...@@ -265,5 +266,13 @@ TEST(LoDTensor, RecordIO) {
} }
} }
TEST(LoDTensor, RecordIO) {
TestRecordIO<int>();
TestRecordIO<int16_t>();
TestRecordIO<uint8_t>();
TestRecordIO<float>();
TestRecordIO<double>();
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -38,7 +38,9 @@ template struct SetConstant<platform::CPUDeviceContext, bool>; ...@@ -38,7 +38,9 @@ template struct SetConstant<platform::CPUDeviceContext, bool>;
template struct Transpose<platform::CPUDeviceContext, double, RANK>; \ template struct Transpose<platform::CPUDeviceContext, double, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int, RANK>; \ template struct Transpose<platform::CPUDeviceContext, int, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \ template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, bool, RANK>; template struct Transpose<platform::CPUDeviceContext, bool, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int16_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, uint8_t, RANK>;
DEFINE_CPU_TRANS(1); DEFINE_CPU_TRANS(1);
DEFINE_CPU_TRANS(2); DEFINE_CPU_TRANS(2);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册