提交 f68d61e4 编写于 作者: B barrierye

update java client

上级 add04dac
package io.paddle.serving.client;
import java.util.*;
import java.util.function.Function;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.StatusRuntimeException;
//import com.google.protobuf;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import io.paddle.serving.grpc.*;
import io.paddle.serving.configure.*;
......@@ -14,25 +18,33 @@ import io.paddle.serving.configure.*;
public class Client {
private ManagedChannel channel_;
private MultiLangGeneralModelServiceGrpc.MultiLangGeneralModelServiceBlockingStub blockingStub_;
private MultiLangGeneralModelServiceGrpc.MultiLangGeneralModelServiceStub asyncStub_;
private MultiLangGeneralModelServiceGrpc.MultiLangGeneralModelServiceFutureStub futureStub_;
private double rpcTimeoutS_;
private List<String> feedNames_;
private Map<String, Integer> feedTypes_;
private Map<String, List<Integer>> feedShapes_;
private List<String> fetchNames_;
private Map<String, Integer> fetchTypes_;
private Map<String, List<Integer>> fetchShapes_;
private Set<String> lodTensorSet_;
private Map<String, Integer> feedTensorLen_;
Client() {
channel_ = null;
blockingStub_ = null;
asyncStub_ = null;
futureStub_ = null;
rpcTimeoutS_ = 2;
feedNames_ = null;
feedTypes_ = null;
feedShapes_ = null;
fetchNames_ = null;
fetchTypes_ = null;
lodTensorSet_ = null;
feedTensorLen_ = null;
}
Boolean setRpcTimeoutMs(int rpc_timeout) throws NullPointerException {
if (asyncStub_ == null || blockingStub_ == null) {
public Boolean setRpcTimeoutMs(int rpc_timeout) throws NullPointerException {
if (futureStub_ == null || blockingStub_ == null) {
throw new NullPointerException("set timeout must be set after connect.");
}
rpcTimeoutS_ = rpc_timeout / 1000.0;
......@@ -49,7 +61,7 @@ public class Client {
return resp.getErrCode() == 0;
}
Boolean connect(String[] endpoints) {
public Boolean connect(String[] endpoints) {
String target = "ipv4:" + String.join(",", endpoints);
// TODO: max_receive_message_length and max_send_message_length
try {
......@@ -58,7 +70,7 @@ public class Client {
.usePlaintext()
.build();
blockingStub_ = MultiLangGeneralModelServiceGrpc.newBlockingStub(channel_);
asyncStub_ = MultiLangGeneralModelServiceGrpc.newStub(channel_);
futureStub_ = MultiLangGeneralModelServiceGrpc.newFutureStub(channel_);
} catch (Exception e) {
System.out.format("Connect failed: %s", e.toString());
return false;
......@@ -76,7 +88,7 @@ public class Client {
return true;
}
void _parseModelConfig(String model_config_str) {
private void _parseModelConfig(String model_config_str) {
GeneralModelConfig.Builder model_conf_builder = GeneralModelConfig.newBuilder();
try {
com.google.protobuf.TextFormat.getParser().merge(model_config_str, model_conf_builder);
......@@ -85,30 +97,203 @@ public class Client {
}
GeneralModelConfig model_conf = model_conf_builder.build();
List<String> feedNames_ ;
feedNames_ = new ArrayList<String>();
fetchNames_ = new ArrayList<String>();
feedTypes_ = new HashMap<String, Integer>();
feedShapes_ = new HashMap<String, List<Integer>>();
fetchTypes_ = new HashMap<String, Integer>();
lodTensorSet_ = new HashSet<String>();
List<FeedVar> feed_var_list = model_conf.getFeedVarList();
for (FeedVar feed_var : feed_var_list) {
feedNames_.add(feed_var.getAliasName());
}
List<String> fetchNames_ ;
List<FetchVar> fetch_var_list = model_conf.getFetchVarList();
for (FetchVar fetch_var : fetch_var_list) {
fetchNames_.add(fetch_var.getAliasName());
}
feedTypes_ = new HashMap<String, Integer>();
feedShapes_ = new HashMap<String, List<Integer>>();
fetchTypes_ = new HashMap<String, Integer>();
fetchShapes_ = new HashMap<String, List<Integer>>();
lodTensorSet_ = new HashSet<String>();
for (int i = 0; i < feed_var_list.size(); ++i) {
FeedVar feed_var = feed_var_list[i];
FeedVar feed_var = feed_var_list.get(i);
String var_name = feed_var.getAliasName();
// feedTypes_[var_name] = feed_var.getFeedType();
feedTypes_.put(var_name, feed_var.getFeedType());
feedShapes_.put(var_name, feed_var.getShapeList());
if (feed_var.getIsLodTensor()) {
lodTensorSet_.add(var_name);
} else {
int counter = 1;
for (int dim : feedShapes_.get(var_name)) {
counter *= dim;
}
feedTensorLen_.put(var_name, counter);
// TODO: check shape
}
}
for (int i = 0; i < fetch_var_list.size(); i++) {
FetchVar fetch_var = fetch_var_list.get(i);
String var_name = fetch_var.getAliasName();
fetchTypes_.put(var_name, fetch_var.getFetchType());
if (fetch_var.getIsLodTensor()) {
lodTensorSet_.add(var_name);
}
}
}
private List<? extends Number> _flattenList(List<? extends Number> x) {
// TODO
return x;
}
private InferenceRequest _packInferenceRequest(
List<Map<String, List<? extends Number>>> feed_batch,
Iterable<String> fetch) throws IllegalArgumentException {
List<String> feed_var_names = new ArrayList<String>();
feed_var_names.addAll(feed_batch.get(0).keySet());
InferenceRequest.Builder req_builder = InferenceRequest.newBuilder()
.addAllFeedVarNames(feed_var_names)
.addAllFetchVarNames(fetch)
.setIsPython(false);
for (Map<String, List<? extends Number>> feed_data: feed_batch) {
FeedInst.Builder inst_builder = FeedInst.newBuilder();
for (String name: feed_var_names) {
Tensor.Builder tensor_builder = Tensor.newBuilder();
List<? extends Number> variable = feed_data.get(name);
List<? extends Number> flattened_list = _flattenList(variable);
int v_type = feedTypes_.get(name);
if (v_type == 0) { // int64
for (Number x: flattened_list) {
tensor_builder.addInt64Data((long)x);
}
} else if (v_type == 1) { // float32
for (Number x: flattened_list) {
tensor_builder.addFloatData((float)x);
}
} else if (v_type == 2) { // int32
for (Number x: flattened_list) {
tensor_builder.addIntData((int)x);
}
} else {
throw new IllegalArgumentException("error tensor value type.");
}
tensor_builder.addAllShape(feedShapes_.get(name));
inst_builder.addTensorArray(tensor_builder.build());
}
req_builder.addInsts(inst_builder.build());
}
return req_builder.build();
}
private Map<String, ? extends Map<String, List<? extends Number>>>
_unpackInferenceResponse(
InferenceResponse resp,
Iterable<String> fetch,
Boolean need_variant_tag) throws IllegalArgumentException {
return Client._staticUnpackInferenceResponse(
resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag);
}
private static Map<String, ? extends Map<String, List<? extends Number>>>
_staticUnpackInferenceResponse(
InferenceResponse resp,
Iterable<String> fetch,
Map<String, Integer> fetchTypes,
Set<String> lodTensorSet,
Boolean need_variant_tag) throws IllegalArgumentException {
if (resp.getErrCode() != 0) {
return null;
}
String tag = resp.getTag();
Map<String, ? extends Map<String, List<? extends Number>>> multi_result_map
= new HashMap<String, HashMap<String, List<? extends Number>>>();
for (ModelOutput model_result: resp.getOutputsList()) {
FetchInst inst = model_result.getInsts(0);
Map<String, List<? extends Number>> result_map
= new HashMap<String, List<? extends Number>>();
int index = 0;
for (String name: fetch) {
Tensor variable = inst.getTensorArray(index);
int v_type = fetchTypes.get(name);
if (v_type == 0) { // int64
result_map.put(name, variable.getInt64DataList());
} else if (v_type == 1) { // float32
result_map.put(name, variable.getFloatDataList());
} else if (v_type == 2) { // int32
result_map.put(name, variable.getIntDataList());
} else {
throw new IllegalArgumentException("error tensor value type.");
}
// TODO: shape
if (lodTensorSet.contains(name)) {
result_map.put(name + ".lod", variable.getLodList());
}
index += 1;
}
}
return multi_result_map;
}
public Map<String, ? extends Map<String, List<? extends Number>>> predict(
List<Map<String, List<? extends Number>>> feed,
Iterable<String> fetch,
Boolean need_variant_tag) {
InferenceRequest req = _packInferenceRequest(feed, fetch);
try {
InferenceResponse resp = blockingStub_.inference(req);
return _unpackInferenceResponse(
resp, fetch, need_variant_tag);
} catch (StatusRuntimeException e) {
System.out.format("grpc failed: %s", e.toString());
return Collections.emptyMap();
}
}
public PredictFuture async_predict(
List<Map<String, List<? extends Number>>> feed,
Iterable<String> fetch,
Boolean need_variant_tag) {
InferenceRequest req = _packInferenceRequest(feed, fetch);
ListenableFuture<InferenceResponse> future = futureStub_.inference(req);
// Function<InferenceResponse,
// Map<String, ? extends Map<String, List<? extends Number>>>>
// call_back_func = partial(
// Client::_unpackInferenceResponse, fetch, need_variant_tag);
return new PredictFuture(
future,
(InferenceResponse resp) -> {
return Client._staticUnpackInferenceResponse(
resp, fetch, fetchTypes_, lodTensorSet_, need_variant_tag);
});
}
public static void main( String[] args ) {
System.out.println( "Hello World!" );
}
}
class PredictFuture {
private ListenableFuture<InferenceResponse> callFuture_;
private Function<InferenceResponse,
Map<String, ? extends Map<String, List<? extends Number>>>> callBackFunc_;
PredictFuture(ListenableFuture<InferenceResponse> call_future,
Function<InferenceResponse,
Map<String, ? extends Map<String, List<? extends Number>>>> call_back_func) {
callFuture_ = call_future;
callBackFunc_ = call_back_func;
}
public Map<String, ? extends Map<String, List<? extends Number>>> get() throws Exception {
InferenceResponse resp = null;
try {
resp = callFuture_.get();
} catch (Exception e) {
System.out.format("grpc failed: %s", e.toString());
return null;
}
return callBackFunc_.apply(resp);
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册