未验证 提交 50beec12 编写于 作者: T TeslaZhao 提交者: GitHub

Merge pull request #1345 from HexToString/grpc_update

增加GRPC直接请求和HTTP(PROTO)
...@@ -39,11 +39,11 @@ INCLUDE_DIRECTORIES(${BRPC_INCLUDE_DIR}) ...@@ -39,11 +39,11 @@ INCLUDE_DIRECTORIES(${BRPC_INCLUDE_DIR})
set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib|${THIRD_PARTY_PATH}/install/glog") set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib|${THIRD_PARTY_PATH}/install/glog")
if(WITH_LITE) if(WITH_LITE)
set(BRPC_REPO "https://github.com/zhangjun/incubator-brpc.git") set(BRPC_REPO "https://github.com/apache/incubator-brpc")
set(BRPC_TAG "master") set(BRPC_TAG "1.0.0-rc01")
else() else()
set(BRPC_REPO "https://github.com/wangjiawei04/brpc") set(BRPC_REPO "https://github.com/apache/incubator-brpc")
set(BRPC_TAG "6d79e0b17f25107c35b705ea58d888083f59ff47") set(BRPC_TAG "1.0.0-rc01")
endif() endif()
# If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF # If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF
......
...@@ -33,7 +33,9 @@ if (WITH_PYTHON) ...@@ -33,7 +33,9 @@ if (WITH_PYTHON)
add_custom_target(general_model_config_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_custom_target(general_model_config_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(general_model_config_py_proto general_model_config_py_proto_init) add_dependencies(general_model_config_py_proto general_model_config_py_proto_init)
py_grpc_proto_compile(general_model_service_py_proto SRCS proto/general_model_service.proto)
add_custom_target(general_model_service_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(general_model_service_py_proto general_model_service_py_proto_init)
if (CLIENT) if (CLIENT)
py_proto_compile(sdk_configure_py_proto SRCS proto/sdk_configure.proto) py_proto_compile(sdk_configure_py_proto SRCS proto/sdk_configure.proto)
...@@ -50,7 +52,12 @@ if (WITH_PYTHON) ...@@ -50,7 +52,12 @@ if (WITH_PYTHON)
COMMAND cp -f *.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/proto COMMAND cp -f *.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/proto
COMMENT "Copy generated general_model_config proto file into directory paddle_serving_client/proto." COMMENT "Copy generated general_model_config proto file into directory paddle_serving_client/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
add_custom_command(TARGET general_model_service_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/proto
COMMAND cp -f *.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/proto
COMMENT "Copy generated general_model_service proto file into directory paddle_serving_client/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif() endif()
...@@ -77,6 +84,12 @@ if (WITH_PYTHON) ...@@ -77,6 +84,12 @@ if (WITH_PYTHON)
COMMAND cp -f *.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto COMMAND cp -f *.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto
COMMENT "Copy generated general_model_config proto file into directory paddle_serving_server/proto." COMMENT "Copy generated general_model_config proto file into directory paddle_serving_server/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
add_custom_command(TARGET general_model_service_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto
COMMAND cp -f *.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto
COMMENT "Copy generated general_model_service proto file into directory paddle_serving_server/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif() endif()
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package baidu.paddle_serving.predictor.general_model;
option java_multiple_files = true;
message Tensor {
repeated string data = 1;
repeated int32 int_data = 2;
repeated int64 int64_data = 3;
repeated float float_data = 4;
optional int32 elem_type =
5; // 0 means int64, 1 means float32, 2 means int32, 3 means string
repeated int32 shape = 6; // shape should include batch
repeated int32 lod = 7; // only for fetch tensor currently
optional string name = 8; // get from the Model prototxt
optional string alias_name = 9; // get from the Model prototxt
};
message Request {
repeated Tensor tensor = 1;
repeated string fetch_var_names = 2;
optional bool profile_server = 3 [ default = false ];
required uint64 log_id = 4 [ default = 0 ];
};
message Response {
repeated ModelOutput outputs = 1;
repeated int64 profile_time = 2;
};
message ModelOutput {
repeated Tensor tensor = 1;
optional string engine_name = 2;
}
service GeneralModelService {
rpc inference(Request) returns (Response) {}
rpc debug(Request) returns (Response) {}
};
...@@ -20,12 +20,12 @@ package baidu.paddle_serving.predictor.general_model; ...@@ -20,12 +20,12 @@ package baidu.paddle_serving.predictor.general_model;
option cc_generic_services = true; option cc_generic_services = true;
message Tensor { message Tensor {
repeated bytes data = 1; repeated string data = 1;
repeated int32 int_data = 2; repeated int32 int_data = 2;
repeated int64 int64_data = 3; repeated int64 int64_data = 3;
repeated float float_data = 4; repeated float float_data = 4;
optional int32 elem_type = optional int32 elem_type =
5; // 0 means int64, 1 means float32, 2 means int32, 3 means bytes(string) 5; // 0 means int64, 1 means float32, 2 means int32, 3 means string
repeated int32 shape = 6; // shape should include batch repeated int32 shape = 6; // shape should include batch
repeated int32 lod = 7; // only for fetch tensor currently repeated int32 lod = 7; // only for fetch tensor currently
optional string name = 8; // get from the Model prototxt optional string name = 8; // get from the Model prototxt
......
...@@ -20,12 +20,12 @@ package baidu.paddle_serving.predictor.general_model; ...@@ -20,12 +20,12 @@ package baidu.paddle_serving.predictor.general_model;
option cc_generic_services = true; option cc_generic_services = true;
message Tensor { message Tensor {
repeated bytes data = 1; repeated string data = 1;
repeated int32 int_data = 2; repeated int32 int_data = 2;
repeated int64 int64_data = 3; repeated int64 int64_data = 3;
repeated float float_data = 4; repeated float float_data = 4;
optional int32 elem_type = optional int32 elem_type =
5; // 0 means int64, 1 means float32, 2 means int32, 3 means bytes(string) 5; // 0 means int64, 1 means float32, 2 means int32, 3 means string
repeated int32 shape = 6; // shape should include batch repeated int32 shape = 6; // shape should include batch
repeated int32 lod = 7; // only for fetch tensor currently repeated int32 lod = 7; // only for fetch tensor currently
optional string name = 8; // get from the Model prototxt optional string name = 8; // get from the Model prototxt
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
([English](./README.md)|简体中文) ([English](./README.md)|简体中文)
### 开发环境 ## 开发环境
为了方便用户使用java进行开发,我们提供了编译好的Serving工程放置在java镜像当中,获取镜像并进入开发环境的方式是 为了方便用户使用java进行开发,我们提供了编译好的Serving工程放置在java镜像当中,获取镜像并进入开发环境的方式是
...@@ -15,7 +15,7 @@ cd Serving/java ...@@ -15,7 +15,7 @@ cd Serving/java
Serving文件夹是镜像生成时的develop分支工程目录,需要git pull 到最新版本,或者git checkout 到想要的分支。 Serving文件夹是镜像生成时的develop分支工程目录,需要git pull 到最新版本,或者git checkout 到想要的分支。
### 安装客户端依赖 ## 安装客户端依赖
由于依赖库数量庞大,因此镜像已经在生成时编译过一次,用户执行以下操作即可 由于依赖库数量庞大,因此镜像已经在生成时编译过一次,用户执行以下操作即可
...@@ -27,7 +27,34 @@ mvn compile ...@@ -27,7 +27,34 @@ mvn compile
mvn install mvn install
``` ```
### 启动服务端(Pipeline方式) ## 请求BRPC-Server
###服务端启动
以fit_a_line模型为例,服务端启动与常规BRPC-Server端启动命令一样。
```
cd ../../python/examples/fit_a_line
sh get_data.sh
python -m paddle_serving_server.serve --model uci_housing_model --thread 10 --port 9393
```
###客户端预测
客户端目前支持多种请求方式,目前支持HTTP(数据为JSON格式)、HTTP(数据为PROTO格式)、GRPC
推荐您使用HTTP(数据为PROTO格式),此时数据体为PROTO格式,传输的数据量小,速度快,目前已经帮用户实现了HTTP/GRPC的数据体(JSON/PROTO)的封装函数,详见[Client.java](./src/main/java/io/paddle/serving/client/Client.java)
```
cd ../../../java/examples/target
java -cp paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PaddleServingClientExample http_proto <configPath>
```
**注意 <configPath>为客户端配置文件,一般是名为serving_client_conf.prototxt的文件。**
更多示例详见[PaddleServingClientExample.java](./examples/src/main/java/PaddleServingClientExample.java)
## 请求Pipeline-Server
###服务端启动
对于input data type = string类型,以IMDB model ensemble模型为例,服务端启动 对于input data type = string类型,以IMDB model ensemble模型为例,服务端启动
...@@ -71,7 +98,7 @@ java -cp paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar Pipeli ...@@ -71,7 +98,7 @@ java -cp paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar Pipeli
### 注意事项 ### 注意事项
1.在示例中,所有非Pipeline模型都需要使用`--use_multilang`来启动GRPC多编程语言支持,以及端口号都是9393,如果需要别的端口,需要在java文件里修改 1.在示例中,端口号都是9393,ip默认设置为了0.0.0.0表示本机,注意ip和port需要与Server端对应
2.目前Serving已推出Pipeline模式(原理详见[Pipeline Serving](../doc/PIPELINE_SERVING_CN.md)),面向Java的Pipeline Serving Client已发布。 2.目前Serving已推出Pipeline模式(原理详见[Pipeline Serving](../doc/PIPELINE_SERVING_CN.md)),面向Java的Pipeline Serving Client已发布。
...@@ -84,5 +111,3 @@ java -cp paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar Pipeli ...@@ -84,5 +111,3 @@ java -cp paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar Pipeli
第一种是GPU Serving和Java Client在运行在同一个GPU镜像中,需要用户在启动GPU镜像后,把在java镜像中编译完成后的文件(位于/Serving/java目录下)拷贝到GPU镜像中的/Serving/java目录下。 第一种是GPU Serving和Java Client在运行在同一个GPU镜像中,需要用户在启动GPU镜像后,把在java镜像中编译完成后的文件(位于/Serving/java目录下)拷贝到GPU镜像中的/Serving/java目录下。
第二种是GPU Serving和Java Client分别在各自的docker镜像中(或具备编译开发环境的不同主机上)部署,此时仅需注意Java Client端与GPU Serving端的ip和port需要对应,详见上述注意事项中的第3项。 第二种是GPU Serving和Java Client分别在各自的docker镜像中(或具备编译开发环境的不同主机上)部署,此时仅需注意Java Client端与GPU Serving端的ip和port需要对应,详见上述注意事项中的第3项。
...@@ -11,7 +11,7 @@ import org.nd4j.linalg.factory.Nd4j; ...@@ -11,7 +11,7 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.*; import java.util.*;
public class PaddleServingClientExample { public class PaddleServingClientExample {
boolean fit_a_line(String model_config_path) { boolean http_proto(String model_config_path) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f, 0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
...@@ -24,7 +24,7 @@ public class PaddleServingClientExample { ...@@ -24,7 +24,7 @@ public class PaddleServingClientExample {
}}; }};
List<String> fetch = Arrays.asList("price"); List<String> fetch = Arrays.asList("price");
HttpClient client = new HttpClient(); Client client = new Client();
client.setIP("0.0.0.0"); client.setIP("0.0.0.0");
client.setPort("9393"); client.setPort("9393");
client.loadClientConfig(model_config_path); client.loadClientConfig(model_config_path);
...@@ -34,6 +34,55 @@ public class PaddleServingClientExample { ...@@ -34,6 +34,55 @@ public class PaddleServingClientExample {
return true; return true;
} }
boolean http_json(String model_config_path) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.createFromArray(data);
long[] batch_shape = {1,13};
INDArray batch_npdata = npdata.reshape(batch_shape);
HashMap<String, Object> feed_data
= new HashMap<String, Object>() {{
put("x", batch_npdata);
}};
List<String> fetch = Arrays.asList("price");
Client client = new Client();
//注意:跨docker,需要设置--net-host或直接访问另一个docker的ip
client.setIP("0.0.0.0");
client.setPort("9393");
client.set_http_proto(false);
client.loadClientConfig(model_config_path);
String result = client.predict(feed_data, fetch, true, 0);
System.out.println(result);
return true;
}
boolean grpc(String model_config_path) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.createFromArray(data);
long[] batch_shape = {1,13};
INDArray batch_npdata = npdata.reshape(batch_shape);
HashMap<String, Object> feed_data
= new HashMap<String, Object>() {{
put("x", batch_npdata);
}};
List<String> fetch = Arrays.asList("price");
Client client = new Client();
client.setIP("0.0.0.0");
client.setPort("9393");
client.loadClientConfig(model_config_path);
client.set_use_grpc_client(true);
String result = client.predict(feed_data, fetch, true, 0);
System.out.println(result);
return true;
}
boolean encrypt(String model_config_path,String keyFilePath) { boolean encrypt(String model_config_path,String keyFilePath) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f, 0.0582f, -0.0727f, -0.1583f, -0.0584f,
...@@ -47,7 +96,7 @@ public class PaddleServingClientExample { ...@@ -47,7 +96,7 @@ public class PaddleServingClientExample {
}}; }};
List<String> fetch = Arrays.asList("price"); List<String> fetch = Arrays.asList("price");
HttpClient client = new HttpClient(); Client client = new Client();
client.setIP("0.0.0.0"); client.setIP("0.0.0.0");
client.setPort("9393"); client.setPort("9393");
client.loadClientConfig(model_config_path); client.loadClientConfig(model_config_path);
...@@ -55,7 +104,6 @@ public class PaddleServingClientExample { ...@@ -55,7 +104,6 @@ public class PaddleServingClientExample {
try { try {
Thread.sleep(1000*3); // 休眠3秒,等待Server启动 Thread.sleep(1000*3); // 休眠3秒,等待Server启动
} catch (Exception e) { } catch (Exception e) {
//TODO: handle exception
} }
String result = client.predict(feed_data, fetch, true, 0); String result = client.predict(feed_data, fetch, true, 0);
...@@ -76,7 +124,7 @@ public class PaddleServingClientExample { ...@@ -76,7 +124,7 @@ public class PaddleServingClientExample {
}}; }};
List<String> fetch = Arrays.asList("price"); List<String> fetch = Arrays.asList("price");
HttpClient client = new HttpClient(); Client client = new Client();
client.setIP("0.0.0.0"); client.setIP("0.0.0.0");
client.setPort("9393"); client.setPort("9393");
client.loadClientConfig(model_config_path); client.loadClientConfig(model_config_path);
...@@ -127,7 +175,7 @@ public class PaddleServingClientExample { ...@@ -127,7 +175,7 @@ public class PaddleServingClientExample {
put("im_size", batch_im_size); put("im_size", batch_im_size);
}}; }};
List<String> fetch = Arrays.asList("save_infer_model/scale_0.tmp_0"); List<String> fetch = Arrays.asList("save_infer_model/scale_0.tmp_0");
HttpClient client = new HttpClient(); Client client = new Client();
client.setIP("0.0.0.0"); client.setIP("0.0.0.0");
client.setPort("9393"); client.setPort("9393");
client.loadClientConfig(model_config_path); client.loadClientConfig(model_config_path);
...@@ -149,7 +197,7 @@ public class PaddleServingClientExample { ...@@ -149,7 +197,7 @@ public class PaddleServingClientExample {
put("segment_ids", Nd4j.createFromArray(segment_ids)); put("segment_ids", Nd4j.createFromArray(segment_ids));
}}; }};
List<String> fetch = Arrays.asList("pooled_output"); List<String> fetch = Arrays.asList("pooled_output");
HttpClient client = new HttpClient(); Client client = new Client();
client.setIP("0.0.0.0"); client.setIP("0.0.0.0");
client.setPort("9393"); client.setPort("9393");
client.loadClientConfig(model_config_path); client.loadClientConfig(model_config_path);
...@@ -219,7 +267,7 @@ public class PaddleServingClientExample { ...@@ -219,7 +267,7 @@ public class PaddleServingClientExample {
put("embedding_0.tmp_0", Nd4j.createFromArray(embedding_0)); put("embedding_0.tmp_0", Nd4j.createFromArray(embedding_0));
}}; }};
List<String> fetch = Arrays.asList("prob"); List<String> fetch = Arrays.asList("prob");
HttpClient client = new HttpClient(); Client client = new Client();
client.setIP("0.0.0.0"); client.setIP("0.0.0.0");
client.setPort("9393"); client.setPort("9393");
client.loadClientConfig(model_config_path); client.loadClientConfig(model_config_path);
...@@ -236,13 +284,17 @@ public class PaddleServingClientExample { ...@@ -236,13 +284,17 @@ public class PaddleServingClientExample {
if (args.length < 2) { if (args.length < 2) {
System.out.println("Usage: java -cp <jar> PaddleServingClientExample <test-type> <configPath>."); System.out.println("Usage: java -cp <jar> PaddleServingClientExample <test-type> <configPath>.");
System.out.println("<test-type>: fit_a_line bert cube_local yolov4 encrypt"); System.out.println("<test-type>: http_proto http_json grpc bert cube_local yolov4 encrypt");
return; return;
} }
String testType = args[0]; String testType = args[0];
System.out.format("[Example] %s\n", testType); System.out.format("[Example] %s\n", testType);
if ("fit_a_line".equals(testType)) { if ("http_proto".equals(testType)) {
succ = e.fit_a_line(args[1]); succ = e.http_proto(args[1]);
} else if ("http_json".equals(testType)) {
succ = e.http_json(args[1]);
} else if ("grpc".equals(testType)) {
succ = e.grpc(args[1]);
} else if ("compress".equals(testType)) { } else if ("compress".equals(testType)) {
succ = e.compress(args[1]); succ = e.compress(args[1]);
} else if ("bert".equals(testType)) { } else if ("bert".equals(testType)) {
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package baidu.paddle_serving.predictor.general_model;
option java_multiple_files = true;
message Tensor {
repeated string data = 1;
repeated int32 int_data = 2;
repeated int64 int64_data = 3;
repeated float float_data = 4;
optional int32 elem_type =
5; // 0 means int64, 1 means float32, 2 means int32, 3 means string
repeated int32 shape = 6; // shape should include batch
repeated int32 lod = 7; // only for fetch tensor currently
optional string name = 8; // get from the Model prototxt
optional string alias_name = 9; // get from the Model prototxt
};
message Request {
repeated Tensor tensor = 1;
repeated string fetch_var_names = 2;
optional bool profile_server = 3 [ default = false ];
required uint64 log_id = 4 [ default = 0 ];
};
message Response {
repeated ModelOutput outputs = 1;
repeated int64 profile_time = 2;
};
message ModelOutput {
repeated Tensor tensor = 1;
optional string engine_name = 2;
}
service GeneralModelService {
rpc inference(Request) returns (Response) {}
rpc debug(Request) returns (Response) {}
};
...@@ -13,24 +13,41 @@ ...@@ -13,24 +13,41 @@
# limitations under the License. # limitations under the License.
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
from paddle_serving_client.httpclient import HttpClient from paddle_serving_client.httpclient import GeneralClient
import sys import sys
import numpy as np import numpy as np
import time import time
client = HttpClient() client = GeneralClient()
client.load_client_config(sys.argv[1]) client.load_client_config(sys.argv[1])
# if you want to enable Encrypt Module,uncommenting the following line '''
# client.use_key("./key") if you want use GRPC-client, set_use_grpc_client(True)
client.set_response_compress(True) or you can directly use client.grpc_client_predict(...)
client.set_request_compress(True) as for HTTP-client,set_use_grpc_client(False)(which is default)
fetch_list = client.get_fetch_names() or you can directly use client.http_client_predict(...)
'''
#client.set_use_grpc_client(True)
'''
if you want to enable Encrypt Module,uncommenting the following line
'''
#client.use_key("./key")
'''
if you want to compress,uncommenting the following line
'''
#client.set_response_compress(True)
#client.set_request_compress(True)
'''
we recommend use Proto data format in HTTP-body, set True(which is default)
if you want use JSON data format in HTTP-body, set False
'''
#client.set_http_proto(True)
import paddle import paddle
test_reader = paddle.batch( test_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.uci_housing.test(), buf_size=500), paddle.dataset.uci_housing.test(), buf_size=500),
batch_size=1) batch_size=1)
fetch_list = client.get_fetch_names()
for data in test_reader(): for data in test_reader():
new_data = np.zeros((1, 13)).astype("float32") new_data = np.zeros((1, 13)).astype("float32")
new_data[0] = data[0][0] new_data[0] = data[0][0]
......
...@@ -345,68 +345,69 @@ class Client(object): ...@@ -345,68 +345,69 @@ class Client(object):
raise ValueError( raise ValueError(
"Fetch names should not be empty or out of saved fetch list.") "Fetch names should not be empty or out of saved fetch list.")
feed_i = feed_batch[0] feed_dict = feed_batch[0]
for key in feed_i: for key in feed_dict:
if ".lod" not in key and key not in self.feed_names_: if ".lod" not in key and key not in self.feed_names_:
raise ValueError("Wrong feed name: {}.".format(key)) raise ValueError("Wrong feed name: {}.".format(key))
if ".lod" in key: if ".lod" in key:
continue continue
self.shape_check(feed_i, key) self.shape_check(feed_dict, key)
if self.feed_types_[key] in int_type: if self.feed_types_[key] in int_type:
int_feed_names.append(key) int_feed_names.append(key)
shape_lst = [] shape_lst = []
if batch == False: if batch == False:
feed_i[key] = np.expand_dims(feed_i[key], 0).repeat( feed_dict[key] = np.expand_dims(feed_dict[key], 0).repeat(
1, axis=0) 1, axis=0)
if isinstance(feed_i[key], np.ndarray): if isinstance(feed_dict[key], np.ndarray):
shape_lst.extend(list(feed_i[key].shape)) shape_lst.extend(list(feed_dict[key].shape))
int_shape.append(shape_lst) int_shape.append(shape_lst)
else: else:
int_shape.append(self.feed_shapes_[key]) int_shape.append(self.feed_shapes_[key])
if "{}.lod".format(key) in feed_i: if "{}.lod".format(key) in feed_dict:
int_lod_slot_batch.append(feed_i["{}.lod".format(key)]) int_lod_slot_batch.append(feed_dict["{}.lod".format(key)])
else: else:
int_lod_slot_batch.append([]) int_lod_slot_batch.append([])
if isinstance(feed_i[key], np.ndarray): if isinstance(feed_dict[key], np.ndarray):
int_slot.append(np.ascontiguousarray(feed_i[key])) int_slot.append(np.ascontiguousarray(feed_dict[key]))
self.has_numpy_input = True self.has_numpy_input = True
else: else:
int_slot.append(np.ascontiguousarray(feed_i[key])) int_slot.append(np.ascontiguousarray(feed_dict[key]))
self.all_numpy_input = False self.all_numpy_input = False
elif self.feed_types_[key] in float_type: elif self.feed_types_[key] in float_type:
float_feed_names.append(key) float_feed_names.append(key)
shape_lst = [] shape_lst = []
if batch == False: if batch == False:
feed_i[key] = np.expand_dims(feed_i[key], 0).repeat( feed_dict[key] = np.expand_dims(feed_dict[key], 0).repeat(
1, axis=0) 1, axis=0)
if isinstance(feed_i[key], np.ndarray): if isinstance(feed_dict[key], np.ndarray):
shape_lst.extend(list(feed_i[key].shape)) shape_lst.extend(list(feed_dict[key].shape))
float_shape.append(shape_lst) float_shape.append(shape_lst)
else: else:
float_shape.append(self.feed_shapes_[key]) float_shape.append(self.feed_shapes_[key])
if "{}.lod".format(key) in feed_i: if "{}.lod".format(key) in feed_dict:
float_lod_slot_batch.append(feed_i["{}.lod".format(key)]) float_lod_slot_batch.append(feed_dict["{}.lod".format(key)])
else: else:
float_lod_slot_batch.append([]) float_lod_slot_batch.append([])
if isinstance(feed_i[key], np.ndarray): if isinstance(feed_dict[key], np.ndarray):
float_slot.append(np.ascontiguousarray(feed_i[key])) float_slot.append(np.ascontiguousarray(feed_dict[key]))
self.has_numpy_input = True self.has_numpy_input = True
else: else:
float_slot.append(np.ascontiguousarray(feed_i[key])) float_slot.append(np.ascontiguousarray(feed_dict[key]))
self.all_numpy_input = False self.all_numpy_input = False
#if input is string, feed is not numpy. #if input is string, feed is not numpy.
elif self.feed_types_[key] in string_type: elif self.feed_types_[key] in string_type:
string_feed_names.append(key) string_feed_names.append(key)
string_shape.append(self.feed_shapes_[key]) string_shape.append(self.feed_shapes_[key])
if "{}.lod".format(key) in feed_i: if "{}.lod".format(key) in feed_dict:
string_lod_slot_batch.append(feed_i["{}.lod".format(key)]) string_lod_slot_batch.append(feed_dict["{}.lod".format(
key)])
else: else:
string_lod_slot_batch.append([]) string_lod_slot_batch.append([])
string_slot.append(feed_i[key]) string_slot.append(feed_dict[key])
self.has_numpy_input = True self.has_numpy_input = True
self.profile_.record('py_prepro_1') self.profile_.record('py_prepro_1')
......
...@@ -107,9 +107,13 @@ def is_gpu_mode(unformatted_gpus): ...@@ -107,9 +107,13 @@ def is_gpu_mode(unformatted_gpus):
def serve_args(): def serve_args():
parser = argparse.ArgumentParser("serve") parser = argparse.ArgumentParser("serve")
parser.add_argument( parser.add_argument(
"--thread", type=int, default=2, help="Concurrency of server") "--thread",
type=int,
default=4,
help="Concurrency of server,[4,1024]",
choices=range(4, 1025))
parser.add_argument( parser.add_argument(
"--port", type=int, default=9292, help="Port of the starting gpu") "--port", type=int, default=9393, help="Port of the starting gpu")
parser.add_argument( parser.add_argument(
"--device", type=str, default="cpu", help="Type of device") "--device", type=str, default="cpu", help="Type of device")
parser.add_argument( parser.add_argument(
......
...@@ -123,7 +123,7 @@ class WebService(object): ...@@ -123,7 +123,7 @@ class WebService(object):
workdir, workdir,
port=9292, port=9292,
gpus=None, gpus=None,
thread_num=2, thread_num=4,
mem_optim=True, mem_optim=True,
use_lite=False, use_lite=False,
use_xpu=False, use_xpu=False,
...@@ -236,7 +236,7 @@ class WebService(object): ...@@ -236,7 +236,7 @@ class WebService(object):
use_lite=False, use_lite=False,
use_xpu=False, use_xpu=False,
ir_optim=False, ir_optim=False,
thread_num=2, thread_num=4,
mem_optim=True, mem_optim=True,
use_trt=False, use_trt=False,
gpu_multi_stream=False, gpu_multi_stream=False,
......
...@@ -20,12 +20,12 @@ package baidu.paddle_serving.predictor.general_model; ...@@ -20,12 +20,12 @@ package baidu.paddle_serving.predictor.general_model;
option cc_generic_services = true; option cc_generic_services = true;
message Tensor { message Tensor {
repeated bytes data = 1; repeated string data = 1;
repeated int32 int_data = 2; repeated int32 int_data = 2;
repeated int64 int64_data = 3; repeated int64 int64_data = 3;
repeated float float_data = 4; repeated float float_data = 4;
optional int32 elem_type = optional int32 elem_type =
5; // 0 means int64, 1 means float32, 2 means int32, 3 means bytes(string) 5; // 0 means int64, 1 means float32, 2 means int32, 3 means string
repeated int32 shape = 6; // shape should include batch repeated int32 shape = 6; // shape should include batch
repeated int32 lod = 7; // only for fetch tensor currently repeated int32 lod = 7; // only for fetch tensor currently
optional string name = 8; // get from the Model prototxt optional string name = 8; // get from the Model prototxt
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册