diff --git a/configs/ppyolo/README.md b/configs/ppyolo/README.md
index 92790e947960189707258f454277a5f085c9e8e3..768a5416954d44522dba4ea46d50c15a74909c72 100644
--- a/configs/ppyolo/README.md
+++ b/configs/ppyolo/README.md
@@ -53,13 +53,14 @@ PP-YOLO improved performance and speed of YOLOv3 with following methods:
**Notes:**
-- PP-YOLO is trained on COCO train2017 datast and evaluated on val2017 & test-dev2017 dataset,Box APtest is evaluation results of `mAP(IoU=0.5:0.95)`.
+- PP-YOLO is trained on COCO train2017 dataset and evaluated on val2017 & test-dev2017 dataset,Box APtest is evaluation results of `mAP(IoU=0.5:0.95)`.
- PP-YOLO used 8 GPUs for training and mini-batch size as 24 on each GPU, if GPU number and mini-batch size is changed, learning rate and iteration times should be adjusted according [FAQ](../../docs/FAQ.md).
- PP-YOLO inference speed is tesed on single Tesla V100 with batch size as 1, CUDA 10.2, CUDNN 7.5.1, TensorRT 5.1.2.2 in TensorRT mode.
- PP-YOLO FP32 inference speed testing uses inference model exported by `tools/export_model.py` and benchmarked by running `depoly/python/infer.py` with `--run_benchmark`. All testing results do not contains the time cost of data reading and post-processing(NMS), which is same as [YOLOv4(AlexyAB)](https://github.com/AlexeyAB/darknet) in testing method.
- TensorRT FP16 inference speed testing exclude the time cost of bounding-box decoding(`yolo_box`) part comparing with FP32 testing above, which means that data reading, bounding-box decoding and post-processing(NMS) is excluded(test method same as [YOLOv4(AlexyAB)](https://github.com/AlexeyAB/darknet) too)
- YOLOv4(AlexyAB) performance and inference speed is copy from single Tesla V100 testing results in [YOLOv4 github repo](https://github.com/AlexeyAB/darknet), Tesla V100 TensorRT FP16 inference speed is testing with tkDNN configuration and TensorRT 5.1.2.2 on single Tesla V100 based on [AlexyAB/darknet repo](https://github.com/AlexeyAB/darknet).
- Download and configuration of YOLOv4(AlexyAB) is reproduced model of YOLOv4 in PaddleDetection, whose evaluation performance is same as YOLOv4(AlexyAB), and finetune training is supported in PaddleDetection currently, reproducing by training from backbone pretrain weights is on working, see [PaddleDetection YOLOv4](../yolov4/README.md) for details.
+- PP-YOLO trained with `batch_size=24` in each GPU with memory as 32G, configuation yaml with `batch_size=12` which can be trained on GPU with memory as 16G is provided as `ppyolo_2x_bs12.yml`, training with `batch_size=12` reached `mAP(IoU=0.5:0.95) = 45.1%` on COCO val2017 dataset, download weights by [ppyolo_2x_bs12 model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_2x_bs12.pdparams)
### PP-YOLO for mobile
diff --git a/configs/ppyolo/README_cn.md b/configs/ppyolo/README_cn.md
index 7f3fb5104f4e9d3f0529ad58a4f266b4d463c982..42dada318bb523ecbc67be8c63422a4efb4248b5 100644
--- a/configs/ppyolo/README_cn.md
+++ b/configs/ppyolo/README_cn.md
@@ -61,6 +61,7 @@ PP-YOLO从如下方面优化和提升YOLOv3模型的精度和速度:
- YOLOv4(AlexyAB)模型精度和V100 FP32推理速度数据使用[YOLOv4 github库](https://github.com/AlexeyAB/darknet)提供的单卡V100上精度速度测试数据,V100 TensorRT FP16推理速度为使用[AlexyAB/darknet](https://github.com/AlexeyAB/darknet)库中tkDNN配置于单卡V100,TensorRT 5.1.2.2的测试结果。
- PP-YOLO模型推理速度测试采用单卡V100,batch size=1进行测试,使用CUDA 10.2, CUDNN 7.5.1,TensorRT推理速度测试使用TensorRT 5.1.2.2。
- YOLOv4(AlexyAB)行`模型下载`和`配置文件`为PaddleDetection复现的YOLOv4模型,目前评估精度已对齐,支持finetune,训练精度对齐中,可参见[PaddleDetection YOLOv4 模型](../yolov4/README.md)
+- PP-YOLO使用每GPU `batch_size=24`训练,需要使用显存为32G的GPU,我们也提供了`batch_size=12`的可以在显存为16G的GPU上训练的配置文件`ppyolo_2x_bs12.yml`,使用这个配置文件训练在COCO val2017数据集上评估结果为`mAP(IoU=0.5:0.95) = 45.1%`,可通过[ppyolo_2x_bs12模型](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_2x_bs12.pdparams)下载权重。
### PP-YOLO 移动端模型
diff --git a/configs/ppyolo/ppyolo_2x_bs12.yml b/configs/ppyolo/ppyolo_2x_bs12.yml
new file mode 100644
index 0000000000000000000000000000000000000000..fcab34d1fcd80f2b5fa53e92fe0ce4c47527253b
--- /dev/null
+++ b/configs/ppyolo/ppyolo_2x_bs12.yml
@@ -0,0 +1,93 @@
+architecture: YOLOv3
+use_gpu: true
+max_iters: 500000
+log_smooth_window: 100
+log_iter: 100
+save_dir: output
+snapshot_iter: 10000
+metric: COCO
+pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
+weights: output/ppyolo/model_final
+num_classes: 80
+use_fine_grained_loss: true
+use_ema: true
+ema_decay: 0.9998
+
+YOLOv3:
+ backbone: ResNet
+ yolo_head: YOLOv3Head
+ use_fine_grained_loss: true
+
+ResNet:
+ norm_type: sync_bn
+ freeze_at: 0
+ freeze_norm: false
+ norm_decay: 0.
+ depth: 50
+ feature_maps: [3, 4, 5]
+ variant: d
+ dcn_v2_stages: [5]
+
+YOLOv3Head:
+ anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
+ anchors: [[10, 13], [16, 30], [33, 23],
+ [30, 61], [62, 45], [59, 119],
+ [116, 90], [156, 198], [373, 326]]
+ norm_decay: 0.
+ coord_conv: true
+ iou_aware: true
+ iou_aware_factor: 0.4
+ scale_x_y: 1.05
+ spp: true
+ yolo_loss: YOLOv3Loss
+ nms: MatrixNMS
+ drop_block: true
+
+YOLOv3Loss:
+ ignore_thresh: 0.7
+ scale_x_y: 1.05
+ label_smooth: false
+ use_fine_grained_loss: true
+ iou_loss: IouLoss
+ iou_aware_loss: IouAwareLoss
+
+IouLoss:
+ loss_weight: 2.5
+ max_height: 608
+ max_width: 608
+
+IouAwareLoss:
+ loss_weight: 1.0
+ max_height: 608
+ max_width: 608
+
+MatrixNMS:
+ background_label: -1
+ keep_top_k: 100
+ normalized: false
+ score_threshold: 0.01
+ post_threshold: 0.01
+
+LearningRate:
+ base_lr: 0.005
+ schedulers:
+ - !PiecewiseDecay
+ gamma: 0.1
+ milestones:
+ - 400000
+ - 450000
+ - !LinearWarmup
+ start_factor: 0.
+ steps: 4000
+
+OptimizerBuilder:
+ optimizer:
+ momentum: 0.9
+ type: Momentum
+ regularizer:
+ factor: 0.0005
+ type: L2
+
+_READER_: 'ppyolo_reader.yml'
+TrainReader:
+ batch_size: 12
diff --git a/configs/ppyolo/ppyolo_reader.yml b/configs/ppyolo/ppyolo_reader.yml
index 295ddbaf9f265b0c9ee2e752f49983890518596a..f03e47216f9ce2003e66c2032511acf05e6ec1d9 100644
--- a/configs/ppyolo/ppyolo_reader.yml
+++ b/configs/ppyolo/ppyolo_reader.yml
@@ -17,6 +17,7 @@ TrainReader:
beta: 1.5
- !ColorDistort {}
- !RandomExpand
+ ratio: 2.0
fill_value: [123.675, 116.28, 103.53]
- !RandomCrop {}
- !RandomFlipImage
diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py
index c32f9fb9f4d3b2c4dad09b3de2e33f7ff82ed017..8a772fbb98ec0b275da68d44156163511c4ba7c5 100644
--- a/ppdet/data/transform/operators.py
+++ b/ppdet/data/transform/operators.py
@@ -2576,9 +2576,7 @@ class DebugVisibleImage(BaseOperator):
x1 = round(keypoint[2 * j]).astype(np.int32)
y1 = round(keypoint[2 * j + 1]).astype(np.int32)
draw.ellipse(
- (x1, y1, x1 + 5, y1i + 5),
- fill='green',
- outline='green')
+ (x1, y1, x1 + 5, y1 + 5), fill='green', outline='green')
save_path = os.path.join(self.output_dir, out_file_name)
image.save(save_path, quality=95)
return sample
diff --git a/ppdet/modeling/losses/yolo_loss.py b/ppdet/modeling/losses/yolo_loss.py
index e978eb992ec0aca9038e61ebb367cd47e6885a7b..9445f6a231ee9eabc6e8a54fab07c5b6e0128a0a 100644
--- a/ppdet/modeling/losses/yolo_loss.py
+++ b/ppdet/modeling/losses/yolo_loss.py
@@ -23,6 +23,9 @@ try:
except Exception:
from collections import Sequence
+import logging
+logger = logging.getLogger(__name__)
+
__all__ = ['YOLOv3Loss']
@@ -41,16 +44,18 @@ class YOLOv3Loss(object):
__inject__ = ['iou_loss', 'iou_aware_loss']
__shared__ = ['use_fine_grained_loss', 'train_batch_size']
- def __init__(self,
- train_batch_size=8,
- ignore_thresh=0.7,
- label_smooth=True,
- use_fine_grained_loss=False,
- iou_loss=None,
- iou_aware_loss=None,
- downsample=[32, 16, 8],
- scale_x_y=1.,
- match_score=False):
+ def __init__(
+ self,
+ train_batch_size=8,
+ batch_size=-1, # stub for backward compatable
+ ignore_thresh=0.7,
+ label_smooth=True,
+ use_fine_grained_loss=False,
+ iou_loss=None,
+ iou_aware_loss=None,
+ downsample=[32, 16, 8],
+ scale_x_y=1.,
+ match_score=False):
self._train_batch_size = train_batch_size
self._ignore_thresh = ignore_thresh
self._label_smooth = label_smooth
@@ -61,6 +66,11 @@ class YOLOv3Loss(object):
self.scale_x_y = scale_x_y
self.match_score = match_score
+ if batch_size != -1:
+ logger.warn(
+ "config YOLOv3Loss.batch_size is deprecated, "
+ "training batch size should be set by TrainReader.batch_size")
+
def __call__(self, outputs, gt_box, gt_label, gt_score, targets, anchors,
anchor_masks, mask_anchors, num_classes, prefix_name):
if self._use_fine_grained_loss:
diff --git a/tools/train.py b/tools/train.py
index dd2edbd4383524aa038347ea00997c997b980e93..b2632f4934eb21823e0b8b50a676404202456dd1 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -255,8 +255,9 @@ def main():
train_stats.update(stats)
logs = train_stats.log()
if it % cfg.log_iter == 0 and (not FLAGS.dist or trainer_id == 0):
- strs = 'iter: {}, lr: {:.6f}, {}, time: {:.3f}, eta: {}'.format(
- it, np.mean(outs[-1]), logs, time_cost, eta)
+ ips = float(cfg['TrainReader']['batch_size']) / time_cost
+ strs = 'iter: {}, lr: {:.6f}, {}, batch_cost: {:.5f} s, eta: {}, ips: {:.5f} images/sec'.format(
+ it, np.mean(outs[-1]), logs, time_cost, eta, ips)
logger.info(strs)
# NOTE : profiler tools, used for benchmark