提交 24d76a63 编写于 作者: H HexToString 提交者: ShiningZhang

fix proto change

上级 f00e91f7
......@@ -42,7 +42,7 @@ python3.6 -m paddle_serving_server.serve --model uci_housing_model --thread 10 -
为了方便用户快速的使用Http方式请求Server端预测服务,我们已经将常用的Http请求的数据体封装、压缩、请求加密等功能封装为一个HttpClient类提供给用户,方便用户使用。
使用HttpClient最简单只需要三步,1、创建一个HttpClient对象。2、加载Client端的prototxt配置文件(本例中为python/examples/fit_a_line/目录下的uci_housing_client/serving_client_conf.prototxt),3、调用Predict函数,通过Http方式请求预测服务。
使用HttpClient最简单只需要四步,1、创建一个HttpClient对象。2、加载Client端的prototxt配置文件(本例中为python/examples/fit_a_line/目录下的uci_housing_client/serving_client_conf.prototxt)。3、调用coonect函数。4、调用Predict函数,通过Http方式请求预测服务。
此外,您可以根据自己的需要配置Server端IP、Port、服务名称(此服务名称需要与[`core/general-server/proto/general_model_service.proto`](../core/general-server/proto/general_model_service.proto)文件中的Service服务名和rpc方法名对应,即`GeneralModelService`字段和`inference`字段),设置Request数据体压缩,设置Response支持压缩传输,模型加密预测(需要配置Server端使用模型加密)、设置响应超时时间等功能。
......@@ -103,7 +103,7 @@ repeated int32 numbers = 1;
```
#### elem_type
表示数据类型,0 means int64, 1 means float32, 2 means int32, 3 means bytes(string)
表示数据类型,0 means int64, 1 means float32, 2 means int32, 20 means bytes(string)
#### fetch_var_names
......
......@@ -59,9 +59,20 @@ import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
enum ElementType
{
Int64_type, Float32_type, Int32_type, Bytes_type;
class ElementType {
public static final int Int64_type = 0;
public static final int Float32_type = 1;
public static final int Int32_type = 2;
public static final int String_type = 20;
public static final Map<Integer, String> feedTypeToDataKey_;
static
{
feedTypeToDataKey_ = new HashMap<Integer, String>();
feedTypeToDataKey_.put(ElementType.Int64_type, "int64_data");
feedTypeToDataKey_.put(ElementType.Float32_type, "float_data");
feedTypeToDataKey_.put(ElementType.Int32_type, "int_data");
feedTypeToDataKey_.put(ElementType.String_type, "data");
}
}
class Profiler {
......@@ -104,7 +115,6 @@ public class Client {
private Map<String, Integer> feedTypes_;
private Map<String, List<Integer>> feedShapes_;
private Map<String, Integer> feedNameToIndex_;
private Map<Integer, String> feedTypeToDataKey_;
private List<String> fetchNames_;
private Map<String, Integer> fetchTypes_;
private Set<String> lodTensorSet_;
......@@ -147,12 +157,6 @@ public class Client {
channel_ = null;
blockingStub_ = null;
feedTypeToDataKey_ = new HashMap<Integer, String>();
feedTypeToDataKey_.put(0, "int64_data");
feedTypeToDataKey_.put(1, "float_data");
feedTypeToDataKey_.put(2, "int_data");
feedTypeToDataKey_.put(3, "data");
profiler_ = new Profiler();
boolean is_profile = false;
String FLAGS_profile_client = System.getenv("FLAGS_profile_client");
......@@ -525,7 +529,7 @@ public class Client {
jsonTensor.put("elem_type", element_type);
// 处理数据与shape
String protoDataKey = feedTypeToDataKey_.get(element_type);
String protoDataKey = ElementType.feedTypeToDataKey_.get(element_type);
// 如果是INDArray类型,先转为一维.
// 此时shape为INDArray的shape
if(objectValue instanceof INDArray){
......@@ -535,11 +539,11 @@ public class Client {
for(long dim:indarrayShape){
shape.add((int)dim);
}
if(element_type == ElementType.Int64_type.ordinal()){
if(element_type == ElementType.Int64_type){
objectValue = tempIndArray.data().asLong();
}else if(element_type == ElementType.Int32_type.ordinal()){
}else if(element_type == ElementType.Int32_type){
objectValue = tempIndArray.data().asInt();
}else if(element_type == ElementType.Float32_type.ordinal()){
}else if(element_type == ElementType.Float32_type){
objectValue = tempIndArray.data().asFloat();
}else{
throw new Exception("INDArray 类型不支持");
......@@ -564,11 +568,11 @@ public class Client {
// 此时无法获取batch信息,故对shape不处理
// 由于Proto中为Repeated,需要把数据包装成list
if(objectValue instanceof String){
if(feedTypes_.get(protoDataKey)!= ElementType.Bytes_type.ordinal()){
if(feedTypes_.get(protoDataKey)!= ElementType.String_type){
throw new Exception("feedvar is not string-type,feed can`t be a single string.");
}
}else{
if(feedTypes_.get(protoDataKey)== ElementType.Bytes_type.ordinal()){
if(feedTypes_.get(protoDataKey)== ElementType.String_type){
throw new Exception("feedvar is string-type,feed, feed can`t be a single int or others.");
}
}
......@@ -662,17 +666,17 @@ public class Client {
for(long dim:indarrayShape){
shape.add((int)dim);
}
if(element_type == ElementType.Int64_type.ordinal()){
if(element_type == ElementType.Int64_type){
List<Long> iter = Arrays.stream(tempIndArray.data().asLong()).boxed().collect(Collectors.toList());
tensor_builder.addAllInt64Data(iter);
}else if(element_type == ElementType.Int32_type.ordinal()){
}else if(element_type == ElementType.Int32_type){
List<Integer> iter = Arrays.stream(tempIndArray.data().asInt()).boxed().collect(Collectors.toList());
tensor_builder.addAllIntData(iter);
}else if(element_type == ElementType.Float32_type.ordinal()){
}else if(element_type == ElementType.Float32_type){
List<Float> iter = Arrays.asList(ArrayUtils.toObject(tempIndArray.data().asFloat()));
tensor_builder.addAllFloatData(iter);
......@@ -684,13 +688,13 @@ public class Client {
// 如果是数组类型,则无须处理,直接使用即可。
// 且数组无法嵌套,此时batch无法从数据中获取
// 默认batch维度为1,或者feedVar的shape信息中已包含batch
if(element_type == ElementType.Int64_type.ordinal()){
if(element_type == ElementType.Int64_type){
List<Long> iter = Arrays.stream((long[])objectValue).boxed().collect(Collectors.toList());
tensor_builder.addAllInt64Data(iter);
}else if(element_type == ElementType.Int32_type.ordinal()){
}else if(element_type == ElementType.Int32_type){
List<Integer> iter = Arrays.stream((int[])objectValue).boxed().collect(Collectors.toList());
tensor_builder.addAllIntData(iter);
}else if(element_type == ElementType.Float32_type.ordinal()){
}else if(element_type == ElementType.Float32_type){
List<Float> iter = Arrays.asList(ArrayUtils.toObject((float[])objectValue));
tensor_builder.addAllFloatData(iter);
}else{
......@@ -707,11 +711,11 @@ public class Client {
// 在index=0处,加上batch
shape.add(0, list.size());
}
if(element_type == ElementType.Int64_type.ordinal()){
if(element_type == ElementType.Int64_type){
tensor_builder.addAllInt64Data((List<Long>)(List)recursiveExtract(objectValue));
}else if(element_type == ElementType.Int32_type.ordinal()){
}else if(element_type == ElementType.Int32_type){
tensor_builder.addAllIntData((List<Integer>)(List)recursiveExtract(objectValue));
}else if(element_type == ElementType.Float32_type.ordinal()){
}else if(element_type == ElementType.Float32_type){
tensor_builder.addAllFloatData((List<Float>)(List)recursiveExtract(objectValue));
}else{
// 看接口是String还是Bytes
......@@ -723,11 +727,11 @@ public class Client {
// 由于Proto中为Repeated,需要把数据包装成list
List<Object> tempList = new ArrayList<>();
tempList.add(objectValue);
if(element_type == ElementType.Int64_type.ordinal()){
if(element_type == ElementType.Int64_type){
tensor_builder.addAllInt64Data((List<Long>)(List)tempList);
}else if(element_type == ElementType.Int32_type.ordinal()){
}else if(element_type == ElementType.Int32_type){
tensor_builder.addAllIntData((List<Integer>)(List)tempList);
}else if(element_type == ElementType.Float32_type.ordinal()){
}else if(element_type == ElementType.Float32_type){
tensor_builder.addAllFloatData((List<Float>)(List)tempList);
}else{
// 看接口是String还是Bytes
......
......@@ -119,7 +119,7 @@ The pre-processing and post-processing is in the C + + server part, the image's
so the value of parameter `feed_var` which is in the file `ocr_det_client/serving_client_conf.prototxt` should be changed.
for this case, `feed_type` should be 3(which means the data type is string),`shape` should be 1.
for this case, `feed_type` should be 20(which means the data type is string),`shape` should be 1.
By passing in multiple client folder paths, the client can be started for multi model prediction.
```
......
......@@ -118,7 +118,7 @@ python3 -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --por
`ocr_det_client/serving_client_conf.prototxt``feed_var`字段
对于本示例而言,`feed_type`应修改为3(数据类型为string),`shape`为1.
对于本示例而言,`feed_type`应修改为20(数据类型为string),`shape`为1.
通过在客户端启动后加入多个client模型的client配置文件夹路径,启动client进行预测。
```
......
......@@ -38,7 +38,12 @@ float32_type = 1
int32_type = 2
bytes_type = 20
# this is corresponding to the proto
proto_data_key_list = ["int64_data", "float_data", "int_data", "data"]
proto_data_key_list = {
0: "int64_data",
1: "float_data",
2: "int_data",
20: "data"
}
def list_flatten(items, ignore_types=(str, bytes)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册