提交 ac252a34 编写于 作者: Z zhiboniu 提交者: zhiboniu

plate run ok

上级 3f5ff9e6
此差异已折叠。
......@@ -28,11 +28,10 @@ import sys
parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
sys.path.insert(0, parent_path)
from utils import Timer, get_current_memory_mb
from infer import Detector, get_test_images, print_arguments, create_inputs
from vechile_plateutils import create_predictor, get_infer_gpuid, argsparser, get_rotate_crop_image
from infer import get_test_images, print_arguments
from vechile_plateutils import create_predictor, get_infer_gpuid, argsparser, get_rotate_crop_image, draw_boxes
from vecplatepostprocess import build_post_process
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image, Resize_Mult32
from preprocess import preprocess, NormalizeImage, Permute, Resize_Mult32
class PlateDetector(object):
......@@ -93,7 +92,8 @@ class PlateDetector(object):
for im_path in image_list:
im, im_info = preprocess(im_path, preprocess_ops)
input_im_lst.append(im)
input_im_info_lst.append(im_info['im_shape'])
input_im_info_lst.append(im_info['im_shape'] /
im_info['scale_factor'])
return np.stack(input_im_lst, axis=0), input_im_info_lst
......@@ -136,18 +136,15 @@ class PlateDetector(object):
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def predict_image(self, img):
def predict_image(self, img_list):
st = time.time()
if self.args.run_benchmark:
self.autolog.times.start()
img, shape_list = self.preprocess(img)
img, shape_list = self.preprocess(img_list)
if img is None:
return None, 0
# img = np.expand_dims(img, axis=0)
# shape_list = np.expand_dims(shape_list, axis=0)
# img = img.copy()
if self.args.run_benchmark:
self.autolog.times.stamp()
......@@ -166,27 +163,28 @@ class PlateDetector(object):
#self.predictor.try_shrink_memory()
post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points']
dt_boxes = self.filter_tag_det_res(dt_boxes, shape_list[0])
dt_batch_boxes = []
for idx in range(len(post_result)):
org_shape = img_list[idx].shape
dt_boxes = post_result[idx]['points']
dt_boxes = self.filter_tag_det_res(dt_boxes, org_shape)
dt_batch_boxes.append(dt_boxes)
if self.args.run_benchmark:
self.autolog.times.end(stamp=True)
et = time.time()
return dt_boxes, et - st
return dt_batch_boxes, et - st
class TextRecognizer(object):
def __init__(self,
FLAGS,
input_shape=[3, 48, 320],
batch_size=8,
rec_algorithm="SVTR",
word_dict_path="rec_word_dict.txt",
use_gpu=True,
benchmark=False):
self.rec_image_shape = input_shape
self.rec_batch_num = batch_size
self.rec_algorithm = rec_algorithm
def __init__(self, FLAGS, use_gpu=True, benchmark=False):
self.rec_image_shape = [
int(v) for v in FLAGS.rec_image_shape.split(",")
]
self.rec_batch_num = FLAGS.rec_batch_num
self.rec_algorithm = FLAGS.rec_algorithm
word_dict_path = FLAGS.word_dict_path
isuse_space_char = True
postprocess_params = {
......@@ -398,7 +396,7 @@ class TextRecognizer(object):
return padding_im, resize_shape, pad_shape, valid_ratio
def __call__(self, img_list):
def predict_text(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
width_list = []
......@@ -549,28 +547,36 @@ class TextRecognizer(object):
class PlateRecognizer(object):
def __init__(self):
self.batch_size = 8
use_gpu = FLAGS.device.lower() == "gpu"
self.platedetector = PlateDetector(FLAGS)
self.textrecognizer = TextRecognizer(
FLAGS,
input_shape=[3, 48, 320],
batch_size=8,
rec_algorithm="SVTR",
word_dict_path="rec_word_dict.txt",
use_gpu=True,
benchmark=False)
FLAGS, use_gpu=use_gpu, benchmark=FLAGS.run_benchmark)
def get_platelicense(self, image_list):
plate_text_list = []
plateboxes, det_time = self.platedetector.predict_image(image_list)
for idx, boxes_pcar in enumerate(plateboxes):
plate_images = get_rotate_crop_image(image_list[idx], boxes_pcar)
print(plate_images.shape)
plate_texts = self.textrecognizer(plate_images)
plate_text_list.append(plate_texts)
import pdb
pdb.set_trace()
return results
for box in boxes_pcar:
plate_images = get_rotate_crop_image(image_list[idx], box)
plate_texts = self.textrecognizer.predict_text([plate_images])
plate_text_list.append(plate_texts)
print("plate text:{}".format(plate_texts))
newimg = draw_boxes(image_list[idx], boxes_pcar)
cv2.imwrite("vechile_plate.jpg", newimg)
return self.check_plate(plate_text_list)
def check_plate(self, text_list):
simcode = [
'浙', '粤', '京', '津', '冀', '晋', '蒙', '辽', '黑', '沪', '吉', '苏', '皖',
'赣', '鲁', '豫', '鄂', '湘', '桂', '琼', '渝', '川', '贵', '云', '藏', '陕',
'甘', '青', '宁'
]
for text_info in text_list:
# import pdb;pdb.set_trace()
text = text_info[0][0][0]
if len(text) > 2 and text[0] in simcode and len(text) < 10:
print("text:{} length:{}".format(text, len(text)))
return text
def main():
......@@ -581,17 +587,16 @@ def main():
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
for img in img_list:
image = cv2.imread(img)
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
results = detector.get_platelicense([image])
if not FLAGS.run_benchmark:
detector.det_times.info(average=True)
else:
if FLAGS.run_benchmark:
mems = {
'cpu_rss_mb': detector.cpu_mem / len(img_list),
'gpu_rss_mb': detector.gpu_mem / len(img_list),
'gpu_util': detector.gpu_util * 100 / len(img_list)
}
perf_info = detector.det_times.report(average=True)
perf_info = detector.self.autolog.times.report(average=True)
model_dir = FLAGS.model_dir
mode = FLAGS.run_mode
model_info = {
......
......@@ -40,6 +40,10 @@ def argsparser():
parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument(
"--word_dict_path",
type=str,
default="deploy/pphuman/rec_word_dict.txt")
parser.add_argument(
"--image_file", type=str, default=None, help="Path of image file.")
parser.add_argument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册