From c585f678d7e586cd308c5af803a2dbcf4b33c89a Mon Sep 17 00:00:00 2001 From: dongdaxiang Date: Mon, 4 May 2020 10:16:46 +0800 Subject: [PATCH] add general_get_conf_str op, add get_config_str client api, refine request and response proto to add new workflow --- core/general-client/include/general_model.h | 2 + core/general-client/src/general_model.cpp | 14 ++++++ .../src/pybind_general_model.cpp | 2 + .../general-server/op/general_get_conf_op.cpp | 43 +++++++++++++++++++ core/general-server/op/general_get_conf_op.h | 37 ++++++++++++++++ .../proto/general_model_service.proto | 2 + core/predictor/framework/resource.cpp | 11 +++++ core/predictor/framework/resource.h | 3 ++ .../sdk-cpp/proto/general_model_service.proto | 2 + python/paddle_serving_client/__init__.py | 39 +++++++++++++++++ 10 files changed, 155 insertions(+) create mode 100644 core/general-server/op/general_get_conf_op.cpp create mode 100644 core/general-server/op/general_get_conf_op.h diff --git a/core/general-client/include/general_model.h b/core/general-client/include/general_model.h index 7e04ae11..5ae01323 100644 --- a/core/general-client/include/general_model.h +++ b/core/general-client/include/general_model.h @@ -167,6 +167,8 @@ class PredictorClient { int destroy_predictor(); + const std::string& get_model_config(); + int batch_predict( const std::vector>>& float_feed_batch, const std::vector& float_feed_name, diff --git a/core/general-client/src/general_model.cpp b/core/general-client/src/general_model.cpp index 86f75bc1..eff3aba7 100644 --- a/core/general-client/src/general_model.cpp +++ b/core/general-client/src/general_model.cpp @@ -135,6 +135,20 @@ int PredictorClient::create_predictor() { return 0; } +const std::string &PredictorClient::get_model_config() { + Request req; + Response res; + req.set_request_type("GetConf"); + if (_predictor->inference(&req, &res) != 0) { + LOG(ERROR) << "failed call predictor with req: " << req.ShortDebugString(); + _api.thrd_clear(); + return ""; + } else { + const std::string &config_str = res.config_str(); + return config_str; + } +} + int PredictorClient::batch_predict( const std::vector>> &float_feed_batch, const std::vector &float_feed_name, diff --git a/core/general-client/src/pybind_general_model.cpp b/core/general-client/src/pybind_general_model.cpp index b0d1d2d6..724d9a9c 100644 --- a/core/general-client/src/pybind_general_model.cpp +++ b/core/general-client/src/pybind_general_model.cpp @@ -56,6 +56,8 @@ PYBIND11_MODULE(serving_client, m) { py::class_(m, "PredictorClient", py::buffer_protocol()) .def(py::init()) + .def("get_model_config", + [](PredictorClient &self) { return self.get_model_config(); }) .def("init_gflags", [](PredictorClient &self, std::vector argv) { self.init_gflags(argv); diff --git a/core/general-server/op/general_get_conf_op.cpp b/core/general-server/op/general_get_conf_op.cpp new file mode 100644 index 00000000..64384164 --- /dev/null +++ b/core/general-server/op/general_get_conf_op.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "core/general-server/op/general_get_conf_op.h" +#include +#include +#include +#include + +namespace baidu { +namespace paddle_serving { +namespace serving { + +using baidu::paddle_serving::predictor::general_model::Request; +using baidu::paddle_serving::predictor::general_model::Response; + +int GeneralGetConfOp::inference() { + // reade request from client + const Request *req = dynamic_cast(get_request_message()); + + baidu::paddle_serving::predictor::Resource &resource = + baidu::paddle_serving::predictor::Resource::instance(); + + std::string conf_str = resource.get_general_model_conf_str(); + Response *res = mutable_data(); + res->set_config_str(conf_str); + return 0; +} +DEFINE_OP(GeneralGetConfOp); +} // namespace serving +} // namespace paddle_serving +} // namespace baidu diff --git a/core/general-server/op/general_get_conf_op.h b/core/general-server/op/general_get_conf_op.h new file mode 100644 index 00000000..040b9855 --- /dev/null +++ b/core/general-server/op/general_get_conf_op.h @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "core/general-server/general_model_service.pb.h" +#include "core/general-server/load_general_model_service.pb.h" +#include "core/general-server/op/general_infer_helper.h" +#include "core/predictor/framework/resource.h" + +namespace baidu { +namespace paddle_serving { +namespace serving { + +class GeneralGetConfOp + : public baidu::paddle_serving::predictor::OpWithChannel { + public: + DECLARE_OP(GeneralGetConfOp); + + int inference(); +}; + +} // namespace serving +} // namespace paddle_serving +} // namespace baidu diff --git a/core/general-server/proto/general_model_service.proto b/core/general-server/proto/general_model_service.proto index 8581ecb2..7754c597 100644 --- a/core/general-server/proto/general_model_service.proto +++ b/core/general-server/proto/general_model_service.proto @@ -37,11 +37,13 @@ message Request { repeated FeedInst insts = 1; repeated string fetch_var_names = 2; optional bool profile_server = 3 [ default = false ]; + optional string request_type = 4; }; message Response { repeated ModelOutput outputs = 1; repeated int64 profile_time = 2; + optional string config_str = 3; }; message ModelOutput { diff --git a/core/predictor/framework/resource.cpp b/core/predictor/framework/resource.cpp index ca219519..2f542cd2 100644 --- a/core/predictor/framework/resource.cpp +++ b/core/predictor/framework/resource.cpp @@ -43,6 +43,10 @@ std::shared_ptr Resource::get_general_model_config() { return _config; } +std::string Resource::get_general_model_conf_str() { + return _general_model_conf_str; +} + void Resource::print_general_model_config( const std::shared_ptr& config) { if (config == nullptr) { @@ -210,6 +214,13 @@ int Resource::general_model_initialize(const std::string& path, return -1; } + // save the general model conf string here + bool print_flag = google::protobuf::TextFormat::PrintToString( + model_config, &_general_model_conf_str); + if (!print_flag) { + LOG(ERROR) << "parse model config message into string failed"; + } + _config.reset(new PaddleGeneralModelConfig()); int feed_var_num = model_config.feed_var_size(); VLOG(2) << "load general model config"; diff --git a/core/predictor/framework/resource.h b/core/predictor/framework/resource.h index 56b87666..156c2e6b 100644 --- a/core/predictor/framework/resource.h +++ b/core/predictor/framework/resource.h @@ -96,6 +96,8 @@ class Resource { std::shared_ptr get_general_model_config(); + std::string get_general_model_conf_str(); + void print_general_model_config( const std::shared_ptr& config); @@ -108,6 +110,7 @@ class Resource { private: int thread_finalize() { return 0; } std::shared_ptr _config; + std::string _general_model_conf_str; std::string cube_config_fullpath; int cube_quant_bits; // 0 if no empty diff --git a/core/sdk-cpp/proto/general_model_service.proto b/core/sdk-cpp/proto/general_model_service.proto index 51c0335a..00b3c683 100644 --- a/core/sdk-cpp/proto/general_model_service.proto +++ b/core/sdk-cpp/proto/general_model_service.proto @@ -37,11 +37,13 @@ message Request { repeated FeedInst insts = 1; repeated string fetch_var_names = 2; optional bool profile_server = 3 [ default = false ]; + optional string request_type = 4; }; message Response { repeated ModelOutput outputs = 1; repeated int64 profile_time = 2; + optional string config_str = 3; }; message ModelOutput { diff --git a/python/paddle_serving_client/__init__.py b/python/paddle_serving_client/__init__.py index 33809349..6af7545d 100644 --- a/python/paddle_serving_client/__init__.py +++ b/python/paddle_serving_client/__init__.py @@ -200,6 +200,45 @@ class Client(object): sdk_desc = self.predictor_sdk_.gen_desc() self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString( )) + # try to get model config when doing connectiong + model_config_str = self.client_handle_.get_model_config() + model_conf = m_config.GeneralModelConfig() + model_conf = google.protobuf.text_format.Merge( + str(model_config_str), model_conf) + + self.result_handle_ = PredictorRes() + self.client_handle_ = PredictorClient() + self.client_handle_.init_from_string(model_config_str) + if "FLAGS_max_body_size" not in os.environ: + os.environ["FLAGS_max_body_size"] = str(512 * 1024 * 1024) + self.client_handle_.init_gflags([sys.argv[ + 0]] + ["--tryfromenv=" + ",".join(read_env_flags)]) + self.feed_names_ = [var.alias_name for var in model_conf.feed_var] + self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var] + self.feed_names_to_idx_ = {} + self.fetch_names_to_type_ = {} + self.fetch_names_to_idx_ = {} + self.lod_tensor_set = set() + self.feed_tensor_len = {} + + for i, var in enumerate(model_conf.feed_var): + self.feed_names_to_idx_[var.alias_name] = i + self.feed_types_[var.alias_name] = var.feed_type + self.feed_shapes_[var.alias_name] = var.shape + + if var.is_lod_tensor: + self.lod_tensor_set.add(var.alias_name) + else: + counter = 1 + for dim in self.feed_shapes_[var.alias_name]: + counter *= dim + self.feed_tensor_len[var.alias_name] = counter + for i, var in enumerate(model_conf.fetch_var): + self.fetch_names_to_idx_[var.alias_name] = i + self.fetch_names_to_type_[var.alias_name] = var.fetch_type + if var.is_lod_tensor: + self.lod_tensor_set.add(var.alias_name) + return def get_feed_names(self): return self.feed_names_ -- GitLab