diff --git a/ppdet/modeling/architectures/yolo.py b/ppdet/modeling/architectures/yolo.py index bf6c19ecfe26a07c42e75d38e4acd9b1c443b499..6c0444480b1de27c96fba217e531d75005e92d70 100644 --- a/ppdet/modeling/architectures/yolo.py +++ b/ppdet/modeling/architectures/yolo.py @@ -20,6 +20,16 @@ class YOLOv3(BaseArch): yolo_head='YOLOv3Head', post_process='BBoxPostProcess', data_format='NCHW'): + """ + YOLOv3 network, see https://arxiv.org/abs/1804.02767 + + Args: + backbone (nn.Layer): backbone instance + neck (nn.Layer): neck instance + yolo_head (nn.Layer): anchor_head instance + bbox_post_process (object): `BBoxPostProcess` instance + data_format (str): data format, NCHW or NHWC + """ super(YOLOv3, self).__init__(data_format=data_format) self.backbone = backbone self.neck = neck diff --git a/ppdet/modeling/backbones/darknet.py b/ppdet/modeling/backbones/darknet.py index ab748c66b9d525082b11adb1c498f5b6603be0ba..9bf0cdaa9ca759cfa199534ce513c870eaf935a3 100755 --- a/ppdet/modeling/backbones/darknet.py +++ b/ppdet/modeling/backbones/darknet.py @@ -37,6 +37,22 @@ class ConvBNLayer(nn.Layer): act="leaky", name=None, data_format='NCHW'): + """ + conv + bn + activation layer + + Args: + ch_in (int): input channel + ch_out (int): output channel + filter_size (int): filter size, default 3 + stride (int): stride, default 1 + groups (int): number of groups of conv layer, default 1 + padding (int): padding size, default 0 + norm_type (str): batch norm type, default bn + norm_decay (str): decay for weight and bias of batch norm layer, default 0. + act (str): activation function type, default 'leaky', which means leaky_relu + name (str): layer name + data_format (str): data format, NCHW or NHWC + """ super(ConvBNLayer, self).__init__() self.conv = nn.Conv2D( @@ -75,6 +91,20 @@ class DownSample(nn.Layer): norm_decay=0., name=None, data_format='NCHW'): + """ + downsample layer + + Args: + ch_in (int): input channel + ch_out (int): output channel + filter_size (int): filter size, default 3 + stride (int): stride, default 2 + padding (int): padding size, default 1 + norm_type (str): batch norm type, default bn + norm_decay (str): decay for weight and bias of batch norm layer, default 0. + name (str): layer name + data_format (str): data format, NCHW or NHWC + """ super(DownSample, self).__init__() @@ -103,6 +133,18 @@ class BasicBlock(nn.Layer): norm_decay=0., name=None, data_format='NCHW'): + """ + BasicBlock layer of DarkNet + + Args: + ch_in (int): input channel + ch_out (int): output channel + norm_type (str): batch norm type, default bn + norm_decay (str): decay for weight and bias of batch norm layer, default 0. + name (str): layer name + data_format (str): data format, NCHW or NHWC + """ + super(BasicBlock, self).__init__() self.conv1 = ConvBNLayer( @@ -142,6 +184,18 @@ class Blocks(nn.Layer): norm_decay=0., name=None, data_format='NCHW'): + """ + Blocks layer, which consist of some BaickBlock layers + + Args: + ch_in (int): input channel + ch_out (int): output channel + count (int): number of BasicBlock layer + norm_type (str): batch norm type, default bn + norm_decay (str): decay for weight and bias of batch norm layer, default 0. + name (str): layer name + data_format (str): data format, NCHW or NHWC + """ super(Blocks, self).__init__() self.basicblock0 = BasicBlock( @@ -189,6 +243,18 @@ class DarkNet(nn.Layer): norm_type='bn', norm_decay=0., data_format='NCHW'): + """ + Darknet, see https://pjreddie.com/darknet/yolo/ + + Args: + depth (int): depth of network + freeze_at (int): freeze the backbone at which stage + filter_size (int): filter size, default 3 + return_idx (list): index of stages whose feature maps are returned + norm_type (str): batch norm type, default bn + norm_decay (str): decay for weight and bias of batch norm layer, default 0. + data_format (str): data format, NCHW or NHWC + """ super(DarkNet, self).__init__() self.depth = depth self.freeze_at = freeze_at diff --git a/ppdet/modeling/heads/yolo_head.py b/ppdet/modeling/heads/yolo_head.py index 3516da4108ddac38c93480c401c6b40af5b9ef05..fa8f9579762cf37548d794391a2751423587ef18 100644 --- a/ppdet/modeling/heads/yolo_head.py +++ b/ppdet/modeling/heads/yolo_head.py @@ -28,6 +28,18 @@ class YOLOv3Head(nn.Layer): iou_aware=False, iou_aware_factor=0.4, data_format='NCHW'): + """ + Head for YOLOv3 network + + Args: + num_classes (int): number of foreground classes + anchors (list): anchors + anchor_masks (list): anchor masks + loss (object): YOLOv3Loss instance + iou_aware (bool): whether to use iou_aware + iou_aware_factor (float): iou aware factor + data_format (str): data format, NCHW or NHWC + """ super(YOLOv3Head, self).__init__() self.num_classes = num_classes self.loss = loss diff --git a/ppdet/modeling/losses/yolo_loss.py b/ppdet/modeling/losses/yolo_loss.py index 9579acf9f9a1d6f27bed7431e6dbe769ce7edbf6..657959cd7e55cf43d6362f03e1a4c1204b814c07 100644 --- a/ppdet/modeling/losses/yolo_loss.py +++ b/ppdet/modeling/losses/yolo_loss.py @@ -46,6 +46,18 @@ class YOLOv3Loss(nn.Layer): scale_x_y=1., iou_loss=None, iou_aware_loss=None): + """ + YOLOv3Loss layer + + Args: + num_calsses (int): number of foreground classes + ignore_thresh (float): threshold to ignore confidence loss + label_smooth (bool): whether to use label smoothing + downsample (list): downsample ratio for each detection block + scale_x_y (float): scale_x_y factor + iou_loss (object): IoULoss instance + iou_aware_loss (object): IouAwareLoss instance + """ super(YOLOv3Loss, self).__init__() self.num_classes = num_classes self.ignore_thresh = ignore_thresh diff --git a/ppdet/modeling/necks/yolo_fpn.py b/ppdet/modeling/necks/yolo_fpn.py index 456bfae2097f3187ced647367986d20fee2730a5..9b8a6d40aa82036f2a473f24d3ea5fae6e2481c0 100644 --- a/ppdet/modeling/necks/yolo_fpn.py +++ b/ppdet/modeling/necks/yolo_fpn.py @@ -27,6 +27,16 @@ __all__ = ['YOLOv3FPN', 'PPYOLOFPN'] class YoloDetBlock(nn.Layer): def __init__(self, ch_in, channel, norm_type, name, data_format='NCHW'): + """ + YOLODetBlock layer for yolov3, see https://arxiv.org/abs/1804.02767 + + Args: + ch_in (int): input channel + channel (int): base channel + norm_type (str): batch norm type + name (str): layer name + data_format (str): data format, NCHW or NHWC + """ super(YoloDetBlock, self).__init__() self.ch_in = ch_in self.channel = channel @@ -78,6 +88,17 @@ class SPP(nn.Layer): norm_type, name, data_format='NCHW'): + """ + SPP layer, which consist of four pooling layer follwed by conv layer + + Args: + ch_in (int): input channel of conv layer + ch_out (int): output channel of conv layer + k (int): kernel size of conv layer + norm_type (str): batch norm type + name (str): layer name + data_format (str): data format, NCHW or NHWC + """ super(SPP, self).__init__() self.pool = [] for size in pool_size: @@ -110,6 +131,15 @@ class SPP(nn.Layer): class DropBlock(nn.Layer): def __init__(self, block_size, keep_prob, name, data_format='NCHW'): + """ + DropBlock layer, see https://arxiv.org/abs/1810.12890 + + Args: + block_size (int): block size + keep_prob (int): keep probability + name (str): layer name + data_format (str): data format, NCHW or NHWC + """ super(DropBlock, self).__init__() self.block_size = block_size self.keep_prob = keep_prob @@ -149,6 +179,19 @@ class CoordConv(nn.Layer): norm_type, name, data_format='NCHW'): + """ + CoordConv layer + + Args: + ch_in (int): input channel + ch_out (int): output channel + filter_size (int): filter size, default 3 + padding (int): padding size, default 0 + norm_type (str): batch norm type, default bn + name (str): layer name + data_format (str): data format, NCHW or NHWC + + """ super(CoordConv, self).__init__() self.conv = ConvBNLayer( ch_in + 2, @@ -193,6 +236,14 @@ class CoordConv(nn.Layer): class PPYOLODetBlock(nn.Layer): def __init__(self, cfg, name, data_format='NCHW'): + """ + PPYOLODetBlock layer + + Args: + cfg (list): layer configs for this block + name (str): block name + data_format (str): data format, NCHW or NHWC + """ super(PPYOLODetBlock, self).__init__() self.conv_module = nn.Sequential() for idx, (conv_name, layer, args, kwargs) in enumerate(cfg[:-1]): @@ -220,6 +271,15 @@ class YOLOv3FPN(nn.Layer): in_channels=[256, 512, 1024], norm_type='bn', data_format='NCHW'): + """ + YOLOv3FPN layer + + Args: + in_channels (list): input channels for fpn + norm_type (str): batch norm type, default bn + data_format (str): data format, NCHW or NHWC + + """ super(YOLOv3FPN, self).__init__() assert len(in_channels) > 0, "in_channels length should > 0" self.in_channels = in_channels @@ -300,6 +360,16 @@ class PPYOLOFPN(nn.Layer): norm_type='bn', data_format='NCHW', **kwargs): + """ + PPYOLOFPN layer + + Args: + in_channels (list): input channels for fpn + norm_type (str): batch norm type, default bn + data_format (str): data format, NCHW or NHWC + kwargs: extra key-value pairs, such as parameter of DropBlock and spp + + """ super(PPYOLOFPN, self).__init__() assert len(in_channels) > 0, "in_channels length should > 0" self.in_channels = in_channels