diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index cb92bb8211b25f436c1c3a0da014f1dc40520fbb..b58c60e96a0bd6695b827e7063fa7a07f42fe586 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -243,6 +243,8 @@ bool AnalysisPredictor::SetFeed(const std::vector &inputs, input_ptr = input.mutable_data(ddim, place_); } else if (inputs[i].dtype == PaddleDType::FLOAT32) { input_ptr = input.mutable_data(ddim, place_); + } else if (inputs[i].dtype == PaddleDType::INT32) { + input_ptr = input.mutable_data(ddim, place_); } else { LOG(ERROR) << "unsupported feed type " << inputs[i].dtype; return false; @@ -326,8 +328,11 @@ bool AnalysisPredictor::GetFetch(std::vector *outputs, } else if (type == framework::proto::VarType::INT64) { GetFetchOne(fetch, output); output->dtype = PaddleDType::INT64; + } else if (type == framework::proto::VarType::INT32) { + GetFetchOne(fetch, output); + output->dtype = PaddleDType::INT32; } else { - LOG(ERROR) << "unknown type, only support float32 and int64 now."; + LOG(ERROR) << "unknown type, only support float32, int64 and int32 now."; } } return true; diff --git a/paddle/fluid/inference/api/api.cc b/paddle/fluid/inference/api/api.cc index f83537f064187e67a08c8bbce52707d1c824abeb..7d57b6ec74468dbdb0519f85140629a0ac01c18d 100644 --- a/paddle/fluid/inference/api/api.cc +++ b/paddle/fluid/inference/api/api.cc @@ -28,6 +28,8 @@ int PaddleDtypeSize(PaddleDType dtype) { return sizeof(float); case PaddleDType::INT64: return sizeof(int64_t); + case PaddleDType::INT32: + return sizeof(int32_t); default: assert(false); return -1; diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index 048286a843f0190a8139cb86eda4f3a3a40d89a1..54f40563c3662af24e794422be4d3262d86c76a7 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -203,6 +203,8 @@ bool NativePaddlePredictor::SetFeed(const std::vector &inputs, input_ptr = input.mutable_data(ddim, place_); } else if (inputs[i].dtype == PaddleDType::FLOAT32) { input_ptr = input.mutable_data(ddim, place_); + } else if (inputs[i].dtype == PaddleDType::INT32) { + input_ptr = input.mutable_data(ddim, place_); } else { LOG(ERROR) << "unsupported feed type " << inputs[i].dtype; return false; @@ -281,8 +283,11 @@ bool NativePaddlePredictor::GetFetch(std::vector *outputs, } else if (type == framework::DataTypeTrait::DataType) { GetFetchOne(fetch, output); output->dtype = PaddleDType::INT64; + } else if (type == framework::DataTypeTrait::DataType) { + GetFetchOne(fetch, output); + output->dtype = PaddleDType::INT32; } else { - LOG(ERROR) << "unknown type, only support float32 and int64 now."; + LOG(ERROR) << "unknown type, only support float32, int64 and int32 now."; } } return true; diff --git a/paddle/fluid/inference/api/api_impl_tester.cc b/paddle/fluid/inference/api/api_impl_tester.cc index e82cb53bf073d3d1ab9a518218edaf430728463f..2dc5dda34d02c6df9c0ccbc47a1ac960e1aca3f5 100644 --- a/paddle/fluid/inference/api/api_impl_tester.cc +++ b/paddle/fluid/inference/api/api_impl_tester.cc @@ -42,6 +42,9 @@ PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) { } else if (t->type() == framework::proto::VarType::FP32) { pt.data.Reset(t->data(), t->numel() * sizeof(float)); pt.dtype = PaddleDType::FLOAT32; + } else if (t->type() == framework::proto::VarType::INT32) { + pt.data.Reset(t->data(), t->numel() * sizeof(int32_t)); + pt.dtype = PaddleDType::INT32; } else { LOG(FATAL) << "unsupported type."; } diff --git a/paddle/fluid/inference/api/demo_ci/utils.h b/paddle/fluid/inference/api/demo_ci/utils.h index d70c6aea791219a40c3164b51499f9d5e562be71..1505a898c5bba285b377203c1503b8615666b196 100644 --- a/paddle/fluid/inference/api/demo_ci/utils.h +++ b/paddle/fluid/inference/api/demo_ci/utils.h @@ -88,13 +88,20 @@ void CheckOutput(const std::string& referfile, const PaddleTensor& output) { } break; } - case PaddleDType::FLOAT32: + case PaddleDType::FLOAT32: { for (size_t i = 0; i < numel; ++i) { CHECK_LT( fabs(static_cast(output.data.data())[i] - refer.data[i]), 1e-5); } break; + } + case PaddleDType::INT32: { + for (size_t i = 0; i < numel; ++i) { + CHECK_EQ(static_cast(output.data.data())[i], refer.data[i]); + } + break; + } } } @@ -113,11 +120,18 @@ static std::string SummaryTensor(const PaddleTensor& tensor) { } break; } - case PaddleDType::FLOAT32: + case PaddleDType::FLOAT32: { for (int i = 0; i < std::min(num_elems, 10); i++) { ss << static_cast(tensor.data.data())[i] << " "; } break; + } + case PaddleDType::INT32: { + for (int i = 0; i < std::min(num_elems, 10); i++) { + ss << static_cast(tensor.data.data())[i] << " "; + } + break; + } } return ss.str(); } diff --git a/paddle/fluid/inference/api/helper.h b/paddle/fluid/inference/api/helper.h index ec3bef42fd91cea04a656dd38a4e5c45c1a76476..f65a8b89818663940f84ed32d6fd16e8ba8e401e 100644 --- a/paddle/fluid/inference/api/helper.h +++ b/paddle/fluid/inference/api/helper.h @@ -202,6 +202,9 @@ static std::string DescribeTensor(const PaddleTensor &tensor, case PaddleDType::INT64: os << "int64"; break; + case PaddleDType::INT32: + os << "int32"; + break; default: os << "unset"; } diff --git a/paddle/fluid/inference/api/paddle_api.h b/paddle/fluid/inference/api/paddle_api.h index f807289f6aee06e3ff61bc0dd92f47c599421354..703fd18069474f28b29c6f16c6308fc19bd3527f 100644 --- a/paddle/fluid/inference/api/paddle_api.h +++ b/paddle/fluid/inference/api/paddle_api.h @@ -36,6 +36,7 @@ namespace paddle { enum PaddleDType { FLOAT32, INT64, + INT32, // TODO(Superjomn) support more data types if needed. }; diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index 2e53fddfe7f6f0c5b31ff069fb1661f143022841..41daff83c482c5f95d02afee9637d19d469ca507 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -25,7 +25,6 @@ #ifdef WITH_GPERFTOOLS #include #endif - #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/inference/analysis/analyzer.h" @@ -97,6 +96,14 @@ void CompareResult(const std::vector &outputs, } break; } + case PaddleDType::INT32: { + int32_t *pdata = static_cast(out.data.data()); + int32_t *pdata_ref = static_cast(ref_out.data.data()); + for (size_t j = 0; j < size; ++j) { + EXPECT_EQ(pdata_ref[j], pdata[j]); + } + break; + } } } } diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 03c1b0bd092181e4f20bf8944823c688ff98d65f..236afc77f708c344665821edd4f7c7841c300465 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -65,7 +65,8 @@ void BindInferenceApi(py::module *m) { void BindPaddleDType(py::module *m) { py::enum_(*m, "PaddleDType") .value("FLOAT32", PaddleDType::FLOAT32) - .value("INT64", PaddleDType::INT64); + .value("INT64", PaddleDType::INT64) + .value("INT32", PaddleDType::INT32); } void BindPaddleBuf(py::module *m) { @@ -103,6 +104,11 @@ void BindPaddleBuf(py::module *m) { int64_t *data = static_cast(self.data()); return {data, data + self.length() / sizeof(*data)}; }) + .def("int32_data", + [](PaddleBuf &self) -> std::vector { + int32_t *data = static_cast(self.data()); + return {data, data + self.length() / sizeof(*data)}; + }) .def("length", &PaddleBuf::length); }