Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Serving
提交
00dd5546
S
Serving
项目概览
PaddlePaddle
/
Serving
1 年多 前同步成功
通知
186
Star
833
Fork
253
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
105
列表
看板
标记
里程碑
合并请求
10
Wiki
2
Wiki
分析
仓库
DevOps
项目成员
Pages
S
Serving
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
105
Issue
105
列表
看板
标记
里程碑
合并请求
10
合并请求
10
Pages
分析
分析
仓库分析
DevOps
Wiki
2
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
00dd5546
编写于
8月 03, 2021
作者:
H
HexToString
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update java
上级
324f4196
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
276 addition
and
94 deletion
+276
-94
java/examples/src/main/java/PaddleServingClientExample.java
java/examples/src/main/java/PaddleServingClientExample.java
+82
-17
java/src/main/java/io/paddle/serving/client/HttpClient.java
java/src/main/java/io/paddle/serving/client/HttpClient.java
+124
-35
python/examples/fit_a_line/test_httpclient.py
python/examples/fit_a_line/test_httpclient.py
+1
-1
python/paddle_serving_client/httpclient.py
python/paddle_serving_client/httpclient.py
+69
-41
未找到文件。
java/examples/src/main/java/PaddleServingClientExample.java
浏览文件 @
00dd5546
...
@@ -11,7 +11,7 @@ import org.nd4j.linalg.factory.Nd4j;
...
@@ -11,7 +11,7 @@ import org.nd4j.linalg.factory.Nd4j;
import
java.util.*
;
import
java.util.*
;
public
class
PaddleServingClientExample
{
public
class
PaddleServingClientExample
{
boolean
fit_a_line
()
{
boolean
fit_a_line
(
String
model_config_path
)
{
float
[]
data
=
{
0.0137f
,
-
0.1136f
,
0.2553f
,
-
0.0692f
,
float
[]
data
=
{
0.0137f
,
-
0.1136f
,
0.2553f
,
-
0.0692f
,
0.0582f
,
-
0.0727f
,
-
0.1583f
,
-
0.0584f
,
0.0582f
,
-
0.0727f
,
-
0.1583f
,
-
0.0584f
,
0.6283f
,
0.4919f
,
0.1856f
,
0.0795f
,
-
0.0332f
};
0.6283f
,
0.4919f
,
0.1856f
,
0.0795f
,
-
0.0332f
};
...
@@ -25,15 +25,69 @@ public class PaddleServingClientExample {
...
@@ -25,15 +25,69 @@ public class PaddleServingClientExample {
List
<
String
>
fetch
=
Arrays
.
asList
(
"price"
);
List
<
String
>
fetch
=
Arrays
.
asList
(
"price"
);
HttpClient
client
=
new
HttpClient
();
HttpClient
client
=
new
HttpClient
();
client
.
setIP
(
"
172.17.0.2
"
);
client
.
setIP
(
"
0.0.0.0
"
);
client
.
setPort
(
"9393"
);
client
.
setPort
(
"9393"
);
client
.
loadClientConfig
(
model_config_path
);
String
result
=
client
.
predict
(
feed_data
,
fetch
,
true
,
0
);
String
result
=
client
.
predict
(
feed_data
,
fetch
,
true
,
0
);
System
.
out
.
println
(
result
);
System
.
out
.
println
(
result
);
return
true
;
return
true
;
}
}
boolean
yolov4
(
String
filename
)
{
boolean
encrypt
(
String
model_config_path
,
String
keyFilePath
)
{
float
[]
data
=
{
0.0137f
,
-
0.1136f
,
0.2553f
,
-
0.0692f
,
0.0582f
,
-
0.0727f
,
-
0.1583f
,
-
0.0584f
,
0.6283f
,
0.4919f
,
0.1856f
,
0.0795f
,
-
0.0332f
};
INDArray
npdata
=
Nd4j
.
createFromArray
(
data
);
long
[]
batch_shape
=
{
1
,
13
};
INDArray
batch_npdata
=
npdata
.
reshape
(
batch_shape
);
HashMap
<
String
,
Object
>
feed_data
=
new
HashMap
<
String
,
Object
>()
{{
put
(
"x"
,
batch_npdata
);
}};
List
<
String
>
fetch
=
Arrays
.
asList
(
"price"
);
HttpClient
client
=
new
HttpClient
();
client
.
setIP
(
"0.0.0.0"
);
client
.
setPort
(
"9393"
);
client
.
loadClientConfig
(
model_config_path
);
client
.
use_key
(
keyFilePath
);
try
{
Thread
.
sleep
(
1000
*
3
);
// 休眠3秒,等待Server启动
}
catch
(
Exception
e
)
{
//TODO: handle exception
}
String
result
=
client
.
predict
(
feed_data
,
fetch
,
true
,
0
);
System
.
out
.
println
(
result
);
return
true
;
}
boolean
compress
(
String
model_config_path
)
{
float
[]
data
=
{
0.0137f
,
-
0.1136f
,
0.2553f
,
-
0.0692f
,
0.0582f
,
-
0.0727f
,
-
0.1583f
,
-
0.0584f
,
0.6283f
,
0.4919f
,
0.1856f
,
0.0795f
,
-
0.0332f
};
INDArray
npdata
=
Nd4j
.
createFromArray
(
data
);
long
[]
batch_shape
=
{
500
,
13
};
INDArray
batch_npdata
=
npdata
.
broadcast
(
batch_shape
);
HashMap
<
String
,
Object
>
feed_data
=
new
HashMap
<
String
,
Object
>()
{{
put
(
"x"
,
batch_npdata
);
}};
List
<
String
>
fetch
=
Arrays
.
asList
(
"price"
);
HttpClient
client
=
new
HttpClient
();
client
.
setIP
(
"0.0.0.0"
);
client
.
setPort
(
"9393"
);
client
.
loadClientConfig
(
model_config_path
);
client
.
set_request_compress
(
true
);
client
.
set_response_compress
(
true
);
String
result
=
client
.
predict
(
feed_data
,
fetch
,
true
,
0
);
System
.
out
.
println
(
result
);
return
true
;
}
boolean
yolov4
(
String
model_config_path
,
String
filename
)
{
// https://deeplearning4j.konduit.ai/
// https://deeplearning4j.konduit.ai/
int
height
=
608
;
int
height
=
608
;
int
width
=
608
;
int
width
=
608
;
...
@@ -74,14 +128,15 @@ public class PaddleServingClientExample {
...
@@ -74,14 +128,15 @@ public class PaddleServingClientExample {
}};
}};
List
<
String
>
fetch
=
Arrays
.
asList
(
"save_infer_model/scale_0.tmp_0"
);
List
<
String
>
fetch
=
Arrays
.
asList
(
"save_infer_model/scale_0.tmp_0"
);
HttpClient
client
=
new
HttpClient
();
HttpClient
client
=
new
HttpClient
();
client
.
setIP
(
"
172.17.0.2
"
);
client
.
setIP
(
"
0.0.0.0
"
);
client
.
setPort
(
"9393"
);
client
.
setPort
(
"9393"
);
client
.
loadClientConfig
(
model_config_path
);
String
result
=
client
.
predict
(
feed_data
,
fetch
,
true
,
0
);
String
result
=
client
.
predict
(
feed_data
,
fetch
,
true
,
0
);
System
.
out
.
println
(
result
);
System
.
out
.
println
(
result
);
return
true
;
return
true
;
}
}
boolean
bert
()
{
boolean
bert
(
String
model_config_path
)
{
float
[]
input_mask
=
{
1.0f
,
1.0f
,
1.0f
,
1.0f
,
1.0f
,
1.0f
,
1.0f
,
1.0f
,
1.0f
,
1.0f
,
1.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
};
float
[]
input_mask
=
{
1.0f
,
1.0f
,
1.0f
,
1.0f
,
1.0f
,
1.0f
,
1.0f
,
1.0f
,
1.0f
,
1.0f
,
1.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
};
long
[]
position_ids
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
long
[]
position_ids
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
long
[]
input_ids
=
{
101
,
6843
,
3241
,
749
,
8024
,
7662
,
2533
,
1391
,
2533
,
2523
,
7676
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
long
[]
input_ids
=
{
101
,
6843
,
3241
,
749
,
8024
,
7662
,
2533
,
1391
,
2533
,
2523
,
7676
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
...
@@ -95,14 +150,15 @@ public class PaddleServingClientExample {
...
@@ -95,14 +150,15 @@ public class PaddleServingClientExample {
}};
}};
List
<
String
>
fetch
=
Arrays
.
asList
(
"pooled_output"
);
List
<
String
>
fetch
=
Arrays
.
asList
(
"pooled_output"
);
HttpClient
client
=
new
HttpClient
();
HttpClient
client
=
new
HttpClient
();
client
.
setIP
(
"
172.17.0.2
"
);
client
.
setIP
(
"
0.0.0.0
"
);
client
.
setPort
(
"9393"
);
client
.
setPort
(
"9393"
);
client
.
loadClientConfig
(
model_config_path
);
String
result
=
client
.
predict
(
feed_data
,
fetch
,
true
,
0
);
String
result
=
client
.
predict
(
feed_data
,
fetch
,
true
,
0
);
System
.
out
.
println
(
result
);
System
.
out
.
println
(
result
);
return
true
;
return
true
;
}
}
boolean
cube_local
()
{
boolean
cube_local
(
String
model_config_path
)
{
long
[]
embedding_14
=
{
250644
};
long
[]
embedding_14
=
{
250644
};
long
[]
embedding_2
=
{
890346
};
long
[]
embedding_2
=
{
890346
};
long
[]
embedding_10
=
{
3939
};
long
[]
embedding_10
=
{
3939
};
...
@@ -164,8 +220,9 @@ public class PaddleServingClientExample {
...
@@ -164,8 +220,9 @@ public class PaddleServingClientExample {
}};
}};
List
<
String
>
fetch
=
Arrays
.
asList
(
"prob"
);
List
<
String
>
fetch
=
Arrays
.
asList
(
"prob"
);
HttpClient
client
=
new
HttpClient
();
HttpClient
client
=
new
HttpClient
();
client
.
setIP
(
"
172.17.0.2
"
);
client
.
setIP
(
"
0.0.0.0
"
);
client
.
setPort
(
"9393"
);
client
.
setPort
(
"9393"
);
client
.
loadClientConfig
(
model_config_path
);
String
result
=
client
.
predict
(
feed_data
,
fetch
,
true
,
0
);
String
result
=
client
.
predict
(
feed_data
,
fetch
,
true
,
0
);
System
.
out
.
println
(
result
);
System
.
out
.
println
(
result
);
return
true
;
return
true
;
...
@@ -177,25 +234,33 @@ public class PaddleServingClientExample {
...
@@ -177,25 +234,33 @@ public class PaddleServingClientExample {
PaddleServingClientExample
e
=
new
PaddleServingClientExample
();
PaddleServingClientExample
e
=
new
PaddleServingClientExample
();
boolean
succ
=
false
;
boolean
succ
=
false
;
if
(
args
.
length
<
1
)
{
if
(
args
.
length
<
2
)
{
System
.
out
.
println
(
"Usage: java -cp <jar> PaddleServingClientExample <test-type>."
);
System
.
out
.
println
(
"Usage: java -cp <jar> PaddleServingClientExample <test-type>
<configPath>
."
);
System
.
out
.
println
(
"<test-type>: fit_a_line bert cube_local yolov4"
);
System
.
out
.
println
(
"<test-type>: fit_a_line bert cube_local yolov4
encrypt
"
);
return
;
return
;
}
}
String
testType
=
args
[
0
];
String
testType
=
args
[
0
];
System
.
out
.
format
(
"[Example] %s\n"
,
testType
);
System
.
out
.
format
(
"[Example] %s\n"
,
testType
);
if
(
"fit_a_line"
.
equals
(
testType
))
{
if
(
"fit_a_line"
.
equals
(
testType
))
{
succ
=
e
.
fit_a_line
();
succ
=
e
.
fit_a_line
(
args
[
1
]);
}
else
if
(
"compress"
.
equals
(
testType
))
{
succ
=
e
.
compress
(
args
[
1
]);
}
else
if
(
"bert"
.
equals
(
testType
))
{
}
else
if
(
"bert"
.
equals
(
testType
))
{
succ
=
e
.
bert
();
succ
=
e
.
bert
(
args
[
1
]
);
}
else
if
(
"cube_local"
.
equals
(
testType
))
{
}
else
if
(
"cube_local"
.
equals
(
testType
))
{
succ
=
e
.
cube_local
();
succ
=
e
.
cube_local
(
args
[
1
]
);
}
else
if
(
"yolov4"
.
equals
(
testType
))
{
}
else
if
(
"yolov4"
.
equals
(
testType
))
{
if
(
args
.
length
<
2
)
{
if
(
args
.
length
<
3
)
{
System
.
out
.
println
(
"Usage: java -cp <jar> PaddleServingClientExample yolov4 <image-filepath>."
);
System
.
out
.
println
(
"Usage: java -cp <jar> PaddleServingClientExample yolov4 <configPath> <image-filepath>."
);
return
;
}
succ
=
e
.
yolov4
(
args
[
1
],
args
[
2
]);
}
else
if
(
"encrypt"
.
equals
(
testType
))
{
if
(
args
.
length
<
3
)
{
System
.
out
.
println
(
"Usage: java -cp <jar> PaddleServingClientExample encrypt <configPath> <keyPath>."
);
return
;
return
;
}
}
succ
=
e
.
yolov4
(
args
[
1
]);
succ
=
e
.
encrypt
(
args
[
1
],
args
[
2
]);
}
else
{
}
else
{
System
.
out
.
format
(
"test-type(%s) not match.\n"
,
testType
);
System
.
out
.
format
(
"test-type(%s) not match.\n"
,
testType
);
return
;
return
;
...
...
java/src/main/java/io/paddle/serving/client/HttpClient.java
浏览文件 @
00dd5546
...
@@ -28,6 +28,7 @@ import org.apache.http.impl.client.CloseableHttpClient;
...
@@ -28,6 +28,7 @@ import org.apache.http.impl.client.CloseableHttpClient;
import
org.apache.http.impl.client.HttpClients
;
import
org.apache.http.impl.client.HttpClients
;
import
org.apache.http.message.BasicNameValuePair
;
import
org.apache.http.message.BasicNameValuePair
;
import
org.apache.http.util.EntityUtils
;
import
org.apache.http.util.EntityUtils
;
import
org.apache.http.entity.InputStreamEntity
;
import
org.json.*
;
import
org.json.*
;
...
@@ -97,6 +98,7 @@ public class HttpClient {
...
@@ -97,6 +98,7 @@ public class HttpClient {
private
String
serviceName
;
private
String
serviceName
;
private
boolean
request_compress_flag
;
private
boolean
request_compress_flag
;
private
boolean
response_compress_flag
;
private
boolean
response_compress_flag
;
private
String
GLOG_v
;
public
HttpClient
()
{
public
HttpClient
()
{
feedNames_
=
null
;
feedNames_
=
null
;
...
@@ -115,6 +117,7 @@ public class HttpClient {
...
@@ -115,6 +117,7 @@ public class HttpClient {
serviceName
=
"/GeneralModelService/inference"
;
serviceName
=
"/GeneralModelService/inference"
;
request_compress_flag
=
false
;
request_compress_flag
=
false
;
response_compress_flag
=
false
;
response_compress_flag
=
false
;
GLOG_v
=
System
.
getenv
(
"GLOG_v"
);
feedTypeToDataKey_
=
new
HashMap
<
Integer
,
String
>();
feedTypeToDataKey_
=
new
HashMap
<
Integer
,
String
>();
feedTypeToDataKey_
.
put
(
0
,
"int64_data"
);
feedTypeToDataKey_
.
put
(
0
,
"int64_data"
);
...
@@ -206,7 +209,7 @@ public class HttpClient {
...
@@ -206,7 +209,7 @@ public class HttpClient {
String
encrypt_url
=
"http://"
+
this
.
ip
+
":"
+
this
.
port
;
String
encrypt_url
=
"http://"
+
this
.
ip
+
":"
+
this
.
port
;
try
{
try
{
byte
[]
data
=
Files
.
readAllBytes
(
Paths
.
get
(
keyFilePath
));
byte
[]
data
=
Files
.
readAllBytes
(
Paths
.
get
(
keyFilePath
));
key_str
=
new
String
(
data
,
"utf-8"
);
key_str
=
Base64
.
getEncoder
().
encodeToString
(
data
);
}
catch
(
Exception
e
)
{
}
catch
(
Exception
e
)
{
System
.
out
.
format
(
"Open key file failed: %s\n"
,
e
.
toString
());
System
.
out
.
format
(
"Open key file failed: %s\n"
,
e
.
toString
());
}
}
...
@@ -237,16 +240,20 @@ public class HttpClient {
...
@@ -237,16 +240,20 @@ public class HttpClient {
this
.
response_compress_flag
=
response_compress_flag
;
this
.
response_compress_flag
=
response_compress_flag
;
}
}
public
static
String
compress
(
String
str
,
String
inEncoding
)
throws
IOException
{
public
byte
[]
compress
(
String
str
)
{
if
(
str
==
null
||
str
.
length
()
==
0
)
{
if
(
str
==
null
||
str
.
length
()
==
0
)
{
return
str
;
return
null
;
}
}
ByteArrayOutputStream
out
=
new
ByteArrayOutputStream
();
ByteArrayOutputStream
out
=
new
ByteArrayOutputStream
();
GZIPOutputStream
gzip
=
new
GZIPOutputStream
(
out
);
GZIPOutputStream
gzip
;
gzip
.
write
(
str
.
getBytes
(
inEncoding
));
try
{
gzip
.
close
();
gzip
=
new
GZIPOutputStream
(
out
);
return
out
.
toString
(
"ISO-8859-1"
);
gzip
.
write
(
str
.
getBytes
(
"UTF-8"
));
gzip
.
close
();
}
catch
(
Exception
e
)
{
e
.
printStackTrace
();
}
return
out
.
toByteArray
();
}
}
// 帮助用户封装Http请求的接口,用户只需要传递FeedData,Lod,Fetchlist即可。
// 帮助用户封装Http请求的接口,用户只需要传递FeedData,Lod,Fetchlist即可。
...
@@ -302,36 +309,83 @@ public class HttpClient {
...
@@ -302,36 +309,83 @@ public class HttpClient {
Object
objectValue
=
mapEntry
.
getValue
();
Object
objectValue
=
mapEntry
.
getValue
();
String
feed_alias_name
=
mapEntry
.
getKey
();
String
feed_alias_name
=
mapEntry
.
getKey
();
String
feed_real_name
=
feedRealNames_
.
get
(
feed_alias_name
);
String
feed_real_name
=
feedRealNames_
.
get
(
feed_alias_name
);
List
<
Integer
>
shape
=
feedShapes_
.
get
(
feed_alias_name
);
List
<
Integer
>
shape
=
new
ArrayList
<
Integer
>(
feedShapes_
.
get
(
feed_alias_name
)
);
int
element_type
=
feedTypes_
.
get
(
feed_alias_name
);
int
element_type
=
feedTypes_
.
get
(
feed_alias_name
);
jsonTensor
.
put
(
"alias_name"
,
feed_alias_name
);
jsonTensor
.
put
(
"name"
,
feed_real_name
);
jsonTensor
.
put
(
"elem_type"
,
element_type
);
// 处理数据与shape
String
protoDataKey
=
feedTypeToDataKey_
.
get
(
element_type
);
String
protoDataKey
=
feedTypeToDataKey_
.
get
(
element_type
);
Object
feedLodValue
=
feedLod
.
get
(
feed_alias_name
);
// 如果是INDArray类型,先转为一维.
// 如果是INDArray类型,先转为一维,再objectValue.ToString.
// 此时shape为INDArray的shape
// 如果是String或List,则直接objectValue.ToString.
if
(
objectValue
instanceof
INDArray
){
if
(
objectValue
.
getClass
().
equals
(
INDArray
.
class
)){
INDArray
tempIndArray
=
(
INDArray
)
objectValue
;
long
[]
flattened_shape
=
{-
1
};
long
[]
indarrayShape
=
tempIndArray
.
shape
();
Class
<?>
classLongArray
=
flattened_shape
.
getClass
();
Method
methodReshape
=
mapEntry
.
getValue
().
getClass
().
getMethod
(
"reshape"
,
classLongArray
);
Method
methodShape
=
mapEntry
.
getValue
().
getClass
().
getMethod
(
"shape"
);
long
[]
indarrayShape
=
(
long
[])
methodShape
.
invoke
(
objectValue
);
shape
.
clear
();
shape
.
clear
();
for
(
long
dim:
indarrayShape
){
for
(
long
dim:
indarrayShape
){
shape
.
add
((
int
)
dim
);
shape
.
add
((
int
)
dim
);
}
}
objectValue
=
methodReshape
.
invoke
(
objectValue
,
flattened_shape
);
objectValue
=
tempIndArray
.
data
().
asDouble
();
}
else
if
(
objectValue
.
getClass
().
isArray
()){
// 如果是数组类型,则无须处理,直接使用即可。
// 且数组无法嵌套,此时batch无法从数据中获取
// 默认batch维度为1,或者feedVar的shape信息中已包含batch
}
else
if
(
objectValue
instanceof
List
){
// 如果为list,可能存在嵌套,此时需要展平
// 如果batchFlag为True,则认为是嵌套list
// 此时取最外层为batch的维度
if
(
batchFlag
)
{
List
<?>
list
=
new
ArrayList
<>();
list
=
new
ArrayList
<>((
Collection
<?>)
objectValue
);
// 在index=0处,加上batch
shape
.
add
(
0
,
list
.
size
());
}
objectValue
=
recursiveExtract
(
objectValue
);
}
else
{
// 此时认为是传入的单个String或者Int等
// 此时无法获取batch信息,故对shape不处理
// 由于Proto中为Repeated,需要把数据包装成list
if
(
objectValue
instanceof
String
){
if
(
feedTypes_
.
get
(
protoDataKey
)!=
ElementType
.
Bytes_type
.
ordinal
()){
throw
new
Exception
(
"feedvar is not string-type,feed can`t be a single string."
);
}
}
else
{
if
(
feedTypes_
.
get
(
protoDataKey
)==
ElementType
.
Bytes_type
.
ordinal
()){
throw
new
Exception
(
"feedvar is string-type,feed, feed can`t be a single int or others."
);
}
}
List
<
Object
>
list
=
new
ArrayList
<>();
list
.
add
(
objectValue
);
objectValue
=
list
;
}
}
if
(
batchFlag
){
jsonTensor
.
put
(
protoDataKey
,
objectValue
);
if
(!
batchFlag
){
// 在index=0处,加上batch=1
// 在index=0处,加上batch=1
shape
.
add
(
0
,
1
);
shape
.
add
(
0
,
1
);
}
}
jsonTensor
.
put
(
"alias_name"
,
feed_alias_name
);
jsonTensor
.
put
(
"name"
,
feed_real_name
);
jsonTensor
.
put
(
"shape"
,
shape
);
jsonTensor
.
put
(
"shape"
,
shape
);
jsonTensor
.
put
(
"elem_type"
,
element_type
);
// 处理lod信息,支持INDArray Array Iterable
jsonTensor
.
put
(
protoDataKey
,
objectValue
);
Object
feedLodValue
=
null
;
if
(
feedLodValue
!=
null
)
{
if
(
feedLod
!=
null
){
jsonTensor
.
put
(
"lod"
,
feedLodValue
);
feedLodValue
=
feedLod
.
get
(
feed_alias_name
);
if
(
feedLodValue
!=
null
)
{
if
(
feedLodValue
instanceof
INDArray
){
INDArray
tempIndArray
=
(
INDArray
)
feedLodValue
;
feedLodValue
=
tempIndArray
.
data
().
asInt
();
}
else
if
(
feedLodValue
.
getClass
().
isArray
()){
// 如果是数组类型,则无须处理,直接使用即可。
}
else
if
(
feedLodValue
instanceof
Iterable
){
// 如果为list,可能存在嵌套,此时需要展平
feedLodValue
=
recursiveExtract
(
feedLodValue
);
}
else
{
throw
new
Exception
(
"Lod must be INDArray or Array or Iterable."
);
}
jsonTensor
.
put
(
"lod"
,
feedLodValue
);
}
}
}
jsonTensorArray
.
put
(
jsonTensor
);
jsonTensorArray
.
put
(
jsonTensor
);
}
}
...
@@ -343,6 +397,9 @@ public class HttpClient {
...
@@ -343,6 +397,9 @@ public class HttpClient {
jsonRequest
.
put
(
"log_id"
,
log_id
);
jsonRequest
.
put
(
"log_id"
,
log_id
);
jsonRequest
.
put
(
"fetch_var_names"
,
jsonFetchList
);
jsonRequest
.
put
(
"fetch_var_names"
,
jsonFetchList
);
jsonRequest
.
put
(
"tensor"
,
jsonTensorArray
);
jsonRequest
.
put
(
"tensor"
,
jsonTensorArray
);
if
(
GLOG_v
!=
null
){
System
.
out
.
format
(
"------- Final jsonRequest: %s\n"
,
jsonRequest
.
toString
());
}
return
doPost
(
server_url
,
jsonRequest
.
toString
());
return
doPost
(
server_url
,
jsonRequest
.
toString
());
}
}
...
@@ -361,29 +418,41 @@ public class HttpClient {
...
@@ -361,29 +418,41 @@ public class HttpClient {
.
build
();
.
build
();
// 为httpPost实例设置配置
// 为httpPost实例设置配置
httpPost
.
setConfig
(
requestConfig
);
httpPost
.
setConfig
(
requestConfig
);
httpPost
.
setHeader
(
"Content-Type"
,
"application/json;charset=utf-8"
);
// 设置请求头
// 设置请求头
httpPost
.
addHeader
(
"Content-Type"
,
"application/json"
);
if
(
response_compress_flag
){
if
(
response_compress_flag
){
httpPost
.
addHeader
(
"Accept-encoding"
,
"gzip"
);
httpPost
.
addHeader
(
"Accept-encoding"
,
"gzip"
);
}
if
(
GLOG_v
!=
null
){
if
(
request_compress_flag
&&
strPostData
.
length
()>
512
){
System
.
out
.
format
(
"------- Accept-encoding gzip: \n"
);
try
{
strPostData
=
compress
(
strPostData
,
"UTF-8"
);
httpPost
.
addHeader
(
"Content-Encoding"
,
"gzip"
);
}
catch
(
IOException
e
)
{
e
.
printStackTrace
();
}
}
}
}
try
{
try
{
httpPost
.
setEntity
(
new
StringEntity
(
strPostData
,
"UTF-8"
));
if
(
request_compress_flag
&&
strPostData
.
length
()>
1024
){
try
{
byte
[]
gzipEncrypt
=
compress
(
strPostData
);
httpPost
.
setEntity
(
new
InputStreamEntity
(
new
ByteArrayInputStream
(
gzipEncrypt
),
gzipEncrypt
.
length
));
httpPost
.
addHeader
(
"Content-Encoding"
,
"gzip"
);
}
catch
(
Exception
e
)
{
e
.
printStackTrace
();
}
}
else
{
httpPost
.
setEntity
(
new
StringEntity
(
strPostData
,
"UTF-8"
));
}
// httpClient对象执行post请求,并返回响应参数对象
// httpClient对象执行post请求,并返回响应参数对象
httpResponse
=
httpClient
.
execute
(
httpPost
);
httpResponse
=
httpClient
.
execute
(
httpPost
);
// 从响应对象中获取响应内容
// 从响应对象中获取响应内容
HttpEntity
entity
=
httpResponse
.
getEntity
();
HttpEntity
entity
=
httpResponse
.
getEntity
();
Header
header
=
entity
.
getContentEncoding
();
Header
header
=
entity
.
getContentEncoding
();
if
(
GLOG_v
!=
null
){
System
.
out
.
format
(
"------- response header: %s\n"
,
header
);
}
if
(
header
!=
null
&&
header
.
getValue
().
equalsIgnoreCase
(
"gzip"
)){
//判断返回内容是否为gzip压缩格式
if
(
header
!=
null
&&
header
.
getValue
().
equalsIgnoreCase
(
"gzip"
)){
//判断返回内容是否为gzip压缩格式
GzipDecompressingEntity
gzipEntity
=
new
GzipDecompressingEntity
(
entity
);
GzipDecompressingEntity
gzipEntity
=
new
GzipDecompressingEntity
(
entity
);
result
=
EntityUtils
.
toString
(
gzipEntity
);
result
=
EntityUtils
.
toString
(
gzipEntity
);
if
(
GLOG_v
!=
null
){
System
.
out
.
format
(
"------- degzip response: %s\n"
,
result
);
}
}
else
{
}
else
{
result
=
EntityUtils
.
toString
(
entity
);
result
=
EntityUtils
.
toString
(
entity
);
}
}
...
@@ -410,5 +479,25 @@ public class HttpClient {
...
@@ -410,5 +479,25 @@ public class HttpClient {
}
}
return
result
;
return
result
;
}
}
public
List
<
Object
>
recursiveExtract
(
Object
stuff
)
{
List
<
Object
>
mylist
=
new
ArrayList
<
Object
>();
if
(
stuff
instanceof
Iterable
)
{
for
(
Object
o
:
(
Iterable
<
?
>)
stuff
)
{
mylist
.
addAll
(
recursiveExtract
(
o
));
}
}
else
if
(
stuff
instanceof
Map
)
{
for
(
Object
o
:
((
Map
<?,
?
extends
Object
>)
stuff
).
values
())
{
mylist
.
addAll
(
recursiveExtract
(
o
));
}
}
else
{
mylist
.
add
(
stuff
);
}
return
mylist
;
}
}
}
python/examples/fit_a_line/test_httpclient.py
浏览文件 @
00dd5546
...
@@ -21,7 +21,7 @@ import time
...
@@ -21,7 +21,7 @@ import time
client
=
HttpClient
()
client
=
HttpClient
()
client
.
load_client_config
(
sys
.
argv
[
1
])
client
.
load_client_config
(
sys
.
argv
[
1
])
# if you want to enable Encrypt Module,uncommenting the following line
# if you want to enable Encrypt Module,uncommenting the following line
#
client.use_key("./key")
client
.
use_key
(
"./key"
)
client
.
set_response_compress
(
True
)
client
.
set_response_compress
(
True
)
client
.
set_request_compress
(
True
)
client
.
set_request_compress
(
True
)
fetch_list
=
client
.
get_fetch_names
()
fetch_list
=
client
.
get_fetch_names
()
...
...
python/paddle_serving_client/httpclient.py
浏览文件 @
00dd5546
...
@@ -42,6 +42,23 @@ def list_flatten(items, ignore_types=(str, bytes)):
...
@@ -42,6 +42,23 @@ def list_flatten(items, ignore_types=(str, bytes)):
yield
x
yield
x
def
data_bytes_number
(
datalist
):
total_bytes_number
=
0
if
isinstance
(
datalist
,
list
):
if
len
(
datalist
)
==
0
:
return
total_bytes_number
else
:
for
data
in
datalist
:
if
isinstance
(
data
,
str
):
total_bytes_number
=
total_bytes_number
+
len
(
data
)
else
:
total_bytes_number
=
total_bytes_number
+
4
*
len
(
datalist
)
break
else
:
raise
ValueError
(
"In the Function data_bytes_number(), data must be list."
)
class
HttpClient
(
object
):
class
HttpClient
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
ip
=
"0.0.0.0"
,
ip
=
"0.0.0.0"
,
...
@@ -157,7 +174,8 @@ class HttpClient(object):
...
@@ -157,7 +174,8 @@ class HttpClient(object):
def
get_fetch_names
(
self
):
def
get_fetch_names
(
self
):
return
self
.
fetch_names_
return
self
.
fetch_names_
# feed 支持Numpy类型,Json-String,以及直接List、tuple
# feed 支持Numpy类型,以及直接List、tuple
# 不支持str类型,因为proto中为repeated.
def
predict
(
self
,
def
predict
(
self
,
feed
=
None
,
feed
=
None
,
fetch
=
None
,
fetch
=
None
,
...
@@ -179,7 +197,7 @@ class HttpClient(object):
...
@@ -179,7 +197,7 @@ class HttpClient(object):
if
isinstance
(
feed
,
dict
):
if
isinstance
(
feed
,
dict
):
feed_batch
.
append
(
feed
)
feed_batch
.
append
(
feed
)
elif
isinstance
(
feed
,
(
list
,
str
,
tuple
)):
elif
isinstance
(
feed
,
(
list
,
str
,
tuple
)):
# if input is a list or str, and the number of feed_var is 1.
# if input is a list or str
or tuple
, and the number of feed_var is 1.
# create a temp_dict { key = feed_var_name, value = list}
# create a temp_dict { key = feed_var_name, value = list}
# put the temp_dict into the feed_batch.
# put the temp_dict into the feed_batch.
if
len
(
self
.
feed_names_
)
!=
1
:
if
len
(
self
.
feed_names_
)
!=
1
:
...
@@ -230,46 +248,55 @@ class HttpClient(object):
...
@@ -230,46 +248,55 @@ class HttpClient(object):
data_value
=
feed_i
[
key
]
data_value
=
feed_i
[
key
]
data_key
=
proto_data_key_list
[
elem_type
]
data_key
=
proto_data_key_list
[
elem_type
]
# 输入不是string类型
# feed_i[key] 可以是np.ndarray
if
self
.
feed_types_
[
key
]
!=
bytes_type
:
# 也可以是list或tuple
# feed_i[key] 可以是np.ndarray
# 当np.ndarray需要处理为list
# 也可以是string或list或tuple
if
isinstance
(
feed_i
[
key
],
np
.
ndarray
):
# 当np.ndarray需要处理为list
shape_lst
=
[]
if
isinstance
(
feed_i
[
key
],
np
.
ndarray
):
# 0维numpy 需要在外层再加一个[]
shape_lst
=
[]
if
feed_i
[
key
].
ndim
==
0
:
# 0维numpy 需要在外层再加一个[]
data_value
=
[
feed_i
[
key
].
tolist
()]
if
feed_i
[
key
].
ndim
==
0
:
shape_lst
.
append
(
1
)
data_value
=
[
feed_i
[
key
].
tolist
()]
else
:
shape_lst
.
append
(
1
)
shape_lst
.
extend
(
list
(
feed_i
[
key
].
shape
))
else
:
shape
=
shape_lst
shape_lst
.
extend
(
list
(
feed_i
[
key
].
shape
))
data_value
=
feed_i
[
key
].
flatten
().
tolist
()
shape
=
shape_lst
# 当Batch为False,shape字段前插一个1,表示batch维
data_value
=
feed_i
[
key
].
flatten
().
tolist
()
# 当Batch为True,则直接使用numpy.shape作为batch维度
# 当Batch为False,shape字段前插一个1,表示batch维
if
batch
==
False
:
# 当Batch为True,则直接使用numpy.shape作为batch维度
shape
.
insert
(
0
,
1
)
if
batch
==
False
:
shape
.
insert
(
0
,
1
)
# 当是list或tuple时,需要把多层嵌套展开
elif
isinstance
(
feed_i
[
key
],
(
list
,
tuple
)):
# 当是list或tuple时,需要把多层嵌套展开
# 当Batch为False,shape字段前插一个1,表示batch维
if
isinstance
(
feed_i
[
key
],
(
list
,
tuple
)):
# 当Batch为True, 由于list并不像numpy那样规整,所以
# 当Batch为False,shape字段前插一个1,表示batch维
# 无法获取shape,此时取第一维度作为Batch维度.
# 当Batch为True, 由于list并不像numpy那样规整,所以
# 插入到feedVar.shape前面.
# 无法获取shape,此时取第一维度作为Batch维度.
if
batch
==
False
:
# 插入到feedVar.shape前面.
shape
.
insert
(
0
,
1
)
if
batch
==
False
:
else
:
shape
.
insert
(
0
,
1
)
shape
.
insert
(
0
,
len
(
feed_i
[
key
]))
else
:
feed_i
[
key
]
=
[
x
for
x
in
list_flatten
(
feed_i
[
key
])]
shape
.
insert
(
0
,
len
(
feed_i
[
key
]))
feed_i
[
key
]
=
[
x
for
x
in
list_flatten
(
feed_i
[
key
])]
data_value
=
feed_i
[
key
]
'''
this is comment, for coder to understand.
#if input is string, feed is not numpy.
else:
shape = self.feed_shapes_[key]
data_value
=
feed_i
[
key
]
data_value
=
feed_i
[
key
]
'''
else
:
total_data_number
=
total_data_number
+
len
(
data_value
)
# 输入可能是单个的str或int值等
# 此时先统一处理为一个list
# 由于输入比较特殊,shape保持原feedvar中不变
data_value
=
[]
data_value
.
append
(
feed_i
[
key
])
if
isinstance
(
feed_i
[
key
],
str
):
if
self
.
feed_types_
[
key
]
!=
bytes_type
:
raise
ValueError
(
"feedvar is not string-type,feed can`t be a single string."
)
else
:
if
self
.
feed_types_
[
key
]
==
bytes_type
:
raise
ValueError
(
"feedvar is string-type,feed, feed can`t be a single int or others."
)
total_data_number
=
total_data_number
+
data_bytes_number
(
data_value
)
Request
[
"tensor"
][
index
][
"elem_type"
]
=
elem_type
Request
[
"tensor"
][
index
][
"elem_type"
]
=
elem_type
Request
[
"tensor"
][
index
][
"shape"
]
=
shape
Request
[
"tensor"
][
index
][
"shape"
]
=
shape
Request
[
"tensor"
][
index
][
data_key
]
=
data_value
Request
[
"tensor"
][
index
][
data_key
]
=
data_value
...
@@ -285,6 +312,7 @@ class HttpClient(object):
...
@@ -285,6 +312,7 @@ class HttpClient(object):
web_url
=
"http://"
+
self
.
ip
+
":"
+
self
.
server_port
+
self
.
service_name
web_url
=
"http://"
+
self
.
ip
+
":"
+
self
.
server_port
+
self
.
service_name
postData
=
json
.
dumps
(
Request
)
postData
=
json
.
dumps
(
Request
)
headers
=
{}
headers
=
{}
# 当数据区长度大于512字节时才压缩.
if
self
.
try_request_gzip
and
total_data_number
>
512
:
if
self
.
try_request_gzip
and
total_data_number
>
512
:
postData
=
gzip
.
compress
(
bytes
(
postData
,
'utf-8'
))
postData
=
gzip
.
compress
(
bytes
(
postData
,
'utf-8'
))
headers
[
"Content-Encoding"
]
=
"gzip"
headers
[
"Content-Encoding"
]
=
"gzip"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录