提交 950153ec 编写于 作者: B barrierye

[WIP] use INDArray

上级 808fbc92
...@@ -56,6 +56,8 @@ ...@@ -56,6 +56,8 @@
<grpc.version>1.27.2</grpc.version> <grpc.version>1.27.2</grpc.version>
<protobuf.version>3.11.0</protobuf.version> <protobuf.version>3.11.0</protobuf.version>
<protoc.version>3.11.0</protoc.version> <protoc.version>3.11.0</protoc.version>
<nd4j.backend>nd4j-native</nd4j.backend>
<nd4j.version>1.0.0-beta7</nd4j.version>
<maven.compiler.source>1.8</maven.compiler.source> <maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target> <maven.compiler.target>1.8</maven.compiler.target>
</properties> </properties>
...@@ -144,6 +146,11 @@ ...@@ -144,6 +146,11 @@
<artifactId>log4j-slf4j-impl</artifactId> <artifactId>log4j-slf4j-impl</artifactId>
<version>2.12.1</version> <version>2.12.1</version>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>${nd4j.backend}</artifactId>
<version>${nd4j.version}</version>
</dependency>
</dependencies> </dependencies>
......
...@@ -11,6 +11,9 @@ import com.google.common.util.concurrent.FutureCallback; ...@@ -11,6 +11,9 @@ import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListenableFuture;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import io.paddle.serving.grpc.*; import io.paddle.serving.grpc.*;
import io.paddle.serving.configure.*; import io.paddle.serving.configure.*;
...@@ -140,13 +143,8 @@ public class Client { ...@@ -140,13 +143,8 @@ public class Client {
} }
} }
private List<? extends Number> _flattenList(List<? extends Number> x) {
// TODO
return x;
}
private InferenceRequest _packInferenceRequest( private InferenceRequest _packInferenceRequest(
List<Map<String, List<? extends Number>>> feed_batch, List<Map<String, INDArray>> feed_batch,
Iterable<String> fetch) throws IllegalArgumentException { Iterable<String> fetch) throws IllegalArgumentException {
List<String> feed_var_names = new ArrayList<String>(); List<String> feed_var_names = new ArrayList<String>();
feed_var_names.addAll(feed_batch.get(0).keySet()); feed_var_names.addAll(feed_batch.get(0).keySet());
...@@ -155,24 +153,24 @@ public class Client { ...@@ -155,24 +153,24 @@ public class Client {
.addAllFeedVarNames(feed_var_names) .addAllFeedVarNames(feed_var_names)
.addAllFetchVarNames(fetch) .addAllFetchVarNames(fetch)
.setIsPython(false); .setIsPython(false);
for (Map<String, List<? extends Number>> feed_data: feed_batch) { for (Map<String, INDArray> feed_data: feed_batch) {
FeedInst.Builder inst_builder = FeedInst.newBuilder(); FeedInst.Builder inst_builder = FeedInst.newBuilder();
for (String name: feed_var_names) { for (String name: feed_var_names) {
Tensor.Builder tensor_builder = Tensor.newBuilder(); Tensor.Builder tensor_builder = Tensor.newBuilder();
List<? extends Number> variable = feed_data.get(name); INDArray variable = feed_data.get(name);
List<? extends Number> flattened_list = _flattenList(variable); INDArray flattened_list = variable.reshape({-1});
int v_type = feedTypes_.get(name); int v_type = feedTypes_.get(name);
if (v_type == 0) { // int64 if (v_type == 0) { // int64
for (Number x: flattened_list) { for (long x: flattened_list) {
tensor_builder.addInt64Data((long)x); tensor_builder.addInt64Data(x);
} }
} else if (v_type == 1) { // float32 } else if (v_type == 1) { // float32
for (Number x: flattened_list) { for (float x: flattened_list) {
tensor_builder.addFloatData((float)x); tensor_builder.addFloatData(x);
} }
} else if (v_type == 2) { // int32 } else if (v_type == 2) { // int32
for (Number x: flattened_list) { for (int x: flattened_list) {
tensor_builder.addIntData((int)x); tensor_builder.addIntData(x);
} }
} else { } else {
throw new IllegalArgumentException("error tensor value type."); throw new IllegalArgumentException("error tensor value type.");
...@@ -185,7 +183,7 @@ public class Client { ...@@ -185,7 +183,7 @@ public class Client {
return req_builder.build(); return req_builder.build();
} }
private Map<String, ? extends Map<String, List<? extends Number>>> private Map<String, ? extends Map<String, INDArray>>
_unpackInferenceResponse( _unpackInferenceResponse(
InferenceResponse resp, InferenceResponse resp,
Iterable<String> fetch, Iterable<String> fetch,
...@@ -194,7 +192,7 @@ public class Client { ...@@ -194,7 +192,7 @@ public class Client {
resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag); resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag);
} }
private static Map<String, ? extends Map<String, List<? extends Number>>> private static Map<String, ? extends Map<String, INDArray>>
_staticUnpackInferenceResponse( _staticUnpackInferenceResponse(
InferenceResponse resp, InferenceResponse resp,
Iterable<String> fetch, Iterable<String> fetch,
...@@ -205,12 +203,12 @@ public class Client { ...@@ -205,12 +203,12 @@ public class Client {
return null; return null;
} }
String tag = resp.getTag(); String tag = resp.getTag();
Map<String, ? extends Map<String, List<? extends Number>>> multi_result_map Map<String, ? extends Map<String, INDArray>> multi_result_map
= new HashMap<String, HashMap<String, List<? extends Number>>>(); = new HashMap<String, HashMap<String, INDArray>>();
for (ModelOutput model_result: resp.getOutputsList()) { for (ModelOutput model_result: resp.getOutputsList()) {
FetchInst inst = model_result.getInsts(0); FetchInst inst = model_result.getInsts(0);
Map<String, List<? extends Number>> result_map Map<String, INDArray> result_map
= new HashMap<String, List<? extends Number>>(); = new HashMap<String, INDArray>();
int index = 0; int index = 0;
for (String name: fetch) { for (String name: fetch) {
Tensor variable = inst.getTensorArray(index); Tensor variable = inst.getTensorArray(index);
...@@ -237,29 +235,28 @@ public class Client { ...@@ -237,29 +235,28 @@ public class Client {
return multi_result_map; return multi_result_map;
} }
/* public Map<String, ? extends Map<String, INDArray>> predict(
public Map<String, ? extends Map<String, List<? extends Number>>> predict( Map<String, INDArray> feed,
Map<String, List<? extends Number>> feed,
Iterable<String> fetch) { Iterable<String> fetch) {
return predict(feed, fetch, false); return predict(feed, fetch, false);
} }
/*
public PredictFuture async_predict( public PredictFuture async_predict(
Map<String, List<? extends Number>> feed, Map<String, INDArray> feed,
Iterable<String> fetch) { Iterable<String> fetch) {
return async_predict(feed, fetch, false); return async_predict(feed, fetch, false);
} }
*/
public Map<String, ? extends Map<String, List<? extends Number>>> predict( public Map<String, ? extends Map<String, INDArray>> predict(
Map<String, List<? extends Number>> feed, Map<String, INDArray> feed,
Iterable<String> fetch, Iterable<String> fetch,
Boolean need_variant_tag) { Boolean need_variant_tag) {
List<? extends Map<String, List<? extends Number>>> feed_batch List<? extends Map<String, INDArray>> feed_batch
= new ArrayList<? extends Map<String, List<? extends Number>>>(); = new ArrayList<? extends Map<String, INDArray>>();
feed_batch.add(feed); feed_batch.add(feed);
return predict(feed_batch, fetch, need_variant_tag); return predict(feed_batch, fetch, need_variant_tag);
} }
/*
public PredictFuture async_predict( public PredictFuture async_predict(
Map<String, List<? extends Number>> feed, Map<String, List<? extends Number>> feed,
Iterable<String> fetch, Iterable<String> fetch,
...@@ -269,22 +266,21 @@ public class Client { ...@@ -269,22 +266,21 @@ public class Client {
feed_batch.add(feed); feed_batch.add(feed);
return async_predict(feed_batch, fetch, need_variant_tag); return async_predict(feed_batch, fetch, need_variant_tag);
} }
*/ */
public Map<String, ? extends Map<String, INDArray>> predict(
public Map<String, ? extends Map<String, List<? extends Number>>> predict( List<? extends Map<String, INDArray>> feed_batch,
List<? extends Map<String, List<? extends Number>>> feed_batch,
Iterable<String> fetch) { Iterable<String> fetch) {
return predict(feed_batch, fetch, false); return predict(feed_batch, fetch, false);
} }
/*
public PredictFuture async_predict( public PredictFuture async_predict(
List<? extends Map<String, List<? extends Number>>> feed_batch, List<? extends Map<String, List<? extends Number>>> feed_batch,
Iterable<String> fetch) { Iterable<String> fetch) {
return async_predict(feed_batch, fetch, false); return async_predict(feed_batch, fetch, false);
} }
*/
public Map<String, ? extends Map<String, List<? extends Number>>> predict( public Map<String, ? extends Map<String, INDArray>> predict(
List<? extends Map<String, List<? extends Number>>> feed_batch, List<? extends Map<String, INDArray>> feed_batch,
Iterable<String> fetch, Iterable<String> fetch,
Boolean need_variant_tag) { Boolean need_variant_tag) {
InferenceRequest req = _packInferenceRequest(feed_batch, fetch); InferenceRequest req = _packInferenceRequest(feed_batch, fetch);
...@@ -297,7 +293,7 @@ public class Client { ...@@ -297,7 +293,7 @@ public class Client {
return Collections.emptyMap(); return Collections.emptyMap();
} }
} }
/*
public PredictFuture async_predict( public PredictFuture async_predict(
List<Map<String, List<? extends Number>>> feed_batch, List<Map<String, List<? extends Number>>> feed_batch,
Iterable<String> fetch, Iterable<String> fetch,
...@@ -311,24 +307,26 @@ public class Client { ...@@ -311,24 +307,26 @@ public class Client {
resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag); resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag);
}); });
} }
*/
public static void main( String[] args ) { public static void main( String[] args ) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.create(data);
System.out.println(npdata);
Client client = new Client(); Client client = new Client();
List<String> endpoints = new ArrayList<String>() List<String> endpoints = new ArrayList<String>()
.add("182.61.111.54:9393"); .add("182.61.111.54:9393");
Client.connect(endpoints); Client.connect(endpoints);
List<HashMap<String, List<Float>>> feed_batch List<HashMap<String, INDArray>> feed_batch
= new ArrayList<HashMap<String, List<Float>>>() = new ArrayList<HashMap<String, INDArray>>()
.add(new HashMap<String, ArrayList<Float>>() .add(new HashMap<String, INDArray>()
.put("x", new ArrayList<Float>() .put("x", npdata));
.add(0.0137f).add(-0.1136f).add(0.2553f)
.add(-0.0692f).add(0.0582f).add(-0.0727f)
.add(-0.1583f).add(-0.0584f).add(0.6283f)
.add(0.4919f).add(0.1856f).add(0.0795f)
.add(-0.0332f)));
List<String> fetch = new ArrayList<String>() List<String> fetch = new ArrayList<String>()
.add("price"); .add("price");
Map<String, ? extends Map<String, List<?>>> fetch_map Map<String, ? extends Map<String, INDArray>> fetch_map
= client.predict(feed_batch, fetch); = client.predict(feed_batch, fetch);
System.out.println( "Hello World!" ); System.out.println( "Hello World!" );
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册