提交 64553f83 编写于 作者: B barrierye

support model ensemble

上级 573c1a23
...@@ -58,7 +58,7 @@ public class Client { ...@@ -58,7 +58,7 @@ public class Client {
try { try {
resp = blockingStub_.setTimeout(timeout_req); resp = blockingStub_.setTimeout(timeout_req);
} catch (StatusRuntimeException e) { } catch (StatusRuntimeException e) {
System.out.format("Set RPC timeout failed: %s", e.toString()); System.out.format("Set RPC timeout failed: %s\n", e.toString());
return false; return false;
} }
return resp.getErrCode() == 0; return resp.getErrCode() == 0;
...@@ -77,7 +77,7 @@ public class Client { ...@@ -77,7 +77,7 @@ public class Client {
blockingStub_ = MultiLangGeneralModelServiceGrpc.newBlockingStub(channel_); blockingStub_ = MultiLangGeneralModelServiceGrpc.newBlockingStub(channel_);
futureStub_ = MultiLangGeneralModelServiceGrpc.newFutureStub(channel_); futureStub_ = MultiLangGeneralModelServiceGrpc.newFutureStub(channel_);
} catch (Exception e) { } catch (Exception e) {
System.out.format("Connect failed: %s", e.toString()); System.out.format("Connect failed: %s\n", e.toString());
return false; return false;
} }
GetClientConfigRequest get_client_config_req = GetClientConfigRequest.newBuilder().build(); GetClientConfigRequest get_client_config_req = GetClientConfigRequest.newBuilder().build();
...@@ -85,7 +85,7 @@ public class Client { ...@@ -85,7 +85,7 @@ public class Client {
try { try {
resp = blockingStub_.getClientConfig(get_client_config_req); resp = blockingStub_.getClientConfig(get_client_config_req);
} catch (StatusRuntimeException e) { } catch (StatusRuntimeException e) {
System.out.format("Get Client config failed: %s", e.toString()); System.out.format("Get Client config failed: %s\n", e.toString());
return false; return false;
} }
String model_config_str = resp.getClientConfigStr(); String model_config_str = resp.getClientConfigStr();
...@@ -98,7 +98,7 @@ public class Client { ...@@ -98,7 +98,7 @@ public class Client {
try { try {
com.google.protobuf.TextFormat.getParser().merge(model_config_str, model_conf_builder); com.google.protobuf.TextFormat.getParser().merge(model_config_str, model_conf_builder);
} catch (com.google.protobuf.TextFormat.ParseException e) { } catch (com.google.protobuf.TextFormat.ParseException e) {
System.out.format("Parse client config failed: %s", e.toString()); System.out.format("Parse client config failed: %s\n", e.toString());
} }
GeneralModelConfig model_conf = model_conf_builder.build(); GeneralModelConfig model_conf = model_conf_builder.build();
...@@ -164,7 +164,9 @@ public class Client { ...@@ -164,7 +164,9 @@ public class Client {
long[] flattened_shape = {-1}; long[] flattened_shape = {-1};
INDArray flattened_list = variable.reshape(flattened_shape); INDArray flattened_list = variable.reshape(flattened_shape);
int v_type = feedTypes_.get(name); int v_type = feedTypes_.get(name);
System.out.println(flattened_list);
NdIndexIterator iter = new NdIndexIterator(flattened_list.shape()); NdIndexIterator iter = new NdIndexIterator(flattened_list.shape());
//System.out.format("name: %s, type: %d\n", name, v_type);
if (v_type == 0) { // int64 if (v_type == 0) { // int64
while (iter.hasNext()) { while (iter.hasNext()) {
long[] next_index = iter.next(); long[] next_index = iter.next();
...@@ -231,39 +233,50 @@ public class Client { ...@@ -231,39 +233,50 @@ public class Client {
for (String name: fetch) { for (String name: fetch) {
Tensor variable = inst.getTensorArray(index); Tensor variable = inst.getTensorArray(index);
int v_type = fetchTypes.get(name); int v_type = fetchTypes.get(name);
INDArray data = null;
if (v_type == 0) { // int64 if (v_type == 0) { // int64
List<Long> list = variable.getInt64DataList(); List<Long> list = variable.getInt64DataList();
long[] array = new long[list.size()]; long[] array = new long[list.size()];
for (int i = 0; i < list.size(); i++) { for (int i = 0; i < list.size(); i++) {
array[i] = list.get(i); array[i] = list.get(i);
} }
result_map.put(name, Nd4j.create(array)); data = Nd4j.createFromArray(array);
} else if (v_type == 1) { // float32 } else if (v_type == 1) { // float32
List<Float> list = variable.getFloatDataList(); List<Float> list = variable.getFloatDataList();
float[] array = new float[list.size()]; float[] array = new float[list.size()];
for (int i = 0; i < list.size(); i++) { for (int i = 0; i < list.size(); i++) {
array[i] = list.get(i); array[i] = list.get(i);
} }
result_map.put(name, Nd4j.create(array)); data = Nd4j.createFromArray(array);
} else if (v_type == 2) { // int32 } else if (v_type == 2) { // int32
List<Integer> list = variable.getIntDataList(); List<Integer> list = variable.getIntDataList();
int[] array = new int[list.size()]; int[] array = new int[list.size()];
for (int i = 0; i < list.size(); i++) { for (int i = 0; i < list.size(); i++) {
array[i] = list.get(i); array[i] = list.get(i);
} }
result_map.put(name, Nd4j.create(array)); data = Nd4j.createFromArray(array);
} else { } else {
throw new IllegalArgumentException("error tensor value type."); throw new IllegalArgumentException("error tensor value type.");
} }
// TODO: shape // shape
List<Integer> shape_lsit = variable.getShapeList();
int[] shape_array = new int[shape_lsit.size()];
for (int i = 0; i < shape_lsit.size(); ++i) {
shape_array[i] = shape_lsit.get(i);
}
data = data.reshape(shape_array);
// put data to result_map
result_map.put(name, data);
// lod
if (lodTensorSet.contains(name)) { if (lodTensorSet.contains(name)) {
List<Integer> list = variable.getLodList(); List<Integer> list = variable.getLodList();
int[] array = new int[list.size()]; int[] array = new int[list.size()];
for (int i = 0; i < list.size(); i++) { for (int i = 0; i < list.size(); i++) {
array[i] = list.get(i); array[i] = list.get(i);
} }
result_map.put(name + ".lod", Nd4j.create(array)); result_map.put(name + ".lod", Nd4j.createFromArray(array));
} }
index += 1; index += 1;
} }
...@@ -353,12 +366,12 @@ public class Client { ...@@ -353,12 +366,12 @@ public class Client {
= new ArrayList<Map.Entry<String, HashMap<String, INDArray>>>( = new ArrayList<Map.Entry<String, HashMap<String, INDArray>>>(
ensemble_result.entrySet()); ensemble_result.entrySet());
if (list.size() != 1) { if (list.size() != 1) {
System.out.format("grpc failed: please use ensemble_predict impl."); System.out.format("grpc failed: please use ensemble_predict impl.\n");
return null; return null;
} }
return list.get(0).getValue(); return list.get(0).getValue();
} catch (StatusRuntimeException e) { } catch (StatusRuntimeException e) {
System.out.format("grpc failed: %s", e.toString()); System.out.format("grpc failed: %s\n", e.toString());
return null; return null;
} }
} }
...@@ -373,7 +386,7 @@ public class Client { ...@@ -373,7 +386,7 @@ public class Client {
return _unpackInferenceResponse( return _unpackInferenceResponse(
resp, fetch, need_variant_tag); resp, fetch, need_variant_tag);
} catch (StatusRuntimeException e) { } catch (StatusRuntimeException e) {
System.out.format("grpc failed: %s", e.toString()); System.out.format("grpc failed: %s\n", e.toString());
return null; return null;
} }
} }
...@@ -394,30 +407,44 @@ public class Client { ...@@ -394,30 +407,44 @@ public class Client {
} }
public static void main( String[] args ) { public static void main( String[] args ) {
/*
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f, 0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.create(data); INDArray npdata = Nd4j.createFromArray(data);
Client client = new Client();
List<String> endpoints = Arrays.asList("localhost:9393");
boolean succ = client.connect(endpoints);
if (succ != true) {
System.out.println("connect failed.");
return;
}
HashMap<String, INDArray> feed_data HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{ = new HashMap<String, INDArray>() {{
put("x", npdata); put("x", npdata);
}}; }};
List<String> fetch = Arrays.asList("price"); List<String> fetch = Arrays.asList("price");
*/
/*
Map<String, INDArray> fetch_map = client.predict(feed_data, fetch); Map<String, INDArray> fetch_map = client.predict(feed_data, fetch);
for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) { for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
} }
/* */
//int[] data = {8, 233, 52, 601};
long[] data = {8, 233, 52, 601};
INDArray npdata = Nd4j.createFromArray(data);
//System.out.println(npdata);
HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{
put("words", npdata);
}};
List<String> fetch = Arrays.asList("prediction");
Client client = new Client();
List<String> endpoints = Arrays.asList("localhost:9393");
boolean succ = client.connect(endpoints);
if (succ != true) {
System.out.println("connect failed.");
return;
}
Map<String, HashMap<String, INDArray>> fetch_map Map<String, HashMap<String, INDArray>> fetch_map
= client.predict(feed_data, fetch); = client.ensemble_predict(feed_data, fetch);
for (Map.Entry<String, HashMap<String, INDArray>> entry : fetch_map.entrySet()) { for (Map.Entry<String, HashMap<String, INDArray>> entry : fetch_map.entrySet()) {
System.out.println("Model = " + entry.getKey()); System.out.println("Model = " + entry.getKey());
HashMap<String, INDArray> tt = entry.getValue(); HashMap<String, INDArray> tt = entry.getValue();
...@@ -425,7 +452,7 @@ public class Client { ...@@ -425,7 +452,7 @@ public class Client {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue()); System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
} }
} }
*/
} }
} }
...@@ -446,7 +473,7 @@ class PredictFuture { ...@@ -446,7 +473,7 @@ class PredictFuture {
try { try {
resp = callFuture_.get(); resp = callFuture_.get();
} catch (Exception e) { } catch (Exception e) {
System.out.format("grpc failed: %s", e.toString()); System.out.format("grpc failed: %s\n", e.toString());
return null; return null;
} }
Map<String, HashMap<String, INDArray>> ensemble_result Map<String, HashMap<String, INDArray>> ensemble_result
...@@ -455,18 +482,18 @@ class PredictFuture { ...@@ -455,18 +482,18 @@ class PredictFuture {
= new ArrayList<Map.Entry<String, HashMap<String, INDArray>>>( = new ArrayList<Map.Entry<String, HashMap<String, INDArray>>>(
ensemble_result.entrySet()); ensemble_result.entrySet());
if (list.size() != 1) { if (list.size() != 1) {
System.out.format("grpc failed: please use get_ensemble impl."); System.out.format("grpc failed: please use get_ensemble impl.\n");
return null; return null;
} }
return list.get(0).getValue(); return list.get(0).getValue();
} }
public Map<String, HashMap<String, INDArray>> get_ensemble() throws Exception { public Map<String, HashMap<String, INDArray>> ensemble_get() throws Exception {
InferenceResponse resp = null; InferenceResponse resp = null;
try { try {
resp = callFuture_.get(); resp = callFuture_.get();
} catch (Exception e) { } catch (Exception e) {
System.out.format("grpc failed: %s", e.toString()); System.out.format("grpc failed: %s\n", e.toString());
return null; return null;
} }
return callBackFunc_.apply(resp); return callBackFunc_.apply(resp);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册