From a8d1f2db94f0ebd2fe2c48527b71379915e3d2f8 Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Wed, 16 Dec 2020 13:06:48 +0800 Subject: [PATCH] update --- configs/det/det_mv3_db.yml | 4 ++-- configs/det/det_r50_vd_db.yml | 4 ++-- ppocr/losses/det_db_loss.py | 8 ++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/configs/det/det_mv3_db.yml b/configs/det/det_mv3_db.yml index a4ed569f..5c8a0923 100644 --- a/configs/det/det_mv3_db.yml +++ b/configs/det/det_mv3_db.yml @@ -5,8 +5,8 @@ Global: print_batch_step: 10 save_model_dir: ./output/db_mv3/ save_epoch_step: 1200 - # evaluation is run every 5000 iterations after the 4000th iteration - eval_batch_step: [4000, 5000] + # evaluation is run every 2000 iterations + eval_batch_step: [0, 2000] # if pretrained_model is saved in static mode, load_static_weights must set to True load_static_weights: True cal_metric_during_train: False diff --git a/configs/det/det_r50_vd_db.yml b/configs/det/det_r50_vd_db.yml index 386a7970..f1188fe3 100644 --- a/configs/det/det_r50_vd_db.yml +++ b/configs/det/det_r50_vd_db.yml @@ -5,8 +5,8 @@ Global: print_batch_step: 10 save_model_dir: ./output/det_r50_vd/ save_epoch_step: 1200 - # evaluation is run every 5000 iterations after the 4000th iteration - eval_batch_step: [4000,5000] + # evaluation is run every 2000 iterations + eval_batch_step: [0,2000] # if pretrained_model is saved in static mode, load_static_weights must set to True load_static_weights: True cal_metric_during_train: False diff --git a/ppocr/losses/det_db_loss.py b/ppocr/losses/det_db_loss.py index 3e2aa063..b079aabf 100755 --- a/ppocr/losses/det_db_loss.py +++ b/ppocr/losses/det_db_loss.py @@ -47,12 +47,12 @@ class DBLoss(nn.Layer): negative_ratio=ohem_ratio) def forward(self, predicts, labels): - predicts = predicts['maps'] + predict_maps = predicts['maps'] label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[ 1:] - shrink_maps = predicts[:, 0, :, :] - threshold_maps = predicts[:, 1, :, :] - binary_maps = predicts[:, 2, :, :] + shrink_maps = predict_maps[:, 0, :, :] + threshold_maps = predict_maps[:, 1, :, :] + binary_maps = predict_maps[:, 2, :, :] loss_shrink_maps = self.bce_loss(shrink_maps, label_shrink_map, label_shrink_mask) -- GitLab