diff --git a/java/paddle-serving-sdk-java/src/main/java/io/paddle/serving/client/Client.java b/java/paddle-serving-sdk-java/src/main/java/io/paddle/serving/client/Client.java index ac1218f747b6acc742eafedb544524cb8b38a414..083eb56edc5146d82dcc3733e1760f5598644f49 100644 --- a/java/paddle-serving-sdk-java/src/main/java/io/paddle/serving/client/Client.java +++ b/java/paddle-serving-sdk-java/src/main/java/io/paddle/serving/client/Client.java @@ -61,7 +61,7 @@ public class Client { return resp.getErrCode() == 0; } - public Boolean connect(String[] endpoints) { + public Boolean connect(List endpoints) { String target = "ipv4:" + String.join(",", endpoints); // TODO: max_receive_message_length and max_send_message_length try { @@ -233,14 +233,61 @@ public class Client { } } + // TODO: tag return multi_result_map; } + /* public Map>> predict( - List>> feed, + Map> feed, + Iterable fetch) { + return predict(feed, fetch, false); + } + + public PredictFuture async_predict( + Map> feed, + Iterable fetch) { + return async_predict(feed, fetch, false); + } + + public Map>> predict( + Map> feed, Iterable fetch, Boolean need_variant_tag) { - InferenceRequest req = _packInferenceRequest(feed, fetch); + List>> feed_batch + = new ArrayList>>(); + feed_batch.add(feed); + return predict(feed_batch, fetch, need_variant_tag); + } + + public PredictFuture async_predict( + Map> feed, + Iterable fetch, + Boolean need_variant_tag) { + List>> feed_batch + = new ArrayList>>(); + feed_batch.add(feed); + return async_predict(feed_batch, fetch, need_variant_tag); + } + */ + + public Map>> predict( + List>> feed_batch, + Iterable fetch) { + return predict(feed_batch, fetch, false); + } + + public PredictFuture async_predict( + List>> feed_batch, + Iterable fetch) { + return async_predict(feed_batch, fetch, false); + } + + public Map>> predict( + List>> feed_batch, + Iterable fetch, + Boolean need_variant_tag) { + InferenceRequest req = _packInferenceRequest(feed_batch, fetch); try { InferenceResponse resp = blockingStub_.inference(req); return _unpackInferenceResponse( @@ -252,15 +299,11 @@ public class Client { } public PredictFuture async_predict( - List>> feed, + List>> feed_batch, Iterable fetch, Boolean need_variant_tag) { - InferenceRequest req = _packInferenceRequest(feed, fetch); + InferenceRequest req = _packInferenceRequest(feed_batch, fetch); ListenableFuture future = futureStub_.inference(req); - // Function>>> - // call_back_func = partial( - // Client::_unpackInferenceResponse, fetch, need_variant_tag); return new PredictFuture( future, (InferenceResponse resp) -> { @@ -270,6 +313,23 @@ public class Client { } public static void main( String[] args ) { + Client client = new Client(); + List endpoints = new ArrayList() + .add("182.61.111.54:9393"); + Client.connect(endpoints); + List>> feed_batch + = new ArrayList>>() + .add(new HashMap>() + .put("x", new ArrayList() + .add(0.0137f).add(-0.1136f).add(0.2553f) + .add(-0.0692f).add(0.0582f).add(-0.0727f) + .add(-0.1583f).add(-0.0584f).add(0.6283f) + .add(0.4919f).add(0.1856f).add(0.0795f) + .add(-0.0332f))); + List fetch = new ArrayList() + .add("price"); + Map>> fetch_map + = client.predict(feed_batch, fetch); System.out.println( "Hello World!" ); } }