提交 826a92d8 编写于 作者: B breezedeus

use param `gpu` at cmd `cnocr_evaluate`

上级 8f41bdc4
......@@ -2,6 +2,8 @@ import logging
import os
import mxnet as mx
from ..utils import gen_context
def _load_model(args):
if 'load_epoch' not in args or args.load_epoch is None:
......@@ -16,10 +18,7 @@ def _load_model(args):
def fit(network, data_train, data_val, metrics, args, hp, data_names=None):
if args.gpu > 0:
contexts = [mx.context.gpu(i) for i in range(args.gpu)]
else:
contexts = [mx.context.cpu()]
context = gen_context(args.gpu)
logging.info('hp: %s', hp)
sym, arg_params, aux_params = _load_model(args)
......@@ -32,7 +31,7 @@ def fit(network, data_train, data_val, metrics, args, hp, data_names=None):
symbol=network,
data_names=["data"] if data_names is None else data_names,
label_names=['label'],
context=contexts,
context=context,
)
begin_epoch = args.load_epoch if args.load_epoch else 0
......
......@@ -20,6 +20,7 @@ from pathlib import Path
import logging
import platform
import zipfile
import mxnet as mx
from mxnet.gluon.utils import download
from .consts import AVAILABLE_MODELS, EMB_MODEL_TYPES, SEQ_MODEL_TYPES
......@@ -55,6 +56,14 @@ def set_logger(log_file=None, log_level=logging.INFO, log_file_level=logging.NOT
return logger
def gen_context(num_gpu):
if num_gpu > 0:
context = [mx.context.gpu(i) for i in range(num_gpu)]
else:
context = [mx.context.cpu()]
return context
def data_dir_default():
"""
......
......@@ -33,7 +33,7 @@ import Levenshtein
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from cnocr import CnOcr
from cnocr.utils import set_logger
from cnocr.utils import set_logger, gen_context
logger = set_logger(log_level=logging.INFO)
......@@ -46,11 +46,10 @@ def evaluate():
)
parser.add_argument("--model-epoch", type=int, default=None, help="model epoch")
parser.add_argument(
"--context",
type=str,
default='cpu',
choices=['cpu', 'gpu'],
help="which context to run inferences",
"--gpu",
help="Number of GPUs for training [Default 0, means using cpu]",
type=int,
default=0,
)
parser.add_argument(
"-i",
......@@ -76,9 +75,10 @@ def evaluate():
help="the output directory which records the analysis results",
)
args = parser.parse_args()
context = gen_context(args.gpu)
ocr = CnOcr(
model_name=args.model_name, model_epoch=args.model_epoch, context=args.context
model_name=args.model_name, model_epoch=args.model_epoch, context=context
)
alphabet = ocr._alphabet
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册