提交 00dd5546 编写于 作者: H HexToString

update java

上级 324f4196
...@@ -11,7 +11,7 @@ import org.nd4j.linalg.factory.Nd4j; ...@@ -11,7 +11,7 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.*; import java.util.*;
public class PaddleServingClientExample { public class PaddleServingClientExample {
boolean fit_a_line() { boolean fit_a_line(String model_config_path) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f, 0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
...@@ -25,15 +25,69 @@ public class PaddleServingClientExample { ...@@ -25,15 +25,69 @@ public class PaddleServingClientExample {
List<String> fetch = Arrays.asList("price"); List<String> fetch = Arrays.asList("price");
HttpClient client = new HttpClient(); HttpClient client = new HttpClient();
client.setIP("172.17.0.2"); client.setIP("0.0.0.0");
client.setPort("9393"); client.setPort("9393");
client.loadClientConfig(model_config_path);
String result = client.predict(feed_data, fetch, true, 0); String result = client.predict(feed_data, fetch, true, 0);
System.out.println(result); System.out.println(result);
return true; return true;
} }
boolean yolov4(String filename) { boolean encrypt(String model_config_path,String keyFilePath) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.createFromArray(data);
long[] batch_shape = {1,13};
INDArray batch_npdata = npdata.reshape(batch_shape);
HashMap<String, Object> feed_data
= new HashMap<String, Object>() {{
put("x", batch_npdata);
}};
List<String> fetch = Arrays.asList("price");
HttpClient client = new HttpClient();
client.setIP("0.0.0.0");
client.setPort("9393");
client.loadClientConfig(model_config_path);
client.use_key(keyFilePath);
try {
Thread.sleep(1000*3); // 休眠3秒,等待Server启动
} catch (Exception e) {
//TODO: handle exception
}
String result = client.predict(feed_data, fetch, true, 0);
System.out.println(result);
return true;
}
boolean compress(String model_config_path) {
float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f,
0.0582f, -0.0727f, -0.1583f, -0.0584f,
0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f};
INDArray npdata = Nd4j.createFromArray(data);
long[] batch_shape = {500,13};
INDArray batch_npdata = npdata.broadcast(batch_shape);
HashMap<String, Object> feed_data
= new HashMap<String, Object>() {{
put("x", batch_npdata);
}};
List<String> fetch = Arrays.asList("price");
HttpClient client = new HttpClient();
client.setIP("0.0.0.0");
client.setPort("9393");
client.loadClientConfig(model_config_path);
client.set_request_compress(true);
client.set_response_compress(true);
String result = client.predict(feed_data, fetch, true, 0);
System.out.println(result);
return true;
}
boolean yolov4(String model_config_path,String filename) {
// https://deeplearning4j.konduit.ai/ // https://deeplearning4j.konduit.ai/
int height = 608; int height = 608;
int width = 608; int width = 608;
...@@ -74,14 +128,15 @@ public class PaddleServingClientExample { ...@@ -74,14 +128,15 @@ public class PaddleServingClientExample {
}}; }};
List<String> fetch = Arrays.asList("save_infer_model/scale_0.tmp_0"); List<String> fetch = Arrays.asList("save_infer_model/scale_0.tmp_0");
HttpClient client = new HttpClient(); HttpClient client = new HttpClient();
client.setIP("172.17.0.2"); client.setIP("0.0.0.0");
client.setPort("9393"); client.setPort("9393");
client.loadClientConfig(model_config_path);
String result = client.predict(feed_data, fetch, true, 0); String result = client.predict(feed_data, fetch, true, 0);
System.out.println(result); System.out.println(result);
return true; return true;
} }
boolean bert() { boolean bert(String model_config_path) {
float[] input_mask = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; float[] input_mask = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
long[] position_ids = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; long[] position_ids = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
long[] input_ids = {101, 6843, 3241, 749, 8024, 7662, 2533, 1391, 2533, 2523, 7676, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; long[] input_ids = {101, 6843, 3241, 749, 8024, 7662, 2533, 1391, 2533, 2523, 7676, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
...@@ -95,14 +150,15 @@ public class PaddleServingClientExample { ...@@ -95,14 +150,15 @@ public class PaddleServingClientExample {
}}; }};
List<String> fetch = Arrays.asList("pooled_output"); List<String> fetch = Arrays.asList("pooled_output");
HttpClient client = new HttpClient(); HttpClient client = new HttpClient();
client.setIP("172.17.0.2"); client.setIP("0.0.0.0");
client.setPort("9393"); client.setPort("9393");
client.loadClientConfig(model_config_path);
String result = client.predict(feed_data, fetch, true, 0); String result = client.predict(feed_data, fetch, true, 0);
System.out.println(result); System.out.println(result);
return true; return true;
} }
boolean cube_local() { boolean cube_local(String model_config_path) {
long[] embedding_14 = {250644}; long[] embedding_14 = {250644};
long[] embedding_2 = {890346}; long[] embedding_2 = {890346};
long[] embedding_10 = {3939}; long[] embedding_10 = {3939};
...@@ -164,8 +220,9 @@ public class PaddleServingClientExample { ...@@ -164,8 +220,9 @@ public class PaddleServingClientExample {
}}; }};
List<String> fetch = Arrays.asList("prob"); List<String> fetch = Arrays.asList("prob");
HttpClient client = new HttpClient(); HttpClient client = new HttpClient();
client.setIP("172.17.0.2"); client.setIP("0.0.0.0");
client.setPort("9393"); client.setPort("9393");
client.loadClientConfig(model_config_path);
String result = client.predict(feed_data, fetch, true, 0); String result = client.predict(feed_data, fetch, true, 0);
System.out.println(result); System.out.println(result);
return true; return true;
...@@ -177,25 +234,33 @@ public class PaddleServingClientExample { ...@@ -177,25 +234,33 @@ public class PaddleServingClientExample {
PaddleServingClientExample e = new PaddleServingClientExample(); PaddleServingClientExample e = new PaddleServingClientExample();
boolean succ = false; boolean succ = false;
if (args.length < 1) { if (args.length < 2) {
System.out.println("Usage: java -cp <jar> PaddleServingClientExample <test-type>."); System.out.println("Usage: java -cp <jar> PaddleServingClientExample <test-type> <configPath>.");
System.out.println("<test-type>: fit_a_line bert cube_local yolov4"); System.out.println("<test-type>: fit_a_line bert cube_local yolov4 encrypt");
return; return;
} }
String testType = args[0]; String testType = args[0];
System.out.format("[Example] %s\n", testType); System.out.format("[Example] %s\n", testType);
if ("fit_a_line".equals(testType)) { if ("fit_a_line".equals(testType)) {
succ = e.fit_a_line(); succ = e.fit_a_line(args[1]);
} else if ("compress".equals(testType)) {
succ = e.compress(args[1]);
} else if ("bert".equals(testType)) { } else if ("bert".equals(testType)) {
succ = e.bert(); succ = e.bert(args[1]);
} else if ("cube_local".equals(testType)) { } else if ("cube_local".equals(testType)) {
succ = e.cube_local(); succ = e.cube_local(args[1]);
} else if ("yolov4".equals(testType)) { } else if ("yolov4".equals(testType)) {
if (args.length < 2) { if (args.length < 3) {
System.out.println("Usage: java -cp <jar> PaddleServingClientExample yolov4 <image-filepath>."); System.out.println("Usage: java -cp <jar> PaddleServingClientExample yolov4 <configPath> <image-filepath>.");
return;
}
succ = e.yolov4(args[1],args[2]);
} else if ("encrypt".equals(testType)) {
if (args.length < 3) {
System.out.println("Usage: java -cp <jar> PaddleServingClientExample encrypt <configPath> <keyPath>.");
return; return;
} }
succ = e.yolov4(args[1]); succ = e.encrypt(args[1],args[2]);
} else { } else {
System.out.format("test-type(%s) not match.\n", testType); System.out.format("test-type(%s) not match.\n", testType);
return; return;
......
...@@ -28,6 +28,7 @@ import org.apache.http.impl.client.CloseableHttpClient; ...@@ -28,6 +28,7 @@ import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients; import org.apache.http.impl.client.HttpClients;
import org.apache.http.message.BasicNameValuePair; import org.apache.http.message.BasicNameValuePair;
import org.apache.http.util.EntityUtils; import org.apache.http.util.EntityUtils;
import org.apache.http.entity.InputStreamEntity;
import org.json.*; import org.json.*;
...@@ -97,6 +98,7 @@ public class HttpClient { ...@@ -97,6 +98,7 @@ public class HttpClient {
private String serviceName; private String serviceName;
private boolean request_compress_flag; private boolean request_compress_flag;
private boolean response_compress_flag; private boolean response_compress_flag;
private String GLOG_v;
public HttpClient() { public HttpClient() {
feedNames_ = null; feedNames_ = null;
...@@ -115,6 +117,7 @@ public class HttpClient { ...@@ -115,6 +117,7 @@ public class HttpClient {
serviceName = "/GeneralModelService/inference"; serviceName = "/GeneralModelService/inference";
request_compress_flag = false; request_compress_flag = false;
response_compress_flag = false; response_compress_flag = false;
GLOG_v = System.getenv("GLOG_v");
feedTypeToDataKey_ = new HashMap<Integer, String>(); feedTypeToDataKey_ = new HashMap<Integer, String>();
feedTypeToDataKey_.put(0, "int64_data"); feedTypeToDataKey_.put(0, "int64_data");
...@@ -206,7 +209,7 @@ public class HttpClient { ...@@ -206,7 +209,7 @@ public class HttpClient {
String encrypt_url = "http://" + this.ip + ":" +this.port; String encrypt_url = "http://" + this.ip + ":" +this.port;
try { try {
byte[] data = Files.readAllBytes(Paths.get(keyFilePath)); byte[] data = Files.readAllBytes(Paths.get(keyFilePath));
key_str = new String(data, "utf-8"); key_str = Base64.getEncoder().encodeToString(data);
} catch (Exception e) { } catch (Exception e) {
System.out.format("Open key file failed: %s\n", e.toString()); System.out.format("Open key file failed: %s\n", e.toString());
} }
...@@ -237,16 +240,20 @@ public class HttpClient { ...@@ -237,16 +240,20 @@ public class HttpClient {
this.response_compress_flag = response_compress_flag; this.response_compress_flag = response_compress_flag;
} }
public static String compress(String str,String inEncoding) throws IOException { public byte[] compress(String str) {
if (str == null || str.length() == 0) { if (str == null || str.length() == 0) {
return str; return null;
} }
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
GZIPOutputStream gzip = new GZIPOutputStream(out); GZIPOutputStream gzip;
gzip.write(str.getBytes(inEncoding)); try {
gzip.close(); gzip = new GZIPOutputStream(out);
return out.toString("ISO-8859-1"); gzip.write(str.getBytes("UTF-8"));
gzip.close();
} catch (Exception e) {
e.printStackTrace();
}
return out.toByteArray();
} }
// 帮助用户封装Http请求的接口,用户只需要传递FeedData,Lod,Fetchlist即可。 // 帮助用户封装Http请求的接口,用户只需要传递FeedData,Lod,Fetchlist即可。
...@@ -302,36 +309,83 @@ public class HttpClient { ...@@ -302,36 +309,83 @@ public class HttpClient {
Object objectValue = mapEntry.getValue(); Object objectValue = mapEntry.getValue();
String feed_alias_name = mapEntry.getKey(); String feed_alias_name = mapEntry.getKey();
String feed_real_name = feedRealNames_.get(feed_alias_name); String feed_real_name = feedRealNames_.get(feed_alias_name);
List<Integer> shape = feedShapes_.get(feed_alias_name); List<Integer> shape = new ArrayList<Integer>(feedShapes_.get(feed_alias_name));
int element_type = feedTypes_.get(feed_alias_name); int element_type = feedTypes_.get(feed_alias_name);
jsonTensor.put("alias_name", feed_alias_name);
jsonTensor.put("name", feed_real_name);
jsonTensor.put("elem_type", element_type);
// 处理数据与shape
String protoDataKey = feedTypeToDataKey_.get(element_type); String protoDataKey = feedTypeToDataKey_.get(element_type);
Object feedLodValue = feedLod.get(feed_alias_name); // 如果是INDArray类型,先转为一维.
// 如果是INDArray类型,先转为一维,再objectValue.ToString. // 此时shape为INDArray的shape
// 如果是String或List,则直接objectValue.ToString. if(objectValue instanceof INDArray){
if(objectValue.getClass().equals(INDArray.class)){ INDArray tempIndArray = (INDArray)objectValue;
long[] flattened_shape = {-1}; long[] indarrayShape = tempIndArray.shape();
Class<?> classLongArray = flattened_shape.getClass();
Method methodReshape = mapEntry.getValue().getClass().getMethod("reshape", classLongArray);
Method methodShape = mapEntry.getValue().getClass().getMethod("shape");
long[] indarrayShape = (long[])methodShape.invoke(objectValue);
shape.clear(); shape.clear();
for(long dim:indarrayShape){ for(long dim:indarrayShape){
shape.add((int)dim); shape.add((int)dim);
} }
objectValue = methodReshape.invoke(objectValue,flattened_shape); objectValue = tempIndArray.data().asDouble();
}else if(objectValue.getClass().isArray()){
// 如果是数组类型,则无须处理,直接使用即可。
// 且数组无法嵌套,此时batch无法从数据中获取
// 默认batch维度为1,或者feedVar的shape信息中已包含batch
}else if(objectValue instanceof List){
// 如果为list,可能存在嵌套,此时需要展平
// 如果batchFlag为True,则认为是嵌套list
// 此时取最外层为batch的维度
if (batchFlag) {
List<?> list = new ArrayList<>();
list = new ArrayList<>((Collection<?>)objectValue);
// 在index=0处,加上batch
shape.add(0, list.size());
}
objectValue = recursiveExtract(objectValue);
}else{
// 此时认为是传入的单个String或者Int等
// 此时无法获取batch信息,故对shape不处理
// 由于Proto中为Repeated,需要把数据包装成list
if(objectValue instanceof String){
if(feedTypes_.get(protoDataKey)!= ElementType.Bytes_type.ordinal()){
throw new Exception("feedvar is not string-type,feed can`t be a single string.");
}
}else{
if(feedTypes_.get(protoDataKey)== ElementType.Bytes_type.ordinal()){
throw new Exception("feedvar is string-type,feed, feed can`t be a single int or others.");
}
}
List<Object> list = new ArrayList<>();
list.add(objectValue);
objectValue = list;
} }
if(batchFlag){ jsonTensor.put(protoDataKey,objectValue);
if(!batchFlag){
// 在index=0处,加上batch=1 // 在index=0处,加上batch=1
shape.add(0, 1); shape.add(0, 1);
} }
jsonTensor.put("alias_name", feed_alias_name);
jsonTensor.put("name", feed_real_name);
jsonTensor.put("shape", shape); jsonTensor.put("shape", shape);
jsonTensor.put("elem_type", element_type); // 处理lod信息,支持INDArray Array Iterable
jsonTensor.put(protoDataKey,objectValue); Object feedLodValue = null;
if(feedLodValue != null) { if(feedLod != null){
jsonTensor.put("lod", feedLodValue); feedLodValue = feedLod.get(feed_alias_name);
if(feedLodValue != null) {
if(feedLodValue instanceof INDArray){
INDArray tempIndArray = (INDArray)feedLodValue;
feedLodValue = tempIndArray.data().asInt();
}else if(feedLodValue.getClass().isArray()){
// 如果是数组类型,则无须处理,直接使用即可。
}else if(feedLodValue instanceof Iterable){
// 如果为list,可能存在嵌套,此时需要展平
feedLodValue = recursiveExtract(feedLodValue);
}else{
throw new Exception("Lod must be INDArray or Array or Iterable.");
}
jsonTensor.put("lod", feedLodValue);
}
} }
jsonTensorArray.put(jsonTensor); jsonTensorArray.put(jsonTensor);
} }
...@@ -343,6 +397,9 @@ public class HttpClient { ...@@ -343,6 +397,9 @@ public class HttpClient {
jsonRequest.put("log_id",log_id); jsonRequest.put("log_id",log_id);
jsonRequest.put("fetch_var_names", jsonFetchList); jsonRequest.put("fetch_var_names", jsonFetchList);
jsonRequest.put("tensor",jsonTensorArray); jsonRequest.put("tensor",jsonTensorArray);
if(GLOG_v != null){
System.out.format("------- Final jsonRequest: %s\n", jsonRequest.toString());
}
return doPost(server_url, jsonRequest.toString()); return doPost(server_url, jsonRequest.toString());
} }
...@@ -361,29 +418,41 @@ public class HttpClient { ...@@ -361,29 +418,41 @@ public class HttpClient {
.build(); .build();
// 为httpPost实例设置配置 // 为httpPost实例设置配置
httpPost.setConfig(requestConfig); httpPost.setConfig(requestConfig);
httpPost.setHeader("Content-Type", "application/json;charset=utf-8");
// 设置请求头 // 设置请求头
httpPost.addHeader("Content-Type", "application/json");
if(response_compress_flag){ if(response_compress_flag){
httpPost.addHeader("Accept-encoding", "gzip"); httpPost.addHeader("Accept-encoding", "gzip");
} if(GLOG_v != null){
if(request_compress_flag && strPostData.length()>512){ System.out.format("------- Accept-encoding gzip: \n");
try{
strPostData = compress(strPostData,"UTF-8");
httpPost.addHeader("Content-Encoding", "gzip");
} catch (IOException e) {
e.printStackTrace();
} }
} }
try { try {
httpPost.setEntity(new StringEntity(strPostData, "UTF-8")); if(request_compress_flag && strPostData.length()>1024){
try{
byte[] gzipEncrypt = compress(strPostData);
httpPost.setEntity(new InputStreamEntity(new ByteArrayInputStream(gzipEncrypt), gzipEncrypt.length));
httpPost.addHeader("Content-Encoding", "gzip");
} catch (Exception e) {
e.printStackTrace();
}
}else{
httpPost.setEntity(new StringEntity(strPostData, "UTF-8"));
}
// httpClient对象执行post请求,并返回响应参数对象 // httpClient对象执行post请求,并返回响应参数对象
httpResponse = httpClient.execute(httpPost); httpResponse = httpClient.execute(httpPost);
// 从响应对象中获取响应内容 // 从响应对象中获取响应内容
HttpEntity entity = httpResponse.getEntity(); HttpEntity entity = httpResponse.getEntity();
Header header = entity.getContentEncoding(); Header header = entity.getContentEncoding();
if(GLOG_v != null){
System.out.format("------- response header: %s\n", header);
}
if(header != null && header.getValue().equalsIgnoreCase("gzip")){ //判断返回内容是否为gzip压缩格式 if(header != null && header.getValue().equalsIgnoreCase("gzip")){ //判断返回内容是否为gzip压缩格式
GzipDecompressingEntity gzipEntity = new GzipDecompressingEntity(entity); GzipDecompressingEntity gzipEntity = new GzipDecompressingEntity(entity);
result = EntityUtils.toString(gzipEntity); result = EntityUtils.toString(gzipEntity);
if(GLOG_v != null){
System.out.format("------- degzip response: %s\n", result);
}
}else{ }else{
result = EntityUtils.toString(entity); result = EntityUtils.toString(entity);
} }
...@@ -410,5 +479,25 @@ public class HttpClient { ...@@ -410,5 +479,25 @@ public class HttpClient {
} }
return result; return result;
} }
public List<Object> recursiveExtract(Object stuff) {
List<Object> mylist = new ArrayList<Object>();
if(stuff instanceof Iterable) {
for(Object o : (Iterable< ? >)stuff) {
mylist.addAll(recursiveExtract(o));
}
} else if(stuff instanceof Map) {
for(Object o : ((Map<?, ? extends Object>) stuff).values()) {
mylist.addAll(recursiveExtract(o));
}
} else {
mylist.add(stuff);
}
return mylist;
}
} }
...@@ -21,7 +21,7 @@ import time ...@@ -21,7 +21,7 @@ import time
client = HttpClient() client = HttpClient()
client.load_client_config(sys.argv[1]) client.load_client_config(sys.argv[1])
# if you want to enable Encrypt Module,uncommenting the following line # if you want to enable Encrypt Module,uncommenting the following line
#client.use_key("./key") client.use_key("./key")
client.set_response_compress(True) client.set_response_compress(True)
client.set_request_compress(True) client.set_request_compress(True)
fetch_list = client.get_fetch_names() fetch_list = client.get_fetch_names()
......
...@@ -42,6 +42,23 @@ def list_flatten(items, ignore_types=(str, bytes)): ...@@ -42,6 +42,23 @@ def list_flatten(items, ignore_types=(str, bytes)):
yield x yield x
def data_bytes_number(datalist):
total_bytes_number = 0
if isinstance(datalist, list):
if len(datalist) == 0:
return total_bytes_number
else:
for data in datalist:
if isinstance(data, str):
total_bytes_number = total_bytes_number + len(data)
else:
total_bytes_number = total_bytes_number + 4 * len(datalist)
break
else:
raise ValueError(
"In the Function data_bytes_number(), data must be list.")
class HttpClient(object): class HttpClient(object):
def __init__(self, def __init__(self,
ip="0.0.0.0", ip="0.0.0.0",
...@@ -157,7 +174,8 @@ class HttpClient(object): ...@@ -157,7 +174,8 @@ class HttpClient(object):
def get_fetch_names(self): def get_fetch_names(self):
return self.fetch_names_ return self.fetch_names_
# feed 支持Numpy类型,Json-String,以及直接List、tuple # feed 支持Numpy类型,以及直接List、tuple
# 不支持str类型,因为proto中为repeated.
def predict(self, def predict(self,
feed=None, feed=None,
fetch=None, fetch=None,
...@@ -179,7 +197,7 @@ class HttpClient(object): ...@@ -179,7 +197,7 @@ class HttpClient(object):
if isinstance(feed, dict): if isinstance(feed, dict):
feed_batch.append(feed) feed_batch.append(feed)
elif isinstance(feed, (list, str, tuple)): elif isinstance(feed, (list, str, tuple)):
# if input is a list or str, and the number of feed_var is 1. # if input is a list or str or tuple, and the number of feed_var is 1.
# create a temp_dict { key = feed_var_name, value = list} # create a temp_dict { key = feed_var_name, value = list}
# put the temp_dict into the feed_batch. # put the temp_dict into the feed_batch.
if len(self.feed_names_) != 1: if len(self.feed_names_) != 1:
...@@ -230,46 +248,55 @@ class HttpClient(object): ...@@ -230,46 +248,55 @@ class HttpClient(object):
data_value = feed_i[key] data_value = feed_i[key]
data_key = proto_data_key_list[elem_type] data_key = proto_data_key_list[elem_type]
# 输入不是string类型 # feed_i[key] 可以是np.ndarray
if self.feed_types_[key] != bytes_type: # 也可以是list或tuple
# feed_i[key] 可以是np.ndarray # 当np.ndarray需要处理为list
# 也可以是string或list或tuple if isinstance(feed_i[key], np.ndarray):
# 当np.ndarray需要处理为list shape_lst = []
if isinstance(feed_i[key], np.ndarray): # 0维numpy 需要在外层再加一个[]
shape_lst = [] if feed_i[key].ndim == 0:
# 0维numpy 需要在外层再加一个[] data_value = [feed_i[key].tolist()]
if feed_i[key].ndim == 0: shape_lst.append(1)
data_value = [feed_i[key].tolist()] else:
shape_lst.append(1) shape_lst.extend(list(feed_i[key].shape))
else: shape = shape_lst
shape_lst.extend(list(feed_i[key].shape)) data_value = feed_i[key].flatten().tolist()
shape = shape_lst # 当Batch为False,shape字段前插一个1,表示batch维
data_value = feed_i[key].flatten().tolist() # 当Batch为True,则直接使用numpy.shape作为batch维度
# 当Batch为False,shape字段前插一个1,表示batch维 if batch == False:
# 当Batch为True,则直接使用numpy.shape作为batch维度 shape.insert(0, 1)
if batch == False:
shape.insert(0, 1) # 当是list或tuple时,需要把多层嵌套展开
elif isinstance(feed_i[key], (list, tuple)):
# 当是list或tuple时,需要把多层嵌套展开 # 当Batch为False,shape字段前插一个1,表示batch维
if isinstance(feed_i[key], (list, tuple)): # 当Batch为True, 由于list并不像numpy那样规整,所以
# 当Batch为False,shape字段前插一个1,表示batch维 # 无法获取shape,此时取第一维度作为Batch维度.
# 当Batch为True, 由于list并不像numpy那样规整,所以 # 插入到feedVar.shape前面.
# 无法获取shape,此时取第一维度作为Batch维度. if batch == False:
# 插入到feedVar.shape前面. shape.insert(0, 1)
if batch == False: else:
shape.insert(0, 1) shape.insert(0, len(feed_i[key]))
else: feed_i[key] = [x for x in list_flatten(feed_i[key])]
shape.insert(0, len(feed_i[key]))
feed_i[key] = [x for x in list_flatten(feed_i[key])]
data_value = feed_i[key]
'''
this is comment, for coder to understand.
#if input is string, feed is not numpy.
else:
shape = self.feed_shapes_[key]
data_value = feed_i[key] data_value = feed_i[key]
''' else:
total_data_number = total_data_number + len(data_value) # 输入可能是单个的str或int值等
# 此时先统一处理为一个list
# 由于输入比较特殊,shape保持原feedvar中不变
data_value = []
data_value.append(feed_i[key])
if isinstance(feed_i[key], str):
if self.feed_types_[key] != bytes_type:
raise ValueError(
"feedvar is not string-type,feed can`t be a single string."
)
else:
if self.feed_types_[key] == bytes_type:
raise ValueError(
"feedvar is string-type,feed, feed can`t be a single int or others."
)
total_data_number = total_data_number + data_bytes_number(
data_value)
Request["tensor"][index]["elem_type"] = elem_type Request["tensor"][index]["elem_type"] = elem_type
Request["tensor"][index]["shape"] = shape Request["tensor"][index]["shape"] = shape
Request["tensor"][index][data_key] = data_value Request["tensor"][index][data_key] = data_value
...@@ -285,6 +312,7 @@ class HttpClient(object): ...@@ -285,6 +312,7 @@ class HttpClient(object):
web_url = "http://" + self.ip + ":" + self.server_port + self.service_name web_url = "http://" + self.ip + ":" + self.server_port + self.service_name
postData = json.dumps(Request) postData = json.dumps(Request)
headers = {} headers = {}
# 当数据区长度大于512字节时才压缩.
if self.try_request_gzip and total_data_number > 512: if self.try_request_gzip and total_data_number > 512:
postData = gzip.compress(bytes(postData, 'utf-8')) postData = gzip.compress(bytes(postData, 'utf-8'))
headers["Content-Encoding"] = "gzip" headers["Content-Encoding"] = "gzip"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册