未验证 提交 2c516c8d 编写于 作者: J Jiawei Wang 提交者: GitHub

Merge pull request #959 from HexToString/develop

fix python bug and add java example
......@@ -27,7 +27,7 @@ mvn compile
mvn install
```
### Start the server
### Start the server(not pipeline)
Take the fit_a_line model as an example, the server starts
......@@ -59,6 +59,48 @@ Client prediction
java -cp paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PaddleServingClientExample yolov4 ../../../python/examples/yolov4/000000570688.jpg
# The case of yolov4 needs to specify a picture as input
```
### Start the server(pipeline)
as for input data type = string,take IMDB model ensemble as an example,the server starts
```
cd ../../python/examples/pipeline/imdb_model_ensemble
sh get_data.sh
python -m paddle_serving_server.serve --model imdb_cnn_model --port 9292 &> cnn.log &
python -m paddle_serving_server.serve --model imdb_bow_model --port 9393 &> bow.log &
python test_pipeline_server.py &>pipeline.log &
```
Client prediction(Synchronous)
```
cd ../../../java/examples/target
java -cp paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PipelineClientExample string_imdb_predict
```
Client prediction(Asynchronous)
```
cd ../../../java/examples/target
java -cp paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PipelineClientExample asyn_predict
```
as for input data type = INDArray,take uci_housing_model as an example,the server starts
```
cd ../../python/examples/pipeline/simple_web_service
sh get_data.sh
python web_service_java.py &>log.txt &
```
Client prediction(Synchronous)
```
cd ../../../java/examples/target
java -cp paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PipelineClientExample indarray_predict
```
### Customization guidance
......@@ -70,4 +112,8 @@ The second is to deploy GPU Serving and Java Client separately. If they are on t
**It should be noted that in the example, all models need to use `--use_multilang` to start GRPC multi-programming language support, and the port number is 9393. If you need another port, you need to modify it in the java file**
**Currently Serving has launched the Pipeline mode (see [Pipeline Serving](../doc/PIPELINE_SERVING.md) for details). The next version (0.4.1) of the Pipeline Serving Client for Java will be released. **
**Currently Serving has launched the Pipeline mode (see [Pipeline Serving](../doc/PIPELINE_SERVING.md) for details). Pipeline Serving Client for Java is released, the next version multi-thread java client example will be released**
**It should be noted that in the example, Java Pipeline Client code is in path /Java/Examples and /Java/src/main, and the Pipeline server code is in path /python/examples/pipeline/**
......@@ -27,7 +27,7 @@ mvn compile
mvn install
```
### 启动服务端
### 启动服务端(非pipeline方式)
以fit_a_line模型为例,服务端启动
......@@ -58,6 +58,49 @@ python -m paddle_serving_server_gpu.serve --model yolov4_model --port 9393 --gpu
# in /Serving/java/examples/target
java -cp paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PaddleServingClientExample yolov4 ../../../python/examples/yolov4/000000570688.jpg
# yolov4的案例需要指定一个图片作为输入
```
### 启动服务端(Pipeline方式)
对于input data type = string类型,以IMDB model ensemble模型为例,服务端启动
```
cd ../../python/examples/pipeline/imdb_model_ensemble
sh get_data.sh
python -m paddle_serving_server.serve --model imdb_cnn_model --port 9292 &> cnn.log &
python -m paddle_serving_server.serve --model imdb_bow_model --port 9393 &> bow.log &
python test_pipeline_server.py &>pipeline.log &
```
客户端预测(同步)
```
cd ../../../java/examples/target
java -cp paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PipelineClientExample string_imdb_predict
```
客户端预测(异步)
```
cd ../../../java/examples/target
java -cp paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PipelineClientExample asyn_predict
```
对于input data type = INDArray类型,以Simple Pipeline WebService中的uci_housing_model模型为例,服务端启动
```
cd ../../python/examples/pipeline/simple_web_service
sh get_data.sh
python web_service_java.py &>log.txt &
```
客户端预测(同步)
```
cd ../../../java/examples/target
java -cp paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar PipelineClientExample indarray_predict
```
### 二次开发指导
......@@ -70,4 +113,9 @@ java -cp paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar Paddle
**需要注意的是,在示例中,所有模型都需要使用`--use_multilang`来启动GRPC多编程语言支持,以及端口号都是9393,如果需要别的端口,需要在java文件里修改**
**目前Serving已推出Pipeline模式(详见[Pipeline Serving](../doc/PIPELINE_SERVING_CN.md)),面向Java的Pipeline Serving Client已发布,下个更新会发布Java版本的多线程用例敬请期待。**
**需要注意的是,Java Pipeline Client相关示例在/Java/Examples和/Java/src/main中,对应的Pipeline server在/python/examples/pipeline/中**
**目前Serving已推出Pipeline模式(详见[Pipeline Serving](../doc/PIPELINE_SERVING_CN.md)),下个版本(0.4.1)面向Java的Pipeline Serving Client将会发布,敬请期待。**
import io.paddle.serving.pipelineclient.*;
import java.io.File;
import java.io.IOException;
import java.net.URL;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.datavec.image.loader.NativeImageLoader;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
/**
* this class give an example for using the client to predict(grpc)
* StaticPipelineClient.client supports mutil-thread.
* By setting StaticPipelineClient.client properties,you can change the Maximum concurrency
* Do not need to generate multiple instances of client,Use the StaticPipelineClient.client or SingleTon instead.
* @author HexToString
*/
public class PipelineClientExample {
/**
* This method gives an example of synchronous prediction whose input type is string.
*/
boolean string_imdb_predict() {
HashMap<String, String> feed_data
= new HashMap<String, String>() {{
put("words", "i am very sad | 0");
}};
System.out.println(feed_data);
List<String> fetch = Arrays.asList("prediction");
System.out.println(fetch);
if (StaticPipelineClient.succ != true) {
if(!StaticPipelineClient.initClient("172.17.0.2","18070")){
System.out.println("connect failed.");
return false;
}
}
HashMap<String,String> result = StaticPipelineClient.client.predict(feed_data, fetch,false,0);
if (result == null) {
return false;
}
System.out.println(result);
return true;
}
/**
* This method gives an example of asynchronous prediction whose input type is string.
*/
boolean asyn_predict() {
HashMap<String, String> feed_data
= new HashMap<String, String>() {{
put("words", "i am very sad | 0");
}};
System.out.println(feed_data);
List<String> fetch = Arrays.asList("prediction");
System.out.println(fetch);
if (StaticPipelineClient.succ != true) {
if(!StaticPipelineClient.initClient("172.17.0.2","18070")){
System.out.println("connect failed.");
return false;
}
}
PipelineFuture future = StaticPipelineClient.client.asyn_pr::qedict(feed_data, fetch,false,0);
HashMap<String,String> result = future.get();
if (result == null) {
return false;
}
System.out.println(result);
return true;
}
/**
* This method gives an example of synchronous prediction whose input type is Array or list or matrix.
* use Nd4j.createFromArray method to convert Array to INDArray.
* use convertINDArrayToString method to convert INDArray to specified String type(for python Numpy eval method).
*/
boolean indarray_predict() {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, 0.0582f, -0.0727f, -0.1583f, -0.0584f, 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.createFromArray(data);
HashMap<String, String> feed_data
= new HashMap<String, String>() {{
put("x", convertINDArrayToString(npdata));
}};
List<String> fetch = Arrays.asList("prediction");
if (StaticPipelineClient.succ != true) {
if(!StaticPipelineClient.initClient("172.17.0.2","9998")){
System.out.println("connect failed.");
return false;
}
}
HashMap<String,String> result = StaticPipelineClient.client.predict(feed_data, fetch,false,0);
if (result == null) {
return false;
}
System.out.println(result);
return true;
}
/**
* This method convert INDArray to specified String type.
* @param npdata INDArray type(The input data).
* @return String (specified String type for python Numpy eval method).
*/
String convertINDArrayToString(INDArray npdata){
return "array("+npdata.toString()+")";
}
/**
* This method is entry function.
* @param args String[] type(Command line parameters)
*/
public static void main( String[] args ) {
PipelineClientExample e = new PipelineClientExample();
boolean succ = false;
if (args.length < 1) {
System.out.println("Usage: java -cp <jar> PaddleServingClientExample <test-type>.");
System.out.println("<test-type>: fit_a_line bert model_ensemble asyn_predict batch_predict cube_local cube_quant yolov4");
return;
}
String testType = args[0];
System.out.format("[Example] %s\n", testType);
if ("string_imdb_predict".equals(testType)) {
succ = e.string_imdb_predict();
}else if ("asyn_predict".equals(testType)) {
succ = e.asyn_predict();
}else if ("indarray_predict".equals(testType)) {
succ = e.indarray_predict();
} else {
System.out.format("test-type(%s) not match.\n", testType);
return;
}
if (succ == true) {
System.out.println("[Example] succ.");
} else {
System.out.println("[Example] fail.");
}
}
}
import io.paddle.serving.pipelineclient.*;
import java.io.File;
import java.io.IOException;
import java.net.URL;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.datavec.image.loader.NativeImageLoader;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
/**
* static resource management class
* @author HexToString
*/
public class StaticPipelineClient {
/**
* Static Variable PipelineClient
*/
public static PipelineClient client = new PipelineClient();
/**
* the sign of connect status
*/
public static boolean succ = false;
/**
* This method returns the sign of connect status.
* @param strIp String type(The server ipv4) such as "192.168.10.10".
* @param strPort String type(The server port) such as "8891".
* @return boolean (the sign of connect status).
*/
public static boolean initClient(String strIp,String strPort){
String target = strIp+ ":"+ strPort;//"172.17.0.2:18070";
System.out.println("initial connect.");
if(succ){
System.out.println("already connect.");
return true;
}
succ = clieint.connect(target);
if (succ != true) {
System.out.println("connect failed.");
return false;
}
return true;
}
}
package io.paddle.serving.pipelineclient;
import java.util.*;
import java.util.function.Function;
import java.lang.management.ManagementFactory;
import java.lang.management.RuntimeMXBean;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.StatusRuntimeException;
import com.google.protobuf.ByteString;
import com.google.common.util.concurrent.ListenableFuture;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.factory.Nd4j;
import io.paddle.serving.pipelineproto.*;
import io.paddle.serving.pipelineclient.PipelineFuture;
/**
* PipelineClient class defination
* @author HexToString
*/
public class PipelineClient {
private ManagedChannel channel_;
private PipelineServiceGrpc.PipelineServiceBlockingStub blockingStub_;
private PipelineServiceGrpc.PipelineServiceFutureStub futureStub_;
private String clientip;
private String _profile_key;
private String _profile_value;
public PipelineClient() {
channel_ = null;
blockingStub_ = null;
futureStub_ = null;
boolean is_profile = false;
clientip = null;
_profile_value = "1";
_profile_key = "pipeline.profile";
}
/**
* This method returns the sign of connect status.
* @param target String type(The server ipv4 and port) such as "192.168.10.10:8891".
* @return boolean (the sign of connect status).
*/
public boolean connect(String target) {
try {
String[] temp = target.split(":");
this.clientip = temp[0] == "localhost"?"127.0.0.1":temp[0];
channel_ = ManagedChannelBuilder.forTarget(target)
.defaultLoadBalancingPolicy("round_robin")
.maxInboundMessageSize(Integer.MAX_VALUE)
.usePlaintext()
.build();
blockingStub_ = PipelineServiceGrpc.newBlockingStub(channel_);
futureStub_ = PipelineServiceGrpc.newFutureStub(channel_);
} catch (Exception e) {
System.out.format("Connect failed: %s\n", e.toString());
return false;
}
return true;
}
/**
* This method returns the Packaged Request.
* @param feed_dict HashMap<String, String>(input data).
* @param profile boolean(profile sign).
* @param logid int
* @return Request (the grpc protobuf Request).
*/
private Request _packInferenceRequest(
HashMap<String, String> feed_dict,
boolean profile,
int logid) throws IllegalArgumentException {
List<String> keys = new ArrayList<String>();
List<String> values = new ArrayList<String>();
long[] flattened_shape = {-1};
Request.Builder req_builder = Request.newBuilder()
.setClientip(this.clientip)
.setLogid(logid);
for (Map.Entry<String, String> entry : feed_dict.entrySet()) {
keys.add(entry.getKey());
values.add(entry.getValue());
}
if(profile){
keys.add(_profile_key);
values.add(_profile_value);
}
req_builder.addAllKey(keys);
req_builder.addAllValue(values);
return req_builder.build();
}
/**
* This method returns the HashMap which is unpackaged from Response.
* @param resp Response(the grpc protobuf Response).
* @return HashMap<String,String> (the output).
*/
private HashMap<String,String> _unpackResponse(Response resp) throws IllegalArgumentException{
return PipelineClient._staitcUnpackResponse(resp);
}
/**
* This static method returns the HashMap which is unpackaged from Response.
* @param resp Response(the grpc protobuf Response).
* @return HashMap<String,String> (the output).
*/
private static HashMap<String,String> _staitcUnpackResponse(Response resp) {
HashMap<String,String> ret_Map = new HashMap<String,String>();
int err_no = resp.getErrNo();
if ( err_no!= 0) {
return null;
}
List<String> keys = resp.getKeyList();
List<String> values= resp.getValueList();
for (int i = 0;i<keys.size();i++) {
ret_Map.put(keys.get(i),values.get(i));
}
return ret_Map;
}
/**
* The synchronous prediction method.
* @param feed_batch HashMap<String, String>(input data).
* @param fetch Iterable<String>(the output key list).
* @param profile boolean(profile sign).
* @param logid int
* @return HashMap<String,String> (the output).
*/
public HashMap<String,String> predict(
HashMap<String, String> feed_batch,
Iterable<String> fetch,
boolean profile,
int logid) {
try {
Request req = _packInferenceRequest(
feed_batch, profile,logid);
Response resp = blockingStub_.inference(req);
return _unpackResponse(resp);
} catch (StatusRuntimeException e) {
System.out.format("Failed to predict: %s\n", e.toString());
return null;
}
}
/**
* The synchronous prediction overload function.
*/
public HashMap<String,String> predict(
HashMap<String, String> feed_batch,
Iterable<String> fetch) {
return predict(feed_batch,fetch,false,0);
}
/**
* The synchronous prediction overload function.
*/
public HashMap<String,String> predict(
HashMap<String, String> feed_batch,
Iterable<String> fetch,
boolean profile) {
return predict(feed_batch,fetch,profile,0);
}
/**
* The synchronous prediction overload function.
*/
public HashMap<String,String> predict(
HashMap<String, String> feed_batch,
Iterable<String> fetch,
int logid) {
return predict(feed_batch,fetch,false,logid);
}
/**
* The asynchronous prediction method.use future.get() to get the result.
* @param feed_batch HashMap<String, String>(input data).
* @param fetch Iterable<String>(the output key list).
* @param profile boolean(profile sign).
* @param logid int
* @return PipelineFuture(the output future).
*/
public PipelineFuture asyn_predict(
HashMap<String, String> feed_batch,
Iterable<String> fetch,
boolean profile,
int logid) {
Request req = _packInferenceRequest(
feed_batch, profile, logid);
ListenableFuture<Response> future = futureStub_.inference(req);
PipelineFuture predict_future = new PipelineFuture(future,
(Response resp) -> {
return PipelineClient._staitcUnpackResponse(resp);
}
);
return predict_future;
}
/**
* The asynchronous prediction overload function.
*/
public PipelineFuture asyn_predict(
HashMap<String, String> feed_batch,
Iterable<String> fetch) {
return asyn_predict(feed_batch,fetch,false,0);
}
/**
* The asynchronous prediction overload function.
*/
public PipelineFuture asyn_predict(
HashMap<String, String> feed_batch,
Iterable<String> fetch,
boolean profile) {
return asyn_predict(feed_batch,fetch,profile,0);
}
/**
* The asynchronous prediction overload function.
*/
public PipelineFuture asyn_predict(
HashMap<String, String> feed_batch,
Iterable<String> fetch,
int logid) {
return asyn_predict(feed_batch,fetch,false,logid);
}
}
package io.paddle.serving.pipelineclient;
import java.util.*;
import java.util.function.Function;
import io.grpc.StatusRuntimeException;
import com.google.common.util.concurrent.ListenableFuture;
import org.nd4j.linalg.api.ndarray.INDArray;
import io.paddle.serving.pipelineclient.PipelineClient;
import io.paddle.serving.pipelineproto.*;
/**
* PipelineFuture class is for asynchronous prediction
* @author HexToString
*/
public class PipelineFuture {
private ListenableFuture<Response> callFuture_;
private Function<Response,
HashMap<String,String> > callBackFunc_;
PipelineFuture(ListenableFuture<Response> call_future,
Function<Response,
HashMap<String,String> > call_back_func) {
callFuture_ = call_future;
callBackFunc_ = call_back_func;
}
/**
* use this method to get the result of asynchronous prediction.
*/
public HashMap<String,String> get() {
Response resp = null;
try {
resp = callFuture_.get();
} catch (Exception e) {
System.out.format("predict failed: %s\n", e.toString());
return null;
}
HashMap<String,String> result
= callBackFunc_.apply(resp);
return result;
}
}
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
option java_multiple_files = true;
option java_package = "io.paddle.serving.pipelineproto";
option java_outer_classname = "PipelineProto";
package baidu.paddle_serving.pipeline_serving;
message Request {
repeated string key = 1;
repeated string value = 2;
optional string name = 3;
optional string method = 4;
optional int64 logid = 5;
optional string clientip = 6;
};
message Response {
optional int32 err_no = 1;
optional string err_msg = 2;
repeated string key = 3;
repeated string value = 4;
};
service PipelineService {
rpc inference(Request) returns (Response) {}
};
# Encryption Model Prediction
([简体中文](README_CN.md)|English)
## Get Origin Model
The example uses the model file of the fit_a_line example as a origin model
```
sh get_data.sh
```
## Encrypt Model
```
python encrypt.py
```
The key is stored in the `key` file, and the encrypted model file and server-side configuration file are stored in the `encrypt_server` directory.
client-side configuration file are stored in the `encrypt_client` directory.
## Start Encryption Service
CPU Service
```
python -m paddle_serving_server.serve --model encrypt_server/ --port 9300 --use_encryption_model
```
GPU Service
```
python -m paddle_serving_server_gpu.serve --model encrypt_server/ --port 9300 --use_encryption_model --gpu_ids 0
```
## Prediction
```
python test_client.py uci_housing_client/serving_client_conf.prototxt
```
# 加密模型预测
(简体中文|[English](README.md))
## 获取明文模型
示例中使用fit_a_line示例的模型文件作为明文模型
```
sh get_data.sh
```
## 模型加密
```
python encrypt.py
```
密钥保存在`key`文件中,加密模型文件以及server端配置文件保存在`encrypt_server`目录下,client端配置文件保存在`encrypt_client`目录下。
## 启动加密预测服务
CPU预测服务
```
python -m paddle_serving_server.serve --model encrypt_server/ --port 9300 --use_encryption_model
```
GPU预测服务
```
python -m paddle_serving_server_gpu.serve --model encrypt_server/ --port 9300 --use_encryption_model --gpu_ids 0
```
## 预测
```
python test_client.py uci_housing_client/serving_client_conf.prototxt
```
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client.io import inference_model_to_serving
def serving_encryption():
inference_model_to_serving(
dirname="./uci_housing_model",
serving_server="encrypt_server",
serving_client="encrypt_client",
encryption=True)
if __name__ == "__main__":
serving_encryption()
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing_example/encrypt.tar.gz
tar -xzf encrypt.tar.gz
cp -rvf ../fit_a_line/uci_housing_model .
cp -rvf ../fit_a_line/uci_housing_client .
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_client import Client
import sys
client = Client()
client.load_client_config(sys.argv[1])
client.use_key("./key")
client.connect(["127.0.0.1:9300"], encryption=True)
import paddle
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.uci_housing.test(), buf_size=500),
batch_size=1)
for data in test_reader():
fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["price"])
print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
......@@ -3,6 +3,7 @@
worker_num: 1
#http端口, rpc_port和http_port不允许同时为空。当rpc_port可用且http_port为空时,不自动生成http_port
rpc_port: 9998
http_port: 18082
dag:
......@@ -20,7 +21,7 @@ op:
model_config: uci_housing_model
#计算硬件ID,当devices为""或不写时为CPU预测;当devices为"0", "0,1,2"时为GPU预测,表示使用的GPU卡
devices: "0" # "0,1"
devices: "" # "0,1"
#client类型,包括brpc, grpc和local_predictor.local_predictor不启动Serving服务,进程内预测
client_type: local_predictor
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
from paddle_serving_server.web_service import WebService, Op
except ImportError:
from paddle_serving_server.web_service import WebService, Op
import logging
import numpy as np
from numpy import array
import sys
import base64
_LOGGER = logging.getLogger()
np.set_printoptions(threshold=sys.maxsize)
class UciOp(Op):
def init_op(self):
self.separator = ","
def preprocess(self, input_dicts, data_id, log_id):
"""
diff with web_server.py
javaclient input type is INDArray, restful request input is list.
this function simply reshape input to the Specified shape.
"""
(_, input_dict), = input_dicts.items()
_LOGGER.error("UciOp::preprocess >>> log_id:{}, input:{}".format(
log_id, input_dict))
proc_dict = {}
x_value = input_dict["x"]
input_dict["x"] = x_value.reshape(1,13)
return input_dict, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id):
_LOGGER.info("UciOp::postprocess >>> log_id:{}, fetch_dict:{}".format(
log_id, fetch_dict))
fetch_dict["price"] = str(fetch_dict["price"][0][0])
return fetch_dict, None, ""
class UciService(WebService):
def get_pipeline_response(self, read_op):
uci_op = UciOp(name="uci", input_ops=[read_op])
return uci_op
uci_service = UciService(name="uci")
uci_service.prepare_pipeline_config("config.yml")
uci_service.run_service()
......@@ -21,6 +21,7 @@ import contextlib
from contextlib import closing
import multiprocessing
import yaml
import io
from .proto import pipeline_service_pb2_grpc, pipeline_service_pb2
from . import operator
......@@ -333,7 +334,7 @@ class ServerYamlConfChecker(object):
raise SystemExit("Failed to prepare_server: only one of yml_file"
" or yml_dict can be selected as the parameter.")
if yml_file is not None:
with open(yml_file, encoding='utf-8') as f:
with io.open(yml_file, encoding='utf-8') as f:
conf = yaml.load(f.read())
elif yml_dict is not None:
conf = yml_dict
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册