提交 27e45738 编写于 作者: B barrierye

succ run

上级 7d06d541
......@@ -14,6 +14,10 @@
syntax = "proto2";
option java_multiple_files = true;
option java_package = "io.paddle.serving.grpc";
option java_outer_classname = "ServingProto";
message Tensor {
optional bytes data = 1;
repeated int32 int_data = 2;
......
......@@ -18,7 +18,6 @@ import org.nd4j.linalg.factory.Nd4j;
import io.paddle.serving.grpc.*;
import io.paddle.serving.configure.*;
public class Client {
private ManagedChannel channel_;
private MultiLangGeneralModelServiceGrpc.MultiLangGeneralModelServiceBlockingStub blockingStub_;
......@@ -66,12 +65,13 @@ public class Client {
}
public Boolean connect(List<String> endpoints) {
// String target = "ipv4:" + String.join(",", endpoints);
// TODO
//String target = "ipv4:" + String.join(",", endpoints);
String target = endpoints.get(0);
// TODO: max_receive_message_length and max_send_message_length
try {
channel_ = ManagedChannelBuilder.forTarget(target)
.defaultLoadBalancingPolicy("round_robin")
.maxInboundMessageSize(Integer.MAX_VALUE)
.usePlaintext()
.build();
blockingStub_ = MultiLangGeneralModelServiceGrpc.newBlockingStub(channel_);
......@@ -82,17 +82,13 @@ public class Client {
}
GetClientConfigRequest get_client_config_req = GetClientConfigRequest.newBuilder().build();
GetClientConfigResponse resp;
System.out.println("try to call getClientConfig");
try {
resp = blockingStub_.getClientConfig(get_client_config_req);
} catch (StatusRuntimeException e) {
System.out.format("Get Client config failed: %s", e.toString());
return false;
}
System.out.println("succ call get client");
System.out.println(resp);
String model_config_str = resp.getClientConfigStr();
System.out.println("model_config_str: " + model_config_str);
_parseModelConfig(model_config_str);
return true;
}
......@@ -112,6 +108,7 @@ public class Client {
feedShapes_ = new HashMap<String, List<Integer>>();
fetchTypes_ = new HashMap<String, Integer>();
lodTensorSet_ = new HashSet<String>();
feedTensorLen_ = new HashMap<String, Integer>();
List<FeedVar> feed_var_list = model_conf.getFeedVarList();
for (FeedVar feed_var : feed_var_list) {
......@@ -166,9 +163,6 @@ public class Client {
INDArray variable = feed_data.get(name);
long[] flattened_shape = {-1};
INDArray flattened_list = variable.reshape(flattened_shape);
for (Map.Entry<String, Integer> entry : feedTypes_.entrySet()) {
System.out.println("Key = " + entry.getKey() + ", Value = " + entry.getValue());
}
int v_type = feedTypes_.get(name);
NdIndexIterator iter = new NdIndexIterator(flattened_list.shape());
if (v_type == 0) { // int64
......@@ -184,15 +178,17 @@ public class Client {
tensor_builder.addFloatData(x);
}
} else if (v_type == 2) { // int32
throw new IllegalArgumentException("error tensor value type.");
//TODO
/*
while (iter.hasNext()) {
long[] next_index = iter.next();
int x = flattened_list.getInt(next_index);
// TODO: long to int?
// the interface of INDArray is strange:
// https://deeplearning4j.org/api/latest/org/nd4j/linalg/api/ndarray/INDArray.html
int[] int_next_index = new int[next_index.length];
for(int i = 0; i < next_index.length; i++) {
int_next_index[i] = (int)next_index[i];
}
int x = flattened_list.getInt(int_next_index);
tensor_builder.addIntData(x);
}*/
}
} else {
throw new IllegalArgumentException("error tensor value type.");
}
......@@ -204,7 +200,7 @@ public class Client {
return req_builder.build();
}
private HashMap<String, HashMap<String, INDArray>>
private Map<String, HashMap<String, INDArray>>
_unpackInferenceResponse(
InferenceResponse resp,
Iterable<String> fetch,
......@@ -213,7 +209,7 @@ public class Client {
resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag);
}
private static HashMap<String, HashMap<String, INDArray>>
private static Map<String, HashMap<String, INDArray>>
_staticUnpackInferenceResponse(
InferenceResponse resp,
Iterable<String> fetch,
......@@ -227,6 +223,7 @@ public class Client {
HashMap<String, HashMap<String, INDArray>> multi_result_map
= new HashMap<String, HashMap<String, INDArray>>();
for (ModelOutput model_result: resp.getOutputsList()) {
String engine_name = model_result.getEngineName();
FetchInst inst = model_result.getInsts(0);
HashMap<String, INDArray> result_map
= new HashMap<String, INDArray>();
......@@ -270,25 +267,32 @@ public class Client {
}
index += 1;
}
multi_result_map.put(engine_name, result_map);
}
// TODO: tag
return multi_result_map;
}
public Map<String, HashMap<String, INDArray>> predict(
public Map<String, INDArray> predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch) {
return predict(feed, fetch, false);
}
/*
public Map<String, HashMap<String, INDArray>> ensemble_predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch) {
return ensemble_predict(feed, fetch, false);
}
public PredictFuture async_predict(
Map<String, INDArray> feed,
HashMap<String, INDArray> feed,
Iterable<String> fetch) {
return async_predict(feed, fetch, false);
}
*/
public Map<String, HashMap<String, INDArray>> predict(
public Map<String, INDArray> predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch,
Boolean need_variant_tag) {
......@@ -297,30 +301,69 @@ public class Client {
feed_batch.add(feed);
return predict(feed_batch, fetch, need_variant_tag);
}
/*
public Map<String, HashMap<String, INDArray>> ensemble_predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch,
Boolean need_variant_tag) {
List<HashMap<String, INDArray>> feed_batch
= new ArrayList<HashMap<String, INDArray>>();
feed_batch.add(feed);
return ensemble_predict(feed_batch, fetch, need_variant_tag);
}
public PredictFuture async_predict(
Map<String, List<? extends Number>> feed,
HashMap<String, INDArray> feed,
Iterable<String> fetch,
Boolean need_variant_tag) {
List<? extends Map<String, List<? extends Number>>> feed_batch
= new ArrayList<? extends Map<String, List<? extends Number>>>();
List<HashMap<String, INDArray>> feed_batch
= new ArrayList<HashMap<String, INDArray>>();
feed_batch.add(feed);
return async_predict(feed_batch, fetch, need_variant_tag);
}
*/
public Map<String, HashMap<String, INDArray>> predict(
public Map<String, INDArray> predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch) {
return predict(feed_batch, fetch, false);
}
/*
public Map<String, HashMap<String, INDArray>> ensemble_predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch) {
return ensemble_predict(feed_batch, fetch, false);
}
public PredictFuture async_predict(
List<? extends Map<String, List<? extends Number>>> feed_batch,
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch) {
return async_predict(feed_batch, fetch, false);
}
*/
public Map<String, HashMap<String, INDArray>> predict(
public Map<String, INDArray> predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch,
Boolean need_variant_tag) {
InferenceRequest req = _packInferenceRequest(feed_batch, fetch);
try {
InferenceResponse resp = blockingStub_.inference(req);
Map<String, HashMap<String, INDArray>> ensemble_result
= _unpackInferenceResponse(resp, fetch, need_variant_tag);
List<Map.Entry<String, HashMap<String, INDArray>>> list
= new ArrayList<Map.Entry<String, HashMap<String, INDArray>>>(
ensemble_result.entrySet());
if (list.size() != 1) {
System.out.format("grpc failed: please use ensemble_predict impl.");
return null;
}
return list.get(0).getValue();
} catch (StatusRuntimeException e) {
System.out.format("grpc failed: %s", e.toString());
return null;
}
}
public Map<String, HashMap<String, INDArray>> ensemble_predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch,
Boolean need_variant_tag) {
......@@ -331,31 +374,30 @@ public class Client {
resp, fetch, need_variant_tag);
} catch (StatusRuntimeException e) {
System.out.format("grpc failed: %s", e.toString());
return Collections.emptyMap();
return null;
}
}
/*
public PredictFuture async_predict(
List<Map<String, List<? extends Number>>> feed_batch,
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch,
Boolean need_variant_tag) {
InferenceRequest req = _packInferenceRequest(feed_batch, fetch);
ListenableFuture<InferenceResponse> future = futureStub_.inference(req);
return new PredictFuture(
future,
PredictFuture predict_future = new PredictFuture(future,
(InferenceResponse resp) -> {
return Client._staticUnpackInferenceResponse(
resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag);
});
}
*/
);
return predict_future;
}
public static void main( String[] args ) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.create(data);
System.out.println(npdata);
Client client = new Client();
List<String> endpoints = Arrays.asList("localhost:9393");
......@@ -368,30 +410,58 @@ public class Client {
= new HashMap<String, INDArray>() {{
put("x", npdata);
}};
List<HashMap<String, INDArray>> feed_batch
= new ArrayList<HashMap<String, INDArray>>() {{
add(feed_data);
}};
List<String> fetch = Arrays.asList("price");
Map<String, INDArray> fetch_map = client.predict(feed_data, fetch);
for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
}
/*
Map<String, HashMap<String, INDArray>> fetch_map
= client.predict(feed_batch, fetch);
System.out.println( "Hello World!" );
= client.predict(feed_data, fetch);
for (Map.Entry<String, HashMap<String, INDArray>> entry : fetch_map.entrySet()) {
System.out.println("Model = " + entry.getKey());
HashMap<String, INDArray> tt = entry.getValue();
for (Map.Entry<String, INDArray> e : tt.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
}
}
*/
}
}
class PredictFuture {
private ListenableFuture<InferenceResponse> callFuture_;
private Function<InferenceResponse,
Map<String, ? extends Map<String, List<? extends Number>>>> callBackFunc_;
Map<String, HashMap<String, INDArray>>> callBackFunc_;
PredictFuture(ListenableFuture<InferenceResponse> call_future,
Function<InferenceResponse,
Map<String, ? extends Map<String, List<? extends Number>>>> call_back_func) {
Map<String, HashMap<String, INDArray>>> call_back_func) {
callFuture_ = call_future;
callBackFunc_ = call_back_func;
}
public Map<String, ? extends Map<String, List<? extends Number>>> get() throws Exception {
public Map<String, INDArray> get() throws Exception {
InferenceResponse resp = null;
try {
resp = callFuture_.get();
} catch (Exception e) {
System.out.format("grpc failed: %s", e.toString());
return null;
}
Map<String, HashMap<String, INDArray>> ensemble_result
= callBackFunc_.apply(resp);
List<Map.Entry<String, HashMap<String, INDArray>>> list
= new ArrayList<Map.Entry<String, HashMap<String, INDArray>>>(
ensemble_result.entrySet());
if (list.size() != 1) {
System.out.format("grpc failed: please use get_ensemble impl.");
return null;
}
return list.get(0).getValue();
}
public Map<String, HashMap<String, INDArray>> get_ensemble() throws Exception {
InferenceResponse resp = null;
try {
resp = callFuture_.get();
......
......@@ -18,8 +18,6 @@ option java_multiple_files = true;
option java_package = "io.paddle.serving.grpc";
option java_outer_classname = "ServingProto";
package paddle.serving.grpc;
message Tensor {
optional bytes data = 1;
repeated int32 int_data = 2;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册