diff --git a/java/examples/src/main/java/PaddleServingClientExample.java b/java/examples/src/main/java/PaddleServingClientExample.java index cdc11df130095d668734ae0a23adb12ef735b2ea..5f5e3ff655e7450d12f562229ae4cb2481ab4a54 100644 --- a/java/examples/src/main/java/PaddleServingClientExample.java +++ b/java/examples/src/main/java/PaddleServingClientExample.java @@ -16,9 +16,11 @@ public class PaddleServingClientExample { 0.0582f, -0.0727f, -0.1583f, -0.0584f, 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; INDArray npdata = Nd4j.createFromArray(data); + long[] batch_shape = {1,13}; + INDArray batch_npdata = npdata.reshape(batch_shape); HashMap feed_data = new HashMap() {{ - put("x", npdata); + put("x", batch_npdata); }}; List fetch = Arrays.asList("price"); @@ -69,12 +71,16 @@ public class PaddleServingClientExample { // Div(255.0) INDArray image = RGBimage.divi(255.0); - + long[] batch_shape = {1,image.shape()[0],image.shape()[1],image.shape()[2]}; + INDArray batch_image = image.reshape(batch_shape); + INDArray im_size = Nd4j.createFromArray(new int[]{height, width}); + long[] batch_size_shape = {1,2}; + INDArray batch_im_size = im_size.reshape(batch_size_shape); HashMap feed_data = new HashMap() {{ - put("image", image); - put("im_size", im_size); + put("image", batch_image); + put("im_size", batch_im_size); }}; List fetch = Arrays.asList("save_infer_model/scale_0.tmp_0"); diff --git a/java/src/main/java/io/paddle/serving/client/Client.java b/java/src/main/java/io/paddle/serving/client/Client.java index 742d4f91ce17555a2ea96f2a629717228ba18cef..aae7e6f8f50d4ca2baca877f2e51c8e71eb64af8 100644 --- a/java/src/main/java/io/paddle/serving/client/Client.java +++ b/java/src/main/java/io/paddle/serving/client/Client.java @@ -4,6 +4,9 @@ import java.util.*; import java.util.function.Function; import java.lang.management.ManagementFactory; import java.lang.management.RuntimeMXBean; +import java.util.stream.Collectors; +import java.util.List; +import java.util.ArrayList; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; @@ -238,7 +241,11 @@ public class Client { } else { throw new IllegalArgumentException("error tensor value type."); } - tensor_builder.addAllShape(feedShapes_.get(name)); + long[] longArray = variable.shape(); + int[] intArray = Arrays.stream(longArray).mapToInt(i -> (int) i).toArray(); + List indarrayShapeList = Arrays.stream(intArray).boxed().collect(Collectors.toList()); + //tensor_builder.addAllShape(feedShapes_.get(name)); + tensor_builder.addAllShape(indarrayShapeList); inst_builder.addTensorArray(tensor_builder.build()); } req_builder.addInsts(inst_builder.build());