From fd2b4b478ed6b3dffcf75b87bb90d33473e80f86 Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Wed, 16 May 2018 12:47:07 +0800 Subject: [PATCH] Make tensor support uint8 --- paddle/fluid/framework/data_type.cc | 1 + paddle/fluid/framework/data_type.h | 8 +++++++- paddle/fluid/framework/framework.proto | 1 + paddle/fluid/framework/lod_tensor_test.cc | 17 +++++++++++++---- paddle/fluid/operators/math/math_function.cc | 4 +++- 5 files changed, 25 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/data_type.cc b/paddle/fluid/framework/data_type.cc index b9c90cb0c32..b6b93cf422a 100644 --- a/paddle/fluid/framework/data_type.cc +++ b/paddle/fluid/framework/data_type.cc @@ -58,6 +58,7 @@ static DataTypeMap* InitDataTypeMap() { RegType(bool, proto::VarType::BOOL); RegType(size_t, proto::VarType::SIZE_T); RegType(int16_t, proto::VarType::INT16); + RegType(uint8_t, proto::VarType::UINT8); #undef RegType return retv; diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index 4b9f572ec5f..491413db8c8 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -47,8 +47,14 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { case proto::VarType::BOOL: visitor.template operator()(); break; + case proto::VarType::UINT8: + visitor.template operator()(); + break; + case proto::VarType::INT16: + visitor.template operator()(); + break; default: - PADDLE_THROW("Not supported"); + PADDLE_THROW("Not supported %d", type); } } diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index d2558f111f4..d35125fe8c3 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -103,6 +103,7 @@ message VarType { FP64 = 6; // Tensor is used in C++. SIZE_T = 19; + UINT8 = 20; // Other types that may need additional descriptions LOD_TENSOR = 7; diff --git a/paddle/fluid/framework/lod_tensor_test.cc b/paddle/fluid/framework/lod_tensor_test.cc index 77e5ec4c7dd..2ceffc93319 100644 --- a/paddle/fluid/framework/lod_tensor_test.cc +++ b/paddle/fluid/framework/lod_tensor_test.cc @@ -228,11 +228,12 @@ TEST(LoD, CheckAbsLoD) { ASSERT_FALSE(CheckAbsLoD(abs_lod0)); } -TEST(LoDTensor, RecordIO) { +template +static void TestRecordIO() { LoDTensor tensor; - int* tmp = tensor.mutable_data(make_ddim({4, 5}), platform::CPUPlace()); + T* tmp = tensor.mutable_data(make_ddim({4, 5}), platform::CPUPlace()); for (int i = 0; i < 20; ++i) { - tmp[i] = i; + tmp[i] = static_cast(i); } std::stringstream* stream = new std::stringstream(); @@ -247,7 +248,7 @@ TEST(LoDTensor, RecordIO) { auto assert_tensor_ok = [](const LoDTensor& tensor) { for (int i = 0; i < 20; ++i) { - ASSERT_EQ(tensor.data()[i], i); + ASSERT_EQ(tensor.data()[i], static_cast(i)); } }; @@ -265,5 +266,13 @@ TEST(LoDTensor, RecordIO) { } } +TEST(LoDTensor, RecordIO) { + TestRecordIO(); + TestRecordIO(); + TestRecordIO(); + TestRecordIO(); + TestRecordIO(); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index d62ea387cc5..d39154c6f88 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -38,7 +38,9 @@ template struct SetConstant; template struct Transpose; \ template struct Transpose; \ template struct Transpose; \ - template struct Transpose; + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; DEFINE_CPU_TRANS(1); DEFINE_CPU_TRANS(2); -- GitLab