From 13ed5e22e88f05c68752cb4e3761119e0db6558d Mon Sep 17 00:00:00 2001 From: breezedeus Date: Fri, 29 May 2020 17:14:12 +0800 Subject: [PATCH] add `set_cand_alphabet` --- cnocr/cn_ocr.py | 18 ++++++++++++++---- tests/test_cnocr.py | 20 +++++++++++++++++--- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/cnocr/cn_ocr.py b/cnocr/cn_ocr.py index e340ca6..b7c820e 100644 --- a/cnocr/cn_ocr.py +++ b/cnocr/cn_ocr.py @@ -160,14 +160,12 @@ class CnOcr(object): root = os.path.join(root, MODEL_VERSION) self._model_dir = os.path.join(root, self._model_name) self._assert_and_prepare_model_files() - self._alphabet, inv_alph_dict = read_charset( + self._alphabet, self._inv_alph_dict = read_charset( os.path.join(self._model_dir, 'label_cn.txt') ) self._cand_alph_idx = None - if cand_alphabet is not None: - self._cand_alph_idx = [0] + [inv_alph_dict[word] for word in cand_alphabet] - self._cand_alph_idx.sort() + self.set_cand_alphabet(cand_alphabet) self._hp = Hyperparams() self._hp._loss_type = None # infer mode @@ -214,6 +212,18 @@ class CnOcr(object): ) return mod + def set_cand_alphabet(self, cand_alphabet): + """ + 设置待识别字符的候选集合。 + :param cand_alphabet: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围 + :return: None + """ + if cand_alphabet is None: + self._cand_alph_idx = None + else: + self._cand_alph_idx = [0] + [self._inv_alph_dict[word] for word in cand_alphabet] + self._cand_alph_idx.sort() + def ocr(self, img_fp): """ :param img_fp: image file path; or color image mx.nd.NDArray or np.ndarray, diff --git a/tests/test_cnocr.py b/tests/test_cnocr.py index 8349e58..5fd9f50 100644 --- a/tests/test_cnocr.py +++ b/tests/test_cnocr.py @@ -177,9 +177,7 @@ def test_gray_aug(img_fp, expected): print(res_img.shape, res_img.dtype) -def test_cand_alphabet(): - from cnocr import NUMBERS - +def test_cand_alphabet1(): img_fp = os.path.join(example_dir, 'hybrid.png') ocr = CnOcr(name='instance1') @@ -195,6 +193,22 @@ def test_cand_alphabet(): assert len(pred) == 1 and pred[0] == '012345678' +def test_cand_alphabet2(): + img_fp = os.path.join(example_dir, 'hybrid.png') + + ocr = CnOcr(name='instance1') + pred = ocr.ocr(img_fp) + pred = [''.join(line_p) for line_p in pred] + print("Predicted Chars:", pred) + assert len(pred) == 1 and pred[0] == 'o12345678' + + ocr.set_cand_alphabet(NUMBERS) + pred = ocr.ocr(img_fp) + pred = [''.join(line_p) for line_p in pred] + print("Predicted Chars:", pred) + assert len(pred) == 1 and pred[0] == '012345678' + + INSTANCE_ID = 0 -- GitLab