提交 75d44eb5 编写于 作者: J Jiawei Wang 提交者: wangjiawei04

Merge pull request #911 from wangjiawei04/example_bugs

fix bugs in examples
上级 1af347ab
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import *
import numpy as np
preprocess = Sequential([
File2Image(), BGR2RGB(), Div(255.0),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], False),
Resize(800, 1333), Transpose((2, 0, 1)), PadStride(32)
])
postprocess = RCNNPostprocess("label_list.txt", "output")
client = Client()
client.load_client_config("serving_client/serving_client_conf.prototxt")
client.connect(['127.0.0.1:9292'])
im = preprocess('000000570688.jpg')
fetch_map = client.predict(
feed={
"image": im,
"im_info": np.array(list(im.shape[1:]) + [1.0]),
"im_shape": np.array(list(im.shape[1:]) + [1.0])
},
fetch=["multiclass_nms_0.tmp_0"],
batch=True)
fetch_map["image"] = '000000570688.jpg'
print(fetch_map)
postprocess(fetch_map)
print(fetch_map)
...@@ -29,13 +29,11 @@ from paddle_serving_app.reader import ChineseBertReader ...@@ -29,13 +29,11 @@ from paddle_serving_app.reader import ChineseBertReader
from paddle_serving_app.reader import * from paddle_serving_app.reader import *
import numpy as np import numpy as np
args = benchmark_args() args = benchmark_args()
def single_func(idx, resource): def single_func(idx, resource):
img="./000000570688.jpg" img = "./000000570688.jpg"
profile_flags = False profile_flags = False
latency_flags = False latency_flags = False
if os.getenv("FLAGS_profile_client"): if os.getenv("FLAGS_profile_client"):
...@@ -67,9 +65,11 @@ def single_func(idx, resource): ...@@ -67,9 +65,11 @@ def single_func(idx, resource):
for bi in range(args.batch_size): for bi in range(args.batch_size):
print("1111batch") print("1111batch")
print(bi) print(bi)
feed_batch.append({"image": im, feed_batch.append({
"image": im,
"im_info": np.array(list(im.shape[1:]) + [1.0]), "im_info": np.array(list(im.shape[1:]) + [1.0]),
"im_shape": np.array(list(im.shape[1:]) + [1.0])}) "im_shape": np.array(list(im.shape[1:]) + [1.0])
})
# im = preprocess(img) # im = preprocess(img)
b_end = time.time() b_end = time.time()
...@@ -81,8 +81,7 @@ def single_func(idx, resource): ...@@ -81,8 +81,7 @@ def single_func(idx, resource):
int(round(b_end * 1000000)))) int(round(b_end * 1000000))))
#result = client.predict(feed=feed_batch, fetch=fetch) #result = client.predict(feed=feed_batch, fetch=fetch)
fetch_map = client.predict( fetch_map = client.predict(
feed=feed_batch, feed=feed_batch, fetch=["multiclass_nms"])
fetch=["multiclass_nms"])
fetch_map["image"] = img fetch_map["image"] = img
postprocess(fetch_map) postprocess(fetch_map)
...@@ -102,13 +101,12 @@ def single_func(idx, resource): ...@@ -102,13 +101,12 @@ def single_func(idx, resource):
if __name__ == '__main__': if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner() multi_thread_runner = MultiThreadRunner()
endpoint_list = [ endpoint_list = ["127.0.0.1:7777"]
"127.0.0.1:7777"
]
turns = 10 turns = 10
start = time.time() start = time.time()
result = multi_thread_runner.run( result = multi_thread_runner.run(
single_func, args.thread, {"endpoint": endpoint_list,"turns": turns}) single_func, args.thread, {"endpoint": endpoint_list,
"turns": turns})
end = time.time() end = time.time()
total_cost = end - start total_cost = end - start
......
background
person
bicycle
car
motorcycle
airplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
couch
potted plant
bed
dining table
toilet
tv
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush
...@@ -53,7 +53,9 @@ class OCRService(WebService): ...@@ -53,7 +53,9 @@ class OCRService(WebService):
self.ori_h, self.ori_w, _ = im.shape self.ori_h, self.ori_w, _ = im.shape
det_img = self.det_preprocess(im) det_img = self.det_preprocess(im)
_, self.new_h, self.new_w = det_img.shape _, self.new_h, self.new_w = det_img.shape
return {"image": det_img[np.newaxis, :].copy()}, ["concat_1.tmp_0"] return {
"image": det_img[np.newaxis, :].copy()
}, ["concat_1.tmp_0"], True
def postprocess(self, feed={}, fetch=[], fetch_map=None): def postprocess(self, feed={}, fetch=[], fetch_map=None):
det_out = fetch_map["concat_1.tmp_0"] det_out = fetch_map["concat_1.tmp_0"]
......
...@@ -54,7 +54,7 @@ class OCRService(WebService): ...@@ -54,7 +54,7 @@ class OCRService(WebService):
det_img = self.det_preprocess(im) det_img = self.det_preprocess(im)
_, self.new_h, self.new_w = det_img.shape _, self.new_h, self.new_w = det_img.shape
print(det_img) print(det_img)
return {"image": det_img}, ["concat_1.tmp_0"] return {"image": det_img}, ["concat_1.tmp_0"], False
def postprocess(self, feed={}, fetch=[], fetch_map=None): def postprocess(self, feed={}, fetch=[], fetch_map=None):
det_out = fetch_map["concat_1.tmp_0"] det_out = fetch_map["concat_1.tmp_0"]
......
...@@ -42,10 +42,9 @@ class OCRService(WebService): ...@@ -42,10 +42,9 @@ class OCRService(WebService):
self.det_client = LocalPredictor() self.det_client = LocalPredictor()
if sys.argv[1] == 'gpu': if sys.argv[1] == 'gpu':
self.det_client.load_model_config( self.det_client.load_model_config(
det_model_config, gpu=True, profile=False) det_model_config, use_gpu=True, gpu_id=1)
elif sys.argv[1] == 'cpu': elif sys.argv[1] == 'cpu':
self.det_client.load_model_config( self.det_client.load_model_config(det_model_config)
det_model_config, gpu=False, profile=False)
self.ocr_reader = OCRReader() self.ocr_reader = OCRReader()
def preprocess(self, feed=[], fetch=[]): def preprocess(self, feed=[], fetch=[]):
...@@ -58,7 +57,7 @@ class OCRService(WebService): ...@@ -58,7 +57,7 @@ class OCRService(WebService):
det_img = det_img[np.newaxis, :] det_img = det_img[np.newaxis, :]
det_img = det_img.copy() det_img = det_img.copy()
det_out = self.det_client.predict( det_out = self.det_client.predict(
feed={"image": det_img}, fetch=["concat_1.tmp_0"]) feed={"image": det_img}, fetch=["concat_1.tmp_0"], batch=True)
filter_func = FilterBoxes(10, 10) filter_func = FilterBoxes(10, 10)
post_func = DBPostProcess({ post_func = DBPostProcess({
"thresh": 0.3, "thresh": 0.3,
...@@ -91,7 +90,7 @@ class OCRService(WebService): ...@@ -91,7 +90,7 @@ class OCRService(WebService):
imgs[id] = norm_img imgs[id] = norm_img
feed = {"image": imgs.copy()} feed = {"image": imgs.copy()}
fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
return feed, fetch return feed, fetch, True
def postprocess(self, feed={}, fetch=[], fetch_map=None): def postprocess(self, feed={}, fetch=[], fetch_map=None):
rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True) rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True)
...@@ -107,7 +106,8 @@ ocr_service.load_model_config("ocr_rec_model") ...@@ -107,7 +106,8 @@ ocr_service.load_model_config("ocr_rec_model")
ocr_service.prepare_server(workdir="workdir", port=9292) ocr_service.prepare_server(workdir="workdir", port=9292)
ocr_service.init_det_debugger(det_model_config="ocr_det_model") ocr_service.init_det_debugger(det_model_config="ocr_det_model")
if sys.argv[1] == 'gpu': if sys.argv[1] == 'gpu':
ocr_service.run_debugger_service(gpu=True) ocr_service.set_gpus("2")
ocr_service.run_debugger_service()
elif sys.argv[1] == 'cpu': elif sys.argv[1] == 'cpu':
ocr_service.run_debugger_service() ocr_service.run_debugger_service()
ocr_service.run_web_service() ocr_service.run_web_service()
...@@ -36,4 +36,5 @@ for img_file in os.listdir(test_img_dir): ...@@ -36,4 +36,5 @@ for img_file in os.listdir(test_img_dir):
image = cv2_to_base64(image_data1) image = cv2_to_base64(image_data1)
data = {"feed": [{"image": image}], "fetch": ["res"]} data = {"feed": [{"image": image}], "fetch": ["res"]}
r = requests.post(url=url, headers=headers, data=json.dumps(data)) r = requests.post(url=url, headers=headers, data=json.dumps(data))
print(r)
print(r.json()) print(r.json())
...@@ -50,7 +50,7 @@ class OCRService(WebService): ...@@ -50,7 +50,7 @@ class OCRService(WebService):
ori_h, ori_w, _ = im.shape ori_h, ori_w, _ = im.shape
det_img = self.det_preprocess(im) det_img = self.det_preprocess(im)
det_out = self.det_client.predict( det_out = self.det_client.predict(
feed={"image": det_img}, fetch=["concat_1.tmp_0"]) feed={"image": det_img}, fetch=["concat_1.tmp_0"], batch=False)
_, new_h, new_w = det_img.shape _, new_h, new_w = det_img.shape
filter_func = FilterBoxes(10, 10) filter_func = FilterBoxes(10, 10)
post_func = DBPostProcess({ post_func = DBPostProcess({
...@@ -77,10 +77,10 @@ class OCRService(WebService): ...@@ -77,10 +77,10 @@ class OCRService(WebService):
max_wh_ratio = max(max_wh_ratio, wh_ratio) max_wh_ratio = max(max_wh_ratio, wh_ratio)
for img in img_list: for img in img_list:
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio) norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
feed = {"image": norm_img} feed_list.append(norm_img[np.newaxis, :])
feed_list.append(feed) feed_batch = {"image": np.concatenate(feed_list, axis=0)}
fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
return feed_list, fetch return feed_batch, fetch, True
def postprocess(self, feed={}, fetch=[], fetch_map=None): def postprocess(self, feed={}, fetch=[], fetch_map=None):
rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True) rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True)
......
...@@ -52,7 +52,7 @@ class OCRService(WebService): ...@@ -52,7 +52,7 @@ class OCRService(WebService):
imgs[i] = norm_img imgs[i] = norm_img
feed = {"image": imgs.copy()} feed = {"image": imgs.copy()}
fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
return feed, fetch return feed, fetch, True
def postprocess(self, feed={}, fetch=[], fetch_map=None): def postprocess(self, feed={}, fetch=[], fetch_map=None):
rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True) rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True)
......
...@@ -51,10 +51,17 @@ class OCRService(WebService): ...@@ -51,10 +51,17 @@ class OCRService(WebService):
max_wh_ratio = max(max_wh_ratio, wh_ratio) max_wh_ratio = max(max_wh_ratio, wh_ratio)
for img in img_list: for img in img_list:
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio) norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
feed = {"image": norm_img} #feed = {"image": norm_img}
feed_list.append(feed) feed_list.append(norm_img)
if len(feed_list) == 1:
feed_batch = {
"image": np.concatenate(
feed_list, axis=0)[np.newaxis, :]
}
else:
feed_batch = {"image": np.concatenate(feed_list, axis=0)}
fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
return feed_list, fetch return feed_batch, fetch, True
def postprocess(self, feed={}, fetch=[], fetch_map=None): def postprocess(self, feed={}, fetch=[], fetch_map=None):
rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True) rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True)
......
...@@ -189,7 +189,7 @@ class WebService(object): ...@@ -189,7 +189,7 @@ class WebService(object):
from paddle_serving_app.local_predict import LocalPredictor from paddle_serving_app.local_predict import LocalPredictor
self.client = LocalPredictor() self.client = LocalPredictor()
self.client.load_model_config( self.client.load_model_config(
"{}".format(self.model_config), gpu=False, profile=False) "{}".format(self.model_config), use_gpu=False)
def run_web_service(self): def run_web_service(self):
print("This API will be deprecated later. Please do not use it") print("This API will be deprecated later. Please do not use it")
......
...@@ -250,7 +250,7 @@ class WebService(object): ...@@ -250,7 +250,7 @@ class WebService(object):
from paddle_serving_app.local_predict import LocalPredictor from paddle_serving_app.local_predict import LocalPredictor
self.client = LocalPredictor() self.client = LocalPredictor()
self.client.load_model_config( self.client.load_model_config(
"{}".format(self.model_config), gpu=gpu, profile=False) "{}".format(self.model_config), use_gpu=True, gpu_id=self.gpus[0])
def run_web_service(self): def run_web_service(self):
print("This API will be deprecated later. Please do not use it") print("This API will be deprecated later. Please do not use it")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册