提交 b6f7530b 编写于 作者: B barrierye

Merge branch 'develop' of https://github.com/PaddlePaddle/Serving into pipeline-update

# Blazeface
## Get Model
```
python -m paddle_serving_app.package --get_model blazeface
tar -xzvf blazeface.tar.gz
```
## RPC Service
### Start Service
```
python -m paddle_serving_server.serve --model serving_server --port 9494
```
### Client Prediction
```
python test_client.py serving_client/serving_client_conf.prototxt test.jpg
```
the result is in `output` folder, including a json file and image file with bounding boxes.
......@@ -13,19 +13,26 @@
# limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import OCRReader
import cv2
from paddle_serving_app.reader import *
import sys
import numpy as np
preprocess = Sequential([
File2Image(),
Normalize([104, 117, 123], [127.502231, 127.502231, 127.502231], False)
])
postprocess = BlazeFacePostprocess("label_list.txt", "output")
client = Client()
client.load_client_config("ocr_rec_client/serving_client_conf.prototxt")
client.connect(["127.0.0.1:9292"])
image_file_list = ["./test_rec.jpg"]
img = cv2.imread(image_file_list[0])
ocr_reader = OCRReader()
feed = {"image": ocr_reader.preprocess([img])}
fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
fetch_map = client.predict(feed=feed, fetch=fetch)
rec_res = ocr_reader.postprocess(fetch_map)
print(image_file_list[0])
print(rec_res[0][0])
client.load_client_config(sys.argv[1])
client.connect(['127.0.0.1:9494'])
im_0 = preprocess(sys.argv[2])
tmp = Transpose((2, 0, 1))
im = tmp(im_0)
fetch_map = client.predict(
feed={"image": im}, fetch=["detection_output_0.tmp_0"])
fetch_map["image"] = sys.argv[2]
fetch_map["im_shape"] = im_0.shape
postprocess(fetch_map)
......@@ -29,6 +29,7 @@ args = benchmark_args()
def single_func(idx, resource):
client = Client()
print([resource["endpoint"][idx % len(resource["endpoint"])]])
client.load_client_config('ctr_client_conf/serving_client_conf.prototxt')
client.connect(['127.0.0.1:9292'])
batch = 1
......@@ -40,27 +41,29 @@ def single_func(idx, resource):
]
reader = dataset.infer_reader(test_filelists[len(test_filelists) - 40:],
batch, buf_size)
args.batch_size = 1
if args.request == "rpc":
fetch = ["prob"]
print("Start Time")
start = time.time()
itr = 1000
for ei in range(itr):
if args.batch_size == 1:
data = reader().next()
feed_dict = {}
feed_dict['dense_input'] = data[0][0]
for i in range(1, 27):
feed_dict["embedding_{}.tmp_0".format(i - 1)] = data[0][i]
result = client.predict(feed=feed_dict, fetch=fetch)
if args.batch_size > 0:
feed_batch = []
for bi in range(args.batch_size):
data = reader().next()
feed_dict = {}
feed_dict['dense_input'] = data[0][0]
for i in range(1, 27):
feed_dict["embedding_{}.tmp_0".format(i - 1)] = data[0][
i]
feed_batch.append(feed_dict)
result = client.predict(feed=feed_batch, fetch=fetch)
else:
print("unsupport batch size {}".format(args.batch_size))
elif args.request == "http":
raise ("Not support http service.")
end = time.time()
qps = itr / (end - start)
qps = itr * args.batch_size / (end - start)
return [[end - start, qps]]
......@@ -70,6 +73,7 @@ if __name__ == '__main__':
#result = single_func(0, {"endpoint": endpoint_list})
result = multi_thread_runner.run(single_func, args.thread,
{"endpoint": endpoint_list})
print(result)
avg_cost = 0
qps = 0
for i in range(args.thread):
......
rm profile_log
batch_size=1
export FLAGS_profile_client=1
export FLAGS_profile_server=1
for thread_num in 1 2 4 8 16
do
$PYTHONROOT/bin/python benchmark.py --thread $thread_num --model ctr_client_conf/serving_client_conf.prototxt --request rpc > profile 2>&1
for batch_size in 1 4 16 64 256
do
$PYTHONROOT/bin/python benchmark.py --thread $thread_num --batch_size $batch_size --model serving_client_conf/serving_client_conf.prototxt --request rpc > profile 2>&1
echo "batch size : $batch_size"
echo "thread num : $thread_num"
echo "========================================"
echo "batch size : $batch_size" >> profile_log
$PYTHONROOT/bin/python ../util/show_profile.py profile $thread_num >> profile_log
tail -n 2 profile >> profile_log
done
done
# -*- coding: utf-8 -*-
#
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_client import Client
import sys
import os
import criteo as criteo
import time
from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args
from paddle_serving_client.metric import auc
args = benchmark_args()
def single_func(idx, resource):
client = Client()
print([resource["endpoint"][idx % len(resource["endpoint"])]])
client.load_client_config('ctr_client_conf/serving_client_conf.prototxt')
client.connect(['127.0.0.1:9292'])
batch = 1
buf_size = 100
dataset = criteo.CriteoDataset()
dataset.setup(1000001)
test_filelists = [
"./raw_data/part-%d" % x for x in range(len(os.listdir("./raw_data")))
]
reader = dataset.infer_reader(test_filelists[len(test_filelists) - 40:],
batch, buf_size)
if args.request == "rpc":
fetch = ["prob"]
start = time.time()
itr = 1000
for ei in range(itr):
if args.batch_size > 1:
feed_batch = []
for bi in range(args.batch_size):
data = reader().next()
feed_dict = {}
feed_dict['dense_input'] = data[0][0]
for i in range(1, 27):
feed_dict["embedding_{}.tmp_0".format(i - 1)] = data[0][
i]
feed_batch.append(feed_dict)
result = client.predict(feed=feed_batch, fetch=fetch)
else:
print("unsupport batch size {}".format(args.batch_size))
elif args.request == "http":
raise ("Not support http service.")
end = time.time()
qps = itr * args.batch_size / (end - start)
return [[end - start, qps]]
if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner()
endpoint_list = ["127.0.0.1:9292"]
#result = single_func(0, {"endpoint": endpoint_list})
result = multi_thread_runner.run(single_func, args.thread,
{"endpoint": endpoint_list})
print(result)
avg_cost = 0
qps = 0
for i in range(args.thread):
avg_cost += result[0][i * 2 + 0]
qps += result[0][i * 2 + 1]
avg_cost = avg_cost / args.thread
print("average total cost {} s.".format(avg_cost))
print("qps {} ins/s".format(qps))
rm profile_log
for thread_num in 1 2 4 8 16
do
for batch_size in 1 2 4 8 16 32 64 128 256 512
do
$PYTHONROOT/bin/python benchmark_batch.py --thread $thread_num --batch_size $batch_size --model serving_client_conf/serving_client_conf.prototxt --request rpc > profile 2>&1
echo "========================================"
echo "batch size : $batch_size" >> profile_log
$PYTHONROOT/bin/python ../util/show_profile.py profile $thread_num >> profile_log
tail -n 2 profile >> profile_log
done
done
......@@ -16,7 +16,5 @@
mkdir -p cube_model
mkdir -p cube/data
./seq_generator ctr_serving_model/SparseFeatFactors ./cube_model/feature
./cube/cube-builder -dict_name=test_dict -job_mode=base -last_version=0 -cur_version=0 -depend_version=0 -input_path=./cube_model -output_path=${PWD}/cube/data -shard_num=1 -only_build=false
mv ./cube/data/0_0/test_dict_part0/* ./cube/data/
cd cube && ./cube
cd cube && ./cube
......@@ -4,18 +4,42 @@
```
python -m paddle_serving_app.package --get_model ocr_rec
tar -xzvf ocr_rec.tar.gz
python -m paddle_serving_app.package --get_model ocr_det
tar -xzvf ocr_det.tar.gz
```
## RPC Service
### Start Service
For the following two code block, please check your devices and pick one
for GPU device
```
python -m paddle_serving_server_gpu.serve --model ocr_rec_model --port 9292 --gpu_id 0
python -m paddle_serving_server_gpu.serve --model ocr_det_model --port 9293 --gpu_id 0
```
for CPU device
```
python -m paddle_serving_server.serve --model ocr_rec_model --port 9292
python -m paddle_serving_server.serve --model ocr_det_model --port 9293
```
### Client Prediction
```
python test_ocr_rec_client.py
python ocr_rpc_client.py
```
## Web Service
### Start Service
```
python -m paddle_serving_server_gpu.serve --model ocr_det_model --port 9293 --gpu_id 0
python ocr_web_server.py
```
### Client Prediction
```
sh ocr_web_client.sh
```
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import OCRReader
import cv2
import sys
import numpy as np
import os
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, File2Image, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
from paddle_serving_app.reader import DBPostProcess, FilterBoxes
import time
import re
def sorted_boxes(dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i]
_boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp
return _boxes
def get_rotate_crop_image(img, points):
#img = cv2.imread(img)
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
img_crop_width = int(np.linalg.norm(points[0] - points[1]))
img_crop_height = int(np.linalg.norm(points[0] - points[3]))
pts_std = np.float32([[0, 0], [img_crop_width, 0], \
[img_crop_width, img_crop_height], [0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img_crop,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
def read_det_box_file(filename):
with open(filename, 'r') as f:
line = f.readline()
a, b, c = int(line.split(' ')[0]), int(line.split(' ')[1]), int(
line.split(' ')[2])
dt_boxes = np.zeros((a, b, c)).astype(np.float32)
line = f.readline()
for i in range(a):
for j in range(b):
line = f.readline()
dt_boxes[i, j, 0], dt_boxes[i, j, 1] = float(
line.split(' ')[0]), float(line.split(' ')[1])
line = f.readline()
def resize_norm_img(img, max_wh_ratio):
import math
imgC, imgH, imgW = 3, 32, 320
imgW = int(32 * max_wh_ratio)
h = img.shape[0]
w = img.shape[1]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def main():
client1 = Client()
client1.load_client_config("ocr_det_client/serving_client_conf.prototxt")
client1.connect(["127.0.0.1:9293"])
client2 = Client()
client2.load_client_config("ocr_rec_client/serving_client_conf.prototxt")
client2.connect(["127.0.0.1:9292"])
read_image_file = File2Image()
preprocess = Sequential([
ResizeByFactor(32, 960), Div(255),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose(
(2, 0, 1))
])
post_func = DBPostProcess({
"thresh": 0.3,
"box_thresh": 0.5,
"max_candidates": 1000,
"unclip_ratio": 1.5,
"min_size": 3
})
filter_func = FilterBoxes(10, 10)
ocr_reader = OCRReader()
files = [
"./imgs/{}".format(f) for f in os.listdir('./imgs')
if re.match(r'[0-9]+.*\.jpg|[0-9]+.*\.png', f)
]
#files = ["2.jpg"]*30
#files = ["rctw/rctw/train/images/image_{}.jpg".format(i) for i in range(500)]
time_all = 0
time_det_all = 0
time_rec_all = 0
for name in files:
#print(name)
im = read_image_file(name)
ori_h, ori_w, _ = im.shape
time1 = time.time()
img = preprocess(im)
_, new_h, new_w = img.shape
ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w]
#print(new_h, new_w, ori_h, ori_w)
time_before_det = time.time()
outputs = client1.predict(feed={"image": img}, fetch=["concat_1.tmp_0"])
time_after_det = time.time()
time_det_all += (time_after_det - time_before_det)
#print(outputs)
dt_boxes_list = post_func(outputs["concat_1.tmp_0"], [ratio_list])
dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w])
dt_boxes = sorted_boxes(dt_boxes)
feed_list = []
img_list = []
max_wh_ratio = 0
for i, dtbox in enumerate(dt_boxes):
boximg = get_rotate_crop_image(im, dt_boxes[i])
img_list.append(boximg)
h, w = boximg.shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for img in img_list:
norm_img = resize_norm_img(img, max_wh_ratio)
#norm_img = norm_img[np.newaxis, :]
feed = {"image": norm_img}
feed_list.append(feed)
#fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
fetch = ["ctc_greedy_decoder_0.tmp_0"]
time_before_rec = time.time()
if len(feed_list) == 0:
continue
fetch_map = client2.predict(feed=feed_list, fetch=fetch)
time_after_rec = time.time()
time_rec_all += (time_after_rec - time_before_rec)
rec_res = ocr_reader.postprocess(fetch_map)
#for res in rec_res:
# print(res[0].encode("utf-8"))
time2 = time.time()
time_all += (time2 - time1)
print("rpc+det time: {}".format(time_all / len(files)))
if __name__ == '__main__':
main()
curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"image": "https://paddle-serving.bj.bcebos.com/others/1.jpg"}], "fetch": ["res"]}' http://127.0.0.1:9292/ocr/prediction
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import OCRReader
import cv2
import sys
import numpy as np
import os
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
from paddle_serving_app.reader import DBPostProcess, FilterBoxes
from paddle_serving_server_gpu.web_service import WebService
import time
import re
class OCRService(WebService):
def init_det_client(self, det_port, det_client_config):
self.det_preprocess = Sequential([
ResizeByFactor(32, 960), Div(255),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose(
(2, 0, 1))
])
self.det_client = Client()
self.det_client.load_client_config(det_client_config)
self.det_client.connect(["127.0.0.1:{}".format(det_port)])
def preprocess(self, feed=[], fetch=[]):
img_url = feed[0]["image"]
#print(feed, img_url)
read_from_url = URL2Image()
im = read_from_url(img_url)
ori_h, ori_w, _ = im.shape
det_img = self.det_preprocess(im)
#print("det_img", det_img, det_img.shape)
det_out = self.det_client.predict(
feed={"image": det_img}, fetch=["concat_1.tmp_0"])
#print("det_out", det_out)
def sorted_boxes(dt_boxes):
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i]
_boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp
return _boxes
def get_rotate_crop_image(img, points):
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
img_crop_width = int(np.linalg.norm(points[0] - points[1]))
img_crop_height = int(np.linalg.norm(points[0] - points[3]))
pts_std = np.float32([[0, 0], [img_crop_width, 0], \
[img_crop_width, img_crop_height], [0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img_crop,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
def resize_norm_img(img, max_wh_ratio):
import math
imgC, imgH, imgW = 3, 32, 320
imgW = int(32 * max_wh_ratio)
h = img.shape[0]
w = img.shape[1]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
_, new_h, new_w = det_img.shape
filter_func = FilterBoxes(10, 10)
post_func = DBPostProcess({
"thresh": 0.3,
"box_thresh": 0.5,
"max_candidates": 1000,
"unclip_ratio": 1.5,
"min_size": 3
})
ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w]
dt_boxes_list = post_func(det_out["concat_1.tmp_0"], [ratio_list])
dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w])
dt_boxes = sorted_boxes(dt_boxes)
feed_list = []
img_list = []
max_wh_ratio = 0
for i, dtbox in enumerate(dt_boxes):
boximg = get_rotate_crop_image(im, dt_boxes[i])
img_list.append(boximg)
h, w = boximg.shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for img in img_list:
norm_img = resize_norm_img(img, max_wh_ratio)
feed = {"image": norm_img}
feed_list.append(feed)
fetch = ["ctc_greedy_decoder_0.tmp_0"]
#print("feed_list", feed_list)
return feed_list, fetch
def postprocess(self, feed={}, fetch=[], fetch_map=None):
#print(fetch_map)
ocr_reader = OCRReader()
rec_res = ocr_reader.postprocess(fetch_map)
res_lst = []
for res in rec_res:
res_lst.append(res[0])
fetch_map["res"] = res_lst
del fetch_map["ctc_greedy_decoder_0.tmp_0"]
del fetch_map["ctc_greedy_decoder_0.tmp_0.lod"]
return fetch_map
ocr_service = OCRService(name="ocr")
ocr_service.load_model_config("ocr_rec_model")
ocr_service.prepare_server(workdir="workdir", port=9292)
ocr_service.init_det_client(
det_port=9293,
det_client_config="ocr_det_client/serving_client_conf.prototxt")
ocr_service.run_rpc_service()
ocr_service.run_web_service()
......@@ -24,14 +24,15 @@ class ServingModels(object):
"SentimentAnalysis"] = ["senta_bilstm", "senta_bow", "senta_cnn"]
self.model_dict["SemanticRepresentation"] = ["ernie"]
self.model_dict["ChineseWordSegmentation"] = ["lac"]
self.model_dict["ObjectDetection"] = ["faster_rcnn", "yolov4"]
self.model_dict[
"ObjectDetection"] = ["faster_rcnn", "yolov4", "blazeface"]
self.model_dict["ImageSegmentation"] = [
"unet", "deeplabv3", "deeplabv3+cityscapes"
]
self.model_dict["ImageClassification"] = [
"resnet_v2_50_imagenet", "mobilenet_v2_imagenet"
]
self.model_dict["TextDetection"] = ["ocr_detection"]
self.model_dict["TextDetection"] = ["ocr_det"]
self.model_dict["OCR"] = ["ocr_rec"]
image_class_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/ImageClassification/"
......
......@@ -29,6 +29,7 @@ def normalize(img, mean, std, channel_first):
else:
img_mean = np.array(mean).reshape((1, 1, 3))
img_std = np.array(std).reshape((1, 1, 3))
img = np.array(img).astype("float32")
img -= img_mean
img /= img_std
return img
......
......@@ -440,6 +440,30 @@ class RCNNPostprocess(object):
self.label_file, self.output_dir)
class BlazeFacePostprocess(RCNNPostprocess):
def clip_bbox(self, bbox, im_size=None):
h = 1. if im_size is None else im_size[0]
w = 1. if im_size is None else im_size[1]
xmin = max(min(bbox[0], w), 0.)
ymin = max(min(bbox[1], h), 0.)
xmax = max(min(bbox[2], w), 0.)
ymax = max(min(bbox[3], h), 0.)
return xmin, ymin, xmax, ymax
def _get_bbox_result(self, fetch_map, fetch_name, clsid2catid):
result = {}
is_bbox_normalized = True #for blaze face, set true here
output = fetch_map[fetch_name]
lod = [fetch_map[fetch_name + '.lod']]
lengths = self._offset_to_lengths(lod)
np_data = np.array(output)
result['bbox'] = (np_data, lengths)
result['im_id'] = np.array([[0]])
result["im_shape"] = np.array(fetch_map["im_shape"]).astype(np.int32)
bbox_results = self._bbox2out([result], clsid2catid, is_bbox_normalized)
return bbox_results
class Sequential(object):
"""
Args:
......
......@@ -182,22 +182,26 @@ class OCRReader(object):
return norm_img_batch[0]
def postprocess(self, outputs):
def postprocess(self, outputs, with_score=False):
rec_res = []
rec_idx_lod = outputs["ctc_greedy_decoder_0.tmp_0.lod"]
predict_lod = outputs["softmax_0.tmp_0.lod"]
rec_idx_batch = outputs["ctc_greedy_decoder_0.tmp_0"]
if with_score:
predict_lod = outputs["softmax_0.tmp_0.lod"]
for rno in range(len(rec_idx_lod) - 1):
beg = rec_idx_lod[rno]
end = rec_idx_lod[rno + 1]
rec_idx_tmp = rec_idx_batch[beg:end, 0]
preds_text = self.char_ops.decode(rec_idx_tmp)
beg = predict_lod[rno]
end = predict_lod[rno + 1]
probs = outputs["softmax_0.tmp_0"][beg:end, :]
ind = np.argmax(probs, axis=1)
blank = probs.shape[1]
valid_ind = np.where(ind != (blank - 1))[0]
score = np.mean(probs[valid_ind, ind[valid_ind]])
rec_res.append([preds_text, score])
if with_score:
beg = predict_lod[rno]
end = predict_lod[rno + 1]
probs = outputs["softmax_0.tmp_0"][beg:end, :]
ind = np.argmax(probs, axis=1)
blank = probs.shape[1]
valid_ind = np.where(ind != (blank - 1))[0]
score = np.mean(probs[valid_ind, ind[valid_ind]])
rec_res.append([preds_text, score])
else:
rec_res.append([preds_text])
return rec_res
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册