提交 121e63e4 编写于 作者: H HexToString

add http_proto and grpcclient

上级 fa153d54
......@@ -39,11 +39,11 @@ INCLUDE_DIRECTORIES(${BRPC_INCLUDE_DIR})
set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib|${THIRD_PARTY_PATH}/install/glog")
if(WITH_LITE)
set(BRPC_REPO "https://github.com/zhangjun/incubator-brpc.git")
set(BRPC_REPO "https://github.com/apache/incubator-brpc")
set(BRPC_TAG "master")
else()
set(BRPC_REPO "https://github.com/wangjiawei04/brpc")
set(BRPC_TAG "6d79e0b17f25107c35b705ea58d888083f59ff47")
set(BRPC_REPO "https://github.com/apache/incubator-brpc")
set(BRPC_TAG "master")
endif()
# If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package baidu.paddle_serving.predictor.general_model;
option java_multiple_files = true;
message Tensor {
repeated string data = 1;
repeated int32 int_data = 2;
repeated int64 int64_data = 3;
repeated float float_data = 4;
optional int32 elem_type =
5; // 0 means int64, 1 means float32, 2 means int32, 3 means string
repeated int32 shape = 6; // shape should include batch
repeated int32 lod = 7; // only for fetch tensor currently
optional string name = 8; // get from the Model prototxt
optional string alias_name = 9; // get from the Model prototxt
};
message Request {
repeated Tensor tensor = 1;
repeated string fetch_var_names = 2;
optional bool profile_server = 3 [ default = false ];
required uint64 log_id = 4 [ default = 0 ];
};
message Response {
repeated ModelOutput outputs = 1;
repeated int64 profile_time = 2;
};
message ModelOutput {
repeated Tensor tensor = 1;
optional string engine_name = 2;
}
service GeneralModelService {
rpc inference(Request) returns (Response) {}
rpc debug(Request) returns (Response) {}
};
......@@ -20,12 +20,12 @@ package baidu.paddle_serving.predictor.general_model;
option cc_generic_services = true;
message Tensor {
repeated bytes data = 1;
repeated string data = 1;
repeated int32 int_data = 2;
repeated int64 int64_data = 3;
repeated float float_data = 4;
optional int32 elem_type =
5; // 0 means int64, 1 means float32, 2 means int32, 3 means bytes(string)
5; // 0 means int64, 1 means float32, 2 means int32, 3 means string
repeated int32 shape = 6; // shape should include batch
repeated int32 lod = 7; // only for fetch tensor currently
optional string name = 8; // get from the Model prototxt
......
......@@ -20,12 +20,12 @@ package baidu.paddle_serving.predictor.general_model;
option cc_generic_services = true;
message Tensor {
repeated bytes data = 1;
repeated string data = 1;
repeated int32 int_data = 2;
repeated int64 int64_data = 3;
repeated float float_data = 4;
optional int32 elem_type =
5; // 0 means int64, 1 means float32, 2 means int32, 3 means bytes(string)
5; // 0 means int64, 1 means float32, 2 means int32, 3 means string
repeated int32 shape = 6; // shape should include batch
repeated int32 lod = 7; // only for fetch tensor currently
optional string name = 8; // get from the Model prototxt
......
......@@ -11,7 +11,7 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
public class PaddleServingClientExample {
boolean fit_a_line(String model_config_path) {
boolean http_proto(String model_config_path) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
......@@ -24,7 +24,7 @@ public class PaddleServingClientExample {
}};
List<String> fetch = Arrays.asList("price");
HttpClient client = new HttpClient();
Client client = new Client();
client.setIP("0.0.0.0");
client.setPort("9393");
client.loadClientConfig(model_config_path);
......@@ -34,6 +34,55 @@ public class PaddleServingClientExample {
return true;
}
boolean http_json(String model_config_path) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.createFromArray(data);
long[] batch_shape = {1,13};
INDArray batch_npdata = npdata.reshape(batch_shape);
HashMap<String, Object> feed_data
= new HashMap<String, Object>() {{
put("x", batch_npdata);
}};
List<String> fetch = Arrays.asList("price");
Client client = new Client();
//注意:跨docker,需要设置--net-host或直接访问另一个docker的ip
client.setIP("0.0.0.0");
client.setPort("9393");
client.set_http_proto(false);
client.loadClientConfig(model_config_path);
String result = client.predict(feed_data, fetch, true, 0);
System.out.println(result);
return true;
}
boolean grpc(String model_config_path) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.createFromArray(data);
long[] batch_shape = {1,13};
INDArray batch_npdata = npdata.reshape(batch_shape);
HashMap<String, Object> feed_data
= new HashMap<String, Object>() {{
put("x", batch_npdata);
}};
List<String> fetch = Arrays.asList("price");
Client client = new Client();
client.setIP("0.0.0.0");
client.setPort("9393");
client.loadClientConfig(model_config_path);
client.set_use_grpc_client(true);
String result = client.predict(feed_data, fetch, true, 0);
System.out.println(result);
return true;
}
boolean encrypt(String model_config_path,String keyFilePath) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
......@@ -47,7 +96,7 @@ public class PaddleServingClientExample {
}};
List<String> fetch = Arrays.asList("price");
HttpClient client = new HttpClient();
Client client = new Client();
client.setIP("0.0.0.0");
client.setPort("9393");
client.loadClientConfig(model_config_path);
......@@ -55,7 +104,6 @@ public class PaddleServingClientExample {
try {
Thread.sleep(1000*3); // 休眠3秒,等待Server启动
} catch (Exception e) {
//TODO: handle exception
}
String result = client.predict(feed_data, fetch, true, 0);
......@@ -76,7 +124,7 @@ public class PaddleServingClientExample {
}};
List<String> fetch = Arrays.asList("price");
HttpClient client = new HttpClient();
Client client = new Client();
client.setIP("0.0.0.0");
client.setPort("9393");
client.loadClientConfig(model_config_path);
......@@ -127,7 +175,7 @@ public class PaddleServingClientExample {
put("im_size", batch_im_size);
}};
List<String> fetch = Arrays.asList("save_infer_model/scale_0.tmp_0");
HttpClient client = new HttpClient();
Client client = new Client();
client.setIP("0.0.0.0");
client.setPort("9393");
client.loadClientConfig(model_config_path);
......@@ -149,7 +197,7 @@ public class PaddleServingClientExample {
put("segment_ids", Nd4j.createFromArray(segment_ids));
}};
List<String> fetch = Arrays.asList("pooled_output");
HttpClient client = new HttpClient();
Client client = new Client();
client.setIP("0.0.0.0");
client.setPort("9393");
client.loadClientConfig(model_config_path);
......@@ -219,7 +267,7 @@ public class PaddleServingClientExample {
put("embedding_0.tmp_0", Nd4j.createFromArray(embedding_0));
}};
List<String> fetch = Arrays.asList("prob");
HttpClient client = new HttpClient();
Client client = new Client();
client.setIP("0.0.0.0");
client.setPort("9393");
client.loadClientConfig(model_config_path);
......@@ -236,13 +284,17 @@ public class PaddleServingClientExample {
if (args.length < 2) {
System.out.println("Usage: java -cp <jar> PaddleServingClientExample <test-type> <configPath>.");
System.out.println("<test-type>: fit_a_line bert cube_local yolov4 encrypt");
System.out.println("<test-type>: http_proto grpc bert cube_local yolov4 encrypt");
return;
}
String testType = args[0];
System.out.format("[Example] %s\n", testType);
if ("fit_a_line".equals(testType)) {
succ = e.fit_a_line(args[1]);
if ("http_proto".equals(testType)) {
succ = e.http_proto(args[1]);
} else if ("http_json".equals(testType)) {
succ = e.http_json(args[1]);
} else if ("grpc".equals(testType)) {
succ = e.grpc(args[1]);
} else if ("compress".equals(testType)) {
succ = e.compress(args[1]);
} else if ("bert".equals(testType)) {
......
package io.paddle.serving.client;
import java.util.*;
import java.util.function.Function;
import java.util.stream.*;
import java.util.Arrays;
import java.util.Iterator;
import java.lang.management.ManagementFactory;
import java.lang.management.RuntimeMXBean;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.Map.Entry;
......@@ -11,6 +16,8 @@ import java.nio.file.*;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.Nd4jCpu.boolean_and;
import java.lang.reflect.*;
import org.apache.http.HttpEntity;
......@@ -28,11 +35,22 @@ import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.message.BasicNameValuePair;
import org.apache.http.util.EntityUtils;
import org.hamcrest.core.IsInstanceOf;
import org.apache.http.entity.InputStreamEntity;
import org.apache.http.entity.ByteArrayEntity;
import org.json.*;
import io.paddle.serving.configure.*;
import baidu.paddle_serving.predictor.general_model.*;
import org.apache.commons.lang3.ArrayUtils;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.StatusRuntimeException;
import com.google.protobuf.ByteString;
import java.io.ByteArrayInputStream;
......@@ -79,8 +97,8 @@ class Profiler {
enable_ = flag;
}
}
public class HttpClient {
private int httpTimeoutS_;
public class Client {
private int timeoutS_;
private List<String> feedNames_;
private Map<String, String> feedRealNames_;
private Map<String, Integer> feedTypes_;
......@@ -99,8 +117,13 @@ public class HttpClient {
private boolean request_compress_flag;
private boolean response_compress_flag;
private String GLOG_v;
private boolean http_proto;
private boolean use_grpc_client;
private ManagedChannel channel_;
private GeneralModelServiceGrpc.GeneralModelServiceBlockingStub blockingStub_;
public HttpClient() {
public Client() {
feedNames_ = null;
feedRealNames_ = null;
feedTypes_ = null;
......@@ -110,7 +133,7 @@ public class HttpClient {
lodTensorSet_ = null;
feedTensorLen_ = null;
feedNameToIndex_ = null;
httpTimeoutS_ = 200000;
timeoutS_ = 200000;
ip = "0.0.0.0";
port = "9393";
serverPort = "9393";
......@@ -118,6 +141,11 @@ public class HttpClient {
request_compress_flag = false;
response_compress_flag = false;
GLOG_v = System.getenv("GLOG_v");
http_proto = true;//use the Proto in HTTP by default.
use_grpc_client = false;
channel_ = null;
blockingStub_ = null;
feedTypeToDataKey_ = new HashMap<Integer, String>();
feedTypeToDataKey_.put(0, "int64_data");
......@@ -134,8 +162,8 @@ public class HttpClient {
profiler_.enable(is_profile);
}
public void setTimeOut(int httpTimeoutS_) {
this.httpTimeoutS_ = httpTimeoutS_;
public void setTimeOut(int timeoutS_) {
this.timeoutS_ = timeoutS_;
}
public void setIP(String ip) {
......@@ -231,24 +259,33 @@ public class HttpClient {
}
public void set_request_compress(boolean request_compress_flag) {
// need to be done.
this.request_compress_flag = request_compress_flag;
}
public void set_response_compress(boolean response_compress_flag) {
// need to be done.
this.response_compress_flag = response_compress_flag;
}
public byte[] compress(String str) {
if (str == null || str.length() == 0) {
public void set_http_proto(boolean http_proto){
this.http_proto = http_proto;
}
public void set_use_grpc_client(boolean use_grpc_client){
this.use_grpc_client = use_grpc_client;
}
public byte[] compress(Object obj) {
if (obj == null) {
return null;
}
ByteArrayOutputStream out = new ByteArrayOutputStream();
GZIPOutputStream gzip;
try {
gzip = new GZIPOutputStream(out);
gzip.write(str.getBytes("UTF-8"));
if(obj instanceof String){
gzip.write(((String)obj).getBytes("UTF-8"));
}else{
gzip.write((byte[])obj);
}
gzip.close();
} catch (Exception e) {
e.printStackTrace();
......@@ -287,123 +324,56 @@ public class HttpClient {
List<String> fetch,
boolean batchFlag,
int log_id) {
String server_url = "http://" + this.ip + ":" + this.serverPort + this.serviceName;
// 处理fetchList
JSONArray jsonFetchList = new JSONArray();
Iterator<String> fetchIterator = fetch.iterator();
while (fetchIterator.hasNext()) {
jsonFetchList.put(fetchIterator.next());
}
if(this.use_grpc_client){
return grpc_predict(feedData, feedLod, fetch, batchFlag, log_id);
}
return http_predict(feedData, feedLod, fetch, batchFlag, log_id);
}
// 处理Tensor
JSONArray jsonTensorArray = new JSONArray();
try{
if (null != feedData && feedData.size() > 0) {
// 通过map集成entrySet方法获取entity
Set<Entry<String, Object>> entrySet = feedData.entrySet();
// 循环遍历,获取迭代器
Iterator<Entry<String, Object>> iterator = entrySet.iterator();
while (iterator.hasNext()) {
JSONObject jsonTensor = new JSONObject();
Entry<String, Object> mapEntry = iterator.next();
Object objectValue = mapEntry.getValue();
String feed_alias_name = mapEntry.getKey();
String feed_real_name = feedRealNames_.get(feed_alias_name);
List<Integer> shape = new ArrayList<Integer>(feedShapes_.get(feed_alias_name));
int element_type = feedTypes_.get(feed_alias_name);
jsonTensor.put("alias_name", feed_alias_name);
jsonTensor.put("name", feed_real_name);
jsonTensor.put("elem_type", element_type);
// 处理数据与shape
String protoDataKey = feedTypeToDataKey_.get(element_type);
// 如果是INDArray类型,先转为一维.
// 此时shape为INDArray的shape
if(objectValue instanceof INDArray){
INDArray tempIndArray = (INDArray)objectValue;
long[] indarrayShape = tempIndArray.shape();
shape.clear();
for(long dim:indarrayShape){
shape.add((int)dim);
}
objectValue = tempIndArray.data().asDouble();
}else if(objectValue.getClass().isArray()){
// 如果是数组类型,则无须处理,直接使用即可。
// 且数组无法嵌套,此时batch无法从数据中获取
// 默认batch维度为1,或者feedVar的shape信息中已包含batch
}else if(objectValue instanceof List){
// 如果为list,可能存在嵌套,此时需要展平
// 如果batchFlag为True,则认为是嵌套list
// 此时取最外层为batch的维度
if (batchFlag) {
List<?> list = new ArrayList<>();
list = new ArrayList<>((Collection<?>)objectValue);
// 在index=0处,加上batch
shape.add(0, list.size());
}
objectValue = recursiveExtract(objectValue);
}else{
// 此时认为是传入的单个String或者Int等
// 此时无法获取batch信息,故对shape不处理
// 由于Proto中为Repeated,需要把数据包装成list
if(objectValue instanceof String){
if(feedTypes_.get(protoDataKey)!= ElementType.Bytes_type.ordinal()){
throw new Exception("feedvar is not string-type,feed can`t be a single string.");
}
}else{
if(feedTypes_.get(protoDataKey)== ElementType.Bytes_type.ordinal()){
throw new Exception("feedvar is string-type,feed, feed can`t be a single int or others.");
}
}
List<Object> list = new ArrayList<>();
list.add(objectValue);
objectValue = list;
}
jsonTensor.put(protoDataKey,objectValue);
if(!batchFlag){
// 在index=0处,加上batch=1
shape.add(0, 1);
}
jsonTensor.put("shape", shape);
// 处理lod信息,支持INDArray Array Iterable
Object feedLodValue = null;
if(feedLod != null){
feedLodValue = feedLod.get(feed_alias_name);
if(feedLodValue != null) {
if(feedLodValue instanceof INDArray){
INDArray tempIndArray = (INDArray)feedLodValue;
feedLodValue = tempIndArray.data().asInt();
}else if(feedLodValue.getClass().isArray()){
// 如果是数组类型,则无须处理,直接使用即可。
}else if(feedLodValue instanceof Iterable){
// 如果为list,可能存在嵌套,此时需要展平
feedLodValue = recursiveExtract(feedLodValue);
}else{
throw new Exception("Lod must be INDArray or Array or Iterable.");
}
jsonTensor.put("lod", feedLodValue);
}
}
jsonTensorArray.put(jsonTensor);
}
}
}catch (Exception e) {
e.printStackTrace();
}
JSONObject jsonRequest = new JSONObject();
jsonRequest.put("log_id",log_id);
jsonRequest.put("fetch_var_names", jsonFetchList);
jsonRequest.put("tensor",jsonTensorArray);
if(GLOG_v != null){
System.out.format("------- Final jsonRequest: %s\n", jsonRequest.toString());
}
return doPost(server_url, jsonRequest.toString());
}
public String grpc_predict(Map<String, Object> feedData,
Map<String, Object> feedLod,
List<String> fetch,
boolean batchFlag,
int log_id) {
String result = null;
try {
String server_url = this.ip + ":" + this.serverPort;
channel_ = ManagedChannelBuilder.forTarget(server_url)
.defaultLoadBalancingPolicy("round_robin")
.maxInboundMessageSize(Integer.MAX_VALUE)
.usePlaintext()
.build();
blockingStub_ = GeneralModelServiceGrpc.newBlockingStub(channel_);
Request request = process_proto_data(feedData, feedLod, fetch, batchFlag, log_id);
Response resp = blockingStub_.inference(request);
result = resp.toString();
} catch (Exception e) {
System.out.format("grpc_predict failed: %s\n", e.toString());
return null;
}
return result;
}
public String doPost(String url, String strPostData) {
public String http_predict(Map<String, Object> feedData,
Map<String, Object> feedLod,
List<String> fetch,
boolean batchFlag,
int log_id) {
String server_url = "http://" + this.ip + ":" + this.serverPort + this.serviceName;
// 处理fetchList
String result = null;
if(this.http_proto){
Request request = process_proto_data(feedData, feedLod, fetch, batchFlag, log_id);
result = doPost(server_url, request.toByteArray());
}else{
JSONObject jsonRequest = process_json_data(feedData,feedLod,fetch,batchFlag,log_id);
result = doPost(server_url, jsonRequest.toString());
}
return result;
}
public String doPost(String url, Object postData) {
CloseableHttpClient httpClient = null;
CloseableHttpResponse httpResponse = null;
String result = "";
......@@ -412,13 +382,18 @@ public class HttpClient {
// 创建httpPost远程连接实例
HttpPost httpPost = new HttpPost(url);
// 配置请求参数实例
RequestConfig requestConfig = RequestConfig.custom().setConnectTimeout(httpTimeoutS_)// 设置连接主机服务超时时间
.setConnectionRequestTimeout(httpTimeoutS_)// 设置连接请求超时时间
.setSocketTimeout(httpTimeoutS_)// 设置读取数据连接超时时间
RequestConfig requestConfig = RequestConfig.custom().setConnectTimeout(timeoutS_)// 设置连接主机服务超时时间
.setConnectionRequestTimeout(timeoutS_)// 设置连接请求超时时间
.setSocketTimeout(timeoutS_)// 设置读取数据连接超时时间
.build();
// 为httpPost实例设置配置
httpPost.setConfig(requestConfig);
httpPost.setHeader("Content-Type", "application/json;charset=utf-8");
if(this.http_proto){
httpPost.setHeader("Content-Type", "application/proto");
}else{
httpPost.setHeader("Content-Type", "application/json");
}
// 设置请求头
if(response_compress_flag){
httpPost.addHeader("Accept-encoding", "gzip");
......@@ -428,17 +403,34 @@ public class HttpClient {
}
try {
if(request_compress_flag && strPostData.length()>1024){
try{
byte[] gzipEncrypt = compress(strPostData);
httpPost.setEntity(new InputStreamEntity(new ByteArrayInputStream(gzipEncrypt), gzipEncrypt.length));
httpPost.addHeader("Content-Encoding", "gzip");
} catch (Exception e) {
e.printStackTrace();
if(postData instanceof String){
if(request_compress_flag && ((String)postData).length()>1024){
try{
byte[] gzipEncrypt = compress(postData);
httpPost.setEntity(new InputStreamEntity(new ByteArrayInputStream(gzipEncrypt), gzipEncrypt.length));
httpPost.addHeader("Content-Encoding", "gzip");
} catch (Exception e) {
e.printStackTrace();
}
}else{
httpPost.setEntity(new StringEntity((String)postData, "UTF-8"));
}
}else{
httpPost.setEntity(new StringEntity(strPostData, "UTF-8"));
if(request_compress_flag && ((byte[])postData).length>1024){
try{
byte[] gzipEncrypt = compress(postData);
httpPost.setEntity(new InputStreamEntity(new ByteArrayInputStream(gzipEncrypt), gzipEncrypt.length));
httpPost.addHeader("Content-Encoding", "gzip");
} catch (Exception e) {
e.printStackTrace();
}
}else{
httpPost.setEntity(new ByteArrayEntity((byte[])postData));
//httpPost.setEntity(new InputStreamEntity(new ByteArrayInputStream((byte[])postData), ((byte[])postData).length));
}
}
// httpClient对象执行post请求,并返回响应参数对象
httpResponse = httpClient.execute(httpPost);
// 从响应对象中获取响应内容
......@@ -454,7 +446,13 @@ public class HttpClient {
System.out.format("------- degzip response: %s\n", result);
}
}else{
result = EntityUtils.toString(entity);
if(this.http_proto){
Response resp = Response.parseFrom(EntityUtils.toByteArray(entity));
result = resp.toString();
}else{
result = EntityUtils.toString(entity);
}
}
} catch (ClientProtocolException e) {
e.printStackTrace();
......@@ -499,5 +497,279 @@ public class HttpClient {
return mylist;
}
public JSONObject process_json_data(Map<String, Object> feedData,
Map<String, Object> feedLod,
List<String> fetch,
boolean batchFlag,
int log_id){
// 处理Tensor
JSONArray jsonTensorArray = new JSONArray();
try{
if (null != feedData && feedData.size() > 0) {
// 通过map集成entrySet方法获取entity
Set<Entry<String, Object>> entrySet = feedData.entrySet();
// 循环遍历,获取迭代器
Iterator<Entry<String, Object>> iterator = entrySet.iterator();
while (iterator.hasNext()) {
JSONObject jsonTensor = new JSONObject();
Entry<String, Object> mapEntry = iterator.next();
Object objectValue = mapEntry.getValue();
String feed_alias_name = mapEntry.getKey();
String feed_real_name = feedRealNames_.get(feed_alias_name);
List<Integer> shape = new ArrayList<Integer>(feedShapes_.get(feed_alias_name));
int element_type = feedTypes_.get(feed_alias_name);
jsonTensor.put("alias_name", feed_alias_name);
jsonTensor.put("name", feed_real_name);
jsonTensor.put("elem_type", element_type);
// 处理数据与shape
String protoDataKey = feedTypeToDataKey_.get(element_type);
// 如果是INDArray类型,先转为一维.
// 此时shape为INDArray的shape
if(objectValue instanceof INDArray){
INDArray tempIndArray = (INDArray)objectValue;
long[] indarrayShape = tempIndArray.shape();
shape.clear();
for(long dim:indarrayShape){
shape.add((int)dim);
}
if(element_type == ElementType.Int64_type.ordinal()){
objectValue = tempIndArray.data().asLong();
}else if(element_type == ElementType.Int32_type.ordinal()){
objectValue = tempIndArray.data().asInt();
}else if(element_type == ElementType.Float32_type.ordinal()){
objectValue = tempIndArray.data().asFloat();
}else{
throw new Exception("INDArray 类型不支持");
}
}else if(objectValue.getClass().isArray()){
// 如果是数组类型,则无须处理,直接使用即可。
// 且数组无法嵌套,此时batch无法从数据中获取
// 默认batch维度为1,或者feedVar的shape信息中已包含batch
}else if(objectValue instanceof List){
// 如果为list,可能存在嵌套,此时需要展平
// 如果batchFlag为True,则认为是嵌套list
// 此时取最外层为batch的维度
if (batchFlag) {
List<?> list = new ArrayList<>();
list = new ArrayList<>((Collection<?>)objectValue);
// 在index=0处,加上batch
shape.add(0, list.size());
}
objectValue = recursiveExtract(objectValue);
}else{
// 此时认为是传入的单个String或者Int等
// 此时无法获取batch信息,故对shape不处理
// 由于Proto中为Repeated,需要把数据包装成list
if(objectValue instanceof String){
if(feedTypes_.get(protoDataKey)!= ElementType.Bytes_type.ordinal()){
throw new Exception("feedvar is not string-type,feed can`t be a single string.");
}
}else{
if(feedTypes_.get(protoDataKey)== ElementType.Bytes_type.ordinal()){
throw new Exception("feedvar is string-type,feed, feed can`t be a single int or others.");
}
}
List<Object> list = new ArrayList<>();
list.add(objectValue);
objectValue = list;
}
jsonTensor.put(protoDataKey,objectValue);
if(!batchFlag){
// 在index=0处,加上batch=1
shape.add(0, 1);
}
jsonTensor.put("shape", shape);
// 处理lod信息,支持INDArray Array Iterable
Object feedLodValue = null;
if(feedLod != null){
feedLodValue = feedLod.get(feed_alias_name);
if(feedLodValue != null) {
if(feedLodValue instanceof INDArray){
INDArray tempIndArray = (INDArray)feedLodValue;
feedLodValue = tempIndArray.data().asInt();
}else if(feedLodValue.getClass().isArray()){
// 如果是数组类型,则无须处理,直接使用即可。
}else if(feedLodValue instanceof Iterable){
// 如果为list,可能存在嵌套,此时需要展平
feedLodValue = recursiveExtract(feedLodValue);
}else{
throw new Exception("Lod must be INDArray or Array or Iterable.");
}
jsonTensor.put("lod", feedLodValue);
}
}
jsonTensorArray.put(jsonTensor);
}
}
}catch (Exception e) {
e.printStackTrace();
}
JSONArray jsonFetchList = new JSONArray(fetch);
/*
Iterator<String> fetchIterator = fetch.iterator();
while (fetchIterator.hasNext()) {
jsonFetchList.put(fetchIterator.next());
}
*/
JSONObject jsonRequest = new JSONObject();
jsonRequest.put("log_id",log_id);
jsonRequest.put("fetch_var_names", jsonFetchList);
jsonRequest.put("tensor",jsonTensorArray);
if(GLOG_v != null){
System.out.format("------- Final jsonRequest: %s\n", jsonRequest.toString());
}
return jsonRequest;
}
public Request process_proto_data(Map<String, Object> feedData,
Map<String, Object> feedLod,
List<String> fetch,
boolean batchFlag,
int log_id){
// 处理Tensor
Request.Builder request_builder = Request.newBuilder().addAllFetchVarNames(fetch).setLogId(log_id);
try{
if (null != feedData && feedData.size() > 0) {
// 通过map集成entrySet方法获取entity
Set<Entry<String, Object>> entrySet = feedData.entrySet();
// 循环遍历,获取迭代器
Iterator<Entry<String, Object>> iterator = entrySet.iterator();
while (iterator.hasNext()) {
Tensor.Builder tensor_builder = Tensor.newBuilder();
Entry<String, Object> mapEntry = iterator.next();
Object objectValue = mapEntry.getValue();
String feed_alias_name = mapEntry.getKey();
String feed_real_name = feedRealNames_.get(feed_alias_name);
List<Integer> shape = new ArrayList<Integer>(feedShapes_.get(feed_alias_name));
int element_type = feedTypes_.get(feed_alias_name);
tensor_builder.setAliasName(feed_alias_name);
tensor_builder.setName(feed_real_name);
tensor_builder.setElemType(element_type);
// 处理数据与shape
// 如果是INDArray类型,先转为一维.
// 此时shape为INDArray的shape
if(objectValue instanceof INDArray){
INDArray tempIndArray = (INDArray)objectValue;
long[] indarrayShape = tempIndArray.shape();
shape.clear();
for(long dim:indarrayShape){
shape.add((int)dim);
}
if(element_type == ElementType.Int64_type.ordinal()){
List<Long> iter = Arrays.stream(tempIndArray.data().asLong()).boxed().collect(Collectors.toList());
tensor_builder.addAllInt64Data(iter);
}else if(element_type == ElementType.Int32_type.ordinal()){
List<Integer> iter = Arrays.stream(tempIndArray.data().asInt()).boxed().collect(Collectors.toList());
tensor_builder.addAllIntData(iter);
}else if(element_type == ElementType.Float32_type.ordinal()){
List<Float> iter = Arrays.asList(ArrayUtils.toObject(tempIndArray.data().asFloat()));
tensor_builder.addAllFloatData(iter);
}else{
// 看接口是String还是Bytes
throw new Exception("INDArray 类型不支持");
}
}else if(objectValue.getClass().isArray()){
// 如果是数组类型,则无须处理,直接使用即可。
// 且数组无法嵌套,此时batch无法从数据中获取
// 默认batch维度为1,或者feedVar的shape信息中已包含batch
if(element_type == ElementType.Int64_type.ordinal()){
List<Long> iter = Arrays.stream((long[])objectValue).boxed().collect(Collectors.toList());
tensor_builder.addAllInt64Data(iter);
}else if(element_type == ElementType.Int32_type.ordinal()){
List<Integer> iter = Arrays.stream((int[])objectValue).boxed().collect(Collectors.toList());
tensor_builder.addAllIntData(iter);
}else if(element_type == ElementType.Float32_type.ordinal()){
List<Float> iter = Arrays.asList(ArrayUtils.toObject((float[])objectValue));
tensor_builder.addAllFloatData(iter);
}else{
List<String> iter = Arrays.asList((String[])objectValue);
tensor_builder.addAllData(iter);
}
}else if(objectValue instanceof List){
// 如果为list,可能存在嵌套,此时需要展平
// 如果batchFlag为True,则认为是嵌套list
// 此时取最外层为batch的维度
if (batchFlag) {
List<?> list = new ArrayList<>();
list = new ArrayList<>((Collection<?>)objectValue);
// 在index=0处,加上batch
shape.add(0, list.size());
}
if(element_type == ElementType.Int64_type.ordinal()){
tensor_builder.addAllInt64Data((List<Long>)(List)recursiveExtract(objectValue));
}else if(element_type == ElementType.Int32_type.ordinal()){
tensor_builder.addAllIntData((List<Integer>)(List)recursiveExtract(objectValue));
}else if(element_type == ElementType.Float32_type.ordinal()){
tensor_builder.addAllFloatData((List<Float>)(List)recursiveExtract(objectValue));
}else{
// 看接口是String还是Bytes
tensor_builder.addAllData((List<String>)(List)recursiveExtract(objectValue));
}
}else{
// 此时认为是传入的单个String或者Int等
// 此时无法获取batch信息,故对shape不处理
// 由于Proto中为Repeated,需要把数据包装成list
List<Object> tempList = new ArrayList<>();
tempList.add(objectValue);
if(element_type == ElementType.Int64_type.ordinal()){
tensor_builder.addAllInt64Data((List<Long>)(List)tempList);
}else if(element_type == ElementType.Int32_type.ordinal()){
tensor_builder.addAllIntData((List<Integer>)(List)tempList);
}else if(element_type == ElementType.Float32_type.ordinal()){
tensor_builder.addAllFloatData((List<Float>)(List)tempList);
}else{
// 看接口是String还是Bytes
tensor_builder.addAllData((List<String>)(List)tempList);
}
}
if(!batchFlag){
// 在index=0处,加上batch=1
shape.add(0, 1);
}
tensor_builder.addAllShape(shape);
// 处理lod信息,支持INDArray Array Iterable
Object feedLodValue = null;
if(feedLod != null){
feedLodValue = feedLod.get(feed_alias_name);
if(feedLodValue != null) {
if(feedLodValue instanceof INDArray){
INDArray tempIndArray = (INDArray)feedLodValue;
List<Integer> iter = Arrays.stream(tempIndArray.data().asInt()).boxed().collect(Collectors.toList());
tensor_builder.addAllLod(iter);
}else if(feedLodValue.getClass().isArray()){
// 如果是数组类型,则无须处理,直接使用即可。
List<Integer> iter = Arrays.stream((int[])feedLodValue).boxed().collect(Collectors.toList());
tensor_builder.addAllLod(iter);
}else if(feedLodValue instanceof Iterable){
// 如果为list,可能存在嵌套,此时需要展平
tensor_builder.addAllLod((List<Integer>)(List)recursiveExtract(feedLodValue));
}else{
throw new Exception("Lod must be INDArray or Array or Iterable.");
}
}
}
request_builder.addTensor(tensor_builder.build());
}
}
}catch (Exception e) {
e.printStackTrace();
}
return request_builder.build();
}
}
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package baidu.paddle_serving.predictor.general_model;
option java_multiple_files = true;
message Tensor {
repeated string data = 1;
repeated int32 int_data = 2;
repeated int64 int64_data = 3;
repeated float float_data = 4;
optional int32 elem_type =
5; // 0 means int64, 1 means float32, 2 means int32, 3 means string
repeated int32 shape = 6; // shape should include batch
repeated int32 lod = 7; // only for fetch tensor currently
optional string name = 8; // get from the Model prototxt
optional string alias_name = 9; // get from the Model prototxt
};
message Request {
repeated Tensor tensor = 1;
repeated string fetch_var_names = 2;
optional bool profile_server = 3 [ default = false ];
required uint64 log_id = 4 [ default = 0 ];
};
message Response {
repeated ModelOutput outputs = 1;
repeated int64 profile_time = 2;
};
message ModelOutput {
repeated Tensor tensor = 1;
optional string engine_name = 2;
}
service GeneralModelService {
rpc inference(Request) returns (Response) {}
rpc debug(Request) returns (Response) {}
};
......@@ -13,34 +13,45 @@
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_client.httpclient import HttpClient
from paddle_serving_client.httpclient import GeneralClient
import sys
import numpy as np
import time
client = HttpClient()
client = GeneralClient()
client.load_client_config(sys.argv[1])
'''
if you want use GRPC-client, set_use_grpc_client(True)
or you can directly use client.grpc_client_predict(...)
as for HTTP-client,set_use_grpc_client(False)(which is default)
or you can directly use client.http_client_predict(...)
'''
#client.set_use_grpc_client(True)
'''
if you want to enable Encrypt Module,uncommenting the following line
'''
# client.use_key("./key")
#client.use_key("./key")
'''
if you want to compress,uncommenting the following line
'''
#client.set_response_compress(True)
#client.set_request_compress(True)
'''
we recommend use Proto data format in HTTP-body, set True(which is default)
if you want use JSON data format in HTTP-body, set False
'''
#client.set_http_proto(True)
fetch_list = client.get_fetch_names()
import paddle
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.uci_housing.test(), buf_size=500),
batch_size=1)
fetch_list = client.get_fetch_names()
for data in test_reader():
new_data = np.zeros((1, 13)).astype("float32")
new_data[0] = data[0][0]
fetch_map = client.grpc_client_predict(
fetch_map = client.predict(
feed={"x": new_data}, fetch=fetch_list, batch=True)
print(fetch_map)
break
......@@ -66,7 +66,14 @@ def data_bytes_number(datalist):
return total_bytes_number
class HttpClient(object):
# 此文件名,暂时为httpclient.py,待后续测试后考虑是否替换client.py
# 默认使用http方式,默认使用Proto in HTTP-body
# 如果想使用JSON in HTTP-body, set_http_proto(False)
# Predict()是包装类http_client_predict/grpc_client_predict
# 可以直接调用需要的http_client_predict/grpc_client_predict
# 例如,如果想使用GRPC方式,set_use_grpc_client(True)
# 或者直接调用grpc_client_predict()
class GeneralClient(object):
def __init__(self,
ip="0.0.0.0",
port="9393",
......@@ -77,7 +84,7 @@ class HttpClient(object):
self.feed_shapes_ = {}
self.feed_types_ = {}
self.feed_names_to_idx_ = {}
self.http_timeout_ms = 200000
self.timeout_ms = 200000
self.ip = ip
self.port = port
self.server_port = port
......@@ -86,7 +93,9 @@ class HttpClient(object):
self.try_request_gzip = False
self.try_response_gzip = False
self.total_data_number = 0
self.http_proto = False
self.http_proto = True
self.max_body_size = 512 * 1024 * 1024
self.use_grpc_client = False
def load_client_config(self, model_config_path_list):
if isinstance(model_config_path_list, str):
......@@ -144,11 +153,14 @@ class HttpClient(object):
self.lod_tensor_set.add(var.alias_name)
return
def set_http_timeout_ms(self, http_timeout_ms):
if not isinstance(http_timeout_ms, int):
raise ValueError("http_timeout_ms must be int type.")
def set_max_body_size(self, max_body_size):
self.max_body_size = max_body_size
def set_timeout_ms(self, timeout_ms):
if not isinstance(timeout_ms, int):
raise ValueError("timeout_ms must be int type.")
else:
self.http_timeout_ms = http_timeout_ms
self.timeout_ms = timeout_ms
def set_ip(self, ip):
self.ip = ip
......@@ -168,6 +180,9 @@ class HttpClient(object):
def set_http_proto(self, http_proto):
self.http_proto = http_proto
def set_use_grpc_client(self, use_grpc_client):
self.use_grpc_client = use_grpc_client
# use_key is the function of encryption.
def use_key(self, key_filename):
with open(key_filename, "rb") as f:
......@@ -195,50 +210,6 @@ class HttpClient(object):
def get_fetch_names(self):
return self.fetch_names_
# feed 支持Numpy类型,以及直接List、tuple
# 不支持str类型,因为proto中为repeated.
def predict(self,
feed=None,
fetch=None,
batch=False,
need_variant_tag=False,
log_id=0):
feed_dict = self.get_feedvar_dict(feed)
fetch_list = self.get_legal_fetch(fetch)
headers = {}
postData = ''
if self.http_proto == True:
postData = self.process_proto_data(feed_dict, fetch_list, batch,
log_id).SerializeToString()
headers["Content-Type"] = "application/proto"
else:
postData = self.process_json_data(feed_dict, fetch_list, batch,
log_id)
headers["Content-Type"] = "application/json"
web_url = "http://" + self.ip + ":" + self.server_port + self.service_name
# 当数据区长度大于512字节时才压缩.
if self.try_request_gzip and self.total_data_number > 512:
postData = gzip.compress(bytes(postData, 'utf-8'))
headers["Content-Encoding"] = "gzip"
if self.try_response_gzip:
headers["Accept-encoding"] = "gzip"
# requests支持自动识别解压
result = requests.post(url=web_url, headers=headers, data=postData)
if result == None:
return None
if result.status_code == 200:
if result.headers["Content-Type"] == 'application/proto':
response = general_model_service_pb2.Response()
response.ParseFromString(result.content)
return response
else:
return result.json()
return result
def get_legal_fetch(self, fetch):
if fetch is None:
raise ValueError("You should specify feed and fetch for prediction")
......@@ -265,29 +236,32 @@ class HttpClient(object):
def get_feedvar_dict(self, feed):
if feed is None:
raise ValueError("You should specify feed and fetch for prediction")
feed_batch = []
feed_dict = {}
if isinstance(feed, dict):
feed_batch.append(feed)
feed_dict = feed
elif isinstance(feed, (list, str, tuple)):
# if input is a list or str or tuple, and the number of feed_var is 1.
# create a temp_dict { key = feed_var_name, value = list}
# put the temp_dict into the feed_batch.
if len(self.feed_names_) != 1:
raise ValueError(
"input is a list, but we got 0 or 2+ feed_var, don`t know how to divide the feed list"
)
temp_dict = {}
temp_dict[self.feed_names_[0]] = feed
feed_batch.append(temp_dict)
else:
raise ValueError("Feed only accepts dict and list of dict")
# create a feed_dict { key = feed_var_name, value = list}
if len(self.feed_names_) == 1:
feed_dict[self.feed_names_[0]] = feed
elif len(self.feed_names_) > 1:
if isinstance(feed, str):
raise ValueError(
"input is a str, but we got 2+ feed_var, don`t know how to divide the string"
)
# feed is a list or tuple
elif len(self.feed_names_) == len(feed):
for index in range(len(feed)):
feed_dict[self.feed_names_[index]] = feed[index]
else:
raise ValueError("len(feed) ≠ len(feed_var), error")
else:
raise ValueError("we got feed, but feed_var is None")
# batch_size must be 1, cause batch is already in Tensor.
if len(feed_batch) != 1:
raise ValueError("len of feed_batch can only be 1.")
else:
raise ValueError("Feed only accepts dict/str/list/tuple")
return feed_batch[0]
return feed_dict
def process_json_data(self, feed_dict, fetch_list, batch, log_id):
Request = {}
......@@ -429,6 +403,64 @@ class HttpClient(object):
tensor_dict["lod"] = lod
return tensor_dict
# feed结构必须为dict、List、tuple、string
# feed中数据支持Numpy、list、tuple、以及基本类型
# fetch默认是从模型的配置文件中获取全部的fetch_var
def predict(self,
feed=None,
fetch=None,
batch=False,
need_variant_tag=False,
log_id=0):
if self.use_grpc_client:
return self.grpc_client_predict(feed, fetch, batch,
need_variant_tag, log_id)
else:
return self.http_client_predict(feed, fetch, batch,
need_variant_tag, log_id)
def http_client_predict(self,
feed=None,
fetch=None,
batch=False,
need_variant_tag=False,
log_id=0):
feed_dict = self.get_feedvar_dict(feed)
fetch_list = self.get_legal_fetch(fetch)
headers = {}
postData = ''
if self.http_proto == True:
postData = self.process_proto_data(feed_dict, fetch_list, batch,
log_id).SerializeToString()
headers["Content-Type"] = "application/proto"
else:
postData = self.process_json_data(feed_dict, fetch_list, batch,
log_id)
headers["Content-Type"] = "application/json"
web_url = "http://" + self.ip + ":" + self.server_port + self.service_name
# 当数据区长度大于512字节时才压缩.
if self.try_request_gzip and self.total_data_number > 512:
postData = gzip.compress(bytes(postData, 'utf-8'))
headers["Content-Encoding"] = "gzip"
if self.try_response_gzip:
headers["Accept-encoding"] = "gzip"
# requests支持自动识别解压
result = requests.post(url=web_url, headers=headers, data=postData)
if result == None:
return None
if result.status_code == 200:
if result.headers["Content-Type"] == 'application/proto':
response = general_model_service_pb2.Response()
response.ParseFromString(result.content)
return response
else:
return result.json()
return result
def grpc_client_predict(self,
feed=None,
fetch=None,
......@@ -440,19 +472,17 @@ class HttpClient(object):
fetch_list = self.get_legal_fetch(fetch)
postData = self.process_proto_data(feed_dict, fetch_list, batch, log_id)
print('proto data', postData)
'''
# https://github.com/tensorflow/serving/issues/1382
options = [('grpc.max_receive_message_length', 512 * 1024 * 1024),
('grpc.max_send_message_length', 512 * 1024 * 1024),
('grpc.lb_policy_name', 'round_robin')]
'''
options = [('grpc.max_receive_message_length', self.max_body_size),
('grpc.max_send_message_length', self.max_body_size)]
endpoints = [self.ip + ":" + self.server_port]
g_endpoint = 'ipv4:{}'.format(','.join(endpoints))
print("my endpoint is ", g_endpoint)
self.channel_ = grpc.insecure_channel(g_endpoint, options=options)
self.stub_ = general_model_service_pb2_grpc.GeneralModelServiceStub(
self.channel_)
resp = self.stub_.inference(postData, timeout=self.http_timeout_ms)
resp = self.stub_.inference(postData, timeout=self.timeout_ms)
return resp
......@@ -108,9 +108,13 @@ def is_gpu_mode(unformatted_gpus):
def serve_args():
parser = argparse.ArgumentParser("serve")
parser.add_argument(
"--thread", type=int, default=2, help="Concurrency of server")
"--thread",
type=int,
default=4,
help="Concurrency of server,[4,1024]",
choices=range(4, 1025))
parser.add_argument(
"--port", type=int, default=9292, help="Port of the starting gpu")
"--port", type=int, default=9393, help="Port of the starting gpu")
parser.add_argument(
"--device", type=str, default="cpu", help="Type of device")
parser.add_argument(
......@@ -180,8 +184,6 @@ def serve_args():
default=False,
action="store_true",
help="Use gpu_multi_stream")
parser.add_argument(
"--grpc", default=False, action="store_true", help="Use grpc test")
return parser.parse_args()
......@@ -386,33 +388,4 @@ if __name__ == "__main__":
)
server.serve_forever()
else:
# this is for grpc Test
if args.grpc:
from .proto import general_model_service_pb2
sys.path.append(
os.path.join(
os.path.abspath(os.path.dirname(__file__)), 'proto'))
from .proto import general_model_service_pb2_grpc
import google.protobuf.text_format
from concurrent import futures
import grpc
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
class GeneralModelService(
general_model_service_pb2_grpc.GeneralModelServiceServicer):
def inference(self, request, context):
return general_model_service_pb2.Response()
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
general_model_service_pb2_grpc.add_GeneralModelServiceServicer_to_server(
GeneralModelService(), server)
server.add_insecure_port('[::]:9393')
server.start()
try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)
else:
start_multi_card(args)
start_multi_card(args)
......@@ -20,12 +20,12 @@ package baidu.paddle_serving.predictor.general_model;
option cc_generic_services = true;
message Tensor {
repeated bytes data = 1;
repeated string data = 1;
repeated int32 int_data = 2;
repeated int64 int64_data = 3;
repeated float float_data = 4;
optional int32 elem_type =
5; // 0 means int64, 1 means float32, 2 means int32, 3 means bytes(string)
5; // 0 means int64, 1 means float32, 2 means int32, 3 means string
repeated int32 shape = 6; // shape should include batch
repeated int32 lod = 7; // only for fetch tensor currently
optional string name = 8; // get from the Model prototxt
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册