diff --git a/java/examples/src/main/java/PaddleServingClientExample.java b/java/examples/src/main/java/PaddleServingClientExample.java index 1ad650c2d6fecdc285c84df698d54b0249c4b303..476130c02e67ac155c91f10062739cf775de836f 100644 --- a/java/examples/src/main/java/PaddleServingClientExample.java +++ b/java/examples/src/main/java/PaddleServingClientExample.java @@ -1,11 +1,189 @@ -import io.paddle.serving.client.Client; -/** - * Hello world! - * - */ +import io.paddle.serving.client.*; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.iter.NdIndexIterator; +import org.nd4j.linalg.factory.Nd4j; +import java.util.*; + public class PaddleServingClientExample { - public static void main( String[] args ) { + boolean fit_a_line() { + float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, + 0.0582f, -0.0727f, -0.1583f, -0.0584f, + 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; + INDArray npdata = Nd4j.createFromArray(data); + HashMap feed_data + = new HashMap() {{ + put("x", npdata); + }}; + List fetch = Arrays.asList("price"); + + Client client = new Client(); + List endpoints = Arrays.asList("localhost:9393"); + boolean succ = client.connect(endpoints); + if (succ != true) { + System.out.println("connect failed."); + return false; + } + + Map fetch_map = client.predict(feed_data, fetch); + for (Map.Entry e : fetch_map.entrySet()) { + System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); + } + return true; + } + + boolean batch_predict() { + float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, + 0.0582f, -0.0727f, -0.1583f, -0.0584f, + 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; + INDArray npdata = Nd4j.createFromArray(data); + HashMap feed_data + = new HashMap() {{ + put("x", npdata); + }}; + List> feed_batch + = new ArrayList>() {{ + add(feed_data); + add(feed_data); + }}; + List fetch = Arrays.asList("price"); + Client client = new Client(); - System.out.println( "Hello World!" ); + List endpoints = Arrays.asList("localhost:9393"); + boolean succ = client.connect(endpoints); + if (succ != true) { + System.out.println("connect failed."); + return false; + } + + Map fetch_map = client.predict(feed_batch, fetch); + for (Map.Entry e : fetch_map.entrySet()) { + System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); + } + return true; + } + + boolean asyn_predict() { + float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, + 0.0582f, -0.0727f, -0.1583f, -0.0584f, + 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; + INDArray npdata = Nd4j.createFromArray(data); + HashMap feed_data + = new HashMap() {{ + put("x", npdata); + }}; + List fetch = Arrays.asList("price"); + + Client client = new Client(); + List endpoints = Arrays.asList("localhost:9393"); + boolean succ = client.connect(endpoints); + if (succ != true) { + System.out.println("connect failed."); + return false; + } + + PredictFuture future = client.asyn_predict(feed_data, fetch); + Map fetch_map = future.get(); + if (fetch_map == null) { + System.out.println("Get future reslut failed"); + return false; + } + + for (Map.Entry e : fetch_map.entrySet()) { + System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); + } + return true; + } + + boolean model_ensemble() { + long[] data = {8, 233, 52, 601}; + INDArray npdata = Nd4j.createFromArray(data); + HashMap feed_data + = new HashMap() {{ + put("words", npdata); + }}; + List fetch = Arrays.asList("prediction"); + + Client client = new Client(); + List endpoints = Arrays.asList("localhost:9393"); + boolean succ = client.connect(endpoints); + if (succ != true) { + System.out.println("connect failed."); + return false; + } + + Map> fetch_map + = client.ensemble_predict(feed_data, fetch); + for (Map.Entry> entry : fetch_map.entrySet()) { + System.out.println("Model = " + entry.getKey()); + HashMap tt = entry.getValue(); + for (Map.Entry e : tt.entrySet()) { + System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); + } + } + return true; + } + + boolean bert() { + float[] input_mask = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + long[] position_ids = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + long[] input_ids = {101, 6843, 3241, 749, 8024, 7662, 2533, 1391, 2533, 2523, 7676, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + long[] segment_ids = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + HashMap feed_data + = new HashMap() {{ + put("input_mask", Nd4j.createFromArray(input_mask)); + put("position_ids", Nd4j.createFromArray(position_ids)); + put("input_ids", Nd4j.createFromArray(input_ids)); + put("segment_ids", Nd4j.createFromArray(segment_ids)); + }}; + List fetch = Arrays.asList("pooled_output"); + + Client client = new Client(); + List endpoints = Arrays.asList("localhost:9393"); + boolean succ = client.connect(endpoints); + if (succ != true) { + System.out.println("connect failed."); + return false; + } + + Map> fetch_map + = client.ensemble_predict(feed_data, fetch); + for (Map.Entry> entry : fetch_map.entrySet()) { + System.out.println("Model = " + entry.getKey()); + HashMap tt = entry.getValue(); + for (Map.Entry e : tt.entrySet()) { + System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); + } + } + return true; + } + + public static void main( String[] args ) { + // DL4J(Deep Learning for Java)Document: + // https://www.bookstack.cn/read/deeplearning4j/bcb48e8eeb38b0c6.md + PaddleServingClientExample e = new PaddleServingClientExample(); + boolean succ = false; + + for (String arg : args) { + System.out.format("[Example] %s\n", arg); + if ("fit_a_line".equals(arg)) { + succ = e.fit_a_line(); + } else if ("bert".equals(arg)) { + succ = e.bert(); + } else if ("model_ensemble".equals(arg)) { + succ = e.model_ensemble(); + } else if ("asyn_predict".equals(arg)) { + succ = e.asyn_predict(); + } else if ("batch_predict".equals(arg)) { + succ = e.batch_predict(); + } else { + System.out.format("%s not match: java -cp PaddleServingClientExample .\n", arg); + } + } + + if (succ == true) { + System.out.println("[Example] succ."); + } else { + System.out.println("[Example] fail."); + } } } diff --git a/java/src/main/java/io/paddle/serving/client/Client.java b/java/src/main/java/io/paddle/serving/client/Client.java index ad950a38169cd71f04af04bd9e32729e708cd8eb..c267563cd095a65b8c5488b257d0876bb1a26ab1 100644 --- a/java/src/main/java/io/paddle/serving/client/Client.java +++ b/java/src/main/java/io/paddle/serving/client/Client.java @@ -7,8 +7,6 @@ import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.grpc.StatusRuntimeException; -import com.google.common.util.concurrent.FutureCallback; -import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import org.nd4j.linalg.api.ndarray.INDArray; @@ -17,6 +15,7 @@ import org.nd4j.linalg.factory.Nd4j; import io.paddle.serving.grpc.*; import io.paddle.serving.configure.*; +import io.paddle.serving.client.PredictFuture; public class Client { private ManagedChannel channel_; @@ -84,7 +83,7 @@ public class Client { GetClientConfigResponse resp; try { resp = blockingStub_.getClientConfig(get_client_config_req); - } catch (StatusRuntimeException e) { + } catch (Exception e) { System.out.format("Get Client config failed: %s\n", e.toString()); return false; } @@ -298,10 +297,10 @@ public class Client { return ensemble_predict(feed, fetch, false); } - public PredictFuture async_predict( + public PredictFuture asyn_predict( HashMap feed, Iterable fetch) { - return async_predict(feed, fetch, false); + return asyn_predict(feed, fetch, false); } public Map predict( @@ -324,14 +323,14 @@ public class Client { return ensemble_predict(feed_batch, fetch, need_variant_tag); } - public PredictFuture async_predict( + public PredictFuture asyn_predict( HashMap feed, Iterable fetch, Boolean need_variant_tag) { List> feed_batch = new ArrayList>(); feed_batch.add(feed); - return async_predict(feed_batch, fetch, need_variant_tag); + return asyn_predict(feed_batch, fetch, need_variant_tag); } public Map predict( @@ -346,10 +345,10 @@ public class Client { return ensemble_predict(feed_batch, fetch, false); } - public PredictFuture async_predict( + public PredictFuture asyn_predict( List> feed_batch, Iterable fetch) { - return async_predict(feed_batch, fetch, false); + return asyn_predict(feed_batch, fetch, false); } public Map predict( @@ -390,7 +389,7 @@ public class Client { } } - public PredictFuture async_predict( + public PredictFuture asyn_predict( List> feed_batch, Iterable fetch, Boolean need_variant_tag) { @@ -405,36 +404,20 @@ public class Client { return predict_future; } - public static void main( String[] args ) { // DL4J(Deep Learning for Java)Document: // https://www.bookstack.cn/read/deeplearning4j/bcb48e8eeb38b0c6.md - //float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, - // 0.0582f, -0.0727f, -0.1583f, -0.0584f, - // 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; - //INDArray npdata = Nd4j.createFromArray(data); - //HashMap feed_data - // = new HashMap() {{ - // put("x", npdata); - //}}; - //List fetch = Arrays.asList("price"); - - //Map fetch_map = client.predict(feed_data, fetch); - //for (Map.Entry e : fetch_map.entrySet()) { - // System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); - //} - - - long[] data = {8, 233, 52, 601}; + float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, + 0.0582f, -0.0727f, -0.1583f, -0.0584f, + 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; INDArray npdata = Nd4j.createFromArray(data); - //System.out.println(npdata); HashMap feed_data = new HashMap() {{ - put("words", npdata); + put("x", npdata); }}; - List fetch = Arrays.asList("prediction"); - + List fetch = Arrays.asList("price"); + Client client = new Client(); List endpoints = Arrays.asList("localhost:9393"); boolean succ = client.connect(endpoints); @@ -443,59 +426,9 @@ public class Client { return; } - Map> fetch_map - = client.ensemble_predict(feed_data, fetch); - for (Map.Entry> entry : fetch_map.entrySet()) { - System.out.println("Model = " + entry.getKey()); - HashMap tt = entry.getValue(); - for (Map.Entry e : tt.entrySet()) { - System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); - } - } - } - -} - -class PredictFuture { - private ListenableFuture callFuture_; - private Function>> callBackFunc_; - - PredictFuture(ListenableFuture call_future, - Function>> call_back_func) { - callFuture_ = call_future; - callBackFunc_ = call_back_func; - } - - public Map get() throws Exception { - InferenceResponse resp = null; - try { - resp = callFuture_.get(); - } catch (Exception e) { - System.out.format("grpc failed: %s\n", e.toString()); - return null; - } - Map> ensemble_result - = callBackFunc_.apply(resp); - List>> list - = new ArrayList>>( - ensemble_result.entrySet()); - if (list.size() != 1) { - System.out.format("grpc failed: please use get_ensemble impl.\n"); - return null; - } - return list.get(0).getValue(); - } - - public Map> ensemble_get() throws Exception { - InferenceResponse resp = null; - try { - resp = callFuture_.get(); - } catch (Exception e) { - System.out.format("grpc failed: %s\n", e.toString()); - return null; + Map fetch_map = client.predict(feed_data, fetch); + for (Map.Entry e : fetch_map.entrySet()) { + System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); } - return callBackFunc_.apply(resp); } } diff --git a/java/src/main/java/io/paddle/serving/client/PredictFuture.java b/java/src/main/java/io/paddle/serving/client/PredictFuture.java new file mode 100644 index 0000000000000000000000000000000000000000..960e6d2b50c7636bb8ae3c59526fa911c97711f3 --- /dev/null +++ b/java/src/main/java/io/paddle/serving/client/PredictFuture.java @@ -0,0 +1,54 @@ +package io.paddle.serving.client; + +import java.util.*; +import java.util.function.Function; +import io.grpc.StatusRuntimeException; +import com.google.common.util.concurrent.ListenableFuture; +import org.nd4j.linalg.api.ndarray.INDArray; + +import io.paddle.serving.client.Client; +import io.paddle.serving.grpc.*; + +public class PredictFuture { + private ListenableFuture callFuture_; + private Function>> callBackFunc_; + + PredictFuture(ListenableFuture call_future, + Function>> call_back_func) { + callFuture_ = call_future; + callBackFunc_ = call_back_func; + } + + public Map get() { + InferenceResponse resp = null; + try { + resp = callFuture_.get(); + } catch (Exception e) { + System.out.format("grpc failed: %s\n", e.toString()); + return null; + } + Map> ensemble_result + = callBackFunc_.apply(resp); + List>> list + = new ArrayList>>( + ensemble_result.entrySet()); + if (list.size() != 1) { + System.out.format("grpc failed: please use get_ensemble impl.\n"); + return null; + } + return list.get(0).getValue(); + } + + public Map> ensemble_get() { + InferenceResponse resp = null; + try { + resp = callFuture_.get(); + } catch (Exception e) { + System.out.format("grpc failed: %s\n", e.toString()); + return null; + } + return callBackFunc_.apply(resp); + } +} diff --git a/java/src/main/resources/log4j2.xml b/java/src/main/resources/log4j2.xml new file mode 100644 index 0000000000000000000000000000000000000000..e13b79d3f92acca50cafde874b501513dbdb292f --- /dev/null +++ b/java/src/main/resources/log4j2.xml @@ -0,0 +1,13 @@ + + + + + + + + + + + + +