提交 8d9cfade 编写于 作者: W wangjingyeye

add polygons

上级 c26e7aee
...@@ -83,14 +83,10 @@ class DetectionIoUEvaluator(object): ...@@ -83,14 +83,10 @@ class DetectionIoUEvaluator(object):
evaluationLog = "" evaluationLog = ""
# print(len(gt))
for n in range(len(gt)): for n in range(len(gt)):
points = gt[n]['points'] points = gt[n]['points']
# transcription = gt[n]['text']
dontCare = gt[n]['ignore'] dontCare = gt[n]['ignore']
# points = Polygon(points) if not Polygon(points).is_valid:
# points = points.buffer(0)
if not Polygon(points).is_valid or not Polygon(points).is_simple:
continue continue
gtPol = points gtPol = points
...@@ -105,9 +101,7 @@ class DetectionIoUEvaluator(object): ...@@ -105,9 +101,7 @@ class DetectionIoUEvaluator(object):
for n in range(len(pred)): for n in range(len(pred)):
points = pred[n]['points'] points = pred[n]['points']
# points = Polygon(points) if not Polygon(points).is_valid:
# points = points.buffer(0)
if not Polygon(points).is_valid or not Polygon(points).is_simple:
continue continue
detPol = points detPol = points
...@@ -191,8 +185,6 @@ class DetectionIoUEvaluator(object): ...@@ -191,8 +185,6 @@ class DetectionIoUEvaluator(object):
methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \ methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
methodRecall * methodPrecision / ( methodRecall * methodPrecision / (
methodRecall + methodPrecision) methodRecall + methodPrecision)
# print(methodRecall, methodPrecision, methodHmean)
# sys.exit(-1)
methodMetrics = { methodMetrics = {
'precision': methodPrecision, 'precision': methodPrecision,
'recall': methodRecall, 'recall': methodRecall,
......
...@@ -38,6 +38,7 @@ class DBPostProcess(object): ...@@ -38,6 +38,7 @@ class DBPostProcess(object):
unclip_ratio=2.0, unclip_ratio=2.0,
use_dilation=False, use_dilation=False,
score_mode="fast", score_mode="fast",
use_polygon=False,
**kwargs): **kwargs):
self.thresh = thresh self.thresh = thresh
self.box_thresh = box_thresh self.box_thresh = box_thresh
...@@ -45,6 +46,7 @@ class DBPostProcess(object): ...@@ -45,6 +46,7 @@ class DBPostProcess(object):
self.unclip_ratio = unclip_ratio self.unclip_ratio = unclip_ratio
self.min_size = 3 self.min_size = 3
self.score_mode = score_mode self.score_mode = score_mode
self.use_polygon = use_polygon
assert score_mode in [ assert score_mode in [
"slow", "fast" "slow", "fast"
], "Score mode must be in [slow, fast] but got: {}".format(score_mode) ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
...@@ -52,6 +54,56 @@ class DBPostProcess(object): ...@@ -52,6 +54,56 @@ class DBPostProcess(object):
self.dilation_kernel = None if not use_dilation else np.array( self.dilation_kernel = None if not use_dilation else np.array(
[[1, 1], [1, 1]]) [[1, 1], [1, 1]])
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap = _bitmap
height, width = bitmap.shape
boxes = []
scores = []
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours[:self.max_candidates]:
epsilon = 0.002 * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
points = approx.reshape((-1, 2))
# print(points)
if points.shape[0] < 4:
continue
score = self.box_score_fast(pred, points.reshape(-1, 2))
if self.box_thresh > score:
continue
if points.shape[0] > 2:
box = self.unclip(points, self.unclip_ratio)
if len(box) > 1:
continue
else:
continue
box = box.reshape(-1, 2)
# print(box)
_, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
if sside < self.min_size + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.tolist())
scores.append(score)
# print(boxes)
return boxes, scores
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
''' '''
_bitmap: single map with shape (1, H, W), _bitmap: single map with shape (1, H, W),
...@@ -85,7 +137,7 @@ class DBPostProcess(object): ...@@ -85,7 +137,7 @@ class DBPostProcess(object):
if self.box_thresh > score: if self.box_thresh > score:
continue continue
box = self.unclip(points).reshape(-1, 1, 2) box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box) box, sside = self.get_mini_boxes(box)
if sside < self.min_size + 2: if sside < self.min_size + 2:
continue continue
...@@ -99,8 +151,7 @@ class DBPostProcess(object): ...@@ -99,8 +151,7 @@ class DBPostProcess(object):
scores.append(score) scores.append(score)
return np.array(boxes, dtype=np.int16), scores return np.array(boxes, dtype=np.int16), scores
def unclip(self, box): def unclip(self, box, unclip_ratio):
unclip_ratio = self.unclip_ratio
poly = Polygon(box) poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset() offset = pyclipper.PyclipperOffset()
...@@ -185,8 +236,12 @@ class DBPostProcess(object): ...@@ -185,8 +236,12 @@ class DBPostProcess(object):
self.dilation_kernel) self.dilation_kernel)
else: else:
mask = segmentation[batch_index] mask = segmentation[batch_index]
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, if self.use_polygon:
src_w, src_h) boxes, scores = self.polygons_from_bitmap(pred[batch_index],
mask, src_w, src_h)
else:
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
src_w, src_h)
boxes_batch.append({'points': boxes}) boxes_batch.append({'points': boxes})
return boxes_batch return boxes_batch
...@@ -202,6 +257,7 @@ class DistillationDBPostProcess(object): ...@@ -202,6 +257,7 @@ class DistillationDBPostProcess(object):
unclip_ratio=1.5, unclip_ratio=1.5,
use_dilation=False, use_dilation=False,
score_mode="fast", score_mode="fast",
use_polygon=False,
**kwargs): **kwargs):
self.model_name = model_name self.model_name = model_name
self.key = key self.key = key
...@@ -211,7 +267,8 @@ class DistillationDBPostProcess(object): ...@@ -211,7 +267,8 @@ class DistillationDBPostProcess(object):
max_candidates=max_candidates, max_candidates=max_candidates,
unclip_ratio=unclip_ratio, unclip_ratio=unclip_ratio,
use_dilation=use_dilation, use_dilation=use_dilation,
score_mode=score_mode) score_mode=score_mode,
use_polygon=use_polygon)
def __call__(self, predicts, shape_list): def __call__(self, predicts, shape_list):
results = {} results = {}
......
...@@ -44,7 +44,7 @@ def draw_det_res(dt_boxes, config, img, img_name, save_path): ...@@ -44,7 +44,7 @@ def draw_det_res(dt_boxes, config, img, img_name, save_path):
import cv2 import cv2
src_im = img src_im = img
for box in dt_boxes: for box in dt_boxes:
box = box.astype(np.int32).reshape((-1, 1, 2)) box = np.array(box).astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
if not os.path.exists(save_path): if not os.path.exists(save_path):
os.makedirs(save_path) os.makedirs(save_path)
...@@ -106,7 +106,7 @@ def main(): ...@@ -106,7 +106,7 @@ def main():
dt_boxes_list = [] dt_boxes_list = []
for box in boxes: for box in boxes:
tmp_json = {"transcription": ""} tmp_json = {"transcription": ""}
tmp_json['points'] = box.tolist() tmp_json['points'] = list(box)
dt_boxes_list.append(tmp_json) dt_boxes_list.append(tmp_json)
det_box_json[k] = dt_boxes_list det_box_json[k] = dt_boxes_list
save_det_path = os.path.dirname(config['Global'][ save_det_path = os.path.dirname(config['Global'][
...@@ -118,7 +118,7 @@ def main(): ...@@ -118,7 +118,7 @@ def main():
# write result # write result
for box in boxes: for box in boxes:
tmp_json = {"transcription": ""} tmp_json = {"transcription": ""}
tmp_json['points'] = box.tolist() tmp_json['points'] = list(box)
dt_boxes_json.append(tmp_json) dt_boxes_json.append(tmp_json)
save_det_path = os.path.dirname(config['Global'][ save_det_path = os.path.dirname(config['Global'][
'save_res_path']) + "/det_results/" 'save_res_path']) + "/det_results/"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册