未验证 提交 a04b2a74 编写于 作者: K Kaipeng Deng 提交者: GitHub

Add ssd (#1782)

* add ssd_vgg16_120e_coco
上级 a8b1d8d7
metric: VOC
num_classes: 20
TrainDataset:
!VOCDataSet
dataset_dir: dataset/voc
anno_path: trainval.txt
label_list: label_list.txt
EvalDataset:
!VOCDataSet
dataset_dir: dataset/voc
anno_path: test.txt
label_list: label_list.txt
TestDataset:
!ImageFolder
anno_path: dataset/voc/label_list.txt
architecture: SSD
pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/dygraph/VGG16_caffe_pretrained.pdparams
weights: output/ssd_vgg16/model_final
# Model Achitecture
SSD:
# model feat info flow
backbone: VGG
ssd_head: SSDHead
# post process
post_process: BBoxPostProcess
VGG:
depth: 16
normalizations: [20., -1, -1, -1, -1, -1]
SSDHead:
in_channels: [512, 1024, 512, 256, 256, 256]
anchor_generator: AnchorGeneratorSSD
AnchorGeneratorSSD:
steps: [8, 16, 32, 64, 100, 300]
aspect_ratios: [[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]]
min_ratio: 20
max_ratio: 90
min_sizes: [30.0, 60.0, 111.0, 162.0, 213.0, 264.0]
max_sizes: [60.0, 111.0, 162.0, 213.0, 264.0, 315.0]
offset: 0.5
flip: true
min_max_aspect_ratios_order: true
BBoxPostProcess:
decode:
name: SSDBox
nms:
name: MultiClassNMS
keep_top_k: 200
score_threshold: 0.01
nms_threshold: 0.45
nms_top_k: 400
nms_eta: 1.0
worker_num: 2
TrainReader:
inputs_def:
fields: ['image', 'gt_bbox', 'gt_class']
num_max_boxes: 90
sample_transforms:
- DecodeOp: {}
- RandomDistortOp: {brightness: [0.5, 1.125, 0.875], random_apply: False}
- RandomExpandOp: {fill_value: [104., 117., 123.]}
- RandomCropOp: {allow_no_crop: true}
- RandomFlipOp: {}
- NormalizeBoxOp: {}
- ResizeOp: {target_size: [300, 300], keep_ratio: False, interp: 1}
- PadBoxOp: {num_max_boxes: 90}
batch_transforms:
- NormalizeImageOp: {mean: [104., 117., 123.], std: [1., 1., 1.], is_scale: false}
- PermuteOp: {}
batch_size: 8
shuffle: true
drop_last: true
EvalReader:
inputs_def:
fields: ['image', 'im_shape', 'scale_factor', 'im_id', 'gt_bbox', 'gt_class', 'difficult']
sample_transforms:
- DecodeOp: {}
- ResizeOp: {target_size: [300, 300], keep_ratio: False, interp: 1}
- NormalizeImageOp: {mean: [104., 117., 123.], std: [1., 1., 1.], is_scale: false}
- PermuteOp: {}
batch_size: 1
drop_empty: false
TestReader:
inputs_def:
image_shape: [3, 300, 300]
fields: ['image', 'im_shape', 'scale_factor', 'im_id']
sample_transforms:
- DecodeOp: {}
- ResizeOp: {target_size: [300, 300], keep_ratio: False, interp: 1}
- NormalizeImageOp: {mean: [104., 117., 123.], std: [1., 1., 1.], is_scale: false}
- PermuteOp: {}
batch_size: 1
_BASE_: [
'./_base_/models/ssd_vgg16_300.yml',
'./_base_/optimizers/ssd_120e.yml',
'./_base_/datasets/coco.yml',
'./_base_/readers/ssd_reader.yml',
'./_base_/runtime.yml',
]
_BASE_: [
'./_base_/models/ssd_vgg16_300.yml',
'./_base_/optimizers/ssd_240e.yml',
'./_base_/datasets/voc.yml',
'./_base_/readers/ssd_reader.yml',
'./_base_/runtime.yml',
]
......@@ -33,6 +33,7 @@ from paddle.inference import create_predictor
SUPPORT_MODELS = {
'YOLO',
'RCNN',
'SSD',
}
......@@ -73,7 +74,7 @@ class Detector(object):
def postprocess(self, np_boxes, np_masks, inputs, threshold=0.5):
# postprocess output of predictor
results = {}
if self.pred_config.arch in ['SSD', 'Face']:
if self.pred_config.arch in ['Face']:
h, w = inputs['im_shape']
scale_y, scale_x = inputs['scale_factor']
w, h = float(h) / scale_y, float(w) / scale_x
......
......@@ -41,3 +41,11 @@ Paddle提供基于ImageNet的骨架网络预训练模型。所有预训练模型
| ResNet50-FPN | Cascade Faster | 1 | 1x | ---- | 41.1 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/cascade_rcnn_r50_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/cascade_faster_rcnn_r50_fpn_1x_coco.yml) |
| ResNet50-FPN | Cascade Mask | 1 | 1x | ---- | 41.6 | 35.3 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/cascade_mask_rcnn_r50_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/cascade_mask_rcnn_r50_fpn_1x_coco.yml) |
| DarkNet53 | YOLOv3 | 1 | 270e | ---- | 39.0 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_darknet53_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/yolov3_darknet53_270e_coco.yml) |
### SSD on Pascal VOC
| 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 |推理时间(fps) | Box AP | 下载 | 配置文件 |
| :-------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: |
| VGG | SSD | 8 | 240e | ---- | 78.2 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ssd_vgg16_300_240e_voc.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/ssd_vgg16_300_240e_voc.yml) |
**注意:** SSD使用4GPU训练,训练240个epoch
......@@ -2,6 +2,7 @@ import copy
import traceback
import logging
import threading
import six
import sys
if sys.version_info >= (3, 0):
import queue as Queue
......@@ -118,25 +119,25 @@ class BaseDataLoader(object):
batch_sampler=None,
return_list=False,
use_prefetch=True):
self._dataset = dataset
self._dataset.parse_dataset(self.with_background)
self.dataset = dataset
self.dataset.parse_dataset(self.with_background)
# get data
self._dataset.set_out(self._sample_transforms,
copy.deepcopy(self._fields))
self.dataset.set_out(self._sample_transforms,
copy.deepcopy(self._fields))
# set kwargs
self._dataset.set_kwargs(**self.kwargs)
self.dataset.set_kwargs(**self.kwargs)
# batch sampler
if batch_sampler is None:
self._batch_sampler = DistributedBatchSampler(
self._dataset,
self.dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
drop_last=self.drop_last)
else:
self._batch_sampler = batch_sampler
loader = DataLoader(
dataset=self._dataset,
self.loader = DataLoader(
dataset=self.dataset,
batch_sampler=self._batch_sampler,
collate_fn=self._batch_transforms,
num_workers=worker_num,
......@@ -144,8 +145,29 @@ class BaseDataLoader(object):
return_list=return_list,
use_buffer_reader=use_prefetch,
use_shared_memory=False)
self.loader = iter(self.loader)
return loader, len(self._batch_sampler)
return self
def __len__(self):
return len(self._batch_sampler)
def __iter__(self):
return self
def __next__(self):
# pack {filed_name: field_data} here
# looking forward to support dictionary
# data structure in paddle.io.DataLoader
try:
data = next(self.loader)
return {k: v for k, v in zip(self._fields, data)}
except StopIteration:
six.reraise(*sys.exc_info())
def next(self):
# python2 compatibility
return self.__next__()
@register
......
......@@ -14,9 +14,9 @@
from . import coco
# TODO add voc and widerface dataset
#from . import voc
from . import voc
#from . import widerface
from .coco import *
#from .voc import *
from .voc import *
#from .widerface import *
......@@ -19,14 +19,14 @@ import xml.etree.ElementTree as ET
from ppdet.core.workspace import register, serializable
from .dataset import DataSet
from .dataset import DetDataset
import logging
logger = logging.getLogger(__name__)
@register
@serializable
class VOCDataSet(DataSet):
class VOCDataSet(DetDataset):
"""
Load dataset with PascalVOC format.
......@@ -38,8 +38,6 @@ class VOCDataSet(DataSet):
image_dir (str): directory for images.
anno_path (str): voc annotation file path.
sample_num (int): number of samples to load, -1 means all.
use_default_label (bool): whether use the default mapping of
label to integer index. Default True.
label_list (str): if use_default_label is False, will load
mapping between category and class index.
"""
......@@ -49,32 +47,15 @@ class VOCDataSet(DataSet):
image_dir=None,
anno_path=None,
sample_num=-1,
use_default_label=True,
label_list='label_list.txt'):
label_list=None):
super(VOCDataSet, self).__init__(
dataset_dir=dataset_dir,
image_dir=image_dir,
anno_path=anno_path,
sample_num=sample_num,
dataset_dir=dataset_dir)
# roidbs is list of dict whose structure is:
# {
# 'im_file': im_fname, # image file name
# 'im_id': im_id, # image id
# 'h': im_h, # height of image
# 'w': im_w, # width
# 'is_crowd': is_crowd,
# 'gt_class': gt_class,
# 'gt_score': gt_score,
# 'gt_bbox': gt_bbox,
# 'difficult': difficult
# }
self.roidbs = None
# 'cname2id' is a dict to map category name to class id
self.cname2cid = None
self.use_default_label = use_default_label
sample_num=sample_num)
self.label_list = label_list
def load_roidb_and_cname2cid(self, with_background=True):
def parse_dataset(self, with_background=True):
anno_path = os.path.join(self.dataset_dir, self.anno_path)
image_dir = os.path.join(self.dataset_dir, self.image_dir)
......@@ -86,7 +67,7 @@ class VOCDataSet(DataSet):
records = []
ct = 0
cname2cid = {}
if not self.use_default_label:
if self.label_list:
label_path = os.path.join(self.dataset_dir, self.label_list)
if not os.path.exists(label_path):
raise ValueError("label_list {} does not exists".format(
......@@ -183,6 +164,9 @@ class VOCDataSet(DataSet):
logger.debug('{} samples in file {}'.format(ct, anno_path))
self.roidbs, self.cname2cid = records, cname2cid
def get_label_list(self):
return os.path.join(self.dataset_dir, self.label_list)
def pascalvoc_label(with_background=True):
labels_map = {
......
......@@ -1668,7 +1668,7 @@ class PadBoxOp(BaseOperator):
# in training, for example in op ExpandImage,
# the bbox and gt_class is expandded, but the difficult is not,
# so, judging by it's length
if 'is_difficult' in sample:
if 'difficult' in sample:
pad_diff = np.zeros((num_max, ), dtype=np.int32)
if gt_num > 0:
pad_diff[:gt_num] = sample['difficult'][:gt_num, 0]
......
......@@ -10,9 +10,11 @@ from . import faster_rcnn
from . import mask_rcnn
from . import yolo
from . import cascade_rcnn
from . import ssd
from .meta_arch import *
from .faster_rcnn import *
from .mask_rcnn import *
from .yolo import *
from .cascade_rcnn import *
from .ssd import *
......@@ -15,17 +15,8 @@ class BaseArch(nn.Layer):
def __init__(self):
super(BaseArch, self).__init__()
def forward(self,
input_tensor=None,
data=None,
input_def=None,
mode='infer'):
if input_tensor is None:
assert data is not None and input_def is not None
self.inputs = self.build_inputs(data, input_def)
else:
self.inputs = input_tensor
def forward(self, inputs, mode='infer'):
self.inputs = inputs
self.inputs['mode'] = mode
self.model_arch()
......
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
from ppdet.core.workspace import register
from .meta_arch import BaseArch
__all__ = ['SSD']
@register
class SSD(BaseArch):
__category__ = 'architecture'
__inject__ = ['backbone', 'neck', 'ssd_head', 'post_process']
def __init__(self, backbone, ssd_head, post_process, neck=None):
super(SSD, self).__init__()
self.backbone = backbone
self.neck = neck
self.ssd_head = ssd_head
self.post_process = post_process
def model_arch(self):
# Backbone
body_feats = self.backbone(self.inputs)
# Neck
if self.neck is not None:
body_feats, spatial_scale = self.neck(body_feats)
# SSD Head
self.ssd_head_outs, self.anchors = self.ssd_head(body_feats,
self.inputs['image'])
def get_loss(self, ):
loss = self.ssd_head.get_loss(self.ssd_head_outs, self.inputs,
self.anchors)
return {"loss": loss}
def get_pred(self, return_numpy=True):
output = {}
bbox, bbox_num = self.post_process(self.ssd_head_outs, self.anchors,
self.inputs['im_shape'],
self.inputs['scale_factor'])
outs = {
"bbox": bbox,
"bbox_num": bbox_num,
}
return outs
from . import vgg
from . import resnet
from . import darknet
from .vgg import *
from .resnet import *
from .darknet import *
from __future__ import division
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.fluid.regularizer import L2Decay
from paddle.nn import Conv2D, MaxPool2D
from ppdet.core.workspace import register, serializable
__all__ = ['VGG']
VGG_cfg = {16: [2, 2, 3, 3, 3], 19: [2, 2, 4, 4, 4]}
class ConvBlock(nn.Layer):
def __init__(self,
in_channels,
out_channels,
groups,
pool_size=2,
pool_stride=2,
pool_padding=0,
name=None):
super(ConvBlock, self).__init__()
self.groups = groups
self.conv0 = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
weight_attr=ParamAttr(name=name + "1_weights"),
bias_attr=ParamAttr(name=name + "1_bias"))
self.conv_out_list = []
for i in range(1, groups):
conv_out = self.add_sublayer(
'conv{}'.format(i),
Conv2D(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
weight_attr=ParamAttr(
name=name + "{}_weights".format(i + 1)),
bias_attr=ParamAttr(name=name + "{}_bias".format(i + 1))))
self.conv_out_list.append(conv_out)
self.pool = MaxPool2D(
kernel_size=pool_size,
stride=pool_stride,
padding=pool_padding,
ceil_mode=True)
def forward(self, inputs):
out = self.conv0(inputs)
out = F.relu(out)
for conv_i in self.conv_out_list:
out = conv_i(out)
out = F.relu(out)
pool = self.pool(out)
return out, pool
class ExtraBlock(nn.Layer):
def __init__(self,
in_channels,
mid_channels,
out_channels,
padding,
stride,
kernel_size,
name=None):
super(ExtraBlock, self).__init__()
self.conv0 = Conv2D(
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=1,
stride=1,
padding=0)
self.conv1 = Conv2D(
in_channels=mid_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding)
def forward(self, inputs):
out = self.conv0(inputs)
out = F.relu(out)
out = self.conv1(out)
out = F.relu(out)
return out
class L2NormScale(nn.Layer):
def __init__(self, num_channels, scale=1.0):
super(L2NormScale, self).__init__()
self.scale = self.create_parameter(
attr=ParamAttr(initializer=paddle.nn.initializer.Constant(scale)),
shape=[num_channels])
def forward(self, inputs):
out = F.normalize(inputs, axis=1, epsilon=1e-10)
# out = self.scale.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(
# out) * out
out = self.scale.unsqueeze(0).unsqueeze(2).unsqueeze(3) * out
return out
@register
@serializable
class VGG(nn.Layer):
def __init__(self,
depth=16,
normalizations=[20., -1, -1, -1, -1, -1],
extra_block_filters=[[256, 512, 1, 2, 3], [128, 256, 1, 2, 3],
[128, 256, 0, 1, 3],
[128, 256, 0, 1, 3]]):
super(VGG, self).__init__()
assert depth in [16, 19], \
"depth as 16/19 supported currently, but got {}".format(depth)
self.depth = depth
self.groups = VGG_cfg[depth]
self.normalizations = normalizations
self.extra_block_filters = extra_block_filters
self.conv_block_0 = ConvBlock(
3, 64, self.groups[0], 2, 2, 0, name="conv1_")
self.conv_block_1 = ConvBlock(
64, 128, self.groups[1], 2, 2, 0, name="conv2_")
self.conv_block_2 = ConvBlock(
128, 256, self.groups[2], 2, 2, 0, name="conv3_")
self.conv_block_3 = ConvBlock(
256, 512, self.groups[3], 2, 2, 0, name="conv4_")
self.conv_block_4 = ConvBlock(
512, 512, self.groups[4], 3, 1, 1, name="conv5_")
self.fc6 = Conv2D(
in_channels=512,
out_channels=1024,
kernel_size=3,
stride=1,
padding=6,
dilation=6)
self.fc7 = Conv2D(
in_channels=1024,
out_channels=1024,
kernel_size=1,
stride=1,
padding=0)
# extra block
self.extra_convs = []
last_channels = 1024
for i, v in enumerate(self.extra_block_filters):
assert len(v) == 5, "extra_block_filters size not fix"
extra_conv = self.add_sublayer("conv{}".format(6 + i),
ExtraBlock(last_channels, v[0], v[1],
v[2], v[3], v[4]))
last_channels = v[1]
self.extra_convs.append(extra_conv)
self.norms = []
for i, n in enumerate(self.normalizations):
if n != -1:
norm = self.add_sublayer("norm{}".format(i),
L2NormScale(
self.extra_block_filters[i][1], n))
else:
norm = None
self.norms.append(norm)
def forward(self, inputs):
outputs = []
conv, pool = self.conv_block_0(inputs['image'])
conv, pool = self.conv_block_1(pool)
conv, pool = self.conv_block_2(pool)
conv, pool = self.conv_block_3(pool)
outputs.append(conv)
conv, pool = self.conv_block_4(pool)
out = self.fc6(pool)
out = F.relu(out)
out = self.fc7(out)
out = F.relu(out)
outputs.append(out)
if not self.extra_block_filters:
return out
# extra block
for extra_conv in self.extra_convs:
out = extra_conv(out)
outputs.append(out)
for i, n in enumerate(self.normalizations):
if n != -1:
outputs[i] = self.norms[i](outputs[i])
return outputs
......@@ -17,9 +17,11 @@ from . import bbox_head
from . import mask_head
from . import yolo_head
from . import roi_extractor
from . import ssd_head
from .rpn_head import *
from .bbox_head import *
from .mask_head import *
from .yolo_head import *
from .roi_extractor import *
from .ssd_head import *
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
from ppdet.core.workspace import register
@register
class SSDHead(nn.Layer):
__shared__ = ['num_classes']
__inject__ = ['anchor_generator', 'loss']
def __init__(self,
num_classes=81,
in_channels=(512, 1024, 512, 256, 256, 256),
anchor_generator='AnchorGeneratorSSD',
loss='SSDLoss'):
super(SSDHead, self).__init__()
self.num_classes = num_classes
self.in_channels = in_channels
self.anchor_generator = anchor_generator
self.loss = loss
self.num_priors = self.anchor_generator.num_priors
self.box_convs = []
self.score_convs = []
for i, num_prior in enumerate(self.num_priors):
self.box_convs.append(
self.add_sublayer(
"boxes{}".format(i),
nn.Conv2D(
in_channels=in_channels[i],
out_channels=num_prior * 4,
kernel_size=3,
padding=1)))
self.score_convs.append(
self.add_sublayer(
"scores{}".format(i),
nn.Conv2D(
in_channels=in_channels[i],
out_channels=num_prior * num_classes,
kernel_size=3,
padding=1)))
def forward(self, feats, image):
box_preds = []
cls_scores = []
prior_boxes = []
for feat, box_conv, score_conv in zip(feats, self.box_convs,
self.score_convs):
box_pred = box_conv(feat)
box_pred = paddle.transpose(box_pred, [0, 2, 3, 1])
box_pred = paddle.reshape(box_pred, [0, -1, 4])
box_preds.append(box_pred)
cls_score = score_conv(feat)
cls_score = paddle.transpose(cls_score, [0, 2, 3, 1])
cls_score = paddle.reshape(cls_score, [0, -1, self.num_classes])
cls_scores.append(cls_score)
prior_boxes = self.anchor_generator(feats, image)
outputs = {}
outputs['boxes'] = box_preds
outputs['scores'] = cls_scores
return outputs, prior_boxes
def get_loss(self, inputs, targets, prior_boxes):
return self.loss(inputs, targets, prior_boxes)
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
from numbers import Integral
......@@ -24,6 +25,12 @@ from . import ops
import paddle.nn.functional as F
def _to_list(l):
if isinstance(l, (list, tuple)):
return list(l)
return [l]
@register
@serializable
class AnchorGeneratorRPN(object):
......@@ -103,6 +110,57 @@ class AnchorTargetGeneratorRPN(object):
return pred_cls_logits, pred_bbox_pred, tgt_labels, tgt_bboxes, bbox_inside_weights
@register
@serializable
class AnchorGeneratorSSD(object):
def __init__(self,
steps=[8, 16, 32, 64, 100, 300],
aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
min_ratio=15,
max_ratio=90,
min_sizes=[30.0, 60.0, 111.0, 162.0, 213.0, 264.0],
max_sizes=[60.0, 111.0, 162.0, 213.0, 264.0, 315.0],
offset=0.5,
flip=True,
clip=False,
min_max_aspect_ratios_order=False):
self.steps = steps
self.aspect_ratios = aspect_ratios
self.min_ratio = min_ratio
self.max_ratio = max_ratio
self.min_sizes = min_sizes
self.max_sizes = max_sizes
self.offset = offset
self.flip = flip
self.clip = clip
self.min_max_aspect_ratios_order = min_max_aspect_ratios_order
self.num_priors = []
for aspect_ratio, min_size, max_size in zip(aspect_ratios, min_sizes,
max_sizes):
self.num_priors.append((len(aspect_ratio) * 2 + 1) * len(
_to_list(min_size)) + len(_to_list(max_size)))
def __call__(self, inputs, image):
boxes = []
for input, min_size, max_size, aspect_ratio, step in zip(
inputs, self.min_sizes, self.max_sizes, self.aspect_ratios,
self.steps):
box, _ = ops.prior_box(
input=input,
image=image,
min_sizes=_to_list(min_size),
max_sizes=_to_list(max_size),
aspect_ratios=aspect_ratio,
flip=self.flip,
clip=self.clip,
steps=[step, step],
offset=self.offset,
min_max_aspect_ratios_order=self.min_max_aspect_ratios_order)
boxes.append(paddle.reshape(box, [-1, 4]))
return boxes
@register
@serializable
class ProposalGenerator(object):
......@@ -420,7 +478,12 @@ class YOLOBox(object):
self.clip_bbox = clip_bbox
self.scale_x_y = scale_x_y
def __call__(self, yolo_head_out, anchors, im_shape, scale_factor):
def __call__(self,
yolo_head_out,
anchors,
im_shape,
scale_factor,
var_weight=None):
boxes_list = []
scores_list = []
origin_shape = im_shape / scale_factor
......@@ -437,6 +500,54 @@ class YOLOBox(object):
return yolo_boxes, yolo_scores
@register
@serializable
class SSDBox(object):
def __init__(self, is_normalized=True):
self.is_normalized = is_normalized
self.norm_delta = float(not self.is_normalized)
def __call__(self,
preds,
prior_boxes,
im_shape,
scale_factor,
var_weight=None):
boxes, scores = preds['boxes'], preds['scores']
outputs = []
for box, score, prior_box in zip(boxes, scores, prior_boxes):
pb_w = prior_box[:, 2] - prior_box[:, 0] + self.norm_delta
pb_h = prior_box[:, 3] - prior_box[:, 1] + self.norm_delta
pb_x = prior_box[:, 0] + pb_w * 0.5
pb_y = prior_box[:, 1] + pb_h * 0.5
out_x = pb_x + box[:, :, 0] * pb_w * 0.1
out_y = pb_y + box[:, :, 1] * pb_h * 0.1
out_w = paddle.exp(box[:, :, 2] * 0.2) * pb_w
out_h = paddle.exp(box[:, :, 3] * 0.2) * pb_h
if self.is_normalized:
h = im_shape[:, 0] / scale_factor[:, 0]
w = im_shape[:, 1] / scale_factor[:, 1]
output = paddle.stack(
[(out_x - out_w / 2.) * w, (out_y - out_h / 2.) * h,
(out_x + out_w / 2.) * w, (out_y + out_h / 2.) * h],
axis=-1)
else:
output = paddle.stack(
[
out_x - out_w / 2., out_y - out_h / 2.,
out_x + out_w / 2. - 1., out_y + out_h / 2. - 1.
],
axis=-1)
outputs.append(output)
boxes = paddle.concat(outputs, axis=1)
scores = F.softmax(paddle.concat(scores, axis=1))
scores = paddle.transpose(scores, [0, 2, 1])
return boxes, scores
@register
@serializable
class AnchorGrid(object):
......
......@@ -15,7 +15,9 @@
from . import yolo_loss
from . import iou_aware_loss
from . import iou_loss
from . import ssd_loss
from .yolo_loss import *
from .iou_aware_loss import *
from .iou_loss import *
from .ssd_loss import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
from ppdet.core.workspace import register
from ..ops import bipartite_match, box_coder, iou_similarity
__all__ = ['SSDLoss']
@register
class SSDLoss(nn.Layer):
def __init__(self,
match_type='per_prediction',
overlap_threshold=0.5,
neg_pos_ratio=3.0,
neg_overlap=0.5,
loc_loss_weight=1.0,
conf_loss_weight=1.0):
super(SSDLoss, self).__init__()
self.match_type = match_type
self.overlap_threshold = overlap_threshold
self.neg_pos_ratio = neg_pos_ratio
self.neg_overlap = neg_overlap
self.loc_loss_weight = loc_loss_weight
self.conf_loss_weight = conf_loss_weight
def _label_target_assign(self,
gt_label,
matched_indices,
neg_mask=None,
mismatch_value=0):
gt_label = gt_label.numpy()
matched_indices = matched_indices.numpy()
if neg_mask is not None:
neg_mask = neg_mask.numpy()
batch_size, num_priors = matched_indices.shape
trg_lbl = np.ones((batch_size, num_priors, 1)).astype('int32')
trg_lbl *= mismatch_value
trg_lbl_wt = np.zeros((batch_size, num_priors, 1)).astype('float32')
for i in range(batch_size):
col_ids = np.where(matched_indices[i] > -1)
col_val = matched_indices[i][col_ids]
trg_lbl[i][col_ids] = gt_label[i][col_val]
trg_lbl_wt[i][col_ids] = 1.0
if neg_mask is not None:
trg_lbl_wt += neg_mask[:, :, np.newaxis]
return paddle.to_tensor(trg_lbl), paddle.to_tensor(trg_lbl_wt)
def _bbox_target_assign(self, encoded_box, matched_indices):
encoded_box = encoded_box.numpy()
matched_indices = matched_indices.numpy()
batch_size, num_priors = matched_indices.shape
trg_bbox = np.zeros((batch_size, num_priors, 4)).astype('float32')
trg_bbox_wt = np.zeros((batch_size, num_priors, 1)).astype('float32')
for i in range(batch_size):
col_ids = np.where(matched_indices[i] > -1)
col_val = matched_indices[i][col_ids]
for v, c in zip(col_val.tolist(), col_ids[0]):
trg_bbox[i][c] = encoded_box[i][v][c]
trg_bbox_wt[i][col_ids] = 1.0
return paddle.to_tensor(trg_bbox), paddle.to_tensor(trg_bbox_wt)
def _mine_hard_example(self,
conf_loss,
matched_indices,
matched_dist,
neg_pos_ratio=3.0,
neg_overlap=0.5):
pos = (matched_indices > -1).astype(conf_loss.dtype)
num_pos = pos.sum(axis=1, keepdim=True)
neg = (matched_dist < neg_overlap).astype(conf_loss.dtype)
conf_loss = conf_loss * (1.0 - pos) * neg
loss_idx = conf_loss.argsort(axis=1, descending=True)
idx_rank = loss_idx.argsort(axis=1)
num_negs = []
for i in range(matched_indices.shape[0]):
cur_idx = loss_idx[i]
cur_num_pos = num_pos[i]
num_neg = paddle.clip(cur_num_pos * neg_pos_ratio, max=pos.shape[1])
num_negs.append(num_neg)
num_neg = paddle.stack(num_negs, axis=0).expand_as(idx_rank)
neg_mask = (idx_rank < num_neg).astype(conf_loss.dtype)
return neg_mask
def forward(self, inputs, targets, anchors):
boxes = paddle.concat(inputs['boxes'], axis=1)
scores = paddle.concat(inputs['scores'], axis=1)
prior_boxes = paddle.concat(anchors, axis=0)
gt_box = targets['gt_bbox']
gt_label = targets['gt_class'].unsqueeze(-1)
batch_size, num_priors, num_classes = scores.shape
def _reshape_to_2d(x):
return paddle.flatten(x, start_axis=2)
# 1. Find matched bounding box by prior box.
# 1.1 Compute IOU similarity between ground-truth boxes and prior boxes.
# 1.2 Compute matched bounding box by bipartite matching algorithm.
matched_indices = []
matched_dist = []
for i in range(gt_box.shape[0]):
iou = iou_similarity(gt_box[i], prior_boxes)
matched_indice, matched_d = bipartite_match(iou, self.match_type,
self.overlap_threshold)
matched_indices.append(matched_indice)
matched_dist.append(matched_d)
matched_indices = paddle.concat(matched_indices, axis=0)
matched_indices.stop_gradient = True
matched_dist = paddle.concat(matched_dist, axis=0)
matched_dist.stop_gradient = True
# 2. Compute confidence for mining hard examples
# 2.1. Get the target label based on matched indices
target_label, _ = self._label_target_assign(gt_label, matched_indices)
confidence = _reshape_to_2d(scores)
# 2.2. Compute confidence loss.
# Reshape confidence to 2D tensor.
target_label = _reshape_to_2d(target_label).astype('int64')
conf_loss = F.softmax_with_cross_entropy(confidence, target_label)
conf_loss = paddle.reshape(conf_loss, [batch_size, num_priors])
# 3. Mining hard examples
neg_mask = self._mine_hard_example(
conf_loss,
matched_indices,
matched_dist,
neg_pos_ratio=self.neg_pos_ratio,
neg_overlap=self.neg_overlap)
# 4. Assign classification and regression targets
# 4.1. Encoded bbox according to the prior boxes.
prior_box_var = paddle.to_tensor(
np.array(
[0.1, 0.1, 0.2, 0.2], dtype='float32')).reshape(
[1, 4]).expand_as(prior_boxes)
encoded_bbox = []
for i in range(gt_box.shape[0]):
encoded_bbox.append(
box_coder(
prior_box=prior_boxes,
prior_box_var=prior_box_var,
target_box=gt_box[i],
code_type='encode_center_size'))
encoded_bbox = paddle.stack(encoded_bbox, axis=0)
# 4.2. Assign regression targets
target_bbox, target_loc_weight = self._bbox_target_assign(
encoded_bbox, matched_indices)
# 4.3. Assign classification targets
target_label, target_conf_weight = self._label_target_assign(
gt_label, matched_indices, neg_mask=neg_mask)
# 5. Compute loss.
# 5.1 Compute confidence loss.
target_label = _reshape_to_2d(target_label).astype('int64')
conf_loss = F.softmax_with_cross_entropy(confidence, target_label)
target_conf_weight = _reshape_to_2d(target_conf_weight)
conf_loss = conf_loss * target_conf_weight * self.conf_loss_weight
# 5.2 Compute regression loss.
location = _reshape_to_2d(boxes)
target_bbox = _reshape_to_2d(target_bbox)
loc_loss = F.smooth_l1_loss(location, target_bbox, reduction='none')
loc_loss = paddle.sum(loc_loss, axis=-1, keepdim=True)
target_loc_weight = _reshape_to_2d(target_loc_weight)
loc_loss = loc_loss * target_loc_weight * self.loc_loss_weight
# 5.3 Compute overall weighted loss.
loss = conf_loss + loc_loss
loss = paddle.reshape(loss, [batch_size, num_priors])
loss = paddle.sum(loss, axis=1, keepdim=True)
normalizer = paddle.sum(target_loc_weight)
loss = paddle.sum(loss / normalizer)
return loss
......@@ -806,15 +806,12 @@ def prior_box(input,
cur_max_sizes = max_sizes
if in_dygraph_mode():
attrs = [
'min_sizes', min_sizes, 'aspect_ratios', aspect_ratios, 'variances',
variance, 'flip', flip, 'clip', clip, 'step_w', steps[0], 'step_h',
steps[1], 'offset', offset, 'min_max_aspect_ratios_order',
min_max_aspect_ratios_order
]
if cur_max_sizes is not None:
attrs.extend('max_sizes', max_sizes)
attrs = tuple(attrs)
assert cur_max_sizes is not None
attrs = ('min_sizes', min_sizes, 'max_sizes', cur_max_sizes,
'aspect_ratios', aspect_ratios, 'variances', variance, 'flip',
flip, 'clip', clip, 'step_w', steps[0], 'step_h', steps[1],
'offset', offset, 'min_max_aspect_ratios_order',
min_max_aspect_ratios_order)
box, var = core.ops.prior_box(input, image, *attrs)
return box, var
else:
......@@ -1254,6 +1251,111 @@ def matrix_nms(bboxes,
return output
def bipartite_match(dist_matrix,
match_type=None,
dist_threshold=None,
name=None):
"""
This operator implements a greedy bipartite matching algorithm, which is
used to obtain the matching with the maximum distance based on the input
distance matrix. For input 2D matrix, the bipartite matching algorithm can
find the matched column for each row (matched means the largest distance),
also can find the matched row for each column. And this operator only
calculate matched indices from column to row. For each instance,
the number of matched indices is the column number of the input distance
matrix. **The OP only supports CPU**.
There are two outputs, matched indices and distance.
A simple description, this algorithm matched the best (maximum distance)
row entity to the column entity and the matched indices are not duplicated
in each row of ColToRowMatchIndices. If the column entity is not matched
any row entity, set -1 in ColToRowMatchIndices.
NOTE: the input DistMat can be LoDTensor (with LoD) or Tensor.
If LoDTensor with LoD, the height of ColToRowMatchIndices is batch size.
If Tensor, the height of ColToRowMatchIndices is 1.
NOTE: This API is a very low level API. It is used by :code:`ssd_loss`
layer. Please consider to use :code:`ssd_loss` instead.
Args:
dist_matrix(Tensor): This input is a 2-D LoDTensor with shape
[K, M]. The data type is float32 or float64. It is pair-wise
distance matrix between the entities represented by each row and
each column. For example, assumed one entity is A with shape [K],
another entity is B with shape [M]. The dist_matrix[i][j] is the
distance between A[i] and B[j]. The bigger the distance is, the
better matching the pairs are. NOTE: This tensor can contain LoD
information to represent a batch of inputs. One instance of this
batch can contain different numbers of entities.
match_type(str, optional): The type of matching method, should be
'bipartite' or 'per_prediction'. None ('bipartite') by default.
dist_threshold(float32, optional): If `match_type` is 'per_prediction',
this threshold is to determine the extra matching bboxes based
on the maximum distance, 0.5 by default.
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Tuple:
matched_indices(Tensor): A 2-D Tensor with shape [N, M]. The data
type is int32. N is the batch size. If match_indices[i][j] is -1, it
means B[j] does not match any entity in i-th instance.
Otherwise, it means B[j] is matched to row
match_indices[i][j] in i-th instance. The row number of
i-th instance is saved in match_indices[i][j].
matched_distance(Tensor): A 2-D Tensor with shape [N, M]. The data
type is float32. N is batch size. If match_indices[i][j] is -1,
match_distance[i][j] is also -1.0. Otherwise, assumed
match_distance[i][j] = d, and the row offsets of each instance
are called LoD. Then match_distance[i][j] =
dist_matrix[d+LoD[i]][j].
Examples:
.. code-block:: python
import paddle
from ppdet.modeling import ops
from ppdet.modeling.utils import iou_similarity
paddle.enable_static()
x = paddle.static.data(name='x', shape=[None, 4], dtype='float32')
y = paddle.static.data(name='y', shape=[None, 4], dtype='float32')
iou = iou_similarity(x=x, y=y)
matched_indices, matched_dist = ops.bipartite_match(iou)
"""
check_variable_and_dtype(dist_matrix, 'dist_matrix',
['float32', 'float64'], 'bipartite_match')
if in_dygraph_mode():
match_indices, match_distance = core.ops.bipartite_match(
dist_matrix, "match_type", match_type, "dist_threshold",
dist_threshold)
return match_indices, match_distance
helper = LayerHelper('bipartite_match', **locals())
match_indices = helper.create_variable_for_type_inference(dtype='int32')
match_distance = helper.create_variable_for_type_inference(
dtype=dist_matrix.dtype)
helper.append_op(
type='bipartite_match',
inputs={'DistMat': dist_matrix},
attrs={
'match_type': match_type,
'dist_threshold': dist_threshold,
},
outputs={
'ColToRowMatchIndices': match_indices,
'ColToRowMatchDist': match_distance
})
return match_indices, match_distance
@paddle.jit.not_to_static
def box_coder(prior_box,
prior_box_var,
......
......@@ -376,6 +376,31 @@ class TestIoUSimilarity(LayerTest):
self.assertTrue(np.array_equal(iou_np, iou_dy_np))
class TestBipartiteMatch(LayerTest):
def test_bipartite_match(self):
distance = np.random.random((20, 10)).astype('float32')
with self.static_graph():
x = paddle.static.data(name='x', shape=[20, 10], dtype='float32')
match_indices, match_dist = ops.bipartite_match(
x, match_type='per_prediction', dist_threshold=0.5)
match_indices_np, match_dist_np = self.get_static_graph_result(
feed={'x': distance, },
fetch_list=[match_indices, match_dist],
with_lod=False)
with self.dynamic_graph():
x_dy = base.to_variable(distance)
match_indices_dy, match_dist_dy = ops.bipartite_match(
x_dy, match_type='per_prediction', dist_threshold=0.5)
match_indices_dy_np = match_indices_dy.numpy()
match_dist_dy_np = match_dist_dy.numpy()
self.assertTrue(np.array_equal(match_indices_np, match_indices_dy_np))
self.assertTrue(np.array_equal(match_dist_np, match_dist_dy_np))
class TestYoloBox(LayerTest):
def test_yolo_box(self):
......
......@@ -33,7 +33,7 @@ def json_eval_results(metric, json_directory=None, dataset=None):
logger.info("{} not exists!".format(v_json))
def get_infer_results(outs_res, eval_type, catid, im_info):
def get_infer_results(outs_res, eval_type, catid):
"""
Get result at the stage of inference.
The output format is dictionary containing bbox or mask result.
......@@ -45,31 +45,27 @@ def get_infer_results(outs_res, eval_type, catid, im_info):
raise ValueError(
'The number of valid detection result if zero. Please use reasonable model and check input data.'
)
infer_res = {}
if 'bbox' in eval_type:
box_res = []
for i, outs in enumerate(outs_res):
im_ids = im_info[i][2]
box_res += get_det_res(outs['bbox'], outs['bbox_num'], im_ids,
catid)
infer_res['bbox'] = box_res
if 'mask' in eval_type:
seg_res = []
# mask post process
for i, outs in enumerate(outs_res):
im_shape = im_info[i][0]
scale_factor = im_info[i][1]
im_ids = im_info[i][2]
mask = outs['mask']
seg_res += get_seg_res(mask, outs['bbox_num'], im_ids, catid)
infer_res['mask'] = seg_res
infer_res = {k: [] for k in eval_type}
for i, outs in enumerate(outs_res):
im_id = outs['im_id']
im_shape = outs['im_shape']
scale_factor = outs['scale_factor']
if 'bbox' in eval_type:
infer_res['bbox'] += get_det_res(outs['bbox'], outs['bbox_num'],
im_id, catid)
if 'mask' in eval_type:
# mask post process
infer_res['mask'] += get_seg_res(outs['mask'], outs['bbox_num'],
im_id, catid)
return infer_res
def eval_results(res, metric, anno_file):
def eval_results(res, metric, dataset):
"""
Evalute the inference result
"""
......@@ -82,7 +78,8 @@ def eval_results(res, metric, anno_file):
json.dump(res['bbox'], f)
logger.info('The bbox result is saved to bbox.json.')
bbox_stats = cocoapi_eval('bbox.json', 'bbox', anno_file=anno_file)
bbox_stats = cocoapi_eval(
'bbox.json', 'bbox', anno_file=dataset.get_anno())
eval_res.append(bbox_stats)
sys.stdout.flush()
if 'mask' in res:
......@@ -90,9 +87,14 @@ def eval_results(res, metric, anno_file):
json.dump(res['mask'], f)
logger.info('The mask result is saved to mask.json.')
seg_stats = cocoapi_eval('mask.json', 'segm', anno_file=anno_file)
seg_stats = cocoapi_eval(
'mask.json', 'segm', anno_file=dataset.get_anno())
eval_res.append(seg_stats)
sys.stdout.flush()
elif metric == 'VOC':
from ppdet.utils.voc_eval import bbox_eval
bbox_stats = bbox_eval(res, 21)
else:
raise NotImplemented("Only COCO metric is supported now.")
......
......@@ -63,46 +63,33 @@ def bbox_eval(results,
evaluate_difficult=evaluate_difficult)
for t in results:
bboxes = t['bbox'][0]
bbox_lengths = t['bbox'][1][0]
bboxes = t['bbox']
bbox_lengths = t['bbox_num']
if bboxes.shape == (1, 1) or bboxes is None:
continue
gt_boxes = t['gt_bbox'][0]
gt_labels = t['gt_class'][0]
difficults = t['is_difficult'][0] if not evaluate_difficult \
gt_boxes = t['gt_bbox']
gt_labels = t['gt_class']
difficults = t['difficult'] if not evaluate_difficult \
else None
if len(t['gt_bbox'][1]) == 0:
# gt_bbox, gt_class, difficult read as zero padded Tensor
bbox_idx = 0
for i in range(len(gt_boxes)):
gt_box = gt_boxes[i]
gt_label = gt_labels[i]
difficult = None if difficults is None \
else difficults[i]
bbox_num = bbox_lengths[i]
bbox = bboxes[bbox_idx:bbox_idx + bbox_num]
gt_box, gt_label, difficult = prune_zero_padding(
gt_box, gt_label, difficult)
detection_map.update(bbox, gt_box, gt_label, difficult)
bbox_idx += bbox_num
else:
# gt_box, gt_label, difficult read as LoDTensor
gt_box_lengths = t['gt_bbox'][1][0]
bbox_idx = 0
gt_box_idx = 0
for i in range(len(bbox_lengths)):
bbox_num = bbox_lengths[i]
gt_box_num = gt_box_lengths[i]
bbox = bboxes[bbox_idx:bbox_idx + bbox_num]
gt_box = gt_boxes[gt_box_idx:gt_box_idx + gt_box_num]
gt_label = gt_labels[gt_box_idx:gt_box_idx + gt_box_num]
difficult = None if difficults is None else \
difficults[gt_box_idx: gt_box_idx + gt_box_num]
detection_map.update(bbox, gt_box, gt_label, difficult)
bbox_idx += bbox_num
gt_box_idx += gt_box_num
scale_factor = t['scale_factor'] if 'scale_factor' in t else np.ones(
(gt_boxes.shape[0], 2)).astype('float32')
bbox_idx = 0
for i in range(gt_boxes.shape[0]):
gt_box = gt_boxes[i]
h, w = scale_factor[i]
gt_box = gt_box / np.array([w, h, w, h])
gt_label = gt_labels[i]
difficult = None if difficults is None \
else difficults[i]
bbox_num = bbox_lengths[i]
bbox = bboxes[bbox_idx:bbox_idx + bbox_num]
gt_box, gt_label, difficult = prune_zero_padding(gt_box, gt_label,
difficult)
detection_map.update(bbox, gt_box, gt_label, difficult)
bbox_idx += bbox_num
logger.info("Accumulating evaluatation results...")
detection_map.accumulate()
......
......@@ -69,34 +69,34 @@ def run(FLAGS, cfg, place):
# Data Reader
dataset = cfg.EvalDataset
eval_loader, _ = create('EvalReader')(dataset, cfg['worker_num'], place)
eval_loader = create('EvalReader')(dataset, cfg['worker_num'], place)
extra_key = ['im_shape', 'scale_factor', 'im_id']
if cfg.metric == 'VOC':
extra_key += ['gt_bbox', 'gt_class', 'difficult']
# Run Eval
outs_res = []
start_time = time.time()
sample_num = 0
im_info = []
for iter_id, data in enumerate(eval_loader):
# forward
fields = cfg['EvalReader']['inputs_def']['fields']
model.eval()
outs = model(data=data, input_def=fields, mode='infer')
outs = model(data, mode='infer')
for key in extra_key:
outs[key] = data[key]
for key, value in outs.items():
outs[key] = value.numpy()
im_shape = data[fields.index('im_shape')].numpy()
scale_factor = data[fields.index('scale_factor')].numpy()
im_id = data[fields.index('im_id')].numpy()
im_info.append([im_shape, scale_factor, im_id])
if 'mask' in outs and 'bbox' in outs:
mask_resolution = model.mask_post_process.mask_resolution
from ppdet.py_op.post_process import mask_post_process
outs['mask'] = mask_post_process(outs, im_shape, scale_factor,
mask_resolution)
outs['mask'] = mask_post_process(
outs, outs['im_shape'], outs['scale_factor'], mask_resolution)
outs_res.append(outs)
# log
sample_num += im_shape.shape[0]
sample_num += outs['im_id'].shape[0]
if iter_id % 100 == 0:
logger.info("Eval iter: {}".format(iter_id))
......@@ -111,15 +111,22 @@ def run(FLAGS, cfg, place):
eval_type.append('mask')
# Metric
# TODO: support other metric
from ppdet.utils.coco_eval import get_category_info
anno_file = dataset.get_anno()
with_background = cfg.with_background
use_default_label = dataset.use_default_label
clsid2catid, catid2name = get_category_info(anno_file, with_background,
use_default_label)
if cfg.metric == 'COCO':
from ppdet.utils.coco_eval import get_category_info
clsid2catid, catid2name = get_category_info(
dataset.get_anno(), with_background, use_default_label)
infer_res = get_infer_results(outs_res, eval_type, clsid2catid)
elif cfg.metric == 'VOC':
from ppdet.utils.voc_eval import get_category_info
clsid2catid, catid2name = get_category_info(
dataset.get_label_list(), with_background, use_default_label)
infer_res = outs_res
infer_res = get_infer_results(outs_res, eval_type, clsid2catid, im_info)
eval_results(infer_res, cfg.metric, anno_file)
eval_results(infer_res, cfg.metric, dataset)
def main():
......
......@@ -49,6 +49,8 @@ def parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
if metric == 'COCO':
from ppdet.utils.coco_eval import get_category_info
elif metric == 'VOC':
from ppdet.utils.voc_eval import get_category_info
else:
raise ValueError("metric only supports COCO, but received {}".format(
metric))
......
......@@ -129,7 +129,11 @@ def run(FLAGS, cfg, place):
dataset = cfg.TestDataset
test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)
dataset.set_images(test_images)
test_loader, _ = create('TestReader')(dataset, cfg['worker_num'], place)
test_loader = create('TestReader')(dataset, cfg['worker_num'], place)
extra_key = ['im_shape', 'scale_factor', 'im_id']
if cfg.metric == 'VOC':
extra_key += ['gt_bbox', 'gt_class', 'difficult']
# TODO: support other metrics
imid2path = dataset.get_imid2path()
......@@ -147,24 +151,18 @@ def run(FLAGS, cfg, place):
# Run Infer
for iter_id, data in enumerate(test_loader):
# forward
fields = cfg.TestReader['inputs_def']['fields']
model.eval()
outs = model(
data=data,
input_def=cfg.TestReader['inputs_def']['fields'],
mode='infer')
outs = model(data, mode='infer')
for key in extra_key:
outs[key] = data[key]
for key, value in outs.items():
outs[key] = value.numpy()
im_shape = data[fields.index('im_shape')].numpy()
scale_factor = data[fields.index('scale_factor')].numpy()
im_ids = data[fields.index('im_id')].numpy()
im_info = [im_shape, scale_factor, im_ids]
if 'mask' in outs and 'bbox' in outs:
mask_resolution = model.mask_post_process.mask_resolution
from ppdet.py_op.post_process import mask_post_process
outs['mask'] = mask_post_process(outs, im_shape, scale_factor,
mask_resolution)
outs['mask'] = mask_post_process(
outs, outs['im_shape'], outs['scale_factor'], mask_resolution)
eval_type = []
if 'bbox' in outs:
......@@ -172,14 +170,14 @@ def run(FLAGS, cfg, place):
if 'mask' in outs:
eval_type.append('mask')
batch_res = get_infer_results([outs], eval_type, clsid2catid, [im_info])
batch_res = get_infer_results([outs], eval_type, clsid2catid)
logger.info('Infer iter {}'.format(iter_id))
bbox_res = None
mask_res = None
bbox_num = outs['bbox_num']
start = 0
for i, im_id in enumerate(im_ids):
for i, im_id in enumerate(outs['im_id']):
image_path = imid2path[int(im_id)]
image = Image.open(image_path).convert('RGB')
end = start + bbox_num[i]
......@@ -197,7 +195,7 @@ def run(FLAGS, cfg, place):
mask_res = batch_res['mask'][start:end]
image = visualize_results(image, bbox_res, mask_res,
int(im_id), catid2name,
int(outs['im_id']), catid2name,
FLAGS.draw_threshold)
# use VisualDL to log image with bbox
......
......@@ -103,8 +103,8 @@ def run(FLAGS, cfg, place):
# Data
dataset = cfg.TrainDataset
train_loader, step_per_epoch = create('TrainReader')(
dataset, cfg['worker_num'], place)
train_loader = create('TrainReader')(dataset, cfg['worker_num'], place)
step_per_epoch = len(train_loader)
# Model
model = create(cfg.architecture)
......@@ -134,7 +134,6 @@ def run(FLAGS, cfg, place):
if ParallelEnv().nranks > 1:
model = paddle.DataParallel(model)
fields = train_loader.collate_fn.output_fields
cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_dir = os.path.join(cfg.save_dir, cfg_name)
# Run Train
......@@ -155,7 +154,7 @@ def run(FLAGS, cfg, place):
# Model Forward
model.train()
outputs = model(data=data, input_def=fields, mode='train')
outputs = model(data, mode='train')
# Model Backward
loss = outputs['loss']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册