提交 56367c7f 编写于 作者: B breezedeus

add `check_context`

上级 2d92c076
......@@ -4,7 +4,7 @@ REC_DATA_ROOT_DIR = data/sample-data-lst
# `EMB_MODEL_TYPE` 可取值:['conv', 'conv-lite-rnn', 'densenet', 'densenet-lite']
EMB_MODEL_TYPE = densenet-lite
# `SEQ_MODEL_TYPE` 可取值:['lstm', 'gru', 'fc']
SEQ_MODEL_TYPE = fc
SEQ_MODEL_TYPE = gru
MODEL_NAME = $(EMB_MODEL_TYPE)-$(SEQ_MODEL_TYPE)
# 产生 *.lst 文件
......
......@@ -33,6 +33,7 @@ from cnocr.utils import (
read_charset,
normalize_img_array,
check_model_name,
check_context,
)
from cnocr.line_split import line_split
......@@ -96,10 +97,10 @@ def load_module(prefix, epoch, data_names, data_shapes, network=None, context='c
pred_fc = sym.get_internals()['pred_fc_output']
sym = mx.sym.softmax(data=pred_fc)
if not check_context(context):
raise NotImplementedError('illegal value %s for parameter context' % context)
if isinstance(context, str):
context = mx.gpu() if context.lower() == 'gpu' else mx.cpu()
elif not isinstance(context, mx.Context):
raise NotImplementedError('illegal value %s for parameter context' % context)
mod = mx.mod.Module(
symbol=sym, context=context, data_names=data_names, label_names=None
......
......@@ -64,6 +64,16 @@ def gen_context(num_gpu):
return context
def check_context(context):
if isinstance(context, str):
return context.lower() in ('gpu', 'cpu')
if isinstance(context, list):
if len(context) < 1:
return False
return all(isinstance(ctx, mx.Context) for ctx in context)
return isinstance(context, mx.Context)
def data_dir_default():
"""
......
from mxnet.gluon.utils import download
def test_download():
url = 'https://www.dropbox.com/s/5n09nxf4x95jprk/cnocr-models-v0.1.0.zip?dl=1'
download(url, './cnocr-models.zip', overwrite=True)
\ No newline at end of file
import os
import sys
import pytest
import mxnet as mx
from mxnet.gluon.utils import download
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(1, os.path.dirname(os.path.abspath(__file__)))
from cnocr.utils import check_context
@pytest.mark.skip()
def test_download():
url = 'https://www.dropbox.com/s/5n09nxf4x95jprk/cnocr-models-v0.1.0.zip?dl=1'
download(url, './cnocr-models.zip', overwrite=True)
@pytest.mark.parametrize('context, expected', [
('gpu', True),
('cpu', True),
('', False),
('xx', False),
(mx.cpu(), True),
(mx.gpu(), True),
([mx.cpu()], True),
([mx.gpu()], True),
([mx.gpu(0), mx.gpu(1)], True),
([], False),
])
def test_check_context(context, expected):
assert check_context(context) == expected
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册