提交 3461602f 编写于 作者: B barrierye

support model ensemble

上级 27e45738
......@@ -58,7 +58,7 @@ public class Client {
try {
resp = blockingStub_.setTimeout(timeout_req);
} catch (StatusRuntimeException e) {
System.out.format("Set RPC timeout failed: %s", e.toString());
System.out.format("Set RPC timeout failed: %s\n", e.toString());
return false;
}
return resp.getErrCode() == 0;
......@@ -77,7 +77,7 @@ public class Client {
blockingStub_ = MultiLangGeneralModelServiceGrpc.newBlockingStub(channel_);
futureStub_ = MultiLangGeneralModelServiceGrpc.newFutureStub(channel_);
} catch (Exception e) {
System.out.format("Connect failed: %s", e.toString());
System.out.format("Connect failed: %s\n", e.toString());
return false;
}
GetClientConfigRequest get_client_config_req = GetClientConfigRequest.newBuilder().build();
......@@ -85,7 +85,7 @@ public class Client {
try {
resp = blockingStub_.getClientConfig(get_client_config_req);
} catch (StatusRuntimeException e) {
System.out.format("Get Client config failed: %s", e.toString());
System.out.format("Get Client config failed: %s\n", e.toString());
return false;
}
String model_config_str = resp.getClientConfigStr();
......@@ -98,7 +98,7 @@ public class Client {
try {
com.google.protobuf.TextFormat.getParser().merge(model_config_str, model_conf_builder);
} catch (com.google.protobuf.TextFormat.ParseException e) {
System.out.format("Parse client config failed: %s", e.toString());
System.out.format("Parse client config failed: %s\n", e.toString());
}
GeneralModelConfig model_conf = model_conf_builder.build();
......@@ -164,7 +164,9 @@ public class Client {
long[] flattened_shape = {-1};
INDArray flattened_list = variable.reshape(flattened_shape);
int v_type = feedTypes_.get(name);
System.out.println(flattened_list);
NdIndexIterator iter = new NdIndexIterator(flattened_list.shape());
//System.out.format("name: %s, type: %d\n", name, v_type);
if (v_type == 0) { // int64
while (iter.hasNext()) {
long[] next_index = iter.next();
......@@ -231,39 +233,50 @@ public class Client {
for (String name: fetch) {
Tensor variable = inst.getTensorArray(index);
int v_type = fetchTypes.get(name);
INDArray data = null;
if (v_type == 0) { // int64
List<Long> list = variable.getInt64DataList();
long[] array = new long[list.size()];
for (int i = 0; i < list.size(); i++) {
array[i] = list.get(i);
}
result_map.put(name, Nd4j.create(array));
data = Nd4j.createFromArray(array);
} else if (v_type == 1) { // float32
List<Float> list = variable.getFloatDataList();
float[] array = new float[list.size()];
for (int i = 0; i < list.size(); i++) {
array[i] = list.get(i);
}
result_map.put(name, Nd4j.create(array));
data = Nd4j.createFromArray(array);
} else if (v_type == 2) { // int32
List<Integer> list = variable.getIntDataList();
int[] array = new int[list.size()];
for (int i = 0; i < list.size(); i++) {
array[i] = list.get(i);
}
result_map.put(name, Nd4j.create(array));
data = Nd4j.createFromArray(array);
} else {
throw new IllegalArgumentException("error tensor value type.");
}
// TODO: shape
// shape
List<Integer> shape_lsit = variable.getShapeList();
int[] shape_array = new int[shape_lsit.size()];
for (int i = 0; i < shape_lsit.size(); ++i) {
shape_array[i] = shape_lsit.get(i);
}
data = data.reshape(shape_array);
// put data to result_map
result_map.put(name, data);
// lod
if (lodTensorSet.contains(name)) {
List<Integer> list = variable.getLodList();
int[] array = new int[list.size()];
for (int i = 0; i < list.size(); i++) {
array[i] = list.get(i);
}
result_map.put(name + ".lod", Nd4j.create(array));
result_map.put(name + ".lod", Nd4j.createFromArray(array));
}
index += 1;
}
......@@ -353,12 +366,12 @@ public class Client {
= new ArrayList<Map.Entry<String, HashMap<String, INDArray>>>(
ensemble_result.entrySet());
if (list.size() != 1) {
System.out.format("grpc failed: please use ensemble_predict impl.");
System.out.format("grpc failed: please use ensemble_predict impl.\n");
return null;
}
return list.get(0).getValue();
} catch (StatusRuntimeException e) {
System.out.format("grpc failed: %s", e.toString());
System.out.format("grpc failed: %s\n", e.toString());
return null;
}
}
......@@ -373,7 +386,7 @@ public class Client {
return _unpackInferenceResponse(
resp, fetch, need_variant_tag);
} catch (StatusRuntimeException e) {
System.out.format("grpc failed: %s", e.toString());
System.out.format("grpc failed: %s\n", e.toString());
return null;
}
}
......@@ -394,30 +407,44 @@ public class Client {
}
public static void main( String[] args ) {
/*
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.create(data);
Client client = new Client();
List<String> endpoints = Arrays.asList("localhost:9393");
boolean succ = client.connect(endpoints);
if (succ != true) {
System.out.println("connect failed.");
return;
}
INDArray npdata = Nd4j.createFromArray(data);
HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{
put("x", npdata);
}};
List<String> fetch = Arrays.asList("price");
*/
/*
Map<String, INDArray> fetch_map = client.predict(feed_data, fetch);
for (Map.Entry<String, INDArray> e : fetch_map.entrySet()) {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
}
/*
*/
//int[] data = {8, 233, 52, 601};
long[] data = {8, 233, 52, 601};
INDArray npdata = Nd4j.createFromArray(data);
//System.out.println(npdata);
HashMap<String, INDArray> feed_data
= new HashMap<String, INDArray>() {{
put("words", npdata);
}};
List<String> fetch = Arrays.asList("prediction");
Client client = new Client();
List<String> endpoints = Arrays.asList("localhost:9393");
boolean succ = client.connect(endpoints);
if (succ != true) {
System.out.println("connect failed.");
return;
}
Map<String, HashMap<String, INDArray>> fetch_map
= client.predict(feed_data, fetch);
= client.ensemble_predict(feed_data, fetch);
for (Map.Entry<String, HashMap<String, INDArray>> entry : fetch_map.entrySet()) {
System.out.println("Model = " + entry.getKey());
HashMap<String, INDArray> tt = entry.getValue();
......@@ -425,7 +452,7 @@ public class Client {
System.out.println("Key = " + e.getKey() + ", Value = " + e.getValue());
}
}
*/
}
}
......@@ -446,7 +473,7 @@ class PredictFuture {
try {
resp = callFuture_.get();
} catch (Exception e) {
System.out.format("grpc failed: %s", e.toString());
System.out.format("grpc failed: %s\n", e.toString());
return null;
}
Map<String, HashMap<String, INDArray>> ensemble_result
......@@ -455,18 +482,18 @@ class PredictFuture {
= new ArrayList<Map.Entry<String, HashMap<String, INDArray>>>(
ensemble_result.entrySet());
if (list.size() != 1) {
System.out.format("grpc failed: please use get_ensemble impl.");
System.out.format("grpc failed: please use get_ensemble impl.\n");
return null;
}
return list.get(0).getValue();
}
public Map<String, HashMap<String, INDArray>> get_ensemble() throws Exception {
public Map<String, HashMap<String, INDArray>> ensemble_get() throws Exception {
InferenceResponse resp = null;
try {
resp = callFuture_.get();
} catch (Exception e) {
System.out.format("grpc failed: %s", e.toString());
System.out.format("grpc failed: %s\n", e.toString());
return null;
}
return callBackFunc_.apply(resp);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册