未验证 提交 24f582fd 编写于 作者: F Feng Ni 提交者: GitHub

[doc] add fcos ttfnet coments (#2456)

* add fcos ttfnet coment, update modelzoo

* fix ttfnet coment
上级 3b02c71e
...@@ -12,9 +12,9 @@ FCOS (Fully Convolutional One-Stage Object Detection) is a fast anchor-free obje ...@@ -12,9 +12,9 @@ FCOS (Fully Convolutional One-Stage Object Detection) is a fast anchor-free obje
| 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 |推理时间(fps) | Box AP | 下载 | 配置文件 | | 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 |推理时间(fps) | Box AP | 下载 | 配置文件 |
| :-------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: | | :-------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: |
| ResNet50-FPN | FCOS | 2 | 1x | ---- | 39.6 | [下载链接](https://paddledet.bj.bcebos.com/models/fcos_r50_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/fcos/fcos_r50_fpn_1x_coco.yml) | | ResNet50-FPN | FCOS | 2 | 1x | ---- | 39.6 | [下载链接](https://paddledet.bj.bcebos.com/models/fcos_r50_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/fcos/fcos_r50_fpn_1x_coco.yml) |
| ResNet50-FPN | FCOS+DCN | 2 | 1x | ---- | 44.3 | [下载链接](https://paddledet.bj.bcebos.com/models/fcos_dcn_r50_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/fcos/fcos_dcn_r50_fpn_1x_coco.yml) | | ResNet50-FPN | FCOS+DCN | 2 | 1x | ---- | 44.3 | [下载链接](https://paddledet.bj.bcebos.com/models/fcos_dcn_r50_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/fcos/fcos_dcn_r50_fpn_1x_coco.yml) |
| ResNet50-FPN | FCOS+multiscale_train | 2 | 2x | ---- | 41.8 | [下载链接](https://paddledet.bj.bcebos.com/models/fcos_r50_fpn_multiscale_2x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/fcos/fcos_r50_fpn_multiscale_2x_coco.yml) | | ResNet50-FPN | FCOS+multiscale_train | 2 | 2x | ---- | 41.8 | [下载链接](https://paddledet.bj.bcebos.com/models/fcos_r50_fpn_multiscale_2x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/fcos/fcos_r50_fpn_multiscale_2x_coco.yml) |
**Notes:** **Notes:**
......
...@@ -47,7 +47,6 @@ FCOSPostProcess: ...@@ -47,7 +47,6 @@ FCOSPostProcess:
decode: decode:
name: FCOSBox name: FCOSBox
num_classes: 80 num_classes: 80
batch_size: 1
nms: nms:
name: MultiClassNMS name: MultiClassNMS
nms_top_k: 1000 nms_top_k: 1000
......
...@@ -13,7 +13,7 @@ TTFNet是一种用于实时目标检测且对训练时间友好的网络,对Ce ...@@ -13,7 +13,7 @@ TTFNet是一种用于实时目标检测且对训练时间友好的网络,对Ce
| 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 |推理时间(fps) | Box AP | 下载 | 配置文件 | | 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 |推理时间(fps) | Box AP | 下载 | 配置文件 |
| :-------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: | | :-------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: |
| DarkNet53 | TTFNet | 12 | 1x | ---- | 33.5 | [下载链接](https://paddledet.bj.bcebos.com/models/ttfnet_darknet53_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ttfnet/ttfnet_darknet53_1x_coco.yml) | | DarkNet53 | TTFNet | 12 | 1x | ---- | 33.5 | [下载链接](https://paddledet.bj.bcebos.com/models/ttfnet_darknet53_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/ttfnet/ttfnet_darknet53_1x_coco.yml) |
## Citations ## Citations
``` ```
......
...@@ -25,6 +25,16 @@ __all__ = ['FCOS'] ...@@ -25,6 +25,16 @@ __all__ = ['FCOS']
@register @register
class FCOS(BaseArch): class FCOS(BaseArch):
"""
FCOS network, see https://arxiv.org/abs/1904.01355
Args:
backbone (object): backbone instance
neck (object): 'FPN' instance
fcos_head (object): 'FCOSHead' instance
post_process (object): 'FCOSPostProcess' instance
"""
__category__ = 'architecture' __category__ = 'architecture'
__inject__ = ['fcos_post_process'] __inject__ = ['fcos_post_process']
...@@ -70,7 +80,7 @@ class FCOS(BaseArch): ...@@ -70,7 +80,7 @@ class FCOS(BaseArch):
loss = {} loss = {}
tag_labels, tag_bboxes, tag_centerness = [], [], [] tag_labels, tag_bboxes, tag_centerness = [], [], []
for i in range(len(self.fcos_head.fpn_stride)): for i in range(len(self.fcos_head.fpn_stride)):
# reg_target, labels, scores, centerness # labels, reg_target, centerness
k_lbl = 'labels{}'.format(i) k_lbl = 'labels{}'.format(i)
if k_lbl in self.inputs: if k_lbl in self.inputs:
tag_labels.append(self.inputs[k_lbl]) tag_labels.append(self.inputs[k_lbl])
...@@ -90,6 +100,6 @@ class FCOS(BaseArch): ...@@ -90,6 +100,6 @@ class FCOS(BaseArch):
return loss return loss
def get_pred(self): def get_pred(self):
bboxes, bbox_num = self._forward() bbox_pred, bbox_num = self._forward()
output = {'bbox': bboxes, 'bbox_num': bbox_num} output = {'bbox': bbox_pred, 'bbox_num': bbox_num}
return output return output
...@@ -28,6 +28,10 @@ from ppdet.modeling.layers import ConvNormLayer ...@@ -28,6 +28,10 @@ from ppdet.modeling.layers import ConvNormLayer
class ScaleReg(nn.Layer): class ScaleReg(nn.Layer):
"""
Parameter for scaling the regression outputs.
"""
def __init__(self): def __init__(self):
super(ScaleReg, self).__init__() super(ScaleReg, self).__init__()
self.scale_reg = self.create_parameter( self.scale_reg = self.create_parameter(
...@@ -113,12 +117,13 @@ class FCOSHead(nn.Layer): ...@@ -113,12 +117,13 @@ class FCOSHead(nn.Layer):
""" """
FCOSHead FCOSHead
Args: Args:
num_classes(int): Number of classes fcos_feat (object): Instance of 'FCOSFeat'
fpn_stride(list): The stride of each FPN Layer num_classes (int): Number of classes
prior_prob(float): Used to set the bias init for the class prediction layer fpn_stride (list): The stride of each FPN Layer
fcos_loss(object): Instance of 'FCOSLoss' prior_prob (float): Used to set the bias init for the class prediction layer
norm_reg_targets(bool): Normalization the regression target if true fcos_loss (object): Instance of 'FCOSLoss'
centerness_on_reg(bool): The prediction of centerness on regression or clssification branch norm_reg_targets (bool): Normalization the regression target if true
centerness_on_reg (bool): The prediction of centerness on regression or clssification branch
""" """
__inject__ = ['fcos_feat', 'fcos_loss'] __inject__ = ['fcos_feat', 'fcos_loss']
__shared__ = ['num_classes'] __shared__ = ['num_classes']
...@@ -199,7 +204,15 @@ class FCOSHead(nn.Layer): ...@@ -199,7 +204,15 @@ class FCOSHead(nn.Layer):
scale_reg = self.add_sublayer(feat_name, ScaleReg()) scale_reg = self.add_sublayer(feat_name, ScaleReg())
self.scales_regs.append(scale_reg) self.scales_regs.append(scale_reg)
def _compute_locatioins_by_level(self, fpn_stride, feature): def _compute_locations_by_level(self, fpn_stride, feature):
"""
Compute locations of anchor points of each FPN layer
Args:
fpn_stride (int): The stride of current FPN feature map
feature (Tensor): Tensor of current FPN feature map
Return:
Anchor points locations of current FPN feature map
"""
shape_fm = paddle.shape(feature) shape_fm = paddle.shape(feature)
shape_fm.stop_gradient = True shape_fm.stop_gradient = True
h, w = shape_fm[2], shape_fm[3] h, w = shape_fm[2], shape_fm[3]
...@@ -247,8 +260,7 @@ class FCOSHead(nn.Layer): ...@@ -247,8 +260,7 @@ class FCOSHead(nn.Layer):
if not is_training: if not is_training:
locations_list = [] locations_list = []
for fpn_stride, feature in zip(self.fpn_stride, fpn_feats): for fpn_stride, feature in zip(self.fpn_stride, fpn_feats):
location = self._compute_locatioins_by_level(fpn_stride, location = self._compute_locations_by_level(fpn_stride, feature)
feature)
locations_list.append(location) locations_list.append(location)
return locations_list, cls_logits_list, bboxes_reg_list, centerness_list return locations_list, cls_logits_list, bboxes_reg_list, centerness_list
......
...@@ -24,7 +24,15 @@ import numpy as np ...@@ -24,7 +24,15 @@ import numpy as np
@register @register
class HMHead(nn.Layer): class HMHead(nn.Layer):
"""
Args:
ch_in (int): The channel number of input Tensor.
ch_out (int): The channel number of output Tensor.
num_classes (int): Number of classes.
conv_num (int): The convolution number of hm_feat.
Return:
Heatmap head output
"""
__shared__ = ['num_classes'] __shared__ = ['num_classes']
def __init__(self, ch_in, ch_out=128, num_classes=80, conv_num=2): def __init__(self, ch_in, ch_out=128, num_classes=80, conv_num=2):
...@@ -65,6 +73,15 @@ class HMHead(nn.Layer): ...@@ -65,6 +73,15 @@ class HMHead(nn.Layer):
@register @register
class WHHead(nn.Layer): class WHHead(nn.Layer):
"""
Args:
ch_in (int): The channel number of input Tensor.
ch_out (int): The channel number of output Tensor.
conv_num (int): The convolution number of wh_feat.
Return:
Width & Height head output
"""
def __init__(self, ch_in, ch_out=64, conv_num=2): def __init__(self, ch_in, ch_out=64, conv_num=2):
super(WHHead, self).__init__() super(WHHead, self).__init__()
head_conv = nn.Sequential() head_conv = nn.Sequential()
...@@ -104,17 +121,22 @@ class TTFHead(nn.Layer): ...@@ -104,17 +121,22 @@ class TTFHead(nn.Layer):
""" """
TTFHead TTFHead
Args: Args:
in_channels(int): the channel number of input to TTFHead. in_channels (int): the channel number of input to TTFHead.
num_classes(int): the number of classes, 80 by default. num_classes (int): the number of classes, 80 by default.
hm_head_planes(int): the channel number in wh head, 128 by default. hm_head_planes (int): the channel number in heatmap head,
wh_head_planes(int): the channel number in wh head, 64 by default. 128 by default.
hm_head_conv_num(int): the number of convolution in wh head, 2 by default. wh_head_planes (int): the channel number in width & height head,
wh_head_conv_num(int): the number of convolution in wh head, 2 by default. 64 by default.
hm_loss(object): Instance of 'CTFocalLoss'. hm_head_conv_num (int): the number of convolution in heatmap head,
wh_loss(object): Instance of 'GIoULoss'. 2 by default.
wh_offset_base(flaot): the base offset of width and height, 16. by default. wh_head_conv_num (int): the number of convolution in width & height
down_ratio(int): the actual down_ratio is calculated by base_down_ratio(default 16) head, 2 by default.
and the number of upsample layers. hm_loss (object): Instance of 'CTFocalLoss'.
wh_loss (object): Instance of 'GIoULoss'.
wh_offset_base (float): the base offset of width and height,
16.0 by default.
down_ratio (int): the actual down_ratio is calculated by base_down_ratio
(default 16) and the number of upsample layers.
""" """
__shared__ = ['num_classes', 'down_ratio'] __shared__ = ['num_classes', 'down_ratio']
...@@ -154,6 +176,9 @@ class TTFHead(nn.Layer): ...@@ -154,6 +176,9 @@ class TTFHead(nn.Layer):
return hm, wh return hm, wh
def filter_box_by_weight(self, pred, target, weight): def filter_box_by_weight(self, pred, target, weight):
"""
Filter out boxes where ttf_reg_weight is 0, only keep positive samples.
"""
index = paddle.nonzero(weight > 0) index = paddle.nonzero(weight > 0)
index.stop_gradient = True index.stop_gradient = True
weight = paddle.gather_nd(weight, index) weight = paddle.gather_nd(weight, index)
......
...@@ -616,20 +616,20 @@ class AnchorGrid(object): ...@@ -616,20 +616,20 @@ class AnchorGrid(object):
@register @register
@serializable @serializable
class FCOSBox(object): class FCOSBox(object):
__shared__ = ['num_classes', 'batch_size'] __shared__ = ['num_classes']
def __init__(self, num_classes=80, batch_size=1): def __init__(self, num_classes=80):
super(FCOSBox, self).__init__() super(FCOSBox, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.batch_size = batch_size
def _merge_hw(self, inputs, ch_type="channel_first"): def _merge_hw(self, inputs, ch_type="channel_first"):
""" """
Merge h and w of the feature map into one dimension.
Args: Args:
inputs (Variables): Feature map whose H and W will be merged into one dimension inputs (Tensor): Tensor of the input feature map
ch_type (str): channel_first / channel_last ch_type (str): "channel_first" or "channel_last" style
Return: Return:
new_shape (Variables): The new shape after h and w merged into one dimension new_shape (Tensor): The new shape after h and w merged
""" """
shape_ = paddle.shape(inputs) shape_ = paddle.shape(inputs)
bs, ch, hi, wi = shape_[0], shape_[1], shape_[2], shape_[3] bs, ch, hi, wi = shape_[0], shape_[1], shape_[2], shape_[3]
...@@ -647,16 +647,18 @@ class FCOSBox(object): ...@@ -647,16 +647,18 @@ class FCOSBox(object):
def _postprocessing_by_level(self, locations, box_cls, box_reg, box_ctn, def _postprocessing_by_level(self, locations, box_cls, box_reg, box_ctn,
scale_factor): scale_factor):
""" """
Postprocess each layer of the output with corresponding locations.
Args: Args:
locations (Variables): anchor points for current layer, [H*W, 2] locations (Tensor): anchor points for current layer, [H*W, 2]
box_cls (Variables): categories prediction, [N, C, H, W], C is the number of classes box_cls (Tensor): categories prediction, [N, C, H, W],
box_reg (Variables): bounding box prediction, [N, 4, H, W] C is the number of classes
box_ctn (Variables): centerness prediction, [N, 1, H, W] box_reg (Tensor): bounding box prediction, [N, 4, H, W]
scale_factor (Variables): [h_scale, w_scale] for input images box_ctn (Tensor): centerness prediction, [N, 1, H, W]
scale_factor (Tensor): [h_scale, w_scale] for input images
Return: Return:
box_cls_ch_last (Variables): score for each category, in [N, C, M] box_cls_ch_last (Tensor): score for each category, in [N, C, M]
C is the number of classes and M is the number of anchor points C is the number of classes and M is the number of anchor points
box_reg_decoding (Variables): decoded bounding box, in [N, M, 4] box_reg_decoding (Tensor): decoded bounding box, in [N, M, 4]
last dimension is [x1, y1, x2, y2] last dimension is [x1, y1, x2, y2]
""" """
act_shape_cls = self._merge_hw(box_cls) act_shape_cls = self._merge_hw(box_cls)
...@@ -712,12 +714,18 @@ class TTFBox(object): ...@@ -712,12 +714,18 @@ class TTFBox(object):
self.down_ratio = down_ratio self.down_ratio = down_ratio
def _simple_nms(self, heat, kernel=3): def _simple_nms(self, heat, kernel=3):
"""
Use maxpool to filter the max score, get local peaks.
"""
pad = (kernel - 1) // 2 pad = (kernel - 1) // 2
hmax = F.max_pool2d(heat, kernel, stride=1, padding=pad) hmax = F.max_pool2d(heat, kernel, stride=1, padding=pad)
keep = paddle.cast(hmax == heat, 'float32') keep = paddle.cast(hmax == heat, 'float32')
return heat * keep return heat * keep
def _topk(self, scores): def _topk(self, scores):
"""
Select top k scores and decode to get xy coordinates.
"""
k = self.max_per_img k = self.max_per_img
shape_fm = paddle.shape(scores) shape_fm = paddle.shape(scores)
shape_fm.stop_gradient = True shape_fm.stop_gradient = True
......
...@@ -27,7 +27,7 @@ __all__ = ['CTFocalLoss'] ...@@ -27,7 +27,7 @@ __all__ = ['CTFocalLoss']
@serializable @serializable
class CTFocalLoss(object): class CTFocalLoss(object):
""" """
CTFocalLoss CTFocalLoss: CornerNet & CenterNet Focal Loss
Args: Args:
loss_weight (float): loss weight loss_weight (float): loss weight
gamma (float): gamma parameter for Focal Loss gamma (float): gamma parameter for Focal Loss
...@@ -41,8 +41,8 @@ class CTFocalLoss(object): ...@@ -41,8 +41,8 @@ class CTFocalLoss(object):
""" """
Calculate the loss Calculate the loss
Args: Args:
pred(Tensor): heatmap prediction pred (Tensor): heatmap prediction
target(Tensor): target for positive samples target (Tensor): target for positive samples
Return: Return:
ct_focal_loss (Tensor): Focal Loss used in CornerNet & CenterNet. ct_focal_loss (Tensor): Focal Loss used in CornerNet & CenterNet.
Note that the values in target are in [0, 1] since gaussian is Note that the values in target are in [0, 1] since gaussian is
......
...@@ -20,8 +20,8 @@ import paddle ...@@ -20,8 +20,8 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from ppdet.core.workspace import register from ppdet.core.workspace import register
from ppdet.modeling import ops
INF = 1e8
__all__ = ['FCOSLoss'] __all__ = ['FCOSLoss']
...@@ -30,33 +30,20 @@ def flatten_tensor(inputs, channel_first=False): ...@@ -30,33 +30,20 @@ def flatten_tensor(inputs, channel_first=False):
Flatten a Tensor Flatten a Tensor
Args: Args:
inputs (Tensor): 4-D Tensor with shape [N, C, H, W] or [N, H, W, C] inputs (Tensor): 4-D Tensor with shape [N, C, H, W] or [N, H, W, C]
channel_first(bool): if true the dimension order of channel_first (bool): If true the dimension order of Tensor is
Tensor is [N, C, H, W], otherwise is [N, H, W, C] [N, C, H, W], otherwise is [N, H, W, C]
Return: Return:
input_channel_last (Tensor): The flattened Tensor in channel_last style output_channel_last (Tensor): The flattened Tensor in channel_last style
""" """
if channel_first: if channel_first:
input_channel_last = paddle.transpose(inputs, perm=[0, 2, 3, 1]) input_channel_last = paddle.transpose(inputs, perm=[0, 2, 3, 1])
else: else:
input_channel_last = inputs input_channel_last = inputs
output_channel_last = paddle.flatten( output_channel_last = paddle.flatten(
input_channel_last, start_axis=0, stop_axis=2) # [N*H*W, C] input_channel_last, start_axis=0, stop_axis=2)
return output_channel_last return output_channel_last
def sigmoid_cross_entropy_with_logits_loss(inputs,
label,
ignore_index=-100,
normalize=False):
output = F.binary_cross_entropy_with_logits(inputs, label, reduction='none')
mask_tensor = paddle.cast(label != ignore_index, 'float32')
output = paddle.multiply(output, mask_tensor)
if normalize:
sum_valid_mask = paddle.sum(mask_tensor)
output = output / sum_valid_mask
return output
@register @register
class FCOSLoss(nn.Layer): class FCOSLoss(nn.Layer):
""" """
...@@ -64,8 +51,8 @@ class FCOSLoss(nn.Layer): ...@@ -64,8 +51,8 @@ class FCOSLoss(nn.Layer):
Args: Args:
loss_alpha (float): alpha in focal loss loss_alpha (float): alpha in focal loss
loss_gamma (float): gamma in focal loss loss_gamma (float): gamma in focal loss
iou_loss_type(str): location loss type, IoU/GIoU/LINEAR_IoU iou_loss_type (str): location loss type, IoU/GIoU/LINEAR_IoU
reg_weights(float): weight for location loss reg_weights (float): weight for location loss
""" """
def __init__(self, def __init__(self,
...@@ -226,7 +213,7 @@ class FCOSLoss(nn.Layer): ...@@ -226,7 +213,7 @@ class FCOSLoss(nn.Layer):
# 3. centerness: sigmoid_cross_entropy_with_logits_loss # 3. centerness: sigmoid_cross_entropy_with_logits_loss
centerness_flatten = paddle.squeeze(centerness_flatten, axis=-1) centerness_flatten = paddle.squeeze(centerness_flatten, axis=-1)
ctn_loss = sigmoid_cross_entropy_with_logits_loss(centerness_flatten, ctn_loss = ops.sigmoid_cross_entropy_with_logits(centerness_flatten,
tag_center_flatten) tag_center_flatten)
ctn_loss = ctn_loss * mask_positive_float / num_positive_fp32 ctn_loss = ctn_loss * mask_positive_float / num_positive_fp32
......
...@@ -17,7 +17,6 @@ import paddle.nn as nn ...@@ -17,7 +17,6 @@ import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import ParamAttr from paddle import ParamAttr
from paddle.nn.initializer import Constant, Uniform, Normal from paddle.nn.initializer import Constant, Uniform, Normal
from paddle.nn import Conv2D, ReLU, Sequential
from paddle import ParamAttr from paddle import ParamAttr
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from paddle.regularizer import L2Decay from paddle.regularizer import L2Decay
...@@ -28,8 +27,6 @@ from ..shape_spec import ShapeSpec ...@@ -28,8 +27,6 @@ from ..shape_spec import ShapeSpec
__all__ = ['TTFFPN'] __all__ = ['TTFFPN']
__all__ = ['TTFFPN']
class Upsample(nn.Layer): class Upsample(nn.Layer):
def __init__(self, ch_in, ch_out, name=None): def __init__(self, ch_in, ch_out, name=None):
...@@ -63,7 +60,7 @@ class Upsample(nn.Layer): ...@@ -63,7 +60,7 @@ class Upsample(nn.Layer):
class ShortCut(nn.Layer): class ShortCut(nn.Layer):
def __init__(self, layer_num, ch_out, name=None): def __init__(self, layer_num, ch_out, name=None):
super(ShortCut, self).__init__() super(ShortCut, self).__init__()
shortcut_conv = Sequential() shortcut_conv = nn.Sequential()
ch_in = ch_out * 2 ch_in = ch_out * 2
for i in range(layer_num): for i in range(layer_num):
fan_out = 3 * 3 * ch_out fan_out = 3 * 3 * ch_out
...@@ -72,7 +69,7 @@ class ShortCut(nn.Layer): ...@@ -72,7 +69,7 @@ class ShortCut(nn.Layer):
shortcut_name = name + '.conv.{}'.format(i) shortcut_name = name + '.conv.{}'.format(i)
shortcut_conv.add_sublayer( shortcut_conv.add_sublayer(
shortcut_name, shortcut_name,
Conv2D( nn.Conv2D(
in_channels=in_channels, in_channels=in_channels,
out_channels=ch_out, out_channels=ch_out,
kernel_size=3, kernel_size=3,
...@@ -81,7 +78,7 @@ class ShortCut(nn.Layer): ...@@ -81,7 +78,7 @@ class ShortCut(nn.Layer):
bias_attr=ParamAttr( bias_attr=ParamAttr(
learning_rate=2., regularizer=L2Decay(0.)))) learning_rate=2., regularizer=L2Decay(0.))))
if i < layer_num - 1: if i < layer_num - 1:
shortcut_conv.add_sublayer(shortcut_name + '.act', ReLU()) shortcut_conv.add_sublayer(shortcut_name + '.act', nn.ReLU())
self.shortcut = self.add_sublayer('short', shortcut_conv) self.shortcut = self.add_sublayer('short', shortcut_conv)
def forward(self, feat): def forward(self, feat):
......
...@@ -1558,7 +1558,6 @@ def sigmoid_cross_entropy_with_logits(input, ...@@ -1558,7 +1558,6 @@ def sigmoid_cross_entropy_with_logits(input,
output = F.binary_cross_entropy_with_logits(input, label, reduction='none') output = F.binary_cross_entropy_with_logits(input, label, reduction='none')
mask_tensor = paddle.cast(label != ignore_index, 'float32') mask_tensor = paddle.cast(label != ignore_index, 'float32')
output = paddle.multiply(output, mask_tensor) output = paddle.multiply(output, mask_tensor)
output = paddle.reshape(output, shape=[output.shape[0], -1])
if normalize: if normalize:
sum_valid_mask = paddle.sum(mask_tensor) sum_valid_mask = paddle.sum(mask_tensor)
output = output / sum_valid_mask output = output / sum_valid_mask
......
...@@ -181,6 +181,9 @@ class FCOSPostProcess(object): ...@@ -181,6 +181,9 @@ class FCOSPostProcess(object):
self.nms = nms self.nms = nms
def __call__(self, fcos_head_outs, scale_factor): def __call__(self, fcos_head_outs, scale_factor):
"""
Decode the bbox and do NMS in FCOS.
"""
locations, cls_logits, bboxes_reg, centerness = fcos_head_outs locations, cls_logits, bboxes_reg, centerness = fcos_head_outs
bboxes, score = self.decode(locations, cls_logits, bboxes_reg, bboxes, score = self.decode(locations, cls_logits, bboxes_reg,
centerness, scale_factor) centerness, scale_factor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册