未验证 提交 f38a22c0 编写于 作者: M MissPenguin 提交者: GitHub

Merge pull request #1449 from WenmuZhou/tree_doc

[Dygraph] change DBHead output to dict and update db config
...@@ -2,11 +2,11 @@ Global: ...@@ -2,11 +2,11 @@ Global:
use_gpu: true use_gpu: true
epoch_num: 1200 epoch_num: 1200
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 2 print_batch_step: 10
save_model_dir: ./output/db_mv3/ save_model_dir: ./output/db_mv3/
save_epoch_step: 1200 save_epoch_step: 1200
# evaluation is run every 5000 iterations after the 4000th iteration # evaluation is run every 2000 iterations
eval_batch_step: [4000, 5000] eval_batch_step: [0, 2000]
# if pretrained_model is saved in static mode, load_static_weights must set to True # if pretrained_model is saved in static mode, load_static_weights must set to True
load_static_weights: True load_static_weights: True
cal_metric_during_train: False cal_metric_during_train: False
...@@ -39,7 +39,7 @@ Loss: ...@@ -39,7 +39,7 @@ Loss:
alpha: 5 alpha: 5
beta: 10 beta: 10
ohem_ratio: 3 ohem_ratio: 3
Optimizer: Optimizer:
name: Adam name: Adam
beta1: 0.9 beta1: 0.9
...@@ -100,7 +100,7 @@ Train: ...@@ -100,7 +100,7 @@ Train:
loader: loader:
shuffle: True shuffle: True
drop_last: False drop_last: False
batch_size_per_card: 4 batch_size_per_card: 16
num_workers: 8 num_workers: 8
Eval: Eval:
...@@ -128,4 +128,4 @@ Eval: ...@@ -128,4 +128,4 @@ Eval:
shuffle: False shuffle: False
drop_last: False drop_last: False
batch_size_per_card: 1 # must be 1 batch_size_per_card: 1 # must be 1
num_workers: 2 num_workers: 8
\ No newline at end of file \ No newline at end of file
...@@ -5,8 +5,8 @@ Global: ...@@ -5,8 +5,8 @@ Global:
print_batch_step: 10 print_batch_step: 10
save_model_dir: ./output/det_r50_vd/ save_model_dir: ./output/det_r50_vd/
save_epoch_step: 1200 save_epoch_step: 1200
# evaluation is run every 5000 iterations after the 4000th iteration # evaluation is run every 2000 iterations
eval_batch_step: [5000,4000] eval_batch_step: [0,2000]
# if pretrained_model is saved in static mode, load_static_weights must set to True # if pretrained_model is saved in static mode, load_static_weights must set to True
load_static_weights: True load_static_weights: True
cal_metric_during_train: False cal_metric_during_train: False
......
...@@ -47,11 +47,12 @@ class DBLoss(nn.Layer): ...@@ -47,11 +47,12 @@ class DBLoss(nn.Layer):
negative_ratio=ohem_ratio) negative_ratio=ohem_ratio)
def forward(self, predicts, labels): def forward(self, predicts, labels):
predict_maps = predicts['maps']
label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[ label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[
1:] 1:]
shrink_maps = predicts[:, 0, :, :] shrink_maps = predict_maps[:, 0, :, :]
threshold_maps = predicts[:, 1, :, :] threshold_maps = predict_maps[:, 1, :, :]
binary_maps = predicts[:, 2, :, :] binary_maps = predict_maps[:, 2, :, :]
loss_shrink_maps = self.bce_loss(shrink_maps, label_shrink_map, loss_shrink_maps = self.bce_loss(shrink_maps, label_shrink_map,
label_shrink_mask) label_shrink_mask)
......
...@@ -120,9 +120,9 @@ class DBHead(nn.Layer): ...@@ -120,9 +120,9 @@ class DBHead(nn.Layer):
def forward(self, x): def forward(self, x):
shrink_maps = self.binarize(x) shrink_maps = self.binarize(x)
if not self.training: if not self.training:
return shrink_maps return {'maps': shrink_maps}
threshold_maps = self.thresh(x) threshold_maps = self.thresh(x)
binary_maps = self.step_function(shrink_maps, threshold_maps) binary_maps = self.step_function(shrink_maps, threshold_maps)
y = paddle.concat([shrink_maps, threshold_maps, binary_maps], axis=1) y = paddle.concat([shrink_maps, threshold_maps, binary_maps], axis=1)
return y return {'maps': y}
...@@ -40,7 +40,8 @@ class DBPostProcess(object): ...@@ -40,7 +40,8 @@ class DBPostProcess(object):
self.max_candidates = max_candidates self.max_candidates = max_candidates
self.unclip_ratio = unclip_ratio self.unclip_ratio = unclip_ratio
self.min_size = 3 self.min_size = 3
self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]]) self.dilation_kernel = None if not use_dilation else np.array(
[[1, 1], [1, 1]])
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
''' '''
...@@ -132,7 +133,8 @@ class DBPostProcess(object): ...@@ -132,7 +133,8 @@ class DBPostProcess(object):
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, pred, shape_list): def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps']
if isinstance(pred, paddle.Tensor): if isinstance(pred, paddle.Tensor):
pred = pred.numpy() pred = pred.numpy()
pred = pred[:, 0, :, :] pred = pred[:, 0, :, :]
......
...@@ -65,12 +65,12 @@ class TextDetector(object): ...@@ -65,12 +65,12 @@ class TextDetector(object):
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = True postprocess_params["use_dilation"] = True
elif self.det_algorithm == "EAST": elif self.det_algorithm == "EAST":
postprocess_params['name'] = 'EASTPostProcess' postprocess_params['name'] = 'EASTPostProcess'
postprocess_params["score_thresh"] = args.det_east_score_thresh postprocess_params["score_thresh"] = args.det_east_score_thresh
postprocess_params["cover_thresh"] = args.det_east_cover_thresh postprocess_params["cover_thresh"] = args.det_east_cover_thresh
postprocess_params["nms_thresh"] = args.det_east_nms_thresh postprocess_params["nms_thresh"] = args.det_east_nms_thresh
elif self.det_algorithm == "SAST": elif self.det_algorithm == "SAST":
postprocess_params['name'] = 'SASTPostProcess' postprocess_params['name'] = 'SASTPostProcess'
postprocess_params["score_thresh"] = args.det_sast_score_thresh postprocess_params["score_thresh"] = args.det_sast_score_thresh
postprocess_params["nms_thresh"] = args.det_sast_nms_thresh postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
self.det_sast_polygon = args.det_sast_polygon self.det_sast_polygon = args.det_sast_polygon
...@@ -177,8 +177,10 @@ class TextDetector(object): ...@@ -177,8 +177,10 @@ class TextDetector(object):
preds['f_score'] = outputs[1] preds['f_score'] = outputs[1]
preds['f_tco'] = outputs[2] preds['f_tco'] = outputs[2]
preds['f_tvo'] = outputs[3] preds['f_tvo'] = outputs[3]
elif self.det_algorithm == 'DB':
preds['maps'] = outputs[0]
else: else:
preds = outputs[0] raise NotImplementedError
post_result = self.postprocess_op(preds, shape_list) post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points'] dt_boxes = post_result[0]['points']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册