提交 4ab7e478 编写于 作者: B barrierye

update code

上级 3697e7c8
...@@ -12,6 +12,7 @@ import com.google.common.util.concurrent.Futures; ...@@ -12,6 +12,7 @@ import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListenableFuture;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import io.paddle.serving.grpc.*; import io.paddle.serving.grpc.*;
...@@ -65,7 +66,8 @@ public class Client { ...@@ -65,7 +66,8 @@ public class Client {
} }
public Boolean connect(List<String> endpoints) { public Boolean connect(List<String> endpoints) {
String target = "ipv4:" + String.join(",", endpoints); // String target = "ipv4:" + String.join(",", endpoints);
String target = endpoints.get(0);
// TODO: max_receive_message_length and max_send_message_length // TODO: max_receive_message_length and max_send_message_length
try { try {
channel_ = ManagedChannelBuilder.forTarget(target) channel_ = ManagedChannelBuilder.forTarget(target)
...@@ -80,13 +82,17 @@ public class Client { ...@@ -80,13 +82,17 @@ public class Client {
} }
GetClientConfigRequest get_client_config_req = GetClientConfigRequest.newBuilder().build(); GetClientConfigRequest get_client_config_req = GetClientConfigRequest.newBuilder().build();
GetClientConfigResponse resp; GetClientConfigResponse resp;
System.out.println("try to call getClientConfig");
try { try {
resp = blockingStub_.getClientConfig(get_client_config_req); resp = blockingStub_.getClientConfig(get_client_config_req);
} catch (StatusRuntimeException e) { } catch (StatusRuntimeException e) {
System.out.format("Get Client config failed: %s", e.toString()); System.out.format("Get Client config failed: %s", e.toString());
return false; return false;
} }
System.out.println("succ call get client");
System.out.println(resp);
String model_config_str = resp.getClientConfigStr(); String model_config_str = resp.getClientConfigStr();
System.out.println("model_config_str: " + model_config_str);
_parseModelConfig(model_config_str); _parseModelConfig(model_config_str);
return true; return true;
} }
...@@ -144,7 +150,7 @@ public class Client { ...@@ -144,7 +150,7 @@ public class Client {
} }
private InferenceRequest _packInferenceRequest( private InferenceRequest _packInferenceRequest(
List<Map<String, INDArray>> feed_batch, List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch) throws IllegalArgumentException { Iterable<String> fetch) throws IllegalArgumentException {
List<String> feed_var_names = new ArrayList<String>(); List<String> feed_var_names = new ArrayList<String>();
feed_var_names.addAll(feed_batch.get(0).keySet()); feed_var_names.addAll(feed_batch.get(0).keySet());
...@@ -153,25 +159,40 @@ public class Client { ...@@ -153,25 +159,40 @@ public class Client {
.addAllFeedVarNames(feed_var_names) .addAllFeedVarNames(feed_var_names)
.addAllFetchVarNames(fetch) .addAllFetchVarNames(fetch)
.setIsPython(false); .setIsPython(false);
for (Map<String, INDArray> feed_data: feed_batch) { for (HashMap<String, INDArray> feed_data: feed_batch) {
FeedInst.Builder inst_builder = FeedInst.newBuilder(); FeedInst.Builder inst_builder = FeedInst.newBuilder();
for (String name: feed_var_names) { for (String name: feed_var_names) {
Tensor.Builder tensor_builder = Tensor.newBuilder(); Tensor.Builder tensor_builder = Tensor.newBuilder();
INDArray variable = feed_data.get(name); INDArray variable = feed_data.get(name);
INDArray flattened_list = variable.reshape({-1}); long[] flattened_shape = {-1};
INDArray flattened_list = variable.reshape(flattened_shape);
for (Map.Entry<String, Integer> entry : feedTypes_.entrySet()) {
System.out.println("Key = " + entry.getKey() + ", Value = " + entry.getValue());
}
int v_type = feedTypes_.get(name); int v_type = feedTypes_.get(name);
NdIndexIterator iter = new NdIndexIterator(flattened_list.shape());
if (v_type == 0) { // int64 if (v_type == 0) { // int64
for (long x: flattened_list) { while (iter.hasNext()) {
long[] next_index = iter.next();
long x = flattened_list.getLong(next_index);
tensor_builder.addInt64Data(x); tensor_builder.addInt64Data(x);
} }
} else if (v_type == 1) { // float32 } else if (v_type == 1) { // float32
for (float x: flattened_list) { while (iter.hasNext()) {
long[] next_index = iter.next();
float x = flattened_list.getFloat(next_index);
tensor_builder.addFloatData(x); tensor_builder.addFloatData(x);
} }
} else if (v_type == 2) { // int32 } else if (v_type == 2) { // int32
for (int x: flattened_list) { throw new IllegalArgumentException("error tensor value type.");
//TODO
/*
while (iter.hasNext()) {
long[] next_index = iter.next();
int x = flattened_list.getInt(next_index);
// TODO: long to int?
tensor_builder.addIntData(x); tensor_builder.addIntData(x);
} }*/
} else { } else {
throw new IllegalArgumentException("error tensor value type."); throw new IllegalArgumentException("error tensor value type.");
} }
...@@ -183,7 +204,7 @@ public class Client { ...@@ -183,7 +204,7 @@ public class Client {
return req_builder.build(); return req_builder.build();
} }
private Map<String, ? extends Map<String, INDArray>> private HashMap<String, HashMap<String, INDArray>>
_unpackInferenceResponse( _unpackInferenceResponse(
InferenceResponse resp, InferenceResponse resp,
Iterable<String> fetch, Iterable<String> fetch,
...@@ -192,7 +213,7 @@ public class Client { ...@@ -192,7 +213,7 @@ public class Client {
resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag); resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag);
} }
private static Map<String, ? extends Map<String, INDArray>> private static HashMap<String, HashMap<String, INDArray>>
_staticUnpackInferenceResponse( _staticUnpackInferenceResponse(
InferenceResponse resp, InferenceResponse resp,
Iterable<String> fetch, Iterable<String> fetch,
...@@ -203,29 +224,49 @@ public class Client { ...@@ -203,29 +224,49 @@ public class Client {
return null; return null;
} }
String tag = resp.getTag(); String tag = resp.getTag();
Map<String, ? extends Map<String, INDArray>> multi_result_map HashMap<String, HashMap<String, INDArray>> multi_result_map
= new HashMap<String, HashMap<String, INDArray>>(); = new HashMap<String, HashMap<String, INDArray>>();
for (ModelOutput model_result: resp.getOutputsList()) { for (ModelOutput model_result: resp.getOutputsList()) {
FetchInst inst = model_result.getInsts(0); FetchInst inst = model_result.getInsts(0);
Map<String, INDArray> result_map HashMap<String, INDArray> result_map
= new HashMap<String, INDArray>(); = new HashMap<String, INDArray>();
int index = 0; int index = 0;
for (String name: fetch) { for (String name: fetch) {
Tensor variable = inst.getTensorArray(index); Tensor variable = inst.getTensorArray(index);
int v_type = fetchTypes.get(name); int v_type = fetchTypes.get(name);
if (v_type == 0) { // int64 if (v_type == 0) { // int64
result_map.put(name, variable.getInt64DataList()); List<Long> list = variable.getInt64DataList();
long[] array = new long[list.size()];
for (int i = 0; i < list.size(); i++) {
array[i] = list.get(i);
}
result_map.put(name, Nd4j.create(array));
} else if (v_type == 1) { // float32 } else if (v_type == 1) { // float32
result_map.put(name, variable.getFloatDataList()); List<Float> list = variable.getFloatDataList();
float[] array = new float[list.size()];
for (int i = 0; i < list.size(); i++) {
array[i] = list.get(i);
}
result_map.put(name, Nd4j.create(array));
} else if (v_type == 2) { // int32 } else if (v_type == 2) { // int32
result_map.put(name, variable.getIntDataList()); List<Integer> list = variable.getIntDataList();
int[] array = new int[list.size()];
for (int i = 0; i < list.size(); i++) {
array[i] = list.get(i);
}
result_map.put(name, Nd4j.create(array));
} else { } else {
throw new IllegalArgumentException("error tensor value type."); throw new IllegalArgumentException("error tensor value type.");
} }
// TODO: shape // TODO: shape
if (lodTensorSet.contains(name)) { if (lodTensorSet.contains(name)) {
result_map.put(name + ".lod", variable.getLodList()); List<Integer> list = variable.getLodList();
int[] array = new int[list.size()];
for (int i = 0; i < list.size(); i++) {
array[i] = list.get(i);
}
result_map.put(name + ".lod", Nd4j.create(array));
} }
index += 1; index += 1;
} }
...@@ -235,8 +276,8 @@ public class Client { ...@@ -235,8 +276,8 @@ public class Client {
return multi_result_map; return multi_result_map;
} }
public Map<String, ? extends Map<String, INDArray>> predict( public Map<String, HashMap<String, INDArray>> predict(
Map<String, INDArray> feed, HashMap<String, INDArray> feed,
Iterable<String> fetch) { Iterable<String> fetch) {
return predict(feed, fetch, false); return predict(feed, fetch, false);
} }
...@@ -247,12 +288,12 @@ public class Client { ...@@ -247,12 +288,12 @@ public class Client {
return async_predict(feed, fetch, false); return async_predict(feed, fetch, false);
} }
*/ */
public Map<String, ? extends Map<String, INDArray>> predict( public Map<String, HashMap<String, INDArray>> predict(
Map<String, INDArray> feed, HashMap<String, INDArray> feed,
Iterable<String> fetch, Iterable<String> fetch,
Boolean need_variant_tag) { Boolean need_variant_tag) {
List<? extends Map<String, INDArray>> feed_batch List<HashMap<String, INDArray>> feed_batch
= new ArrayList<? extends Map<String, INDArray>>(); = new ArrayList<HashMap<String, INDArray>>();
feed_batch.add(feed); feed_batch.add(feed);
return predict(feed_batch, fetch, need_variant_tag); return predict(feed_batch, fetch, need_variant_tag);
} }
...@@ -267,8 +308,8 @@ public class Client { ...@@ -267,8 +308,8 @@ public class Client {
return async_predict(feed_batch, fetch, need_variant_tag); return async_predict(feed_batch, fetch, need_variant_tag);
} }
*/ */
public Map<String, ? extends Map<String, INDArray>> predict( public Map<String, HashMap<String, INDArray>> predict(
List<? extends Map<String, INDArray>> feed_batch, List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch) { Iterable<String> fetch) {
return predict(feed_batch, fetch, false); return predict(feed_batch, fetch, false);
} }
...@@ -279,8 +320,8 @@ public class Client { ...@@ -279,8 +320,8 @@ public class Client {
return async_predict(feed_batch, fetch, false); return async_predict(feed_batch, fetch, false);
} }
*/ */
public Map<String, ? extends Map<String, INDArray>> predict( public Map<String, HashMap<String, INDArray>> predict(
List<? extends Map<String, INDArray>> feed_batch, List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch, Iterable<String> fetch,
Boolean need_variant_tag) { Boolean need_variant_tag) {
InferenceRequest req = _packInferenceRequest(feed_batch, fetch); InferenceRequest req = _packInferenceRequest(feed_batch, fetch);
...@@ -317,16 +358,22 @@ public class Client { ...@@ -317,16 +358,22 @@ public class Client {
System.out.println(npdata); System.out.println(npdata);
Client client = new Client(); Client client = new Client();
List<String> endpoints = new ArrayList<String>() List<String> endpoints = Arrays.asList("localhost:9393");
.add("182.61.111.54:9393"); boolean succ = client.connect(endpoints);
Client.connect(endpoints); if (succ != true) {
System.out.println("connect failed.");
return;
}
HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{
put("x", npdata);
}};
List<HashMap<String, INDArray>> feed_batch List<HashMap<String, INDArray>> feed_batch
= new ArrayList<HashMap<String, INDArray>>() = new ArrayList<HashMap<String, INDArray>>() {{
.add(new HashMap<String, INDArray>() add(feed_data);
.put("x", npdata)); }};
List<String> fetch = new ArrayList<String>() List<String> fetch = Arrays.asList("price");
.add("price"); Map<String, HashMap<String, INDArray>> fetch_map
Map<String, ? extends Map<String, INDArray>> fetch_map
= client.predict(feed_batch, fetch); = client.predict(feed_batch, fetch);
System.out.println( "Hello World!" ); System.out.println( "Hello World!" );
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册