提交 121e63e4 编写于 作者: H HexToString

add http_proto and grpcclient

上级 fa153d54
......@@ -39,11 +39,11 @@ INCLUDE_DIRECTORIES(${BRPC_INCLUDE_DIR})
set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib|${THIRD_PARTY_PATH}/install/glog")
if(WITH_LITE)
set(BRPC_REPO "https://github.com/zhangjun/incubator-brpc.git")
set(BRPC_REPO "https://github.com/apache/incubator-brpc")
set(BRPC_TAG "master")
else()
set(BRPC_REPO "https://github.com/wangjiawei04/brpc")
set(BRPC_TAG "6d79e0b17f25107c35b705ea58d888083f59ff47")
set(BRPC_REPO "https://github.com/apache/incubator-brpc")
set(BRPC_TAG "master")
endif()
# If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package baidu.paddle_serving.predictor.general_model;
option java_multiple_files = true;
message Tensor {
repeated string data = 1;
repeated int32 int_data = 2;
repeated int64 int64_data = 3;
repeated float float_data = 4;
optional int32 elem_type =
5; // 0 means int64, 1 means float32, 2 means int32, 3 means string
repeated int32 shape = 6; // shape should include batch
repeated int32 lod = 7; // only for fetch tensor currently
optional string name = 8; // get from the Model prototxt
optional string alias_name = 9; // get from the Model prototxt
};
message Request {
repeated Tensor tensor = 1;
repeated string fetch_var_names = 2;
optional bool profile_server = 3 [ default = false ];
required uint64 log_id = 4 [ default = 0 ];
};
message Response {
repeated ModelOutput outputs = 1;
repeated int64 profile_time = 2;
};
message ModelOutput {
repeated Tensor tensor = 1;
optional string engine_name = 2;
}
service GeneralModelService {
rpc inference(Request) returns (Response) {}
rpc debug(Request) returns (Response) {}
};
......@@ -20,12 +20,12 @@ package baidu.paddle_serving.predictor.general_model;
option cc_generic_services = true;
message Tensor {
repeated bytes data = 1;
repeated string data = 1;
repeated int32 int_data = 2;
repeated int64 int64_data = 3;
repeated float float_data = 4;
optional int32 elem_type =
5; // 0 means int64, 1 means float32, 2 means int32, 3 means bytes(string)
5; // 0 means int64, 1 means float32, 2 means int32, 3 means string
repeated int32 shape = 6; // shape should include batch
repeated int32 lod = 7; // only for fetch tensor currently
optional string name = 8; // get from the Model prototxt
......
......@@ -20,12 +20,12 @@ package baidu.paddle_serving.predictor.general_model;
option cc_generic_services = true;
message Tensor {
repeated bytes data = 1;
repeated string data = 1;
repeated int32 int_data = 2;
repeated int64 int64_data = 3;
repeated float float_data = 4;
optional int32 elem_type =
5; // 0 means int64, 1 means float32, 2 means int32, 3 means bytes(string)
5; // 0 means int64, 1 means float32, 2 means int32, 3 means string
repeated int32 shape = 6; // shape should include batch
repeated int32 lod = 7; // only for fetch tensor currently
optional string name = 8; // get from the Model prototxt
......
......@@ -11,7 +11,7 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
public class PaddleServingClientExample {
boolean fit_a_line(String model_config_path) {
boolean http_proto(String model_config_path) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
......@@ -24,7 +24,7 @@ public class PaddleServingClientExample {
}};
List<String> fetch = Arrays.asList("price");
HttpClient client = new HttpClient();
Client client = new Client();
client.setIP("0.0.0.0");
client.setPort("9393");
client.loadClientConfig(model_config_path);
......@@ -34,6 +34,55 @@ public class PaddleServingClientExample {
return true;
}
boolean http_json(String model_config_path) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.createFromArray(data);
long[] batch_shape = {1,13};
INDArray batch_npdata = npdata.reshape(batch_shape);
HashMap<String, Object> feed_data
= new HashMap<String, Object>() {{
put("x", batch_npdata);
}};
List<String> fetch = Arrays.asList("price");
Client client = new Client();
//注意:跨docker,需要设置--net-host或直接访问另一个docker的ip
client.setIP("0.0.0.0");
client.setPort("9393");
client.set_http_proto(false);
client.loadClientConfig(model_config_path);
String result = client.predict(feed_data, fetch, true, 0);
System.out.println(result);
return true;
}
boolean grpc(String model_config_path) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.createFromArray(data);
long[] batch_shape = {1,13};
INDArray batch_npdata = npdata.reshape(batch_shape);
HashMap<String, Object> feed_data
= new HashMap<String, Object>() {{
put("x", batch_npdata);
}};
List<String> fetch = Arrays.asList("price");
Client client = new Client();
client.setIP("0.0.0.0");
client.setPort("9393");
client.loadClientConfig(model_config_path);
client.set_use_grpc_client(true);
String result = client.predict(feed_data, fetch, true, 0);
System.out.println(result);
return true;
}
boolean encrypt(String model_config_path,String keyFilePath) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
......@@ -47,7 +96,7 @@ public class PaddleServingClientExample {
}};
List<String> fetch = Arrays.asList("price");
HttpClient client = new HttpClient();
Client client = new Client();
client.setIP("0.0.0.0");
client.setPort("9393");
client.loadClientConfig(model_config_path);
......@@ -55,7 +104,6 @@ public class PaddleServingClientExample {
try {
Thread.sleep(1000*3); // 休眠3秒,等待Server启动
} catch (Exception e) {
//TODO: handle exception
}
String result = client.predict(feed_data, fetch, true, 0);
......@@ -76,7 +124,7 @@ public class PaddleServingClientExample {
}};
List<String> fetch = Arrays.asList("price");
HttpClient client = new HttpClient();
Client client = new Client();
client.setIP("0.0.0.0");
client.setPort("9393");
client.loadClientConfig(model_config_path);
......@@ -127,7 +175,7 @@ public class PaddleServingClientExample {
put("im_size", batch_im_size);
}};
List<String> fetch = Arrays.asList("save_infer_model/scale_0.tmp_0");
HttpClient client = new HttpClient();
Client client = new Client();
client.setIP("0.0.0.0");
client.setPort("9393");
client.loadClientConfig(model_config_path);
......@@ -149,7 +197,7 @@ public class PaddleServingClientExample {
put("segment_ids", Nd4j.createFromArray(segment_ids));
}};
List<String> fetch = Arrays.asList("pooled_output");
HttpClient client = new HttpClient();
Client client = new Client();
client.setIP("0.0.0.0");
client.setPort("9393");
client.loadClientConfig(model_config_path);
......@@ -219,7 +267,7 @@ public class PaddleServingClientExample {
put("embedding_0.tmp_0", Nd4j.createFromArray(embedding_0));
}};
List<String> fetch = Arrays.asList("prob");
HttpClient client = new HttpClient();
Client client = new Client();
client.setIP("0.0.0.0");
client.setPort("9393");
client.loadClientConfig(model_config_path);
......@@ -236,13 +284,17 @@ public class PaddleServingClientExample {
if (args.length < 2) {
System.out.println("Usage: java -cp <jar> PaddleServingClientExample <test-type> <configPath>.");
System.out.println("<test-type>: fit_a_line bert cube_local yolov4 encrypt");
System.out.println("<test-type>: http_proto grpc bert cube_local yolov4 encrypt");
return;
}
String testType = args[0];
System.out.format("[Example] %s\n", testType);
if ("fit_a_line".equals(testType)) {
succ = e.fit_a_line(args[1]);
if ("http_proto".equals(testType)) {
succ = e.http_proto(args[1]);
} else if ("http_json".equals(testType)) {
succ = e.http_json(args[1]);
} else if ("grpc".equals(testType)) {
succ = e.grpc(args[1]);
} else if ("compress".equals(testType)) {
succ = e.compress(args[1]);
} else if ("bert".equals(testType)) {
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package baidu.paddle_serving.predictor.general_model;
option java_multiple_files = true;
message Tensor {
repeated string data = 1;
repeated int32 int_data = 2;
repeated int64 int64_data = 3;
repeated float float_data = 4;
optional int32 elem_type =
5; // 0 means int64, 1 means float32, 2 means int32, 3 means string
repeated int32 shape = 6; // shape should include batch
repeated int32 lod = 7; // only for fetch tensor currently
optional string name = 8; // get from the Model prototxt
optional string alias_name = 9; // get from the Model prototxt
};
message Request {
repeated Tensor tensor = 1;
repeated string fetch_var_names = 2;
optional bool profile_server = 3 [ default = false ];
required uint64 log_id = 4 [ default = 0 ];
};
message Response {
repeated ModelOutput outputs = 1;
repeated int64 profile_time = 2;
};
message ModelOutput {
repeated Tensor tensor = 1;
optional string engine_name = 2;
}
service GeneralModelService {
rpc inference(Request) returns (Response) {}
rpc debug(Request) returns (Response) {}
};
......@@ -13,34 +13,45 @@
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_client.httpclient import HttpClient
from paddle_serving_client.httpclient import GeneralClient
import sys
import numpy as np
import time
client = HttpClient()
client = GeneralClient()
client.load_client_config(sys.argv[1])
'''
if you want use GRPC-client, set_use_grpc_client(True)
or you can directly use client.grpc_client_predict(...)
as for HTTP-client,set_use_grpc_client(False)(which is default)
or you can directly use client.http_client_predict(...)
'''
#client.set_use_grpc_client(True)
'''
if you want to enable Encrypt Module,uncommenting the following line
'''
# client.use_key("./key")
#client.use_key("./key")
'''
if you want to compress,uncommenting the following line
'''
#client.set_response_compress(True)
#client.set_request_compress(True)
'''
we recommend use Proto data format in HTTP-body, set True(which is default)
if you want use JSON data format in HTTP-body, set False
'''
#client.set_http_proto(True)
fetch_list = client.get_fetch_names()
import paddle
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.uci_housing.test(), buf_size=500),
batch_size=1)
fetch_list = client.get_fetch_names()
for data in test_reader():
new_data = np.zeros((1, 13)).astype("float32")
new_data[0] = data[0][0]
fetch_map = client.grpc_client_predict(
fetch_map = client.predict(
feed={"x": new_data}, fetch=fetch_list, batch=True)
print(fetch_map)
break
......@@ -66,7 +66,14 @@ def data_bytes_number(datalist):
return total_bytes_number
class HttpClient(object):
# 此文件名,暂时为httpclient.py,待后续测试后考虑是否替换client.py
# 默认使用http方式,默认使用Proto in HTTP-body
# 如果想使用JSON in HTTP-body, set_http_proto(False)
# Predict()是包装类http_client_predict/grpc_client_predict
# 可以直接调用需要的http_client_predict/grpc_client_predict
# 例如,如果想使用GRPC方式,set_use_grpc_client(True)
# 或者直接调用grpc_client_predict()
class GeneralClient(object):
def __init__(self,
ip="0.0.0.0",
port="9393",
......@@ -77,7 +84,7 @@ class HttpClient(object):
self.feed_shapes_ = {}
self.feed_types_ = {}
self.feed_names_to_idx_ = {}
self.http_timeout_ms = 200000
self.timeout_ms = 200000
self.ip = ip
self.port = port
self.server_port = port
......@@ -86,7 +93,9 @@ class HttpClient(object):
self.try_request_gzip = False
self.try_response_gzip = False
self.total_data_number = 0
self.http_proto = False
self.http_proto = True
self.max_body_size = 512 * 1024 * 1024
self.use_grpc_client = False
def load_client_config(self, model_config_path_list):
if isinstance(model_config_path_list, str):
......@@ -144,11 +153,14 @@ class HttpClient(object):
self.lod_tensor_set.add(var.alias_name)
return
def set_http_timeout_ms(self, http_timeout_ms):
if not isinstance(http_timeout_ms, int):
raise ValueError("http_timeout_ms must be int type.")
def set_max_body_size(self, max_body_size):
self.max_body_size = max_body_size
def set_timeout_ms(self, timeout_ms):
if not isinstance(timeout_ms, int):
raise ValueError("timeout_ms must be int type.")
else:
self.http_timeout_ms = http_timeout_ms
self.timeout_ms = timeout_ms
def set_ip(self, ip):
self.ip = ip
......@@ -168,6 +180,9 @@ class HttpClient(object):
def set_http_proto(self, http_proto):
self.http_proto = http_proto
def set_use_grpc_client(self, use_grpc_client):
self.use_grpc_client = use_grpc_client
# use_key is the function of encryption.
def use_key(self, key_filename):
with open(key_filename, "rb") as f:
......@@ -195,50 +210,6 @@ class HttpClient(object):
def get_fetch_names(self):
return self.fetch_names_
# feed 支持Numpy类型,以及直接List、tuple
# 不支持str类型,因为proto中为repeated.
def predict(self,
feed=None,
fetch=None,
batch=False,
need_variant_tag=False,
log_id=0):
feed_dict = self.get_feedvar_dict(feed)
fetch_list = self.get_legal_fetch(fetch)
headers = {}
postData = ''
if self.http_proto == True:
postData = self.process_proto_data(feed_dict, fetch_list, batch,
log_id).SerializeToString()
headers["Content-Type"] = "application/proto"
else:
postData = self.process_json_data(feed_dict, fetch_list, batch,
log_id)
headers["Content-Type"] = "application/json"
web_url = "http://" + self.ip + ":" + self.server_port + self.service_name
# 当数据区长度大于512字节时才压缩.
if self.try_request_gzip and self.total_data_number > 512:
postData = gzip.compress(bytes(postData, 'utf-8'))
headers["Content-Encoding"] = "gzip"
if self.try_response_gzip:
headers["Accept-encoding"] = "gzip"
# requests支持自动识别解压
result = requests.post(url=web_url, headers=headers, data=postData)
if result == None:
return None
if result.status_code == 200:
if result.headers["Content-Type"] == 'application/proto':
response = general_model_service_pb2.Response()
response.ParseFromString(result.content)
return response
else:
return result.json()
return result
def get_legal_fetch(self, fetch):
if fetch is None:
raise ValueError("You should specify feed and fetch for prediction")
......@@ -265,29 +236,32 @@ class HttpClient(object):
def get_feedvar_dict(self, feed):
if feed is None:
raise ValueError("You should specify feed and fetch for prediction")
feed_batch = []
feed_dict = {}
if isinstance(feed, dict):
feed_batch.append(feed)
feed_dict = feed
elif isinstance(feed, (list, str, tuple)):
# if input is a list or str or tuple, and the number of feed_var is 1.
# create a temp_dict { key = feed_var_name, value = list}
# put the temp_dict into the feed_batch.
if len(self.feed_names_) != 1:
# create a feed_dict { key = feed_var_name, value = list}
if len(self.feed_names_) == 1:
feed_dict[self.feed_names_[0]] = feed
elif len(self.feed_names_) > 1:
if isinstance(feed, str):
raise ValueError(
"input is a list, but we got 0 or 2+ feed_var, don`t know how to divide the feed list"
"input is a str, but we got 2+ feed_var, don`t know how to divide the string"
)
temp_dict = {}
temp_dict[self.feed_names_[0]] = feed
feed_batch.append(temp_dict)
# feed is a list or tuple
elif len(self.feed_names_) == len(feed):
for index in range(len(feed)):
feed_dict[self.feed_names_[index]] = feed[index]
else:
raise ValueError("len(feed) ≠ len(feed_var), error")
else:
raise ValueError("Feed only accepts dict and list of dict")
raise ValueError("we got feed, but feed_var is None")
# batch_size must be 1, cause batch is already in Tensor.
if len(feed_batch) != 1:
raise ValueError("len of feed_batch can only be 1.")
else:
raise ValueError("Feed only accepts dict/str/list/tuple")
return feed_batch[0]
return feed_dict
def process_json_data(self, feed_dict, fetch_list, batch, log_id):
Request = {}
......@@ -429,6 +403,64 @@ class HttpClient(object):
tensor_dict["lod"] = lod
return tensor_dict
# feed结构必须为dict、List、tuple、string
# feed中数据支持Numpy、list、tuple、以及基本类型
# fetch默认是从模型的配置文件中获取全部的fetch_var
def predict(self,
feed=None,
fetch=None,
batch=False,
need_variant_tag=False,
log_id=0):
if self.use_grpc_client:
return self.grpc_client_predict(feed, fetch, batch,
need_variant_tag, log_id)
else:
return self.http_client_predict(feed, fetch, batch,
need_variant_tag, log_id)
def http_client_predict(self,
feed=None,
fetch=None,
batch=False,
need_variant_tag=False,
log_id=0):
feed_dict = self.get_feedvar_dict(feed)
fetch_list = self.get_legal_fetch(fetch)
headers = {}
postData = ''
if self.http_proto == True:
postData = self.process_proto_data(feed_dict, fetch_list, batch,
log_id).SerializeToString()
headers["Content-Type"] = "application/proto"
else:
postData = self.process_json_data(feed_dict, fetch_list, batch,
log_id)
headers["Content-Type"] = "application/json"
web_url = "http://" + self.ip + ":" + self.server_port + self.service_name
# 当数据区长度大于512字节时才压缩.
if self.try_request_gzip and self.total_data_number > 512:
postData = gzip.compress(bytes(postData, 'utf-8'))
headers["Content-Encoding"] = "gzip"
if self.try_response_gzip:
headers["Accept-encoding"] = "gzip"
# requests支持自动识别解压
result = requests.post(url=web_url, headers=headers, data=postData)
if result == None:
return None
if result.status_code == 200:
if result.headers["Content-Type"] == 'application/proto':
response = general_model_service_pb2.Response()
response.ParseFromString(result.content)
return response
else:
return result.json()
return result
def grpc_client_predict(self,
feed=None,
fetch=None,
......@@ -440,19 +472,17 @@ class HttpClient(object):
fetch_list = self.get_legal_fetch(fetch)
postData = self.process_proto_data(feed_dict, fetch_list, batch, log_id)
print('proto data', postData)
'''
# https://github.com/tensorflow/serving/issues/1382
options = [('grpc.max_receive_message_length', 512 * 1024 * 1024),
('grpc.max_send_message_length', 512 * 1024 * 1024),
('grpc.lb_policy_name', 'round_robin')]
'''
options = [('grpc.max_receive_message_length', self.max_body_size),
('grpc.max_send_message_length', self.max_body_size)]
endpoints = [self.ip + ":" + self.server_port]
g_endpoint = 'ipv4:{}'.format(','.join(endpoints))
print("my endpoint is ", g_endpoint)
self.channel_ = grpc.insecure_channel(g_endpoint, options=options)
self.stub_ = general_model_service_pb2_grpc.GeneralModelServiceStub(
self.channel_)
resp = self.stub_.inference(postData, timeout=self.http_timeout_ms)
resp = self.stub_.inference(postData, timeout=self.timeout_ms)
return resp
......@@ -108,9 +108,13 @@ def is_gpu_mode(unformatted_gpus):
def serve_args():
parser = argparse.ArgumentParser("serve")
parser.add_argument(
"--thread", type=int, default=2, help="Concurrency of server")
"--thread",
type=int,
default=4,
help="Concurrency of server,[4,1024]",
choices=range(4, 1025))
parser.add_argument(
"--port", type=int, default=9292, help="Port of the starting gpu")
"--port", type=int, default=9393, help="Port of the starting gpu")
parser.add_argument(
"--device", type=str, default="cpu", help="Type of device")
parser.add_argument(
......@@ -180,8 +184,6 @@ def serve_args():
default=False,
action="store_true",
help="Use gpu_multi_stream")
parser.add_argument(
"--grpc", default=False, action="store_true", help="Use grpc test")
return parser.parse_args()
......@@ -385,34 +387,5 @@ if __name__ == "__main__":
'Starting encryption server, waiting for key from client, use <Ctrl-C> to stop'
)
server.serve_forever()
else:
# this is for grpc Test
if args.grpc:
from .proto import general_model_service_pb2
sys.path.append(
os.path.join(
os.path.abspath(os.path.dirname(__file__)), 'proto'))
from .proto import general_model_service_pb2_grpc
import google.protobuf.text_format
from concurrent import futures
import grpc
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
class GeneralModelService(
general_model_service_pb2_grpc.GeneralModelServiceServicer):
def inference(self, request, context):
return general_model_service_pb2.Response()
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
general_model_service_pb2_grpc.add_GeneralModelServiceServicer_to_server(
GeneralModelService(), server)
server.add_insecure_port('[::]:9393')
server.start()
try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)
else:
start_multi_card(args)
......@@ -20,12 +20,12 @@ package baidu.paddle_serving.predictor.general_model;
option cc_generic_services = true;
message Tensor {
repeated bytes data = 1;
repeated string data = 1;
repeated int32 int_data = 2;
repeated int64 int64_data = 3;
repeated float float_data = 4;
optional int32 elem_type =
5; // 0 means int64, 1 means float32, 2 means int32, 3 means bytes(string)
5; // 0 means int64, 1 means float32, 2 means int32, 3 means string
repeated int32 shape = 6; // shape should include batch
repeated int32 lod = 7; // only for fetch tensor currently
optional string name = 8; // get from the Model prototxt
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册