提交 27e45738 编写于 作者: B barrierye

succ run

上级 7d06d541
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
syntax = "proto2"; syntax = "proto2";
option java_multiple_files = true;
option java_package = "io.paddle.serving.grpc";
option java_outer_classname = "ServingProto";
message Tensor { message Tensor {
optional bytes data = 1; optional bytes data = 1;
repeated int32 int_data = 2; repeated int32 int_data = 2;
......
...@@ -18,7 +18,6 @@ import org.nd4j.linalg.factory.Nd4j; ...@@ -18,7 +18,6 @@ import org.nd4j.linalg.factory.Nd4j;
import io.paddle.serving.grpc.*; import io.paddle.serving.grpc.*;
import io.paddle.serving.configure.*; import io.paddle.serving.configure.*;
public class Client { public class Client {
private ManagedChannel channel_; private ManagedChannel channel_;
private MultiLangGeneralModelServiceGrpc.MultiLangGeneralModelServiceBlockingStub blockingStub_; private MultiLangGeneralModelServiceGrpc.MultiLangGeneralModelServiceBlockingStub blockingStub_;
...@@ -66,12 +65,13 @@ public class Client { ...@@ -66,12 +65,13 @@ public class Client {
} }
public Boolean connect(List<String> endpoints) { public Boolean connect(List<String> endpoints) {
// String target = "ipv4:" + String.join(",", endpoints); // TODO
//String target = "ipv4:" + String.join(",", endpoints);
String target = endpoints.get(0); String target = endpoints.get(0);
// TODO: max_receive_message_length and max_send_message_length
try { try {
channel_ = ManagedChannelBuilder.forTarget(target) channel_ = ManagedChannelBuilder.forTarget(target)
.defaultLoadBalancingPolicy("round_robin") .defaultLoadBalancingPolicy("round_robin")
.maxInboundMessageSize(Integer.MAX_VALUE)
.usePlaintext() .usePlaintext()
.build(); .build();
blockingStub_ = MultiLangGeneralModelServiceGrpc.newBlockingStub(channel_); blockingStub_ = MultiLangGeneralModelServiceGrpc.newBlockingStub(channel_);
...@@ -82,17 +82,13 @@ public class Client { ...@@ -82,17 +82,13 @@ public class Client {
} }
GetClientConfigRequest get_client_config_req = GetClientConfigRequest.newBuilder().build(); GetClientConfigRequest get_client_config_req = GetClientConfigRequest.newBuilder().build();
GetClientConfigResponse resp; GetClientConfigResponse resp;
System.out.println("try to call getClientConfig");
try { try {
resp = blockingStub_.getClientConfig(get_client_config_req); resp = blockingStub_.getClientConfig(get_client_config_req);
} catch (StatusRuntimeException e) { } catch (StatusRuntimeException e) {
System.out.format("Get Client config failed: %s", e.toString()); System.out.format("Get Client config failed: %s", e.toString());
return false; return false;
} }
System.out.println("succ call get client");
System.out.println(resp);
String model_config_str = resp.getClientConfigStr(); String model_config_str = resp.getClientConfigStr();
System.out.println("model_config_str: " + model_config_str);
_parseModelConfig(model_config_str); _parseModelConfig(model_config_str);
return true; return true;
} }
...@@ -112,6 +108,7 @@ public class Client { ...@@ -112,6 +108,7 @@ public class Client {
feedShapes_ = new HashMap<String, List<Integer>>(); feedShapes_ = new HashMap<String, List<Integer>>();
fetchTypes_ = new HashMap<String, Integer>(); fetchTypes_ = new HashMap<String, Integer>();
lodTensorSet_ = new HashSet<String>(); lodTensorSet_ = new HashSet<String>();
feedTensorLen_ = new HashMap<String, Integer>();
List<FeedVar> feed_var_list = model_conf.getFeedVarList(); List<FeedVar> feed_var_list = model_conf.getFeedVarList();
for (FeedVar feed_var : feed_var_list) { for (FeedVar feed_var : feed_var_list) {
...@@ -166,9 +163,6 @@ public class Client { ...@@ -166,9 +163,6 @@ public class Client {
INDArray variable = feed_data.get(name); INDArray variable = feed_data.get(name);
long[] flattened_shape = {-1}; long[] flattened_shape = {-1};
INDArray flattened_list = variable.reshape(flattened_shape); INDArray flattened_list = variable.reshape(flattened_shape);
for (Map.Entry<String, Integer> entry : feedTypes_.entrySet()) {
System.out.println("Key = " + entry.getKey() + ", Value = " + entry.getValue());
}
int v_type = feedTypes_.get(name); int v_type = feedTypes_.get(name);
NdIndexIterator iter = new NdIndexIterator(flattened_list.shape()); NdIndexIterator iter = new NdIndexIterator(flattened_list.shape());
if (v_type == 0) { // int64 if (v_type == 0) { // int64
...@@ -184,15 +178,17 @@ public class Client { ...@@ -184,15 +178,17 @@ public class Client {
tensor_builder.addFloatData(x); tensor_builder.addFloatData(x);
} }
} else if (v_type == 2) { // int32 } else if (v_type == 2) { // int32
throw new IllegalArgumentException("error tensor value type.");
//TODO
/*
while (iter.hasNext()) { while (iter.hasNext()) {
long[] next_index = iter.next(); long[] next_index = iter.next();
int x = flattened_list.getInt(next_index); // the interface of INDArray is strange:
// TODO: long to int? // https://deeplearning4j.org/api/latest/org/nd4j/linalg/api/ndarray/INDArray.html
int[] int_next_index = new int[next_index.length];
for(int i = 0; i < next_index.length; i++) {
int_next_index[i] = (int)next_index[i];
}
int x = flattened_list.getInt(int_next_index);
tensor_builder.addIntData(x); tensor_builder.addIntData(x);
}*/ }
} else { } else {
throw new IllegalArgumentException("error tensor value type."); throw new IllegalArgumentException("error tensor value type.");
} }
...@@ -204,7 +200,7 @@ public class Client { ...@@ -204,7 +200,7 @@ public class Client {
return req_builder.build(); return req_builder.build();
} }
private HashMap<String, HashMap<String, INDArray>> private Map<String, HashMap<String, INDArray>>
_unpackInferenceResponse( _unpackInferenceResponse(
InferenceResponse resp, InferenceResponse resp,
Iterable<String> fetch, Iterable<String> fetch,
...@@ -213,7 +209,7 @@ public class Client { ...@@ -213,7 +209,7 @@ public class Client {
resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag); resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag);
} }
private static HashMap<String, HashMap<String, INDArray>> private static Map<String, HashMap<String, INDArray>>
_staticUnpackInferenceResponse( _staticUnpackInferenceResponse(
InferenceResponse resp, InferenceResponse resp,
Iterable<String> fetch, Iterable<String> fetch,
...@@ -227,6 +223,7 @@ public class Client { ...@@ -227,6 +223,7 @@ public class Client {
HashMap<String, HashMap<String, INDArray>> multi_result_map HashMap<String, HashMap<String, INDArray>> multi_result_map
= new HashMap<String, HashMap<String, INDArray>>(); = new HashMap<String, HashMap<String, INDArray>>();
for (ModelOutput model_result: resp.getOutputsList()) { for (ModelOutput model_result: resp.getOutputsList()) {
String engine_name = model_result.getEngineName();
FetchInst inst = model_result.getInsts(0); FetchInst inst = model_result.getInsts(0);
HashMap<String, INDArray> result_map HashMap<String, INDArray> result_map
= new HashMap<String, INDArray>(); = new HashMap<String, INDArray>();
...@@ -270,25 +267,32 @@ public class Client { ...@@ -270,25 +267,32 @@ public class Client {
} }
index += 1; index += 1;
} }
multi_result_map.put(engine_name, result_map);
} }
// TODO: tag // TODO: tag
return multi_result_map; return multi_result_map;
} }
public Map<String, HashMap<String, INDArray>> predict( public Map<String, INDArray> predict(
HashMap<String, INDArray> feed, HashMap<String, INDArray> feed,
Iterable<String> fetch) { Iterable<String> fetch) {
return predict(feed, fetch, false); return predict(feed, fetch, false);
} }
/*
public Map<String, HashMap<String, INDArray>> ensemble_predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch) {
return ensemble_predict(feed, fetch, false);
}
public PredictFuture async_predict( public PredictFuture async_predict(
Map<String, INDArray> feed, HashMap<String, INDArray> feed,
Iterable<String> fetch) { Iterable<String> fetch) {
return async_predict(feed, fetch, false); return async_predict(feed, fetch, false);
} }
*/
public Map<String, HashMap<String, INDArray>> predict( public Map<String, INDArray> predict(
HashMap<String, INDArray> feed, HashMap<String, INDArray> feed,
Iterable<String> fetch, Iterable<String> fetch,
Boolean need_variant_tag) { Boolean need_variant_tag) {
...@@ -297,30 +301,69 @@ public class Client { ...@@ -297,30 +301,69 @@ public class Client {
feed_batch.add(feed); feed_batch.add(feed);
return predict(feed_batch, fetch, need_variant_tag); return predict(feed_batch, fetch, need_variant_tag);
} }
/*
public Map<String, HashMap<String, INDArray>> ensemble_predict(
HashMap<String, INDArray> feed,
Iterable<String> fetch,
Boolean need_variant_tag) {
List<HashMap<String, INDArray>> feed_batch
= new ArrayList<HashMap<String, INDArray>>();
feed_batch.add(feed);
return ensemble_predict(feed_batch, fetch, need_variant_tag);
}
public PredictFuture async_predict( public PredictFuture async_predict(
Map<String, List<? extends Number>> feed, HashMap<String, INDArray> feed,
Iterable<String> fetch, Iterable<String> fetch,
Boolean need_variant_tag) { Boolean need_variant_tag) {
List<? extends Map<String, List<? extends Number>>> feed_batch List<HashMap<String, INDArray>> feed_batch
= new ArrayList<? extends Map<String, List<? extends Number>>>(); = new ArrayList<HashMap<String, INDArray>>();
feed_batch.add(feed); feed_batch.add(feed);
return async_predict(feed_batch, fetch, need_variant_tag); return async_predict(feed_batch, fetch, need_variant_tag);
} }
*/
public Map<String, HashMap<String, INDArray>> predict( public Map<String, INDArray> predict(
List<HashMap<String, INDArray>> feed_batch, List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch) { Iterable<String> fetch) {
return predict(feed_batch, fetch, false); return predict(feed_batch, fetch, false);
} }
/*
public Map<String, HashMap<String, INDArray>> ensemble_predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch) {
return ensemble_predict(feed_batch, fetch, false);
}
public PredictFuture async_predict( public PredictFuture async_predict(
List<? extends Map<String, List<? extends Number>>> feed_batch, List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch) { Iterable<String> fetch) {
return async_predict(feed_batch, fetch, false); return async_predict(feed_batch, fetch, false);
} }
*/
public Map<String, HashMap<String, INDArray>> predict( public Map<String, INDArray> predict(
List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch,
Boolean need_variant_tag) {
InferenceRequest req = _packInferenceRequest(feed_batch, fetch);
try {
InferenceResponse resp = blockingStub_.inference(req);
Map<String, HashMap<String, INDArray>> ensemble_result
= _unpackInferenceResponse(resp, fetch, need_variant_tag);
List<Map.Entry<String, HashMap<String, INDArray>>> list
= new ArrayList<Map.Entry<String, HashMap<String, INDArray>>>(
ensemble_result.entrySet());
if (list.size() != 1) {
System.out.format("grpc failed: please use ensemble_predict impl.");
return null;
}
return list.get(0).getValue();
} catch (StatusRuntimeException e) {
System.out.format("grpc failed: %s", e.toString());
return null;
}
}
public Map<String, HashMap<String, INDArray>> ensemble_predict(
List<HashMap<String, INDArray>> feed_batch, List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch, Iterable<String> fetch,
Boolean need_variant_tag) { Boolean need_variant_tag) {
...@@ -331,31 +374,30 @@ public class Client { ...@@ -331,31 +374,30 @@ public class Client {
resp, fetch, need_variant_tag); resp, fetch, need_variant_tag);
} catch (StatusRuntimeException e) { } catch (StatusRuntimeException e) {
System.out.format("grpc failed: %s", e.toString()); System.out.format("grpc failed: %s", e.toString());
return Collections.emptyMap(); return null;
} }
} }
/*
public PredictFuture async_predict( public PredictFuture async_predict(
List<Map<String, List<? extends Number>>> feed_batch, List<HashMap<String, INDArray>> feed_batch,
Iterable<String> fetch, Iterable<String> fetch,
Boolean need_variant_tag) { Boolean need_variant_tag) {
InferenceRequest req = _packInferenceRequest(feed_batch, fetch); InferenceRequest req = _packInferenceRequest(feed_batch, fetch);
ListenableFuture<InferenceResponse> future = futureStub_.inference(req); ListenableFuture<InferenceResponse> future = futureStub_.inference(req);
return new PredictFuture( PredictFuture predict_future = new PredictFuture(future,
future,
(InferenceResponse resp) -> { (InferenceResponse resp) -> {
return Client._staticUnpackInferenceResponse( return Client._staticUnpackInferenceResponse(
resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag); resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag);
});
} }
*/ );
return predict_future;
}
public static void main( String[] args ) { public static void main( String[] args ) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f, 0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.create(data); INDArray npdata = Nd4j.create(data);
System.out.println(npdata);
Client client = new Client(); Client client = new Client();
List<String> endpoints = Arrays.asList("localhost:9393"); List<String> endpoints = Arrays.asList("localhost:9393");
...@@ -368,30 +410,58 @@ public class Client { ...@@ -368,30 +410,58 @@ public class Client {
= new HashMap<String, INDArray>() {{ = new HashMap<String, INDArray>() {{
put("x", npdata); put("x", npdata);
}}; }};
List<HashMap<String, INDArray>> feed_batch
= new ArrayList<HashMap<String, INDArray>>() {{
add(feed_data);
}};
List<String> fetch = Arrays.asList("price"); List<String> fetch = Arrays.asList("price");
Map<String, INDArray> fetch_map = client.predict(feed_data, fetch);
for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
}
/*
Map<String, HashMap<String, INDArray>> fetch_map Map<String, HashMap<String, INDArray>> fetch_map
= client.predict(feed_batch, fetch); = client.predict(feed_data, fetch);
System.out.println( "Hello World!" ); for (Map.Entry<String, HashMap<String, INDArray>> entry : fetch_map.entrySet()) {
System.out.println("Model = " + entry.getKey());
HashMap<String, INDArray> tt = entry.getValue();
for (Map.Entry<String, INDArray> e : tt.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
}
}
*/
} }
} }
class PredictFuture { class PredictFuture {
private ListenableFuture<InferenceResponse> callFuture_; private ListenableFuture<InferenceResponse> callFuture_;
private Function<InferenceResponse, private Function<InferenceResponse,
Map<String, ? extends Map<String, List<? extends Number>>>> callBackFunc_; Map<String, HashMap<String, INDArray>>> callBackFunc_;
PredictFuture(ListenableFuture<InferenceResponse> call_future, PredictFuture(ListenableFuture<InferenceResponse> call_future,
Function<InferenceResponse, Function<InferenceResponse,
Map<String, ? extends Map<String, List<? extends Number>>>> call_back_func) { Map<String, HashMap<String, INDArray>>> call_back_func) {
callFuture_ = call_future; callFuture_ = call_future;
callBackFunc_ = call_back_func; callBackFunc_ = call_back_func;
} }
public Map<String, ? extends Map<String, List<? extends Number>>> get() throws Exception { public Map<String, INDArray> get() throws Exception {
InferenceResponse resp = null;
try {
resp = callFuture_.get();
} catch (Exception e) {
System.out.format("grpc failed: %s", e.toString());
return null;
}
Map<String, HashMap<String, INDArray>> ensemble_result
= callBackFunc_.apply(resp);
List<Map.Entry<String, HashMap<String, INDArray>>> list
= new ArrayList<Map.Entry<String, HashMap<String, INDArray>>>(
ensemble_result.entrySet());
if (list.size() != 1) {
System.out.format("grpc failed: please use get_ensemble impl.");
return null;
}
return list.get(0).getValue();
}
public Map<String, HashMap<String, INDArray>> get_ensemble() throws Exception {
InferenceResponse resp = null; InferenceResponse resp = null;
try { try {
resp = callFuture_.get(); resp = callFuture_.get();
......
...@@ -18,8 +18,6 @@ option java_multiple_files = true; ...@@ -18,8 +18,6 @@ option java_multiple_files = true;
option java_package = "io.paddle.serving.grpc"; option java_package = "io.paddle.serving.grpc";
option java_outer_classname = "ServingProto"; option java_outer_classname = "ServingProto";
package paddle.serving.grpc;
message Tensor { message Tensor {
optional bytes data = 1; optional bytes data = 1;
repeated int32 int_data = 2; repeated int32 int_data = 2;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册