From b457850659c43fbd4a26c4fc4b70a3709b9952d4 Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Mon, 21 Sep 2020 16:46:09 +0800 Subject: [PATCH] sync TrainReader.batch to global train_batch_size (#1415) * sync TrainReader.batch to global train_batch_size --- configs/dcn/yolov3_r50vd_dcn.yml | 5 ----- ...dcn_db_iouaware_obj365_pretrained_coco.yml | 1 - ..._dcn_db_iouloss_obj365_pretrained_coco.yml | 5 ----- ...v3_r50vd_dcn_db_obj365_pretrained_coco.yml | 5 ----- ...olov3_r50vd_dcn_obj365_pretrained_coco.yml | 5 ----- configs/ppyolo/ppyolo.yml | 1 - configs/ppyolo/ppyolo_2x.yml | 1 - configs/ppyolo/ppyolo_r18vd.yml | 1 - configs/ppyolo/ppyolo_test.yml | 1 - configs/yolov3_darknet.yml | 5 ----- configs/yolov3_darknet_voc.yml | 5 ----- configs/yolov3_darknet_voc_diouloss.yml | 1 - configs/yolov3_mobilenet_v1.yml | 5 ----- configs/yolov3_mobilenet_v1_fruit.yml | 5 ----- configs/yolov3_mobilenet_v1_voc.yml | 5 ----- configs/yolov3_mobilenet_v3.yml | 5 ----- configs/yolov3_r34.yml | 5 ----- configs/yolov3_r34_voc.yml | 5 ----- configs/yolov4/yolov4_cspdarknet.yml | 5 ----- configs/yolov4/yolov4_cspdarknet_coco.yml | 5 ----- configs/yolov4/yolov4_cspdarknet_voc.yml | 5 ----- ppdet/core/workspace.py | 9 +++++++++ ppdet/modeling/losses/yolo_loss.py | 20 +++++++++---------- 23 files changed, 19 insertions(+), 91 deletions(-) diff --git a/configs/dcn/yolov3_r50vd_dcn.yml b/configs/dcn/yolov3_r50vd_dcn.yml index 0493597b1..99815fc32 100755 --- a/configs/dcn/yolov3_r50vd_dcn.yml +++ b/configs/dcn/yolov3_r50vd_dcn.yml @@ -40,11 +40,6 @@ YOLOv3Head: score_threshold: 0.01 YOLOv3Loss: - # batch_size here is only used for fine grained loss, not used - # for training batch_size setting, training batch_size setting - # is in configs/yolov3_reader.yml TrainReader.batch_size, batch - # size here should be set as same value as TrainReader.batch_size - batch_size: 8 ignore_thresh: 0.7 label_smooth: false diff --git a/configs/dcn/yolov3_r50vd_dcn_db_iouaware_obj365_pretrained_coco.yml b/configs/dcn/yolov3_r50vd_dcn_db_iouaware_obj365_pretrained_coco.yml index 6177aaac7..63c4cad4e 100755 --- a/configs/dcn/yolov3_r50vd_dcn_db_iouaware_obj365_pretrained_coco.yml +++ b/configs/dcn/yolov3_r50vd_dcn_db_iouaware_obj365_pretrained_coco.yml @@ -44,7 +44,6 @@ YOLOv3Head: drop_block: true YOLOv3Loss: - batch_size: 8 ignore_thresh: 0.7 label_smooth: false use_fine_grained_loss: true diff --git a/configs/dcn/yolov3_r50vd_dcn_db_iouloss_obj365_pretrained_coco.yml b/configs/dcn/yolov3_r50vd_dcn_db_iouloss_obj365_pretrained_coco.yml index 5e9431453..037c52714 100755 --- a/configs/dcn/yolov3_r50vd_dcn_db_iouloss_obj365_pretrained_coco.yml +++ b/configs/dcn/yolov3_r50vd_dcn_db_iouloss_obj365_pretrained_coco.yml @@ -42,11 +42,6 @@ YOLOv3Head: drop_block: true YOLOv3Loss: - # batch_size here is only used for fine grained loss, not used - # for training batch_size setting, training batch_size setting - # is in configs/yolov3_reader.yml TrainReader.batch_size, batch - # size here should be set as same value as TrainReader.batch_size - batch_size: 8 ignore_thresh: 0.7 label_smooth: false use_fine_grained_loss: true diff --git a/configs/dcn/yolov3_r50vd_dcn_db_obj365_pretrained_coco.yml b/configs/dcn/yolov3_r50vd_dcn_db_obj365_pretrained_coco.yml index 3c69e410a..084930b96 100755 --- a/configs/dcn/yolov3_r50vd_dcn_db_obj365_pretrained_coco.yml +++ b/configs/dcn/yolov3_r50vd_dcn_db_obj365_pretrained_coco.yml @@ -43,11 +43,6 @@ YOLOv3Head: keep_prob: 0.94 YOLOv3Loss: - # batch_size here is only used for fine grained loss, not used - # for training batch_size setting, training batch_size setting - # is in configs/yolov3_reader.yml TrainReader.batch_size, batch - # size here should be set as same value as TrainReader.batch_size - batch_size: 8 ignore_thresh: 0.7 label_smooth: false use_fine_grained_loss: true diff --git a/configs/dcn/yolov3_r50vd_dcn_obj365_pretrained_coco.yml b/configs/dcn/yolov3_r50vd_dcn_obj365_pretrained_coco.yml index 014a7947e..31e781980 100755 --- a/configs/dcn/yolov3_r50vd_dcn_obj365_pretrained_coco.yml +++ b/configs/dcn/yolov3_r50vd_dcn_obj365_pretrained_coco.yml @@ -41,11 +41,6 @@ YOLOv3Head: score_threshold: 0.01 YOLOv3Loss: - # batch_size here is only used for fine grained loss, not used - # for training batch_size setting, training batch_size setting - # is in configs/yolov3_reader.yml TrainReader.batch_size, batch - # size here should be set as same value as TrainReader.batch_size - batch_size: 8 ignore_thresh: 0.7 label_smooth: false use_fine_grained_loss: true diff --git a/configs/ppyolo/ppyolo.yml b/configs/ppyolo/ppyolo.yml index 59d5faa3a..a1a9e9959 100644 --- a/configs/ppyolo/ppyolo.yml +++ b/configs/ppyolo/ppyolo.yml @@ -44,7 +44,6 @@ YOLOv3Head: drop_block: true YOLOv3Loss: - batch_size: 24 ignore_thresh: 0.7 scale_x_y: 1.05 label_smooth: false diff --git a/configs/ppyolo/ppyolo_2x.yml b/configs/ppyolo/ppyolo_2x.yml index 8c2493372..a78158867 100644 --- a/configs/ppyolo/ppyolo_2x.yml +++ b/configs/ppyolo/ppyolo_2x.yml @@ -44,7 +44,6 @@ YOLOv3Head: drop_block: true YOLOv3Loss: - batch_size: 24 ignore_thresh: 0.7 scale_x_y: 1.05 label_smooth: false diff --git a/configs/ppyolo/ppyolo_r18vd.yml b/configs/ppyolo/ppyolo_r18vd.yml index c054d5f5d..a686a2099 100755 --- a/configs/ppyolo/ppyolo_r18vd.yml +++ b/configs/ppyolo/ppyolo_r18vd.yml @@ -39,7 +39,6 @@ YOLOv3Head: drop_block: true YOLOv3Loss: - batch_size: 32 ignore_thresh: 0.7 scale_x_y: 1.05 label_smooth: false diff --git a/configs/ppyolo/ppyolo_test.yml b/configs/ppyolo/ppyolo_test.yml index 840865a0b..a9b16dd44 100644 --- a/configs/ppyolo/ppyolo_test.yml +++ b/configs/ppyolo/ppyolo_test.yml @@ -47,7 +47,6 @@ YOLOv3Head: drop_block: true YOLOv3Loss: - batch_size: 24 ignore_thresh: 0.7 scale_x_y: 1.05 label_smooth: false diff --git a/configs/yolov3_darknet.yml b/configs/yolov3_darknet.yml index b84d81103..c3b4477f9 100644 --- a/configs/yolov3_darknet.yml +++ b/configs/yolov3_darknet.yml @@ -35,11 +35,6 @@ YOLOv3Head: score_threshold: 0.01 YOLOv3Loss: - # batch_size here is only used for fine grained loss, not used - # for training batch_size setting, training batch_size setting - # is in configs/yolov3_reader.yml TrainReader.batch_size, batch - # size here should be set as same value as TrainReader.batch_size - batch_size: 8 ignore_thresh: 0.7 label_smooth: true diff --git a/configs/yolov3_darknet_voc.yml b/configs/yolov3_darknet_voc.yml index b1c48f5f6..362989c28 100644 --- a/configs/yolov3_darknet_voc.yml +++ b/configs/yolov3_darknet_voc.yml @@ -36,11 +36,6 @@ YOLOv3Head: score_threshold: 0.01 YOLOv3Loss: - # batch_size here is only used for fine grained loss, not used - # for training batch_size setting, training batch_size setting - # is in configs/yolov3_reader.yml TrainReader.batch_size, batch - # size here should be set as same value as TrainReader.batch_size - batch_size: 8 ignore_thresh: 0.7 label_smooth: false diff --git a/configs/yolov3_darknet_voc_diouloss.yml b/configs/yolov3_darknet_voc_diouloss.yml index 62c912dc1..8a006fe8b 100644 --- a/configs/yolov3_darknet_voc_diouloss.yml +++ b/configs/yolov3_darknet_voc_diouloss.yml @@ -36,7 +36,6 @@ YOLOv3Head: score_threshold: 0.01 YOLOv3Loss: - batch_size: 8 ignore_thresh: 0.7 label_smooth: false iou_loss: DiouLossYolo diff --git a/configs/yolov3_mobilenet_v1.yml b/configs/yolov3_mobilenet_v1.yml index 040f0f2c9..3325bd4fb 100644 --- a/configs/yolov3_mobilenet_v1.yml +++ b/configs/yolov3_mobilenet_v1.yml @@ -36,11 +36,6 @@ YOLOv3Head: score_threshold: 0.01 YOLOv3Loss: - # batch_size here is only used for fine grained loss, not used - # for training batch_size setting, training batch_size setting - # is in configs/yolov3_reader.yml TrainReader.batch_size, batch - # size here should be set as same value as TrainReader.batch_size - batch_size: 8 ignore_thresh: 0.7 label_smooth: true diff --git a/configs/yolov3_mobilenet_v1_fruit.yml b/configs/yolov3_mobilenet_v1_fruit.yml index 78f50206e..b9e576c29 100644 --- a/configs/yolov3_mobilenet_v1_fruit.yml +++ b/configs/yolov3_mobilenet_v1_fruit.yml @@ -38,11 +38,6 @@ YOLOv3Head: score_threshold: 0.01 YOLOv3Loss: - # batch_size here is only used for fine grained loss, not used - # for training batch_size setting, training batch_size setting - # is in configs/yolov3_reader.yml TrainReader.batch_size, batch - # size here should be set as same value as TrainReader.batch_size - batch_size: 8 ignore_thresh: 0.7 label_smooth: true diff --git a/configs/yolov3_mobilenet_v1_voc.yml b/configs/yolov3_mobilenet_v1_voc.yml index 1b7097ad3..3df184b25 100644 --- a/configs/yolov3_mobilenet_v1_voc.yml +++ b/configs/yolov3_mobilenet_v1_voc.yml @@ -37,11 +37,6 @@ YOLOv3Head: score_threshold: 0.01 YOLOv3Loss: - # batch_size here is only used for fine grained loss, not used - # for training batch_size setting, training batch_size setting - # is in configs/yolov3_reader.yml TrainReader.batch_size, batch - # size here should be set as same value as TrainReader.batch_size - batch_size: 8 ignore_thresh: 0.7 label_smooth: false diff --git a/configs/yolov3_mobilenet_v3.yml b/configs/yolov3_mobilenet_v3.yml index 223d14c49..d8526f6a8 100644 --- a/configs/yolov3_mobilenet_v3.yml +++ b/configs/yolov3_mobilenet_v3.yml @@ -38,11 +38,6 @@ YOLOv3Head: score_threshold: 0.01 YOLOv3Loss: - # batch_size here is only used for fine grained loss, not used - # for training batch_size setting, training batch_size setting - # is in configs/yolov3_reader.yml TrainReader.batch_size, batch - # size here should be set as same value as TrainReader.batch_size - batch_size: 8 ignore_thresh: 0.7 label_smooth: false diff --git a/configs/yolov3_r34.yml b/configs/yolov3_r34.yml index da887cf3d..ca4d50b4c 100644 --- a/configs/yolov3_r34.yml +++ b/configs/yolov3_r34.yml @@ -38,11 +38,6 @@ YOLOv3Head: score_threshold: 0.01 YOLOv3Loss: - # batch_size here is only used for fine grained loss, not used - # for training batch_size setting, training batch_size setting - # is in configs/yolov3_reader.yml TrainReader.batch_size, batch - # size here should be set as same value as TrainReader.batch_size - batch_size: 8 ignore_thresh: 0.7 label_smooth: true diff --git a/configs/yolov3_r34_voc.yml b/configs/yolov3_r34_voc.yml index 2d980dd0c..6aa4aa74c 100644 --- a/configs/yolov3_r34_voc.yml +++ b/configs/yolov3_r34_voc.yml @@ -39,11 +39,6 @@ YOLOv3Head: score_threshold: 0.01 YOLOv3Loss: - # batch_size here is only used for fine grained loss, not used - # for training batch_size setting, training batch_size setting - # is in configs/yolov3_reader.yml TrainReader.batch_size, batch - # size here should be set as same value as TrainReader.batch_size - batch_size: 8 ignore_thresh: 0.7 label_smooth: false diff --git a/configs/yolov4/yolov4_cspdarknet.yml b/configs/yolov4/yolov4_cspdarknet.yml index 4411b054f..cbc69d122 100644 --- a/configs/yolov4/yolov4_cspdarknet.yml +++ b/configs/yolov4/yolov4_cspdarknet.yml @@ -35,11 +35,6 @@ YOLOv4Head: scale_x_y: [1.2, 1.1, 1.05] YOLOv3Loss: - # batch_size here is only used for fine grained loss, not used - # for training batch_size setting, training batch_size setting - # is in configs/yolov3_reader.yml TrainReader.batch_size, batch - # size here should be set as same value as TrainReader.batch_size - batch_size: 4 ignore_thresh: 0.7 label_smooth: true downsample: [8,16,32] diff --git a/configs/yolov4/yolov4_cspdarknet_coco.yml b/configs/yolov4/yolov4_cspdarknet_coco.yml index 8b4a15dc5..a711a177a 100644 --- a/configs/yolov4/yolov4_cspdarknet_coco.yml +++ b/configs/yolov4/yolov4_cspdarknet_coco.yml @@ -34,11 +34,6 @@ YOLOv4Head: scale_x_y: [1.2, 1.1, 1.05] YOLOv3Loss: - # batch_size here is only used for fine grained loss, not used - # for training batch_size setting, training batch_size setting - # is in configs/yolov3_reader.yml TrainReader.batch_size, batch - # size here should be set as same value as TrainReader.batch_size - batch_size: 8 ignore_thresh: 0.7 label_smooth: true downsample: [8,16,32] diff --git a/configs/yolov4/yolov4_cspdarknet_voc.yml b/configs/yolov4/yolov4_cspdarknet_voc.yml index beefaa0f1..3f2af08a6 100644 --- a/configs/yolov4/yolov4_cspdarknet_voc.yml +++ b/configs/yolov4/yolov4_cspdarknet_voc.yml @@ -34,11 +34,6 @@ YOLOv4Head: scale_x_y: [1.2, 1.1, 1.05] YOLOv3Loss: - # batch_size here is only used for fine grained loss, not used - # for training batch_size setting, training batch_size setting - # is in configs/yolov3_reader.yml TrainReader.batch_size, batch - # size here should be set as same value as TrainReader.batch_size - batch_size: 4 ignore_thresh: 0.7 label_smooth: true downsample: [8,16,32] diff --git a/ppdet/core/workspace.py b/ppdet/core/workspace.py index b7f7370b4..a5124b5cd 100644 --- a/ppdet/core/workspace.py +++ b/ppdet/core/workspace.py @@ -97,6 +97,15 @@ def load_config(file_path): del cfg[READER_KEY] merge_config(cfg) + + # NOTE: training batch size defined only in TrainReader, sychornized + # batch size config to global, models can get batch size config + # from global config when building model. + # batch size in evaluation or inference can also be added here + if 'TrainReader' in global_config: + global_config['train_batch_size'] = global_config['TrainReader'][ + 'batch_size'] + return global_config diff --git a/ppdet/modeling/losses/yolo_loss.py b/ppdet/modeling/losses/yolo_loss.py index 6823c024b..e978eb992 100644 --- a/ppdet/modeling/losses/yolo_loss.py +++ b/ppdet/modeling/losses/yolo_loss.py @@ -32,17 +32,17 @@ class YOLOv3Loss(object): Combined loss for YOLOv3 network Args: - batch_size (int): training batch size + train_batch_size (int): training batch size ignore_thresh (float): threshold to ignore confidence loss label_smooth (bool): whether to use label smoothing use_fine_grained_loss (bool): whether use fine grained YOLOv3 loss instead of fluid.layers.yolov3_loss """ __inject__ = ['iou_loss', 'iou_aware_loss'] - __shared__ = ['use_fine_grained_loss'] + __shared__ = ['use_fine_grained_loss', 'train_batch_size'] def __init__(self, - batch_size=8, + train_batch_size=8, ignore_thresh=0.7, label_smooth=True, use_fine_grained_loss=False, @@ -51,7 +51,7 @@ class YOLOv3Loss(object): downsample=[32, 16, 8], scale_x_y=1., match_score=False): - self._batch_size = batch_size + self._train_batch_size = train_batch_size self._ignore_thresh = ignore_thresh self._label_smooth = label_smooth self._use_fine_grained_loss = use_fine_grained_loss @@ -65,7 +65,7 @@ class YOLOv3Loss(object): anchor_masks, mask_anchors, num_classes, prefix_name): if self._use_fine_grained_loss: return self._get_fine_grained_loss( - outputs, targets, gt_box, self._batch_size, num_classes, + outputs, targets, gt_box, self._train_batch_size, num_classes, mask_anchors, self._ignore_thresh) else: losses = [] @@ -95,7 +95,7 @@ class YOLOv3Loss(object): outputs, targets, gt_box, - batch_size, + train_batch_size, num_classes, mask_anchors, ignore_thresh, @@ -108,7 +108,7 @@ class YOLOv3Loss(object): targets ([Variables]): List of Variables, The targets for yolo loss calculatation. gt_box (Variable): The ground-truth boudding boxes. - batch_size (int): The training batch size + train_batch_size (int): The training batch size num_classes (int): class num of dataset mask_anchors ([[float]]): list of anchors in each output layer ignore_thresh (float): prediction bbox overlap any gt_box greater @@ -171,7 +171,7 @@ class YOLOv3Loss(object): loss_h = fluid.layers.reduce_sum(loss_h, dim=[1, 2, 3]) if self._iou_loss is not None: loss_iou = self._iou_loss(x, y, w, h, tx, ty, tw, th, anchors, - downsample, self._batch_size, + downsample, self._train_batch_size, scale_x_y) loss_iou = loss_iou * tscale_tobj loss_iou = fluid.layers.reduce_sum(loss_iou, dim=[1, 2, 3]) @@ -180,14 +180,14 @@ class YOLOv3Loss(object): if self._iou_aware_loss is not None: loss_iou_aware = self._iou_aware_loss( ioup, x, y, w, h, tx, ty, tw, th, anchors, downsample, - self._batch_size, scale_x_y) + self._train_batch_size, scale_x_y) loss_iou_aware = loss_iou_aware * tobj loss_iou_aware = fluid.layers.reduce_sum( loss_iou_aware, dim=[1, 2, 3]) loss_iou_awares.append(fluid.layers.reduce_mean(loss_iou_aware)) loss_obj_pos, loss_obj_neg = self._calc_obj_loss( - output, obj, tobj, gt_box, self._batch_size, anchors, + output, obj, tobj, gt_box, self._train_batch_size, anchors, num_classes, downsample, self._ignore_thresh, scale_x_y) loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(cls, tcls) -- GitLab