From 9db077cb56d7daa09ca968633a710f6be514a67e Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Thu, 14 May 2020 11:04:16 +0800 Subject: [PATCH] batch_size is forced to be set to 1 in RCNN --- docs/apis/models.md | 4 ++-- paddlex/cv/models/faster_rcnn.py | 8 ++++++-- paddlex/cv/models/mask_rcnn.py | 7 ++++++- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/docs/apis/models.md b/docs/apis/models.md index 524fabb..009999e 100644 --- a/docs/apis/models.md +++ b/docs/apis/models.md @@ -228,7 +228,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec > **参数:** > > > - **eval_dataset** (paddlex.datasets): 验证数据读取器。 -> > - **batch_size** (int): 验证数据批大小。默认为1。 +> > - **batch_size** (int): 验证数据批大小。默认为1。当前只支持设置为1。 > > - **epoch_id** (int): 当前评估模型所在的训练轮数。 > > - **metric** (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认为None,根据用户传入的Dataset自动选择,如为VOCDetection,则`metric`为'VOC'; 如为COCODetection,则`metric`为'COCO'。 > > - **return_details** (bool): 是否返回详细信息。默认值为False。 @@ -309,7 +309,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_ > **参数:** > > > - **eval_dataset** (paddlex.datasets): 验证数据读取器。 -> > - **batch_size** (int): 验证数据批大小。默认为1。 +> > - **batch_size** (int): 验证数据批大小。默认为1。当前只支持设置为1。 > > - **epoch_id** (int): 当前评估模型所在的训练轮数。 > > - **metric** (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认为None,根据用户传入的Dataset自动选择,如为VOCDetection,则`metric`为'VOC'; 如为COCODetection,则`metric`为'COCO'。 > > - **return_details** (bool): 是否返回详细信息。默认值为False。 diff --git a/paddlex/cv/models/faster_rcnn.py b/paddlex/cv/models/faster_rcnn.py index 3b27bf0..3b7144f 100644 --- a/paddlex/cv/models/faster_rcnn.py +++ b/paddlex/cv/models/faster_rcnn.py @@ -259,7 +259,7 @@ class FasterRCNN(BaseAPI): Args: eval_dataset (paddlex.datasets): 验证数据读取器。 - batch_size (int): 验证数据批大小。默认为1。 + batch_size (int): 验证数据批大小。默认为1。当前只支持设置为1。 epoch_id (int): 当前评估模型所在的训练轮数。 metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认为None, 根据用户传入的Dataset自动选择,如为VOCDetection,则metric为'VOC'; @@ -288,7 +288,11 @@ class FasterRCNN(BaseAPI): "eval_dataset should be datasets.VOCDetection or datasets.COCODetection." ) assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'" - + if batch_size > 1: + batch_size = 1 + logging.warning( + "Faster RCNN supports batch_size=1 only during evaluating, so batch_size is forced to be set to 1." + ) dataset = eval_dataset.generator( batch_size=batch_size, drop_last=False) diff --git a/paddlex/cv/models/mask_rcnn.py b/paddlex/cv/models/mask_rcnn.py index 3766384..ba5da33 100644 --- a/paddlex/cv/models/mask_rcnn.py +++ b/paddlex/cv/models/mask_rcnn.py @@ -225,7 +225,7 @@ class MaskRCNN(FasterRCNN): Args: eval_dataset (paddlex.datasets): 验证数据读取器。 - batch_size (int): 验证数据批大小。默认为1。 + batch_size (int): 验证数据批大小。默认为1。当前只支持设置为1。 epoch_id (int): 当前评估模型所在的训练轮数。 metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认为None, 根据用户传入的Dataset自动选择,如为VOCDetection,则metric为'VOC'; @@ -253,6 +253,11 @@ class MaskRCNN(FasterRCNN): raise Exception( "eval_dataset should be datasets.COCODetection.") assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'" + if batch_size > 1: + batch_size = 1 + logging.warning( + "Mask RCNN supports batch_size=1 only during evaluating, so batch_size is forced to be set to 1." + ) data_generator = eval_dataset.generator( batch_size=batch_size, drop_last=False) -- GitLab