From 00dd554626f18d9925f70e21866719e956b5963c Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Tue, 3 Aug 2021 14:06:19 +0000 Subject: [PATCH] update java --- .../main/java/PaddleServingClientExample.java | 99 +++++++++-- .../io/paddle/serving/client/HttpClient.java | 159 ++++++++++++++---- python/examples/fit_a_line/test_httpclient.py | 2 +- python/paddle_serving_client/httpclient.py | 110 +++++++----- 4 files changed, 276 insertions(+), 94 deletions(-) diff --git a/java/examples/src/main/java/PaddleServingClientExample.java b/java/examples/src/main/java/PaddleServingClientExample.java index 806284c9..b15f9dbb 100755 --- a/java/examples/src/main/java/PaddleServingClientExample.java +++ b/java/examples/src/main/java/PaddleServingClientExample.java @@ -11,7 +11,7 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.*; public class PaddleServingClientExample { - boolean fit_a_line() { + boolean fit_a_line(String model_config_path) { float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, 0.0582f, -0.0727f, -0.1583f, -0.0584f, 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; @@ -25,15 +25,69 @@ public class PaddleServingClientExample { List fetch = Arrays.asList("price"); HttpClient client = new HttpClient(); - client.setIP("172.17.0.2"); + client.setIP("0.0.0.0"); client.setPort("9393"); + client.loadClientConfig(model_config_path); String result = client.predict(feed_data, fetch, true, 0); System.out.println(result); return true; } - boolean yolov4(String filename) { + boolean encrypt(String model_config_path,String keyFilePath) { + float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, + 0.0582f, -0.0727f, -0.1583f, -0.0584f, + 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; + INDArray npdata = Nd4j.createFromArray(data); + long[] batch_shape = {1,13}; + INDArray batch_npdata = npdata.reshape(batch_shape); + HashMap feed_data + = new HashMap() {{ + put("x", batch_npdata); + }}; + List fetch = Arrays.asList("price"); + + HttpClient client = new HttpClient(); + client.setIP("0.0.0.0"); + client.setPort("9393"); + client.loadClientConfig(model_config_path); + client.use_key(keyFilePath); + try { + Thread.sleep(1000*3); // 休眠3秒,等待Server启动 + } catch (Exception e) { + //TODO: handle exception + } + String result = client.predict(feed_data, fetch, true, 0); + + System.out.println(result); + return true; + } + + boolean compress(String model_config_path) { + float[] data = {0.0137f, -0.1136f, 0.2553f, -0.0692f, + 0.0582f, -0.0727f, -0.1583f, -0.0584f, + 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; + INDArray npdata = Nd4j.createFromArray(data); + long[] batch_shape = {500,13}; + INDArray batch_npdata = npdata.broadcast(batch_shape); + HashMap feed_data + = new HashMap() {{ + put("x", batch_npdata); + }}; + List fetch = Arrays.asList("price"); + + HttpClient client = new HttpClient(); + client.setIP("0.0.0.0"); + client.setPort("9393"); + client.loadClientConfig(model_config_path); + client.set_request_compress(true); + client.set_response_compress(true); + String result = client.predict(feed_data, fetch, true, 0); + System.out.println(result); + return true; + } + + boolean yolov4(String model_config_path,String filename) { // https://deeplearning4j.konduit.ai/ int height = 608; int width = 608; @@ -74,14 +128,15 @@ public class PaddleServingClientExample { }}; List fetch = Arrays.asList("save_infer_model/scale_0.tmp_0"); HttpClient client = new HttpClient(); - client.setIP("172.17.0.2"); + client.setIP("0.0.0.0"); client.setPort("9393"); + client.loadClientConfig(model_config_path); String result = client.predict(feed_data, fetch, true, 0); System.out.println(result); return true; } - boolean bert() { + boolean bert(String model_config_path) { float[] input_mask = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; long[] position_ids = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; long[] input_ids = {101, 6843, 3241, 749, 8024, 7662, 2533, 1391, 2533, 2523, 7676, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; @@ -95,14 +150,15 @@ public class PaddleServingClientExample { }}; List fetch = Arrays.asList("pooled_output"); HttpClient client = new HttpClient(); - client.setIP("172.17.0.2"); + client.setIP("0.0.0.0"); client.setPort("9393"); + client.loadClientConfig(model_config_path); String result = client.predict(feed_data, fetch, true, 0); System.out.println(result); return true; } - boolean cube_local() { + boolean cube_local(String model_config_path) { long[] embedding_14 = {250644}; long[] embedding_2 = {890346}; long[] embedding_10 = {3939}; @@ -164,8 +220,9 @@ public class PaddleServingClientExample { }}; List fetch = Arrays.asList("prob"); HttpClient client = new HttpClient(); - client.setIP("172.17.0.2"); + client.setIP("0.0.0.0"); client.setPort("9393"); + client.loadClientConfig(model_config_path); String result = client.predict(feed_data, fetch, true, 0); System.out.println(result); return true; @@ -177,25 +234,33 @@ public class PaddleServingClientExample { PaddleServingClientExample e = new PaddleServingClientExample(); boolean succ = false; - if (args.length < 1) { - System.out.println("Usage: java -cp PaddleServingClientExample ."); - System.out.println(": fit_a_line bert cube_local yolov4"); + if (args.length < 2) { + System.out.println("Usage: java -cp PaddleServingClientExample ."); + System.out.println(": fit_a_line bert cube_local yolov4 encrypt"); return; } String testType = args[0]; System.out.format("[Example] %s\n", testType); if ("fit_a_line".equals(testType)) { - succ = e.fit_a_line(); + succ = e.fit_a_line(args[1]); + } else if ("compress".equals(testType)) { + succ = e.compress(args[1]); } else if ("bert".equals(testType)) { - succ = e.bert(); + succ = e.bert(args[1]); } else if ("cube_local".equals(testType)) { - succ = e.cube_local(); + succ = e.cube_local(args[1]); } else if ("yolov4".equals(testType)) { - if (args.length < 2) { - System.out.println("Usage: java -cp PaddleServingClientExample yolov4 ."); + if (args.length < 3) { + System.out.println("Usage: java -cp PaddleServingClientExample yolov4 ."); + return; + } + succ = e.yolov4(args[1],args[2]); + } else if ("encrypt".equals(testType)) { + if (args.length < 3) { + System.out.println("Usage: java -cp PaddleServingClientExample encrypt ."); return; } - succ = e.yolov4(args[1]); + succ = e.encrypt(args[1],args[2]); } else { System.out.format("test-type(%s) not match.\n", testType); return; diff --git a/java/src/main/java/io/paddle/serving/client/HttpClient.java b/java/src/main/java/io/paddle/serving/client/HttpClient.java index 69376e1e..5e481845 100644 --- a/java/src/main/java/io/paddle/serving/client/HttpClient.java +++ b/java/src/main/java/io/paddle/serving/client/HttpClient.java @@ -28,6 +28,7 @@ import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClients; import org.apache.http.message.BasicNameValuePair; import org.apache.http.util.EntityUtils; +import org.apache.http.entity.InputStreamEntity; import org.json.*; @@ -97,6 +98,7 @@ public class HttpClient { private String serviceName; private boolean request_compress_flag; private boolean response_compress_flag; + private String GLOG_v; public HttpClient() { feedNames_ = null; @@ -115,6 +117,7 @@ public class HttpClient { serviceName = "/GeneralModelService/inference"; request_compress_flag = false; response_compress_flag = false; + GLOG_v = System.getenv("GLOG_v"); feedTypeToDataKey_ = new HashMap(); feedTypeToDataKey_.put(0, "int64_data"); @@ -206,7 +209,7 @@ public class HttpClient { String encrypt_url = "http://" + this.ip + ":" +this.port; try { byte[] data = Files.readAllBytes(Paths.get(keyFilePath)); - key_str = new String(data, "utf-8"); + key_str = Base64.getEncoder().encodeToString(data); } catch (Exception e) { System.out.format("Open key file failed: %s\n", e.toString()); } @@ -237,16 +240,20 @@ public class HttpClient { this.response_compress_flag = response_compress_flag; } - public static String compress(String str,String inEncoding) throws IOException { + public byte[] compress(String str) { if (str == null || str.length() == 0) { - return str; + return null; } ByteArrayOutputStream out = new ByteArrayOutputStream(); - GZIPOutputStream gzip = new GZIPOutputStream(out); - gzip.write(str.getBytes(inEncoding)); - gzip.close(); - return out.toString("ISO-8859-1"); - + GZIPOutputStream gzip; + try { + gzip = new GZIPOutputStream(out); + gzip.write(str.getBytes("UTF-8")); + gzip.close(); + } catch (Exception e) { + e.printStackTrace(); + } + return out.toByteArray(); } // 帮助用户封装Http请求的接口,用户只需要传递FeedData,Lod,Fetchlist即可。 @@ -302,36 +309,83 @@ public class HttpClient { Object objectValue = mapEntry.getValue(); String feed_alias_name = mapEntry.getKey(); String feed_real_name = feedRealNames_.get(feed_alias_name); - List shape = feedShapes_.get(feed_alias_name); + List shape = new ArrayList(feedShapes_.get(feed_alias_name)); int element_type = feedTypes_.get(feed_alias_name); + + jsonTensor.put("alias_name", feed_alias_name); + jsonTensor.put("name", feed_real_name); + jsonTensor.put("elem_type", element_type); + + // 处理数据与shape String protoDataKey = feedTypeToDataKey_.get(element_type); - Object feedLodValue = feedLod.get(feed_alias_name); - // 如果是INDArray类型,先转为一维,再objectValue.ToString. - // 如果是String或List,则直接objectValue.ToString. - if(objectValue.getClass().equals(INDArray.class)){ - long[] flattened_shape = {-1}; - Class classLongArray = flattened_shape.getClass(); - Method methodReshape = mapEntry.getValue().getClass().getMethod("reshape", classLongArray); - Method methodShape = mapEntry.getValue().getClass().getMethod("shape"); - long[] indarrayShape = (long[])methodShape.invoke(objectValue); + // 如果是INDArray类型,先转为一维. + // 此时shape为INDArray的shape + if(objectValue instanceof INDArray){ + INDArray tempIndArray = (INDArray)objectValue; + long[] indarrayShape = tempIndArray.shape(); shape.clear(); for(long dim:indarrayShape){ shape.add((int)dim); } - objectValue = methodReshape.invoke(objectValue,flattened_shape); + objectValue = tempIndArray.data().asDouble(); + + }else if(objectValue.getClass().isArray()){ + // 如果是数组类型,则无须处理,直接使用即可。 + // 且数组无法嵌套,此时batch无法从数据中获取 + // 默认batch维度为1,或者feedVar的shape信息中已包含batch + }else if(objectValue instanceof List){ + // 如果为list,可能存在嵌套,此时需要展平 + // 如果batchFlag为True,则认为是嵌套list + // 此时取最外层为batch的维度 + if (batchFlag) { + List list = new ArrayList<>(); + list = new ArrayList<>((Collection)objectValue); + // 在index=0处,加上batch + shape.add(0, list.size()); + } + objectValue = recursiveExtract(objectValue); + }else{ + // 此时认为是传入的单个String或者Int等 + // 此时无法获取batch信息,故对shape不处理 + // 由于Proto中为Repeated,需要把数据包装成list + if(objectValue instanceof String){ + if(feedTypes_.get(protoDataKey)!= ElementType.Bytes_type.ordinal()){ + throw new Exception("feedvar is not string-type,feed can`t be a single string."); + } + }else{ + if(feedTypes_.get(protoDataKey)== ElementType.Bytes_type.ordinal()){ + throw new Exception("feedvar is string-type,feed, feed can`t be a single int or others."); + } + } + List list = new ArrayList<>(); + list.add(objectValue); + objectValue = list; } - if(batchFlag){ + jsonTensor.put(protoDataKey,objectValue); + if(!batchFlag){ // 在index=0处,加上batch=1 shape.add(0, 1); } - jsonTensor.put("alias_name", feed_alias_name); - jsonTensor.put("name", feed_real_name); jsonTensor.put("shape", shape); - jsonTensor.put("elem_type", element_type); - jsonTensor.put(protoDataKey,objectValue); - if(feedLodValue != null) { - jsonTensor.put("lod", feedLodValue); + // 处理lod信息,支持INDArray Array Iterable + Object feedLodValue = null; + if(feedLod != null){ + feedLodValue = feedLod.get(feed_alias_name); + if(feedLodValue != null) { + if(feedLodValue instanceof INDArray){ + INDArray tempIndArray = (INDArray)feedLodValue; + feedLodValue = tempIndArray.data().asInt(); + }else if(feedLodValue.getClass().isArray()){ + // 如果是数组类型,则无须处理,直接使用即可。 + }else if(feedLodValue instanceof Iterable){ + // 如果为list,可能存在嵌套,此时需要展平 + feedLodValue = recursiveExtract(feedLodValue); + }else{ + throw new Exception("Lod must be INDArray or Array or Iterable."); + } + jsonTensor.put("lod", feedLodValue); + } } jsonTensorArray.put(jsonTensor); } @@ -343,6 +397,9 @@ public class HttpClient { jsonRequest.put("log_id",log_id); jsonRequest.put("fetch_var_names", jsonFetchList); jsonRequest.put("tensor",jsonTensorArray); + if(GLOG_v != null){ + System.out.format("------- Final jsonRequest: %s\n", jsonRequest.toString()); + } return doPost(server_url, jsonRequest.toString()); } @@ -361,29 +418,41 @@ public class HttpClient { .build(); // 为httpPost实例设置配置 httpPost.setConfig(requestConfig); + httpPost.setHeader("Content-Type", "application/json;charset=utf-8"); // 设置请求头 - httpPost.addHeader("Content-Type", "application/json"); if(response_compress_flag){ httpPost.addHeader("Accept-encoding", "gzip"); - } - if(request_compress_flag && strPostData.length()>512){ - try{ - strPostData = compress(strPostData,"UTF-8"); - httpPost.addHeader("Content-Encoding", "gzip"); - } catch (IOException e) { - e.printStackTrace(); + if(GLOG_v != null){ + System.out.format("------- Accept-encoding gzip: \n"); } } + try { - httpPost.setEntity(new StringEntity(strPostData, "UTF-8")); + if(request_compress_flag && strPostData.length()>1024){ + try{ + byte[] gzipEncrypt = compress(strPostData); + httpPost.setEntity(new InputStreamEntity(new ByteArrayInputStream(gzipEncrypt), gzipEncrypt.length)); + httpPost.addHeader("Content-Encoding", "gzip"); + } catch (Exception e) { + e.printStackTrace(); + } + }else{ + httpPost.setEntity(new StringEntity(strPostData, "UTF-8")); + } // httpClient对象执行post请求,并返回响应参数对象 httpResponse = httpClient.execute(httpPost); // 从响应对象中获取响应内容 HttpEntity entity = httpResponse.getEntity(); Header header = entity.getContentEncoding(); + if(GLOG_v != null){ + System.out.format("------- response header: %s\n", header); + } if(header != null && header.getValue().equalsIgnoreCase("gzip")){ //判断返回内容是否为gzip压缩格式 GzipDecompressingEntity gzipEntity = new GzipDecompressingEntity(entity); result = EntityUtils.toString(gzipEntity); + if(GLOG_v != null){ + System.out.format("------- degzip response: %s\n", result); + } }else{ result = EntityUtils.toString(entity); } @@ -410,5 +479,25 @@ public class HttpClient { } return result; } + + public List recursiveExtract(Object stuff) { + + List mylist = new ArrayList(); + + if(stuff instanceof Iterable) { + for(Object o : (Iterable< ? >)stuff) { + mylist.addAll(recursiveExtract(o)); + } + } else if(stuff instanceof Map) { + for(Object o : ((Map) stuff).values()) { + mylist.addAll(recursiveExtract(o)); + } + } else { + mylist.add(stuff); + } + + return mylist; + } + } diff --git a/python/examples/fit_a_line/test_httpclient.py b/python/examples/fit_a_line/test_httpclient.py index cd993042..082b65a7 100644 --- a/python/examples/fit_a_line/test_httpclient.py +++ b/python/examples/fit_a_line/test_httpclient.py @@ -21,7 +21,7 @@ import time client = HttpClient() client.load_client_config(sys.argv[1]) # if you want to enable Encrypt Module,uncommenting the following line -#client.use_key("./key") +client.use_key("./key") client.set_response_compress(True) client.set_request_compress(True) fetch_list = client.get_fetch_names() diff --git a/python/paddle_serving_client/httpclient.py b/python/paddle_serving_client/httpclient.py index 84ba5c4f..356d27bd 100644 --- a/python/paddle_serving_client/httpclient.py +++ b/python/paddle_serving_client/httpclient.py @@ -42,6 +42,23 @@ def list_flatten(items, ignore_types=(str, bytes)): yield x +def data_bytes_number(datalist): + total_bytes_number = 0 + if isinstance(datalist, list): + if len(datalist) == 0: + return total_bytes_number + else: + for data in datalist: + if isinstance(data, str): + total_bytes_number = total_bytes_number + len(data) + else: + total_bytes_number = total_bytes_number + 4 * len(datalist) + break + else: + raise ValueError( + "In the Function data_bytes_number(), data must be list.") + + class HttpClient(object): def __init__(self, ip="0.0.0.0", @@ -157,7 +174,8 @@ class HttpClient(object): def get_fetch_names(self): return self.fetch_names_ - # feed 支持Numpy类型,Json-String,以及直接List、tuple + # feed 支持Numpy类型,以及直接List、tuple + # 不支持str类型,因为proto中为repeated. def predict(self, feed=None, fetch=None, @@ -179,7 +197,7 @@ class HttpClient(object): if isinstance(feed, dict): feed_batch.append(feed) elif isinstance(feed, (list, str, tuple)): - # if input is a list or str, and the number of feed_var is 1. + # if input is a list or str or tuple, and the number of feed_var is 1. # create a temp_dict { key = feed_var_name, value = list} # put the temp_dict into the feed_batch. if len(self.feed_names_) != 1: @@ -230,46 +248,55 @@ class HttpClient(object): data_value = feed_i[key] data_key = proto_data_key_list[elem_type] - # 输入不是string类型 - if self.feed_types_[key] != bytes_type: - # feed_i[key] 可以是np.ndarray - # 也可以是string或list或tuple - # 当np.ndarray需要处理为list - if isinstance(feed_i[key], np.ndarray): - shape_lst = [] - # 0维numpy 需要在外层再加一个[] - if feed_i[key].ndim == 0: - data_value = [feed_i[key].tolist()] - shape_lst.append(1) - else: - shape_lst.extend(list(feed_i[key].shape)) - shape = shape_lst - data_value = feed_i[key].flatten().tolist() - # 当Batch为False,shape字段前插一个1,表示batch维 - # 当Batch为True,则直接使用numpy.shape作为batch维度 - if batch == False: - shape.insert(0, 1) - - # 当是list或tuple时,需要把多层嵌套展开 - if isinstance(feed_i[key], (list, tuple)): - # 当Batch为False,shape字段前插一个1,表示batch维 - # 当Batch为True, 由于list并不像numpy那样规整,所以 - # 无法获取shape,此时取第一维度作为Batch维度. - # 插入到feedVar.shape前面. - if batch == False: - shape.insert(0, 1) - else: - shape.insert(0, len(feed_i[key])) - feed_i[key] = [x for x in list_flatten(feed_i[key])] - data_value = feed_i[key] - ''' - this is comment, for coder to understand. - #if input is string, feed is not numpy. - else: - shape = self.feed_shapes_[key] + # feed_i[key] 可以是np.ndarray + # 也可以是list或tuple + # 当np.ndarray需要处理为list + if isinstance(feed_i[key], np.ndarray): + shape_lst = [] + # 0维numpy 需要在外层再加一个[] + if feed_i[key].ndim == 0: + data_value = [feed_i[key].tolist()] + shape_lst.append(1) + else: + shape_lst.extend(list(feed_i[key].shape)) + shape = shape_lst + data_value = feed_i[key].flatten().tolist() + # 当Batch为False,shape字段前插一个1,表示batch维 + # 当Batch为True,则直接使用numpy.shape作为batch维度 + if batch == False: + shape.insert(0, 1) + + # 当是list或tuple时,需要把多层嵌套展开 + elif isinstance(feed_i[key], (list, tuple)): + # 当Batch为False,shape字段前插一个1,表示batch维 + # 当Batch为True, 由于list并不像numpy那样规整,所以 + # 无法获取shape,此时取第一维度作为Batch维度. + # 插入到feedVar.shape前面. + if batch == False: + shape.insert(0, 1) + else: + shape.insert(0, len(feed_i[key])) + feed_i[key] = [x for x in list_flatten(feed_i[key])] data_value = feed_i[key] - ''' - total_data_number = total_data_number + len(data_value) + else: + # 输入可能是单个的str或int值等 + # 此时先统一处理为一个list + # 由于输入比较特殊,shape保持原feedvar中不变 + data_value = [] + data_value.append(feed_i[key]) + if isinstance(feed_i[key], str): + if self.feed_types_[key] != bytes_type: + raise ValueError( + "feedvar is not string-type,feed can`t be a single string." + ) + else: + if self.feed_types_[key] == bytes_type: + raise ValueError( + "feedvar is string-type,feed, feed can`t be a single int or others." + ) + + total_data_number = total_data_number + data_bytes_number( + data_value) Request["tensor"][index]["elem_type"] = elem_type Request["tensor"][index]["shape"] = shape Request["tensor"][index][data_key] = data_value @@ -285,6 +312,7 @@ class HttpClient(object): web_url = "http://" + self.ip + ":" + self.server_port + self.service_name postData = json.dumps(Request) headers = {} + # 当数据区长度大于512字节时才压缩. if self.try_request_gzip and total_data_number > 512: postData = gzip.compress(bytes(postData, 'utf-8')) headers["Content-Encoding"] = "gzip" -- GitLab