提交 e9adf01d 编写于 作者: B breezedeus

use MODEL_VERSION instead of __version__

上级 90076695
from .cn_ocr import CnOcr
\ No newline at end of file
from .cn_ocr import CnOcr
from .consts import MODEL_VERSION, AVAILABLE_MODELS, NUMBERS, ENG_LETTERS
......@@ -16,12 +16,12 @@
# specific language governing permissions and limitations
# under the License.
import os
import logging
import mxnet as mx
import numpy as np
from PIL import Image
from cnocr.__version__ import __version__
from cnocr.consts import AVAILABLE_MODELS
from cnocr import MODEL_VERSION, AVAILABLE_MODELS
from cnocr.hyperparams.cn_hyperparams import CnHyperparams as Hyperparams
from cnocr.fit.lstm import init_states
from cnocr.fit.ctc_metrics import CtcMetrics
......@@ -37,6 +37,9 @@ from cnocr.utils import (
from cnocr.line_split import line_split
logger = logging.getLogger(__name__)
def read_ocr_img(path):
"""
:param path: image file path
......@@ -102,7 +105,7 @@ def load_module(prefix, epoch, data_names, data_shapes, network=None):
class CnOcr(object):
MODEL_FILE_PREFIX = 'cnocr-v{}'.format(__version__)
MODEL_FILE_PREFIX = 'cnocr-v{}'.format(MODEL_VERSION)
def __init__(
self,
......@@ -125,7 +128,7 @@ class CnOcr(object):
self._model_file_prefix = '{}-{}'.format(self.MODEL_FILE_PREFIX, model_name)
self._model_epoch = model_epoch or AVAILABLE_MODELS[model_name][0]
root = os.path.join(root, __version__)
root = os.path.join(root, MODEL_VERSION)
self._model_dir = os.path.join(root, self._model_name)
self._assert_and_prepare_model_files()
self._alphabet, inv_alph_dict = read_charset(
......@@ -167,7 +170,7 @@ class CnOcr(object):
prefix = os.path.join(self._model_dir, self._model_file_prefix)
data_names = ['data']
data_shapes = [(data_names[0], (hp.batch_size, 1, hp.img_height, hp.img_width))]
print('loading model parameters from dir %s' % self._model_dir)
logger.info('loading model parameters from dir %s' % self._model_dir)
mod = load_module(
prefix, self._model_epoch, data_names, data_shapes, network=network
)
......
# coding: utf-8
import os
import string
from .__version__ import __version__
# 模型版本只对应到第二层,第三层的改动表示模型兼容。
# 如: __version__ = '1.2.*',对应的 MODEL_VERSION 都是 '1.2.0'
MODEL_VERSION = '.'.join(__version__.split('.', maxsplit=2)[:2]) + '.0'
EMB_MODEL_TYPES = ['conv', 'conv-lite', 'densenet', 'densenet-lite']
SEQ_MODEL_TYPES = ['lstm', 'gru', 'fc']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册