提交 24d76a63 编写于 作者: H HexToString 提交者: ShiningZhang

fix proto change

上级 f00e91f7
...@@ -42,7 +42,7 @@ python3.6 -m paddle_serving_server.serve --model uci_housing_model --thread 10 - ...@@ -42,7 +42,7 @@ python3.6 -m paddle_serving_server.serve --model uci_housing_model --thread 10 -
为了方便用户快速的使用Http方式请求Server端预测服务,我们已经将常用的Http请求的数据体封装、压缩、请求加密等功能封装为一个HttpClient类提供给用户,方便用户使用。 为了方便用户快速的使用Http方式请求Server端预测服务,我们已经将常用的Http请求的数据体封装、压缩、请求加密等功能封装为一个HttpClient类提供给用户,方便用户使用。
使用HttpClient最简单只需要三步,1、创建一个HttpClient对象。2、加载Client端的prototxt配置文件(本例中为python/examples/fit_a_line/目录下的uci_housing_client/serving_client_conf.prototxt),3、调用Predict函数,通过Http方式请求预测服务。 使用HttpClient最简单只需要四步,1、创建一个HttpClient对象。2、加载Client端的prototxt配置文件(本例中为python/examples/fit_a_line/目录下的uci_housing_client/serving_client_conf.prototxt)。3、调用coonect函数。4、调用Predict函数,通过Http方式请求预测服务。
此外,您可以根据自己的需要配置Server端IP、Port、服务名称(此服务名称需要与[`core/general-server/proto/general_model_service.proto`](../core/general-server/proto/general_model_service.proto)文件中的Service服务名和rpc方法名对应,即`GeneralModelService`字段和`inference`字段),设置Request数据体压缩,设置Response支持压缩传输,模型加密预测(需要配置Server端使用模型加密)、设置响应超时时间等功能。 此外,您可以根据自己的需要配置Server端IP、Port、服务名称(此服务名称需要与[`core/general-server/proto/general_model_service.proto`](../core/general-server/proto/general_model_service.proto)文件中的Service服务名和rpc方法名对应,即`GeneralModelService`字段和`inference`字段),设置Request数据体压缩,设置Response支持压缩传输,模型加密预测(需要配置Server端使用模型加密)、设置响应超时时间等功能。
...@@ -103,7 +103,7 @@ repeated int32 numbers = 1; ...@@ -103,7 +103,7 @@ repeated int32 numbers = 1;
``` ```
#### elem_type #### elem_type
表示数据类型,0 means int64, 1 means float32, 2 means int32, 3 means bytes(string) 表示数据类型,0 means int64, 1 means float32, 2 means int32, 20 means bytes(string)
#### fetch_var_names #### fetch_var_names
......
...@@ -59,9 +59,20 @@ import java.util.zip.GZIPInputStream; ...@@ -59,9 +59,20 @@ import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream; import java.util.zip.GZIPOutputStream;
enum ElementType class ElementType {
{ public static final int Int64_type = 0;
Int64_type, Float32_type, Int32_type, Bytes_type; public static final int Float32_type = 1;
public static final int Int32_type = 2;
public static final int String_type = 20;
public static final Map<Integer, String> feedTypeToDataKey_;
static
{
feedTypeToDataKey_ = new HashMap<Integer, String>();
feedTypeToDataKey_.put(ElementType.Int64_type, "int64_data");
feedTypeToDataKey_.put(ElementType.Float32_type, "float_data");
feedTypeToDataKey_.put(ElementType.Int32_type, "int_data");
feedTypeToDataKey_.put(ElementType.String_type, "data");
}
} }
class Profiler { class Profiler {
...@@ -104,7 +115,6 @@ public class Client { ...@@ -104,7 +115,6 @@ public class Client {
private Map<String, Integer> feedTypes_; private Map<String, Integer> feedTypes_;
private Map<String, List<Integer>> feedShapes_; private Map<String, List<Integer>> feedShapes_;
private Map<String, Integer> feedNameToIndex_; private Map<String, Integer> feedNameToIndex_;
private Map<Integer, String> feedTypeToDataKey_;
private List<String> fetchNames_; private List<String> fetchNames_;
private Map<String, Integer> fetchTypes_; private Map<String, Integer> fetchTypes_;
private Set<String> lodTensorSet_; private Set<String> lodTensorSet_;
...@@ -147,12 +157,6 @@ public class Client { ...@@ -147,12 +157,6 @@ public class Client {
channel_ = null; channel_ = null;
blockingStub_ = null; blockingStub_ = null;
feedTypeToDataKey_ = new HashMap<Integer, String>();
feedTypeToDataKey_.put(0, "int64_data");
feedTypeToDataKey_.put(1, "float_data");
feedTypeToDataKey_.put(2, "int_data");
feedTypeToDataKey_.put(3, "data");
profiler_ = new Profiler(); profiler_ = new Profiler();
boolean is_profile = false; boolean is_profile = false;
String FLAGS_profile_client = System.getenv("FLAGS_profile_client"); String FLAGS_profile_client = System.getenv("FLAGS_profile_client");
...@@ -525,7 +529,7 @@ public class Client { ...@@ -525,7 +529,7 @@ public class Client {
jsonTensor.put("elem_type", element_type); jsonTensor.put("elem_type", element_type);
// 处理数据与shape // 处理数据与shape
String protoDataKey = feedTypeToDataKey_.get(element_type); String protoDataKey = ElementType.feedTypeToDataKey_.get(element_type);
// 如果是INDArray类型,先转为一维. // 如果是INDArray类型,先转为一维.
// 此时shape为INDArray的shape // 此时shape为INDArray的shape
if(objectValue instanceof INDArray){ if(objectValue instanceof INDArray){
...@@ -535,11 +539,11 @@ public class Client { ...@@ -535,11 +539,11 @@ public class Client {
for(long dim:indarrayShape){ for(long dim:indarrayShape){
shape.add((int)dim); shape.add((int)dim);
} }
if(element_type == ElementType.Int64_type.ordinal()){ if(element_type == ElementType.Int64_type){
objectValue = tempIndArray.data().asLong(); objectValue = tempIndArray.data().asLong();
}else if(element_type == ElementType.Int32_type.ordinal()){ }else if(element_type == ElementType.Int32_type){
objectValue = tempIndArray.data().asInt(); objectValue = tempIndArray.data().asInt();
}else if(element_type == ElementType.Float32_type.ordinal()){ }else if(element_type == ElementType.Float32_type){
objectValue = tempIndArray.data().asFloat(); objectValue = tempIndArray.data().asFloat();
}else{ }else{
throw new Exception("INDArray 类型不支持"); throw new Exception("INDArray 类型不支持");
...@@ -564,11 +568,11 @@ public class Client { ...@@ -564,11 +568,11 @@ public class Client {
// 此时无法获取batch信息,故对shape不处理 // 此时无法获取batch信息,故对shape不处理
// 由于Proto中为Repeated,需要把数据包装成list // 由于Proto中为Repeated,需要把数据包装成list
if(objectValue instanceof String){ if(objectValue instanceof String){
if(feedTypes_.get(protoDataKey)!= ElementType.Bytes_type.ordinal()){ if(feedTypes_.get(protoDataKey)!= ElementType.String_type){
throw new Exception("feedvar is not string-type,feed can`t be a single string."); throw new Exception("feedvar is not string-type,feed can`t be a single string.");
} }
}else{ }else{
if(feedTypes_.get(protoDataKey)== ElementType.Bytes_type.ordinal()){ if(feedTypes_.get(protoDataKey)== ElementType.String_type){
throw new Exception("feedvar is string-type,feed, feed can`t be a single int or others."); throw new Exception("feedvar is string-type,feed, feed can`t be a single int or others.");
} }
} }
...@@ -662,17 +666,17 @@ public class Client { ...@@ -662,17 +666,17 @@ public class Client {
for(long dim:indarrayShape){ for(long dim:indarrayShape){
shape.add((int)dim); shape.add((int)dim);
} }
if(element_type == ElementType.Int64_type.ordinal()){ if(element_type == ElementType.Int64_type){
List<Long> iter = Arrays.stream(tempIndArray.data().asLong()).boxed().collect(Collectors.toList()); List<Long> iter = Arrays.stream(tempIndArray.data().asLong()).boxed().collect(Collectors.toList());
tensor_builder.addAllInt64Data(iter); tensor_builder.addAllInt64Data(iter);
}else if(element_type == ElementType.Int32_type.ordinal()){ }else if(element_type == ElementType.Int32_type){
List<Integer> iter = Arrays.stream(tempIndArray.data().asInt()).boxed().collect(Collectors.toList()); List<Integer> iter = Arrays.stream(tempIndArray.data().asInt()).boxed().collect(Collectors.toList());
tensor_builder.addAllIntData(iter); tensor_builder.addAllIntData(iter);
}else if(element_type == ElementType.Float32_type.ordinal()){ }else if(element_type == ElementType.Float32_type){
List<Float> iter = Arrays.asList(ArrayUtils.toObject(tempIndArray.data().asFloat())); List<Float> iter = Arrays.asList(ArrayUtils.toObject(tempIndArray.data().asFloat()));
tensor_builder.addAllFloatData(iter); tensor_builder.addAllFloatData(iter);
...@@ -684,13 +688,13 @@ public class Client { ...@@ -684,13 +688,13 @@ public class Client {
// 如果是数组类型,则无须处理,直接使用即可。 // 如果是数组类型,则无须处理,直接使用即可。
// 且数组无法嵌套,此时batch无法从数据中获取 // 且数组无法嵌套,此时batch无法从数据中获取
// 默认batch维度为1,或者feedVar的shape信息中已包含batch // 默认batch维度为1,或者feedVar的shape信息中已包含batch
if(element_type == ElementType.Int64_type.ordinal()){ if(element_type == ElementType.Int64_type){
List<Long> iter = Arrays.stream((long[])objectValue).boxed().collect(Collectors.toList()); List<Long> iter = Arrays.stream((long[])objectValue).boxed().collect(Collectors.toList());
tensor_builder.addAllInt64Data(iter); tensor_builder.addAllInt64Data(iter);
}else if(element_type == ElementType.Int32_type.ordinal()){ }else if(element_type == ElementType.Int32_type){
List<Integer> iter = Arrays.stream((int[])objectValue).boxed().collect(Collectors.toList()); List<Integer> iter = Arrays.stream((int[])objectValue).boxed().collect(Collectors.toList());
tensor_builder.addAllIntData(iter); tensor_builder.addAllIntData(iter);
}else if(element_type == ElementType.Float32_type.ordinal()){ }else if(element_type == ElementType.Float32_type){
List<Float> iter = Arrays.asList(ArrayUtils.toObject((float[])objectValue)); List<Float> iter = Arrays.asList(ArrayUtils.toObject((float[])objectValue));
tensor_builder.addAllFloatData(iter); tensor_builder.addAllFloatData(iter);
}else{ }else{
...@@ -707,11 +711,11 @@ public class Client { ...@@ -707,11 +711,11 @@ public class Client {
// 在index=0处,加上batch // 在index=0处,加上batch
shape.add(0, list.size()); shape.add(0, list.size());
} }
if(element_type == ElementType.Int64_type.ordinal()){ if(element_type == ElementType.Int64_type){
tensor_builder.addAllInt64Data((List<Long>)(List)recursiveExtract(objectValue)); tensor_builder.addAllInt64Data((List<Long>)(List)recursiveExtract(objectValue));
}else if(element_type == ElementType.Int32_type.ordinal()){ }else if(element_type == ElementType.Int32_type){
tensor_builder.addAllIntData((List<Integer>)(List)recursiveExtract(objectValue)); tensor_builder.addAllIntData((List<Integer>)(List)recursiveExtract(objectValue));
}else if(element_type == ElementType.Float32_type.ordinal()){ }else if(element_type == ElementType.Float32_type){
tensor_builder.addAllFloatData((List<Float>)(List)recursiveExtract(objectValue)); tensor_builder.addAllFloatData((List<Float>)(List)recursiveExtract(objectValue));
}else{ }else{
// 看接口是String还是Bytes // 看接口是String还是Bytes
...@@ -723,11 +727,11 @@ public class Client { ...@@ -723,11 +727,11 @@ public class Client {
// 由于Proto中为Repeated,需要把数据包装成list // 由于Proto中为Repeated,需要把数据包装成list
List<Object> tempList = new ArrayList<>(); List<Object> tempList = new ArrayList<>();
tempList.add(objectValue); tempList.add(objectValue);
if(element_type == ElementType.Int64_type.ordinal()){ if(element_type == ElementType.Int64_type){
tensor_builder.addAllInt64Data((List<Long>)(List)tempList); tensor_builder.addAllInt64Data((List<Long>)(List)tempList);
}else if(element_type == ElementType.Int32_type.ordinal()){ }else if(element_type == ElementType.Int32_type){
tensor_builder.addAllIntData((List<Integer>)(List)tempList); tensor_builder.addAllIntData((List<Integer>)(List)tempList);
}else if(element_type == ElementType.Float32_type.ordinal()){ }else if(element_type == ElementType.Float32_type){
tensor_builder.addAllFloatData((List<Float>)(List)tempList); tensor_builder.addAllFloatData((List<Float>)(List)tempList);
}else{ }else{
// 看接口是String还是Bytes // 看接口是String还是Bytes
......
...@@ -119,7 +119,7 @@ The pre-processing and post-processing is in the C + + server part, the image's ...@@ -119,7 +119,7 @@ The pre-processing and post-processing is in the C + + server part, the image's
so the value of parameter `feed_var` which is in the file `ocr_det_client/serving_client_conf.prototxt` should be changed. so the value of parameter `feed_var` which is in the file `ocr_det_client/serving_client_conf.prototxt` should be changed.
for this case, `feed_type` should be 3(which means the data type is string),`shape` should be 1. for this case, `feed_type` should be 20(which means the data type is string),`shape` should be 1.
By passing in multiple client folder paths, the client can be started for multi model prediction. By passing in multiple client folder paths, the client can be started for multi model prediction.
``` ```
......
...@@ -118,7 +118,7 @@ python3 -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --por ...@@ -118,7 +118,7 @@ python3 -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --por
`ocr_det_client/serving_client_conf.prototxt``feed_var`字段 `ocr_det_client/serving_client_conf.prototxt``feed_var`字段
对于本示例而言,`feed_type`应修改为3(数据类型为string),`shape`为1. 对于本示例而言,`feed_type`应修改为20(数据类型为string),`shape`为1.
通过在客户端启动后加入多个client模型的client配置文件夹路径,启动client进行预测。 通过在客户端启动后加入多个client模型的client配置文件夹路径,启动client进行预测。
``` ```
......
...@@ -38,7 +38,12 @@ float32_type = 1 ...@@ -38,7 +38,12 @@ float32_type = 1
int32_type = 2 int32_type = 2
bytes_type = 20 bytes_type = 20
# this is corresponding to the proto # this is corresponding to the proto
proto_data_key_list = ["int64_data", "float_data", "int_data", "data"] proto_data_key_list = {
0: "int64_data",
1: "float_data",
2: "int_data",
20: "data"
}
def list_flatten(items, ignore_types=(str, bytes)): def list_flatten(items, ignore_types=(str, bytes)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册