From 64553f837531cda47ddc93ac2c742c7c1e976972 Mon Sep 17 00:00:00 2001 From: barrierye Date: Mon, 13 Jul 2020 23:05:32 +0800 Subject: [PATCH] support model ensemble --- .../java/io/paddle/serving/client/Client.java | 83 ++++++++++++------- 1 file changed, 55 insertions(+), 28 deletions(-) diff --git a/java/paddle-serving-sdk-java/src/main/java/io/paddle/serving/client/Client.java b/java/paddle-serving-sdk-java/src/main/java/io/paddle/serving/client/Client.java index 958580b5..15f9e000 100644 --- a/java/paddle-serving-sdk-java/src/main/java/io/paddle/serving/client/Client.java +++ b/java/paddle-serving-sdk-java/src/main/java/io/paddle/serving/client/Client.java @@ -58,7 +58,7 @@ public class Client { try { resp = blockingStub_.setTimeout(timeout_req); } catch (StatusRuntimeException e) { - System.out.format("Set RPC timeout failed: %s", e.toString()); + System.out.format("Set RPC timeout failed: %s\n", e.toString()); return false; } return resp.getErrCode() == 0; @@ -77,7 +77,7 @@ public class Client { blockingStub_ = MultiLangGeneralModelServiceGrpc.newBlockingStub(channel_); futureStub_ = MultiLangGeneralModelServiceGrpc.newFutureStub(channel_); } catch (Exception e) { - System.out.format("Connect failed: %s", e.toString()); + System.out.format("Connect failed: %s\n", e.toString()); return false; } GetClientConfigRequest get_client_config_req = GetClientConfigRequest.newBuilder().build(); @@ -85,7 +85,7 @@ public class Client { try { resp = blockingStub_.getClientConfig(get_client_config_req); } catch (StatusRuntimeException e) { - System.out.format("Get Client config failed: %s", e.toString()); + System.out.format("Get Client config failed: %s\n", e.toString()); return false; } String model_config_str = resp.getClientConfigStr(); @@ -98,7 +98,7 @@ public class Client { try { com.google.protobuf.TextFormat.getParser().merge(model_config_str, model_conf_builder); } catch (com.google.protobuf.TextFormat.ParseException e) { - System.out.format("Parse client config failed: %s", e.toString()); + System.out.format("Parse client config failed: %s\n", e.toString()); } GeneralModelConfig model_conf = model_conf_builder.build(); @@ -164,7 +164,9 @@ public class Client { long[] flattened_shape = {-1}; INDArray flattened_list = variable.reshape(flattened_shape); int v_type = feedTypes_.get(name); + System.out.println(flattened_list); NdIndexIterator iter = new NdIndexIterator(flattened_list.shape()); + //System.out.format("name: %s, type: %d\n", name, v_type); if (v_type == 0) { // int64 while (iter.hasNext()) { long[] next_index = iter.next(); @@ -231,39 +233,50 @@ public class Client { for (String name: fetch) { Tensor variable = inst.getTensorArray(index); int v_type = fetchTypes.get(name); + INDArray data = null; if (v_type == 0) { // int64 List list = variable.getInt64DataList(); long[] array = new long[list.size()]; for (int i = 0; i < list.size(); i++) { array[i] = list.get(i); } - result_map.put(name, Nd4j.create(array)); + data = Nd4j.createFromArray(array); } else if (v_type == 1) { // float32 List list = variable.getFloatDataList(); float[] array = new float[list.size()]; for (int i = 0; i < list.size(); i++) { array[i] = list.get(i); } - result_map.put(name, Nd4j.create(array)); + data = Nd4j.createFromArray(array); } else if (v_type == 2) { // int32 List list = variable.getIntDataList(); int[] array = new int[list.size()]; for (int i = 0; i < list.size(); i++) { array[i] = list.get(i); } - result_map.put(name, Nd4j.create(array)); + data = Nd4j.createFromArray(array); } else { throw new IllegalArgumentException("error tensor value type."); } - // TODO: shape + // shape + List shape_lsit = variable.getShapeList(); + int[] shape_array = new int[shape_lsit.size()]; + for (int i = 0; i < shape_lsit.size(); ++i) { + shape_array[i] = shape_lsit.get(i); + } + data = data.reshape(shape_array); + + // put data to result_map + result_map.put(name, data); + // lod if (lodTensorSet.contains(name)) { List list = variable.getLodList(); int[] array = new int[list.size()]; for (int i = 0; i < list.size(); i++) { array[i] = list.get(i); } - result_map.put(name + ".lod", Nd4j.create(array)); + result_map.put(name + ".lod", Nd4j.createFromArray(array)); } index += 1; } @@ -353,12 +366,12 @@ public class Client { = new ArrayList>>( ensemble_result.entrySet()); if (list.size() != 1) { - System.out.format("grpc failed: please use ensemble_predict impl."); + System.out.format("grpc failed: please use ensemble_predict impl.\n"); return null; } return list.get(0).getValue(); } catch (StatusRuntimeException e) { - System.out.format("grpc failed: %s", e.toString()); + System.out.format("grpc failed: %s\n", e.toString()); return null; } } @@ -373,7 +386,7 @@ public class Client { return _unpackInferenceResponse( resp, fetch, need_variant_tag); } catch (StatusRuntimeException e) { - System.out.format("grpc failed: %s", e.toString()); + System.out.format("grpc failed: %s\n", e.toString()); return null; } } @@ -394,30 +407,44 @@ public class Client { } public static void main( String[] args ) { + /* float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, 0.0582f, -0.0727f, -0.1583f, -0.0584f, 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; - INDArray npdata = Nd4j.create(data); - - Client client = new Client(); - List endpoints = Arrays.asList("localhost:9393"); - boolean succ = client.connect(endpoints); - if (succ != true) { - System.out.println("connect failed."); - return; - } + INDArray npdata = Nd4j.createFromArray(data); HashMap feed_data = new HashMap() {{ put("x", npdata); }}; List fetch = Arrays.asList("price"); + */ + /* Map fetch_map = client.predict(feed_data, fetch); for (Map.Entry e : fetch_map.entrySet()) { System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); } - /* + */ + + //int[] data = {8, 233, 52, 601}; + long[] data = {8, 233, 52, 601}; + INDArray npdata = Nd4j.createFromArray(data); + //System.out.println(npdata); + HashMap feed_data + = new HashMap() {{ + put("words", npdata); + }}; + List fetch = Arrays.asList("prediction"); + + Client client = new Client(); + List endpoints = Arrays.asList("localhost:9393"); + boolean succ = client.connect(endpoints); + if (succ != true) { + System.out.println("connect failed."); + return; + } + Map> fetch_map - = client.predict(feed_data, fetch); + = client.ensemble_predict(feed_data, fetch); for (Map.Entry> entry : fetch_map.entrySet()) { System.out.println("Model = " + entry.getKey()); HashMap tt = entry.getValue(); @@ -425,7 +452,7 @@ public class Client { System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); } } - */ + } } @@ -446,7 +473,7 @@ class PredictFuture { try { resp = callFuture_.get(); } catch (Exception e) { - System.out.format("grpc failed: %s", e.toString()); + System.out.format("grpc failed: %s\n", e.toString()); return null; } Map> ensemble_result @@ -455,18 +482,18 @@ class PredictFuture { = new ArrayList>>( ensemble_result.entrySet()); if (list.size() != 1) { - System.out.format("grpc failed: please use get_ensemble impl."); + System.out.format("grpc failed: please use get_ensemble impl.\n"); return null; } return list.get(0).getValue(); } - public Map> get_ensemble() throws Exception { + public Map> ensemble_get() throws Exception { InferenceResponse resp = null; try { resp = callFuture_.get(); } catch (Exception e) { - System.out.format("grpc failed: %s", e.toString()); + System.out.format("grpc failed: %s\n", e.toString()); return null; } return callBackFunc_.apply(resp); -- GitLab