diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index f853738160b2b307bf6ce9cc6cf5b9969d72bf48..c13b7624a3bdf839a5d74b5dbd822c0ce78e688e 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -109,7 +109,7 @@ class AnalysisPredictor : public PaddlePredictor { // negative sharing_identifier directly. In the future, this may affect // the meaning of negative predictor id. predictor_id_ = -trt_identifier; - LOG(WARNING) + LOG_FIRST_N(WARNING, 1) << "Since the engine context memory of multiple predictors " "is enabled in Paddle-TRT, we set the id of current predictor to " "negative sharing_identifier you specified."; diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor.cc b/paddle/fluid/inference/api/details/zero_copy_tensor.cc index b87d7b361136223c6c878d636faf4fc2ebe29be1..e7cda8707c872471c7e54816652c24d765077302 100644 --- a/paddle/fluid/inference/api/details/zero_copy_tensor.cc +++ b/paddle/fluid/inference/api/details/zero_copy_tensor.cc @@ -176,6 +176,8 @@ DataType Tensor::type() const { return DataType::UINT8; } else if (type == paddle::framework::proto::VarType::INT8) { return DataType::INT8; + } else if (type == paddle::framework::proto::VarType::BOOL) { + return DataType::BOOL; } return DataType::FLOAT32; } @@ -279,6 +281,11 @@ void Tensor::CopyFromCpu(const T *data) { template struct DataTypeInfo; +template <> +struct DataTypeInfo { + paddle::experimental::DataType TYPE = paddle::experimental::DataType::BOOL; +}; + template <> struct DataTypeInfo { paddle::experimental::DataType TYPE = paddle::experimental::DataType::FLOAT32; @@ -513,6 +520,7 @@ template PD_INFER_DECL void Tensor::CopyFromCpu(const int32_t *data); template PD_INFER_DECL void Tensor::CopyFromCpu(const uint8_t *data); template PD_INFER_DECL void Tensor::CopyFromCpu(const int8_t *data); template PD_INFER_DECL void Tensor::CopyFromCpu(const float16 *data); +template PD_INFER_DECL void Tensor::CopyFromCpu(const bool *data); template PD_INFER_DECL void Tensor::ShareExternalData( const float *data, @@ -544,6 +552,11 @@ template PD_INFER_DECL void Tensor::ShareExternalData( const std::vector &shape, PlaceType place, DataLayout layout); +template PD_INFER_DECL void Tensor::ShareExternalData( + const bool *data, + const std::vector &shape, + PlaceType place, + DataLayout layout); template PD_INFER_DECL void Tensor::CopyToCpu(float *data) const; template PD_INFER_DECL void Tensor::CopyToCpu(int64_t *data) const; @@ -551,6 +564,7 @@ template PD_INFER_DECL void Tensor::CopyToCpu(int32_t *data) const; template PD_INFER_DECL void Tensor::CopyToCpu(uint8_t *data) const; template PD_INFER_DECL void Tensor::CopyToCpu(int8_t *data) const; template PD_INFER_DECL void Tensor::CopyToCpu(float16 *data) const; +template PD_INFER_DECL void Tensor::CopyToCpu(bool *data) const; template PD_INFER_DECL void Tensor::CopyToCpuImpl(float *data, void *exec_stream, @@ -566,6 +580,10 @@ template PD_INFER_DECL void Tensor::CopyToCpuImpl( int8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const; template PD_INFER_DECL void Tensor::CopyToCpuImpl( float16 *data, void *exec_stream, CallbackFunc cb, void *cb_params) const; +template PD_INFER_DECL void Tensor::CopyToCpuImpl(bool *data, + void *exec_stream, + CallbackFunc cb, + void *cb_params) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync( float *data, void *exec_stream) const; @@ -579,6 +597,8 @@ template PD_INFER_DECL void Tensor::CopyToCpuAsync( int8_t *data, void *exec_stream) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync( float16 *data, void *exec_stream) const; +template PD_INFER_DECL void Tensor::CopyToCpuAsync( + bool *data, void *exec_stream) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync( float *data, CallbackFunc cb, void *cb_params) const; @@ -592,6 +612,9 @@ template PD_INFER_DECL void Tensor::CopyToCpuAsync( int8_t *data, CallbackFunc cb, void *cb_params) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync( float16 *data, CallbackFunc cb, void *cb_params) const; +template PD_INFER_DECL void Tensor::CopyToCpuAsync(bool *data, + CallbackFunc cb, + void *cb_params) const; template PD_INFER_DECL float *Tensor::data(PlaceType *place, int *size) const; @@ -605,6 +628,8 @@ template PD_INFER_DECL int8_t *Tensor::data(PlaceType *place, int *size) const; template PD_INFER_DECL float16 *Tensor::data(PlaceType *place, int *size) const; +template PD_INFER_DECL bool *Tensor::data(PlaceType *place, + int *size) const; template PD_INFER_DECL float *Tensor::mutable_data(PlaceType place); template PD_INFER_DECL int64_t *Tensor::mutable_data(PlaceType place); @@ -612,6 +637,7 @@ template PD_INFER_DECL int32_t *Tensor::mutable_data(PlaceType place); template PD_INFER_DECL uint8_t *Tensor::mutable_data(PlaceType place); template PD_INFER_DECL int8_t *Tensor::mutable_data(PlaceType place); template PD_INFER_DECL float16 *Tensor::mutable_data(PlaceType place); +template PD_INFER_DECL bool *Tensor::mutable_data(PlaceType place); Tensor::Tensor(void *scope, const void *device_contexts) : scope_{scope}, device_contexs_(device_contexts) {} @@ -895,6 +921,8 @@ template void InternalUtils::CopyFromCpuWithIoStream( paddle_infer::Tensor *t, const int8_t *data, cudaStream_t stream); template void InternalUtils::CopyFromCpuWithIoStream( paddle_infer::Tensor *t, const float16 *data, cudaStream_t stream); +template void InternalUtils::CopyFromCpuWithIoStream( + paddle_infer::Tensor *t, const bool *data, cudaStream_t stream); template void InternalUtils::CopyToCpuWithIoStream( paddle_infer::Tensor *t, float *data, cudaStream_t stream); @@ -908,6 +936,8 @@ template void InternalUtils::CopyToCpuWithIoStream( paddle_infer::Tensor *t, int8_t *data, cudaStream_t stream); template void InternalUtils::CopyToCpuWithIoStream( paddle_infer::Tensor *t, float16 *data, cudaStream_t stream); +template void InternalUtils::CopyToCpuWithIoStream( + paddle_infer::Tensor *t, bool *data, cudaStream_t stream); } // namespace experimental diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 0adeaf356de0ac2a131de1e8845a2e6d66a0b44b..eab8818a1df009feea4e391d6c2f75a819b8a597 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -161,7 +161,7 @@ struct PD_INFER_DECL AnalysisConfig { explicit AnalysisConfig(const std::string& prog_file, const std::string& params_file); /// - /// \brief Precision of inference in TensorRT. + /// \brief Precision of inference. /// enum class Precision { kFloat32 = 0, ///< fp32 diff --git a/paddle/fluid/inference/api/paddle_tensor.h b/paddle/fluid/inference/api/paddle_tensor.h index 9bc95f251eb60c7fbc21a8e31012b31bd24f43a6..3748ba7338756fbf65d4f82bbe899af2cca9b9db 100644 --- a/paddle/fluid/inference/api/paddle_tensor.h +++ b/paddle/fluid/inference/api/paddle_tensor.h @@ -52,13 +52,14 @@ class InternalUtils; /// \brief Paddle data type. enum DataType { - FLOAT32, INT64, INT32, UINT8, INT8, + FLOAT32, FLOAT16, - // TODO(Superjomn) support more data types if needed. + BOOL, + // TODO(Inference): support more data types if needed. }; enum class PlaceType { kUNK = -1, kCPU, kGPU, kXPU, kNPU, kIPU, kCUSTOM }; diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index d314a9a7835190643b165ae287a52531d87b4b9d..80c4751ad565fcce9acab3a3c902ff395df834e5 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -175,16 +175,22 @@ py::dtype PaddleDTypeToNumpyDType(PaddleDType dtype) { case PaddleDType::FLOAT32: dt = py::dtype::of(); break; + case PaddleDType::FLOAT16: + dt = py::dtype::of(); + break; case PaddleDType::UINT8: dt = py::dtype::of(); break; - case PaddleDType::FLOAT16: - dt = py::dtype::of(); + case PaddleDType::INT8: + dt = py::dtype::of(); + break; + case PaddleDType::BOOL: + dt = py::dtype::of(); break; default: PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported data type. Now only supports INT32, INT64, UINT8 and " - "FLOAT32.")); + "Unsupported data type. Now only supports INT32, INT64, FLOAT32, " + "FLOAT16, INT8, UINT8 and BOOL.")); } return dt; @@ -282,10 +288,22 @@ size_t PaddleGetDTypeSize(PaddleDType dt) { case PaddleDType::FLOAT32: size = sizeof(float); break; + case PaddleDType::FLOAT16: + size = sizeof(paddle_infer::float16); + break; + case PaddleDType::INT8: + size = sizeof(int8_t); + break; + case PaddleDType::UINT8: + size = sizeof(uint8_t); + break; + case PaddleDType::BOOL: + size = sizeof(bool); + break; default: PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported data type. Now only supports INT32, INT64 and " - "FLOAT32.")); + "Unsupported data t ype. Now only supports INT32, INT64, FLOAT32, " + "FLOAT16, INT8, UINT8 and BOOL.")); } return size; } @@ -316,10 +334,13 @@ py::array ZeroCopyTensorToNumpy(ZeroCopyTensor &tensor) { // NOLINT case PaddleDType::INT8: tensor.copy_to_cpu(static_cast(array.mutable_data())); break; + case PaddleDType::BOOL: + tensor.copy_to_cpu(static_cast(array.mutable_data())); + break; default: PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported data type. Now only supports INT32, INT64, UINT8 and " - "FLOAT32.")); + "Unsupported data type. Now only supports INT32, INT64, FLOAT32, " + "FLOAT16, INT8, UINT8 and BOOL.")); } return array; } @@ -350,10 +371,13 @@ py::array PaddleInferTensorToNumpy(paddle_infer::Tensor &tensor) { // NOLINT case PaddleDType::INT8: tensor.CopyToCpu(static_cast(array.mutable_data())); break; + case PaddleDType::BOOL: + tensor.CopyToCpu(static_cast(array.mutable_data())); + break; default: PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported data type. Now only supports INT32, INT64 and " - "FLOAT32.")); + "Unsupported data t ype. Now only supports INT32, INT64, FLOAT32, " + "FLOAT16, INT8, UINT8 and BOOL.")); } return array; } @@ -433,8 +457,12 @@ namespace { void BindPaddleDType(py::module *m) { py::enum_(*m, "PaddleDType") .value("FLOAT32", PaddleDType::FLOAT32) + .value("FLOAT16", PaddleDType::FLOAT16) .value("INT64", PaddleDType::INT64) - .value("INT32", PaddleDType::INT32); + .value("INT32", PaddleDType::INT32) + .value("UINT8", PaddleDType::UINT8) + .value("INT8", PaddleDType::INT8) + .value("BOOL", PaddleDType::BOOL); } void BindPaddleDataLayout(py::module *m) { @@ -538,7 +566,8 @@ void BindPaddlePlace(py::module *m) { .value("CPU", PaddlePlace::kCPU) .value("GPU", PaddlePlace::kGPU) .value("XPU", PaddlePlace::kXPU) - .value("NPU", PaddlePlace::kNPU); + .value("NPU", PaddlePlace::kNPU) + .value("CUSTOM", PaddlePlace::kCUSTOM); } void BindPaddlePredictor(py::module *m) { @@ -990,10 +1019,13 @@ void BindZeroCopyTensor(py::module *m) { .def("reshape", py::overload_cast( &paddle_infer::Tensor::ReshapeStrings)) + .def("copy_from_cpu", &ZeroCopyTensorCreate) + .def("copy_from_cpu", &ZeroCopyTensorCreate) .def("copy_from_cpu", &ZeroCopyTensorCreate) .def("copy_from_cpu", &ZeroCopyTensorCreate) .def("copy_from_cpu", &ZeroCopyTensorCreate) .def("copy_from_cpu", &ZeroCopyTensorCreate) + .def("copy_from_cpu", &ZeroCopyTensorCreate) .def("copy_from_cpu", &ZeroCopyStringTensorCreate) .def("copy_to_cpu", &ZeroCopyTensorToNumpy) .def("shape", &ZeroCopyTensor::shape) @@ -1010,11 +1042,14 @@ void BindPaddleInferTensor(py::module *m) { .def("reshape", py::overload_cast( &paddle_infer::Tensor::ReshapeStrings)) + .def("copy_from_cpu_bind", &PaddleInferTensorCreate) + .def("copy_from_cpu_bind", &PaddleInferTensorCreate) .def("copy_from_cpu_bind", &PaddleInferTensorCreate) .def("copy_from_cpu_bind", &PaddleInferTensorCreate) .def("copy_from_cpu_bind", &PaddleInferTensorCreate) .def("copy_from_cpu_bind", &PaddleInferTensorCreate) + .def("copy_from_cpu_bind", &PaddleInferTensorCreate) .def("copy_from_cpu_bind", &PaddleInferStringTensorCreate) .def("share_external_data_bind", &PaddleInferShareExternalData) .def("copy_to_cpu", &PaddleInferTensorToNumpy)