From 4ab7e4783db23b989242ad44fb21c1a5c7dc321e Mon Sep 17 00:00:00 2001 From: barrierye Date: Mon, 13 Jul 2020 15:56:37 +0800 Subject: [PATCH] update code --- .../java/io/paddle/serving/client/Client.java | 117 ++++++++++++------ 1 file changed, 82 insertions(+), 35 deletions(-) diff --git a/java/paddle-serving-sdk-java/src/main/java/io/paddle/serving/client/Client.java b/java/paddle-serving-sdk-java/src/main/java/io/paddle/serving/client/Client.java index cef718d9..61d1d7f3 100644 --- a/java/paddle-serving-sdk-java/src/main/java/io/paddle/serving/client/Client.java +++ b/java/paddle-serving-sdk-java/src/main/java/io/paddle/serving/client/Client.java @@ -12,6 +12,7 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.factory.Nd4j; import io.paddle.serving.grpc.*; @@ -65,7 +66,8 @@ public class Client { } public Boolean connect(List endpoints) { - String target = "ipv4:" + String.join(",", endpoints); + // String target = "ipv4:" + String.join(",", endpoints); + String target = endpoints.get(0); // TODO: max_receive_message_length and max_send_message_length try { channel_ = ManagedChannelBuilder.forTarget(target) @@ -80,13 +82,17 @@ public class Client { } GetClientConfigRequest get_client_config_req = GetClientConfigRequest.newBuilder().build(); GetClientConfigResponse resp; + System.out.println("try to call getClientConfig"); try { resp = blockingStub_.getClientConfig(get_client_config_req); } catch (StatusRuntimeException e) { System.out.format("Get Client config failed: %s", e.toString()); return false; } + System.out.println("succ call get client"); + System.out.println(resp); String model_config_str = resp.getClientConfigStr(); + System.out.println("model_config_str: " + model_config_str); _parseModelConfig(model_config_str); return true; } @@ -144,7 +150,7 @@ public class Client { } private InferenceRequest _packInferenceRequest( - List> feed_batch, + List> feed_batch, Iterable fetch) throws IllegalArgumentException { List feed_var_names = new ArrayList(); feed_var_names.addAll(feed_batch.get(0).keySet()); @@ -153,25 +159,40 @@ public class Client { .addAllFeedVarNames(feed_var_names) .addAllFetchVarNames(fetch) .setIsPython(false); - for (Map feed_data: feed_batch) { + for (HashMap feed_data: feed_batch) { FeedInst.Builder inst_builder = FeedInst.newBuilder(); for (String name: feed_var_names) { Tensor.Builder tensor_builder = Tensor.newBuilder(); INDArray variable = feed_data.get(name); - INDArray flattened_list = variable.reshape({-1}); + long[] flattened_shape = {-1}; + INDArray flattened_list = variable.reshape(flattened_shape); + for (Map.Entry entry : feedTypes_.entrySet()) { + System.out.println("Key = " + entry.getKey() + ", Value = " + entry.getValue()); + } int v_type = feedTypes_.get(name); + NdIndexIterator iter = new NdIndexIterator(flattened_list.shape()); if (v_type == 0) { // int64 - for (long x: flattened_list) { + while (iter.hasNext()) { + long[] next_index = iter.next(); + long x = flattened_list.getLong(next_index); tensor_builder.addInt64Data(x); } } else if (v_type == 1) { // float32 - for (float x: flattened_list) { + while (iter.hasNext()) { + long[] next_index = iter.next(); + float x = flattened_list.getFloat(next_index); tensor_builder.addFloatData(x); } } else if (v_type == 2) { // int32 - for (int x: flattened_list) { + throw new IllegalArgumentException("error tensor value type."); + //TODO + /* + while (iter.hasNext()) { + long[] next_index = iter.next(); + int x = flattened_list.getInt(next_index); + // TODO: long to int? tensor_builder.addIntData(x); - } + }*/ } else { throw new IllegalArgumentException("error tensor value type."); } @@ -183,7 +204,7 @@ public class Client { return req_builder.build(); } - private Map> + private HashMap> _unpackInferenceResponse( InferenceResponse resp, Iterable fetch, @@ -192,7 +213,7 @@ public class Client { resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag); } - private static Map> + private static HashMap> _staticUnpackInferenceResponse( InferenceResponse resp, Iterable fetch, @@ -203,29 +224,49 @@ public class Client { return null; } String tag = resp.getTag(); - Map> multi_result_map + HashMap> multi_result_map = new HashMap>(); for (ModelOutput model_result: resp.getOutputsList()) { FetchInst inst = model_result.getInsts(0); - Map result_map + HashMap result_map = new HashMap(); int index = 0; for (String name: fetch) { Tensor variable = inst.getTensorArray(index); int v_type = fetchTypes.get(name); if (v_type == 0) { // int64 - result_map.put(name, variable.getInt64DataList()); + List list = variable.getInt64DataList(); + long[] array = new long[list.size()]; + for (int i = 0; i < list.size(); i++) { + array[i] = list.get(i); + } + result_map.put(name, Nd4j.create(array)); } else if (v_type == 1) { // float32 - result_map.put(name, variable.getFloatDataList()); + List list = variable.getFloatDataList(); + float[] array = new float[list.size()]; + for (int i = 0; i < list.size(); i++) { + array[i] = list.get(i); + } + result_map.put(name, Nd4j.create(array)); } else if (v_type == 2) { // int32 - result_map.put(name, variable.getIntDataList()); + List list = variable.getIntDataList(); + int[] array = new int[list.size()]; + for (int i = 0; i < list.size(); i++) { + array[i] = list.get(i); + } + result_map.put(name, Nd4j.create(array)); } else { throw new IllegalArgumentException("error tensor value type."); } // TODO: shape if (lodTensorSet.contains(name)) { - result_map.put(name + ".lod", variable.getLodList()); + List list = variable.getLodList(); + int[] array = new int[list.size()]; + for (int i = 0; i < list.size(); i++) { + array[i] = list.get(i); + } + result_map.put(name + ".lod", Nd4j.create(array)); } index += 1; } @@ -235,8 +276,8 @@ public class Client { return multi_result_map; } - public Map> predict( - Map feed, + public Map> predict( + HashMap feed, Iterable fetch) { return predict(feed, fetch, false); } @@ -247,12 +288,12 @@ public class Client { return async_predict(feed, fetch, false); } */ - public Map> predict( - Map feed, + public Map> predict( + HashMap feed, Iterable fetch, Boolean need_variant_tag) { - List> feed_batch - = new ArrayList>(); + List> feed_batch + = new ArrayList>(); feed_batch.add(feed); return predict(feed_batch, fetch, need_variant_tag); } @@ -267,8 +308,8 @@ public class Client { return async_predict(feed_batch, fetch, need_variant_tag); } */ - public Map> predict( - List> feed_batch, + public Map> predict( + List> feed_batch, Iterable fetch) { return predict(feed_batch, fetch, false); } @@ -279,8 +320,8 @@ public class Client { return async_predict(feed_batch, fetch, false); } */ - public Map> predict( - List> feed_batch, + public Map> predict( + List> feed_batch, Iterable fetch, Boolean need_variant_tag) { InferenceRequest req = _packInferenceRequest(feed_batch, fetch); @@ -317,16 +358,22 @@ public class Client { System.out.println(npdata); Client client = new Client(); - List endpoints = new ArrayList() - .add("182.61.111.54:9393"); - Client.connect(endpoints); + List endpoints = Arrays.asList("localhost:9393"); + boolean succ = client.connect(endpoints); + if (succ != true) { + System.out.println("connect failed."); + return; + } + HashMap feed_data + = new HashMap() {{ + put("x", npdata); + }}; List> feed_batch - = new ArrayList>() - .add(new HashMap() - .put("x", npdata)); - List fetch = new ArrayList() - .add("price"); - Map> fetch_map + = new ArrayList>() {{ + add(feed_data); + }}; + List fetch = Arrays.asList("price"); + Map> fetch_map = client.predict(feed_batch, fetch); System.out.println( "Hello World!" ); } -- GitLab