From 36abc964dff01156119be4c87282a7142ee1998c Mon Sep 17 00:00:00 2001 From: nhzlx Date: Fri, 25 Jan 2019 04:02:01 +0000 Subject: [PATCH] fix pybind problem: add an enum to AnalysisConfig test=develop --- paddle/fluid/pybind/inference_api.cc | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 2624702666..e05667d2c7 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -180,8 +180,14 @@ void BindNativePredictor(py::module *m) { } void BindAnalysisConfig(py::module *m) { - py::class_(*m, "AnalysisConfig") - .def(py::init()) + py::class_ analysis_config(*m, "AnalysisConfig"); + + py::enum_(analysis_config, "Precision") + .value("Float32", AnalysisConfig::Precision::kFloat32) + .value("Int8", AnalysisConfig::Precision::kInt8) + .export_values(); + + analysis_config.def(py::init()) .def(py::init()) .def(py::init()) .def("set_model", (void (AnalysisConfig::*)(const std::string &)) & @@ -215,7 +221,8 @@ void BindAnalysisConfig(py::module *m) { .def("specify_input_name", &AnalysisConfig::specify_input_name) .def("enable_tensorrt_engine", &AnalysisConfig::EnableTensorRtEngine, py::arg("workspace_size") = 1 << 20, py::arg("max_batch_size") = 1, - py::arg("min_subgraph_size") = 3) + py::arg("min_subgraph_size") = 3, + py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32) .def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled) .def("switch_ir_debug", &AnalysisConfig::SwitchIrDebug, py::arg("x") = true) -- GitLab