提交 bfdd07b0 编写于 作者: D dongdaxiang

refine this pr

上级 1021499c
...@@ -190,12 +190,19 @@ int PredictorClient::create_predictor() { ...@@ -190,12 +190,19 @@ int PredictorClient::create_predictor() {
const std::string &PredictorClient::get_model_config() { const std::string &PredictorClient::get_model_config() {
Request req; Request req;
Response res; Response res;
VLOG(2) << "going to send request";
req.set_request_type("GetConf"); req.set_request_type("GetConf");
_api.thrd_initialize();
std::string variant_tag;
_predictor = _api.fetch_predictor("general_model", &variant_tag);
VLOG(2) << "sending";
if (_predictor->inference(&req, &res) != 0) { if (_predictor->inference(&req, &res) != 0) {
LOG(ERROR) << "failed call predictor with req: " << req.ShortDebugString(); LOG(ERROR) << "failed call predictor with req: " << req.ShortDebugString();
_api.thrd_clear(); _api.thrd_clear();
return ""; return "";
} else { } else {
VLOG(2) << "get model config succeed";
const std::string &config_str = res.config_str(); const std::string &config_str = res.config_str();
return config_str; return config_str;
} }
......
...@@ -27,16 +27,24 @@ using baidu::paddle_serving::predictor::general_model::Response; ...@@ -27,16 +27,24 @@ using baidu::paddle_serving::predictor::general_model::Response;
int GeneralGetConfOp::inference() { int GeneralGetConfOp::inference() {
// reade request from client // reade request from client
VLOG(2) << "going to get request";
const Request *req = dynamic_cast<const Request *>(get_request_message()); const Request *req = dynamic_cast<const Request *>(get_request_message());
VLOG(2) << "request got";
baidu::paddle_serving::predictor::Resource &resource = baidu::paddle_serving::predictor::Resource &resource =
baidu::paddle_serving::predictor::Resource::instance(); baidu::paddle_serving::predictor::Resource::instance();
std::string conf_str = resource.get_general_model_conf_str(); VLOG(2) << "request type : " << req->request_type();
VLOG(2) << "fetching conf str";
const std::string &conf_str = resource.get_general_model_conf_str();
VLOG(2) << conf_str;
Response *res = mutable_data<Response>(); Response *res = mutable_data<Response>();
res->set_config_str(conf_str); res->set_config_str(conf_str.c_str());
VLOG(2) << "done.";
return 0; return 0;
} }
DEFINE_OP(GeneralGetConfOp); DEFINE_OP(GeneralGetConfOp);
} // namespace serving } // namespace serving
} // namespace paddle_serving } // namespace paddle_serving
......
...@@ -136,7 +136,7 @@ const std::string& InferService::name() const { return _infer_service_format; } ...@@ -136,7 +136,7 @@ const std::string& InferService::name() const { return _infer_service_format; }
int InferService::inference(const google::protobuf::Message* request, int InferService::inference(const google::protobuf::Message* request,
google::protobuf::Message* response, google::protobuf::Message* response,
butil::IOBufBuilder* debug_os) { butil::IOBufBuilder* debug_os) {
TRACEPRINTF("start to inference"); VLOG(2) << "start to inference";
// when funtion call begins, framework will reset // when funtion call begins, framework will reset
// thread local variables&resources automatically. // thread local variables&resources automatically.
if (Resource::instance().thread_clear() != 0) { if (Resource::instance().thread_clear() != 0) {
...@@ -144,7 +144,7 @@ int InferService::inference(const google::protobuf::Message* request, ...@@ -144,7 +144,7 @@ int InferService::inference(const google::protobuf::Message* request,
return ERR_INTERNAL_FAILURE; return ERR_INTERNAL_FAILURE;
} }
TRACEPRINTF("finish to thread clear"); VLOG(2) << "finish to thread clear";
if (_enable_map_request_to_workflow) { if (_enable_map_request_to_workflow) {
LOG(INFO) << "enable map request == True"; LOG(INFO) << "enable map request == True";
...@@ -155,14 +155,15 @@ int InferService::inference(const google::protobuf::Message* request, ...@@ -155,14 +155,15 @@ int InferService::inference(const google::protobuf::Message* request,
} }
size_t fsize = workflows->size(); size_t fsize = workflows->size();
for (size_t fi = 0; fi < fsize; ++fi) { for (size_t fi = 0; fi < fsize; ++fi) {
VLOG(2) << "workflow index: " << fi;
Workflow* workflow = (*workflows)[fi]; Workflow* workflow = (*workflows)[fi];
if (workflow == NULL) { if (workflow == NULL) {
LOG(ERROR) << "Failed to get valid workflow at: " << fi; LOG(ERROR) << "Failed to get valid workflow at: " << fi;
return ERR_INTERNAL_FAILURE; return ERR_INTERNAL_FAILURE;
} }
TRACEPRINTF("start to execute workflow[%s]", workflow->name().c_str()); VLOG(2) << "start to execute workflow[" << workflow->name() << "]";
int errcode = _execute_workflow(workflow, request, response, debug_os); int errcode = _execute_workflow(workflow, request, response, debug_os);
TRACEPRINTF("finish to execute workflow[%s]", workflow->name().c_str()); VLOG(2) << "finish to execute workflow[" << workflow->name() << "]";
if (errcode < 0) { if (errcode < 0) {
LOG(ERROR) << "Failed execute workflow[" << workflow->name() LOG(ERROR) << "Failed execute workflow[" << workflow->name()
<< "] in:" << name(); << "] in:" << name();
...@@ -171,12 +172,12 @@ int InferService::inference(const google::protobuf::Message* request, ...@@ -171,12 +172,12 @@ int InferService::inference(const google::protobuf::Message* request,
} }
} else { } else {
LOG(INFO) << "enable map request == False"; LOG(INFO) << "enable map request == False";
TRACEPRINTF("start to execute one workflow"); VLOG(2) << "start to execute one workflow";
size_t fsize = _flows.size(); size_t fsize = _flows.size();
for (size_t fi = 0; fi < fsize; ++fi) { for (size_t fi = 0; fi < fsize; ++fi) {
TRACEPRINTF("start to execute one workflow-%lu", fi); VLOG(2) << "start to execute one workflow-" << fi;
int errcode = execute_one_workflow(fi, request, response, debug_os); int errcode = execute_one_workflow(fi, request, response, debug_os);
TRACEPRINTF("finish to execute one workflow-%lu", fi); VLOG(2) << "finish to execute one workflow-" << fi;
if (errcode < 0) { if (errcode < 0) {
LOG(ERROR) << "Failed execute 0-th workflow in:" << name(); LOG(ERROR) << "Failed execute 0-th workflow in:" << name();
return errcode; return errcode;
...@@ -215,6 +216,7 @@ int InferService::_execute_workflow(Workflow* workflow, ...@@ -215,6 +216,7 @@ int InferService::_execute_workflow(Workflow* workflow,
req_channel.init(0, START_OP_NAME); req_channel.init(0, START_OP_NAME);
req_channel = request; req_channel = request;
VLOG(2) << "dag full name: " << full_name();
DagView* dv = workflow->fetch_dag_view(full_name()); DagView* dv = workflow->fetch_dag_view(full_name());
dv->set_request_channel(req_channel); dv->set_request_channel(req_channel);
...@@ -225,14 +227,14 @@ int InferService::_execute_workflow(Workflow* workflow, ...@@ -225,14 +227,14 @@ int InferService::_execute_workflow(Workflow* workflow,
return errcode; return errcode;
} }
TRACEPRINTF("finish to dv execute"); VLOG(2) << "finish to dv execute";
// create ender channel and copy // create ender channel and copy
const Channel* res_channel = dv->get_response_channel(); const Channel* res_channel = dv->get_response_channel();
if (!_merger || !_merger->merge(res_channel->message(), response)) { if (!_merger || !_merger->merge(res_channel->message(), response)) {
LOG(ERROR) << "Failed merge channel res to response"; LOG(ERROR) << "Failed merge channel res to response";
return ERR_INTERNAL_FAILURE; return ERR_INTERNAL_FAILURE;
} }
TRACEPRINTF("finish to copy from"); VLOG(2) << "finish to copy from";
workflow_time.stop(); workflow_time.stop();
LOG(INFO) << "workflow total time: " << workflow_time.u_elapsed(); LOG(INFO) << "workflow total time: " << workflow_time.u_elapsed();
...@@ -241,7 +243,7 @@ int InferService::_execute_workflow(Workflow* workflow, ...@@ -241,7 +243,7 @@ int InferService::_execute_workflow(Workflow* workflow,
// return tls data to object pool // return tls data to object pool
workflow->return_dag_view(dv); workflow->return_dag_view(dv);
TRACEPRINTF("finish to return dag view"); VLOG(2) << "finish to return dag view";
return ERR_OK; return ERR_OK;
} }
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_client import Client
import sys
client = Client()
#client.load_client_config(sys.argv[1])
client.connect(["127.0.0.1:9696"])
'''
import paddle
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.uci_housing.test(), buf_size=500),
batch_size=1)
for data in test_reader():
fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["price"])
print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
'''
...@@ -154,6 +154,10 @@ class Client(object): ...@@ -154,6 +154,10 @@ class Client(object):
print( print(
"parameter endpoints({}) will not take effect, because you use the add_variant function.". "parameter endpoints({}) will not take effect, because you use the add_variant function.".
format(endpoints)) format(endpoints))
from .serving_client import PredictorClient
from .serving_client import PredictorRes
self.result_handle_ = PredictorRes()
self.client_handle_ = PredictorClient()
sdk_desc = self.predictor_sdk_.gen_desc() sdk_desc = self.predictor_sdk_.gen_desc()
self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString( self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
)) ))
...@@ -163,11 +167,10 @@ class Client(object): ...@@ -163,11 +167,10 @@ class Client(object):
model_conf = google.protobuf.text_format.Merge( model_conf = google.protobuf.text_format.Merge(
str(model_config_str), model_conf) str(model_config_str), model_conf)
self.result_handle_ = PredictorRes()
self.client_handle_ = PredictorClient()
self.client_handle_.init_from_string(model_config_str) self.client_handle_.init_from_string(model_config_str)
if "FLAGS_max_body_size" not in os.environ: if "FLAGS_max_body_size" not in os.environ:
os.environ["FLAGS_max_body_size"] = str(512 * 1024 * 1024) os.environ["FLAGS_max_body_size"] = str(512 * 1024 * 1024)
read_env_flags = ["profile_client", "profile_server", "max_body_size"]
self.client_handle_.init_gflags([sys.argv[ self.client_handle_.init_gflags([sys.argv[
0]] + ["--tryfromenv=" + ",".join(read_env_flags)]) 0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
self.feed_names_ = [var.alias_name for var in model_conf.feed_var] self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
......
...@@ -107,7 +107,7 @@ class OpSeqMaker(object): ...@@ -107,7 +107,7 @@ class OpSeqMaker(object):
node = server_sdk.DAGNode() node = server_sdk.DAGNode()
node.name = "general_get_conf_0" node.name = "general_get_conf_0"
node.type = "GeneralGetConfOp" node.type = "GeneralGetConfOp"
get_conf_workflow.nodes.extend(node) get_conf_workflow.nodes.extend([node])
workflow_conf.workflows.extend([get_conf_workflow]) workflow_conf.workflows.extend([get_conf_workflow])
return workflow_conf return workflow_conf
...@@ -233,15 +233,13 @@ class Server(object): ...@@ -233,15 +233,13 @@ class Server(object):
infer_service.request_field_key = "request_type" infer_service.request_field_key = "request_type"
kv_workflow1 = server_sdk.ValueMappedWorkflow() kv_workflow1 = server_sdk.ValueMappedWorkflow()
kv_workflow1.request_field_value = "GetConf" kv_workflow1.request_field_value = "Predict"
kv_workflow1.workflows = "workflow1" kv_workflow1.workflow = "workflow1"
infer_service.value_mapped_workflows.extend([kv_workflow1]) infer_service.value_mapped_workflows.extend([kv_workflow1])
kv_workflow2 = server_sdk.ValueMappedWorkflow() kv_workflow2 = server_sdk.ValueMappedWorkflow()
kv_workflow2.request_field_value = "Predict" kv_workflow2.request_field_value = "GetConf"
kv_workflow2.workflows = "workflow2" kv_workflow2.workflow = "workflow2"
infer_service.value_mapped_workflows.extend([kv_workflow2]) infer_service.value_mapped_workflows.extend([kv_workflow2])
infer_service.name = "GeneralModelService" infer_service.name = "GeneralModelService"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册