未验证 提交 d2d2abbc 编写于 作者: D dyning 提交者: GitHub

Merge pull request #356 from littletomatodonkey/fix_eval_comment

Fix eval comment and backbone stride
...@@ -219,7 +219,7 @@ PostProcessor::BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap, ...@@ -219,7 +219,7 @@ PostProcessor::BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap,
std::vector<std::vector<std::vector<int>>> boxes; std::vector<std::vector<std::vector<int>>> boxes;
for (int _i = 0; _i < num_contours; _i++) { for (int _i = 0; _i < num_contours; _i++) {
if (contours[_i].size() <= 0) { if (contours[_i].size() <= 2) {
continue; continue;
} }
float ssid; float ssid;
......
...@@ -82,7 +82,7 @@ def main(): ...@@ -82,7 +82,7 @@ def main():
'fetch_name_list':eval_fetch_name_list,\ 'fetch_name_list':eval_fetch_name_list,\
'fetch_varname_list':eval_fetch_varname_list} 'fetch_varname_list':eval_fetch_varname_list}
metrics = eval_det_run(exe, config, eval_info_dict, "eval") metrics = eval_det_run(exe, config, eval_info_dict, "eval")
print("Eval result", metrics) logger.info("Eval result: {}".format(metrics))
else: else:
reader_type = config['Global']['reader_yml'] reader_type = config['Global']['reader_yml']
if "benchmark" not in reader_type: if "benchmark" not in reader_type:
...@@ -92,7 +92,7 @@ def main(): ...@@ -92,7 +92,7 @@ def main():
'fetch_name_list': eval_fetch_name_list, \ 'fetch_name_list': eval_fetch_name_list, \
'fetch_varname_list': eval_fetch_varname_list} 'fetch_varname_list': eval_fetch_varname_list}
metrics = eval_rec_run(exe, config, eval_info_dict, "eval") metrics = eval_rec_run(exe, config, eval_info_dict, "eval")
print("Eval result:", metrics) logger.info("Eval result: {}".format(metrics))
else: else:
eval_info_dict = {'program':eval_program,\ eval_info_dict = {'program':eval_program,\
'fetch_name_list':eval_fetch_name_list,\ 'fetch_name_list':eval_fetch_name_list,\
......
...@@ -75,6 +75,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode): ...@@ -75,6 +75,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
char_ops, preds, preds_lod, labels, labels_lod, is_remove_duplicate) char_ops, preds, preds_lod, labels, labels_lod, is_remove_duplicate)
total_acc_num += acc_num total_acc_num += acc_num
total_sample_num += sample_num total_sample_num += sample_num
logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc))
total_batch_num += 1 total_batch_num += 1
avg_acc = total_acc_num * 1.0 / total_sample_num avg_acc = total_acc_num * 1.0 / total_sample_num
metrics = {'avg_acc': avg_acc, "total_acc_num": total_acc_num, \ metrics = {'avg_acc': avg_acc, "total_acc_num": total_acc_num, \
......
...@@ -70,7 +70,7 @@ def draw_det_res(dt_boxes, config, img, img_name): ...@@ -70,7 +70,7 @@ def draw_det_res(dt_boxes, config, img, img_name):
def main(): def main():
config = program.load_config(FLAGS.config) config = program.load_config(FLAGS.config)
program.merge_config(FLAGS.opt) program.merge_config(FLAGS.opt)
print(config) logger.info(config)
# check if set use_gpu=True in paddlepaddle cpu version # check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu'] use_gpu = config['Global']['use_gpu']
......
...@@ -84,7 +84,7 @@ def main(): ...@@ -84,7 +84,7 @@ def main():
if len(infer_list) == 0: if len(infer_list) == 0:
logger.info("Can not find img in infer_img dir.") logger.info("Can not find img in infer_img dir.")
for i in range(max_img_num): for i in range(max_img_num):
print("infer_img:%s" % infer_list[i]) logger.info("infer_img:%s" % infer_list[i])
img = next(blobs) img = next(blobs)
predict = exe.run(program=eval_prog, predict = exe.run(program=eval_prog,
feed={"image": img}, feed={"image": img},
...@@ -115,9 +115,9 @@ def main(): ...@@ -115,9 +115,9 @@ def main():
preds = preds.reshape(-1) preds = preds.reshape(-1)
preds_text = char_ops.decode(preds) preds_text = char_ops.decode(preds)
print("\t index:", preds) logger.info("\t index: {}".format(preds))
print("\t word :", preds_text) logger.info("\t word : {}".format(preds_text))
print("\t score :", score) logger.info("\t score: {}".format(score))
# save for inference model # save for inference model
target_var = [] target_var = []
......
...@@ -41,19 +41,19 @@ def draw_server_result(image_file, res): ...@@ -41,19 +41,19 @@ def draw_server_result(image_file, res):
if len(res) == 0: if len(res) == 0:
return np.array(image) return np.array(image)
keys = res[0].keys() keys = res[0].keys()
if 'text_region' not in keys: # for ocr_rec, draw function is invalid if 'text_region' not in keys: # for ocr_rec, draw function is invalid
print("draw function is invalid for ocr_rec!") logger.info("draw function is invalid for ocr_rec!")
return None return None
elif 'text' not in keys: # for ocr_det elif 'text' not in keys: # for ocr_det
print("draw text boxes only!") logger.info("draw text boxes only!")
boxes = [] boxes = []
for dno in range(len(res)): for dno in range(len(res)):
boxes.append(res[dno]['text_region']) boxes.append(res[dno]['text_region'])
boxes = np.array(boxes) boxes = np.array(boxes)
draw_img = draw_boxes(image, boxes) draw_img = draw_boxes(image, boxes)
return draw_img return draw_img
else: # for ocr_system else: # for ocr_system
print("draw boxes and texts!") logger.info("draw boxes and texts!")
boxes = [] boxes = []
texts = [] texts = []
scores = [] scores = []
...@@ -63,7 +63,8 @@ def draw_server_result(image_file, res): ...@@ -63,7 +63,8 @@ def draw_server_result(image_file, res):
scores.append(res[dno]['confidence']) scores.append(res[dno]['confidence'])
boxes = np.array(boxes) boxes = np.array(boxes)
scores = np.array(scores) scores = np.array(scores)
draw_img = draw_ocr(image, boxes, texts, scores, draw_txt=True, drop_score=0.5) draw_img = draw_ocr(
image, boxes, texts, scores, draw_txt=True, drop_score=0.5)
return draw_img return draw_img
...@@ -81,13 +82,13 @@ def main(url, image_path): ...@@ -81,13 +82,13 @@ def main(url, image_path):
# 发送HTTP请求 # 发送HTTP请求
starttime = time.time() starttime = time.time()
data = {'images':[cv2_to_base64(img)]} data = {'images': [cv2_to_base64(img)]}
r = requests.post(url=url, headers=headers, data=json.dumps(data)) r = requests.post(url=url, headers=headers, data=json.dumps(data))
elapse = time.time() - starttime elapse = time.time() - starttime
total_time += elapse total_time += elapse
print("Predict time of %s: %.3fs" % (image_file, elapse)) logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
res = r.json()["results"][0] res = r.json()["results"][0]
print(res) logger.info(res)
if is_visualize: if is_visualize:
draw_img = draw_server_result(image_file, res) draw_img = draw_server_result(image_file, res)
...@@ -98,16 +99,17 @@ def main(url, image_path): ...@@ -98,16 +99,17 @@ def main(url, image_path):
cv2.imwrite( cv2.imwrite(
os.path.join(draw_img_save, os.path.basename(image_file)), os.path.join(draw_img_save, os.path.basename(image_file)),
draw_img[:, :, ::-1]) draw_img[:, :, ::-1])
print("The visualized image saved in {}".format( logger.info("The visualized image saved in {}".format(
os.path.join(draw_img_save, os.path.basename(image_file)))) os.path.join(draw_img_save, os.path.basename(image_file))))
cnt += 1 cnt += 1
if cnt % 100 == 0: if cnt % 100 == 0:
print(cnt, "processed") logger.info("{} processed".format(cnt))
print("avg time cost: ", float(total_time)/cnt) logger.info("avg time cost: {}".format(float(total_time) / cnt))
if __name__ == '__main__':
if __name__ == '__main__':
if len(sys.argv) != 3: if len(sys.argv) != 3:
print("Usage: %s server_url image_path" % sys.argv[0]) logger.info("Usage: %s server_url image_path" % sys.argv[0])
else: else:
server_url = sys.argv[1] server_url = sys.argv[1]
image_path = sys.argv[2] image_path = sys.argv[2]
......
...@@ -118,7 +118,7 @@ def main(): ...@@ -118,7 +118,7 @@ def main():
def test_reader(): def test_reader():
config = program.load_config(FLAGS.config) config = program.load_config(FLAGS.config)
program.merge_config(FLAGS.opt) program.merge_config(FLAGS.opt)
print(config) logger.info(config)
train_reader = reader_main(config=config, mode="train") train_reader = reader_main(config=config, mode="train")
import time import time
starttime = time.time() starttime = time.time()
...@@ -129,7 +129,7 @@ def test_reader(): ...@@ -129,7 +129,7 @@ def test_reader():
if count % 1 == 0: if count % 1 == 0:
batch_time = time.time() - starttime batch_time = time.time() - starttime
starttime = time.time() starttime = time.time()
print("reader:", count, len(data), batch_time) logger.info("reader:", count, len(data), batch_time)
except Exception as e: except Exception as e:
logger.info(e) logger.info(e)
logger.info("finish reader: {}, Success!".format(count)) logger.info("finish reader: {}, Success!".format(count))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册