提交 36abc964 编写于 作者: N nhzlx

fix pybind problem: add an enum to AnalysisConfig

test=develop
上级 0779e355
......@@ -180,8 +180,14 @@ void BindNativePredictor(py::module *m) {
}
void BindAnalysisConfig(py::module *m) {
py::class_<AnalysisConfig>(*m, "AnalysisConfig")
.def(py::init<const AnalysisConfig &>())
py::class_<AnalysisConfig> analysis_config(*m, "AnalysisConfig");
py::enum_<AnalysisConfig::Precision>(analysis_config, "Precision")
.value("Float32", AnalysisConfig::Precision::kFloat32)
.value("Int8", AnalysisConfig::Precision::kInt8)
.export_values();
analysis_config.def(py::init<const AnalysisConfig &>())
.def(py::init<const std::string &>())
.def(py::init<const std::string &, const std::string &>())
.def("set_model", (void (AnalysisConfig::*)(const std::string &)) &
......@@ -215,7 +221,8 @@ void BindAnalysisConfig(py::module *m) {
.def("specify_input_name", &AnalysisConfig::specify_input_name)
.def("enable_tensorrt_engine", &AnalysisConfig::EnableTensorRtEngine,
py::arg("workspace_size") = 1 << 20, py::arg("max_batch_size") = 1,
py::arg("min_subgraph_size") = 3)
py::arg("min_subgraph_size") = 3,
py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32)
.def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled)
.def("switch_ir_debug", &AnalysisConfig::SwitchIrDebug,
py::arg("x") = true)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册