提交 717bbc08 编写于 作者: L lidanqing 提交者: ceci3

Add INT32 support. INT32 in last switch case

test=develop
上级 bcd7b993
...@@ -243,6 +243,8 @@ bool AnalysisPredictor::SetFeed(const std::vector<PaddleTensor> &inputs, ...@@ -243,6 +243,8 @@ bool AnalysisPredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
input_ptr = input.mutable_data<int64_t>(ddim, place_); input_ptr = input.mutable_data<int64_t>(ddim, place_);
} else if (inputs[i].dtype == PaddleDType::FLOAT32) { } else if (inputs[i].dtype == PaddleDType::FLOAT32) {
input_ptr = input.mutable_data<float>(ddim, place_); input_ptr = input.mutable_data<float>(ddim, place_);
} else if (inputs[i].dtype == PaddleDType::INT32) {
input_ptr = input.mutable_data<int32_t>(ddim, place_);
} else { } else {
LOG(ERROR) << "unsupported feed type " << inputs[i].dtype; LOG(ERROR) << "unsupported feed type " << inputs[i].dtype;
return false; return false;
...@@ -326,8 +328,11 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs, ...@@ -326,8 +328,11 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
} else if (type == framework::proto::VarType::INT64) { } else if (type == framework::proto::VarType::INT64) {
GetFetchOne<int64_t>(fetch, output); GetFetchOne<int64_t>(fetch, output);
output->dtype = PaddleDType::INT64; output->dtype = PaddleDType::INT64;
} else if (type == framework::proto::VarType::INT32) {
GetFetchOne<int32_t>(fetch, output);
output->dtype = PaddleDType::INT32;
} else { } else {
LOG(ERROR) << "unknown type, only support float32 and int64 now."; LOG(ERROR) << "unknown type, only support float32, int64 and int32 now.";
} }
} }
return true; return true;
......
...@@ -28,6 +28,8 @@ int PaddleDtypeSize(PaddleDType dtype) { ...@@ -28,6 +28,8 @@ int PaddleDtypeSize(PaddleDType dtype) {
return sizeof(float); return sizeof(float);
case PaddleDType::INT64: case PaddleDType::INT64:
return sizeof(int64_t); return sizeof(int64_t);
case PaddleDType::INT32:
return sizeof(int32_t);
default: default:
assert(false); assert(false);
return -1; return -1;
......
...@@ -203,6 +203,8 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs, ...@@ -203,6 +203,8 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
input_ptr = input.mutable_data<int64_t>(ddim, place_); input_ptr = input.mutable_data<int64_t>(ddim, place_);
} else if (inputs[i].dtype == PaddleDType::FLOAT32) { } else if (inputs[i].dtype == PaddleDType::FLOAT32) {
input_ptr = input.mutable_data<float>(ddim, place_); input_ptr = input.mutable_data<float>(ddim, place_);
} else if (inputs[i].dtype == PaddleDType::INT32) {
input_ptr = input.mutable_data<int32_t>(ddim, place_);
} else { } else {
LOG(ERROR) << "unsupported feed type " << inputs[i].dtype; LOG(ERROR) << "unsupported feed type " << inputs[i].dtype;
return false; return false;
...@@ -281,8 +283,11 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs, ...@@ -281,8 +283,11 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs,
} else if (type == framework::DataTypeTrait<int64_t>::DataType) { } else if (type == framework::DataTypeTrait<int64_t>::DataType) {
GetFetchOne<int64_t>(fetch, output); GetFetchOne<int64_t>(fetch, output);
output->dtype = PaddleDType::INT64; output->dtype = PaddleDType::INT64;
} else if (type == framework::DataTypeTrait<int32_t>::DataType) {
GetFetchOne<int32_t>(fetch, output);
output->dtype = PaddleDType::INT32;
} else { } else {
LOG(ERROR) << "unknown type, only support float32 and int64 now."; LOG(ERROR) << "unknown type, only support float32, int64 and int32 now.";
} }
} }
return true; return true;
......
...@@ -42,6 +42,9 @@ PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) { ...@@ -42,6 +42,9 @@ PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) {
} else if (t->type() == framework::proto::VarType::FP32) { } else if (t->type() == framework::proto::VarType::FP32) {
pt.data.Reset(t->data<void>(), t->numel() * sizeof(float)); pt.data.Reset(t->data<void>(), t->numel() * sizeof(float));
pt.dtype = PaddleDType::FLOAT32; pt.dtype = PaddleDType::FLOAT32;
} else if (t->type() == framework::proto::VarType::INT32) {
pt.data.Reset(t->data<void>(), t->numel() * sizeof(int32_t));
pt.dtype = PaddleDType::INT32;
} else { } else {
LOG(FATAL) << "unsupported type."; LOG(FATAL) << "unsupported type.";
} }
......
...@@ -88,13 +88,20 @@ void CheckOutput(const std::string& referfile, const PaddleTensor& output) { ...@@ -88,13 +88,20 @@ void CheckOutput(const std::string& referfile, const PaddleTensor& output) {
} }
break; break;
} }
case PaddleDType::FLOAT32: case PaddleDType::FLOAT32: {
for (size_t i = 0; i < numel; ++i) { for (size_t i = 0; i < numel; ++i) {
CHECK_LT( CHECK_LT(
fabs(static_cast<float*>(output.data.data())[i] - refer.data[i]), fabs(static_cast<float*>(output.data.data())[i] - refer.data[i]),
1e-5); 1e-5);
} }
break; break;
}
case PaddleDType::INT32: {
for (size_t i = 0; i < numel; ++i) {
CHECK_EQ(static_cast<int32_t*>(output.data.data())[i], refer.data[i]);
}
break;
}
} }
} }
...@@ -113,11 +120,18 @@ static std::string SummaryTensor(const PaddleTensor& tensor) { ...@@ -113,11 +120,18 @@ static std::string SummaryTensor(const PaddleTensor& tensor) {
} }
break; break;
} }
case PaddleDType::FLOAT32: case PaddleDType::FLOAT32: {
for (int i = 0; i < std::min(num_elems, 10); i++) { for (int i = 0; i < std::min(num_elems, 10); i++) {
ss << static_cast<float*>(tensor.data.data())[i] << " "; ss << static_cast<float*>(tensor.data.data())[i] << " ";
} }
break; break;
}
case PaddleDType::INT32: {
for (int i = 0; i < std::min(num_elems, 10); i++) {
ss << static_cast<int32_t*>(tensor.data.data())[i] << " ";
}
break;
}
} }
return ss.str(); return ss.str();
} }
......
...@@ -202,6 +202,9 @@ static std::string DescribeTensor(const PaddleTensor &tensor, ...@@ -202,6 +202,9 @@ static std::string DescribeTensor(const PaddleTensor &tensor,
case PaddleDType::INT64: case PaddleDType::INT64:
os << "int64"; os << "int64";
break; break;
case PaddleDType::INT32:
os << "int32";
break;
default: default:
os << "unset"; os << "unset";
} }
......
...@@ -36,6 +36,7 @@ namespace paddle { ...@@ -36,6 +36,7 @@ namespace paddle {
enum PaddleDType { enum PaddleDType {
FLOAT32, FLOAT32,
INT64, INT64,
INT32,
// TODO(Superjomn) support more data types if needed. // TODO(Superjomn) support more data types if needed.
}; };
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#ifdef WITH_GPERFTOOLS #ifdef WITH_GPERFTOOLS
#include <gperftools/profiler.h> #include <gperftools/profiler.h>
#endif #endif
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/analysis/analyzer.h"
...@@ -97,6 +96,14 @@ void CompareResult(const std::vector<PaddleTensor> &outputs, ...@@ -97,6 +96,14 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
} }
break; break;
} }
case PaddleDType::INT32: {
int32_t *pdata = static_cast<int32_t *>(out.data.data());
int32_t *pdata_ref = static_cast<int32_t *>(ref_out.data.data());
for (size_t j = 0; j < size; ++j) {
EXPECT_EQ(pdata_ref[j], pdata[j]);
}
break;
}
} }
} }
} }
......
...@@ -65,7 +65,8 @@ void BindInferenceApi(py::module *m) { ...@@ -65,7 +65,8 @@ void BindInferenceApi(py::module *m) {
void BindPaddleDType(py::module *m) { void BindPaddleDType(py::module *m) {
py::enum_<PaddleDType>(*m, "PaddleDType") py::enum_<PaddleDType>(*m, "PaddleDType")
.value("FLOAT32", PaddleDType::FLOAT32) .value("FLOAT32", PaddleDType::FLOAT32)
.value("INT64", PaddleDType::INT64); .value("INT64", PaddleDType::INT64)
.value("INT32", PaddleDType::INT32);
} }
void BindPaddleBuf(py::module *m) { void BindPaddleBuf(py::module *m) {
...@@ -103,6 +104,11 @@ void BindPaddleBuf(py::module *m) { ...@@ -103,6 +104,11 @@ void BindPaddleBuf(py::module *m) {
int64_t *data = static_cast<int64_t *>(self.data()); int64_t *data = static_cast<int64_t *>(self.data());
return {data, data + self.length() / sizeof(*data)}; return {data, data + self.length() / sizeof(*data)};
}) })
.def("int32_data",
[](PaddleBuf &self) -> std::vector<int32_t> {
int32_t *data = static_cast<int32_t *>(self.data());
return {data, data + self.length() / sizeof(*data)};
})
.def("length", &PaddleBuf::length); .def("length", &PaddleBuf::length);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册