未验证 提交 d64a4c3f 编写于 作者: D Double_V 提交者: GitHub

Merge pull request #756 from LDOUBLEV/fixocr

advance db post_process
...@@ -108,9 +108,11 @@ void DBDetector::Run(cv::Mat &img, ...@@ -108,9 +108,11 @@ void DBDetector::Run(cv::Mat &img,
const double maxvalue = 255; const double maxvalue = 255;
cv::Mat bit_map; cv::Mat bit_map;
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY); cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
cv::Mat dilation_map;
cv::Mat dila_ele = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2,2));
cv::dilate(bit_map, dilation_map, dila_ele);
boxes = post_processor_.BoxesFromBitmap( boxes = post_processor_.BoxesFromBitmap(
pred_map, bit_map, this->det_db_box_thresh_, this->det_db_unclip_ratio_); pred_map, dilation_map, this->det_db_box_thresh_, this->det_db_unclip_ratio_);
boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, srcimg); boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, srcimg);
......
...@@ -294,7 +294,7 @@ PostProcessor::FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes, ...@@ -294,7 +294,7 @@ PostProcessor::FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes,
pow(boxes[n][0][1] - boxes[n][1][1], 2))); pow(boxes[n][0][1] - boxes[n][1][1], 2)));
rect_height = int(sqrt(pow(boxes[n][0][0] - boxes[n][3][0], 2) + rect_height = int(sqrt(pow(boxes[n][0][0] - boxes[n][3][0], 2) +
pow(boxes[n][0][1] - boxes[n][3][1], 2))); pow(boxes[n][0][1] - boxes[n][3][1], 2)));
if (rect_width <= 10 || rect_height <= 10) if (rect_width <= 4 || rect_height <= 4)
continue; continue;
root_points.push_back(boxes[n]); root_points.push_back(boxes[n]);
} }
......
...@@ -10,7 +10,7 @@ use_zero_copy_run 1 ...@@ -10,7 +10,7 @@ use_zero_copy_run 1
max_side_len 960 max_side_len 960
det_db_thresh 0.3 det_db_thresh 0.3
det_db_box_thresh 0.5 det_db_box_thresh 0.5
det_db_unclip_ratio 2.0 det_db_unclip_ratio 1.6
det_model_dir ./inference/det_db det_model_dir ./inference/det_db
# cls config # cls config
......
max_side_len 960 max_side_len 960
det_db_thresh 0.3 det_db_thresh 0.3
det_db_box_thresh 0.5 det_db_box_thresh 0.5
det_db_unclip_ratio 2.0 det_db_unclip_ratio 1.6
\ No newline at end of file \ No newline at end of file
...@@ -293,7 +293,7 @@ FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes, float ratio_h, ...@@ -293,7 +293,7 @@ FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes, float ratio_h,
rect_height = rect_height =
static_cast<int>(sqrt(pow(boxes[n][0][0] - boxes[n][3][0], 2) + static_cast<int>(sqrt(pow(boxes[n][0][0] - boxes[n][3][0], 2) +
pow(boxes[n][0][1] - boxes[n][3][1], 2))); pow(boxes[n][0][1] - boxes[n][3][1], 2)));
if (rect_width <= 10 || rect_height <= 10) if (rect_width <= 4 || rect_height <= 4)
continue; continue;
root_points.push_back(boxes[n]); root_points.push_back(boxes[n]);
} }
......
...@@ -289,8 +289,10 @@ RunDetModel(std::shared_ptr<PaddlePredictor> predictor, cv::Mat img, ...@@ -289,8 +289,10 @@ RunDetModel(std::shared_ptr<PaddlePredictor> predictor, cv::Mat img,
const double maxvalue = 255; const double maxvalue = 255;
cv::Mat bit_map; cv::Mat bit_map;
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY); cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
cv::Mat dilation_map;
auto boxes = BoxesFromBitmap(pred_map, bit_map, Config); cv::Mat dila_ele = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2,2));
cv::dilate(bit_map, dilation_map, dila_ele);
auto boxes = BoxesFromBitmap(pred_map, dilation_map, Config);
std::vector<std::vector<std::vector<int>>> filter_boxes = std::vector<std::vector<std::vector<int>>> filter_boxes =
FilterTagDetRes(boxes, ratio_hw[0], ratio_hw[1], srcimg); FilterTagDetRes(boxes, ratio_hw[0], ratio_hw[1], srcimg);
......
...@@ -37,6 +37,7 @@ class DBPostProcess(object): ...@@ -37,6 +37,7 @@ class DBPostProcess(object):
self.max_candidates = params['max_candidates'] self.max_candidates = params['max_candidates']
self.unclip_ratio = params['unclip_ratio'] self.unclip_ratio = params['unclip_ratio']
self.min_size = 3 self.min_size = 3
self.dilation_kernel = np.array([[1, 1], [1, 1]])
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
''' '''
...@@ -140,8 +141,9 @@ class DBPostProcess(object): ...@@ -140,8 +141,9 @@ class DBPostProcess(object):
boxes_batch = [] boxes_batch = []
for batch_index in range(pred.shape[0]): for batch_index in range(pred.shape[0]):
height, width = pred.shape[-2:] height, width = pred.shape[-2:]
tmp_boxes, tmp_scores = self.boxes_from_bitmap(
pred[batch_index], segmentation[batch_index], width, height) mask = cv2.dilate(np.array(segmentation[batch_index]).astype(np.uint8), self.dilation_kernel)
tmp_boxes, tmp_scores = self.boxes_from_bitmap(pred[batch_index], mask, width, height)
boxes = [] boxes = []
for k in range(len(tmp_boxes)): for k in range(len(tmp_boxes)):
......
...@@ -47,7 +47,7 @@ def parse_args(): ...@@ -47,7 +47,7 @@ def parse_args():
# DB parmas # DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3) parser.add_argument("--det_db_thresh", type=float, default=0.3)
parser.add_argument("--det_db_box_thresh", type=float, default=0.5) parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0) parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
# EAST parmas # EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
...@@ -64,7 +64,7 @@ def parse_args(): ...@@ -64,7 +64,7 @@ def parse_args():
parser.add_argument("--rec_model_dir", type=str) parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--rec_char_type", type=str, default='ch') parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=30) parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument("--max_text_length", type=int, default=25) parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument( parser.add_argument(
"--rec_char_dict_path", "--rec_char_dict_path",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册