From d980d775f7957de2280242580ba15553888f1bee Mon Sep 17 00:00:00 2001 From: barrierye Date: Sun, 28 Jun 2020 22:08:08 +0800 Subject: [PATCH] [WIP] use INDArray --- java/paddle-serving-sdk-java/pom.xml | 7 ++ .../java/io/paddle/serving/client/Client.java | 100 +++++++++--------- 2 files changed, 56 insertions(+), 51 deletions(-) diff --git a/java/paddle-serving-sdk-java/pom.xml b/java/paddle-serving-sdk-java/pom.xml index 693c1bfa..821f395f 100644 --- a/java/paddle-serving-sdk-java/pom.xml +++ b/java/paddle-serving-sdk-java/pom.xml @@ -56,6 +56,8 @@ 1.27.2 3.11.0 3.11.0 + nd4j-native + 1.0.0-beta7 1.8 1.8 @@ -144,6 +146,11 @@ log4j-slf4j-impl 2.12.1 + + org.nd4j + ${nd4j.backend} + ${nd4j.version} + diff --git a/java/paddle-serving-sdk-java/src/main/java/io/paddle/serving/client/Client.java b/java/paddle-serving-sdk-java/src/main/java/io/paddle/serving/client/Client.java index 083eb56e..cef718d9 100644 --- a/java/paddle-serving-sdk-java/src/main/java/io/paddle/serving/client/Client.java +++ b/java/paddle-serving-sdk-java/src/main/java/io/paddle/serving/client/Client.java @@ -11,6 +11,9 @@ import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + import io.paddle.serving.grpc.*; import io.paddle.serving.configure.*; @@ -140,13 +143,8 @@ public class Client { } } - private List _flattenList(List x) { - // TODO - return x; - } - private InferenceRequest _packInferenceRequest( - List>> feed_batch, + List> feed_batch, Iterable fetch) throws IllegalArgumentException { List feed_var_names = new ArrayList(); feed_var_names.addAll(feed_batch.get(0).keySet()); @@ -155,24 +153,24 @@ public class Client { .addAllFeedVarNames(feed_var_names) .addAllFetchVarNames(fetch) .setIsPython(false); - for (Map> feed_data: feed_batch) { + for (Map feed_data: feed_batch) { FeedInst.Builder inst_builder = FeedInst.newBuilder(); for (String name: feed_var_names) { Tensor.Builder tensor_builder = Tensor.newBuilder(); - List variable = feed_data.get(name); - List flattened_list = _flattenList(variable); + INDArray variable = feed_data.get(name); + INDArray flattened_list = variable.reshape({-1}); int v_type = feedTypes_.get(name); if (v_type == 0) { // int64 - for (Number x: flattened_list) { - tensor_builder.addInt64Data((long)x); + for (long x: flattened_list) { + tensor_builder.addInt64Data(x); } } else if (v_type == 1) { // float32 - for (Number x: flattened_list) { - tensor_builder.addFloatData((float)x); + for (float x: flattened_list) { + tensor_builder.addFloatData(x); } } else if (v_type == 2) { // int32 - for (Number x: flattened_list) { - tensor_builder.addIntData((int)x); + for (int x: flattened_list) { + tensor_builder.addIntData(x); } } else { throw new IllegalArgumentException("error tensor value type."); @@ -185,7 +183,7 @@ public class Client { return req_builder.build(); } - private Map>> + private Map> _unpackInferenceResponse( InferenceResponse resp, Iterable fetch, @@ -194,7 +192,7 @@ public class Client { resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag); } - private static Map>> + private static Map> _staticUnpackInferenceResponse( InferenceResponse resp, Iterable fetch, @@ -205,12 +203,12 @@ public class Client { return null; } String tag = resp.getTag(); - Map>> multi_result_map - = new HashMap>>(); + Map> multi_result_map + = new HashMap>(); for (ModelOutput model_result: resp.getOutputsList()) { FetchInst inst = model_result.getInsts(0); - Map> result_map - = new HashMap>(); + Map result_map + = new HashMap(); int index = 0; for (String name: fetch) { Tensor variable = inst.getTensorArray(index); @@ -237,29 +235,28 @@ public class Client { return multi_result_map; } - /* - public Map>> predict( - Map> feed, + public Map> predict( + Map feed, Iterable fetch) { return predict(feed, fetch, false); } - +/* public PredictFuture async_predict( - Map> feed, + Map feed, Iterable fetch) { return async_predict(feed, fetch, false); } - - public Map>> predict( - Map> feed, +*/ + public Map> predict( + Map feed, Iterable fetch, Boolean need_variant_tag) { - List>> feed_batch - = new ArrayList>>(); + List> feed_batch + = new ArrayList>(); feed_batch.add(feed); return predict(feed_batch, fetch, need_variant_tag); } - +/* public PredictFuture async_predict( Map> feed, Iterable fetch, @@ -269,22 +266,21 @@ public class Client { feed_batch.add(feed); return async_predict(feed_batch, fetch, need_variant_tag); } - */ - - public Map>> predict( - List>> feed_batch, +*/ + public Map> predict( + List> feed_batch, Iterable fetch) { return predict(feed_batch, fetch, false); } - + /* public PredictFuture async_predict( List>> feed_batch, Iterable fetch) { return async_predict(feed_batch, fetch, false); } - - public Map>> predict( - List>> feed_batch, +*/ + public Map> predict( + List> feed_batch, Iterable fetch, Boolean need_variant_tag) { InferenceRequest req = _packInferenceRequest(feed_batch, fetch); @@ -297,7 +293,7 @@ public class Client { return Collections.emptyMap(); } } - + /* public PredictFuture async_predict( List>> feed_batch, Iterable fetch, @@ -311,24 +307,26 @@ public class Client { resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag); }); } + */ public static void main( String[] args ) { + float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, + 0.0582f, -0.0727f, -0.1583f, -0.0584f, + 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; + INDArray npdata = Nd4j.create(data); + System.out.println(npdata); + Client client = new Client(); List endpoints = new ArrayList() .add("182.61.111.54:9393"); Client.connect(endpoints); - List>> feed_batch - = new ArrayList>>() - .add(new HashMap>() - .put("x", new ArrayList() - .add(0.0137f).add(-0.1136f).add(0.2553f) - .add(-0.0692f).add(0.0582f).add(-0.0727f) - .add(-0.1583f).add(-0.0584f).add(0.6283f) - .add(0.4919f).add(0.1856f).add(0.0795f) - .add(-0.0332f))); + List> feed_batch + = new ArrayList>() + .add(new HashMap() + .put("x", npdata)); List fetch = new ArrayList() .add("price"); - Map>> fetch_map + Map> fetch_map = client.predict(feed_batch, fetch); System.out.println( "Hello World!" ); } -- GitLab