未验证 提交 a9d70931 编写于 作者: S shaohua.zhang 提交者: GitHub

fix some errors and bugs (#1185)

fix some errors in pip
上级 03616470
...@@ -87,8 +87,8 @@ def download_with_progressbar(url, save_path): ...@@ -87,8 +87,8 @@ def download_with_progressbar(url, save_path):
progress_bar.update(len(data)) progress_bar.update(len(data))
file.write(data) file.write(data)
progress_bar.close() progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
logger.error("ERROR, something went wrong") logger.error("Something went wrong while downloading models")
sys.exit(0) sys.exit(0)
...@@ -157,7 +157,6 @@ def parse_args(): ...@@ -157,7 +157,6 @@ def parse_args():
parser.add_argument("--use_space_char", type=bool, default=True) parser.add_argument("--use_space_char", type=bool, default=True)
# params for text classifier # params for text classifier
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
parser.add_argument("--cls_model_dir", type=str, default=None) parser.add_argument("--cls_model_dir", type=str, default=None)
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
parser.add_argument("--label_list", type=list, default=['0', '180']) parser.add_argument("--label_list", type=list, default=['0', '180'])
...@@ -171,7 +170,7 @@ def parse_args(): ...@@ -171,7 +170,7 @@ def parse_args():
parser.add_argument("--lang", type=str, default='ch') parser.add_argument("--lang", type=str, default='ch')
parser.add_argument("--det", type=str2bool, default=True) parser.add_argument("--det", type=str2bool, default=True)
parser.add_argument("--rec", type=str2bool, default=True) parser.add_argument("--rec", type=str2bool, default=True)
parser.add_argument("--cls", type=str2bool, default=False) parser.add_argument("--use_angle_cls", type=str2bool, default=True)
return parser.parse_args() return parser.parse_args()
...@@ -206,7 +205,6 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -206,7 +205,6 @@ class PaddleOCR(predict_system.TextSystem):
maybe_download(postprocess_params.det_model_dir, model_urls['det']) maybe_download(postprocess_params.det_model_dir, model_urls['det'])
maybe_download(postprocess_params.rec_model_dir, maybe_download(postprocess_params.rec_model_dir,
model_urls['rec'][lang]['url']) model_urls['rec'][lang]['url'])
if self.use_angle_cls:
maybe_download(postprocess_params.cls_model_dir, model_urls['cls']) maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL: if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
...@@ -231,9 +229,6 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -231,9 +229,6 @@ class PaddleOCR(predict_system.TextSystem):
rec: use text recognition or not, if false, only det will be exec. default is True rec: use text recognition or not, if false, only det will be exec. default is True
""" """
assert isinstance(img, (np.ndarray, list, str)) assert isinstance(img, (np.ndarray, list, str))
if cls and not self.use_angle_cls:
print('cls should be false when use_angle_cls is false')
exit(-1)
self.use_angle_cls = cls self.use_angle_cls = cls
if isinstance(img, str): if isinstance(img, str):
image_file = img image_file = img
...@@ -275,6 +270,7 @@ def main(): ...@@ -275,6 +270,7 @@ def main():
result = ocr_engine.ocr(img_path, result = ocr_engine.ocr(img_path,
det=args.det, det=args.det,
rec=args.rec, rec=args.rec,
cls=args.cls) cls=args.use_angle_cls)
if result is not None:
for line in result: for line in result:
print(line) print(line)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册