提交 d7a468dd 编写于 作者: H HexToString

add java mutil-thread client and add annotation

上级 2fd5551c
......@@ -91,7 +91,7 @@ as for input data type = INDArray,take uci_housing_model as an example,the s
```
cd ../../python/examples/pipeline/simple_web_service
sh get_data.sh
python web_service_forJava.py &>log.txt &
python web_service_java.py &>log.txt &
```
Client prediction(Synchronous)
......
......@@ -93,7 +93,7 @@ java -cp paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar Pipeli
```
cd ../../python/examples/pipeline/simple_web_service
sh get_data.sh
python web_service_forJava.py &>log.txt &
python web_service_java.py &>log.txt &
```
客户端预测(同步)
......
......@@ -10,8 +10,18 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
/**
* this class give an example for using the client to predict(grpc)
* StaticPipelineClient.client supports mutil-thread.
* By setting StaticPipelineClient.client properties,you can change the Maximum concurrency
* Do not need to generate multiple instances of client,Use the StaticPipelineClient.client or SingleTon instead.
* @author HexToString
*/
public class PipelineClientExample {
/**
* This method gives an example of synchronous prediction whose input type is string.
*/
boolean string_imdb_predict() {
HashMap<String, String> feed_data
= new HashMap<String, String>() {{
......@@ -20,15 +30,14 @@ public class PipelineClientExample {
System.out.println(feed_data);
List<String> fetch = Arrays.asList("prediction");
System.out.println(fetch);
PipelineClient client = new PipelineClient();
String target = "172.17.0.2:18070";
boolean succ = client.connect(target);
if (succ != true) {
System.out.println("connect failed.");
return false;
if (StaticPipelineClient.succ != true) {
if(!StaticPipelineClient.initClient("172.17.0.2","18070")){
System.out.println("connect failed.");
return false;
}
}
HashMap<String,String> result = client.predict(feed_data, fetch,false,0);
HashMap<String,String> result = StaticPipelineClient.client.predict(feed_data, fetch,false,0);
if (result == null) {
return false;
}
......@@ -36,6 +45,9 @@ public class PipelineClientExample {
return true;
}
/**
* This method gives an example of asynchronous prediction whose input type is string.
*/
boolean asyn_predict() {
HashMap<String, String> feed_data
= new HashMap<String, String>() {{
......@@ -44,14 +56,13 @@ public class PipelineClientExample {
System.out.println(feed_data);
List<String> fetch = Arrays.asList("prediction");
System.out.println(fetch);
PipelineClient client = new PipelineClient();
String target = "172.17.0.2:18070";
boolean succ = client.connect(target);
if (succ != true) {
System.out.println("connect failed.");
return false;
if (StaticPipelineClient.succ != true) {
if(!StaticPipelineClient.initClient("172.17.0.2","18070")){
System.out.println("connect failed.");
return false;
}
}
PipelineFuture future = client.asyn_predict(feed_data, fetch,false,0);
PipelineFuture future = StaticPipelineClient.client.asyn_pr::qedict(feed_data, fetch,false,0);
HashMap<String,String> result = future.get();
if (result == null) {
return false;
......@@ -60,24 +71,28 @@ public class PipelineClientExample {
return true;
}
/**
* This method gives an example of synchronous prediction whose input type is Array or list or matrix.
* use Nd4j.createFromArray method to convert Array to INDArray.
* use convertINDArrayToString method to convert INDArray to specified String type(for python Numpy eval method).
*/
boolean indarray_predict() {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, 0.0582f, -0.0727f, -0.1583f, -0.0584f, 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.createFromArray(data);
HashMap<String, String> feed_data
= new HashMap<String, String>() {{
put("x", "array("+npdata.toString()+")");
put("x", convertINDArrayToString(npdata));
}};
List<String> fetch = Arrays.asList("prediction");
PipelineClient client = new PipelineClient();
String target = "172.17.0.2:9998";
boolean succ = client.connect(target);
if (succ != true) {
System.out.println("connect failed.");
return false;
if (StaticPipelineClient.succ != true) {
if(!StaticPipelineClient.initClient("172.17.0.2","9998")){
System.out.println("connect failed.");
return false;
}
}
HashMap<String,String> result = client.predict(feed_data, fetch,false,0);
HashMap<String,String> result = StaticPipelineClient.client.predict(feed_data, fetch,false,0);
if (result == null) {
return false;
}
......@@ -85,9 +100,21 @@ public class PipelineClientExample {
return true;
}
/**
* This method convert INDArray to specified String type.
* @param npdata INDArray type(The input data).
* @return String (specified String type for python Numpy eval method).
*/
String convertINDArrayToString(INDArray npdata){
return "array("+npdata.toString()+")";
}
/**
* This method is entry function.
* @param args String[] type(Command line parameters)
*/
public static void main( String[] args ) {
PipelineClientExample e = new PipelineClientExample();
boolean succ = false;
if (args.length < 1) {
......@@ -95,6 +122,7 @@ public class PipelineClientExample {
System.out.println("<test-type>: fit_a_line bert model_ensemble asyn_predict batch_predict cube_local cube_quant yolov4");
return;
}
String testType = args[0];
System.out.format("[Example] %s\n", testType);
if ("string_imdb_predict".equals(testType)) {
......@@ -114,12 +142,6 @@ public class PipelineClientExample {
System.out.println("[Example] fail.");
}
}
}
//if list or array or matrix,please Convert to INDArray,for example:
//INDArray npdata = Nd4j.createFromArray(data);
//INDArray Convert to String,for example:
//string value = "array("+npdata.toString()+")"
import io.paddle.serving.pipelineclient.*;
import java.io.File;
import java.io.IOException;
import java.net.URL;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.datavec.image.loader.NativeImageLoader;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
/**
* static resource management class
* @author HexToString
*/
public class StaticPipelineClient {
/**
* Static Variable PipelineClient
*/
public static PipelineClient client = new PipelineClient();
/**
* the sign of connect status
*/
public static boolean succ = false;
/**
* This method returns the sign of connect status.
* @param strIp String type(The server ipv4) such as "192.168.10.10".
* @param strPort String type(The server port) such as "8891".
* @return boolean (the sign of connect status).
*/
public static boolean initClient(String strIp,String strPort){
String target = strIp+ ":"+ strPort;//"172.17.0.2:18070";
System.out.println("initial connect.");
if(succ){
System.out.println("already connect.");
return true;
}
succ = clieint.connect(target);
if (succ != true) {
System.out.println("connect failed.");
return false;
}
return true;
}
}
......@@ -19,6 +19,11 @@ import org.nd4j.linalg.factory.Nd4j;
import io.paddle.serving.pipelineproto.*;
import io.paddle.serving.pipelineclient.PipelineFuture;
/**
* PipelineClient class defination
* @author HexToString
*/
public class PipelineClient {
private ManagedChannel channel_;
private PipelineServiceGrpc.PipelineServiceBlockingStub blockingStub_;
......@@ -38,6 +43,11 @@ public class PipelineClient {
_profile_key = "pipeline.profile";
}
/**
* This method returns the sign of connect status.
* @param target String type(The server ipv4 and port) such as "192.168.10.10:8891".
* @return boolean (the sign of connect status).
*/
public boolean connect(String target) {
try {
String[] temp = target.split(":");
......@@ -56,6 +66,13 @@ public class PipelineClient {
return true;
}
/**
* This method returns the Packaged Request.
* @param feed_dict HashMap<String, String>(input data).
* @param profile boolean(profile sign).
* @param logid int
* @return Request (the grpc protobuf Request).
*/
private Request _packInferenceRequest(
HashMap<String, String> feed_dict,
boolean profile,
......@@ -80,11 +97,20 @@ public class PipelineClient {
return req_builder.build();
}
/**
* This method returns the HashMap which is unpackaged from Response.
* @param resp Response(the grpc protobuf Response).
* @return HashMap<String,String> (the output).
*/
private HashMap<String,String> _unpackResponse(Response resp) throws IllegalArgumentException{
return PipelineClient._staitcUnpackResponse(resp);
}
/**
* This static method returns the HashMap which is unpackaged from Response.
* @param resp Response(the grpc protobuf Response).
* @return HashMap<String,String> (the output).
*/
private static HashMap<String,String> _staitcUnpackResponse(Response resp) {
HashMap<String,String> ret_Map = new HashMap<String,String>();
int err_no = resp.getErrNo();
......@@ -99,6 +125,14 @@ public class PipelineClient {
return ret_Map;
}
/**
* The synchronous prediction method.
* @param feed_batch HashMap<String, String>(input data).
* @param fetch Iterable<String>(the output key list).
* @param profile boolean(profile sign).
* @param logid int
* @return HashMap<String,String> (the output).
*/
public HashMap<String,String> predict(
HashMap<String, String> feed_batch,
Iterable<String> fetch,
......@@ -115,12 +149,18 @@ public class PipelineClient {
}
}
/**
* The synchronous prediction overload function.
*/
public HashMap<String,String> predict(
HashMap<String, String> feed_batch,
Iterable<String> fetch) {
return predict(feed_batch,fetch,false,0);
}
/**
* The synchronous prediction overload function.
*/
public HashMap<String,String> predict(
HashMap<String, String> feed_batch,
Iterable<String> fetch,
......@@ -128,6 +168,9 @@ public class PipelineClient {
return predict(feed_batch,fetch,profile,0);
}
/**
* The synchronous prediction overload function.
*/
public HashMap<String,String> predict(
HashMap<String, String> feed_batch,
Iterable<String> fetch,
......@@ -135,6 +178,14 @@ public class PipelineClient {
return predict(feed_batch,fetch,false,logid);
}
/**
* The asynchronous prediction method.use future.get() to get the result.
* @param feed_batch HashMap<String, String>(input data).
* @param fetch Iterable<String>(the output key list).
* @param profile boolean(profile sign).
* @param logid int
* @return PipelineFuture(the output future).
*/
public PipelineFuture asyn_predict(
HashMap<String, String> feed_batch,
Iterable<String> fetch,
......@@ -151,12 +202,18 @@ public class PipelineClient {
return predict_future;
}
/**
* The asynchronous prediction overload function.
*/
public PipelineFuture asyn_predict(
HashMap<String, String> feed_batch,
Iterable<String> fetch) {
return asyn_predict(feed_batch,fetch,false,0);
}
/**
* The asynchronous prediction overload function.
*/
public PipelineFuture asyn_predict(
HashMap<String, String> feed_batch,
Iterable<String> fetch,
......@@ -164,6 +221,9 @@ public class PipelineClient {
return asyn_predict(feed_batch,fetch,profile,0);
}
/**
* The asynchronous prediction overload function.
*/
public PipelineFuture asyn_predict(
HashMap<String, String> feed_batch,
Iterable<String> fetch,
......
......@@ -9,6 +9,10 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import io.paddle.serving.pipelineclient.PipelineClient;
import io.paddle.serving.pipelineproto.*;
/**
* PipelineFuture class is for asynchronous prediction
* @author HexToString
*/
public class PipelineFuture {
private ListenableFuture<Response> callFuture_;
private Function<Response,
......@@ -21,6 +25,9 @@ public class PipelineFuture {
callBackFunc_ = call_back_func;
}
/**
* use this method to get the result of asynchronous prediction.
*/
public HashMap<String,String> get() {
Response resp = null;
try {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册