提交 950153ec 编写于 作者: B barrierye

[WIP] use INDArray

上级 808fbc92
......@@ -56,6 +56,8 @@
<grpc.version>1.27.2</grpc.version>
<protobuf.version>3.11.0</protobuf.version>
<protoc.version>3.11.0</protoc.version>
<nd4j.backend>nd4j-native</nd4j.backend>
<nd4j.version>1.0.0-beta7</nd4j.version>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
</properties>
......@@ -144,6 +146,11 @@
<artifactId>log4j-slf4j-impl</artifactId>
<version>2.12.1</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>${nd4j.backend}</artifactId>
<version>${nd4j.version}</version>
</dependency>
</dependencies>
......
......@@ -11,6 +11,9 @@ import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import io.paddle.serving.grpc.*;
import io.paddle.serving.configure.*;
......@@ -140,13 +143,8 @@ public class Client {
}
}
private List<? extends Number> _flattenList(List<? extends Number> x) {
// TODO
return x;
}
private InferenceRequest _packInferenceRequest(
List<Map<String, List<? extends Number>>> feed_batch,
List<Map<String, INDArray>> feed_batch,
Iterable<String> fetch) throws IllegalArgumentException {
List<String> feed_var_names = new ArrayList<String>();
feed_var_names.addAll(feed_batch.get(0).keySet());
......@@ -155,24 +153,24 @@ public class Client {
.addAllFeedVarNames(feed_var_names)
.addAllFetchVarNames(fetch)
.setIsPython(false);
for (Map<String, List<? extends Number>> feed_data: feed_batch) {
for (Map<String, INDArray> feed_data: feed_batch) {
FeedInst.Builder inst_builder = FeedInst.newBuilder();
for (String name: feed_var_names) {
Tensor.Builder tensor_builder = Tensor.newBuilder();
List<? extends Number> variable = feed_data.get(name);
List<? extends Number> flattened_list = _flattenList(variable);
INDArray variable = feed_data.get(name);
INDArray flattened_list = variable.reshape({-1});
int v_type = feedTypes_.get(name);
if (v_type == 0) { // int64
for (Number x: flattened_list) {
tensor_builder.addInt64Data((long)x);
for (long x: flattened_list) {
tensor_builder.addInt64Data(x);
}
} else if (v_type == 1) { // float32
for (Number x: flattened_list) {
tensor_builder.addFloatData((float)x);
for (float x: flattened_list) {
tensor_builder.addFloatData(x);
}
} else if (v_type == 2) { // int32
for (Number x: flattened_list) {
tensor_builder.addIntData((int)x);
for (int x: flattened_list) {
tensor_builder.addIntData(x);
}
} else {
throw new IllegalArgumentException("error tensor value type.");
......@@ -185,7 +183,7 @@ public class Client {
return req_builder.build();
}
private Map<String, ? extends Map<String, List<? extends Number>>>
private Map<String, ? extends Map<String, INDArray>>
_unpackInferenceResponse(
InferenceResponse resp,
Iterable<String> fetch,
......@@ -194,7 +192,7 @@ public class Client {
resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag);
}
private static Map<String, ? extends Map<String, List<? extends Number>>>
private static Map<String, ? extends Map<String, INDArray>>
_staticUnpackInferenceResponse(
InferenceResponse resp,
Iterable<String> fetch,
......@@ -205,12 +203,12 @@ public class Client {
return null;
}
String tag = resp.getTag();
Map<String, ? extends Map<String, List<? extends Number>>> multi_result_map
= new HashMap<String, HashMap<String, List<? extends Number>>>();
Map<String, ? extends Map<String, INDArray>> multi_result_map
= new HashMap<String, HashMap<String, INDArray>>();
for (ModelOutput model_result: resp.getOutputsList()) {
FetchInst inst = model_result.getInsts(0);
Map<String, List<? extends Number>> result_map
= new HashMap<String, List<? extends Number>>();
Map<String, INDArray> result_map
= new HashMap<String, INDArray>();
int index = 0;
for (String name: fetch) {
Tensor variable = inst.getTensorArray(index);
......@@ -237,29 +235,28 @@ public class Client {
return multi_result_map;
}
/*
public Map<String, ? extends Map<String, List<? extends Number>>> predict(
Map<String, List<? extends Number>> feed,
public Map<String, ? extends Map<String, INDArray>> predict(
Map<String, INDArray> feed,
Iterable<String> fetch) {
return predict(feed, fetch, false);
}
/*
public PredictFuture async_predict(
Map<String, List<? extends Number>> feed,
Map<String, INDArray> feed,
Iterable<String> fetch) {
return async_predict(feed, fetch, false);
}
public Map<String, ? extends Map<String, List<? extends Number>>> predict(
Map<String, List<? extends Number>> feed,
*/
public Map<String, ? extends Map<String, INDArray>> predict(
Map<String, INDArray> feed,
Iterable<String> fetch,
Boolean need_variant_tag) {
List<? extends Map<String, List<? extends Number>>> feed_batch
= new ArrayList<? extends Map<String, List<? extends Number>>>();
List<? extends Map<String, INDArray>> feed_batch
= new ArrayList<? extends Map<String, INDArray>>();
feed_batch.add(feed);
return predict(feed_batch, fetch, need_variant_tag);
}
/*
public PredictFuture async_predict(
Map<String, List<? extends Number>> feed,
Iterable<String> fetch,
......@@ -269,22 +266,21 @@ public class Client {
feed_batch.add(feed);
return async_predict(feed_batch, fetch, need_variant_tag);
}
*/
public Map<String, ? extends Map<String, List<? extends Number>>> predict(
List<? extends Map<String, List<? extends Number>>> feed_batch,
*/
public Map<String, ? extends Map<String, INDArray>> predict(
List<? extends Map<String, INDArray>> feed_batch,
Iterable<String> fetch) {
return predict(feed_batch, fetch, false);
}
/*
public PredictFuture async_predict(
List<? extends Map<String, List<? extends Number>>> feed_batch,
Iterable<String> fetch) {
return async_predict(feed_batch, fetch, false);
}
public Map<String, ? extends Map<String, List<? extends Number>>> predict(
List<? extends Map<String, List<? extends Number>>> feed_batch,
*/
public Map<String, ? extends Map<String, INDArray>> predict(
List<? extends Map<String, INDArray>> feed_batch,
Iterable<String> fetch,
Boolean need_variant_tag) {
InferenceRequest req = _packInferenceRequest(feed_batch, fetch);
......@@ -297,7 +293,7 @@ public class Client {
return Collections.emptyMap();
}
}
/*
public PredictFuture async_predict(
List<Map<String, List<? extends Number>>> feed_batch,
Iterable<String> fetch,
......@@ -311,24 +307,26 @@ public class Client {
resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag);
});
}
*/
public static void main( String[] args ) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.create(data);
System.out.println(npdata);
Client client = new Client();
List<String> endpoints = new ArrayList<String>()
.add("182.61.111.54:9393");
Client.connect(endpoints);
List<HashMap<String, List<Float>>> feed_batch
= new ArrayList<HashMap<String, List<Float>>>()
.add(new HashMap<String, ArrayList<Float>>()
.put("x", new ArrayList<Float>()
.add(0.0137f).add(-0.1136f).add(0.2553f)
.add(-0.0692f).add(0.0582f).add(-0.0727f)
.add(-0.1583f).add(-0.0584f).add(0.6283f)
.add(0.4919f).add(0.1856f).add(0.0795f)
.add(-0.0332f)));
List<HashMap<String, INDArray>> feed_batch
= new ArrayList<HashMap<String, INDArray>>()
.add(new HashMap<String, INDArray>()
.put("x", npdata));
List<String> fetch = new ArrayList<String>()
.add("price");
Map<String, ? extends Map<String, List<?>>> fetch_map
Map<String, ? extends Map<String, INDArray>> fetch_map
= client.predict(feed_batch, fetch);
System.out.println( "Hello World!" );
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册