提交 c585f678 编写于 作者: D dongdaxiang

add general_get_conf_str op, add get_config_str client api, refine request and...

add general_get_conf_str op, add get_config_str client api, refine request and response proto to add new workflow
上级 1627446b
......@@ -167,6 +167,8 @@ class PredictorClient {
int destroy_predictor();
const std::string& get_model_config();
int batch_predict(
const std::vector<std::vector<std::vector<float>>>& float_feed_batch,
const std::vector<std::string>& float_feed_name,
......
......@@ -135,6 +135,20 @@ int PredictorClient::create_predictor() {
return 0;
}
const std::string &PredictorClient::get_model_config() {
Request req;
Response res;
req.set_request_type("GetConf");
if (_predictor->inference(&req, &res) != 0) {
LOG(ERROR) << "failed call predictor with req: " << req.ShortDebugString();
_api.thrd_clear();
return "";
} else {
const std::string &config_str = res.config_str();
return config_str;
}
}
int PredictorClient::batch_predict(
const std::vector<std::vector<std::vector<float>>> &float_feed_batch,
const std::vector<std::string> &float_feed_name,
......
......@@ -56,6 +56,8 @@ PYBIND11_MODULE(serving_client, m) {
py::class_<PredictorClient>(m, "PredictorClient", py::buffer_protocol())
.def(py::init())
.def("get_model_config",
[](PredictorClient &self) { return self.get_model_config(); })
.def("init_gflags",
[](PredictorClient &self, std::vector<std::string> argv) {
self.init_gflags(argv);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/general-server/op/general_get_conf_op.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
namespace baidu {
namespace paddle_serving {
namespace serving {
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::Response;
int GeneralGetConfOp::inference() {
// reade request from client
const Request *req = dynamic_cast<const Request *>(get_request_message());
baidu::paddle_serving::predictor::Resource &resource =
baidu::paddle_serving::predictor::Resource::instance();
std::string conf_str = resource.get_general_model_conf_str();
Response *res = mutable_data<Response>();
res->set_config_str(conf_str);
return 0;
}
DEFINE_OP(GeneralGetConfOp);
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "core/general-server/general_model_service.pb.h"
#include "core/general-server/load_general_model_service.pb.h"
#include "core/general-server/op/general_infer_helper.h"
#include "core/predictor/framework/resource.h"
namespace baidu {
namespace paddle_serving {
namespace serving {
class GeneralGetConfOp
: public baidu::paddle_serving::predictor::OpWithChannel<GeneralBlob> {
public:
DECLARE_OP(GeneralGetConfOp);
int inference();
};
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
......@@ -37,11 +37,13 @@ message Request {
repeated FeedInst insts = 1;
repeated string fetch_var_names = 2;
optional bool profile_server = 3 [ default = false ];
optional string request_type = 4;
};
message Response {
repeated ModelOutput outputs = 1;
repeated int64 profile_time = 2;
optional string config_str = 3;
};
message ModelOutput {
......
......@@ -43,6 +43,10 @@ std::shared_ptr<PaddleGeneralModelConfig> Resource::get_general_model_config() {
return _config;
}
std::string Resource::get_general_model_conf_str() {
return _general_model_conf_str;
}
void Resource::print_general_model_config(
const std::shared_ptr<PaddleGeneralModelConfig>& config) {
if (config == nullptr) {
......@@ -210,6 +214,13 @@ int Resource::general_model_initialize(const std::string& path,
return -1;
}
// save the general model conf string here
bool print_flag = google::protobuf::TextFormat::PrintToString(
model_config, &_general_model_conf_str);
if (!print_flag) {
LOG(ERROR) << "parse model config message into string failed";
}
_config.reset(new PaddleGeneralModelConfig());
int feed_var_num = model_config.feed_var_size();
VLOG(2) << "load general model config";
......
......@@ -96,6 +96,8 @@ class Resource {
std::shared_ptr<PaddleGeneralModelConfig> get_general_model_config();
std::string get_general_model_conf_str();
void print_general_model_config(
const std::shared_ptr<PaddleGeneralModelConfig>& config);
......@@ -108,6 +110,7 @@ class Resource {
private:
int thread_finalize() { return 0; }
std::shared_ptr<PaddleGeneralModelConfig> _config;
std::string _general_model_conf_str;
std::string cube_config_fullpath;
int cube_quant_bits; // 0 if no empty
......
......@@ -37,11 +37,13 @@ message Request {
repeated FeedInst insts = 1;
repeated string fetch_var_names = 2;
optional bool profile_server = 3 [ default = false ];
optional string request_type = 4;
};
message Response {
repeated ModelOutput outputs = 1;
repeated int64 profile_time = 2;
optional string config_str = 3;
};
message ModelOutput {
......
......@@ -200,6 +200,45 @@ class Client(object):
sdk_desc = self.predictor_sdk_.gen_desc()
self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
))
# try to get model config when doing connectiong
model_config_str = self.client_handle_.get_model_config()
model_conf = m_config.GeneralModelConfig()
model_conf = google.protobuf.text_format.Merge(
str(model_config_str), model_conf)
self.result_handle_ = PredictorRes()
self.client_handle_ = PredictorClient()
self.client_handle_.init_from_string(model_config_str)
if "FLAGS_max_body_size" not in os.environ:
os.environ["FLAGS_max_body_size"] = str(512 * 1024 * 1024)
self.client_handle_.init_gflags([sys.argv[
0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.feed_names_to_idx_ = {}
self.fetch_names_to_type_ = {}
self.fetch_names_to_idx_ = {}
self.lod_tensor_set = set()
self.feed_tensor_len = {}
for i, var in enumerate(model_conf.feed_var):
self.feed_names_to_idx_[var.alias_name] = i
self.feed_types_[var.alias_name] = var.feed_type
self.feed_shapes_[var.alias_name] = var.shape
if var.is_lod_tensor:
self.lod_tensor_set.add(var.alias_name)
else:
counter = 1
for dim in self.feed_shapes_[var.alias_name]:
counter *= dim
self.feed_tensor_len[var.alias_name] = counter
for i, var in enumerate(model_conf.fetch_var):
self.fetch_names_to_idx_[var.alias_name] = i
self.fetch_names_to_type_[var.alias_name] = var.fetch_type
if var.is_lod_tensor:
self.lod_tensor_set.add(var.alias_name)
return
def get_feed_names(self):
return self.feed_names_
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册