Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Serving
提交
26fce3da
S
Serving
项目概览
PaddlePaddle
/
Serving
大约 1 年 前同步成功
通知
186
Star
833
Fork
253
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
105
列表
看板
标记
里程碑
合并请求
10
Wiki
2
Wiki
分析
仓库
DevOps
项目成员
Pages
S
Serving
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
105
Issue
105
列表
看板
标记
里程碑
合并请求
10
合并请求
10
Pages
分析
分析
仓库分析
DevOps
Wiki
2
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
26fce3da
编写于
5月 11, 2021
作者:
J
Jiawei Wang
提交者:
GitHub
5月 11, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' into patch-9
上级
1c174fc2
2318fe1b
变更
38
隐藏空白更改
内联
并排
Showing
38 changed file
with
1221 addition
and
35 deletion
+1221
-35
python/examples/fit_a_line/benchmark.py
python/examples/fit_a_line/benchmark.py
+9
-2
python/examples/pipeline/PaddleDetection/faster_rcnn/000000570688.jpg
...les/pipeline/PaddleDetection/faster_rcnn/000000570688.jpg
+0
-0
python/examples/pipeline/PaddleDetection/faster_rcnn/README.md
...n/examples/pipeline/PaddleDetection/faster_rcnn/README.md
+18
-0
python/examples/pipeline/PaddleDetection/faster_rcnn/benchmark.py
...xamples/pipeline/PaddleDetection/faster_rcnn/benchmark.py
+93
-0
python/examples/pipeline/PaddleDetection/faster_rcnn/benchmark.sh
...xamples/pipeline/PaddleDetection/faster_rcnn/benchmark.sh
+36
-0
python/examples/pipeline/PaddleDetection/faster_rcnn/benchmark_config.yaml
...ipeline/PaddleDetection/faster_rcnn/benchmark_config.yaml
+32
-0
python/examples/pipeline/PaddleDetection/faster_rcnn/config.yml
.../examples/pipeline/PaddleDetection/faster_rcnn/config.yml
+17
-0
python/examples/pipeline/PaddleDetection/faster_rcnn/label_list.txt
...mples/pipeline/PaddleDetection/faster_rcnn/label_list.txt
+80
-0
python/examples/pipeline/PaddleDetection/faster_rcnn/pipeline_http_client.py
...eline/PaddleDetection/faster_rcnn/pipeline_http_client.py
+35
-0
python/examples/pipeline/PaddleDetection/faster_rcnn/web_service.py
...mples/pipeline/PaddleDetection/faster_rcnn/web_service.py
+68
-0
python/examples/pipeline/PaddleDetection/ppyolo_mbv3/000000570688.jpg
...les/pipeline/PaddleDetection/ppyolo_mbv3/000000570688.jpg
+0
-0
python/examples/pipeline/PaddleDetection/ppyolo_mbv3/README.md
...n/examples/pipeline/PaddleDetection/ppyolo_mbv3/README.md
+20
-0
python/examples/pipeline/PaddleDetection/ppyolo_mbv3/benchmark.py
...xamples/pipeline/PaddleDetection/ppyolo_mbv3/benchmark.py
+93
-0
python/examples/pipeline/PaddleDetection/ppyolo_mbv3/benchmark.sh
...xamples/pipeline/PaddleDetection/ppyolo_mbv3/benchmark.sh
+36
-0
python/examples/pipeline/PaddleDetection/ppyolo_mbv3/benchmark_config.yaml
...ipeline/PaddleDetection/ppyolo_mbv3/benchmark_config.yaml
+32
-0
python/examples/pipeline/PaddleDetection/ppyolo_mbv3/config.yml
.../examples/pipeline/PaddleDetection/ppyolo_mbv3/config.yml
+17
-0
python/examples/pipeline/PaddleDetection/ppyolo_mbv3/label_list.txt
...mples/pipeline/PaddleDetection/ppyolo_mbv3/label_list.txt
+80
-0
python/examples/pipeline/PaddleDetection/ppyolo_mbv3/pipeline_http_client.py
...eline/PaddleDetection/ppyolo_mbv3/pipeline_http_client.py
+35
-0
python/examples/pipeline/PaddleDetection/ppyolo_mbv3/web_service.py
...mples/pipeline/PaddleDetection/ppyolo_mbv3/web_service.py
+69
-0
python/examples/pipeline/PaddleDetection/yolov3/000000570688.jpg
...examples/pipeline/PaddleDetection/yolov3/000000570688.jpg
+0
-0
python/examples/pipeline/PaddleDetection/yolov3/README.md
python/examples/pipeline/PaddleDetection/yolov3/README.md
+20
-0
python/examples/pipeline/PaddleDetection/yolov3/benchmark.py
python/examples/pipeline/PaddleDetection/yolov3/benchmark.py
+93
-0
python/examples/pipeline/PaddleDetection/yolov3/benchmark.sh
python/examples/pipeline/PaddleDetection/yolov3/benchmark.sh
+36
-0
python/examples/pipeline/PaddleDetection/yolov3/benchmark_config.yaml
...les/pipeline/PaddleDetection/yolov3/benchmark_config.yaml
+32
-0
python/examples/pipeline/PaddleDetection/yolov3/config.yml
python/examples/pipeline/PaddleDetection/yolov3/config.yml
+17
-0
python/examples/pipeline/PaddleDetection/yolov3/label_list.txt
...n/examples/pipeline/PaddleDetection/yolov3/label_list.txt
+80
-0
python/examples/pipeline/PaddleDetection/yolov3/pipeline_http_client.py
...s/pipeline/PaddleDetection/yolov3/pipeline_http_client.py
+35
-0
python/examples/pipeline/PaddleDetection/yolov3/web_service.py
...n/examples/pipeline/PaddleDetection/yolov3/web_service.py
+68
-0
python/paddle_serving_app/reader/image_reader.py
python/paddle_serving_app/reader/image_reader.py
+4
-1
tools/Dockerfile.cuda10.1-cudnn7.devel
tools/Dockerfile.cuda10.1-cudnn7.devel
+4
-4
tools/Dockerfile.cuda10.2-cudnn8.devel
tools/Dockerfile.cuda10.2-cudnn8.devel
+4
-4
tools/Dockerfile.cuda11-cudnn8.devel
tools/Dockerfile.cuda11-cudnn8.devel
+5
-5
tools/Dockerfile.devel
tools/Dockerfile.devel
+3
-3
tools/Dockerfile.runtime_template
tools/Dockerfile.runtime_template
+5
-1
tools/dockerfiles/build_scripts/install_trt.sh
tools/dockerfiles/build_scripts/install_trt.sh
+5
-15
tools/dockerfiles/build_scripts/install_whl.sh
tools/dockerfiles/build_scripts/install_whl.sh
+16
-0
tools/dockerfiles/build_scripts/soft_link.sh
tools/dockerfiles/build_scripts/soft_link.sh
+22
-0
tools/generate_runtime_docker.sh
tools/generate_runtime_docker.sh
+2
-0
未找到文件。
python/examples/fit_a_line/benchmark.py
浏览文件 @
26fce3da
...
...
@@ -30,6 +30,7 @@ def single_func(idx, resource):
paddle
.
dataset
.
uci_housing
.
train
(),
buf_size
=
500
),
batch_size
=
1
)
total_number
=
sum
(
1
for
_
in
train_reader
())
latency_list
=
[]
if
args
.
request
==
"rpc"
:
client
=
Client
()
...
...
@@ -37,9 +38,12 @@ def single_func(idx, resource):
client
.
connect
([
args
.
endpoint
])
start
=
time
.
time
()
for
data
in
train_reader
():
l_start
=
time
.
time
()
fetch_map
=
client
.
predict
(
feed
=
{
"x"
:
data
[
0
][
0
]},
fetch
=
[
"price"
])
l_end
=
time
.
time
()
latency_list
.
append
(
l_end
*
1000
-
l_start
*
1000
)
end
=
time
.
time
()
return
[[
end
-
start
],
[
total_number
]]
return
[[
end
-
start
],
latency_list
,
[
total_number
]]
elif
args
.
request
==
"http"
:
train_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
...
...
@@ -47,11 +51,14 @@ def single_func(idx, resource):
batch_size
=
1
)
start
=
time
.
time
()
for
data
in
train_reader
():
l_start
=
time
.
time
()
r
=
requests
.
post
(
'http://{}/uci/prediction'
.
format
(
args
.
endpoint
),
data
=
{
"x"
:
data
[
0
]})
l_end
=
time
.
time
()
latency_list
.
append
(
l_end
*
1000
-
l_start
*
1000
)
end
=
time
.
time
()
return
[[
end
-
start
],
[
total_number
]]
return
[[
end
-
start
],
latency_list
,
[
total_number
]]
start
=
time
.
time
()
...
...
python/examples/pipeline/PaddleDetection/faster_rcnn/000000570688.jpg
0 → 100644
浏览文件 @
26fce3da
135.1 KB
python/examples/pipeline/PaddleDetection/faster_rcnn/README.md
0 → 100644
浏览文件 @
26fce3da
# Faster RCNN model on Pipeline Paddle Serving
### Get The Faster RCNN Model
```
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/faster_rcnn_r50_fpn_1x_coco.tar
```
### Start the service
```
tar xf faster_rcnn_r50_fpn_1x_coco.tar
python web_service.py
```
### Perform prediction
```
python pipeline_http_client.py
```
python/examples/pipeline/PaddleDetection/faster_rcnn/benchmark.py
0 → 100644
浏览文件 @
26fce3da
import
sys
import
os
import
yaml
import
requests
import
time
import
json
import
cv2
import
base64
try
:
from
paddle_serving_server_gpu.pipeline
import
PipelineClient
except
ImportError
:
from
paddle_serving_server.pipeline
import
PipelineClient
import
numpy
as
np
from
paddle_serving_client.utils
import
MultiThreadRunner
from
paddle_serving_client.utils
import
benchmark_args
,
show_latency
def
cv2_to_base64
(
image
):
return
base64
.
b64encode
(
image
).
decode
(
'utf8'
)
def
parse_benchmark
(
filein
,
fileout
):
with
open
(
filein
,
"r"
)
as
fin
:
res
=
yaml
.
load
(
fin
)
del_list
=
[]
for
key
in
res
[
"DAG"
].
keys
():
if
"call"
in
key
:
del_list
.
append
(
key
)
for
key
in
del_list
:
del
res
[
"DAG"
][
key
]
with
open
(
fileout
,
"w"
)
as
fout
:
yaml
.
dump
(
res
,
fout
,
default_flow_style
=
False
)
def
gen_yml
(
device
,
gpu_id
):
fin
=
open
(
"config.yml"
,
"r"
)
config
=
yaml
.
load
(
fin
)
fin
.
close
()
config
[
"dag"
][
"tracer"
]
=
{
"interval_s"
:
30
}
if
device
==
"gpu"
:
config
[
"op"
][
"faster_rcnn"
][
"local_service_conf"
][
"device_type"
]
=
1
config
[
"op"
][
"faster_rcnn"
][
"local_service_conf"
][
"devices"
]
=
gpu_id
with
open
(
"config2.yml"
,
"w"
)
as
fout
:
yaml
.
dump
(
config
,
fout
,
default_flow_style
=
False
)
def
run_http
(
idx
,
batch_size
):
print
(
"start thread ({})"
.
format
(
idx
))
url
=
"http://127.0.0.1:18082/faster_rcnn/prediction"
with
open
(
os
.
path
.
join
(
"."
,
"000000570688.jpg"
),
'rb'
)
as
file
:
image_data1
=
file
.
read
()
image
=
cv2_to_base64
(
image_data1
)
start
=
time
.
time
()
while
True
:
data
=
{
"key"
:
[],
"value"
:
[]}
for
j
in
range
(
batch_size
):
data
[
"key"
].
append
(
"image_"
+
str
(
j
))
data
[
"value"
].
append
(
image
)
r
=
requests
.
post
(
url
=
url
,
data
=
json
.
dumps
(
data
))
end
=
time
.
time
()
if
end
-
start
>
70
:
print
(
"70s end"
)
break
return
[[
end
-
start
]]
def
multithread_http
(
thread
,
batch_size
):
multi_thread_runner
=
MultiThreadRunner
()
result
=
multi_thread_runner
.
run
(
run_http
,
thread
,
batch_size
)
def
run_rpc
(
thread
,
batch_size
):
pass
def
multithread_rpc
(
thraed
,
batch_size
):
multi_thread_runner
=
MultiThreadRunner
()
result
=
multi_thread_runner
.
run
(
run_rpc
,
thread
,
batch_size
)
if
__name__
==
"__main__"
:
if
sys
.
argv
[
1
]
==
"yaml"
:
mode
=
sys
.
argv
[
2
]
# brpc/ local predictor
thread
=
int
(
sys
.
argv
[
3
])
device
=
sys
.
argv
[
4
]
gpu_id
=
sys
.
argv
[
5
]
gen_yml
(
device
,
gpu_id
)
elif
sys
.
argv
[
1
]
==
"run"
:
mode
=
sys
.
argv
[
2
]
# http/ rpc
thread
=
int
(
sys
.
argv
[
3
])
batch_size
=
int
(
sys
.
argv
[
4
])
if
mode
==
"http"
:
multithread_http
(
thread
,
batch_size
)
elif
mode
==
"rpc"
:
multithread_rpc
(
thread
,
batch_size
)
elif
sys
.
argv
[
1
]
==
"dump"
:
filein
=
sys
.
argv
[
2
]
fileout
=
sys
.
argv
[
3
]
parse_benchmark
(
filein
,
fileout
)
python/examples/pipeline/PaddleDetection/faster_rcnn/benchmark.sh
0 → 100644
浏览文件 @
26fce3da
export
FLAGS_profile_pipeline
=
1
alias
python3
=
"python3.7"
modelname
=
"faster_rcnn_r50_fpn_1x_coco"
gpu_id
=
"0"
benchmark_config_filename
=
"benchmark_config.yaml"
# HTTP
ps
-ef
|
grep
web_service |
awk
'{print $2}'
| xargs
kill
-9
sleep
3
python3 benchmark.py yaml local_predictor 1 gpu
$gpu_id
rm
-rf
profile_log_
$modelname
for
thread_num
in
1
do
for
batch_size
in
1
do
echo
"#----FasterRCNN thread num:
$thread_num
batch size:
$batch_size
mode:http ----"
>>
profile_log_
$modelname
rm
-rf
PipelineServingLogs
rm
-rf
cpu_utilization.py
python3 web_service.py
>
web.log 2>&1 &
sleep
3
nvidia-smi
--id
=
${
gpu_id
}
--query-compute-apps
=
used_memory
--format
=
csv
-lms
100
>
gpu_use.log 2>&1 &
nvidia-smi
--id
=
${
gpu_id
}
--query-gpu
=
utilization.gpu
--format
=
csv
-lms
100
>
gpu_utilization.log 2>&1 &
echo
"import psutil
\n
cpu_utilization=psutil.cpu_percent(1,False)
\n
print('CPU_UTILIZATION:', cpu_utilization)
\n
"
>
cpu_utilization.py
python3 benchmark.py run http
$thread_num
$batch_size
python3 cpu_utilization.py
>>
profile_log_
$modelname
python3
-m
paddle_serving_server_gpu.profiler
>>
profile_log_
$modelname
ps
-ef
|
grep
web_service |
awk
'{print $2}'
| xargs
kill
-9
ps
-ef
|
grep
nvidia-smi |
awk
'{print $2}'
| xargs
kill
-9
python3 benchmark.py dump benchmark.log benchmark.tmp
mv
benchmark.tmp benchmark.log
awk
'BEGIN {max = 0} {if(NR>1){if ($modelname > max) max=$modelname}} END {print "GPU_MEM:", max}'
gpu_use.log
>>
profile_log_
$modelname
awk
'BEGIN {max = 0} {if(NR>1){if ($modelname > max) max=$modelname}} END {print "GPU_UTIL:", max}'
gpu_utilization.log
>>
profile_log_
$modelname
cat
benchmark.log
>>
profile_log_
$modelname
python3
-m
paddle_serving_server_gpu.parse_profile
--benchmark_cfg
$benchmark_config_filename
--benchmark_log
profile_log_
$modelname
#rm -rf gpu_use.log gpu_utilization.log
done
done
python/examples/pipeline/PaddleDetection/faster_rcnn/benchmark_config.yaml
0 → 100644
浏览文件 @
26fce3da
cuda_version
:
"
10.1"
cudnn_version
:
"
7.6"
trt_version
:
"
6.0"
python_version
:
"
3.7"
gcc_version
:
"
8.2"
paddle_version
:
"
2.0.2"
cpu
:
"
Xeon
6148"
gpu
:
"
P4"
xpu
:
"
None"
api
:
"
"
owner
:
"
wangjiawei04"
model_name
:
"
faster_rcnn"
model_type
:
"
static"
model_source
:
"
paddledetection"
model_url
:
"
"
batch_size
:
1
num_of_samples
:
1000
input_shape
:
"
3,
480,
640"
runtime_device
:
"
gpu"
ir_optim
:
true
enable_memory_optim
:
true
enable_tensorrt
:
false
precision
:
"
fp32"
enable_mkldnn
:
true
cpu_math_library_num_threads
:
"
"
python/examples/pipeline/PaddleDetection/faster_rcnn/config.yml
0 → 100644
浏览文件 @
26fce3da
dag
:
is_thread_op
:
false
tracer
:
interval_s
:
30
http_port
:
18082
op
:
faster_rcnn
:
local_service_conf
:
client_type
:
local_predictor
concurrency
:
2
device_type
:
1
devices
:
'
2'
fetch_list
:
-
save_infer_model/scale_0.tmp_1
model_config
:
serving_server/
rpc_port
:
9998
worker_num
:
20
python/examples/pipeline/PaddleDetection/faster_rcnn/label_list.txt
0 → 100644
浏览文件 @
26fce3da
person
bicycle
car
motorcycle
airplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
couch
potted plant
bed
dining table
toilet
tv
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush
python/examples/pipeline/PaddleDetection/faster_rcnn/pipeline_http_client.py
0 → 100644
浏览文件 @
26fce3da
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# from paddle_serving_server.pipeline import PipelineClient
import
numpy
as
np
import
requests
import
json
import
cv2
import
base64
import
os
def
cv2_to_base64
(
image
):
return
base64
.
b64encode
(
image
).
decode
(
'utf8'
)
url
=
"http://127.0.0.1:18082/faster_rcnn/prediction"
with
open
(
os
.
path
.
join
(
"."
,
"000000570688.jpg"
),
'rb'
)
as
file
:
image_data1
=
file
.
read
()
image
=
cv2_to_base64
(
image_data1
)
for
i
in
range
(
1
):
data
=
{
"key"
:
[
"image"
],
"value"
:
[
image
]}
r
=
requests
.
post
(
url
=
url
,
data
=
json
.
dumps
(
data
))
print
(
r
.
json
())
python/examples/pipeline/PaddleDetection/faster_rcnn/web_service.py
0 → 100644
浏览文件 @
26fce3da
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle_serving_server.web_service
import
WebService
,
Op
import
logging
import
numpy
as
np
import
sys
import
cv2
from
paddle_serving_app.reader
import
*
import
base64
class
FasterRCNNOp
(
Op
):
def
init_op
(
self
):
self
.
img_preprocess
=
Sequential
([
BGR2RGB
(),
Div
(
255.0
),
Normalize
([
0.485
,
0.456
,
0.406
],
[
0.229
,
0.224
,
0.225
],
False
),
Resize
((
640
,
640
)),
Transpose
((
2
,
0
,
1
))
])
self
.
img_postprocess
=
RCNNPostprocess
(
"label_list.txt"
,
"output"
)
def
preprocess
(
self
,
input_dicts
,
data_id
,
log_id
):
(
_
,
input_dict
),
=
input_dicts
.
items
()
imgs
=
[]
#print("keys", input_dict.keys())
for
key
in
input_dict
.
keys
():
data
=
base64
.
b64decode
(
input_dict
[
key
].
encode
(
'utf8'
))
data
=
np
.
fromstring
(
data
,
np
.
uint8
)
im
=
cv2
.
imdecode
(
data
,
cv2
.
IMREAD_COLOR
)
im
=
self
.
img_preprocess
(
im
)
imgs
.
append
({
"image"
:
im
[
np
.
newaxis
,:],
"im_shape"
:
np
.
array
(
list
(
im
.
shape
[
1
:])).
reshape
(
-
1
)[
np
.
newaxis
,:],
"scale_factor"
:
np
.
array
([
1.0
,
1.0
]).
reshape
(
-
1
)[
np
.
newaxis
,:],
})
feed_dict
=
{
"image"
:
np
.
concatenate
([
x
[
"image"
]
for
x
in
imgs
],
axis
=
0
),
"im_shape"
:
np
.
concatenate
([
x
[
"im_shape"
]
for
x
in
imgs
],
axis
=
0
),
"scale_factor"
:
np
.
concatenate
([
x
[
"scale_factor"
]
for
x
in
imgs
],
axis
=
0
)
}
#for key in feed_dict.keys():
# print(key, feed_dict[key].shape)
return
feed_dict
,
False
,
None
,
""
def
postprocess
(
self
,
input_dicts
,
fetch_dict
,
log_id
):
#print(fetch_dict)
res_dict
=
{
"bbox_result"
:
str
(
self
.
img_postprocess
(
fetch_dict
,
visualize
=
False
))}
return
res_dict
,
None
,
""
class
FasterRCNNService
(
WebService
):
def
get_pipeline_response
(
self
,
read_op
):
faster_rcnn_op
=
FasterRCNNOp
(
name
=
"faster_rcnn"
,
input_ops
=
[
read_op
])
return
faster_rcnn_op
fasterrcnn_service
=
FasterRCNNService
(
name
=
"faster_rcnn"
)
fasterrcnn_service
.
prepare_pipeline_config
(
"config2.yml"
)
fasterrcnn_service
.
run_service
()
python/examples/pipeline/PaddleDetection/ppyolo_mbv3/000000570688.jpg
0 → 100644
浏览文件 @
26fce3da
135.1 KB
python/examples/pipeline/PaddleDetection/ppyolo_mbv3/README.md
0 → 100644
浏览文件 @
26fce3da
# PPYOLO model on Pipeline Paddle Serving
(
[
简体中文
](
./README_CN.md
)
|English)
### Get Model
```
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/ppyolo_mbv3_large_coco.tar
```
### Start the service
```
tar xf ppyolo_mbv3_large_coco.tar
python web_service.py
```
### Perform prediction
```
python pipeline_http_client.py
```
python/examples/pipeline/PaddleDetection/ppyolo_mbv3/benchmark.py
0 → 100644
浏览文件 @
26fce3da
import
sys
import
os
import
yaml
import
requests
import
time
import
json
import
cv2
import
base64
try
:
from
paddle_serving_server_gpu.pipeline
import
PipelineClient
except
ImportError
:
from
paddle_serving_server.pipeline
import
PipelineClient
import
numpy
as
np
from
paddle_serving_client.utils
import
MultiThreadRunner
from
paddle_serving_client.utils
import
benchmark_args
,
show_latency
def
cv2_to_base64
(
image
):
return
base64
.
b64encode
(
image
).
decode
(
'utf8'
)
def
parse_benchmark
(
filein
,
fileout
):
with
open
(
filein
,
"r"
)
as
fin
:
res
=
yaml
.
load
(
fin
)
del_list
=
[]
for
key
in
res
[
"DAG"
].
keys
():
if
"call"
in
key
:
del_list
.
append
(
key
)
for
key
in
del_list
:
del
res
[
"DAG"
][
key
]
with
open
(
fileout
,
"w"
)
as
fout
:
yaml
.
dump
(
res
,
fout
,
default_flow_style
=
False
)
def
gen_yml
(
device
,
gpu_id
):
fin
=
open
(
"config.yml"
,
"r"
)
config
=
yaml
.
load
(
fin
)
fin
.
close
()
config
[
"dag"
][
"tracer"
]
=
{
"interval_s"
:
30
}
if
device
==
"gpu"
:
config
[
"op"
][
"ppyolo_mbv3"
][
"local_service_conf"
][
"device_type"
]
=
1
config
[
"op"
][
"ppyolo_mbv3"
][
"local_service_conf"
][
"devices"
]
=
gpu_id
with
open
(
"config2.yml"
,
"w"
)
as
fout
:
yaml
.
dump
(
config
,
fout
,
default_flow_style
=
False
)
def
run_http
(
idx
,
batch_size
):
print
(
"start thread ({})"
.
format
(
idx
))
url
=
"http://127.0.0.1:18082/ppyolo_mbv3/prediction"
with
open
(
os
.
path
.
join
(
"."
,
"000000570688.jpg"
),
'rb'
)
as
file
:
image_data1
=
file
.
read
()
image
=
cv2_to_base64
(
image_data1
)
start
=
time
.
time
()
while
True
:
data
=
{
"key"
:
[],
"value"
:
[]}
for
j
in
range
(
batch_size
):
data
[
"key"
].
append
(
"image_"
+
str
(
j
))
data
[
"value"
].
append
(
image
)
r
=
requests
.
post
(
url
=
url
,
data
=
json
.
dumps
(
data
))
end
=
time
.
time
()
if
end
-
start
>
70
:
print
(
"70s end"
)
break
return
[[
end
-
start
]]
def
multithread_http
(
thread
,
batch_size
):
multi_thread_runner
=
MultiThreadRunner
()
result
=
multi_thread_runner
.
run
(
run_http
,
thread
,
batch_size
)
def
run_rpc
(
thread
,
batch_size
):
pass
def
multithread_rpc
(
thraed
,
batch_size
):
multi_thread_runner
=
MultiThreadRunner
()
result
=
multi_thread_runner
.
run
(
run_rpc
,
thread
,
batch_size
)
if
__name__
==
"__main__"
:
if
sys
.
argv
[
1
]
==
"yaml"
:
mode
=
sys
.
argv
[
2
]
# brpc/ local predictor
thread
=
int
(
sys
.
argv
[
3
])
device
=
sys
.
argv
[
4
]
gpu_id
=
sys
.
argv
[
5
]
gen_yml
(
device
,
gpu_id
)
elif
sys
.
argv
[
1
]
==
"run"
:
mode
=
sys
.
argv
[
2
]
# http/ rpc
thread
=
int
(
sys
.
argv
[
3
])
batch_size
=
int
(
sys
.
argv
[
4
])
if
mode
==
"http"
:
multithread_http
(
thread
,
batch_size
)
elif
mode
==
"rpc"
:
multithread_rpc
(
thread
,
batch_size
)
elif
sys
.
argv
[
1
]
==
"dump"
:
filein
=
sys
.
argv
[
2
]
fileout
=
sys
.
argv
[
3
]
parse_benchmark
(
filein
,
fileout
)
python/examples/pipeline/PaddleDetection/ppyolo_mbv3/benchmark.sh
0 → 100644
浏览文件 @
26fce3da
export
FLAGS_profile_pipeline
=
1
alias
python3
=
"python3.7"
modelname
=
"ppyolo_mbv3_large"
gpu_id
=
"0"
benchmark_config_filename
=
"benchmark_config.yaml"
# HTTP
ps
-ef
|
grep
web_service |
awk
'{print $2}'
| xargs
kill
-9
sleep
3
python3 benchmark.py yaml local_predictor 1 gpu
$gpu_id
rm
-rf
profile_log_
$modelname
for
thread_num
in
1
do
for
batch_size
in
1
do
echo
"#----PPyolo thread num:
$thread_num
batch size:
$batch_size
mode:http ----"
>>
profile_log_
$modelname
rm
-rf
PipelineServingLogs
rm
-rf
cpu_utilization.py
python3 web_service.py
>
web.log 2>&1 &
sleep
3
nvidia-smi
--id
=
${
gpu_id
}
--query-compute-apps
=
used_memory
--format
=
csv
-lms
100
>
gpu_use.log 2>&1 &
nvidia-smi
--id
=
${
gpu_id
}
--query-gpu
=
utilization.gpu
--format
=
csv
-lms
100
>
gpu_utilization.log 2>&1 &
echo
"import psutil
\n
cpu_utilization=psutil.cpu_percent(1,False)
\n
print('CPU_UTILIZATION:', cpu_utilization)
\n
"
>
cpu_utilization.py
python3 benchmark.py run http
$thread_num
$batch_size
python3 cpu_utilization.py
>>
profile_log_
$modelname
python3
-m
paddle_serving_server_gpu.profiler
>>
profile_log_
$modelname
ps
-ef
|
grep
web_service |
awk
'{print $2}'
| xargs
kill
-9
python3 benchmark.py dump benchmark.log benchmark.tmp
mv
benchmark.tmp benchmark.log
awk
'BEGIN {max = 0} {if(NR>1){if ($modelname > max) max=$modelname}} END {print "GPU_MEM:", max}'
gpu_use.log
>>
profile_log_
$modelname
awk
'BEGIN {max = 0} {if(NR>1){if ($modelname > max) max=$modelname}} END {print "GPU_UTIL:", max}'
gpu_utilization.log
>>
profile_log_
$modelname
cat
benchmark.log
>>
profile_log_
$modelname
python3
-m
paddle_serving_server_gpu.parse_profile
--benchmark_cfg
$benchmark_config_filename
--benchmark_log
profile_log_
$modelname
#rm -rf gpu_use.log gpu_utilization.log
done
done
python/examples/pipeline/PaddleDetection/ppyolo_mbv3/benchmark_config.yaml
0 → 100644
浏览文件 @
26fce3da
cuda_version
:
"
10.1"
cudnn_version
:
"
7.6"
trt_version
:
"
6.0"
python_version
:
"
3.7"
gcc_version
:
"
8.2"
paddle_version
:
"
2.0.2"
cpu
:
"
Xeon
6148"
gpu
:
"
P4"
xpu
:
"
None"
api
:
"
"
owner
:
"
wangjiawei04"
model_name
:
"
ppyolo"
model_type
:
"
static"
model_source
:
"
paddledetection"
model_url
:
"
"
batch_size
:
1
num_of_samples
:
1000
input_shape
:
"
3,
480,
640"
runtime_device
:
"
gpu"
ir_optim
:
true
enable_memory_optim
:
true
enable_tensorrt
:
false
precision
:
"
fp32"
enable_mkldnn
:
true
cpu_math_library_num_threads
:
"
"
python/examples/pipeline/PaddleDetection/ppyolo_mbv3/config.yml
0 → 100644
浏览文件 @
26fce3da
dag
:
is_thread_op
:
false
tracer
:
interval_s
:
30
http_port
:
18082
op
:
ppyolo_mbv3
:
local_service_conf
:
client_type
:
local_predictor
concurrency
:
10
device_type
:
1
devices
:
'
2'
fetch_list
:
-
save_infer_model/scale_0.tmp_1
model_config
:
serving_server/
rpc_port
:
9998
worker_num
:
20
python/examples/pipeline/PaddleDetection/ppyolo_mbv3/label_list.txt
0 → 100644
浏览文件 @
26fce3da
person
bicycle
car
motorcycle
airplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
couch
potted plant
bed
dining table
toilet
tv
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush
python/examples/pipeline/PaddleDetection/ppyolo_mbv3/pipeline_http_client.py
0 → 100644
浏览文件 @
26fce3da
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# from paddle_serving_server.pipeline import PipelineClient
import
numpy
as
np
import
requests
import
json
import
cv2
import
base64
import
os
def
cv2_to_base64
(
image
):
return
base64
.
b64encode
(
image
).
decode
(
'utf8'
)
url
=
"http://127.0.0.1:18082/ppyolo_mbv3/prediction"
with
open
(
os
.
path
.
join
(
"."
,
"000000570688.jpg"
),
'rb'
)
as
file
:
image_data1
=
file
.
read
()
image
=
cv2_to_base64
(
image_data1
)
for
i
in
range
(
1
):
data
=
{
"key"
:
[
"image"
],
"value"
:
[
image
]}
r
=
requests
.
post
(
url
=
url
,
data
=
json
.
dumps
(
data
))
print
(
r
.
json
())
python/examples/pipeline/PaddleDetection/ppyolo_mbv3/web_service.py
0 → 100644
浏览文件 @
26fce3da
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle_serving_server.web_service
import
WebService
,
Op
import
logging
import
numpy
as
np
import
sys
import
cv2
from
paddle_serving_app.reader
import
*
import
base64
class
PPYoloMbvOp
(
Op
):
def
init_op
(
self
):
self
.
img_preprocess
=
Sequential
([
BGR2RGB
(),
Div
(
255.0
),
Normalize
([
0.485
,
0.456
,
0.406
],
[
0.229
,
0.224
,
0.225
],
False
),
Resize
((
320
,
320
)),
Transpose
((
2
,
0
,
1
))
])
self
.
img_postprocess
=
RCNNPostprocess
(
"label_list.txt"
,
"output"
)
def
preprocess
(
self
,
input_dicts
,
data_id
,
log_id
):
(
_
,
input_dict
),
=
input_dicts
.
items
()
imgs
=
[]
#print("keys", input_dict.keys())
for
key
in
input_dict
.
keys
():
data
=
base64
.
b64decode
(
input_dict
[
key
].
encode
(
'utf8'
))
data
=
np
.
fromstring
(
data
,
np
.
uint8
)
im
=
cv2
.
imdecode
(
data
,
cv2
.
IMREAD_COLOR
)
im
=
self
.
img_preprocess
(
im
)
imgs
.
append
({
"image"
:
im
[
np
.
newaxis
,:],
"im_shape"
:
np
.
array
(
list
(
im
.
shape
[
1
:])).
reshape
(
-
1
)[
np
.
newaxis
,:],
"scale_factor"
:
np
.
array
([
1.0
,
1.0
]).
reshape
(
-
1
)[
np
.
newaxis
,:],
})
feed_dict
=
{
"image"
:
np
.
concatenate
([
x
[
"image"
]
for
x
in
imgs
],
axis
=
0
),
"im_shape"
:
np
.
concatenate
([
x
[
"im_shape"
]
for
x
in
imgs
],
axis
=
0
),
"scale_factor"
:
np
.
concatenate
([
x
[
"scale_factor"
]
for
x
in
imgs
],
axis
=
0
)
}
for
key
in
feed_dict
.
keys
():
print
(
key
,
feed_dict
[
key
].
shape
)
return
feed_dict
,
False
,
None
,
""
def
postprocess
(
self
,
input_dicts
,
fetch_dict
,
log_id
):
#print(fetch_dict)
res_dict
=
{
"bbox_result"
:
str
(
self
.
img_postprocess
(
fetch_dict
,
visualize
=
False
))}
return
res_dict
,
None
,
""
class
PPYoloMbv
(
WebService
):
def
get_pipeline_response
(
self
,
read_op
):
ppyolo_mbv3_op
=
PPYoloMbvOp
(
name
=
"ppyolo_mbv3"
,
input_ops
=
[
read_op
])
return
ppyolo_mbv3_op
ppyolo_mbv3_service
=
PPYoloMbv
(
name
=
"ppyolo_mbv3"
)
ppyolo_mbv3_service
.
prepare_pipeline_config
(
"config2.yml"
)
ppyolo_mbv3_service
.
run_service
()
python/examples/pipeline/PaddleDetection/yolov3/000000570688.jpg
0 → 100644
浏览文件 @
26fce3da
135.1 KB
python/examples/pipeline/PaddleDetection/yolov3/README.md
0 → 100644
浏览文件 @
26fce3da
# YOLOv3 model on Pipeline Paddle Serving
(
[
简体中文
](
./README_CN.md
)
|English)
### Get Model
```
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/yolov3_darknet53_270e_coco.tar
```
### Start the service
```
tar xf yolov3_darknet53_270e_coco.tar
python web_service.py
```
### Perform prediction
```
python pipeline_http_client.py
```
python/examples/pipeline/PaddleDetection/yolov3/benchmark.py
0 → 100644
浏览文件 @
26fce3da
import
sys
import
os
import
yaml
import
requests
import
time
import
json
import
cv2
import
base64
try
:
from
paddle_serving_server_gpu.pipeline
import
PipelineClient
except
ImportError
:
from
paddle_serving_server.pipeline
import
PipelineClient
import
numpy
as
np
from
paddle_serving_client.utils
import
MultiThreadRunner
from
paddle_serving_client.utils
import
benchmark_args
,
show_latency
def
cv2_to_base64
(
image
):
return
base64
.
b64encode
(
image
).
decode
(
'utf8'
)
def
parse_benchmark
(
filein
,
fileout
):
with
open
(
filein
,
"r"
)
as
fin
:
res
=
yaml
.
load
(
fin
)
del_list
=
[]
for
key
in
res
[
"DAG"
].
keys
():
if
"call"
in
key
:
del_list
.
append
(
key
)
for
key
in
del_list
:
del
res
[
"DAG"
][
key
]
with
open
(
fileout
,
"w"
)
as
fout
:
yaml
.
dump
(
res
,
fout
,
default_flow_style
=
False
)
def
gen_yml
(
device
,
gpu_id
):
fin
=
open
(
"config.yml"
,
"r"
)
config
=
yaml
.
load
(
fin
)
fin
.
close
()
config
[
"dag"
][
"tracer"
]
=
{
"interval_s"
:
30
}
if
device
==
"gpu"
:
config
[
"op"
][
"faster_rcnn"
][
"local_service_conf"
][
"device_type"
]
=
1
config
[
"op"
][
"faster_rcnn"
][
"local_service_conf"
][
"devices"
]
=
gpu_id
with
open
(
"config2.yml"
,
"w"
)
as
fout
:
yaml
.
dump
(
config
,
fout
,
default_flow_style
=
False
)
def
run_http
(
idx
,
batch_size
):
print
(
"start thread ({})"
.
format
(
idx
))
url
=
"http://127.0.0.1:18082/yolov3/prediction"
with
open
(
os
.
path
.
join
(
"."
,
"000000570688.jpg"
),
'rb'
)
as
file
:
image_data1
=
file
.
read
()
image
=
cv2_to_base64
(
image_data1
)
start
=
time
.
time
()
while
True
:
data
=
{
"key"
:
[],
"value"
:
[]}
for
j
in
range
(
batch_size
):
data
[
"key"
].
append
(
"image_"
+
str
(
j
))
data
[
"value"
].
append
(
image
)
r
=
requests
.
post
(
url
=
url
,
data
=
json
.
dumps
(
data
))
end
=
time
.
time
()
if
end
-
start
>
70
:
print
(
"70s end"
)
break
return
[[
end
-
start
]]
def
multithread_http
(
thread
,
batch_size
):
multi_thread_runner
=
MultiThreadRunner
()
result
=
multi_thread_runner
.
run
(
run_http
,
thread
,
batch_size
)
def
run_rpc
(
thread
,
batch_size
):
pass
def
multithread_rpc
(
thraed
,
batch_size
):
multi_thread_runner
=
MultiThreadRunner
()
result
=
multi_thread_runner
.
run
(
run_rpc
,
thread
,
batch_size
)
if
__name__
==
"__main__"
:
if
sys
.
argv
[
1
]
==
"yaml"
:
mode
=
sys
.
argv
[
2
]
# brpc/ local predictor
thread
=
int
(
sys
.
argv
[
3
])
device
=
sys
.
argv
[
4
]
gpu_id
=
sys
.
argv
[
5
]
gen_yml
(
device
,
gpu_id
)
elif
sys
.
argv
[
1
]
==
"run"
:
mode
=
sys
.
argv
[
2
]
# http/ rpc
thread
=
int
(
sys
.
argv
[
3
])
batch_size
=
int
(
sys
.
argv
[
4
])
if
mode
==
"http"
:
multithread_http
(
thread
,
batch_size
)
elif
mode
==
"rpc"
:
multithread_rpc
(
thread
,
batch_size
)
elif
sys
.
argv
[
1
]
==
"dump"
:
filein
=
sys
.
argv
[
2
]
fileout
=
sys
.
argv
[
3
]
parse_benchmark
(
filein
,
fileout
)
python/examples/pipeline/PaddleDetection/yolov3/benchmark.sh
0 → 100644
浏览文件 @
26fce3da
export
FLAGS_profile_pipeline
=
1
alias
python3
=
"python3.7"
modelname
=
"yolov3_darknet53_270e_coco"
gpu_id
=
"0"
benchmark_config_filename
=
"benchmark_config.yaml"
# HTTP
ps
-ef
|
grep
web_service |
awk
'{print $2}'
| xargs
kill
-9
sleep
3
python3 benchmark.py yaml local_predictor 1 cpu
rm
-rf
profile_log_
$modelname
for
thread_num
in
1 8 16
do
for
batch_size
in
1
do
echo
"#----Yolov3 thread num:
$thread_num
batch size:
$batch_size
mode:http ----"
>>
profile_log_
$modelname
rm
-rf
PipelineServingLogs
rm
-rf
cpu_utilization.py
python3 web_service.py
>
web.log 2>&1 &
sleep
3
nvidia-smi
--id
=
${
gpu_id
}
--query-compute-apps
=
used_memory
--format
=
csv
-lms
100
>
gpu_use.log 2>&1 &
nvidia-smi
--id
=
${
gpu_id
}
--query-gpu
=
utilization.gpu
--format
=
csv
-lms
100
>
gpu_utilization.log 2>&1 &
echo
"import psutil
\n
cpu_utilization=psutil.cpu_percent(1,False)
\n
print('CPU_UTILIZATION:', cpu_utilization)
\n
"
>
cpu_utilization.py
python3 benchmark.py run http
$thread_num
$batch_size
python3 cpu_utilization.py
>>
profile_log_
$modelname
python3
-m
paddle_serving_server_gpu.profiler
>>
profile_log_
$modelname
ps
-ef
|
grep
web_service |
awk
'{print $2}'
| xargs
kill
-9
python3 benchmark.py dump benchmark.log benchmark.tmp
mv
benchmark.tmp benchmark.log
awk
'BEGIN {max = 0} {if(NR>1){if ($modelname > max) max=$modelname}} END {print "GPU_MEM:", max}'
gpu_use.log
>>
profile_log_
$modelname
awk
'BEGIN {max = 0} {if(NR>1){if ($modelname > max) max=$modelname}} END {print "GPU_UTIL:", max}'
gpu_utilization.log
>>
profile_log_
$modelname
cat
benchmark.log
>>
profile_log_
$modelname
python3
-m
paddle_serving_server_gpu.parse_profile
--benchmark_cfg
$benchmark_config_filename
--benchmark_log
profile_log_
$modelname
#rm -rf gpu_use.log gpu_utilization.log
done
done
python/examples/pipeline/PaddleDetection/yolov3/benchmark_config.yaml
0 → 100644
浏览文件 @
26fce3da
cuda_version
:
"
10.1"
cudnn_version
:
"
7.6"
trt_version
:
"
6.0"
python_version
:
"
3.7"
gcc_version
:
"
8.2"
paddle_version
:
"
2.0.2"
cpu
:
"
Xeon
6148"
gpu
:
"
P4"
xpu
:
"
None"
api
:
"
"
owner
:
"
wangjiawei04"
model_name
:
"
yolov3"
model_type
:
"
static"
model_source
:
"
paddledetection"
model_url
:
"
"
batch_size
:
1
num_of_samples
:
1000
input_shape
:
"
3,
480,
640"
runtime_device
:
"
gpu"
ir_optim
:
true
enable_memory_optim
:
true
enable_tensorrt
:
false
precision
:
"
fp32"
enable_mkldnn
:
true
cpu_math_library_num_threads
:
"
"
python/examples/pipeline/PaddleDetection/yolov3/config.yml
0 → 100644
浏览文件 @
26fce3da
dag
:
is_thread_op
:
false
tracer
:
interval_s
:
30
http_port
:
18082
op
:
yolov3
:
local_service_conf
:
client_type
:
local_predictor
concurrency
:
10
device_type
:
1
devices
:
'
2'
fetch_list
:
-
save_infer_model/scale_0.tmp_1
model_config
:
serving_server/
rpc_port
:
9998
worker_num
:
20
python/examples/pipeline/PaddleDetection/yolov3/label_list.txt
0 → 100644
浏览文件 @
26fce3da
person
bicycle
car
motorcycle
airplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
couch
potted plant
bed
dining table
toilet
tv
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush
python/examples/pipeline/PaddleDetection/yolov3/pipeline_http_client.py
0 → 100644
浏览文件 @
26fce3da
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# from paddle_serving_server.pipeline import PipelineClient
import
numpy
as
np
import
requests
import
json
import
cv2
import
base64
import
os
def
cv2_to_base64
(
image
):
return
base64
.
b64encode
(
image
).
decode
(
'utf8'
)
url
=
"http://127.0.0.1:18082/yolov3/prediction"
with
open
(
os
.
path
.
join
(
"."
,
"000000570688.jpg"
),
'rb'
)
as
file
:
image_data1
=
file
.
read
()
image
=
cv2_to_base64
(
image_data1
)
for
i
in
range
(
1
):
data
=
{
"key"
:
[
"image"
],
"value"
:
[
image
]}
r
=
requests
.
post
(
url
=
url
,
data
=
json
.
dumps
(
data
))
print
(
r
.
json
())
python/examples/pipeline/PaddleDetection/yolov3/web_service.py
0 → 100644
浏览文件 @
26fce3da
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle_serving_server.web_service
import
WebService
,
Op
import
logging
import
numpy
as
np
import
sys
import
cv2
from
paddle_serving_app.reader
import
*
import
base64
class
Yolov3Op
(
Op
):
def
init_op
(
self
):
self
.
img_preprocess
=
Sequential
([
BGR2RGB
(),
Div
(
255.0
),
Normalize
([
0.485
,
0.456
,
0.406
],
[
0.229
,
0.224
,
0.225
],
False
),
Resize
((
640
,
640
)),
Transpose
((
2
,
0
,
1
))
])
self
.
img_postprocess
=
RCNNPostprocess
(
"label_list.txt"
,
"output"
)
def
preprocess
(
self
,
input_dicts
,
data_id
,
log_id
):
(
_
,
input_dict
),
=
input_dicts
.
items
()
imgs
=
[]
#print("keys", input_dict.keys())
for
key
in
input_dict
.
keys
():
data
=
base64
.
b64decode
(
input_dict
[
key
].
encode
(
'utf8'
))
data
=
np
.
fromstring
(
data
,
np
.
uint8
)
im
=
cv2
.
imdecode
(
data
,
cv2
.
IMREAD_COLOR
)
im
=
self
.
img_preprocess
(
im
)
imgs
.
append
({
"image"
:
im
[
np
.
newaxis
,:],
"im_shape"
:
np
.
array
(
list
(
im
.
shape
[
1
:])).
reshape
(
-
1
)[
np
.
newaxis
,:],
"scale_factor"
:
np
.
array
([
1.0
,
1.0
]).
reshape
(
-
1
)[
np
.
newaxis
,:],
})
feed_dict
=
{
"image"
:
np
.
concatenate
([
x
[
"image"
]
for
x
in
imgs
],
axis
=
0
),
"im_shape"
:
np
.
concatenate
([
x
[
"im_shape"
]
for
x
in
imgs
],
axis
=
0
),
"scale_factor"
:
np
.
concatenate
([
x
[
"scale_factor"
]
for
x
in
imgs
],
axis
=
0
)
}
#for key in feed_dict.keys():
# print(key, feed_dict[key].shape)
return
feed_dict
,
False
,
None
,
""
def
postprocess
(
self
,
input_dicts
,
fetch_dict
,
log_id
):
#print(fetch_dict)
res_dict
=
{
"bbox_result"
:
str
(
self
.
img_postprocess
(
fetch_dict
,
visualize
=
False
))}
return
res_dict
,
None
,
""
class
Yolov3Service
(
WebService
):
def
get_pipeline_response
(
self
,
read_op
):
yolov3_op
=
Yolov3Op
(
name
=
"yolov3"
,
input_ops
=
[
read_op
])
return
yolov3_op
yolov3_service
=
Yolov3Service
(
name
=
"yolov3"
)
yolov3_service
.
prepare_pipeline_config
(
"config2.yml"
)
yolov3_service
.
run_service
()
python/paddle_serving_app/reader/image_reader.py
浏览文件 @
26fce3da
...
...
@@ -415,7 +415,7 @@ class RCNNPostprocess(object):
out_path
=
os
.
path
.
join
(
self
.
output_dir
,
image_path
)
image
.
save
(
out_path
,
quality
=
95
)
def
__call__
(
self
,
image_with_bbox
):
def
__call__
(
self
,
image_with_bbox
,
visualize
=
True
):
fetch_name
=
""
for
key
in
image_with_bbox
:
if
key
==
"image"
:
...
...
@@ -427,6 +427,8 @@ class RCNNPostprocess(object):
self
.
clsid2catid
)
if
os
.
path
.
isdir
(
self
.
output_dir
)
is
False
:
os
.
mkdir
(
self
.
output_dir
)
if
visualize
is
False
:
return
bbox_result
self
.
visualize
(
image_with_bbox
[
"image"
],
bbox_result
,
self
.
catid2name
,
len
(
self
.
label_list
))
if
os
.
path
.
isdir
(
self
.
output_dir
)
is
False
:
...
...
@@ -434,6 +436,7 @@ class RCNNPostprocess(object):
bbox_file
=
os
.
path
.
join
(
self
.
output_dir
,
'bbox.json'
)
with
open
(
bbox_file
,
'w'
)
as
f
:
json
.
dump
(
bbox_result
,
f
,
indent
=
4
)
return
bbox_result
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"label_file: {1}, output_dir: {2}"
.
format
(
...
...
tools/Dockerfile.cuda10.1-cudnn7.devel
浏览文件 @
26fce3da
...
...
@@ -104,7 +104,7 @@ ENV PATH=usr/local/go/bin:/root/go/bin:${PATH}
# Downgrade TensorRT
COPY tools/dockerfiles/build_scripts /build_scripts
RUN bash /build_scripts/install_trt.sh
RUN bash /build_scripts/install_trt.sh
cuda10.1
RUN rm -rf /build_scripts
# git credential to skip password typing
...
...
@@ -132,9 +132,9 @@ RUN wget https://paddle-ci.gz.bcebos.com/ccache-3.7.9.tar.gz && \
make -j8 && make install && \
ln -s /usr/local/ccache-3.7.9/bin/ccache /usr/local/bin/ccache
RUN python3.8 -m pip install --upgrade pip requests && \
python3.7 -m pip install --upgrade pip requests && \
python3.6 -m pip install --upgrade pip requests
RUN python3.8 -m pip install --upgrade pip
==21.1.1
requests && \
python3.7 -m pip install --upgrade pip
==21.1.1
requests && \
python3.6 -m pip install --upgrade pip
==21.1.1
requests
RUN wget https://paddle-serving.bj.bcebos.com/others/centos_ssl.tar && \
tar xf centos_ssl.tar && rm -rf centos_ssl.tar && \
...
...
tools/Dockerfile.cuda10.2-cudnn8.devel
浏览文件 @
26fce3da
...
...
@@ -104,7 +104,7 @@ ENV PATH=usr/local/go/bin:/root/go/bin:${PATH}
# Downgrade TensorRT
COPY tools/dockerfiles/build_scripts /build_scripts
RUN bash /build_scripts/install_trt.sh
RUN bash /build_scripts/install_trt.sh
cuda10.2
RUN rm -rf /build_scripts
# git credential to skip password typing
...
...
@@ -132,9 +132,9 @@ RUN wget https://paddle-ci.gz.bcebos.com/ccache-3.7.9.tar.gz && \
make -j8 && make install && \
ln -s /usr/local/ccache-3.7.9/bin/ccache /usr/local/bin/ccache
RUN python3.8 -m pip install --upgrade pip requests && \
python3.7 -m pip install --upgrade pip requests && \
python3.6 -m pip install --upgrade pip requests
RUN python3.8 -m pip install --upgrade pip
==21.1.1
requests && \
python3.7 -m pip install --upgrade pip
==21.1.1
requests && \
python3.6 -m pip install --upgrade pip
==21.1.1
requests
RUN wget https://paddle-serving.bj.bcebos.com/others/centos_ssl.tar && \
tar xf centos_ssl.tar && rm -rf centos_ssl.tar && \
...
...
tools/Dockerfile.cuda11
.2
-cudnn8.devel
→
tools/Dockerfile.cuda11-cudnn8.devel
浏览文件 @
26fce3da
# A image for building paddle binaries
# Use cuda devel base image for both cpu and gpu environment
# When you modify it, please be aware of cudnn-runtime version
FROM nvidia/cuda:11.
2.0
-cudnn8-devel-ubuntu16.04
FROM nvidia/cuda:11.
0.3
-cudnn8-devel-ubuntu16.04
MAINTAINER PaddlePaddle Authors <paddle-dev@baidu.com>
# ENV variables
...
...
@@ -104,7 +104,7 @@ ENV PATH=usr/local/go/bin:/root/go/bin:${PATH}
# Downgrade TensorRT
COPY tools/dockerfiles/build_scripts /build_scripts
RUN bash /build_scripts/install_trt.sh
RUN bash /build_scripts/install_trt.sh
cuda11
RUN rm -rf /build_scripts
# git credential to skip password typing
...
...
@@ -132,9 +132,9 @@ RUN wget https://paddle-ci.gz.bcebos.com/ccache-3.7.9.tar.gz && \
make -j8 && make install && \
ln -s /usr/local/ccache-3.7.9/bin/ccache /usr/local/bin/ccache
RUN python3.8 -m pip install --upgrade pip requests && \
python3.7 -m pip install --upgrade pip requests && \
python3.6 -m pip install --upgrade pip requests
RUN python3.8 -m pip install --upgrade pip
==21.1.1
requests && \
python3.7 -m pip install --upgrade pip
==21.1.1
requests && \
python3.6 -m pip install --upgrade pip
==21.1.1
requests
RUN wget https://paddle-serving.bj.bcebos.com/others/centos_ssl.tar && \
tar xf centos_ssl.tar && rm -rf centos_ssl.tar && \
...
...
tools/Dockerfile.devel
浏览文件 @
26fce3da
...
...
@@ -132,9 +132,9 @@ RUN wget https://paddle-ci.gz.bcebos.com/ccache-3.7.9.tar.gz && \
make -j8 && make install && \
ln -s /usr/local/ccache-3.7.9/bin/ccache /usr/local/bin/ccache
RUN python3.8 -m pip install --upgrade pip requests && \
python3.7 -m pip install --upgrade pip requests && \
python3.6 -m pip install --upgrade pip requests
RUN python3.8 -m pip install --upgrade pip
==21.1.1
requests && \
python3.7 -m pip install --upgrade pip
==21.1.1
requests && \
python3.6 -m pip install --upgrade pip
==21.1.1
requests
RUN wget https://paddle-serving.bj.bcebos.com/others/centos_ssl.tar && \
tar xf centos_ssl.tar && rm -rf centos_ssl.tar && \
...
...
tools/Dockerfile.runtime_template
浏览文件 @
26fce3da
...
...
@@ -30,10 +30,14 @@ WORKDIR /home
COPY tools/dockerfiles/build_scripts /build_scripts
RUN bash /build_scripts/install_whl.sh <<serving_version>> <<paddle_version>> <<run_env>> <<python_version>> && rm -rf /build_scripts
WORKDIR /home
COPY tools/dockerfiles/build_scripts /build_scripts
RUN bash /build_scripts/soft_link.sh <<run_env>>
# install tensorrt
WORKDIR /home
COPY tools/dockerfiles/build_scripts /build_scripts
RUN bash /build_scripts/install_trt.sh && rm -rf /build_scripts
RUN bash /build_scripts/install_trt.sh
<<run_env>>
&& rm -rf /build_scripts
# install go
RUN wget -qO- https://dl.google.com/go/go1.14.linux-amd64.tar.gz | \
...
...
tools/dockerfiles/build_scripts/install_trt.sh
浏览文件 @
26fce3da
...
...
@@ -14,31 +14,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
VERSION
=
$(
nvcc
--version
|
grep
release |
grep
-oEi
"release ([0-9]+)
\.
([0-9])"
|
sed
"s/release //"
)
if
[[
"
$VERSION
"
==
"10.1"
]]
;
then
VERSION
=
$1
if
[[
"
$VERSION
"
==
"cuda10.1"
]]
;
then
wget
-q
https://paddle-ci.gz.bcebos.com/TRT/TensorRT6-cuda10.1-cudnn7.tar.gz
--no-check-certificate
tar
-zxf
TensorRT6-cuda10.1-cudnn7.tar.gz
-C
/usr/local
cp
-rf
/usr/local/TensorRT6-cuda10.1-cudnn7/include/
*
/usr/include/
&&
cp
-rf
/usr/local/TensorRT6-cuda10.1-cudnn7/lib/
*
/usr/lib/
echo
"cuda10.1 trt install ==============>>>>>>>>>>>>"
rm
TensorRT6-cuda10.1-cudnn7.tar.gz
elif
[[
"
$VERSION
"
==
"
11.0
"
]]
;
then
elif
[[
"
$VERSION
"
==
"
cuda11
"
]]
;
then
wget
-q
https://paddle-ci.cdn.bcebos.com/TRT/TensorRT-7.1.3.4.Ubuntu-16.04.x86_64-gnu.cuda-11.0.cudnn8.0.tar.gz
--no-check-certificate
tar
-zxf
TensorRT-7.1.3.4.Ubuntu-16.04.x86_64-gnu.cuda-11.0.cudnn8.0.tar.gz
-C
/usr/local
cp
-rf
/usr/local/TensorRT-7.1.3.4/include/
*
/usr/include/
&&
cp
-rf
/usr/local/TensorRT-7.1.3.4/lib/
*
/usr/lib/
rm
TensorRT-7.1.3.4.Ubuntu-16.04.x86_64-gnu.cuda-11.0.cudnn8.0.tar.gz
elif
[[
"
$VERSION
"
==
"10.2"
]]
;
then
elif
[[
"
$VERSION
"
==
"
cuda
10.2"
]]
;
then
wget https://paddle-ci.gz.bcebos.com/TRT/TensorRT7-cuda10.2-cudnn8.tar.gz
--no-check-certificate
tar
-zxf
TensorRT7-cuda10.2-cudnn8.tar.gz
-C
/usr/local
cp
-rf
/usr/local/TensorRT-7.1.3.4/include/
*
/usr/include/
&&
cp
-rf
/usr/local/TensorRT-7.1.3.4/lib/
*
/usr/lib/
rm
TensorRT7-cuda10.2-cudnn8.tar.gz
elif
[[
"
$VERSION
"
==
"10.0"
]]
;
then
wget
-q
https://paddle-ci.gz.bcebos.com/TRT/TensorRT6-cuda10.0-cudnn7.tar.gz
--no-check-certificate
tar
-zxf
TensorRT6-cuda10.0-cudnn7.tar.gz
-C
/usr/local
cp
-rf
/usr/local/TensorRT6-cuda10.0-cudnn7/include/
*
/usr/include/
&&
cp
-rf
/usr/local/TensorRT6-cuda10.0-cudnn7/lib/
*
/usr/lib/
rm
TensorRT6-cuda10.0-cudnn7.tar.gz
elif
[[
"
$VERSION
"
==
"9.0"
]]
;
then
wget
-q
https://paddle-ci.gz.bcebos.com/TRT/TensorRT6-cuda9.0-cudnn7.tar.gz
--no-check-certificate
tar
-zxf
TensorRT6-cuda9.0-cudnn7.tar.gz
-C
/usr/local
cp
-rf
/usr/local/TensorRT6-cuda9.0-cudnn7/include/
*
/usr/include/
&&
cp
-rf
/usr/local/TensorRT6-cuda9.0-cudnn7/lib/
*
/usr/lib/
rm
TensorRT6-cuda9.0-cudnn7.tar.gz
fi
tools/dockerfiles/build_scripts/install_whl.sh
浏览文件 @
26fce3da
...
...
@@ -40,6 +40,9 @@ if [[ $SERVING_VERSION == "0.5.0" ]]; then
elif
[[
"
$RUN_ENV
"
==
"cuda10.2"
]]
;
then
server_release
=
"paddle-serving-server-gpu==
$SERVING_VERSION
.post102"
serving_bin
=
"https://paddle-serving.bj.bcebos.com/bin/serving-gpu-102-
${
SERVING_VERSION
}
.tar.gz"
elif
[[
"
$RUN_ENV
"
==
"cuda11"
]]
;
then
server_release
=
"paddle-serving-server-gpu==
$SERVING_VERSION
.post11"
serving_bin
=
"https://paddle-serving.bj.bcebos.com/bin/serving-gpu-cuda11-
${
SERVING_VERSION
}
.tar.gz"
fi
client_release
=
"paddle-serving-client==
$SERVING_VERSION
"
app_release
=
"paddle-serving-app==0.3.1"
...
...
@@ -53,6 +56,9 @@ elif [[ $SERVING_VERSION == "0.6.0" ]]; then
elif
[[
"
$RUN_ENV
"
==
"cuda10.2"
]]
;
then
server_release
=
"https://paddle-serving.bj.bcebos.com/test-dev/whl/paddle_serving_server_gpu-
$SERVING_VERSION
.post102-py3-none-any.whl"
serving_bin
=
"https://paddle-serving.bj.bcebos.com/test-dev/bin/serving-gpu-102-
$SERVING_VERSION
.tar.gz"
elif
[[
"
$RUN_ENV
"
==
"cuda11"
]]
;
then
server_release
=
"https://paddle-serving.bj.bcebos.com/test-dev/whl/paddle_serving_server_gpu-
$SERVING_VERSION
.post11-py3-none-any.whl"
serving_bin
=
"https://paddle-serving.bj.bcebos.com/test-dev/bin/serving-gpu-cuda11-
$SERVING_VERSION
.tar.gz"
fi
client_release
=
"https://paddle-serving.bj.bcebos.com/test-dev/whl/paddle_serving_client-
$SERVING_VERSION
-cp
$CPYTHON
-none-any.whl"
app_release
=
"https://paddle-serving.bj.bcebos.com/test-dev/whl/paddle_serving_app-
$SERVING_VERSION
-py3-none-any.whl"
...
...
@@ -88,6 +94,16 @@ elif [[ "$RUN_ENV" == "cuda10.2" ]];then
echo
"export SERVING_BIN=
$PWD
/serving_bin/serving"
>>
/root/.bashrc
rm
-rf
serving-gpu-102-
${
SERVING_VERSION
}
.tar.gz
cd
-
elif
[[
"
$RUN_ENV
"
==
"cuda11"
]]
;
then
python
$PYTHON_VERSION
-m
pip
install
$client_release
$app_release
$server_release
python
$PYTHON_VERSION
-m
pip
install
paddlepaddle-gpu
==
${
PADDLE_VERSION
}
cd
/usr/local/
wget
$serving_bin
tar
xf serving-gpu-cuda11-
${
SERVING_VERSION
}
.tar.gz
mv
$PWD
/serving-gpu-cuda11-
${
SERVING_VERSION
}
$PWD
/serving_bin
echo
"export SERVING_BIN=
$PWD
/serving_bin/serving"
>>
/root/.bashrc
rm
-rf
serving-gpu-cuda11-
${
SERVING_VERSION
}
.tar.gz
cd
-
fi
tools/dockerfiles/build_scripts/soft_link.sh
0 → 100644
浏览文件 @
26fce3da
RUN_ENV
=
$1
if
[[
"
$RUN_ENV
"
==
"cuda10.1"
]]
;
then
ln
-sf
/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudart.so.10.1 /usr/lib/libcudart.so
&&
\
ln
-sf
/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcusolver.so.10 /usr/lib/libcusolver.so
&&
\
ln
-sf
/usr/lib/x86_64-linux-gnu/libcuda.so /usr/lib/libcuda.so
&&
\
ln
-sf
/usr/lib/x86_64-linux-gnu/libcublas.so.10 /usr/lib/libcublas.so
&&
\
ln
-sf
/usr/lib/x86_64-linux-gnu/libcudnn.so.7 /usr/lib/libcudnn.so
elif
[[
"
$RUN_ENV
"
==
"cuda10.1"
]]
;
then
ln
-sf
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudart.so.10.2 /usr/lib/libcudart.so
&&
\
ln
-sf
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcusolver.so.10 /usr/lib/libcusolver.so
&&
\
ln
-sf
/usr/lib/x86_64-linux-gnu/libcuda.so /usr/lib/libcuda.so
&&
\
ln
-sf
/usr/lib/x86_64-linux-gnu/libcublas.so.10 /usr/lib/libcublas.so
&&
\
ln
-sf
/usr/lib/x86_64-linux-gnu/libcudnn.so.8 /usr/lib/libcudnn.so
elif
[[
"
$RUN_ENV
"
==
"cuda10.1"
]]
;
then
ln
-sf
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudart.so.11.0 /usr/lib/libcudart.so
&&
\
ln
-sf
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcusolver.so.10 /usr/lib/libcusolver.so
&&
\
ln
-sf
/usr/lib/x86_64-linux-gnu/libcuda.so /usr/lib/libcuda.so
&&
\
ln
-sf
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcublas.so.11 /usr/lib/libcublas.so
&&
\
ln
-sf
/usr/lib/x86_64-linux-gnu/libcudnn.so.8 /usr/lib/libcudnn.so
fi
tools/generate_runtime_docker.sh
浏览文件 @
26fce3da
...
...
@@ -66,6 +66,8 @@ function run
base_image
=
"nvidia
\/
cuda:10.1-cudnn7-runtime-ubuntu16.04"
elif
[
$env
==
"cuda10.2"
]
;
then
base_image
=
"nvidia
\/
cuda:10.2-cudnn8-runtime-ubuntu16.04"
elif
[
$env
==
"cuda11"
]
;
then
base_image
=
"nvidia
\/
cuda:11.0.3-cudnn8-runtime-ubuntu16.04"
fi
echo
"base image:
$base_image
"
echo
"named arg: python:
$python
"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录