提交 d7a468dd 编写于 作者: H HexToString

add java mutil-thread client and add annotation

上级 2fd5551c
...@@ -91,7 +91,7 @@ as for input data type = INDArray,take uci_housing_model as an example,the s ...@@ -91,7 +91,7 @@ as for input data type = INDArray,take uci_housing_model as an example,the s
``` ```
cd ../../python/examples/pipeline/simple_web_service cd ../../python/examples/pipeline/simple_web_service
sh get_data.sh sh get_data.sh
python web_service_forJava.py &>log.txt & python web_service_java.py &>log.txt &
``` ```
Client prediction(Synchronous) Client prediction(Synchronous)
......
...@@ -93,7 +93,7 @@ java -cp paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar Pipeli ...@@ -93,7 +93,7 @@ java -cp paddle-serving-sdk-java-examples-0.0.1-jar-with-dependencies.jar Pipeli
``` ```
cd ../../python/examples/pipeline/simple_web_service cd ../../python/examples/pipeline/simple_web_service
sh get_data.sh sh get_data.sh
python web_service_forJava.py &>log.txt & python web_service_java.py &>log.txt &
``` ```
客户端预测(同步) 客户端预测(同步)
......
...@@ -10,8 +10,18 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp; ...@@ -10,8 +10,18 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.util.*; import java.util.*;
/**
* this class give an example for using the client to predict(grpc)
* StaticPipelineClient.client supports mutil-thread.
* By setting StaticPipelineClient.client properties,you can change the Maximum concurrency
* Do not need to generate multiple instances of client,Use the StaticPipelineClient.client or SingleTon instead.
* @author HexToString
*/
public class PipelineClientExample { public class PipelineClientExample {
/**
* This method gives an example of synchronous prediction whose input type is string.
*/
boolean string_imdb_predict() { boolean string_imdb_predict() {
HashMap<String, String> feed_data HashMap<String, String> feed_data
= new HashMap<String, String>() {{ = new HashMap<String, String>() {{
...@@ -20,15 +30,14 @@ public class PipelineClientExample { ...@@ -20,15 +30,14 @@ public class PipelineClientExample {
System.out.println(feed_data); System.out.println(feed_data);
List<String> fetch = Arrays.asList("prediction"); List<String> fetch = Arrays.asList("prediction");
System.out.println(fetch); System.out.println(fetch);
PipelineClient client = new PipelineClient();
String target = "172.17.0.2:18070"; if (StaticPipelineClient.succ != true) {
boolean succ = client.connect(target); if(!StaticPipelineClient.initClient("172.17.0.2","18070")){
if (succ != true) { System.out.println("connect failed.");
System.out.println("connect failed."); return false;
return false; }
} }
HashMap<String,String> result = StaticPipelineClient.client.predict(feed_data, fetch,false,0);
HashMap<String,String> result = client.predict(feed_data, fetch,false,0);
if (result == null) { if (result == null) {
return false; return false;
} }
...@@ -36,6 +45,9 @@ public class PipelineClientExample { ...@@ -36,6 +45,9 @@ public class PipelineClientExample {
return true; return true;
} }
/**
* This method gives an example of asynchronous prediction whose input type is string.
*/
boolean asyn_predict() { boolean asyn_predict() {
HashMap<String, String> feed_data HashMap<String, String> feed_data
= new HashMap<String, String>() {{ = new HashMap<String, String>() {{
...@@ -44,14 +56,13 @@ public class PipelineClientExample { ...@@ -44,14 +56,13 @@ public class PipelineClientExample {
System.out.println(feed_data); System.out.println(feed_data);
List<String> fetch = Arrays.asList("prediction"); List<String> fetch = Arrays.asList("prediction");
System.out.println(fetch); System.out.println(fetch);
PipelineClient client = new PipelineClient(); if (StaticPipelineClient.succ != true) {
String target = "172.17.0.2:18070"; if(!StaticPipelineClient.initClient("172.17.0.2","18070")){
boolean succ = client.connect(target); System.out.println("connect failed.");
if (succ != true) { return false;
System.out.println("connect failed."); }
return false;
} }
PipelineFuture future = client.asyn_predict(feed_data, fetch,false,0); PipelineFuture future = StaticPipelineClient.client.asyn_pr::qedict(feed_data, fetch,false,0);
HashMap<String,String> result = future.get(); HashMap<String,String> result = future.get();
if (result == null) { if (result == null) {
return false; return false;
...@@ -60,24 +71,28 @@ public class PipelineClientExample { ...@@ -60,24 +71,28 @@ public class PipelineClientExample {
return true; return true;
} }
/**
* This method gives an example of synchronous prediction whose input type is Array or list or matrix.
* use Nd4j.createFromArray method to convert Array to INDArray.
* use convertINDArrayToString method to convert INDArray to specified String type(for python Numpy eval method).
*/
boolean indarray_predict() { boolean indarray_predict() {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, 0.0582f, -0.0727f, -0.1583f, -0.0584f, 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, 0.0582f, -0.0727f, -0.1583f, -0.0584f, 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.createFromArray(data); INDArray npdata = Nd4j.createFromArray(data);
HashMap<String, String> feed_data HashMap<String, String> feed_data
= new HashMap<String, String>() {{ = new HashMap<String, String>() {{
put("x", "array("+npdata.toString()+")"); put("x", convertINDArrayToString(npdata));
}}; }};
List<String> fetch = Arrays.asList("prediction"); List<String> fetch = Arrays.asList("prediction");
PipelineClient client = new PipelineClient(); if (StaticPipelineClient.succ != true) {
String target = "172.17.0.2:9998"; if(!StaticPipelineClient.initClient("172.17.0.2","9998")){
boolean succ = client.connect(target); System.out.println("connect failed.");
if (succ != true) { return false;
System.out.println("connect failed."); }
return false;
} }
HashMap<String,String> result = client.predict(feed_data, fetch,false,0); HashMap<String,String> result = StaticPipelineClient.client.predict(feed_data, fetch,false,0);
if (result == null) { if (result == null) {
return false; return false;
} }
...@@ -85,9 +100,21 @@ public class PipelineClientExample { ...@@ -85,9 +100,21 @@ public class PipelineClientExample {
return true; return true;
} }
/**
* This method convert INDArray to specified String type.
* @param npdata INDArray type(The input data).
* @return String (specified String type for python Numpy eval method).
*/
String convertINDArrayToString(INDArray npdata){
return "array("+npdata.toString()+")";
}
/**
* This method is entry function.
* @param args String[] type(Command line parameters)
*/
public static void main( String[] args ) { public static void main( String[] args ) {
PipelineClientExample e = new PipelineClientExample(); PipelineClientExample e = new PipelineClientExample();
boolean succ = false; boolean succ = false;
if (args.length < 1) { if (args.length < 1) {
...@@ -95,6 +122,7 @@ public class PipelineClientExample { ...@@ -95,6 +122,7 @@ public class PipelineClientExample {
System.out.println("<test-type>: fit_a_line bert model_ensemble asyn_predict batch_predict cube_local cube_quant yolov4"); System.out.println("<test-type>: fit_a_line bert model_ensemble asyn_predict batch_predict cube_local cube_quant yolov4");
return; return;
} }
String testType = args[0]; String testType = args[0];
System.out.format("[Example] %s\n", testType); System.out.format("[Example] %s\n", testType);
if ("string_imdb_predict".equals(testType)) { if ("string_imdb_predict".equals(testType)) {
...@@ -114,12 +142,6 @@ public class PipelineClientExample { ...@@ -114,12 +142,6 @@ public class PipelineClientExample {
System.out.println("[Example] fail."); System.out.println("[Example] fail.");
} }
} }
} }
//if list or array or matrix,please Convert to INDArray,for example:
//INDArray npdata = Nd4j.createFromArray(data);
//INDArray Convert to String,for example:
//string value = "array("+npdata.toString()+")"
import io.paddle.serving.pipelineclient.*;
import java.io.File;
import java.io.IOException;
import java.net.URL;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.datavec.image.loader.NativeImageLoader;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
/**
* static resource management class
* @author HexToString
*/
public class StaticPipelineClient {
/**
* Static Variable PipelineClient
*/
public static PipelineClient client = new PipelineClient();
/**
* the sign of connect status
*/
public static boolean succ = false;
/**
* This method returns the sign of connect status.
* @param strIp String type(The server ipv4) such as "192.168.10.10".
* @param strPort String type(The server port) such as "8891".
* @return boolean (the sign of connect status).
*/
public static boolean initClient(String strIp,String strPort){
String target = strIp+ ":"+ strPort;//"172.17.0.2:18070";
System.out.println("initial connect.");
if(succ){
System.out.println("already connect.");
return true;
}
succ = clieint.connect(target);
if (succ != true) {
System.out.println("connect failed.");
return false;
}
return true;
}
}
...@@ -19,6 +19,11 @@ import org.nd4j.linalg.factory.Nd4j; ...@@ -19,6 +19,11 @@ import org.nd4j.linalg.factory.Nd4j;
import io.paddle.serving.pipelineproto.*; import io.paddle.serving.pipelineproto.*;
import io.paddle.serving.pipelineclient.PipelineFuture; import io.paddle.serving.pipelineclient.PipelineFuture;
/**
* PipelineClient class defination
* @author HexToString
*/
public class PipelineClient { public class PipelineClient {
private ManagedChannel channel_; private ManagedChannel channel_;
private PipelineServiceGrpc.PipelineServiceBlockingStub blockingStub_; private PipelineServiceGrpc.PipelineServiceBlockingStub blockingStub_;
...@@ -38,6 +43,11 @@ public class PipelineClient { ...@@ -38,6 +43,11 @@ public class PipelineClient {
_profile_key = "pipeline.profile"; _profile_key = "pipeline.profile";
} }
/**
* This method returns the sign of connect status.
* @param target String type(The server ipv4 and port) such as "192.168.10.10:8891".
* @return boolean (the sign of connect status).
*/
public boolean connect(String target) { public boolean connect(String target) {
try { try {
String[] temp = target.split(":"); String[] temp = target.split(":");
...@@ -56,6 +66,13 @@ public class PipelineClient { ...@@ -56,6 +66,13 @@ public class PipelineClient {
return true; return true;
} }
/**
* This method returns the Packaged Request.
* @param feed_dict HashMap<String, String>(input data).
* @param profile boolean(profile sign).
* @param logid int
* @return Request (the grpc protobuf Request).
*/
private Request _packInferenceRequest( private Request _packInferenceRequest(
HashMap<String, String> feed_dict, HashMap<String, String> feed_dict,
boolean profile, boolean profile,
...@@ -80,11 +97,20 @@ public class PipelineClient { ...@@ -80,11 +97,20 @@ public class PipelineClient {
return req_builder.build(); return req_builder.build();
} }
/**
* This method returns the HashMap which is unpackaged from Response.
* @param resp Response(the grpc protobuf Response).
* @return HashMap<String,String> (the output).
*/
private HashMap<String,String> _unpackResponse(Response resp) throws IllegalArgumentException{ private HashMap<String,String> _unpackResponse(Response resp) throws IllegalArgumentException{
return PipelineClient._staitcUnpackResponse(resp); return PipelineClient._staitcUnpackResponse(resp);
} }
/**
* This static method returns the HashMap which is unpackaged from Response.
* @param resp Response(the grpc protobuf Response).
* @return HashMap<String,String> (the output).
*/
private static HashMap<String,String> _staitcUnpackResponse(Response resp) { private static HashMap<String,String> _staitcUnpackResponse(Response resp) {
HashMap<String,String> ret_Map = new HashMap<String,String>(); HashMap<String,String> ret_Map = new HashMap<String,String>();
int err_no = resp.getErrNo(); int err_no = resp.getErrNo();
...@@ -99,6 +125,14 @@ public class PipelineClient { ...@@ -99,6 +125,14 @@ public class PipelineClient {
return ret_Map; return ret_Map;
} }
/**
* The synchronous prediction method.
* @param feed_batch HashMap<String, String>(input data).
* @param fetch Iterable<String>(the output key list).
* @param profile boolean(profile sign).
* @param logid int
* @return HashMap<String,String> (the output).
*/
public HashMap<String,String> predict( public HashMap<String,String> predict(
HashMap<String, String> feed_batch, HashMap<String, String> feed_batch,
Iterable<String> fetch, Iterable<String> fetch,
...@@ -115,12 +149,18 @@ public class PipelineClient { ...@@ -115,12 +149,18 @@ public class PipelineClient {
} }
} }
/**
* The synchronous prediction overload function.
*/
public HashMap<String,String> predict( public HashMap<String,String> predict(
HashMap<String, String> feed_batch, HashMap<String, String> feed_batch,
Iterable<String> fetch) { Iterable<String> fetch) {
return predict(feed_batch,fetch,false,0); return predict(feed_batch,fetch,false,0);
} }
/**
* The synchronous prediction overload function.
*/
public HashMap<String,String> predict( public HashMap<String,String> predict(
HashMap<String, String> feed_batch, HashMap<String, String> feed_batch,
Iterable<String> fetch, Iterable<String> fetch,
...@@ -128,6 +168,9 @@ public class PipelineClient { ...@@ -128,6 +168,9 @@ public class PipelineClient {
return predict(feed_batch,fetch,profile,0); return predict(feed_batch,fetch,profile,0);
} }
/**
* The synchronous prediction overload function.
*/
public HashMap<String,String> predict( public HashMap<String,String> predict(
HashMap<String, String> feed_batch, HashMap<String, String> feed_batch,
Iterable<String> fetch, Iterable<String> fetch,
...@@ -135,6 +178,14 @@ public class PipelineClient { ...@@ -135,6 +178,14 @@ public class PipelineClient {
return predict(feed_batch,fetch,false,logid); return predict(feed_batch,fetch,false,logid);
} }
/**
* The asynchronous prediction method.use future.get() to get the result.
* @param feed_batch HashMap<String, String>(input data).
* @param fetch Iterable<String>(the output key list).
* @param profile boolean(profile sign).
* @param logid int
* @return PipelineFuture(the output future).
*/
public PipelineFuture asyn_predict( public PipelineFuture asyn_predict(
HashMap<String, String> feed_batch, HashMap<String, String> feed_batch,
Iterable<String> fetch, Iterable<String> fetch,
...@@ -151,12 +202,18 @@ public class PipelineClient { ...@@ -151,12 +202,18 @@ public class PipelineClient {
return predict_future; return predict_future;
} }
/**
* The asynchronous prediction overload function.
*/
public PipelineFuture asyn_predict( public PipelineFuture asyn_predict(
HashMap<String, String> feed_batch, HashMap<String, String> feed_batch,
Iterable<String> fetch) { Iterable<String> fetch) {
return asyn_predict(feed_batch,fetch,false,0); return asyn_predict(feed_batch,fetch,false,0);
} }
/**
* The asynchronous prediction overload function.
*/
public PipelineFuture asyn_predict( public PipelineFuture asyn_predict(
HashMap<String, String> feed_batch, HashMap<String, String> feed_batch,
Iterable<String> fetch, Iterable<String> fetch,
...@@ -164,6 +221,9 @@ public class PipelineClient { ...@@ -164,6 +221,9 @@ public class PipelineClient {
return asyn_predict(feed_batch,fetch,profile,0); return asyn_predict(feed_batch,fetch,profile,0);
} }
/**
* The asynchronous prediction overload function.
*/
public PipelineFuture asyn_predict( public PipelineFuture asyn_predict(
HashMap<String, String> feed_batch, HashMap<String, String> feed_batch,
Iterable<String> fetch, Iterable<String> fetch,
......
...@@ -9,6 +9,10 @@ import org.nd4j.linalg.api.ndarray.INDArray; ...@@ -9,6 +9,10 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import io.paddle.serving.pipelineclient.PipelineClient; import io.paddle.serving.pipelineclient.PipelineClient;
import io.paddle.serving.pipelineproto.*; import io.paddle.serving.pipelineproto.*;
/**
* PipelineFuture class is for asynchronous prediction
* @author HexToString
*/
public class PipelineFuture { public class PipelineFuture {
private ListenableFuture<Response> callFuture_; private ListenableFuture<Response> callFuture_;
private Function<Response, private Function<Response,
...@@ -21,6 +25,9 @@ public class PipelineFuture { ...@@ -21,6 +25,9 @@ public class PipelineFuture {
callBackFunc_ = call_back_func; callBackFunc_ = call_back_func;
} }
/**
* use this method to get the result of asynchronous prediction.
*/
public HashMap<String,String> get() { public HashMap<String,String> get() {
Response resp = null; Response resp = null;
try { try {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册