From da50ea4bb43ab7ecf19ec6246e2d33c297fc5aa5 Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Mon, 8 Feb 2021 17:51:50 +0800 Subject: [PATCH] fix java Client indarray shape bug by HexToString --- .../src/main/java/PaddleServingClientExample.java | 14 ++++++++++---- .../main/java/io/paddle/serving/client/Client.java | 9 ++++++++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/java/examples/src/main/java/PaddleServingClientExample.java b/java/examples/src/main/java/PaddleServingClientExample.java index cdc11df1..5f5e3ff6 100644 --- a/java/examples/src/main/java/PaddleServingClientExample.java +++ b/java/examples/src/main/java/PaddleServingClientExample.java @@ -16,9 +16,11 @@ public class PaddleServingClientExample { 0.0582f, -0.0727f, -0.1583f, -0.0584f, 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; INDArray npdata = Nd4j.createFromArray(data); + long[] batch_shape = {1,13}; + INDArray batch_npdata = npdata.reshape(batch_shape); HashMap feed_data = new HashMap() {{ - put("x", npdata); + put("x", batch_npdata); }}; List fetch = Arrays.asList("price"); @@ -69,12 +71,16 @@ public class PaddleServingClientExample { // Div(255.0) INDArray image = RGBimage.divi(255.0); - + long[] batch_shape = {1,image.shape()[0],image.shape()[1],image.shape()[2]}; + INDArray batch_image = image.reshape(batch_shape); + INDArray im_size = Nd4j.createFromArray(new int[]{height, width}); + long[] batch_size_shape = {1,2}; + INDArray batch_im_size = im_size.reshape(batch_size_shape); HashMap feed_data = new HashMap() {{ - put("image", image); - put("im_size", im_size); + put("image", batch_image); + put("im_size", batch_im_size); }}; List fetch = Arrays.asList("save_infer_model/scale_0.tmp_0"); diff --git a/java/src/main/java/io/paddle/serving/client/Client.java b/java/src/main/java/io/paddle/serving/client/Client.java index 742d4f91..aae7e6f8 100644 --- a/java/src/main/java/io/paddle/serving/client/Client.java +++ b/java/src/main/java/io/paddle/serving/client/Client.java @@ -4,6 +4,9 @@ import java.util.*; import java.util.function.Function; import java.lang.management.ManagementFactory; import java.lang.management.RuntimeMXBean; +import java.util.stream.Collectors; +import java.util.List; +import java.util.ArrayList; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; @@ -238,7 +241,11 @@ public class Client { } else { throw new IllegalArgumentException("error tensor value type."); } - tensor_builder.addAllShape(feedShapes_.get(name)); + long[] longArray = variable.shape(); + int[] intArray = Arrays.stream(longArray).mapToInt(i -> (int) i).toArray(); + List indarrayShapeList = Arrays.stream(intArray).boxed().collect(Collectors.toList()); + //tensor_builder.addAllShape(feedShapes_.get(name)); + tensor_builder.addAllShape(indarrayShapeList); inst_builder.addTensorArray(tensor_builder.build()); } req_builder.addInsts(inst_builder.build()); -- GitLab