Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Serving
提交
00dd5546
S
Serving
项目概览
PaddlePaddle
/
Serving
1 年多 前同步成功
通知
186
Star
833
Fork
253
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
105
列表
看板
标记
里程碑
合并请求
10
Wiki
2
Wiki
分析
仓库
DevOps
项目成员
Pages
S
Serving
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
105
Issue
105
列表
看板
标记
里程碑
合并请求
10
合并请求
10
Pages
分析
分析
仓库分析
DevOps
Wiki
2
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
00dd5546
编写于
3年前
作者:
H
HexToString
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update java
上级
324f4196
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
276 addition
and
94 deletion
+276
-94
java/examples/src/main/java/PaddleServingClientExample.java
java/examples/src/main/java/PaddleServingClientExample.java
+82
-17
java/src/main/java/io/paddle/serving/client/HttpClient.java
java/src/main/java/io/paddle/serving/client/HttpClient.java
+124
-35
python/examples/fit_a_line/test_httpclient.py
python/examples/fit_a_line/test_httpclient.py
+1
-1
python/paddle_serving_client/httpclient.py
python/paddle_serving_client/httpclient.py
+69
-41
未找到文件。
java/examples/src/main/java/PaddleServingClientExample.java
浏览文件 @
00dd5546
...
...
@@ -11,7 +11,7 @@ import org.nd4j.linalg.factory.Nd4j;
import
java.util.*
;
public
class
PaddleServingClientExample
{
boolean
fit_a_line
()
{
boolean
fit_a_line
(
String
model_config_path
)
{
float
[]
data
=
{
0.0137f
,
-
0.1136f
,
0.2553f
,
-
0.0692f
,
0.0582f
,
-
0.0727f
,
-
0.1583f
,
-
0.0584f
,
0.6283f
,
0.4919f
,
0.1856f
,
0.0795f
,
-
0.0332f
};
...
...
@@ -25,15 +25,69 @@ public class PaddleServingClientExample {
List
<
String
>
fetch
=
Arrays
.
asList
(
"price"
);
HttpClient
client
=
new
HttpClient
();
client
.
setIP
(
"
172.17.0.2
"
);
client
.
setIP
(
"
0.0.0.0
"
);
client
.
setPort
(
"9393"
);
client
.
loadClientConfig
(
model_config_path
);
String
result
=
client
.
predict
(
feed_data
,
fetch
,
true
,
0
);
System
.
out
.
println
(
result
);
return
true
;
}
boolean
yolov4
(
String
filename
)
{
boolean
encrypt
(
String
model_config_path
,
String
keyFilePath
)
{
float
[]
data
=
{
0.0137f
,
-
0.1136f
,
0.2553f
,
-
0.0692f
,
0.0582f
,
-
0.0727f
,
-
0.1583f
,
-
0.0584f
,
0.6283f
,
0.4919f
,
0.1856f
,
0.0795f
,
-
0.0332f
};
INDArray
npdata
=
Nd4j
.
createFromArray
(
data
);
long
[]
batch_shape
=
{
1
,
13
};
INDArray
batch_npdata
=
npdata
.
reshape
(
batch_shape
);
HashMap
<
String
,
Object
>
feed_data
=
new
HashMap
<
String
,
Object
>()
{{
put
(
"x"
,
batch_npdata
);
}};
List
<
String
>
fetch
=
Arrays
.
asList
(
"price"
);
HttpClient
client
=
new
HttpClient
();
client
.
setIP
(
"0.0.0.0"
);
client
.
setPort
(
"9393"
);
client
.
loadClientConfig
(
model_config_path
);
client
.
use_key
(
keyFilePath
);
try
{
Thread
.
sleep
(
1000
*
3
);
// 休眠3秒,等待Server启动
}
catch
(
Exception
e
)
{
//TODO: handle exception
}
String
result
=
client
.
predict
(
feed_data
,
fetch
,
true
,
0
);
System
.
out
.
println
(
result
);
return
true
;
}
boolean
compress
(
String
model_config_path
)
{
float
[]
data
=
{
0.0137f
,
-
0.1136f
,
0.2553f
,
-
0.0692f
,
0.0582f
,
-
0.0727f
,
-
0.1583f
,
-
0.0584f
,
0.6283f
,
0.4919f
,
0.1856f
,
0.0795f
,
-
0.0332f
};
INDArray
npdata
=
Nd4j
.
createFromArray
(
data
);
long
[]
batch_shape
=
{
500
,
13
};
INDArray
batch_npdata
=
npdata
.
broadcast
(
batch_shape
);
HashMap
<
String
,
Object
>
feed_data
=
new
HashMap
<
String
,
Object
>()
{{
put
(
"x"
,
batch_npdata
);
}};
List
<
String
>
fetch
=
Arrays
.
asList
(
"price"
);
HttpClient
client
=
new
HttpClient
();
client
.
setIP
(
"0.0.0.0"
);
client
.
setPort
(
"9393"
);
client
.
loadClientConfig
(
model_config_path
);
client
.
set_request_compress
(
true
);
client
.
set_response_compress
(
true
);
String
result
=
client
.
predict
(
feed_data
,
fetch
,
true
,
0
);
System
.
out
.
println
(
result
);
return
true
;
}
boolean
yolov4
(
String
model_config_path
,
String
filename
)
{
// https://deeplearning4j.konduit.ai/
int
height
=
608
;
int
width
=
608
;
...
...
@@ -74,14 +128,15 @@ public class PaddleServingClientExample {
}};
List
<
String
>
fetch
=
Arrays
.
asList
(
"save_infer_model/scale_0.tmp_0"
);
HttpClient
client
=
new
HttpClient
();
client
.
setIP
(
"
172.17.0.2
"
);
client
.
setIP
(
"
0.0.0.0
"
);
client
.
setPort
(
"9393"
);
client
.
loadClientConfig
(
model_config_path
);
String
result
=
client
.
predict
(
feed_data
,
fetch
,
true
,
0
);
System
.
out
.
println
(
result
);
return
true
;
}
boolean
bert
()
{
boolean
bert
(
String
model_config_path
)
{
float
[]
input_mask
=
{
1.0f
,
1.0f
,
1.0f
,
1.0f
,
1.0f
,
1.0f
,
1.0f
,
1.0f
,
1.0f
,
1.0f
,
1.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
,
0.0f
};
long
[]
position_ids
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
long
[]
input_ids
=
{
101
,
6843
,
3241
,
749
,
8024
,
7662
,
2533
,
1391
,
2533
,
2523
,
7676
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
...
...
@@ -95,14 +150,15 @@ public class PaddleServingClientExample {
}};
List
<
String
>
fetch
=
Arrays
.
asList
(
"pooled_output"
);
HttpClient
client
=
new
HttpClient
();
client
.
setIP
(
"
172.17.0.2
"
);
client
.
setIP
(
"
0.0.0.0
"
);
client
.
setPort
(
"9393"
);
client
.
loadClientConfig
(
model_config_path
);
String
result
=
client
.
predict
(
feed_data
,
fetch
,
true
,
0
);
System
.
out
.
println
(
result
);
return
true
;
}
boolean
cube_local
()
{
boolean
cube_local
(
String
model_config_path
)
{
long
[]
embedding_14
=
{
250644
};
long
[]
embedding_2
=
{
890346
};
long
[]
embedding_10
=
{
3939
};
...
...
@@ -164,8 +220,9 @@ public class PaddleServingClientExample {
}};
List
<
String
>
fetch
=
Arrays
.
asList
(
"prob"
);
HttpClient
client
=
new
HttpClient
();
client
.
setIP
(
"
172.17.0.2
"
);
client
.
setIP
(
"
0.0.0.0
"
);
client
.
setPort
(
"9393"
);
client
.
loadClientConfig
(
model_config_path
);
String
result
=
client
.
predict
(
feed_data
,
fetch
,
true
,
0
);
System
.
out
.
println
(
result
);
return
true
;
...
...
@@ -177,25 +234,33 @@ public class PaddleServingClientExample {
PaddleServingClientExample
e
=
new
PaddleServingClientExample
();
boolean
succ
=
false
;
if
(
args
.
length
<
1
)
{
System
.
out
.
println
(
"Usage: java -cp <jar> PaddleServingClientExample <test-type>."
);
System
.
out
.
println
(
"<test-type>: fit_a_line bert cube_local yolov4"
);
if
(
args
.
length
<
2
)
{
System
.
out
.
println
(
"Usage: java -cp <jar> PaddleServingClientExample <test-type>
<configPath>
."
);
System
.
out
.
println
(
"<test-type>: fit_a_line bert cube_local yolov4
encrypt
"
);
return
;
}
String
testType
=
args
[
0
];
System
.
out
.
format
(
"[Example] %s\n"
,
testType
);
if
(
"fit_a_line"
.
equals
(
testType
))
{
succ
=
e
.
fit_a_line
();
succ
=
e
.
fit_a_line
(
args
[
1
]);
}
else
if
(
"compress"
.
equals
(
testType
))
{
succ
=
e
.
compress
(
args
[
1
]);
}
else
if
(
"bert"
.
equals
(
testType
))
{
succ
=
e
.
bert
();
succ
=
e
.
bert
(
args
[
1
]
);
}
else
if
(
"cube_local"
.
equals
(
testType
))
{
succ
=
e
.
cube_local
();
succ
=
e
.
cube_local
(
args
[
1
]
);
}
else
if
(
"yolov4"
.
equals
(
testType
))
{
if
(
args
.
length
<
2
)
{
System
.
out
.
println
(
"Usage: java -cp <jar> PaddleServingClientExample yolov4 <image-filepath>."
);
if
(
args
.
length
<
3
)
{
System
.
out
.
println
(
"Usage: java -cp <jar> PaddleServingClientExample yolov4 <configPath> <image-filepath>."
);
return
;
}
succ
=
e
.
yolov4
(
args
[
1
],
args
[
2
]);
}
else
if
(
"encrypt"
.
equals
(
testType
))
{
if
(
args
.
length
<
3
)
{
System
.
out
.
println
(
"Usage: java -cp <jar> PaddleServingClientExample encrypt <configPath> <keyPath>."
);
return
;
}
succ
=
e
.
yolov4
(
args
[
1
]);
succ
=
e
.
encrypt
(
args
[
1
],
args
[
2
]);
}
else
{
System
.
out
.
format
(
"test-type(%s) not match.\n"
,
testType
);
return
;
...
...
This diff is collapsed.
Click to expand it.
java/src/main/java/io/paddle/serving/client/HttpClient.java
浏览文件 @
00dd5546
...
...
@@ -28,6 +28,7 @@ import org.apache.http.impl.client.CloseableHttpClient;
import
org.apache.http.impl.client.HttpClients
;
import
org.apache.http.message.BasicNameValuePair
;
import
org.apache.http.util.EntityUtils
;
import
org.apache.http.entity.InputStreamEntity
;
import
org.json.*
;
...
...
@@ -97,6 +98,7 @@ public class HttpClient {
private
String
serviceName
;
private
boolean
request_compress_flag
;
private
boolean
response_compress_flag
;
private
String
GLOG_v
;
public
HttpClient
()
{
feedNames_
=
null
;
...
...
@@ -115,6 +117,7 @@ public class HttpClient {
serviceName
=
"/GeneralModelService/inference"
;
request_compress_flag
=
false
;
response_compress_flag
=
false
;
GLOG_v
=
System
.
getenv
(
"GLOG_v"
);
feedTypeToDataKey_
=
new
HashMap
<
Integer
,
String
>();
feedTypeToDataKey_
.
put
(
0
,
"int64_data"
);
...
...
@@ -206,7 +209,7 @@ public class HttpClient {
String
encrypt_url
=
"http://"
+
this
.
ip
+
":"
+
this
.
port
;
try
{
byte
[]
data
=
Files
.
readAllBytes
(
Paths
.
get
(
keyFilePath
));
key_str
=
new
String
(
data
,
"utf-8"
);
key_str
=
Base64
.
getEncoder
().
encodeToString
(
data
);
}
catch
(
Exception
e
)
{
System
.
out
.
format
(
"Open key file failed: %s\n"
,
e
.
toString
());
}
...
...
@@ -237,16 +240,20 @@ public class HttpClient {
this
.
response_compress_flag
=
response_compress_flag
;
}
public
static
String
compress
(
String
str
,
String
inEncoding
)
throws
IOException
{
public
byte
[]
compress
(
String
str
)
{
if
(
str
==
null
||
str
.
length
()
==
0
)
{
return
str
;
return
null
;
}
ByteArrayOutputStream
out
=
new
ByteArrayOutputStream
();
GZIPOutputStream
gzip
=
new
GZIPOutputStream
(
out
);
gzip
.
write
(
str
.
getBytes
(
inEncoding
));
GZIPOutputStream
gzip
;
try
{
gzip
=
new
GZIPOutputStream
(
out
);
gzip
.
write
(
str
.
getBytes
(
"UTF-8"
));
gzip
.
close
();
return
out
.
toString
(
"ISO-8859-1"
);
}
catch
(
Exception
e
)
{
e
.
printStackTrace
();
}
return
out
.
toByteArray
();
}
// 帮助用户封装Http请求的接口,用户只需要传递FeedData,Lod,Fetchlist即可。
...
...
@@ -302,37 +309,84 @@ public class HttpClient {
Object
objectValue
=
mapEntry
.
getValue
();
String
feed_alias_name
=
mapEntry
.
getKey
();
String
feed_real_name
=
feedRealNames_
.
get
(
feed_alias_name
);
List
<
Integer
>
shape
=
feedShapes_
.
get
(
feed_alias_name
);
List
<
Integer
>
shape
=
new
ArrayList
<
Integer
>(
feedShapes_
.
get
(
feed_alias_name
)
);
int
element_type
=
feedTypes_
.
get
(
feed_alias_name
);
jsonTensor
.
put
(
"alias_name"
,
feed_alias_name
);
jsonTensor
.
put
(
"name"
,
feed_real_name
);
jsonTensor
.
put
(
"elem_type"
,
element_type
);
// 处理数据与shape
String
protoDataKey
=
feedTypeToDataKey_
.
get
(
element_type
);
Object
feedLodValue
=
feedLod
.
get
(
feed_alias_name
);
// 如果是INDArray类型,先转为一维,再objectValue.ToString.
// 如果是String或List,则直接objectValue.ToString.
if
(
objectValue
.
getClass
().
equals
(
INDArray
.
class
)){
long
[]
flattened_shape
=
{-
1
};
Class
<?>
classLongArray
=
flattened_shape
.
getClass
();
Method
methodReshape
=
mapEntry
.
getValue
().
getClass
().
getMethod
(
"reshape"
,
classLongArray
);
Method
methodShape
=
mapEntry
.
getValue
().
getClass
().
getMethod
(
"shape"
);
long
[]
indarrayShape
=
(
long
[])
methodShape
.
invoke
(
objectValue
);
// 如果是INDArray类型,先转为一维.
// 此时shape为INDArray的shape
if
(
objectValue
instanceof
INDArray
){
INDArray
tempIndArray
=
(
INDArray
)
objectValue
;
long
[]
indarrayShape
=
tempIndArray
.
shape
();
shape
.
clear
();
for
(
long
dim:
indarrayShape
){
shape
.
add
((
int
)
dim
);
}
objectValue
=
methodReshape
.
invoke
(
objectValue
,
flattened_shape
);
objectValue
=
tempIndArray
.
data
().
asDouble
();
}
else
if
(
objectValue
.
getClass
().
isArray
()){
// 如果是数组类型,则无须处理,直接使用即可。
// 且数组无法嵌套,此时batch无法从数据中获取
// 默认batch维度为1,或者feedVar的shape信息中已包含batch
}
else
if
(
objectValue
instanceof
List
){
// 如果为list,可能存在嵌套,此时需要展平
// 如果batchFlag为True,则认为是嵌套list
// 此时取最外层为batch的维度
if
(
batchFlag
)
{
List
<?>
list
=
new
ArrayList
<>();
list
=
new
ArrayList
<>((
Collection
<?>)
objectValue
);
// 在index=0处,加上batch
shape
.
add
(
0
,
list
.
size
());
}
objectValue
=
recursiveExtract
(
objectValue
);
}
else
{
// 此时认为是传入的单个String或者Int等
// 此时无法获取batch信息,故对shape不处理
// 由于Proto中为Repeated,需要把数据包装成list
if
(
objectValue
instanceof
String
){
if
(
feedTypes_
.
get
(
protoDataKey
)!=
ElementType
.
Bytes_type
.
ordinal
()){
throw
new
Exception
(
"feedvar is not string-type,feed can`t be a single string."
);
}
if
(
batchFlag
){
}
else
{
if
(
feedTypes_
.
get
(
protoDataKey
)==
ElementType
.
Bytes_type
.
ordinal
()){
throw
new
Exception
(
"feedvar is string-type,feed, feed can`t be a single int or others."
);
}
}
List
<
Object
>
list
=
new
ArrayList
<>();
list
.
add
(
objectValue
);
objectValue
=
list
;
}
jsonTensor
.
put
(
protoDataKey
,
objectValue
);
if
(!
batchFlag
){
// 在index=0处,加上batch=1
shape
.
add
(
0
,
1
);
}
jsonTensor
.
put
(
"alias_name"
,
feed_alias_name
);
jsonTensor
.
put
(
"name"
,
feed_real_name
);
jsonTensor
.
put
(
"shape"
,
shape
);
jsonTensor
.
put
(
"elem_type"
,
element_type
);
jsonTensor
.
put
(
protoDataKey
,
objectValue
);
// 处理lod信息,支持INDArray Array Iterable
Object
feedLodValue
=
null
;
if
(
feedLod
!=
null
){
feedLodValue
=
feedLod
.
get
(
feed_alias_name
);
if
(
feedLodValue
!=
null
)
{
if
(
feedLodValue
instanceof
INDArray
){
INDArray
tempIndArray
=
(
INDArray
)
feedLodValue
;
feedLodValue
=
tempIndArray
.
data
().
asInt
();
}
else
if
(
feedLodValue
.
getClass
().
isArray
()){
// 如果是数组类型,则无须处理,直接使用即可。
}
else
if
(
feedLodValue
instanceof
Iterable
){
// 如果为list,可能存在嵌套,此时需要展平
feedLodValue
=
recursiveExtract
(
feedLodValue
);
}
else
{
throw
new
Exception
(
"Lod must be INDArray or Array or Iterable."
);
}
jsonTensor
.
put
(
"lod"
,
feedLodValue
);
}
}
jsonTensorArray
.
put
(
jsonTensor
);
}
}
...
...
@@ -343,6 +397,9 @@ public class HttpClient {
jsonRequest
.
put
(
"log_id"
,
log_id
);
jsonRequest
.
put
(
"fetch_var_names"
,
jsonFetchList
);
jsonRequest
.
put
(
"tensor"
,
jsonTensorArray
);
if
(
GLOG_v
!=
null
){
System
.
out
.
format
(
"------- Final jsonRequest: %s\n"
,
jsonRequest
.
toString
());
}
return
doPost
(
server_url
,
jsonRequest
.
toString
());
}
...
...
@@ -361,29 +418,41 @@ public class HttpClient {
.
build
();
// 为httpPost实例设置配置
httpPost
.
setConfig
(
requestConfig
);
httpPost
.
setHeader
(
"Content-Type"
,
"application/json;charset=utf-8"
);
// 设置请求头
httpPost
.
addHeader
(
"Content-Type"
,
"application/json"
);
if
(
response_compress_flag
){
httpPost
.
addHeader
(
"Accept-encoding"
,
"gzip"
);
if
(
GLOG_v
!=
null
){
System
.
out
.
format
(
"------- Accept-encoding gzip: \n"
);
}
}
if
(
request_compress_flag
&&
strPostData
.
length
()>
512
){
try
{
if
(
request_compress_flag
&&
strPostData
.
length
()>
1024
){
try
{
strPostData
=
compress
(
strPostData
,
"UTF-8"
);
byte
[]
gzipEncrypt
=
compress
(
strPostData
);
httpPost
.
setEntity
(
new
InputStreamEntity
(
new
ByteArrayInputStream
(
gzipEncrypt
),
gzipEncrypt
.
length
));
httpPost
.
addHeader
(
"Content-Encoding"
,
"gzip"
);
}
catch
(
IO
Exception
e
)
{
}
catch
(
Exception
e
)
{
e
.
printStackTrace
();
}
}
try
{
}
else
{
httpPost
.
setEntity
(
new
StringEntity
(
strPostData
,
"UTF-8"
));
}
// httpClient对象执行post请求,并返回响应参数对象
httpResponse
=
httpClient
.
execute
(
httpPost
);
// 从响应对象中获取响应内容
HttpEntity
entity
=
httpResponse
.
getEntity
();
Header
header
=
entity
.
getContentEncoding
();
if
(
GLOG_v
!=
null
){
System
.
out
.
format
(
"------- response header: %s\n"
,
header
);
}
if
(
header
!=
null
&&
header
.
getValue
().
equalsIgnoreCase
(
"gzip"
)){
//判断返回内容是否为gzip压缩格式
GzipDecompressingEntity
gzipEntity
=
new
GzipDecompressingEntity
(
entity
);
result
=
EntityUtils
.
toString
(
gzipEntity
);
if
(
GLOG_v
!=
null
){
System
.
out
.
format
(
"------- degzip response: %s\n"
,
result
);
}
}
else
{
result
=
EntityUtils
.
toString
(
entity
);
}
...
...
@@ -410,5 +479,25 @@ public class HttpClient {
}
return
result
;
}
public
List
<
Object
>
recursiveExtract
(
Object
stuff
)
{
List
<
Object
>
mylist
=
new
ArrayList
<
Object
>();
if
(
stuff
instanceof
Iterable
)
{
for
(
Object
o
:
(
Iterable
<
?
>)
stuff
)
{
mylist
.
addAll
(
recursiveExtract
(
o
));
}
}
else
if
(
stuff
instanceof
Map
)
{
for
(
Object
o
:
((
Map
<?,
?
extends
Object
>)
stuff
).
values
())
{
mylist
.
addAll
(
recursiveExtract
(
o
));
}
}
else
{
mylist
.
add
(
stuff
);
}
return
mylist
;
}
}
This diff is collapsed.
Click to expand it.
python/examples/fit_a_line/test_httpclient.py
浏览文件 @
00dd5546
...
...
@@ -21,7 +21,7 @@ import time
client
=
HttpClient
()
client
.
load_client_config
(
sys
.
argv
[
1
])
# if you want to enable Encrypt Module,uncommenting the following line
#
client.use_key("./key")
client
.
use_key
(
"./key"
)
client
.
set_response_compress
(
True
)
client
.
set_request_compress
(
True
)
fetch_list
=
client
.
get_fetch_names
()
...
...
This diff is collapsed.
Click to expand it.
python/paddle_serving_client/httpclient.py
浏览文件 @
00dd5546
...
...
@@ -42,6 +42,23 @@ def list_flatten(items, ignore_types=(str, bytes)):
yield
x
def
data_bytes_number
(
datalist
):
total_bytes_number
=
0
if
isinstance
(
datalist
,
list
):
if
len
(
datalist
)
==
0
:
return
total_bytes_number
else
:
for
data
in
datalist
:
if
isinstance
(
data
,
str
):
total_bytes_number
=
total_bytes_number
+
len
(
data
)
else
:
total_bytes_number
=
total_bytes_number
+
4
*
len
(
datalist
)
break
else
:
raise
ValueError
(
"In the Function data_bytes_number(), data must be list."
)
class
HttpClient
(
object
):
def
__init__
(
self
,
ip
=
"0.0.0.0"
,
...
...
@@ -157,7 +174,8 @@ class HttpClient(object):
def
get_fetch_names
(
self
):
return
self
.
fetch_names_
# feed 支持Numpy类型,Json-String,以及直接List、tuple
# feed 支持Numpy类型,以及直接List、tuple
# 不支持str类型,因为proto中为repeated.
def
predict
(
self
,
feed
=
None
,
fetch
=
None
,
...
...
@@ -179,7 +197,7 @@ class HttpClient(object):
if
isinstance
(
feed
,
dict
):
feed_batch
.
append
(
feed
)
elif
isinstance
(
feed
,
(
list
,
str
,
tuple
)):
# if input is a list or str, and the number of feed_var is 1.
# if input is a list or str
or tuple
, and the number of feed_var is 1.
# create a temp_dict { key = feed_var_name, value = list}
# put the temp_dict into the feed_batch.
if
len
(
self
.
feed_names_
)
!=
1
:
...
...
@@ -230,10 +248,8 @@ class HttpClient(object):
data_value
=
feed_i
[
key
]
data_key
=
proto_data_key_list
[
elem_type
]
# 输入不是string类型
if
self
.
feed_types_
[
key
]
!=
bytes_type
:
# feed_i[key] 可以是np.ndarray
# 也可以是string或
list或tuple
# 也可以是
list或tuple
# 当np.ndarray需要处理为list
if
isinstance
(
feed_i
[
key
],
np
.
ndarray
):
shape_lst
=
[]
...
...
@@ -251,7 +267,7 @@ class HttpClient(object):
shape
.
insert
(
0
,
1
)
# 当是list或tuple时,需要把多层嵌套展开
if
isinstance
(
feed_i
[
key
],
(
list
,
tuple
)):
el
if
isinstance
(
feed_i
[
key
],
(
list
,
tuple
)):
# 当Batch为False,shape字段前插一个1,表示batch维
# 当Batch为True, 由于list并不像numpy那样规整,所以
# 无法获取shape,此时取第一维度作为Batch维度.
...
...
@@ -262,14 +278,25 @@ class HttpClient(object):
shape
.
insert
(
0
,
len
(
feed_i
[
key
]))
feed_i
[
key
]
=
[
x
for
x
in
list_flatten
(
feed_i
[
key
])]
data_value
=
feed_i
[
key
]
'''
this is comment, for coder to understand.
#if input is string, feed is not numpy.
else
:
shape = self.feed_shapes_[key]
data_value = feed_i[key]
'''
total_data_number
=
total_data_number
+
len
(
data_value
)
# 输入可能是单个的str或int值等
# 此时先统一处理为一个list
# 由于输入比较特殊,shape保持原feedvar中不变
data_value
=
[]
data_value
.
append
(
feed_i
[
key
])
if
isinstance
(
feed_i
[
key
],
str
):
if
self
.
feed_types_
[
key
]
!=
bytes_type
:
raise
ValueError
(
"feedvar is not string-type,feed can`t be a single string."
)
else
:
if
self
.
feed_types_
[
key
]
==
bytes_type
:
raise
ValueError
(
"feedvar is string-type,feed, feed can`t be a single int or others."
)
total_data_number
=
total_data_number
+
data_bytes_number
(
data_value
)
Request
[
"tensor"
][
index
][
"elem_type"
]
=
elem_type
Request
[
"tensor"
][
index
][
"shape"
]
=
shape
Request
[
"tensor"
][
index
][
data_key
]
=
data_value
...
...
@@ -285,6 +312,7 @@ class HttpClient(object):
web_url
=
"http://"
+
self
.
ip
+
":"
+
self
.
server_port
+
self
.
service_name
postData
=
json
.
dumps
(
Request
)
headers
=
{}
# 当数据区长度大于512字节时才压缩.
if
self
.
try_request_gzip
and
total_data_number
>
512
:
postData
=
gzip
.
compress
(
bytes
(
postData
,
'utf-8'
))
headers
[
"Content-Encoding"
]
=
"gzip"
...
...
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录