diff --git a/ppdet/modeling/architecture/mask_rcnn.py b/ppdet/modeling/architecture/mask_rcnn.py index 00c13ec02c2d468fda093990f813d63c7deb8491..eee788a555c2d94f083e0be9f1e413fcf057477f 100644 --- a/ppdet/modeling/architecture/mask_rcnn.py +++ b/ppdet/modeling/architecture/mask_rcnn.py @@ -1,8 +1,22 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import absolute_import from __future__ import division from __future__ import print_function -from paddle import fluid +import paddle from ppdet.core.workspace import register from .meta_arch import BaseArch @@ -114,7 +128,7 @@ class MaskRCNN(BaseArch): loss_mask = self.mask_head.get_loss(self.mask_head_out, mask_targets) loss.update(loss_mask) - total_loss = fluid.layers.sums(list(loss.values())) + total_loss = paddle.add_n(list(loss.values())) loss.update({'loss': total_loss}) return loss diff --git a/ppdet/modeling/backbone/resnet.py b/ppdet/modeling/backbone/resnet.py index c44098dde35f5aa7759ff983aa859cda258cd8cb..27a5dcccae7f3de0666293a23927d8560ce0e99e 100755 --- a/ppdet/modeling/backbone/resnet.py +++ b/ppdet/modeling/backbone/resnet.py @@ -1,16 +1,33 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import numpy as np -import paddle.fluid as fluid -from paddle.fluid.dygraph import Layer, Sequential -from paddle.fluid.dygraph import Conv2D, Pool2D, BatchNorm -from paddle.fluid.param_attr import ParamAttr -from paddle.fluid.initializer import Constant +from paddle import ParamAttr +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn import Conv2D, BatchNorm +from paddle.nn import MaxPool2D + from ppdet.core.workspace import register, serializable -from paddle.fluid.regularizer import L2Decay + +from paddle.regularizer import L2Decay from .name_adapter import NameAdapter from numbers import Integral -class ConvNormLayer(Layer): +class ConvNormLayer(nn.Layer): def __init__(self, ch_in, ch_out, @@ -24,19 +41,18 @@ class ConvNormLayer(Layer): lr=1.0, name=None): super(ConvNormLayer, self).__init__() - assert norm_type in ['bn', 'affine_channel'] + assert norm_type in ['bn', 'sync_bn'] self.norm_type = norm_type self.act = act self.conv = Conv2D( - num_channels=ch_in, - num_filters=ch_out, - filter_size=filter_size, + in_channels=ch_in, + out_channels=ch_out, + kernel_size=filter_size, stride=stride, padding=(filter_size - 1) // 2, groups=1, - act=None, - param_attr=ParamAttr( + weight_attr=ParamAttr( learning_rate=lr, name=name + "_weights"), bias_attr=False) @@ -53,30 +69,16 @@ class ConvNormLayer(Layer): name=bn_name + "_offset", trainable=False if freeze_norm else True) - if norm_type in ['bn', 'sync_bn']: - global_stats = True if freeze_norm else False - self.norm = BatchNorm( - num_channels=ch_out, - act=act, - param_attr=param_attr, - bias_attr=bias_attr, - use_global_stats=global_stats, - moving_mean_name=bn_name + '_mean', - moving_variance_name=bn_name + '_variance') - norm_params = self.norm.parameters() - elif norm_type == 'affine_channel': - self.scale = fluid.layers.create_parameter( - shape=[ch_out], - dtype='float32', - attr=param_attr, - default_initializer=Constant(1.)) - - self.offset = fluid.layers.create_parameter( - shape=[ch_out], - dtype='float32', - attr=bias_attr, - default_initializer=Constant(0.)) - norm_params = [self.scale, self.offset] + global_stats = True if freeze_norm else False + self.norm = BatchNorm( + ch_out, + act=act, + param_attr=param_attr, + bias_attr=bias_attr, + use_global_stats=global_stats, + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + norm_params = self.norm.parameters() if freeze_norm: for param in norm_params: @@ -86,13 +88,10 @@ class ConvNormLayer(Layer): out = self.conv(inputs) if self.norm_type == 'bn': out = self.norm(out) - elif self.norm_type == 'affine_channel': - out = fluid.layers.affine_channel( - out, scale=self.scale, bias=self.offset, act=self.act) return out -class BottleNeck(Layer): +class BottleNeck(nn.Layer): def __init__(self, ch_in, ch_out, @@ -176,12 +175,13 @@ class BottleNeck(Layer): out = self.branch2b(out) out = self.branch2c(out) - out = fluid.layers.elementwise_add(x=short, y=out, act='relu') + out = paddle.add(x=short, y=out) + out = F.relu(out) return out -class Blocks(Layer): +class Blocks(nn.Layer): def __init__(self, ch_in, ch_out, @@ -226,7 +226,7 @@ ResNet_cfg = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]} @register @serializable -class ResNet(Layer): +class ResNet(nn.Layer): def __init__(self, depth=50, variant='b', @@ -265,7 +265,7 @@ class ResNet(Layer): ] else: conv_def = [[3, 64, 7, 2, conv1_name]] - self.conv1 = Sequential() + self.conv1 = nn.Sequential() for (c_in, c_out, k, s, _name) in conv_def: self.conv1.add_sublayer( _name, @@ -282,8 +282,7 @@ class ResNet(Layer): lr=lr_mult, name=_name)) - self.pool = Pool2D( - pool_type='max', pool_size=3, pool_stride=2, pool_padding=1) + self.pool = MaxPool2D(kernel_size=3, stride=2, padding=1) ch_in_list = [64, 256, 512, 1024] ch_out_list = [64, 128, 256, 512] diff --git a/ppdet/modeling/head/bbox_head.py b/ppdet/modeling/head/bbox_head.py index 3dd52e5f14dcbcba255076d991caf4380f73979d..6b5b8013525084a60ac5b99ad4800c7ff0018e41 100644 --- a/ppdet/modeling/head/bbox_head.py +++ b/ppdet/modeling/head/bbox_head.py @@ -1,16 +1,30 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import paddle -import paddle.fluid as fluid -from paddle.fluid.dygraph import Layer -from paddle.fluid.param_attr import ParamAttr -from paddle.fluid.initializer import Normal, Xavier -from paddle.fluid.regularizer import L2Decay -from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear -from ppdet.core.workspace import register +from paddle import ParamAttr +import paddle.nn as nn import paddle.nn.functional as F +from paddle.nn import ReLU +from paddle.nn.initializer import Normal, XavierUniform +from paddle.regularizer import L2Decay +from ppdet.core.workspace import register +from ppdet.modeling import ops @register -class TwoFCHead(Layer): +class TwoFCHead(nn.Layer): __shared__ = ['num_stages'] @@ -21,48 +35,47 @@ class TwoFCHead(Layer): self.num_stages = num_stages fan = in_dim * resolution * resolution self.fc6_list = [] + self.fc6_relu_list = [] self.fc7_list = [] + self.fc7_relu_list = [] for stage in range(num_stages): fc6_name = 'fc6_{}'.format(stage) fc7_name = 'fc7_{}'.format(stage) fc6 = self.add_sublayer( fc6_name, - Linear( + nn.Linear( in_dim * resolution * resolution, mlp_dim, - act='relu', - param_attr=ParamAttr( - #name='fc6_w', - initializer=Xavier(fan_out=fan)), + weight_attr=ParamAttr( + initializer=XavierUniform(fan_out=fan)), bias_attr=ParamAttr( - #name='fc6_b', - learning_rate=2., - regularizer=L2Decay(0.)))) + learning_rate=2., regularizer=L2Decay(0.)))) + fc6_relu = self.add_sublayer(fc6_name + 'act', ReLU()) fc7 = self.add_sublayer( fc7_name, - Linear( + nn.Linear( mlp_dim, mlp_dim, - act='relu', - param_attr=ParamAttr( - #name='fc7_w', - initializer=Xavier()), + weight_attr=ParamAttr(initializer=XavierUniform()), bias_attr=ParamAttr( - #name='fc7_b', - learning_rate=2., - regularizer=L2Decay(0.)))) + learning_rate=2., regularizer=L2Decay(0.)))) + fc7_relu = self.add_sublayer(fc7_name + 'act', ReLU()) self.fc6_list.append(fc6) + self.fc6_relu_list.append(fc6_relu) self.fc7_list.append(fc7) + self.fc7_relu_list.append(fc7_relu) def forward(self, rois_feat, stage=0): - rois_feat = fluid.layers.flatten(rois_feat) + rois_feat = paddle.flatten(rois_feat, start_axis=1, stop_axis=-1) fc6 = self.fc6_list[stage](rois_feat) - fc7 = self.fc7_list[stage](fc6) - return fc7 + fc6_relu = self.fc6_relu_list[stage](fc6) + fc7 = self.fc7_list[stage](fc6_relu) + fc7_relu = self.fc7_relu_list[stage](fc7) + return fc7_relu @register -class BBoxFeat(Layer): +class BBoxFeat(nn.Layer): __inject__ = ['roi_extractor', 'head_feat'] def __init__(self, roi_extractor, head_feat): @@ -77,7 +90,7 @@ class BBoxFeat(Layer): @register -class BBoxHead(Layer): +class BBoxHead(nn.Layer): __shared__ = ['num_classes', 'num_stages'] __inject__ = ['bbox_feat'] @@ -105,40 +118,30 @@ class BBoxHead(Layer): delta_name = 'bbox_delta_{}'.format(stage) bbox_score = self.add_sublayer( score_name, - fluid.dygraph.Linear( - input_dim=in_feat, - output_dim=1 * self.num_classes, - act=None, - param_attr=ParamAttr( - #name='cls_score_w', - initializer=Normal( - loc=0.0, scale=0.01)), + nn.Linear( + in_feat, + 1 * self.num_classes, + weight_attr=ParamAttr(initializer=Normal( + mean=0.0, std=0.01)), bias_attr=ParamAttr( - #name='cls_score_b', - learning_rate=2., - regularizer=L2Decay(0.)))) + learning_rate=2., regularizer=L2Decay(0.)))) bbox_delta = self.add_sublayer( delta_name, - fluid.dygraph.Linear( - input_dim=in_feat, - output_dim=4 * self.delta_dim, - act=None, - param_attr=ParamAttr( - #name='bbox_pred_w', - initializer=Normal( - loc=0.0, scale=0.001)), + nn.Linear( + in_feat, + 4 * self.delta_dim, + weight_attr=ParamAttr(initializer=Normal( + mean=0.0, std=0.001)), bias_attr=ParamAttr( - #name='bbox_pred_b', - learning_rate=2., - regularizer=L2Decay(0.)))) + learning_rate=2., regularizer=L2Decay(0.)))) self.bbox_score_list.append(bbox_score) self.bbox_delta_list.append(bbox_delta) def forward(self, body_feats, rois, spatial_scale, stage=0): bbox_feat = self.bbox_feat(body_feats, rois, spatial_scale, stage) if self.with_pool: - bbox_feat = fluid.layers.pool2d( + bbox_feat = F.pool2d( bbox_feat, pool_type='avg', global_pooling=True) bbox_head_out = [] scores = self.bbox_score_list[stage](bbox_feat) @@ -148,20 +151,19 @@ class BBoxHead(Layer): def _get_head_loss(self, score, delta, target): # bbox cls - labels_int64 = fluid.layers.cast( - x=target['labels_int32'], dtype='int64') + labels_int64 = paddle.cast(x=target['labels_int32'], dtype='int64') labels_int64.stop_gradient = True - loss_bbox_cls = fluid.layers.softmax_with_cross_entropy( + loss_bbox_cls = F.softmax_with_cross_entropy( logits=score, label=labels_int64) - loss_bbox_cls = fluid.layers.reduce_mean(loss_bbox_cls) + loss_bbox_cls = paddle.mean(loss_bbox_cls) # bbox reg - loss_bbox_reg = fluid.layers.smooth_l1( - x=delta, - y=target['bbox_targets'], + loss_bbox_reg = ops.smooth_l1( + input=delta, + label=target['bbox_targets'], inside_weight=target['bbox_inside_weights'], outside_weight=target['bbox_outside_weights'], sigma=1.0) - loss_bbox_reg = fluid.layers.reduce_mean(loss_bbox_reg) + loss_bbox_reg = paddle.mean(loss_bbox_reg) return loss_bbox_cls, loss_bbox_reg def get_loss(self, bbox_head_out, targets): diff --git a/ppdet/modeling/head/mask_head.py b/ppdet/modeling/head/mask_head.py index 363ad0a34c9883773b4a279da4fd6d26e68c37c0..93b9ba6f8377fa2cfcdb027c41225b8107a2ad41 100644 --- a/ppdet/modeling/head/mask_head.py +++ b/ppdet/modeling/head/mask_head.py @@ -1,11 +1,26 @@ -import paddle.fluid as fluid -from paddle.fluid.dygraph import Layer, Sequential +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -from paddle.fluid.param_attr import ParamAttr -from paddle.fluid.initializer import MSRA -from paddle.fluid.regularizer import L2Decay -from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Conv2DTranspose +import paddle +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.nn import Layer, Sequential +from paddle.nn import Conv2D, Conv2DTranspose, ReLU +from paddle.nn.initializer import KaimingNormal +from paddle.regularizer import L2Decay from ppdet.core.workspace import register +from ppdet.modeling import ops @register @@ -37,35 +52,27 @@ class MaskFeat(Layer): mask_conv.add_sublayer( conv_name, Conv2D( - num_channels=feat_in if j == 0 else feat_out, - num_filters=feat_out, - filter_size=3, - act='relu', + in_channels=feat_in if j == 0 else feat_out, + out_channels=feat_out, + kernel_size=3, padding=1, - param_attr=ParamAttr( - #name=conv_name+'_w', - initializer=MSRA( - uniform=False, fan_in=fan_conv)), + weight_attr=ParamAttr( + initializer=KaimingNormal(fan_in=fan_conv)), bias_attr=ParamAttr( - #name=conv_name+'_b', - learning_rate=2., - regularizer=L2Decay(0.)))) + learning_rate=2., regularizer=L2Decay(0.)))) + mask_conv.add_sublayer(conv_name + 'act', ReLU()) mask_conv.add_sublayer( 'conv5_mask', Conv2DTranspose( - num_channels=self.feat_in, - num_filters=self.feat_out, - filter_size=2, + in_channels=self.feat_in, + out_channels=self.feat_out, + kernel_size=2, stride=2, - act='relu', - param_attr=ParamAttr( - #name='conv5_mask_w', - initializer=MSRA( - uniform=False, fan_in=fan_deconv)), + weight_attr=ParamAttr( + initializer=KaimingNormal(fan_in=fan_deconv)), bias_attr=ParamAttr( - #name='conv5_mask_b', - learning_rate=2., - regularizer=L2Decay(0.)))) + learning_rate=2., regularizer=L2Decay(0.)))) + mask_conv.add_sublayer('conv5_mask' + 'act', ReLU()) upsample = self.add_sublayer(name, mask_conv) self.upsample_module.append(upsample) @@ -77,7 +84,7 @@ class MaskFeat(Layer): spatial_scale, stage=0): if self.share_bbox_feat: - rois_feat = fluid.layers.gather(bbox_feat, mask_index) + rois_feat = paddle.gather(bbox_feat, mask_index) else: rois_feat = self.mask_roi_extractor(body_feats, bboxes, spatial_scale) @@ -107,18 +114,14 @@ class MaskHead(Layer): self.mask_fcn_logits.append( self.add_sublayer( name, - fluid.dygraph.Conv2D( - num_channels=self.feat_in, - num_filters=self.num_classes, - filter_size=1, - param_attr=ParamAttr( - #name='mask_fcn_logits_w', - initializer=MSRA( - uniform=False, fan_in=self.num_classes)), + Conv2D( + in_channels=self.feat_in, + out_channels=self.num_classes, + kernel_size=1, + weight_attr=ParamAttr(initializer=KaimingNormal( + fan_in=self.num_classes)), bias_attr=ParamAttr( - #name='mask_fcn_logits_b', - learning_rate=2., - regularizer=L2Decay(0.0))))) + learning_rate=2., regularizer=L2Decay(0.0))))) def forward_train(self, body_feats, @@ -150,14 +153,13 @@ class MaskHead(Layer): for idx, num in enumerate(bbox_num): for n in range(num): im_info_expand.append(im_info[idx, -1]) - im_info_expand = fluid.layers.concat(im_info_expand) - scaled_bbox = fluid.layers.elementwise_mul( - bbox[:, 2:], im_info_expand, axis=0) + im_info_expand = paddle.concat(im_info_expand) + scaled_bbox = paddle.multiply(bbox[:, 2:], im_info_expand, axis=0) scaled_bboxes = (scaled_bbox, bbox_num) mask_feat = self.mask_feat(body_feats, scaled_bboxes, bbox_feat, mask_index, spatial_scale, stage) mask_logit = self.mask_fcn_logits[stage](mask_feat) - mask_head_out = fluid.layers.sigmoid(mask_logit) + mask_head_out = F.sigmoid(mask_logit) return mask_head_out def forward(self, @@ -179,12 +181,14 @@ class MaskHead(Layer): return mask_head_out def get_loss(self, mask_head_out, mask_target): - mask_logits = fluid.layers.flatten(mask_head_out) - mask_label = fluid.layers.cast(x=mask_target, dtype='float32') + mask_logits = paddle.flatten(mask_head_out, start_axis=1, stop_axis=-1) + mask_label = paddle.cast(x=mask_target, dtype='float32') mask_label.stop_gradient = True - - loss_mask = fluid.layers.sigmoid_cross_entropy_with_logits( - x=mask_logits, label=mask_label, ignore_index=-1, normalize=True) - loss_mask = fluid.layers.reduce_sum(loss_mask) + loss_mask = ops.sigmoid_cross_entropy_with_logits( + input=mask_logits, + label=mask_label, + ignore_index=-1, + normalize=True) + loss_mask = paddle.sum(loss_mask) return {'loss_mask': loss_mask} diff --git a/ppdet/modeling/head/rpn_head.py b/ppdet/modeling/head/rpn_head.py index 5fa4e4ee3c2eac273f24da947abad4192bf87899..64f7acc495326d4edbbff389e5351f602e67f0de 100644 --- a/ppdet/modeling/head/rpn_head.py +++ b/ppdet/modeling/head/rpn_head.py @@ -1,41 +1,53 @@ -import paddle.fluid as fluid -from paddle.fluid.dygraph import Layer -from paddle.fluid.param_attr import ParamAttr -from paddle.fluid.initializer import Normal -from paddle.fluid.regularizer import L2Decay -from paddle.fluid.dygraph.nn import Conv2D +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.nn.initializer import Normal +from paddle.regularizer import L2Decay +from paddle.nn import Conv2D + from ppdet.core.workspace import register +from ppdet.modeling import ops @register -class RPNFeat(Layer): +class RPNFeat(nn.Layer): def __init__(self, feat_in=1024, feat_out=1024): super(RPNFeat, self).__init__() # rpn feat is shared with each level self.rpn_conv = Conv2D( - num_channels=feat_in, - num_filters=feat_out, - filter_size=3, + in_channels=feat_in, + out_channels=feat_out, + kernel_size=3, padding=1, - act='relu', - param_attr=ParamAttr( - #name="conv_rpn_fpn2_w", - initializer=Normal( - loc=0., scale=0.01)), + weight_attr=ParamAttr(initializer=Normal( + mean=0., std=0.01)), bias_attr=ParamAttr( - #name="conv_rpn_fpn2_b", - learning_rate=2., - regularizer=L2Decay(0.))) + learning_rate=2., regularizer=L2Decay(0.))) def forward(self, inputs, feats): rpn_feats = [] for feat in feats: - rpn_feats.append(self.rpn_conv(feat)) + rpn_feats.append(F.relu(self.rpn_conv(feat))) return rpn_feats @register -class RPNHead(Layer): +class RPNHead(nn.Layer): __inject__ = ['rpn_feat'] def __init__(self, rpn_feat, anchor_per_position=15, rpn_channel=1024): @@ -46,35 +58,25 @@ class RPNHead(Layer): # rpn head is shared with each level # rpn roi classification scores self.rpn_rois_score = Conv2D( - num_channels=rpn_channel, - num_filters=anchor_per_position, - filter_size=1, + in_channels=rpn_channel, + out_channels=anchor_per_position, + kernel_size=1, padding=0, - act=None, - param_attr=ParamAttr( - #name="rpn_cls_logits_fpn2_w", - initializer=Normal( - loc=0., scale=0.01)), + weight_attr=ParamAttr(initializer=Normal( + mean=0., std=0.01)), bias_attr=ParamAttr( - #name="rpn_cls_logits_fpn2_b", - learning_rate=2., - regularizer=L2Decay(0.))) + learning_rate=2., regularizer=L2Decay(0.))) # rpn roi bbox regression deltas self.rpn_rois_delta = Conv2D( - num_channels=rpn_channel, - num_filters=4 * anchor_per_position, - filter_size=1, + in_channels=rpn_channel, + out_channels=4 * anchor_per_position, + kernel_size=1, padding=0, - act=None, - param_attr=ParamAttr( - #name="rpn_bbox_pred_fpn2_w", - initializer=Normal( - loc=0., scale=0.01)), + weight_attr=ParamAttr(initializer=Normal( + mean=0., std=0.01)), bias_attr=ParamAttr( - #name="rpn_bbox_pred_fpn2_b", - learning_rate=2., - regularizer=L2Decay(0.))) + learning_rate=2., regularizer=L2Decay(0.))) def forward(self, inputs, feats): rpn_feats = self.rpn_feat(inputs, feats) @@ -87,28 +89,26 @@ class RPNHead(Layer): def get_loss(self, loss_inputs): # cls loss - score_tgt = fluid.layers.cast( + score_tgt = paddle.cast( x=loss_inputs['rpn_score_target'], dtype='float32') score_tgt.stop_gradient = True - loss_rpn_cls = fluid.layers.sigmoid_cross_entropy_with_logits( - x=loss_inputs['rpn_score_pred'], label=score_tgt) - loss_rpn_cls = fluid.layers.reduce_mean( - loss_rpn_cls, name='loss_rpn_cls') + loss_rpn_cls = ops.sigmoid_cross_entropy_with_logits( + input=loss_inputs['rpn_score_pred'], label=score_tgt) + loss_rpn_cls = paddle.mean(loss_rpn_cls, name='loss_rpn_cls') # reg loss - loc_tgt = fluid.layers.cast( - x=loss_inputs['rpn_rois_target'], dtype='float32') + loc_tgt = paddle.cast(x=loss_inputs['rpn_rois_target'], dtype='float32') loc_tgt.stop_gradient = True - loss_rpn_reg = fluid.layers.smooth_l1( - x=loss_inputs['rpn_rois_pred'], - y=loc_tgt, - sigma=3.0, + loss_rpn_reg = ops.smooth_l1( + input=loss_inputs['rpn_rois_pred'], + label=loc_tgt, inside_weight=loss_inputs['rpn_rois_weight'], - outside_weight=loss_inputs['rpn_rois_weight']) - loss_rpn_reg = fluid.layers.reduce_sum(loss_rpn_reg) - score_shape = fluid.layers.shape(score_tgt) - score_shape = fluid.layers.cast(x=score_shape, dtype='float32') - norm = fluid.layers.reduce_prod(score_shape) + outside_weight=loss_inputs['rpn_rois_weight'], + sigma=3.0, ) + loss_rpn_reg = paddle.sum(loss_rpn_reg) + score_shape = paddle.shape(score_tgt) + score_shape = paddle.cast(score_shape, dtype='float32') + norm = paddle.prod(score_shape) norm.stop_gradient = True loss_rpn_reg = loss_rpn_reg / norm diff --git a/ppdet/modeling/neck/fpn.py b/ppdet/modeling/neck/fpn.py index 321335f457f9c958cb072a59fcf72dd87e0ed1df..da9c63ace523c5f4cb79c64225f083af0de61724 100644 --- a/ppdet/modeling/neck/fpn.py +++ b/ppdet/modeling/neck/fpn.py @@ -1,9 +1,25 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import numpy as np +import paddle import paddle.fluid as fluid -from paddle.fluid.dygraph import Layer -from paddle.fluid.dygraph import Conv2D, Pool2D, BatchNorm -from paddle.fluid.param_attr import ParamAttr -from paddle.fluid.initializer import Xavier +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.nn import Layer +from paddle.nn import Conv2D +from paddle.nn.initializer import XavierUniform from paddle.fluid.regularizer import L2Decay from ppdet.core.workspace import register, serializable @@ -32,33 +48,27 @@ class FPN(Layer): lateral = self.add_sublayer( lateral_name, Conv2D( - num_channels=in_c, - num_filters=out_channel, - filter_size=1, - param_attr=ParamAttr( - #name=lateral_name+'_w', - initializer=Xavier(fan_out=in_c)), + in_channels=in_c, + out_channels=out_channel, + kernel_size=1, + weight_attr=ParamAttr( + initializer=XavierUniform(fan_out=in_c)), bias_attr=ParamAttr( - #name=lateral_name+'_b', - learning_rate=2., - regularizer=L2Decay(0.)))) + learning_rate=2., regularizer=L2Decay(0.)))) self.lateral_convs.append(lateral) fpn_name = 'fpn_res{}_sum'.format(i + 2) fpn_conv = self.add_sublayer( fpn_name, Conv2D( - num_channels=out_channel, - num_filters=out_channel, - filter_size=3, + in_channels=out_channel, + out_channels=out_channel, + kernel_size=3, padding=1, - param_attr=ParamAttr( - #name=fpn_name+'_w', - initializer=Xavier(fan_out=fan)), + weight_attr=ParamAttr( + initializer=XavierUniform(fan_out=fan)), bias_attr=ParamAttr( - #name=fpn_name+'_b', - learning_rate=2., - regularizer=L2Decay(0.)))) + learning_rate=2., regularizer=L2Decay(0.)))) self.fpn_convs.append(fpn_conv) self.min_level = min_level @@ -71,14 +81,17 @@ class FPN(Layer): laterals.append(self.lateral_convs[lvl](body_feats[lvl])) for lvl in range(self.max_level - 1, self.min_level, -1): - upsample = fluid.layers.resize_nearest(laterals[lvl], scale=2.) + upsample = F.interpolate( + laterals[lvl], + scale_factor=2., + mode='nearest', ) laterals[lvl - 1] = laterals[lvl - 1] + upsample fpn_output = [] for lvl in range(self.min_level, self.max_level): fpn_output.append(self.fpn_convs[lvl](laterals[lvl])) - extension = fluid.layers.pool2d(fpn_output[-1], 1, 'max', pool_stride=2) + extension = F.max_pool2d(fpn_output[-1], 1, stride=2) spatial_scale = self.spatial_scale + [self.spatial_scale[-1] * 0.5] fpn_output.append(extension) diff --git a/ppdet/modeling/ops.py b/ppdet/modeling/ops.py index 6a2d5941dae6c82a4cc767fd3d94c2714c2b8124..6c8a55204e999aaf32cd8883718c3e6f4dfd3797 100644 --- a/ppdet/modeling/ops.py +++ b/ppdet/modeling/ops.py @@ -13,6 +13,7 @@ # limitations under the License. import paddle +import paddle.nn.functional as F from paddle.fluid.framework import Variable, in_dygraph_mode from paddle.fluid import core @@ -1509,3 +1510,30 @@ def generate_proposals(scores, return rpn_rois, rpn_roi_probs, rpn_rois_num else: return rpn_rois, rpn_roi_probs + + +def sigmoid_cross_entropy_with_logits(input, + label, + ignore_index=-100, + normalize=False): + output = F.binary_cross_entropy_with_logits(input, label, reduction='none') + mask_tensor = paddle.cast(label != ignore_index, 'float32') + output = paddle.multiply(output, mask_tensor) + output = paddle.reshape(output, shape=[output.shape[0], -1]) + if normalize: + sum_valid_mask = paddle.sum(mask_tensor) + output = output / sum_valid_mask + return output + + +def smooth_l1(input, label, inside_weight=None, outside_weight=None, + sigma=None): + input_new = paddle.multiply(input, inside_weight) + label_new = paddle.multiply(label, inside_weight) + delta = 1 / (sigma * sigma) + out = F.smooth_l1_loss(input_new, label_new, reduction='none', delta=delta) + out = paddle.multiply(out, outside_weight) + out = out / delta + out = paddle.reshape(out, shape=[out.shape[0], -1]) + out = paddle.sum(out, axis=1) + return out