提交 4ab7e478 编写于 作者: B barrierye

update code

上级 3697e7c8
......@@ -12,6 +12,7 @@ import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.factory.Nd4j;
import io.paddle.serving.grpc.*;
......@@ -65,7 +66,8 @@ public class Client {
}
public Boolean connect(List<String> endpoints) {
String target = "ipv4:" + String.join(",", endpoints);
// String target = "ipv4:" + String.join(",", endpoints);
String target = endpoints.get(0);
// TODO: max_receive_message_length and max_send_message_length
try {
channel_ = ManagedChannelBuilder.forTarget(target)
......@@ -80,13 +82,17 @@ public class Client {
}
GetClientConfigRequest get_client_config_req = GetClientConfigRequest.newBuilder().build();
GetClientConfigResponse resp;
System.out.println("try to call getClientConfig");
try {
resp = blockingStub_.getClientConfig(get_client_config_req);
} catch (StatusRuntimeException e) {
System.out.format("Get Client config failed: %s", e.toString());
return false;
}
System.out.println("succ call get client");
System.out.println(resp);
String model_config_str = resp.getClientConfigStr();
System.out.println("model_config_str: " + model_config_str);
_parseModelConfig(model_config_str);
return true;
}
......@@ -144,7 +150,7 @@ public class Client {
}
private InferenceRequest _packInferenceRequest(
List<Map<String, INDArray>> feed_batch,
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch) throws IllegalArgumentException {
List<String> feed_var_names = new ArrayList<String>();
feed_var_names.addAll(feed_batch.get(0).keySet());
......@@ -153,25 +159,40 @@ public class Client {
.addAllFeedVarNames(feed_var_names)
.addAllFetchVarNames(fetch)
.setIsPython(false);
for (Map<String, INDArray> feed_data: feed_batch) {
for (HashMap<String, INDArray> feed_data: feed_batch) {
FeedInst.Builder inst_builder = FeedInst.newBuilder();
for (String name: feed_var_names) {
Tensor.Builder tensor_builder = Tensor.newBuilder();
INDArray variable = feed_data.get(name);
INDArray flattened_list = variable.reshape({-1});
long[] flattened_shape = {-1};
INDArray flattened_list = variable.reshape(flattened_shape);
for (Map.Entry<String, Integer> entry : feedTypes_.entrySet()) {
System.out.println("Key = " + entry.getKey() + ", Value = " + entry.getValue());
}
int v_type = feedTypes_.get(name);
NdIndexIterator iter = new NdIndexIterator(flattened_list.shape());
if (v_type == 0) { // int64
for (long x: flattened_list) {
while (iter.hasNext()) {
long[] next_index = iter.next();
long x = flattened_list.getLong(next_index);
tensor_builder.addInt64Data(x);
}
} else if (v_type == 1) { // float32
for (float x: flattened_list) {
while (iter.hasNext()) {
long[] next_index = iter.next();
float x = flattened_list.getFloat(next_index);
tensor_builder.addFloatData(x);
}
} else if (v_type == 2) { // int32
for (int x: flattened_list) {
throw new IllegalArgumentException("error tensor value type.");
//TODO
/*
while (iter.hasNext()) {
long[] next_index = iter.next();
int x = flattened_list.getInt(next_index);
// TODO: long to int?
tensor_builder.addIntData(x);
}
}*/
} else {
throw new IllegalArgumentException("error tensor value type.");
}
......@@ -183,7 +204,7 @@ public class Client {
return req_builder.build();
}
private Map<String, ? extends Map<String, INDArray>>
private HashMap<String, HashMap<String, INDArray>>
_unpackInferenceResponse(
InferenceResponse resp,
Iterable<String> fetch,
......@@ -192,7 +213,7 @@ public class Client {
resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag);
}
private static Map<String, ? extends Map<String, INDArray>>
private static HashMap<String, HashMap<String, INDArray>>
_staticUnpackInferenceResponse(
InferenceResponse resp,
Iterable<String> fetch,
......@@ -203,29 +224,49 @@ public class Client {
return null;
}
String tag = resp.getTag();
Map<String, ? extends Map<String, INDArray>> multi_result_map
HashMap<String, HashMap<String, INDArray>> multi_result_map
= new HashMap<String, HashMap<String, INDArray>>();
for (ModelOutput model_result: resp.getOutputsList()) {
FetchInst inst = model_result.getInsts(0);
Map<String, INDArray> result_map
HashMap<String, INDArray> result_map
= new HashMap<String, INDArray>();
int index = 0;
for (String name: fetch) {
Tensor variable = inst.getTensorArray(index);
int v_type = fetchTypes.get(name);
if (v_type == 0) { // int64
result_map.put(name, variable.getInt64DataList());
List<Long> list = variable.getInt64DataList();
long[] array = new long[list.size()];
for (int i = 0; i < list.size(); i++) {
array[i] = list.get(i);
}
result_map.put(name, Nd4j.create(array));
} else if (v_type == 1) { // float32
result_map.put(name, variable.getFloatDataList());
List<Float> list = variable.getFloatDataList();
float[] array = new float[list.size()];
for (int i = 0; i < list.size(); i++) {
array[i] = list.get(i);
}
result_map.put(name, Nd4j.create(array));
} else if (v_type == 2) { // int32
result_map.put(name, variable.getIntDataList());
List<Integer> list = variable.getIntDataList();
int[] array = new int[list.size()];
for (int i = 0; i < list.size(); i++) {
array[i] = list.get(i);
}
result_map.put(name, Nd4j.create(array));
} else {
throw new IllegalArgumentException("error tensor value type.");
}
// TODO: shape
if (lodTensorSet.contains(name)) {
result_map.put(name + ".lod", variable.getLodList());
List<Integer> list = variable.getLodList();
int[] array = new int[list.size()];
for (int i = 0; i < list.size(); i++) {
array[i] = list.get(i);
}
result_map.put(name + ".lod", Nd4j.create(array));
}
index += 1;
}
......@@ -235,8 +276,8 @@ public class Client {
return multi_result_map;
}
public Map<String, ? extends Map<String, INDArray>> predict(
Map<String, INDArray> feed,
public Map<String, HashMap<String, INDArray>> predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch) {
return predict(feed, fetch, false);
}
......@@ -247,12 +288,12 @@ public class Client {
return async_predict(feed, fetch, false);
}
*/
public Map<String, ? extends Map<String, INDArray>> predict(
Map<String, INDArray> feed,
public Map<String, HashMap<String, INDArray>> predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch,
Boolean need_variant_tag) {
List<? extends Map<String, INDArray>> feed_batch
= new ArrayList<? extends Map<String, INDArray>>();
List<HashMap<String, INDArray>> feed_batch
= new ArrayList<HashMap<String, INDArray>>();
feed_batch.add(feed);
return predict(feed_batch, fetch, need_variant_tag);
}
......@@ -267,8 +308,8 @@ public class Client {
return async_predict(feed_batch, fetch, need_variant_tag);
}
*/
public Map<String, ? extends Map<String, INDArray>> predict(
List<? extends Map<String, INDArray>> feed_batch,
public Map<String, HashMap<String, INDArray>> predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch) {
return predict(feed_batch, fetch, false);
}
......@@ -279,8 +320,8 @@ public class Client {
return async_predict(feed_batch, fetch, false);
}
*/
public Map<String, ? extends Map<String, INDArray>> predict(
List<? extends Map<String, INDArray>> feed_batch,
public Map<String, HashMap<String, INDArray>> predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch,
Boolean need_variant_tag) {
InferenceRequest req = _packInferenceRequest(feed_batch, fetch);
......@@ -317,16 +358,22 @@ public class Client {
System.out.println(npdata);
Client client = new Client();
List<String> endpoints = new ArrayList<String>()
.add("182.61.111.54:9393");
Client.connect(endpoints);
List<String> endpoints = Arrays.asList("localhost:9393");
boolean succ = client.connect(endpoints);
if (succ != true) {
System.out.println("connect failed.");
return;
}
HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{
put("x", npdata);
}};
List<HashMap<String, INDArray>> feed_batch
= new ArrayList<HashMap<String, INDArray>>()
.add(new HashMap<String, INDArray>()
.put("x", npdata));
List<String> fetch = new ArrayList<String>()
.add("price");
Map<String, ? extends Map<String, INDArray>> fetch_map
= new ArrayList<HashMap<String, INDArray>>() {{
add(feed_data);
}};
List<String> fetch = Arrays.asList("price");
Map<String, HashMap<String, INDArray>> fetch_map
= client.predict(feed_batch, fetch);
System.out.println( "Hello World!" );
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册