提交 e3f171a3 编写于 作者: L LDOUBLEV

fix det inference bug and optimize save path

上级 ab8fd7d9
...@@ -4,7 +4,7 @@ Global: ...@@ -4,7 +4,7 @@ Global:
epoch_num: 1200 epoch_num: 1200
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 2 print_batch_step: 2
save_model_dir: output save_model_dir: ./output/det_db/
save_epoch_step: 200 save_epoch_step: 200
eval_batch_step: 5000 eval_batch_step: 5000
train_batch_size_per_card: 16 train_batch_size_per_card: 16
...@@ -13,7 +13,7 @@ Global: ...@@ -13,7 +13,7 @@ Global:
reader_yml: ./configs/det/det_db_icdar15_reader.yml reader_yml: ./configs/det/det_db_icdar15_reader.yml
pretrain_weights: ./pretrain_models/MobileNetV3_pretrained/MobileNetV3_large_x0_5_pretrained/ pretrain_weights: ./pretrain_models/MobileNetV3_pretrained/MobileNetV3_large_x0_5_pretrained/
checkpoints: checkpoints:
save_res_path: ./output/predicts_db.txt save_res_path: ./output/det_db/predicts_db.txt
save_inference_dir: save_inference_dir:
Architecture: Architecture:
......
...@@ -4,7 +4,7 @@ Global: ...@@ -4,7 +4,7 @@ Global:
epoch_num: 1200 epoch_num: 1200
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 2 print_batch_step: 2
save_model_dir: output save_model_dir: ./output/det_db/
save_epoch_step: 200 save_epoch_step: 200
eval_batch_step: 5000 eval_batch_step: 5000
train_batch_size_per_card: 8 train_batch_size_per_card: 8
...@@ -12,7 +12,9 @@ Global: ...@@ -12,7 +12,9 @@ Global:
image_shape: [3, 640, 640] image_shape: [3, 640, 640]
reader_yml: ./configs/det/det_db_icdar15_reader.yml reader_yml: ./configs/det/det_db_icdar15_reader.yml
pretrain_weights: ./pretrain_models/ResNet50_vd_pretrained/ pretrain_weights: ./pretrain_models/ResNet50_vd_pretrained/
save_res_path: ./output/predicts_db.txt save_res_path: ./output/det_db/predicts_db.txt
checkpoints:
save_inference_dir:
Architecture: Architecture:
function: ppocr.modeling.architectures.det_model,DetModel function: ppocr.modeling.architectures.det_model,DetModel
......
...@@ -4,7 +4,7 @@ Global: ...@@ -4,7 +4,7 @@ Global:
epoch_num: 100000 epoch_num: 100000
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 5 print_batch_step: 5
save_model_dir: output save_model_dir: ./output/det_east/
save_epoch_step: 200 save_epoch_step: 200
eval_batch_step: 5000 eval_batch_step: 5000
train_batch_size_per_card: 16 train_batch_size_per_card: 16
...@@ -12,7 +12,9 @@ Global: ...@@ -12,7 +12,9 @@ Global:
image_shape: [3, 512, 512] image_shape: [3, 512, 512]
reader_yml: ./configs/det/det_east_icdar15_reader.yml reader_yml: ./configs/det/det_east_icdar15_reader.yml
pretrain_weights: ./pretrain_models/MobileNetV3_pretrained/MobileNetV3_large_x0_5_pretrained/ pretrain_weights: ./pretrain_models/MobileNetV3_pretrained/MobileNetV3_large_x0_5_pretrained/
save_res_path: ./output/predicts_east.txt checkpoints:
save_res_path: ./output/det_east/predicts_east.txt
save_inference_dir:
Architecture: Architecture:
function: ppocr.modeling.architectures.det_model,DetModel function: ppocr.modeling.architectures.det_model,DetModel
......
...@@ -4,7 +4,7 @@ Global: ...@@ -4,7 +4,7 @@ Global:
epoch_num: 100000 epoch_num: 100000
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 5 print_batch_step: 5
save_model_dir: output save_model_dir: ./output/det_east/
save_epoch_step: 200 save_epoch_step: 200
eval_batch_step: 5000 eval_batch_step: 5000
train_batch_size_per_card: 8 train_batch_size_per_card: 8
...@@ -12,7 +12,9 @@ Global: ...@@ -12,7 +12,9 @@ Global:
image_shape: [3, 512, 512] image_shape: [3, 512, 512]
reader_yml: ./configs/det/det_east_icdar15_reader.yml reader_yml: ./configs/det/det_east_icdar15_reader.yml
pretrain_weights: ./pretrain_models/ResNet50_vd_pretrained/ pretrain_weights: ./pretrain_models/ResNet50_vd_pretrained/
save_res_path: ./output/predicts_east.txt save_res_path: ./output/det_east/predicts_east.txt
checkpoints:
save_inference_dir:
Architecture: Architecture:
function: ppocr.modeling.architectures.det_model,DetModel function: ppocr.modeling.architectures.det_model,DetModel
......
...@@ -4,7 +4,7 @@ Global: ...@@ -4,7 +4,7 @@ Global:
epoch_num: 72 epoch_num: 72
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
save_model_dir: output save_model_dir: output/rec
save_epoch_step: 3 save_epoch_step: 3
eval_batch_step: 2000 eval_batch_step: 2000
train_batch_size_per_card: 256 train_batch_size_per_card: 256
...@@ -15,6 +15,8 @@ Global: ...@@ -15,6 +15,8 @@ Global:
loss_type: ctc loss_type: ctc
reader_yml: ./configs/rec/rec_benchmark_reader.yml reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights: pretrain_weights:
checkpoints:
save_inference_dir:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -4,7 +4,7 @@ Global: ...@@ -4,7 +4,7 @@ Global:
epoch_num: 72 epoch_num: 72
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
save_model_dir: output save_model_dir: output/rec
save_epoch_step: 3 save_epoch_step: 3
eval_batch_step: 2000 eval_batch_step: 2000
train_batch_size_per_card: 256 train_batch_size_per_card: 256
...@@ -15,6 +15,8 @@ Global: ...@@ -15,6 +15,8 @@ Global:
loss_type: ctc loss_type: ctc
reader_yml: ./configs/rec/rec_benchmark_reader.yml reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights: pretrain_weights:
checkpoints:
save_inference_dir:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -4,7 +4,7 @@ Global: ...@@ -4,7 +4,7 @@ Global:
epoch_num: 72 epoch_num: 72
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
save_model_dir: output save_model_dir: output/rec
save_epoch_step: 3 save_epoch_step: 3
eval_batch_step: 2000 eval_batch_step: 2000
train_batch_size_per_card: 256 train_batch_size_per_card: 256
...@@ -15,6 +15,8 @@ Global: ...@@ -15,6 +15,8 @@ Global:
loss_type: attention loss_type: attention
reader_yml: ./configs/rec/rec_benchmark_reader.yml reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights: pretrain_weights:
checkpoints:
save_inference_dir:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -4,7 +4,7 @@ Global: ...@@ -4,7 +4,7 @@ Global:
epoch_num: 72 epoch_num: 72
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
save_model_dir: output save_model_dir: output/rec
save_epoch_step: 3 save_epoch_step: 3
eval_batch_step: 2000 eval_batch_step: 2000
train_batch_size_per_card: 256 train_batch_size_per_card: 256
...@@ -15,6 +15,9 @@ Global: ...@@ -15,6 +15,9 @@ Global:
loss_type: ctc loss_type: ctc
reader_yml: ./configs/rec/rec_benchmark_reader.yml reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights: pretrain_weights:
checkpoints:
save_inference_dir:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -4,7 +4,7 @@ Global: ...@@ -4,7 +4,7 @@ Global:
epoch_num: 72 epoch_num: 72
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
save_model_dir: output save_model_dir: output/rec
save_epoch_step: 3 save_epoch_step: 3
eval_batch_step: 2000 eval_batch_step: 2000
train_batch_size_per_card: 256 train_batch_size_per_card: 256
...@@ -15,6 +15,8 @@ Global: ...@@ -15,6 +15,8 @@ Global:
loss_type: ctc loss_type: ctc
reader_yml: ./configs/rec/rec_benchmark_reader.yml reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights: pretrain_weights:
checkpoints:
save_inference_dir:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -4,7 +4,7 @@ Global: ...@@ -4,7 +4,7 @@ Global:
epoch_num: 72 epoch_num: 72
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
save_model_dir: output save_model_dir: output/rec
save_epoch_step: 3 save_epoch_step: 3
eval_batch_step: 2000 eval_batch_step: 2000
train_batch_size_per_card: 256 train_batch_size_per_card: 256
...@@ -15,6 +15,8 @@ Global: ...@@ -15,6 +15,8 @@ Global:
loss_type: ctc loss_type: ctc
reader_yml: ./configs/rec/rec_benchmark_reader.yml reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights: pretrain_weights:
checkpoints:
save_inference_dir:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -4,7 +4,7 @@ Global: ...@@ -4,7 +4,7 @@ Global:
epoch_num: 72 epoch_num: 72
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
save_model_dir: output save_model_dir: output/rec
save_epoch_step: 3 save_epoch_step: 3
eval_batch_step: 2000 eval_batch_step: 2000
train_batch_size_per_card: 256 train_batch_size_per_card: 256
...@@ -15,6 +15,8 @@ Global: ...@@ -15,6 +15,8 @@ Global:
loss_type: attention loss_type: attention
reader_yml: ./configs/rec/rec_benchmark_reader.yml reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights: pretrain_weights:
checkpoints:
save_inference_dir:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -4,7 +4,7 @@ Global: ...@@ -4,7 +4,7 @@ Global:
epoch_num: 72 epoch_num: 72
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
save_model_dir: output save_model_dir: output/rec
save_epoch_step: 3 save_epoch_step: 3
eval_batch_step: 2000 eval_batch_step: 2000
train_batch_size_per_card: 256 train_batch_size_per_card: 256
...@@ -15,6 +15,8 @@ Global: ...@@ -15,6 +15,8 @@ Global:
loss_type: ctc loss_type: ctc
reader_yml: ./configs/rec/rec_benchmark_reader.yml reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights: pretrain_weights:
checkpoints:
save_inference_dir:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -196,7 +196,7 @@ class DBHead(object): ...@@ -196,7 +196,7 @@ class DBHead(object):
fuse = fluid.layers.concat(input=[p5, p4, p3, p2], axis=1) fuse = fluid.layers.concat(input=[p5, p4, p3, p2], axis=1)
shrink_maps = self.binarize(fuse) shrink_maps = self.binarize(fuse)
if mode != "train": if mode != "train":
return shrink_maps return {"maps", shrink_maps}
threshold_maps = self.thresh(fuse) threshold_maps = self.thresh(fuse)
binary_maps = self.step_function(shrink_maps, threshold_maps) binary_maps = self.step_function(shrink_maps, threshold_maps)
y = fluid.layers.concat( y = fluid.layers.concat(
......
...@@ -128,6 +128,7 @@ class DBPostProcess(object): ...@@ -128,6 +128,7 @@ class DBPostProcess(object):
def __call__(self, outs_dict, ratio_list): def __call__(self, outs_dict, ratio_list):
pred = outs_dict['maps'] pred = outs_dict['maps']
pred = pred[:, 0, :, :] pred = pred[:, 0, :, :]
segmentation = pred > self.thresh segmentation = pred > self.thresh
......
...@@ -24,6 +24,7 @@ import copy ...@@ -24,6 +24,7 @@ import copy
import numpy as np import numpy as np
import math import math
import time import time
import sys
class TextDetector(object): class TextDetector(object):
...@@ -52,10 +53,10 @@ class TextDetector(object): ...@@ -52,10 +53,10 @@ class TextDetector(object):
utility.create_predictor(args, mode="det") utility.create_predictor(args, mode="det")
def order_points_clockwise(self, pts): def order_points_clockwise(self, pts):
####### """
## https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
########
# sort the points based on their x-coordinates # sort the points based on their x-coordinates
"""
xSorted = pts[np.argsort(pts[:, 0]), :] xSorted = pts[np.argsort(pts[:, 0]), :]
# grab the left-most and right-most points from the sorted # grab the left-most and right-most points from the sorted
...@@ -141,7 +142,7 @@ class TextDetector(object): ...@@ -141,7 +142,7 @@ class TextDetector(object):
outs_dict['f_score'] = outputs[0] outs_dict['f_score'] = outputs[0]
outs_dict['f_geo'] = outputs[1] outs_dict['f_geo'] = outputs[1]
else: else:
outs_dict['maps'] = [outputs[0]] outs_dict['maps'] = outputs[0]
dt_boxes_list = self.postprocess_op(outs_dict, [ratio_list]) dt_boxes_list = self.postprocess_op(outs_dict, [ratio_list])
dt_boxes = dt_boxes_list[0] dt_boxes = dt_boxes_list[0]
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
......
...@@ -219,6 +219,8 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict): ...@@ -219,6 +219,8 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
eval_batch_step = config['Global']['eval_batch_step'] eval_batch_step = config['Global']['eval_batch_step']
save_epoch_step = config['Global']['save_epoch_step'] save_epoch_step = config['Global']['save_epoch_step']
save_model_dir = config['Global']['save_model_dir'] save_model_dir = config['Global']['save_model_dir']
if not os.path.exists(save_model_dir):
os.makedirs(save_model_dir)
train_stats = TrainingStats(log_smooth_window, train_stats = TrainingStats(log_smooth_window,
train_info_dict['fetch_name_list']) train_info_dict['fetch_name_list'])
best_eval_hmean = -1 best_eval_hmean = -1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册