// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "core/general-client/include/general_model.h" namespace py = pybind11; namespace baidu { namespace paddle_serving { namespace general_model { PYBIND11_MODULE(serving_client, m) { m.doc() = R"pddoc(this is a practice )pddoc"; py::class_(m, "PredictorRes", py::buffer_protocol()) .def(py::init()) .def("get_int64_by_name", [](PredictorRes &self, int model_idx, std::string &name) { // see more: https://github.com/pybind/pybind11/issues/1042 std::vector *ptr = new std::vector( std::move(self.get_int64_by_name_with_rv(model_idx, name))); auto capsule = py::capsule(ptr, [](void *p) { delete reinterpret_cast *>(p); }); return py::array(ptr->size(), ptr->data(), capsule); }) .def("get_float_by_name", [](PredictorRes &self, int model_idx, std::string &name) { std::vector *ptr = new std::vector( std::move(self.get_float_by_name_with_rv(model_idx, name))); auto capsule = py::capsule(ptr, [](void *p) { delete reinterpret_cast *>(p); }); return py::array(ptr->size(), ptr->data(), capsule); }) .def("get_shape", [](PredictorRes &self, int model_idx, std::string &name) { std::vector *ptr = new std::vector( std::move(self.get_shape_by_name_with_rv(model_idx, name))); auto capsule = py::capsule(ptr, [](void *p) { delete reinterpret_cast *>(p); }); return py::array(ptr->size(), ptr->data(), capsule); }) .def("get_lod", [](PredictorRes &self, int model_idx, std::string &name) { std::vector *ptr = new std::vector( std::move(self.get_lod_by_name_with_rv(model_idx, name))); auto capsule = py::capsule(ptr, [](void *p) { delete reinterpret_cast *>(p); }); return py::array(ptr->size(), ptr->data(), capsule); }) .def("variant_tag", [](PredictorRes &self) { return self.variant_tag(); }) .def("get_engine_names", [](PredictorRes &self) { return self.get_engine_names(); }); py::class_(m, "PredictorClient", py::buffer_protocol()) .def(py::init()) .def("init_gflags", [](PredictorClient &self, std::vector argv) { self.init_gflags(argv); }) .def("init", [](PredictorClient &self, const std::string &conf) { return self.init(conf); }) .def("set_predictor_conf", [](PredictorClient &self, const std::string &conf_path, const std::string &conf_file) { self.set_predictor_conf(conf_path, conf_file); }) .def("create_predictor_by_desc", [](PredictorClient &self, const std::string &sdk_desc) { self.create_predictor_by_desc(sdk_desc); }) .def("create_predictor", [](PredictorClient &self) { self.create_predictor(); }) .def("destroy_predictor", [](PredictorClient &self) { self.destroy_predictor(); }) .def("batch_predict", [](PredictorClient &self, const std::vector>> &float_feed_batch, const std::vector &float_feed_name, const std::vector> &float_shape, const std::vector>> &int_feed_batch, const std::vector &int_feed_name, const std::vector> &int_shape, const std::vector &fetch_name, PredictorRes &predict_res_batch, const int &pid, const uint64_t log_id) { return self.batch_predict(float_feed_batch, float_feed_name, float_shape, int_feed_batch, int_feed_name, int_shape, fetch_name, predict_res_batch, pid, log_id); }, py::call_guard()) .def("numpy_predict", [](PredictorClient &self, const std::vector>> &float_feed_batch, const std::vector &float_feed_name, const std::vector> &float_shape, const std::vector>> &int_feed_batch, const std::vector &int_feed_name, const std::vector> &int_shape, const std::vector &fetch_name, PredictorRes &predict_res_batch, const int &pid, const uint64_t log_id) { return self.numpy_predict(float_feed_batch, float_feed_name, float_shape, int_feed_batch, int_feed_name, int_shape, fetch_name, predict_res_batch, pid, log_id); }, py::call_guard()); } } // namespace general_model } // namespace paddle_serving } // namespace baidu