提交 9db077cb 编写于 作者: F FlyingQianMM

batch_size is forced to be set to 1 in RCNN

上级 3d2eecb9
......@@ -228,7 +228,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
> **参数:**
>
> > - **eval_dataset** (paddlex.datasets): 验证数据读取器。
> > - **batch_size** (int): 验证数据批大小。默认为1。
> > - **batch_size** (int): 验证数据批大小。默认为1。当前只支持设置为1。
> > - **epoch_id** (int): 当前评估模型所在的训练轮数。
> > - **metric** (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认为None,根据用户传入的Dataset自动选择,如为VOCDetection,则`metric`为'VOC'; 如为COCODetection,则`metric`为'COCO'。
> > - **return_details** (bool): 是否返回详细信息。默认值为False。
......@@ -309,7 +309,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
> **参数:**
>
> > - **eval_dataset** (paddlex.datasets): 验证数据读取器。
> > - **batch_size** (int): 验证数据批大小。默认为1。
> > - **batch_size** (int): 验证数据批大小。默认为1。当前只支持设置为1。
> > - **epoch_id** (int): 当前评估模型所在的训练轮数。
> > - **metric** (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认为None,根据用户传入的Dataset自动选择,如为VOCDetection,则`metric`为'VOC'; 如为COCODetection,则`metric`为'COCO'。
> > - **return_details** (bool): 是否返回详细信息。默认值为False。
......
......@@ -259,7 +259,7 @@ class FasterRCNN(BaseAPI):
Args:
eval_dataset (paddlex.datasets): 验证数据读取器。
batch_size (int): 验证数据批大小。默认为1。
batch_size (int): 验证数据批大小。默认为1。当前只支持设置为1。
epoch_id (int): 当前评估模型所在的训练轮数。
metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认为None,
根据用户传入的Dataset自动选择,如为VOCDetection,则metric为'VOC';
......@@ -288,7 +288,11 @@ class FasterRCNN(BaseAPI):
"eval_dataset should be datasets.VOCDetection or datasets.COCODetection."
)
assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
if batch_size > 1:
batch_size = 1
logging.warning(
"Faster RCNN supports batch_size=1 only during evaluating, so batch_size is forced to be set to 1."
)
dataset = eval_dataset.generator(
batch_size=batch_size, drop_last=False)
......
......@@ -225,7 +225,7 @@ class MaskRCNN(FasterRCNN):
Args:
eval_dataset (paddlex.datasets): 验证数据读取器。
batch_size (int): 验证数据批大小。默认为1。
batch_size (int): 验证数据批大小。默认为1。当前只支持设置为1。
epoch_id (int): 当前评估模型所在的训练轮数。
metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认为None,
根据用户传入的Dataset自动选择,如为VOCDetection,则metric为'VOC';
......@@ -253,6 +253,11 @@ class MaskRCNN(FasterRCNN):
raise Exception(
"eval_dataset should be datasets.COCODetection.")
assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
if batch_size > 1:
batch_size = 1
logging.warning(
"Mask RCNN supports batch_size=1 only during evaluating, so batch_size is forced to be set to 1."
)
data_generator = eval_dataset.generator(
batch_size=batch_size, drop_last=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册