提交 da50ea4b 编写于 作者: H HexToString

fix java Client indarray shape bug by HexToString

上级 32c7be83
...@@ -16,9 +16,11 @@ public class PaddleServingClientExample { ...@@ -16,9 +16,11 @@ public class PaddleServingClientExample {
0.0582f, -0.0727f, -0.1583f, -0.0584f, 0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.createFromArray(data); INDArray npdata = Nd4j.createFromArray(data);
long[] batch_shape = {1,13};
INDArray batch_npdata = npdata.reshape(batch_shape);
HashMap<String, INDArray> feed_data HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{ = new HashMap<String, INDArray>() {{
put("x", npdata); put("x", batch_npdata);
}}; }};
List<String> fetch = Arrays.asList("price"); List<String> fetch = Arrays.asList("price");
...@@ -69,12 +71,16 @@ public class PaddleServingClientExample { ...@@ -69,12 +71,16 @@ public class PaddleServingClientExample {
// Div(255.0) // Div(255.0)
INDArray image = RGBimage.divi(255.0); INDArray image = RGBimage.divi(255.0);
long[] batch_shape = {1,image.shape()[0],image.shape()[1],image.shape()[2]};
INDArray batch_image = image.reshape(batch_shape);
INDArray im_size = Nd4j.createFromArray(new int[]{height, width}); INDArray im_size = Nd4j.createFromArray(new int[]{height, width});
long[] batch_size_shape = {1,2};
INDArray batch_im_size = im_size.reshape(batch_size_shape);
HashMap<String, INDArray> feed_data HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{ = new HashMap<String, INDArray>() {{
put("image", image); put("image", batch_image);
put("im_size", im_size); put("im_size", batch_im_size);
}}; }};
List<String> fetch = Arrays.asList("save_infer_model/scale_0.tmp_0"); List<String> fetch = Arrays.asList("save_infer_model/scale_0.tmp_0");
......
...@@ -4,6 +4,9 @@ import java.util.*; ...@@ -4,6 +4,9 @@ import java.util.*;
import java.util.function.Function; import java.util.function.Function;
import java.lang.management.ManagementFactory; import java.lang.management.ManagementFactory;
import java.lang.management.RuntimeMXBean; import java.lang.management.RuntimeMXBean;
import java.util.stream.Collectors;
import java.util.List;
import java.util.ArrayList;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder; import io.grpc.ManagedChannelBuilder;
...@@ -238,7 +241,11 @@ public class Client { ...@@ -238,7 +241,11 @@ public class Client {
} else { } else {
throw new IllegalArgumentException("error tensor value type."); throw new IllegalArgumentException("error tensor value type.");
} }
tensor_builder.addAllShape(feedShapes_.get(name)); long[] longArray = variable.shape();
int[] intArray = Arrays.stream(longArray).mapToInt(i -> (int) i).toArray();
List<Integer> indarrayShapeList = Arrays.stream(intArray).boxed().collect(Collectors.toList());
//tensor_builder.addAllShape(feedShapes_.get(name));
tensor_builder.addAllShape(indarrayShapeList);
inst_builder.addTensorArray(tensor_builder.build()); inst_builder.addTensorArray(tensor_builder.build());
} }
req_builder.addInsts(inst_builder.build()); req_builder.addInsts(inst_builder.build());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册