提交 16752dd7 编写于 作者: T tink2123

modified default shape

上级 8f7518cd
...@@ -131,7 +131,7 @@ class TextRecognizer(object): ...@@ -131,7 +131,7 @@ class TextRecognizer(object):
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image padding_im[:, :, 0:resized_w] = resized_image
return padding_im return padding_im
def resize_norm_img_svtr(self, img, image_shape): def resize_norm_img_svtr(self, img, image_shape):
imgC, imgH, imgW = image_shape imgC, imgH, imgW = image_shape
...@@ -274,7 +274,7 @@ class TextRecognizer(object): ...@@ -274,7 +274,7 @@ class TextRecognizer(object):
wh_ratio = w * 1.0 / h wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio) max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no): for ino in range(beg_img_no, end_img_no):
if self.rec_algorithm == "SAR": if self.rec_algorithm == "SAR":
norm_img, _, _, valid_ratio = self.resize_norm_img_sar( norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
img_list[indices[ino]], self.rec_image_shape) img_list[indices[ino]], self.rec_image_shape)
...@@ -296,8 +296,8 @@ class TextRecognizer(object): ...@@ -296,8 +296,8 @@ class TextRecognizer(object):
gsrm_slf_attn_bias2_list.append(norm_img[4]) gsrm_slf_attn_bias2_list.append(norm_img[4])
norm_img_batch.append(norm_img[0]) norm_img_batch.append(norm_img[0])
elif self.rec_algorithm == "SVTR": elif self.rec_algorithm == "SVTR":
norm_img = self.resize_norm_img_svtr( norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
img_list[indices[ino]], self.rec_image_shape) self.rec_image_shape)
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img) norm_img_batch.append(norm_img)
else: else:
...@@ -405,9 +405,13 @@ def main(args): ...@@ -405,9 +405,13 @@ def main(args):
valid_image_file_list = [] valid_image_file_list = []
img_list = [] img_list = []
logger.info(
"In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
"if you are using an older PP-OCR, please set --rec_image_shape='3,32,320'"
)
# warmup 2 times # warmup 2 times
if args.warmup: if args.warmup:
img = np.random.uniform(0, 255, [32, 320, 3]).astype(np.uint8) img = np.random.uniform(0, 255, [48, 320, 3]).astype(np.uint8)
for i in range(2): for i in range(2):
res = text_recognizer([img] * int(args.rec_batch_num)) res = text_recognizer([img] * int(args.rec_batch_num))
......
...@@ -59,7 +59,7 @@ class TextSystem(object): ...@@ -59,7 +59,7 @@ class TextSystem(object):
for bno in range(bbox_num): for bno in range(bbox_num):
cv2.imwrite( cv2.imwrite(
os.path.join(output_dir, os.path.join(output_dir,
f"mg_crop_{bno+self.crop_image_res_index}.jpg"), f"mg_crop_{bno+self.crop_image_res_index}.jpg "),
img_crop_list[bno]) img_crop_list[bno])
logger.debug(f"{bno}, {rec_res[bno]}") logger.debug(f"{bno}, {rec_res[bno]}")
self.crop_image_res_index += bbox_num self.crop_image_res_index += bbox_num
...@@ -133,6 +133,9 @@ def main(args): ...@@ -133,6 +133,9 @@ def main(args):
os.makedirs(draw_img_save_dir, exist_ok=True) os.makedirs(draw_img_save_dir, exist_ok=True)
save_results = [] save_results = []
logger.info("In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
"if you are using an older PP-OCR, please set --rec_image_shape='3,32,320'")
# warm up 10 times # warm up 10 times
if args.warmup: if args.warmup:
img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
......
...@@ -81,7 +81,7 @@ def init_args(): ...@@ -81,7 +81,7 @@ def init_args():
# params for text recognizer # params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN') parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str) parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
parser.add_argument("--rec_batch_num", type=int, default=6) parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument("--max_text_length", type=int, default=25) parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册