未验证 提交 50beec12 编写于 作者: T TeslaZhao 提交者: GitHub

Merge pull request #1345 from HexToString/grpc_update

增加GRPC直接请求和HTTP(PROTO)
...@@ -39,11 +39,11 @@ INCLUDE_DIRECTORIES(${BRPC_INCLUDE_DIR}) ...@@ -39,11 +39,11 @@ INCLUDE_DIRECTORIES(${BRPC_INCLUDE_DIR})
set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib|${THIRD_PARTY_PATH}/install/glog") set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib|${THIRD_PARTY_PATH}/install/glog")
if(WITH_LITE) if(WITH_LITE)
set(BRPC_REPO "https://github.com/zhangjun/incubator-brpc.git") set(BRPC_REPO "https://github.com/apache/incubator-brpc")
set(BRPC_TAG "master") set(BRPC_TAG "1.0.0-rc01")
else() else()
set(BRPC_REPO "https://github.com/wangjiawei04/brpc") set(BRPC_REPO "https://github.com/apache/incubator-brpc")
set(BRPC_TAG "6d79e0b17f25107c35b705ea58d888083f59ff47") set(BRPC_TAG "1.0.0-rc01")
endif() endif()
# If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF # If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF
......
...@@ -33,7 +33,9 @@ if (WITH_PYTHON) ...@@ -33,7 +33,9 @@ if (WITH_PYTHON)
add_custom_target(general_model_config_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_custom_target(general_model_config_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(general_model_config_py_proto general_model_config_py_proto_init) add_dependencies(general_model_config_py_proto general_model_config_py_proto_init)
py_grpc_proto_compile(general_model_service_py_proto SRCS proto/general_model_service.proto)
add_custom_target(general_model_service_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(general_model_service_py_proto general_model_service_py_proto_init)
if (CLIENT) if (CLIENT)
py_proto_compile(sdk_configure_py_proto SRCS proto/sdk_configure.proto) py_proto_compile(sdk_configure_py_proto SRCS proto/sdk_configure.proto)
...@@ -51,6 +53,11 @@ if (WITH_PYTHON) ...@@ -51,6 +53,11 @@ if (WITH_PYTHON)
COMMENT "Copy generated general_model_config proto file into directory paddle_serving_client/proto." COMMENT "Copy generated general_model_config proto file into directory paddle_serving_client/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
add_custom_command(TARGET general_model_service_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/proto
COMMAND cp -f *.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/proto
COMMENT "Copy generated general_model_service proto file into directory paddle_serving_client/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif() endif()
...@@ -78,6 +85,12 @@ if (WITH_PYTHON) ...@@ -78,6 +85,12 @@ if (WITH_PYTHON)
COMMENT "Copy generated general_model_config proto file into directory paddle_serving_server/proto." COMMENT "Copy generated general_model_config proto file into directory paddle_serving_server/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
add_custom_command(TARGET general_model_service_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto
COMMAND cp -f *.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto
COMMENT "Copy generated general_model_service proto file into directory paddle_serving_server/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif() endif()
endif() endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package baidu.paddle_serving.predictor.general_model;
option java_multiple_files = true;
message Tensor {
repeated string data = 1;
repeated int32 int_data = 2;
repeated int64 int64_data = 3;
repeated float float_data = 4;
optional int32 elem_type =
5; // 0 means int64, 1 means float32, 2 means int32, 3 means string
repeated int32 shape = 6; // shape should include batch
repeated int32 lod = 7; // only for fetch tensor currently
optional string name = 8; // get from the Model prototxt
optional string alias_name = 9; // get from the Model prototxt
};
message Request {
repeated Tensor tensor = 1;
repeated string fetch_var_names = 2;
optional bool profile_server = 3 [ default = false ];
required uint64 log_id = 4 [ default = 0 ];
};
message Response {
repeated ModelOutput outputs = 1;
repeated int64 profile_time = 2;
};
message ModelOutput {
repeated Tensor tensor = 1;
optional string engine_name = 2;
}
service GeneralModelService {
rpc inference(Request) returns (Response) {}
rpc debug(Request) returns (Response) {}
};
...@@ -20,12 +20,12 @@ package baidu.paddle_serving.predictor.general_model; ...@@ -20,12 +20,12 @@ package baidu.paddle_serving.predictor.general_model;
option cc_generic_services = true; option cc_generic_services = true;
message Tensor { message Tensor {
repeated bytes data = 1; repeated string data = 1;
repeated int32 int_data = 2; repeated int32 int_data = 2;
repeated int64 int64_data = 3; repeated int64 int64_data = 3;
repeated float float_data = 4; repeated float float_data = 4;
optional int32 elem_type = optional int32 elem_type =
5; // 0 means int64, 1 means float32, 2 means int32, 3 means bytes(string) 5; // 0 means int64, 1 means float32, 2 means int32, 3 means string
repeated int32 shape = 6; // shape should include batch repeated int32 shape = 6; // shape should include batch
repeated int32 lod = 7; // only for fetch tensor currently repeated int32 lod = 7; // only for fetch tensor currently
optional string name = 8; // get from the Model prototxt optional string name = 8; // get from the Model prototxt
......
...@@ -20,12 +20,12 @@ package baidu.paddle_serving.predictor.general_model; ...@@ -20,12 +20,12 @@ package baidu.paddle_serving.predictor.general_model;
option cc_generic_services = true; option cc_generic_services = true;
message Tensor { message Tensor {
repeated bytes data = 1; repeated string data = 1;
repeated int32 int_data = 2; repeated int32 int_data = 2;
repeated int64 int64_data = 3; repeated int64 int64_data = 3;
repeated float float_data = 4; repeated float float_data = 4;
optional int32 elem_type = optional int32 elem_type =
5; // 0 means int64, 1 means float32, 2 means int32, 3 means bytes(string) 5; // 0 means int64, 1 means float32, 2 means int32, 3 means string
repeated int32 shape = 6; // shape should include batch repeated int32 shape = 6; // shape should include batch
repeated int32 lod = 7; // only for fetch tensor currently repeated int32 lod = 7; // only for fetch tensor currently
optional string name = 8; // get from the Model prototxt optional string name = 8; // get from the Model prototxt
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
([English](./README.md)|简体中文) ([English](./README.md)|简体中文)
### 开发环境 ## 开发环境
为了方便用户使用java进行开发,我们提供了编译好的Serving工程放置在java镜像当中,获取镜像并进入开发环境的方式是 为了方便用户使用java进行开发,我们提供了编译好的Serving工程放置在java镜像当中,获取镜像并进入开发环境的方式是
...@@ -15,7 +15,7 @@ cd Serving/java ...@@ -15,7 +15,7 @@ cd Serving/java
Serving文件夹是镜像生成时的develop分支工程目录,需要git pull 到最新版本,或者git checkout 到想要的分支。 Serving文件夹是镜像生成时的develop分支工程目录,需要git pull 到最新版本,或者git checkout 到想要的分支。
### 安装客户端依赖 ## 安装客户端依赖
由于依赖库数量庞大,因此镜像已经在生成时编译过一次,用户执行以下操作即可 由于依赖库数量庞大,因此镜像已经在生成时编译过一次,用户执行以下操作即可
...@@ -27,7 +27,34 @@ mvn compile ...@@ -27,7 +27,34 @@ mvn compile
mvn install mvn install
``` ```
### 启动服务端(Pipeline方式) ## 请求BRPC-Server
###服务端启动
以fit_a_line模型为例,服务端启动与常规BRPC-Server端启动命令一样。
```
cd ../../python/examples/fit_a_line
sh get_data.sh
python -m paddle_serving_server.serve --model uci_housing_model --thread 10 --port 9393
```
###客户端预测
客户端目前支持多种请求方式,目前支持HTTP(数据为JSON格式)、HTTP(数据为PROTO格式)、GRPC
推荐您使用HTTP(数据为PROTO格式),此时数据体为PROTO格式,传输的数据量小,速度快,目前已经帮用户实现了HTTP/GRPC的数据体(JSON/PROTO)的封装函数,详见[Client.java](./src/main/java/io/paddle/serving/client/Client.java)
```
cd ../../../java/examples/target
java -cp paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PaddleServingClientExample http_proto <configPath>
```
**注意 <configPath>为客户端配置文件,一般是名为serving_client_conf.prototxt的文件。**
更多示例详见[PaddleServingClientExample.java](./examples/src/main/java/PaddleServingClientExample.java)
## 请求Pipeline-Server
###服务端启动
对于input data type = string类型,以IMDB model ensemble模型为例,服务端启动 对于input data type = string类型,以IMDB model ensemble模型为例,服务端启动
...@@ -71,7 +98,7 @@ java -cp paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar Pipeli ...@@ -71,7 +98,7 @@ java -cp paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar Pipeli
### 注意事项 ### 注意事项
1.在示例中,所有非Pipeline模型都需要使用`--use_multilang`来启动GRPC多编程语言支持,以及端口号都是9393,如果需要别的端口,需要在java文件里修改 1.在示例中,端口号都是9393,ip默认设置为了0.0.0.0表示本机,注意ip和port需要与Server端对应
2.目前Serving已推出Pipeline模式(原理详见[Pipeline Serving](../doc/PIPELINE_SERVING_CN.md)),面向Java的Pipeline Serving Client已发布。 2.目前Serving已推出Pipeline模式(原理详见[Pipeline Serving](../doc/PIPELINE_SERVING_CN.md)),面向Java的Pipeline Serving Client已发布。
...@@ -84,5 +111,3 @@ java -cp paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar Pipeli ...@@ -84,5 +111,3 @@ java -cp paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar Pipeli
第一种是GPU Serving和Java Client在运行在同一个GPU镜像中,需要用户在启动GPU镜像后,把在java镜像中编译完成后的文件(位于/Serving/java目录下)拷贝到GPU镜像中的/Serving/java目录下。 第一种是GPU Serving和Java Client在运行在同一个GPU镜像中,需要用户在启动GPU镜像后,把在java镜像中编译完成后的文件(位于/Serving/java目录下)拷贝到GPU镜像中的/Serving/java目录下。
第二种是GPU Serving和Java Client分别在各自的docker镜像中(或具备编译开发环境的不同主机上)部署,此时仅需注意Java Client端与GPU Serving端的ip和port需要对应,详见上述注意事项中的第3项。 第二种是GPU Serving和Java Client分别在各自的docker镜像中(或具备编译开发环境的不同主机上)部署,此时仅需注意Java Client端与GPU Serving端的ip和port需要对应,详见上述注意事项中的第3项。
...@@ -11,7 +11,7 @@ import org.nd4j.linalg.factory.Nd4j; ...@@ -11,7 +11,7 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.*; import java.util.*;
public class PaddleServingClientExample { public class PaddleServingClientExample {
boolean fit_a_line(String model_config_path) { boolean http_proto(String model_config_path) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f, 0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
...@@ -24,7 +24,7 @@ public class PaddleServingClientExample { ...@@ -24,7 +24,7 @@ public class PaddleServingClientExample {
}}; }};
List<String> fetch = Arrays.asList("price"); List<String> fetch = Arrays.asList("price");
HttpClient client = new HttpClient(); Client client = new Client();
client.setIP("0.0.0.0"); client.setIP("0.0.0.0");
client.setPort("9393"); client.setPort("9393");
client.loadClientConfig(model_config_path); client.loadClientConfig(model_config_path);
...@@ -34,6 +34,55 @@ public class PaddleServingClientExample { ...@@ -34,6 +34,55 @@ public class PaddleServingClientExample {
return true; return true;
} }
boolean http_json(String model_config_path) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.createFromArray(data);
long[] batch_shape = {1,13};
INDArray batch_npdata = npdata.reshape(batch_shape);
HashMap<String, Object> feed_data
= new HashMap<String, Object>() {{
put("x", batch_npdata);
}};
List<String> fetch = Arrays.asList("price");
Client client = new Client();
//注意:跨docker,需要设置--net-host或直接访问另一个docker的ip
client.setIP("0.0.0.0");
client.setPort("9393");
client.set_http_proto(false);
client.loadClientConfig(model_config_path);
String result = client.predict(feed_data, fetch, true, 0);
System.out.println(result);
return true;
}
boolean grpc(String model_config_path) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.createFromArray(data);
long[] batch_shape = {1,13};
INDArray batch_npdata = npdata.reshape(batch_shape);
HashMap<String, Object> feed_data
= new HashMap<String, Object>() {{
put("x", batch_npdata);
}};
List<String> fetch = Arrays.asList("price");
Client client = new Client();
client.setIP("0.0.0.0");
client.setPort("9393");
client.loadClientConfig(model_config_path);
client.set_use_grpc_client(true);
String result = client.predict(feed_data, fetch, true, 0);
System.out.println(result);
return true;
}
boolean encrypt(String model_config_path,String keyFilePath) { boolean encrypt(String model_config_path,String keyFilePath) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f, 0.0582f, -0.0727f, -0.1583f, -0.0584f,
...@@ -47,7 +96,7 @@ public class PaddleServingClientExample { ...@@ -47,7 +96,7 @@ public class PaddleServingClientExample {
}}; }};
List<String> fetch = Arrays.asList("price"); List<String> fetch = Arrays.asList("price");
HttpClient client = new HttpClient(); Client client = new Client();
client.setIP("0.0.0.0"); client.setIP("0.0.0.0");
client.setPort("9393"); client.setPort("9393");
client.loadClientConfig(model_config_path); client.loadClientConfig(model_config_path);
...@@ -55,7 +104,6 @@ public class PaddleServingClientExample { ...@@ -55,7 +104,6 @@ public class PaddleServingClientExample {
try { try {
Thread.sleep(1000*3); // 休眠3秒,等待Server启动 Thread.sleep(1000*3); // 休眠3秒,等待Server启动
} catch (Exception e) { } catch (Exception e) {
//TODO: handle exception
} }
String result = client.predict(feed_data, fetch, true, 0); String result = client.predict(feed_data, fetch, true, 0);
...@@ -76,7 +124,7 @@ public class PaddleServingClientExample { ...@@ -76,7 +124,7 @@ public class PaddleServingClientExample {
}}; }};
List<String> fetch = Arrays.asList("price"); List<String> fetch = Arrays.asList("price");
HttpClient client = new HttpClient(); Client client = new Client();
client.setIP("0.0.0.0"); client.setIP("0.0.0.0");
client.setPort("9393"); client.setPort("9393");
client.loadClientConfig(model_config_path); client.loadClientConfig(model_config_path);
...@@ -127,7 +175,7 @@ public class PaddleServingClientExample { ...@@ -127,7 +175,7 @@ public class PaddleServingClientExample {
put("im_size", batch_im_size); put("im_size", batch_im_size);
}}; }};
List<String> fetch = Arrays.asList("save_infer_model/scale_0.tmp_0"); List<String> fetch = Arrays.asList("save_infer_model/scale_0.tmp_0");
HttpClient client = new HttpClient(); Client client = new Client();
client.setIP("0.0.0.0"); client.setIP("0.0.0.0");
client.setPort("9393"); client.setPort("9393");
client.loadClientConfig(model_config_path); client.loadClientConfig(model_config_path);
...@@ -149,7 +197,7 @@ public class PaddleServingClientExample { ...@@ -149,7 +197,7 @@ public class PaddleServingClientExample {
put("segment_ids", Nd4j.createFromArray(segment_ids)); put("segment_ids", Nd4j.createFromArray(segment_ids));
}}; }};
List<String> fetch = Arrays.asList("pooled_output"); List<String> fetch = Arrays.asList("pooled_output");
HttpClient client = new HttpClient(); Client client = new Client();
client.setIP("0.0.0.0"); client.setIP("0.0.0.0");
client.setPort("9393"); client.setPort("9393");
client.loadClientConfig(model_config_path); client.loadClientConfig(model_config_path);
...@@ -219,7 +267,7 @@ public class PaddleServingClientExample { ...@@ -219,7 +267,7 @@ public class PaddleServingClientExample {
put("embedding_0.tmp_0", Nd4j.createFromArray(embedding_0)); put("embedding_0.tmp_0", Nd4j.createFromArray(embedding_0));
}}; }};
List<String> fetch = Arrays.asList("prob"); List<String> fetch = Arrays.asList("prob");
HttpClient client = new HttpClient(); Client client = new Client();
client.setIP("0.0.0.0"); client.setIP("0.0.0.0");
client.setPort("9393"); client.setPort("9393");
client.loadClientConfig(model_config_path); client.loadClientConfig(model_config_path);
...@@ -236,13 +284,17 @@ public class PaddleServingClientExample { ...@@ -236,13 +284,17 @@ public class PaddleServingClientExample {
if (args.length < 2) { if (args.length < 2) {
System.out.println("Usage: java -cp <jar> PaddleServingClientExample <test-type> <configPath>."); System.out.println("Usage: java -cp <jar> PaddleServingClientExample <test-type> <configPath>.");
System.out.println("<test-type>: fit_a_line bert cube_local yolov4 encrypt"); System.out.println("<test-type>: http_proto http_json grpc bert cube_local yolov4 encrypt");
return; return;
} }
String testType = args[0]; String testType = args[0];
System.out.format("[Example] %s\n", testType); System.out.format("[Example] %s\n", testType);
if ("fit_a_line".equals(testType)) { if ("http_proto".equals(testType)) {
succ = e.fit_a_line(args[1]); succ = e.http_proto(args[1]);
} else if ("http_json".equals(testType)) {
succ = e.http_json(args[1]);
} else if ("grpc".equals(testType)) {
succ = e.grpc(args[1]);
} else if ("compress".equals(testType)) { } else if ("compress".equals(testType)) {
succ = e.compress(args[1]); succ = e.compress(args[1]);
} else if ("bert".equals(testType)) { } else if ("bert".equals(testType)) {
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package baidu.paddle_serving.predictor.general_model;
option java_multiple_files = true;
message Tensor {
repeated string data = 1;
repeated int32 int_data = 2;
repeated int64 int64_data = 3;
repeated float float_data = 4;
optional int32 elem_type =
5; // 0 means int64, 1 means float32, 2 means int32, 3 means string
repeated int32 shape = 6; // shape should include batch
repeated int32 lod = 7; // only for fetch tensor currently
optional string name = 8; // get from the Model prototxt
optional string alias_name = 9; // get from the Model prototxt
};
message Request {
repeated Tensor tensor = 1;
repeated string fetch_var_names = 2;
optional bool profile_server = 3 [ default = false ];
required uint64 log_id = 4 [ default = 0 ];
};
message Response {
repeated ModelOutput outputs = 1;
repeated int64 profile_time = 2;
};
message ModelOutput {
repeated Tensor tensor = 1;
optional string engine_name = 2;
}
service GeneralModelService {
rpc inference(Request) returns (Response) {}
rpc debug(Request) returns (Response) {}
};
...@@ -13,24 +13,41 @@ ...@@ -13,24 +13,41 @@
# limitations under the License. # limitations under the License.
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
from paddle_serving_client.httpclient import HttpClient from paddle_serving_client.httpclient import GeneralClient
import sys import sys
import numpy as np import numpy as np
import time import time
client = HttpClient() client = GeneralClient()
client.load_client_config(sys.argv[1]) client.load_client_config(sys.argv[1])
# if you want to enable Encrypt Module,uncommenting the following line '''
# client.use_key("./key") if you want use GRPC-client, set_use_grpc_client(True)
client.set_response_compress(True) or you can directly use client.grpc_client_predict(...)
client.set_request_compress(True) as for HTTP-client,set_use_grpc_client(False)(which is default)
fetch_list = client.get_fetch_names() or you can directly use client.http_client_predict(...)
'''
#client.set_use_grpc_client(True)
'''
if you want to enable Encrypt Module,uncommenting the following line
'''
#client.use_key("./key")
'''
if you want to compress,uncommenting the following line
'''
#client.set_response_compress(True)
#client.set_request_compress(True)
'''
we recommend use Proto data format in HTTP-body, set True(which is default)
if you want use JSON data format in HTTP-body, set False
'''
#client.set_http_proto(True)
import paddle import paddle
test_reader = paddle.batch( test_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.uci_housing.test(), buf_size=500), paddle.dataset.uci_housing.test(), buf_size=500),
batch_size=1) batch_size=1)
fetch_list = client.get_fetch_names()
for data in test_reader(): for data in test_reader():
new_data = np.zeros((1, 13)).astype("float32") new_data = np.zeros((1, 13)).astype("float32")
new_data[0] = data[0][0] new_data[0] = data[0][0]
......
...@@ -345,68 +345,69 @@ class Client(object): ...@@ -345,68 +345,69 @@ class Client(object):
raise ValueError( raise ValueError(
"Fetch names should not be empty or out of saved fetch list.") "Fetch names should not be empty or out of saved fetch list.")
feed_i = feed_batch[0] feed_dict = feed_batch[0]
for key in feed_i: for key in feed_dict:
if ".lod" not in key and key not in self.feed_names_: if ".lod" not in key and key not in self.feed_names_:
raise ValueError("Wrong feed name: {}.".format(key)) raise ValueError("Wrong feed name: {}.".format(key))
if ".lod" in key: if ".lod" in key:
continue continue
self.shape_check(feed_i, key) self.shape_check(feed_dict, key)
if self.feed_types_[key] in int_type: if self.feed_types_[key] in int_type:
int_feed_names.append(key) int_feed_names.append(key)
shape_lst = [] shape_lst = []
if batch == False: if batch == False:
feed_i[key] = np.expand_dims(feed_i[key], 0).repeat( feed_dict[key] = np.expand_dims(feed_dict[key], 0).repeat(
1, axis=0) 1, axis=0)
if isinstance(feed_i[key], np.ndarray): if isinstance(feed_dict[key], np.ndarray):
shape_lst.extend(list(feed_i[key].shape)) shape_lst.extend(list(feed_dict[key].shape))
int_shape.append(shape_lst) int_shape.append(shape_lst)
else: else:
int_shape.append(self.feed_shapes_[key]) int_shape.append(self.feed_shapes_[key])
if "{}.lod".format(key) in feed_i: if "{}.lod".format(key) in feed_dict:
int_lod_slot_batch.append(feed_i["{}.lod".format(key)]) int_lod_slot_batch.append(feed_dict["{}.lod".format(key)])
else: else:
int_lod_slot_batch.append([]) int_lod_slot_batch.append([])
if isinstance(feed_i[key], np.ndarray): if isinstance(feed_dict[key], np.ndarray):
int_slot.append(np.ascontiguousarray(feed_i[key])) int_slot.append(np.ascontiguousarray(feed_dict[key]))
self.has_numpy_input = True self.has_numpy_input = True
else: else:
int_slot.append(np.ascontiguousarray(feed_i[key])) int_slot.append(np.ascontiguousarray(feed_dict[key]))
self.all_numpy_input = False self.all_numpy_input = False
elif self.feed_types_[key] in float_type: elif self.feed_types_[key] in float_type:
float_feed_names.append(key) float_feed_names.append(key)
shape_lst = [] shape_lst = []
if batch == False: if batch == False:
feed_i[key] = np.expand_dims(feed_i[key], 0).repeat( feed_dict[key] = np.expand_dims(feed_dict[key], 0).repeat(
1, axis=0) 1, axis=0)
if isinstance(feed_i[key], np.ndarray): if isinstance(feed_dict[key], np.ndarray):
shape_lst.extend(list(feed_i[key].shape)) shape_lst.extend(list(feed_dict[key].shape))
float_shape.append(shape_lst) float_shape.append(shape_lst)
else: else:
float_shape.append(self.feed_shapes_[key]) float_shape.append(self.feed_shapes_[key])
if "{}.lod".format(key) in feed_i: if "{}.lod".format(key) in feed_dict:
float_lod_slot_batch.append(feed_i["{}.lod".format(key)]) float_lod_slot_batch.append(feed_dict["{}.lod".format(key)])
else: else:
float_lod_slot_batch.append([]) float_lod_slot_batch.append([])
if isinstance(feed_i[key], np.ndarray): if isinstance(feed_dict[key], np.ndarray):
float_slot.append(np.ascontiguousarray(feed_i[key])) float_slot.append(np.ascontiguousarray(feed_dict[key]))
self.has_numpy_input = True self.has_numpy_input = True
else: else:
float_slot.append(np.ascontiguousarray(feed_i[key])) float_slot.append(np.ascontiguousarray(feed_dict[key]))
self.all_numpy_input = False self.all_numpy_input = False
#if input is string, feed is not numpy. #if input is string, feed is not numpy.
elif self.feed_types_[key] in string_type: elif self.feed_types_[key] in string_type:
string_feed_names.append(key) string_feed_names.append(key)
string_shape.append(self.feed_shapes_[key]) string_shape.append(self.feed_shapes_[key])
if "{}.lod".format(key) in feed_i: if "{}.lod".format(key) in feed_dict:
string_lod_slot_batch.append(feed_i["{}.lod".format(key)]) string_lod_slot_batch.append(feed_dict["{}.lod".format(
key)])
else: else:
string_lod_slot_batch.append([]) string_lod_slot_batch.append([])
string_slot.append(feed_i[key]) string_slot.append(feed_dict[key])
self.has_numpy_input = True self.has_numpy_input = True
self.profile_.record('py_prepro_1') self.profile_.record('py_prepro_1')
......
...@@ -21,7 +21,13 @@ import google.protobuf.text_format ...@@ -21,7 +21,13 @@ import google.protobuf.text_format
import gzip import gzip
from collections import Iterable from collections import Iterable
import base64 import base64
import sys
import grpc
from .proto import general_model_service_pb2
sys.path.append(
os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto'))
from .proto import general_model_service_pb2_grpc
#param 'type'(which is in feed_var or fetch_var) = 0 means dataType is int64 #param 'type'(which is in feed_var or fetch_var) = 0 means dataType is int64
#param 'type'(which is in feed_var or fetch_var) = 1 means dataType is float32 #param 'type'(which is in feed_var or fetch_var) = 1 means dataType is float32
#param 'type'(which is in feed_var or fetch_var) = 2 means dataType is int32 #param 'type'(which is in feed_var or fetch_var) = 2 means dataType is int32
...@@ -60,7 +66,14 @@ def data_bytes_number(datalist): ...@@ -60,7 +66,14 @@ def data_bytes_number(datalist):
return total_bytes_number return total_bytes_number
class HttpClient(object): # 此文件名,暂时为httpclient.py,待后续测试后考虑是否替换client.py
# 默认使用http方式,默认使用Proto in HTTP-body
# 如果想使用JSON in HTTP-body, set_http_proto(False)
# Predict()是包装类http_client_predict/grpc_client_predict
# 可以直接调用需要的http_client_predict/grpc_client_predict
# 例如,如果想使用GRPC方式,set_use_grpc_client(True)
# 或者直接调用grpc_client_predict()
class GeneralClient(object):
def __init__(self, def __init__(self,
ip="0.0.0.0", ip="0.0.0.0",
port="9393", port="9393",
...@@ -71,7 +84,7 @@ class HttpClient(object): ...@@ -71,7 +84,7 @@ class HttpClient(object):
self.feed_shapes_ = {} self.feed_shapes_ = {}
self.feed_types_ = {} self.feed_types_ = {}
self.feed_names_to_idx_ = {} self.feed_names_to_idx_ = {}
self.http_timeout_ms = 200000 self.timeout_ms = 200000
self.ip = ip self.ip = ip
self.port = port self.port = port
self.server_port = port self.server_port = port
...@@ -79,6 +92,10 @@ class HttpClient(object): ...@@ -79,6 +92,10 @@ class HttpClient(object):
self.key = None self.key = None
self.try_request_gzip = False self.try_request_gzip = False
self.try_response_gzip = False self.try_response_gzip = False
self.total_data_number = 0
self.http_proto = True
self.max_body_size = 512 * 1024 * 1024
self.use_grpc_client = False
def load_client_config(self, model_config_path_list): def load_client_config(self, model_config_path_list):
if isinstance(model_config_path_list, str): if isinstance(model_config_path_list, str):
...@@ -136,11 +153,14 @@ class HttpClient(object): ...@@ -136,11 +153,14 @@ class HttpClient(object):
self.lod_tensor_set.add(var.alias_name) self.lod_tensor_set.add(var.alias_name)
return return
def set_http_timeout_ms(self, http_timeout_ms): def set_max_body_size(self, max_body_size):
if not isinstance(http_timeout_ms, int): self.max_body_size = max_body_size
raise ValueError("http_timeout_ms must be int type.")
def set_timeout_ms(self, timeout_ms):
if not isinstance(timeout_ms, int):
raise ValueError("timeout_ms must be int type.")
else: else:
self.http_timeout_ms = http_timeout_ms self.timeout_ms = timeout_ms
def set_ip(self, ip): def set_ip(self, ip):
self.ip = ip self.ip = ip
...@@ -157,6 +177,12 @@ class HttpClient(object): ...@@ -157,6 +177,12 @@ class HttpClient(object):
def set_response_compress(self, try_response_gzip): def set_response_compress(self, try_response_gzip):
self.try_response_gzip = try_response_gzip self.try_response_gzip = try_response_gzip
def set_http_proto(self, http_proto):
self.http_proto = http_proto
def set_use_grpc_client(self, use_grpc_client):
self.use_grpc_client = use_grpc_client
# use_key is the function of encryption. # use_key is the function of encryption.
def use_key(self, key_filename): def use_key(self, key_filename):
with open(key_filename, "rb") as f: with open(key_filename, "rb") as f:
...@@ -171,7 +197,6 @@ class HttpClient(object): ...@@ -171,7 +197,6 @@ class HttpClient(object):
req = json.dumps({}) req = json.dumps({})
r = requests.post(encrypt_url, req) r = requests.post(encrypt_url, req)
result = r.json() result = r.json()
print(result)
if "endpoint_list" not in result: if "endpoint_list" not in result:
raise ValueError("server not ready") raise ValueError("server not ready")
else: else:
...@@ -184,15 +209,8 @@ class HttpClient(object): ...@@ -184,15 +209,8 @@ class HttpClient(object):
def get_fetch_names(self): def get_fetch_names(self):
return self.fetch_names_ return self.fetch_names_
# feed 支持Numpy类型,以及直接List、tuple def get_legal_fetch(self, fetch):
# 不支持str类型,因为proto中为repeated. if fetch is None:
def predict(self,
feed=None,
fetch=None,
batch=False,
need_variant_tag=False,
log_id=0):
if feed is None or fetch is None:
raise ValueError("You should specify feed and fetch for prediction") raise ValueError("You should specify feed and fetch for prediction")
fetch_list = [] fetch_list = []
...@@ -203,27 +221,6 @@ class HttpClient(object): ...@@ -203,27 +221,6 @@ class HttpClient(object):
else: else:
raise ValueError("Fetch only accepts string and list of string") raise ValueError("Fetch only accepts string and list of string")
feed_batch = []
if isinstance(feed, dict):
feed_batch.append(feed)
elif isinstance(feed, (list, str, tuple)):
# if input is a list or str or tuple, and the number of feed_var is 1.
# create a temp_dict { key = feed_var_name, value = list}
# put the temp_dict into the feed_batch.
if len(self.feed_names_) != 1:
raise ValueError(
"input is a list, but we got 0 or 2+ feed_var, don`t know how to divide the feed list"
)
temp_dict = {}
temp_dict[self.feed_names_[0]] = feed
feed_batch.append(temp_dict)
else:
raise ValueError("Feed only accepts dict and list of dict")
# batch_size must be 1, cause batch is already in Tensor.
if len(feed_batch) != 1:
raise ValueError("len of feed_batch can only be 1.")
fetch_names = [] fetch_names = []
for key in fetch_list: for key in fetch_list:
if key in self.fetch_names_: if key in self.fetch_names_:
...@@ -233,51 +230,137 @@ class HttpClient(object): ...@@ -233,51 +230,137 @@ class HttpClient(object):
raise ValueError( raise ValueError(
"Fetch names should not be empty or out of saved fetch list.") "Fetch names should not be empty or out of saved fetch list.")
return {} return {}
return fetch_names
feed_i = feed_batch[0] def get_feedvar_dict(self, feed):
if feed is None:
raise ValueError("You should specify feed and fetch for prediction")
feed_dict = {}
if isinstance(feed, dict):
feed_dict = feed
elif isinstance(feed, (list, str, tuple)):
# if input is a list or str or tuple, and the number of feed_var is 1.
# create a feed_dict { key = feed_var_name, value = list}
if len(self.feed_names_) == 1:
feed_dict[self.feed_names_[0]] = feed
elif len(self.feed_names_) > 1:
if isinstance(feed, str):
raise ValueError(
"input is a str, but we got 2+ feed_var, don`t know how to divide the string"
)
# feed is a list or tuple
elif len(self.feed_names_) == len(feed):
for index in range(len(feed)):
feed_dict[self.feed_names_[index]] = feed[index]
else:
raise ValueError("len(feed) ≠ len(feed_var), error")
else:
raise ValueError("we got feed, but feed_var is None")
else:
raise ValueError("Feed only accepts dict/str/list/tuple")
return feed_dict
def process_json_data(self, feed_dict, fetch_list, batch, log_id):
Request = {} Request = {}
Request["fetch_var_names"] = fetch_list Request["fetch_var_names"] = fetch_list
Request["log_id"] = int(log_id) Request["log_id"] = int(log_id)
Request["tensor"] = [] Request["tensor"] = []
index = 0 for key in feed_dict:
total_data_number = 0
for key in feed_i:
if ".lod" not in key and key not in self.feed_names_: if ".lod" not in key and key not in self.feed_names_:
raise ValueError("Wrong feed name: {}.".format(key)) raise ValueError("Wrong feed name: {}.".format(key))
if ".lod" in key: if ".lod" in key:
continue continue
Request["tensor"].append('') tensor_dict = self.process_tensor(key, feed_dict, batch)
Request["tensor"][index] = {} data_key = tensor_dict["data_key"]
data_value = tensor_dict["data_value"]
tensor = {}
tensor[data_key] = data_value
tensor["shape"] = tensor_dict["shape"]
tensor["elem_type"] = tensor_dict["elem_type"]
tensor["name"] = tensor_dict["name"]
tensor["alias_name"] = tensor_dict["alias_name"]
if "lod" in tensor_dict:
tensor["lod"] = tensor_dict["lod"]
Request["tensor"].append(tensor)
# request
postData = json.dumps(Request)
return postData
def process_proto_data(self, feed_dict, fetch_list, batch, log_id):
req = general_model_service_pb2.Request()
req.fetch_var_names.extend(fetch_list)
req.log_id = log_id
for key in feed_dict:
tensor = general_model_service_pb2.Tensor()
if ".lod" not in key and key not in self.feed_names_:
raise ValueError("Wrong feed name: {}.".format(key))
if ".lod" in key:
continue
tensor_dict = self.process_tensor(key, feed_dict, batch)
tensor.shape.extend(tensor_dict["shape"])
tensor.name = tensor_dict["name"]
tensor.alias_name = tensor_dict["alias_name"]
tensor.elem_type = tensor_dict["elem_type"]
if "lod" in tensor_dict:
tensor.lod.extend(tensor_dict["lod"])
if tensor_dict["data_key"] == "int64_data":
tensor.int64_data.extend(tensor_dict["data_value"])
elif tensor_dict["data_key"] == "float_data":
tensor.float_data.extend(tensor_dict["data_value"])
elif tensor_dict["data_key"] == "int_data":
tensor.int_data.extend(tensor_dict["data_value"])
elif tensor_dict["data_key"] == "data":
tensor.data.extend(tensor_dict["data_value"])
else:
raise ValueError(
"tensor element_type must be one of [int64_data,float_data,int_data,data]."
)
req.tensor.append(tensor)
return req
def process_tensor(self, key, feed_dict, batch):
lod = [] lod = []
if "{}.lod".format(key) in feed_i: if "{}.lod".format(key) in feed_dict:
lod = feed_i["{}.lod".format(key)] lod = feed_dict["{}.lod".format(key)]
shape = self.feed_shapes_[key].copy() shape = self.feed_shapes_[key].copy()
elem_type = self.feed_types_[key] elem_type = self.feed_types_[key]
data_value = feed_i[key] data_value = feed_dict[key]
data_key = proto_data_key_list[elem_type] data_key = proto_data_key_list[elem_type]
proto_index = self.feed_names_to_idx_[key]
name = self.feed_real_names[proto_index]
alias_name = key
# feed_i[key] 可以是np.ndarray # feed_dict[key] 可以是np.ndarray
# 也可以是list或tuple # 也可以是list或tuple
# 当np.ndarray需要处理为list # 当np.ndarray需要处理为list
if isinstance(feed_i[key], np.ndarray): if isinstance(feed_dict[key], np.ndarray):
shape_lst = [] shape_lst = []
# 0维numpy 需要在外层再加一个[] # 0维numpy 需要在外层再加一个[]
if feed_i[key].ndim == 0: if feed_dict[key].ndim == 0:
data_value = [feed_i[key].tolist()] data_value = [feed_dict[key].tolist()]
shape_lst.append(1) shape_lst.append(1)
else: else:
shape_lst.extend(list(feed_i[key].shape)) shape_lst.extend(list(feed_dict[key].shape))
shape = shape_lst shape = shape_lst
data_value = feed_i[key].flatten().tolist() data_value = feed_dict[key].flatten().tolist()
# 当Batch为False,shape字段前插一个1,表示batch维 # 当Batch为False,shape字段前插一个1,表示batch维
# 当Batch为True,则直接使用numpy.shape作为batch维度 # 当Batch为True,则直接使用numpy.shape作为batch维度
if batch == False: if batch == False:
shape.insert(0, 1) shape.insert(0, 1)
# 当是list或tuple时,需要把多层嵌套展开 # 当是list或tuple时,需要把多层嵌套展开
elif isinstance(feed_i[key], (list, tuple)): elif isinstance(feed_dict[key], (list, tuple)):
# 当Batch为False,shape字段前插一个1,表示batch维 # 当Batch为False,shape字段前插一个1,表示batch维
# 当Batch为True, 由于list并不像numpy那样规整,所以 # 当Batch为True, 由于list并不像numpy那样规整,所以
# 无法获取shape,此时取第一维度作为Batch维度. # 无法获取shape,此时取第一维度作为Batch维度.
...@@ -285,16 +368,16 @@ class HttpClient(object): ...@@ -285,16 +368,16 @@ class HttpClient(object):
if batch == False: if batch == False:
shape.insert(0, 1) shape.insert(0, 1)
else: else:
shape.insert(0, len(feed_i[key])) shape.insert(0, len(feed_dict[key]))
feed_i[key] = [x for x in list_flatten(feed_i[key])] feed_dict[key] = [x for x in list_flatten(feed_dict[key])]
data_value = feed_i[key] data_value = feed_dict[key]
else: else:
# 输入可能是单个的str或int值等 # 输入可能是单个的str或int值等
# 此时先统一处理为一个list # 此时先统一处理为一个list
# 由于输入比较特殊,shape保持原feedvar中不变 # 由于输入比较特殊,shape保持原feedvar中不变
data_value = [] data_value = []
data_value.append(feed_i[key]) data_value.append(feed_dict[key])
if isinstance(feed_i[key], str): if isinstance(feed_dict[key], str):
if self.feed_types_[key] != bytes_type: if self.feed_types_[key] != bytes_type:
raise ValueError( raise ValueError(
"feedvar is not string-type,feed can`t be a single string." "feedvar is not string-type,feed can`t be a single string."
...@@ -306,34 +389,114 @@ class HttpClient(object): ...@@ -306,34 +389,114 @@ class HttpClient(object):
) )
# 如果不压缩,那么不需要统计数据量。 # 如果不压缩,那么不需要统计数据量。
if self.try_request_gzip: if self.try_request_gzip:
total_data_number = total_data_number + data_bytes_number( self.total_data_number = self.total_data_number + data_bytes_number(
data_value) data_value)
Request["tensor"][index]["elem_type"] = elem_type tensor_dict = {}
Request["tensor"][index]["shape"] = shape tensor_dict["data_key"] = data_key
Request["tensor"][index][data_key] = data_value tensor_dict["data_value"] = data_value
proto_index = self.feed_names_to_idx_[key] tensor_dict["shape"] = shape
Request["tensor"][index]["name"] = self.feed_real_names[proto_index] tensor_dict["elem_type"] = elem_type
Request["tensor"][index]["alias_name"] = key tensor_dict["name"] = name
tensor_dict["alias_name"] = alias_name
if len(lod) > 0: if len(lod) > 0:
Request["tensor"][index]["lod"] = lod tensor_dict["lod"] = lod
index = index + 1 return tensor_dict
result = None # feed结构必须为dict、List、tuple、string
# request # feed中数据支持Numpy、list、tuple、以及基本类型
web_url = "http://" + self.ip + ":" + self.server_port + self.service_name # fetch默认是从模型的配置文件中获取全部的fetch_var
postData = json.dumps(Request) def predict(self,
feed=None,
fetch=None,
batch=False,
need_variant_tag=False,
log_id=0):
if self.use_grpc_client:
return self.grpc_client_predict(feed, fetch, batch,
need_variant_tag, log_id)
else:
return self.http_client_predict(feed, fetch, batch,
need_variant_tag, log_id)
def http_client_predict(self,
feed=None,
fetch=None,
batch=False,
need_variant_tag=False,
log_id=0):
feed_dict = self.get_feedvar_dict(feed)
fetch_list = self.get_legal_fetch(fetch)
headers = {} headers = {}
postData = ''
if self.http_proto == True:
postData = self.process_proto_data(feed_dict, fetch_list, batch,
log_id).SerializeToString()
headers["Content-Type"] = "application/proto"
else:
postData = self.process_json_data(feed_dict, fetch_list, batch,
log_id)
headers["Content-Type"] = "application/json"
web_url = "http://" + self.ip + ":" + self.server_port + self.service_name
# 当数据区长度大于512字节时才压缩. # 当数据区长度大于512字节时才压缩.
if self.try_request_gzip and total_data_number > 512: try:
if self.try_request_gzip and self.total_data_number > 512:
origin_data = postData
postData = gzip.compress(bytes(postData, 'utf-8')) postData = gzip.compress(bytes(postData, 'utf-8'))
headers["Content-Encoding"] = "gzip" headers["Content-Encoding"] = "gzip"
if self.try_response_gzip: if self.try_response_gzip:
headers["Accept-encoding"] = "gzip" headers["Accept-encoding"] = "gzip"
# 压缩异常,使用原始数据
except:
print("compress error, we will use the no-compress data")
headers.pop("Content-Encoding", "nokey")
postData = origin_data
# requests支持自动识别解压 # requests支持自动识别解压
try:
result = requests.post(url=web_url, headers=headers, data=postData) result = requests.post(url=web_url, headers=headers, data=postData)
except:
print("http post error")
return None
else:
if result == None: if result == None:
return None return None
if result.status_code == 200: if result.status_code == 200:
if result.headers["Content-Type"] == 'application/proto':
response = general_model_service_pb2.Response()
response.ParseFromString(result.content)
return response
else:
return result.json() return result.json()
return result return result
def grpc_client_predict(self,
feed=None,
fetch=None,
batch=False,
need_variant_tag=False,
log_id=0):
feed_dict = self.get_feedvar_dict(feed)
fetch_list = self.get_legal_fetch(fetch)
postData = self.process_proto_data(feed_dict, fetch_list, batch, log_id)
# https://github.com/tensorflow/serving/issues/1382
options = [('grpc.max_receive_message_length', self.max_body_size),
('grpc.max_send_message_length', self.max_body_size)]
endpoints = [self.ip + ":" + self.server_port]
g_endpoint = 'ipv4:{}'.format(','.join(endpoints))
self.channel_ = grpc.insecure_channel(g_endpoint, options=options)
self.stub_ = general_model_service_pb2_grpc.GeneralModelServiceStub(
self.channel_)
try:
resp = self.stub_.inference(postData, timeout=self.timeout_ms)
except:
print("Grpc inference error occur")
return None
else:
return resp
...@@ -107,9 +107,13 @@ def is_gpu_mode(unformatted_gpus): ...@@ -107,9 +107,13 @@ def is_gpu_mode(unformatted_gpus):
def serve_args(): def serve_args():
parser = argparse.ArgumentParser("serve") parser = argparse.ArgumentParser("serve")
parser.add_argument( parser.add_argument(
"--thread", type=int, default=2, help="Concurrency of server") "--thread",
type=int,
default=4,
help="Concurrency of server,[4,1024]",
choices=range(4, 1025))
parser.add_argument( parser.add_argument(
"--port", type=int, default=9292, help="Port of the starting gpu") "--port", type=int, default=9393, help="Port of the starting gpu")
parser.add_argument( parser.add_argument(
"--device", type=str, default="cpu", help="Type of device") "--device", type=str, default="cpu", help="Type of device")
parser.add_argument( parser.add_argument(
......
...@@ -123,7 +123,7 @@ class WebService(object): ...@@ -123,7 +123,7 @@ class WebService(object):
workdir, workdir,
port=9292, port=9292,
gpus=None, gpus=None,
thread_num=2, thread_num=4,
mem_optim=True, mem_optim=True,
use_lite=False, use_lite=False,
use_xpu=False, use_xpu=False,
...@@ -236,7 +236,7 @@ class WebService(object): ...@@ -236,7 +236,7 @@ class WebService(object):
use_lite=False, use_lite=False,
use_xpu=False, use_xpu=False,
ir_optim=False, ir_optim=False,
thread_num=2, thread_num=4,
mem_optim=True, mem_optim=True,
use_trt=False, use_trt=False,
gpu_multi_stream=False, gpu_multi_stream=False,
......
...@@ -20,12 +20,12 @@ package baidu.paddle_serving.predictor.general_model; ...@@ -20,12 +20,12 @@ package baidu.paddle_serving.predictor.general_model;
option cc_generic_services = true; option cc_generic_services = true;
message Tensor { message Tensor {
repeated bytes data = 1; repeated string data = 1;
repeated int32 int_data = 2; repeated int32 int_data = 2;
repeated int64 int64_data = 3; repeated int64 int64_data = 3;
repeated float float_data = 4; repeated float float_data = 4;
optional int32 elem_type = optional int32 elem_type =
5; // 0 means int64, 1 means float32, 2 means int32, 3 means bytes(string) 5; // 0 means int64, 1 means float32, 2 means int32, 3 means string
repeated int32 shape = 6; // shape should include batch repeated int32 shape = 6; // shape should include batch
repeated int32 lod = 7; // only for fetch tensor currently repeated int32 lod = 7; // only for fetch tensor currently
optional string name = 8; // get from the Model prototxt optional string name = 8; // get from the Model prototxt
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册