diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 26247026667158a2f43cdac21bf5600479455e16..e05667d2c7e9ce5c64cfacee4919cd36d7383c0c 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -180,8 +180,14 @@ void BindNativePredictor(py::module *m) { } void BindAnalysisConfig(py::module *m) { - py::class_(*m, "AnalysisConfig") - .def(py::init()) + py::class_ analysis_config(*m, "AnalysisConfig"); + + py::enum_(analysis_config, "Precision") + .value("Float32", AnalysisConfig::Precision::kFloat32) + .value("Int8", AnalysisConfig::Precision::kInt8) + .export_values(); + + analysis_config.def(py::init()) .def(py::init()) .def(py::init()) .def("set_model", (void (AnalysisConfig::*)(const std::string &)) & @@ -215,7 +221,8 @@ void BindAnalysisConfig(py::module *m) { .def("specify_input_name", &AnalysisConfig::specify_input_name) .def("enable_tensorrt_engine", &AnalysisConfig::EnableTensorRtEngine, py::arg("workspace_size") = 1 << 20, py::arg("max_batch_size") = 1, - py::arg("min_subgraph_size") = 3) + py::arg("min_subgraph_size") = 3, + py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32) .def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled) .def("switch_ir_debug", &AnalysisConfig::SwitchIrDebug, py::arg("x") = true)