未验证 提交 01227c98 编写于 作者: D dyning 提交者: GitHub

Merge pull request #328 from LDOUBLEV/fixocr

fix bug det train
...@@ -96,6 +96,7 @@ class EvalTestReader(object): ...@@ -96,6 +96,7 @@ class EvalTestReader(object):
img = cv2.imread(img_path) img = cv2.imread(img_path)
if img is None: if img is None:
logger.info("{} does not exist!".format(img_path)) logger.info("{} does not exist!".format(img_path))
continue
elif len(list(img.shape)) == 2 or img.shape[2] == 1: elif len(list(img.shape)) == 2 or img.shape[2] == 1:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
outs = process_function(img) outs = process_function(img)
......
...@@ -256,15 +256,15 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict): ...@@ -256,15 +256,15 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
t2 = time.time() t2 = time.time()
train_batch_elapse = t2 - t1 train_batch_elapse = t2 - t1
train_stats.update(stats) train_stats.update(stats)
if train_batch_id > start_eval_step and (train_batch_id -start_eval_step) \ if train_batch_id > 0 and train_batch_id \
% print_batch_step == 0: % print_batch_step == 0:
logs = train_stats.log() logs = train_stats.log()
strs = 'epoch: {}, iter: {}, {}, time: {:.3f}'.format( strs = 'epoch: {}, iter: {}, {}, time: {:.3f}'.format(
epoch, train_batch_id, logs, train_batch_elapse) epoch, train_batch_id, logs, train_batch_elapse)
logger.info(strs) logger.info(strs)
if train_batch_id > 0 and\ if train_batch_id > start_eval_step and\
train_batch_id % eval_batch_step == 0: (train_batch_id - start_eval_step) % eval_batch_step == 0:
metrics = eval_det_run(exe, config, eval_info_dict, "eval") metrics = eval_det_run(exe, config, eval_info_dict, "eval")
hmean = metrics['hmean'] hmean = metrics['hmean']
if hmean >= best_eval_hmean: if hmean >= best_eval_hmean:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册