未验证 提交 d2d2abbc 编写于 作者: D dyning 提交者: GitHub

Merge pull request #356 from littletomatodonkey/fix_eval_comment

Fix eval comment and backbone stride
......@@ -219,7 +219,7 @@ PostProcessor::BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap,
std::vector<std::vector<std::vector<int>>> boxes;
for (int _i = 0; _i < num_contours; _i++) {
if (contours[_i].size() <= 0) {
if (contours[_i].size() <= 2) {
continue;
}
float ssid;
......
......@@ -82,7 +82,7 @@ def main():
'fetch_name_list':eval_fetch_name_list,\
'fetch_varname_list':eval_fetch_varname_list}
metrics = eval_det_run(exe, config, eval_info_dict, "eval")
print("Eval result", metrics)
logger.info("Eval result: {}".format(metrics))
else:
reader_type = config['Global']['reader_yml']
if "benchmark" not in reader_type:
......@@ -92,7 +92,7 @@ def main():
'fetch_name_list': eval_fetch_name_list, \
'fetch_varname_list': eval_fetch_varname_list}
metrics = eval_rec_run(exe, config, eval_info_dict, "eval")
print("Eval result:", metrics)
logger.info("Eval result: {}".format(metrics))
else:
eval_info_dict = {'program':eval_program,\
'fetch_name_list':eval_fetch_name_list,\
......
......@@ -75,6 +75,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
char_ops, preds, preds_lod, labels, labels_lod, is_remove_duplicate)
total_acc_num += acc_num
total_sample_num += sample_num
logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc))
total_batch_num += 1
avg_acc = total_acc_num * 1.0 / total_sample_num
metrics = {'avg_acc': avg_acc, "total_acc_num": total_acc_num, \
......
......@@ -70,7 +70,7 @@ def draw_det_res(dt_boxes, config, img, img_name):
def main():
config = program.load_config(FLAGS.config)
program.merge_config(FLAGS.opt)
print(config)
logger.info(config)
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu']
......
......@@ -84,7 +84,7 @@ def main():
if len(infer_list) == 0:
logger.info("Can not find img in infer_img dir.")
for i in range(max_img_num):
print("infer_img:%s" % infer_list[i])
logger.info("infer_img:%s" % infer_list[i])
img = next(blobs)
predict = exe.run(program=eval_prog,
feed={"image": img},
......@@ -115,9 +115,9 @@ def main():
preds = preds.reshape(-1)
preds_text = char_ops.decode(preds)
print("\t index:", preds)
print("\t word :", preds_text)
print("\t score :", score)
logger.info("\t index: {}".format(preds))
logger.info("\t word : {}".format(preds_text))
logger.info("\t score: {}".format(score))
# save for inference model
target_var = []
......
......@@ -41,19 +41,19 @@ def draw_server_result(image_file, res):
if len(res) == 0:
return np.array(image)
keys = res[0].keys()
if 'text_region' not in keys: # for ocr_rec, draw function is invalid
print("draw function is invalid for ocr_rec!")
if 'text_region' not in keys: # for ocr_rec, draw function is invalid
logger.info("draw function is invalid for ocr_rec!")
return None
elif 'text' not in keys: # for ocr_det
print("draw text boxes only!")
elif 'text' not in keys: # for ocr_det
logger.info("draw text boxes only!")
boxes = []
for dno in range(len(res)):
boxes.append(res[dno]['text_region'])
boxes = np.array(boxes)
draw_img = draw_boxes(image, boxes)
return draw_img
else: # for ocr_system
print("draw boxes and texts!")
else: # for ocr_system
logger.info("draw boxes and texts!")
boxes = []
texts = []
scores = []
......@@ -63,7 +63,8 @@ def draw_server_result(image_file, res):
scores.append(res[dno]['confidence'])
boxes = np.array(boxes)
scores = np.array(scores)
draw_img = draw_ocr(image, boxes, texts, scores, draw_txt=True, drop_score=0.5)
draw_img = draw_ocr(
image, boxes, texts, scores, draw_txt=True, drop_score=0.5)
return draw_img
......@@ -81,13 +82,13 @@ def main(url, image_path):
# 发送HTTP请求
starttime = time.time()
data = {'images':[cv2_to_base64(img)]}
data = {'images': [cv2_to_base64(img)]}
r = requests.post(url=url, headers=headers, data=json.dumps(data))
elapse = time.time() - starttime
total_time += elapse
print("Predict time of %s: %.3fs" % (image_file, elapse))
logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
res = r.json()["results"][0]
print(res)
logger.info(res)
if is_visualize:
draw_img = draw_server_result(image_file, res)
......@@ -98,16 +99,17 @@ def main(url, image_path):
cv2.imwrite(
os.path.join(draw_img_save, os.path.basename(image_file)),
draw_img[:, :, ::-1])
print("The visualized image saved in {}".format(
logger.info("The visualized image saved in {}".format(
os.path.join(draw_img_save, os.path.basename(image_file))))
cnt += 1
if cnt % 100 == 0:
print(cnt, "processed")
print("avg time cost: ", float(total_time)/cnt)
logger.info("{} processed".format(cnt))
logger.info("avg time cost: {}".format(float(total_time) / cnt))
if __name__ == '__main__':
if __name__ == '__main__':
if len(sys.argv) != 3:
print("Usage: %s server_url image_path" % sys.argv[0])
logger.info("Usage: %s server_url image_path" % sys.argv[0])
else:
server_url = sys.argv[1]
image_path = sys.argv[2]
......
......@@ -118,7 +118,7 @@ def main():
def test_reader():
config = program.load_config(FLAGS.config)
program.merge_config(FLAGS.opt)
print(config)
logger.info(config)
train_reader = reader_main(config=config, mode="train")
import time
starttime = time.time()
......@@ -129,7 +129,7 @@ def test_reader():
if count % 1 == 0:
batch_time = time.time() - starttime
starttime = time.time()
print("reader:", count, len(data), batch_time)
logger.info("reader:", count, len(data), batch_time)
except Exception as e:
logger.info(e)
logger.info("finish reader: {}, Success!".format(count))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册