提交 a8d1f2db 编写于 作者: W WenmuZhou

update

上级 41c2af49
...@@ -5,8 +5,8 @@ Global: ...@@ -5,8 +5,8 @@ Global:
print_batch_step: 10 print_batch_step: 10
save_model_dir: ./output/db_mv3/ save_model_dir: ./output/db_mv3/
save_epoch_step: 1200 save_epoch_step: 1200
# evaluation is run every 5000 iterations after the 4000th iteration # evaluation is run every 2000 iterations
eval_batch_step: [4000, 5000] eval_batch_step: [0, 2000]
# if pretrained_model is saved in static mode, load_static_weights must set to True # if pretrained_model is saved in static mode, load_static_weights must set to True
load_static_weights: True load_static_weights: True
cal_metric_during_train: False cal_metric_during_train: False
......
...@@ -5,8 +5,8 @@ Global: ...@@ -5,8 +5,8 @@ Global:
print_batch_step: 10 print_batch_step: 10
save_model_dir: ./output/det_r50_vd/ save_model_dir: ./output/det_r50_vd/
save_epoch_step: 1200 save_epoch_step: 1200
# evaluation is run every 5000 iterations after the 4000th iteration # evaluation is run every 2000 iterations
eval_batch_step: [4000,5000] eval_batch_step: [0,2000]
# if pretrained_model is saved in static mode, load_static_weights must set to True # if pretrained_model is saved in static mode, load_static_weights must set to True
load_static_weights: True load_static_weights: True
cal_metric_during_train: False cal_metric_during_train: False
......
...@@ -47,12 +47,12 @@ class DBLoss(nn.Layer): ...@@ -47,12 +47,12 @@ class DBLoss(nn.Layer):
negative_ratio=ohem_ratio) negative_ratio=ohem_ratio)
def forward(self, predicts, labels): def forward(self, predicts, labels):
predicts = predicts['maps'] predict_maps = predicts['maps']
label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[ label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[
1:] 1:]
shrink_maps = predicts[:, 0, :, :] shrink_maps = predict_maps[:, 0, :, :]
threshold_maps = predicts[:, 1, :, :] threshold_maps = predict_maps[:, 1, :, :]
binary_maps = predicts[:, 2, :, :] binary_maps = predict_maps[:, 2, :, :]
loss_shrink_maps = self.bce_loss(shrink_maps, label_shrink_map, loss_shrink_maps = self.bce_loss(shrink_maps, label_shrink_map,
label_shrink_mask) label_shrink_mask)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册