From 27e45738b1574dd9258abd40dd11cbe72882a725 Mon Sep 17 00:00:00 2001 From: barrierye Date: Mon, 13 Jul 2020 18:58:48 +0800 Subject: [PATCH] succ run --- .../multi_lang_general_model_service.proto | 4 + .../java/io/paddle/serving/client/Client.java | 178 ++++++++++++------ .../multi_lang_general_model_service.proto | 2 - 3 files changed, 128 insertions(+), 56 deletions(-) diff --git a/core/configure/proto/multi_lang_general_model_service.proto b/core/configure/proto/multi_lang_general_model_service.proto index 2a8a8bc1..b83450ae 100644 --- a/core/configure/proto/multi_lang_general_model_service.proto +++ b/core/configure/proto/multi_lang_general_model_service.proto @@ -14,6 +14,10 @@ syntax = "proto2"; +option java_multiple_files = true; +option java_package = "io.paddle.serving.grpc"; +option java_outer_classname = "ServingProto"; + message Tensor { optional bytes data = 1; repeated int32 int_data = 2; diff --git a/java/paddle-serving-sdk-java/src/main/java/io/paddle/serving/client/Client.java b/java/paddle-serving-sdk-java/src/main/java/io/paddle/serving/client/Client.java index 61d1d7f3..958580b5 100644 --- a/java/paddle-serving-sdk-java/src/main/java/io/paddle/serving/client/Client.java +++ b/java/paddle-serving-sdk-java/src/main/java/io/paddle/serving/client/Client.java @@ -18,7 +18,6 @@ import org.nd4j.linalg.factory.Nd4j; import io.paddle.serving.grpc.*; import io.paddle.serving.configure.*; - public class Client { private ManagedChannel channel_; private MultiLangGeneralModelServiceGrpc.MultiLangGeneralModelServiceBlockingStub blockingStub_; @@ -66,12 +65,13 @@ public class Client { } public Boolean connect(List endpoints) { - // String target = "ipv4:" + String.join(",", endpoints); + // TODO + //String target = "ipv4:" + String.join(",", endpoints); String target = endpoints.get(0); - // TODO: max_receive_message_length and max_send_message_length try { channel_ = ManagedChannelBuilder.forTarget(target) .defaultLoadBalancingPolicy("round_robin") + .maxInboundMessageSize(Integer.MAX_VALUE) .usePlaintext() .build(); blockingStub_ = MultiLangGeneralModelServiceGrpc.newBlockingStub(channel_); @@ -82,17 +82,13 @@ public class Client { } GetClientConfigRequest get_client_config_req = GetClientConfigRequest.newBuilder().build(); GetClientConfigResponse resp; - System.out.println("try to call getClientConfig"); try { resp = blockingStub_.getClientConfig(get_client_config_req); } catch (StatusRuntimeException e) { System.out.format("Get Client config failed: %s", e.toString()); return false; } - System.out.println("succ call get client"); - System.out.println(resp); String model_config_str = resp.getClientConfigStr(); - System.out.println("model_config_str: " + model_config_str); _parseModelConfig(model_config_str); return true; } @@ -112,6 +108,7 @@ public class Client { feedShapes_ = new HashMap>(); fetchTypes_ = new HashMap(); lodTensorSet_ = new HashSet(); + feedTensorLen_ = new HashMap(); List feed_var_list = model_conf.getFeedVarList(); for (FeedVar feed_var : feed_var_list) { @@ -166,9 +163,6 @@ public class Client { INDArray variable = feed_data.get(name); long[] flattened_shape = {-1}; INDArray flattened_list = variable.reshape(flattened_shape); - for (Map.Entry entry : feedTypes_.entrySet()) { - System.out.println("Key = " + entry.getKey() + ", Value = " + entry.getValue()); - } int v_type = feedTypes_.get(name); NdIndexIterator iter = new NdIndexIterator(flattened_list.shape()); if (v_type == 0) { // int64 @@ -184,15 +178,17 @@ public class Client { tensor_builder.addFloatData(x); } } else if (v_type == 2) { // int32 - throw new IllegalArgumentException("error tensor value type."); - //TODO - /* while (iter.hasNext()) { long[] next_index = iter.next(); - int x = flattened_list.getInt(next_index); - // TODO: long to int? + // the interface of INDArray is strange: + // https://deeplearning4j.org/api/latest/org/nd4j/linalg/api/ndarray/INDArray.html + int[] int_next_index = new int[next_index.length]; + for(int i = 0; i < next_index.length; i++) { + int_next_index[i] = (int)next_index[i]; + } + int x = flattened_list.getInt(int_next_index); tensor_builder.addIntData(x); - }*/ + } } else { throw new IllegalArgumentException("error tensor value type."); } @@ -204,7 +200,7 @@ public class Client { return req_builder.build(); } - private HashMap> + private Map> _unpackInferenceResponse( InferenceResponse resp, Iterable fetch, @@ -213,7 +209,7 @@ public class Client { resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag); } - private static HashMap> + private static Map> _staticUnpackInferenceResponse( InferenceResponse resp, Iterable fetch, @@ -227,6 +223,7 @@ public class Client { HashMap> multi_result_map = new HashMap>(); for (ModelOutput model_result: resp.getOutputsList()) { + String engine_name = model_result.getEngineName(); FetchInst inst = model_result.getInsts(0); HashMap result_map = new HashMap(); @@ -270,25 +267,32 @@ public class Client { } index += 1; } + multi_result_map.put(engine_name, result_map); } // TODO: tag return multi_result_map; } - public Map> predict( + public Map predict( HashMap feed, Iterable fetch) { return predict(feed, fetch, false); } -/* + + public Map> ensemble_predict( + HashMap feed, + Iterable fetch) { + return ensemble_predict(feed, fetch, false); + } + public PredictFuture async_predict( - Map feed, + HashMap feed, Iterable fetch) { return async_predict(feed, fetch, false); } -*/ - public Map> predict( + + public Map predict( HashMap feed, Iterable fetch, Boolean need_variant_tag) { @@ -297,30 +301,69 @@ public class Client { feed_batch.add(feed); return predict(feed_batch, fetch, need_variant_tag); } -/* + + public Map> ensemble_predict( + HashMap feed, + Iterable fetch, + Boolean need_variant_tag) { + List> feed_batch + = new ArrayList>(); + feed_batch.add(feed); + return ensemble_predict(feed_batch, fetch, need_variant_tag); + } + public PredictFuture async_predict( - Map> feed, + HashMap feed, Iterable fetch, Boolean need_variant_tag) { - List>> feed_batch - = new ArrayList>>(); + List> feed_batch + = new ArrayList>(); feed_batch.add(feed); return async_predict(feed_batch, fetch, need_variant_tag); } -*/ - public Map> predict( + + public Map predict( List> feed_batch, Iterable fetch) { return predict(feed_batch, fetch, false); } - /* + + public Map> ensemble_predict( + List> feed_batch, + Iterable fetch) { + return ensemble_predict(feed_batch, fetch, false); + } + public PredictFuture async_predict( - List>> feed_batch, + List> feed_batch, Iterable fetch) { return async_predict(feed_batch, fetch, false); } -*/ - public Map> predict( + + public Map predict( + List> feed_batch, + Iterable fetch, + Boolean need_variant_tag) { + InferenceRequest req = _packInferenceRequest(feed_batch, fetch); + try { + InferenceResponse resp = blockingStub_.inference(req); + Map> ensemble_result + = _unpackInferenceResponse(resp, fetch, need_variant_tag); + List>> list + = new ArrayList>>( + ensemble_result.entrySet()); + if (list.size() != 1) { + System.out.format("grpc failed: please use ensemble_predict impl."); + return null; + } + return list.get(0).getValue(); + } catch (StatusRuntimeException e) { + System.out.format("grpc failed: %s", e.toString()); + return null; + } + } + + public Map> ensemble_predict( List> feed_batch, Iterable fetch, Boolean need_variant_tag) { @@ -331,31 +374,30 @@ public class Client { resp, fetch, need_variant_tag); } catch (StatusRuntimeException e) { System.out.format("grpc failed: %s", e.toString()); - return Collections.emptyMap(); + return null; } } - /* + public PredictFuture async_predict( - List>> feed_batch, + List> feed_batch, Iterable fetch, Boolean need_variant_tag) { InferenceRequest req = _packInferenceRequest(feed_batch, fetch); ListenableFuture future = futureStub_.inference(req); - return new PredictFuture( - future, - (InferenceResponse resp) -> { - return Client._staticUnpackInferenceResponse( - resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag); - }); + PredictFuture predict_future = new PredictFuture(future, + (InferenceResponse resp) -> { + return Client._staticUnpackInferenceResponse( + resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag); + } + ); + return predict_future; } - */ public static void main( String[] args ) { float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, 0.0582f, -0.0727f, -0.1583f, -0.0584f, 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; INDArray npdata = Nd4j.create(data); - System.out.println(npdata); Client client = new Client(); List endpoints = Arrays.asList("localhost:9393"); @@ -368,30 +410,58 @@ public class Client { = new HashMap() {{ put("x", npdata); }}; - List> feed_batch - = new ArrayList>() {{ - add(feed_data); - }}; List fetch = Arrays.asList("price"); + Map fetch_map = client.predict(feed_data, fetch); + for (Map.Entry e : fetch_map.entrySet()) { + System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); + } + /* Map> fetch_map - = client.predict(feed_batch, fetch); - System.out.println( "Hello World!" ); + = client.predict(feed_data, fetch); + for (Map.Entry> entry : fetch_map.entrySet()) { + System.out.println("Model = " + entry.getKey()); + HashMap tt = entry.getValue(); + for (Map.Entry e : tt.entrySet()) { + System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); + } + } + */ } } class PredictFuture { private ListenableFuture callFuture_; private Function>>> callBackFunc_; + Map>> callBackFunc_; PredictFuture(ListenableFuture call_future, Function>>> call_back_func) { + Map>> call_back_func) { callFuture_ = call_future; callBackFunc_ = call_back_func; } - public Map>> get() throws Exception { + public Map get() throws Exception { + InferenceResponse resp = null; + try { + resp = callFuture_.get(); + } catch (Exception e) { + System.out.format("grpc failed: %s", e.toString()); + return null; + } + Map> ensemble_result + = callBackFunc_.apply(resp); + List>> list + = new ArrayList>>( + ensemble_result.entrySet()); + if (list.size() != 1) { + System.out.format("grpc failed: please use get_ensemble impl."); + return null; + } + return list.get(0).getValue(); + } + + public Map> get_ensemble() throws Exception { InferenceResponse resp = null; try { resp = callFuture_.get(); @@ -401,4 +471,4 @@ class PredictFuture { } return callBackFunc_.apply(resp); } -} +} diff --git a/java/paddle-serving-sdk-java/src/main/proto/multi_lang_general_model_service.proto b/java/paddle-serving-sdk-java/src/main/proto/multi_lang_general_model_service.proto index efdaf4fb..89902d1d 100644 --- a/java/paddle-serving-sdk-java/src/main/proto/multi_lang_general_model_service.proto +++ b/java/paddle-serving-sdk-java/src/main/proto/multi_lang_general_model_service.proto @@ -18,8 +18,6 @@ option java_multiple_files = true; option java_package = "io.paddle.serving.grpc"; option java_outer_classname = "ServingProto"; -package paddle.serving.grpc; - message Tensor { optional bytes data = 1; repeated int32 int_data = 2; -- GitLab