提交 14ce96a2 编写于 作者: B barrierye

add some example

上级 eea93609
import io.paddle.serving.client.Client; import io.paddle.serving.client.*;
/** import org.nd4j.linalg.api.ndarray.INDArray;
* Hello world! import org.nd4j.linalg.api.iter.NdIndexIterator;
* import org.nd4j.linalg.factory.Nd4j;
*/ import java.util.*;
public class PaddleServingClientExample { public class PaddleServingClientExample {
public static void main( String[] args ) { boolean fit_a_line() {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.createFromArray(data);
HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{
put("x", npdata);
}};
List<String> fetch = Arrays.asList("price");
Client client = new Client();
List<String> endpoints = Arrays.asList("localhost:9393");
boolean succ = client.connect(endpoints);
if (succ != true) {
System.out.println("connect failed.");
return false;
}
Map<String, INDArray> fetch_map = client.predict(feed_data, fetch);
for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
}
return true;
}
boolean batch_predict() {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.createFromArray(data);
HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{
put("x", npdata);
}};
List<HashMap<String, INDArray>> feed_batch
= new ArrayList<HashMap<String, INDArray>>() {{
add(feed_data);
add(feed_data);
}};
List<String> fetch = Arrays.asList("price");
Client client = new Client(); Client client = new Client();
System.out.println( "Hello World!" ); List<String> endpoints = Arrays.asList("localhost:9393");
boolean succ = client.connect(endpoints);
if (succ != true) {
System.out.println("connect failed.");
return false;
}
Map<String, INDArray> fetch_map = client.predict(feed_batch, fetch);
for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
}
return true;
}
boolean asyn_predict() {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.createFromArray(data);
HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{
put("x", npdata);
}};
List<String> fetch = Arrays.asList("price");
Client client = new Client();
List<String> endpoints = Arrays.asList("localhost:9393");
boolean succ = client.connect(endpoints);
if (succ != true) {
System.out.println("connect failed.");
return false;
}
PredictFuture future = client.asyn_predict(feed_data, fetch);
Map<String, INDArray> fetch_map = future.get();
if (fetch_map == null) {
System.out.println("Get future reslut failed");
return false;
}
for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
}
return true;
}
boolean model_ensemble() {
long[] data = {8, 233, 52, 601};
INDArray npdata = Nd4j.createFromArray(data);
HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{
put("words", npdata);
}};
List<String> fetch = Arrays.asList("prediction");
Client client = new Client();
List<String> endpoints = Arrays.asList("localhost:9393");
boolean succ = client.connect(endpoints);
if (succ != true) {
System.out.println("connect failed.");
return false;
}
Map<String, HashMap<String, INDArray>> fetch_map
= client.ensemble_predict(feed_data, fetch);
for (Map.Entry<String, HashMap<String, INDArray>> entry : fetch_map.entrySet()) {
System.out.println("Model = " + entry.getKey());
HashMap<String, INDArray> tt = entry.getValue();
for (Map.Entry<String, INDArray> e : tt.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
}
}
return true;
}
boolean bert() {
float[] input_mask = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
long[] position_ids = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
long[] input_ids = {101, 6843, 3241, 749, 8024, 7662, 2533, 1391, 2533, 2523, 7676, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
long[] segment_ids = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{
put("input_mask", Nd4j.createFromArray(input_mask));
put("position_ids", Nd4j.createFromArray(position_ids));
put("input_ids", Nd4j.createFromArray(input_ids));
put("segment_ids", Nd4j.createFromArray(segment_ids));
}};
List<String> fetch = Arrays.asList("pooled_output");
Client client = new Client();
List<String> endpoints = Arrays.asList("localhost:9393");
boolean succ = client.connect(endpoints);
if (succ != true) {
System.out.println("connect failed.");
return false;
}
Map<String, HashMap<String, INDArray>> fetch_map
= client.ensemble_predict(feed_data, fetch);
for (Map.Entry<String, HashMap<String, INDArray>> entry : fetch_map.entrySet()) {
System.out.println("Model = " + entry.getKey());
HashMap<String, INDArray> tt = entry.getValue();
for (Map.Entry<String, INDArray> e : tt.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
}
}
return true;
}
public static void main( String[] args ) {
// DL4J(Deep Learning for Java)Document:
// https://www.bookstack.cn/read/deeplearning4j/bcb48e8eeb38b0c6.md
PaddleServingClientExample e = new PaddleServingClientExample();
boolean succ = false;
for (String arg : args) {
System.out.format("[Example] %s\n", arg);
if ("fit_a_line".equals(arg)) {
succ = e.fit_a_line();
} else if ("bert".equals(arg)) {
succ = e.bert();
} else if ("model_ensemble".equals(arg)) {
succ = e.model_ensemble();
} else if ("asyn_predict".equals(arg)) {
succ = e.asyn_predict();
} else if ("batch_predict".equals(arg)) {
succ = e.batch_predict();
} else {
System.out.format("%s not match: java -cp <jar> PaddleServingClientExample <exp>.\n", arg);
}
}
if (succ == true) {
System.out.println("[Example] succ.");
} else {
System.out.println("[Example] fail.");
}
} }
} }
...@@ -7,8 +7,6 @@ import io.grpc.ManagedChannel; ...@@ -7,8 +7,6 @@ import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder; import io.grpc.ManagedChannelBuilder;
import io.grpc.StatusRuntimeException; import io.grpc.StatusRuntimeException;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListenableFuture;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
...@@ -17,6 +15,7 @@ import org.nd4j.linalg.factory.Nd4j; ...@@ -17,6 +15,7 @@ import org.nd4j.linalg.factory.Nd4j;
import io.paddle.serving.grpc.*; import io.paddle.serving.grpc.*;
import io.paddle.serving.configure.*; import io.paddle.serving.configure.*;
import io.paddle.serving.client.PredictFuture;
public class Client { public class Client {
private ManagedChannel channel_; private ManagedChannel channel_;
...@@ -84,7 +83,7 @@ public class Client { ...@@ -84,7 +83,7 @@ public class Client {
GetClientConfigResponse resp; GetClientConfigResponse resp;
try { try {
resp = blockingStub_.getClientConfig(get_client_config_req); resp = blockingStub_.getClientConfig(get_client_config_req);
} catch (StatusRuntimeException e) { } catch (Exception e) {
System.out.format("Get Client config failed: %s\n", e.toString()); System.out.format("Get Client config failed: %s\n", e.toString());
return false; return false;
} }
...@@ -298,10 +297,10 @@ public class Client { ...@@ -298,10 +297,10 @@ public class Client {
return ensemble_predict(feed, fetch, false); return ensemble_predict(feed, fetch, false);
} }
public PredictFuture async_predict( public PredictFuture asyn_predict(
HashMap<String, INDArray> feed, HashMap<String, INDArray> feed,
Iterable<String> fetch) { Iterable<String> fetch) {
return async_predict(feed, fetch, false); return asyn_predict(feed, fetch, false);
} }
public Map<String, INDArray> predict( public Map<String, INDArray> predict(
...@@ -324,14 +323,14 @@ public class Client { ...@@ -324,14 +323,14 @@ public class Client {
return ensemble_predict(feed_batch, fetch, need_variant_tag); return ensemble_predict(feed_batch, fetch, need_variant_tag);
} }
public PredictFuture async_predict( public PredictFuture asyn_predict(
HashMap<String, INDArray> feed, HashMap<String, INDArray> feed,
Iterable<String> fetch, Iterable<String> fetch,
Boolean need_variant_tag) { Boolean need_variant_tag) {
List<HashMap<String, INDArray>> feed_batch List<HashMap<String, INDArray>> feed_batch
= new ArrayList<HashMap<String, INDArray>>(); = new ArrayList<HashMap<String, INDArray>>();
feed_batch.add(feed); feed_batch.add(feed);
return async_predict(feed_batch, fetch, need_variant_tag); return asyn_predict(feed_batch, fetch, need_variant_tag);
} }
public Map<String, INDArray> predict( public Map<String, INDArray> predict(
...@@ -346,10 +345,10 @@ public class Client { ...@@ -346,10 +345,10 @@ public class Client {
return ensemble_predict(feed_batch, fetch, false); return ensemble_predict(feed_batch, fetch, false);
} }
public PredictFuture async_predict( public PredictFuture asyn_predict(
List<HashMap<String, INDArray>> feed_batch, List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch) { Iterable<String> fetch) {
return async_predict(feed_batch, fetch, false); return asyn_predict(feed_batch, fetch, false);
} }
public Map<String, INDArray> predict( public Map<String, INDArray> predict(
...@@ -390,7 +389,7 @@ public class Client { ...@@ -390,7 +389,7 @@ public class Client {
} }
} }
public PredictFuture async_predict( public PredictFuture asyn_predict(
List<HashMap<String, INDArray>> feed_batch, List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch, Iterable<String> fetch,
Boolean need_variant_tag) { Boolean need_variant_tag) {
...@@ -405,36 +404,20 @@ public class Client { ...@@ -405,36 +404,20 @@ public class Client {
return predict_future; return predict_future;
} }
public static void main( String[] args ) { public static void main( String[] args ) {
// DL4J(Deep Learning for Java)Document: // DL4J(Deep Learning for Java)Document:
// https://www.bookstack.cn/read/deeplearning4j/bcb48e8eeb38b0c6.md // https://www.bookstack.cn/read/deeplearning4j/bcb48e8eeb38b0c6.md
//float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
// 0.0582f, -0.0727f, -0.1583f, -0.0584f, 0.0582f, -0.0727f, -0.1583f, -0.0584f,
// 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
//INDArray npdata = Nd4j.createFromArray(data);
//HashMap<String, INDArray> feed_data
// = new HashMap<String, INDArray>() {{
// put("x", npdata);
//}};
//List<String> fetch = Arrays.asList("price");
//Map<String, INDArray> fetch_map = client.predict(feed_data, fetch);
//for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) {
// System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
//}
long[] data = {8, 233, 52, 601};
INDArray npdata = Nd4j.createFromArray(data); INDArray npdata = Nd4j.createFromArray(data);
//System.out.println(npdata);
HashMap<String, INDArray> feed_data HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{ = new HashMap<String, INDArray>() {{
put("words", npdata); put("x", npdata);
}}; }};
List<String> fetch = Arrays.asList("prediction"); List<String> fetch = Arrays.asList("price");
Client client = new Client(); Client client = new Client();
List<String> endpoints = Arrays.asList("localhost:9393"); List<String> endpoints = Arrays.asList("localhost:9393");
boolean succ = client.connect(endpoints); boolean succ = client.connect(endpoints);
...@@ -443,59 +426,9 @@ public class Client { ...@@ -443,59 +426,9 @@ public class Client {
return; return;
} }
Map<String, HashMap<String, INDArray>> fetch_map Map<String, INDArray> fetch_map = client.predict(feed_data, fetch);
= client.ensemble_predict(feed_data, fetch); for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) {
for (Map.Entry<String, HashMap<String, INDArray>> entry : fetch_map.entrySet()) { System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
System.out.println("Model = " + entry.getKey());
HashMap<String, INDArray> tt = entry.getValue();
for (Map.Entry<String, INDArray> e : tt.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
}
}
}
}
class PredictFuture {
private ListenableFuture<InferenceResponse> callFuture_;
private Function<InferenceResponse,
Map<String, HashMap<String, INDArray>>> callBackFunc_;
PredictFuture(ListenableFuture<InferenceResponse> call_future,
Function<InferenceResponse,
Map<String, HashMap<String, INDArray>>> call_back_func) {
callFuture_ = call_future;
callBackFunc_ = call_back_func;
}
public Map<String, INDArray> get() throws Exception {
InferenceResponse resp = null;
try {
resp = callFuture_.get();
} catch (Exception e) {
System.out.format("grpc failed: %s\n", e.toString());
return null;
}
Map<String, HashMap<String, INDArray>> ensemble_result
= callBackFunc_.apply(resp);
List<Map.Entry<String, HashMap<String, INDArray>>> list
= new ArrayList<Map.Entry<String, HashMap<String, INDArray>>>(
ensemble_result.entrySet());
if (list.size() != 1) {
System.out.format("grpc failed: please use get_ensemble impl.\n");
return null;
}
return list.get(0).getValue();
}
public Map<String, HashMap<String, INDArray>> ensemble_get() throws Exception {
InferenceResponse resp = null;
try {
resp = callFuture_.get();
} catch (Exception e) {
System.out.format("grpc failed: %s\n", e.toString());
return null;
} }
return callBackFunc_.apply(resp);
} }
} }
package io.paddle.serving.client;
import java.util.*;
import java.util.function.Function;
import io.grpc.StatusRuntimeException;
import com.google.common.util.concurrent.ListenableFuture;
import org.nd4j.linalg.api.ndarray.INDArray;
import io.paddle.serving.client.Client;
import io.paddle.serving.grpc.*;
public class PredictFuture {
private ListenableFuture<InferenceResponse> callFuture_;
private Function<InferenceResponse,
Map<String, HashMap<String, INDArray>>> callBackFunc_;
PredictFuture(ListenableFuture<InferenceResponse> call_future,
Function<InferenceResponse,
Map<String, HashMap<String, INDArray>>> call_back_func) {
callFuture_ = call_future;
callBackFunc_ = call_back_func;
}
public Map<String, INDArray> get() {
InferenceResponse resp = null;
try {
resp = callFuture_.get();
} catch (Exception e) {
System.out.format("grpc failed: %s\n", e.toString());
return null;
}
Map<String, HashMap<String, INDArray>> ensemble_result
= callBackFunc_.apply(resp);
List<Map.Entry<String, HashMap<String, INDArray>>> list
= new ArrayList<Map.Entry<String, HashMap<String, INDArray>>>(
ensemble_result.entrySet());
if (list.size() != 1) {
System.out.format("grpc failed: please use get_ensemble impl.\n");
return null;
}
return list.get(0).getValue();
}
public Map<String, HashMap<String, INDArray>> ensemble_get() {
InferenceResponse resp = null;
try {
resp = callFuture_.get();
} catch (Exception e) {
System.out.format("grpc failed: %s\n", e.toString());
return null;
}
return callBackFunc_.apply(resp);
}
}
<?xml version="1.0" encoding="UTF-8"?>
<Configuration status="INFO">
<Appenders>
<Console name="Console" target="SYSTEM_OUT">
<PatternLayout pattern="%highlight{%d{yyyy-MM-dd HH:mm:ss} %C %M %n%p: %m%n}{STYLE=Logback}"/>
</Console>
</Appenders>
<Loggers>
<Root level="INFO">
<AppenderRef ref="Console"/>
</Root>
</Loggers>
</Configuration>
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册