提交 8446df46 编写于 作者: B breezedeus

replace print with logger

上级 e9adf01d
......@@ -22,6 +22,7 @@ from __future__ import print_function
import sys
import os
import time
import logging
import argparse
from operator import itemgetter
from pathlib import Path
......@@ -32,6 +33,10 @@ import Levenshtein
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from cnocr import CnOcr
from cnocr.utils import set_logger
logger = set_logger(log_level=logging.INFO)
def evaluate():
......@@ -76,7 +81,7 @@ def evaluate():
bad_cnt = 0
badcases = []
while start_idx < len(fn_labels_list):
print('start_idx: ', start_idx)
logger.info('start_idx: %d', start_idx)
batch = fn_labels_list[start_idx : start_idx + args.batch_size]
batch_img_fns = []
batch_labels = []
......@@ -95,7 +100,7 @@ def evaluate():
batch_preds, batch_labels, batch_img_fns, alphabet
):
if args.verbose:
print('\t'.join(bad_info))
logger.info('\t'.join(bad_info))
distance = Levenshtein.distance(bad_info[1], bad_info[2])
bad_info.insert(0, distance)
badcases.append(bad_info)
......@@ -133,7 +138,7 @@ def evaluate():
for word, num in redundant_cnt.most_common():
f.write('\t'.join([word, str(num)]) + '\n')
print(
logger.info(
"number of total cases: %d, time cost per image: %f, number of bad cases: %d"
% (len(fn_labels_list), model_time_cost / len(fn_labels_list), bad_cnt)
)
......
......@@ -21,11 +21,16 @@ from __future__ import print_function
import sys
import os
import logging
import argparse
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from cnocr import CnOcr
from cnocr.utils import set_logger
logger = set_logger(log_level=logging.INFO)
def main():
......@@ -48,7 +53,7 @@ def main():
res = ocr.ocr_for_single_line(args.file)
else:
res = ocr.ocr(args.file)
print("Predicted Chars:", res)
logger.info("Predicted Chars: %s", res)
if __name__ == '__main__':
......
......@@ -27,7 +27,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from cnocr.__version__ import __version__
from cnocr.consts import EMB_MODEL_TYPES, SEQ_MODEL_TYPES
from cnocr.utils import data_dir
from cnocr.utils import data_dir, set_logger
from cnocr.hyperparams.cn_hyperparams import CnHyperparams
from cnocr.data_utils.data_iter import GrayImageIter
from cnocr.data_utils.aug import FgBgFlipAug
......@@ -36,6 +36,9 @@ from cnocr.fit.ctc_metrics import CtcMetrics
from cnocr.fit.fit import fit
logger = set_logger(log_level=logging.INFO)
def parse_args():
# Parse command line arguments
parser = argparse.ArgumentParser()
......@@ -114,7 +117,7 @@ def train_cnocr(args):
logging.basicConfig(level=logging.DEBUG, format=head)
args.model_name = args.emb_model_type + '-' + args.seq_model_type
out_dir = os.path.join(args.out_model_dir, args.model_name)
print('save models to dir: %s' % out_dir, flush=True)
logger.info('save models to dir: %s' % out_dir)
if not os.path.exists(out_dir):
os.makedirs(out_dir)
args.prefix = os.path.join(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册