未验证 提交 ee911f6b 编写于 作者: G Guanghua Yu 提交者: GitHub

[Dygraph]update mask_rcnn_fpn model (#1718)

* update mask_rcnn_fpn model

* update smooth_l1_loss

* update smooth_l1 and sigmoid_cross_entropy_with_logits

* fix relu name
上级 31628f60
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from paddle import fluid import paddle
from ppdet.core.workspace import register from ppdet.core.workspace import register
from .meta_arch import BaseArch from .meta_arch import BaseArch
...@@ -114,7 +128,7 @@ class MaskRCNN(BaseArch): ...@@ -114,7 +128,7 @@ class MaskRCNN(BaseArch):
loss_mask = self.mask_head.get_loss(self.mask_head_out, mask_targets) loss_mask = self.mask_head.get_loss(self.mask_head_out, mask_targets)
loss.update(loss_mask) loss.update(loss_mask)
total_loss = fluid.layers.sums(list(loss.values())) total_loss = paddle.add_n(list(loss.values()))
loss.update({'loss': total_loss}) loss.update({'loss': total_loss})
return loss return loss
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
import paddle.fluid as fluid from paddle import ParamAttr
from paddle.fluid.dygraph import Layer, Sequential import paddle
from paddle.fluid.dygraph import Conv2D, Pool2D, BatchNorm import paddle.nn as nn
from paddle.fluid.param_attr import ParamAttr import paddle.nn.functional as F
from paddle.fluid.initializer import Constant from paddle.nn import Conv2D, BatchNorm
from paddle.nn import MaxPool2D
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from paddle.fluid.regularizer import L2Decay
from paddle.regularizer import L2Decay
from .name_adapter import NameAdapter from .name_adapter import NameAdapter
from numbers import Integral from numbers import Integral
class ConvNormLayer(Layer): class ConvNormLayer(nn.Layer):
def __init__(self, def __init__(self,
ch_in, ch_in,
ch_out, ch_out,
...@@ -24,19 +41,18 @@ class ConvNormLayer(Layer): ...@@ -24,19 +41,18 @@ class ConvNormLayer(Layer):
lr=1.0, lr=1.0,
name=None): name=None):
super(ConvNormLayer, self).__init__() super(ConvNormLayer, self).__init__()
assert norm_type in ['bn', 'affine_channel'] assert norm_type in ['bn', 'sync_bn']
self.norm_type = norm_type self.norm_type = norm_type
self.act = act self.act = act
self.conv = Conv2D( self.conv = Conv2D(
num_channels=ch_in, in_channels=ch_in,
num_filters=ch_out, out_channels=ch_out,
filter_size=filter_size, kernel_size=filter_size,
stride=stride, stride=stride,
padding=(filter_size - 1) // 2, padding=(filter_size - 1) // 2,
groups=1, groups=1,
act=None, weight_attr=ParamAttr(
param_attr=ParamAttr(
learning_rate=lr, name=name + "_weights"), learning_rate=lr, name=name + "_weights"),
bias_attr=False) bias_attr=False)
...@@ -53,30 +69,16 @@ class ConvNormLayer(Layer): ...@@ -53,30 +69,16 @@ class ConvNormLayer(Layer):
name=bn_name + "_offset", name=bn_name + "_offset",
trainable=False if freeze_norm else True) trainable=False if freeze_norm else True)
if norm_type in ['bn', 'sync_bn']: global_stats = True if freeze_norm else False
global_stats = True if freeze_norm else False self.norm = BatchNorm(
self.norm = BatchNorm( ch_out,
num_channels=ch_out, act=act,
act=act, param_attr=param_attr,
param_attr=param_attr, bias_attr=bias_attr,
bias_attr=bias_attr, use_global_stats=global_stats,
use_global_stats=global_stats, moving_mean_name=bn_name + '_mean',
moving_mean_name=bn_name + '_mean', moving_variance_name=bn_name + '_variance')
moving_variance_name=bn_name + '_variance') norm_params = self.norm.parameters()
norm_params = self.norm.parameters()
elif norm_type == 'affine_channel':
self.scale = fluid.layers.create_parameter(
shape=[ch_out],
dtype='float32',
attr=param_attr,
default_initializer=Constant(1.))
self.offset = fluid.layers.create_parameter(
shape=[ch_out],
dtype='float32',
attr=bias_attr,
default_initializer=Constant(0.))
norm_params = [self.scale, self.offset]
if freeze_norm: if freeze_norm:
for param in norm_params: for param in norm_params:
...@@ -86,13 +88,10 @@ class ConvNormLayer(Layer): ...@@ -86,13 +88,10 @@ class ConvNormLayer(Layer):
out = self.conv(inputs) out = self.conv(inputs)
if self.norm_type == 'bn': if self.norm_type == 'bn':
out = self.norm(out) out = self.norm(out)
elif self.norm_type == 'affine_channel':
out = fluid.layers.affine_channel(
out, scale=self.scale, bias=self.offset, act=self.act)
return out return out
class BottleNeck(Layer): class BottleNeck(nn.Layer):
def __init__(self, def __init__(self,
ch_in, ch_in,
ch_out, ch_out,
...@@ -176,12 +175,13 @@ class BottleNeck(Layer): ...@@ -176,12 +175,13 @@ class BottleNeck(Layer):
out = self.branch2b(out) out = self.branch2b(out)
out = self.branch2c(out) out = self.branch2c(out)
out = fluid.layers.elementwise_add(x=short, y=out, act='relu') out = paddle.add(x=short, y=out)
out = F.relu(out)
return out return out
class Blocks(Layer): class Blocks(nn.Layer):
def __init__(self, def __init__(self,
ch_in, ch_in,
ch_out, ch_out,
...@@ -226,7 +226,7 @@ ResNet_cfg = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]} ...@@ -226,7 +226,7 @@ ResNet_cfg = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}
@register @register
@serializable @serializable
class ResNet(Layer): class ResNet(nn.Layer):
def __init__(self, def __init__(self,
depth=50, depth=50,
variant='b', variant='b',
...@@ -265,7 +265,7 @@ class ResNet(Layer): ...@@ -265,7 +265,7 @@ class ResNet(Layer):
] ]
else: else:
conv_def = [[3, 64, 7, 2, conv1_name]] conv_def = [[3, 64, 7, 2, conv1_name]]
self.conv1 = Sequential() self.conv1 = nn.Sequential()
for (c_in, c_out, k, s, _name) in conv_def: for (c_in, c_out, k, s, _name) in conv_def:
self.conv1.add_sublayer( self.conv1.add_sublayer(
_name, _name,
...@@ -282,8 +282,7 @@ class ResNet(Layer): ...@@ -282,8 +282,7 @@ class ResNet(Layer):
lr=lr_mult, lr=lr_mult,
name=_name)) name=_name))
self.pool = Pool2D( self.pool = MaxPool2D(kernel_size=3, stride=2, padding=1)
pool_type='max', pool_size=3, pool_stride=2, pool_padding=1)
ch_in_list = [64, 256, 512, 1024] ch_in_list = [64, 256, 512, 1024]
ch_out_list = [64, 128, 256, 512] ch_out_list = [64, 128, 256, 512]
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle import paddle
import paddle.fluid as fluid from paddle import ParamAttr
from paddle.fluid.dygraph import Layer import paddle.nn as nn
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Normal, Xavier
from paddle.fluid.regularizer import L2Decay
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from ppdet.core.workspace import register
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.nn import ReLU
from paddle.nn.initializer import Normal, XavierUniform
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register
from ppdet.modeling import ops
@register @register
class TwoFCHead(Layer): class TwoFCHead(nn.Layer):
__shared__ = ['num_stages'] __shared__ = ['num_stages']
...@@ -21,48 +35,47 @@ class TwoFCHead(Layer): ...@@ -21,48 +35,47 @@ class TwoFCHead(Layer):
self.num_stages = num_stages self.num_stages = num_stages
fan = in_dim * resolution * resolution fan = in_dim * resolution * resolution
self.fc6_list = [] self.fc6_list = []
self.fc6_relu_list = []
self.fc7_list = [] self.fc7_list = []
self.fc7_relu_list = []
for stage in range(num_stages): for stage in range(num_stages):
fc6_name = 'fc6_{}'.format(stage) fc6_name = 'fc6_{}'.format(stage)
fc7_name = 'fc7_{}'.format(stage) fc7_name = 'fc7_{}'.format(stage)
fc6 = self.add_sublayer( fc6 = self.add_sublayer(
fc6_name, fc6_name,
Linear( nn.Linear(
in_dim * resolution * resolution, in_dim * resolution * resolution,
mlp_dim, mlp_dim,
act='relu', weight_attr=ParamAttr(
param_attr=ParamAttr( initializer=XavierUniform(fan_out=fan)),
#name='fc6_w',
initializer=Xavier(fan_out=fan)),
bias_attr=ParamAttr( bias_attr=ParamAttr(
#name='fc6_b', learning_rate=2., regularizer=L2Decay(0.))))
learning_rate=2., fc6_relu = self.add_sublayer(fc6_name + 'act', ReLU())
regularizer=L2Decay(0.))))
fc7 = self.add_sublayer( fc7 = self.add_sublayer(
fc7_name, fc7_name,
Linear( nn.Linear(
mlp_dim, mlp_dim,
mlp_dim, mlp_dim,
act='relu', weight_attr=ParamAttr(initializer=XavierUniform()),
param_attr=ParamAttr(
#name='fc7_w',
initializer=Xavier()),
bias_attr=ParamAttr( bias_attr=ParamAttr(
#name='fc7_b', learning_rate=2., regularizer=L2Decay(0.))))
learning_rate=2., fc7_relu = self.add_sublayer(fc7_name + 'act', ReLU())
regularizer=L2Decay(0.))))
self.fc6_list.append(fc6) self.fc6_list.append(fc6)
self.fc6_relu_list.append(fc6_relu)
self.fc7_list.append(fc7) self.fc7_list.append(fc7)
self.fc7_relu_list.append(fc7_relu)
def forward(self, rois_feat, stage=0): def forward(self, rois_feat, stage=0):
rois_feat = fluid.layers.flatten(rois_feat) rois_feat = paddle.flatten(rois_feat, start_axis=1, stop_axis=-1)
fc6 = self.fc6_list[stage](rois_feat) fc6 = self.fc6_list[stage](rois_feat)
fc7 = self.fc7_list[stage](fc6) fc6_relu = self.fc6_relu_list[stage](fc6)
return fc7 fc7 = self.fc7_list[stage](fc6_relu)
fc7_relu = self.fc7_relu_list[stage](fc7)
return fc7_relu
@register @register
class BBoxFeat(Layer): class BBoxFeat(nn.Layer):
__inject__ = ['roi_extractor', 'head_feat'] __inject__ = ['roi_extractor', 'head_feat']
def __init__(self, roi_extractor, head_feat): def __init__(self, roi_extractor, head_feat):
...@@ -77,7 +90,7 @@ class BBoxFeat(Layer): ...@@ -77,7 +90,7 @@ class BBoxFeat(Layer):
@register @register
class BBoxHead(Layer): class BBoxHead(nn.Layer):
__shared__ = ['num_classes', 'num_stages'] __shared__ = ['num_classes', 'num_stages']
__inject__ = ['bbox_feat'] __inject__ = ['bbox_feat']
...@@ -105,40 +118,30 @@ class BBoxHead(Layer): ...@@ -105,40 +118,30 @@ class BBoxHead(Layer):
delta_name = 'bbox_delta_{}'.format(stage) delta_name = 'bbox_delta_{}'.format(stage)
bbox_score = self.add_sublayer( bbox_score = self.add_sublayer(
score_name, score_name,
fluid.dygraph.Linear( nn.Linear(
input_dim=in_feat, in_feat,
output_dim=1 * self.num_classes, 1 * self.num_classes,
act=None, weight_attr=ParamAttr(initializer=Normal(
param_attr=ParamAttr( mean=0.0, std=0.01)),
#name='cls_score_w',
initializer=Normal(
loc=0.0, scale=0.01)),
bias_attr=ParamAttr( bias_attr=ParamAttr(
#name='cls_score_b', learning_rate=2., regularizer=L2Decay(0.))))
learning_rate=2.,
regularizer=L2Decay(0.))))
bbox_delta = self.add_sublayer( bbox_delta = self.add_sublayer(
delta_name, delta_name,
fluid.dygraph.Linear( nn.Linear(
input_dim=in_feat, in_feat,
output_dim=4 * self.delta_dim, 4 * self.delta_dim,
act=None, weight_attr=ParamAttr(initializer=Normal(
param_attr=ParamAttr( mean=0.0, std=0.001)),
#name='bbox_pred_w',
initializer=Normal(
loc=0.0, scale=0.001)),
bias_attr=ParamAttr( bias_attr=ParamAttr(
#name='bbox_pred_b', learning_rate=2., regularizer=L2Decay(0.))))
learning_rate=2.,
regularizer=L2Decay(0.))))
self.bbox_score_list.append(bbox_score) self.bbox_score_list.append(bbox_score)
self.bbox_delta_list.append(bbox_delta) self.bbox_delta_list.append(bbox_delta)
def forward(self, body_feats, rois, spatial_scale, stage=0): def forward(self, body_feats, rois, spatial_scale, stage=0):
bbox_feat = self.bbox_feat(body_feats, rois, spatial_scale, stage) bbox_feat = self.bbox_feat(body_feats, rois, spatial_scale, stage)
if self.with_pool: if self.with_pool:
bbox_feat = fluid.layers.pool2d( bbox_feat = F.pool2d(
bbox_feat, pool_type='avg', global_pooling=True) bbox_feat, pool_type='avg', global_pooling=True)
bbox_head_out = [] bbox_head_out = []
scores = self.bbox_score_list[stage](bbox_feat) scores = self.bbox_score_list[stage](bbox_feat)
...@@ -148,20 +151,19 @@ class BBoxHead(Layer): ...@@ -148,20 +151,19 @@ class BBoxHead(Layer):
def _get_head_loss(self, score, delta, target): def _get_head_loss(self, score, delta, target):
# bbox cls # bbox cls
labels_int64 = fluid.layers.cast( labels_int64 = paddle.cast(x=target['labels_int32'], dtype='int64')
x=target['labels_int32'], dtype='int64')
labels_int64.stop_gradient = True labels_int64.stop_gradient = True
loss_bbox_cls = fluid.layers.softmax_with_cross_entropy( loss_bbox_cls = F.softmax_with_cross_entropy(
logits=score, label=labels_int64) logits=score, label=labels_int64)
loss_bbox_cls = fluid.layers.reduce_mean(loss_bbox_cls) loss_bbox_cls = paddle.mean(loss_bbox_cls)
# bbox reg # bbox reg
loss_bbox_reg = fluid.layers.smooth_l1( loss_bbox_reg = ops.smooth_l1(
x=delta, input=delta,
y=target['bbox_targets'], label=target['bbox_targets'],
inside_weight=target['bbox_inside_weights'], inside_weight=target['bbox_inside_weights'],
outside_weight=target['bbox_outside_weights'], outside_weight=target['bbox_outside_weights'],
sigma=1.0) sigma=1.0)
loss_bbox_reg = fluid.layers.reduce_mean(loss_bbox_reg) loss_bbox_reg = paddle.mean(loss_bbox_reg)
return loss_bbox_cls, loss_bbox_reg return loss_bbox_cls, loss_bbox_reg
def get_loss(self, bbox_head_out, targets): def get_loss(self, bbox_head_out, targets):
......
import paddle.fluid as fluid # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
from paddle.fluid.dygraph import Layer, Sequential #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.param_attr import ParamAttr import paddle
from paddle.fluid.initializer import MSRA import paddle.nn.functional as F
from paddle.fluid.regularizer import L2Decay from paddle import ParamAttr
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Conv2DTranspose from paddle.nn import Layer, Sequential
from paddle.nn import Conv2D, Conv2DTranspose, ReLU
from paddle.nn.initializer import KaimingNormal
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register from ppdet.core.workspace import register
from ppdet.modeling import ops
@register @register
...@@ -37,35 +52,27 @@ class MaskFeat(Layer): ...@@ -37,35 +52,27 @@ class MaskFeat(Layer):
mask_conv.add_sublayer( mask_conv.add_sublayer(
conv_name, conv_name,
Conv2D( Conv2D(
num_channels=feat_in if j == 0 else feat_out, in_channels=feat_in if j == 0 else feat_out,
num_filters=feat_out, out_channels=feat_out,
filter_size=3, kernel_size=3,
act='relu',
padding=1, padding=1,
param_attr=ParamAttr( weight_attr=ParamAttr(
#name=conv_name+'_w', initializer=KaimingNormal(fan_in=fan_conv)),
initializer=MSRA(
uniform=False, fan_in=fan_conv)),
bias_attr=ParamAttr( bias_attr=ParamAttr(
#name=conv_name+'_b', learning_rate=2., regularizer=L2Decay(0.))))
learning_rate=2., mask_conv.add_sublayer(conv_name + 'act', ReLU())
regularizer=L2Decay(0.))))
mask_conv.add_sublayer( mask_conv.add_sublayer(
'conv5_mask', 'conv5_mask',
Conv2DTranspose( Conv2DTranspose(
num_channels=self.feat_in, in_channels=self.feat_in,
num_filters=self.feat_out, out_channels=self.feat_out,
filter_size=2, kernel_size=2,
stride=2, stride=2,
act='relu', weight_attr=ParamAttr(
param_attr=ParamAttr( initializer=KaimingNormal(fan_in=fan_deconv)),
#name='conv5_mask_w',
initializer=MSRA(
uniform=False, fan_in=fan_deconv)),
bias_attr=ParamAttr( bias_attr=ParamAttr(
#name='conv5_mask_b', learning_rate=2., regularizer=L2Decay(0.))))
learning_rate=2., mask_conv.add_sublayer('conv5_mask' + 'act', ReLU())
regularizer=L2Decay(0.))))
upsample = self.add_sublayer(name, mask_conv) upsample = self.add_sublayer(name, mask_conv)
self.upsample_module.append(upsample) self.upsample_module.append(upsample)
...@@ -77,7 +84,7 @@ class MaskFeat(Layer): ...@@ -77,7 +84,7 @@ class MaskFeat(Layer):
spatial_scale, spatial_scale,
stage=0): stage=0):
if self.share_bbox_feat: if self.share_bbox_feat:
rois_feat = fluid.layers.gather(bbox_feat, mask_index) rois_feat = paddle.gather(bbox_feat, mask_index)
else: else:
rois_feat = self.mask_roi_extractor(body_feats, bboxes, rois_feat = self.mask_roi_extractor(body_feats, bboxes,
spatial_scale) spatial_scale)
...@@ -107,18 +114,14 @@ class MaskHead(Layer): ...@@ -107,18 +114,14 @@ class MaskHead(Layer):
self.mask_fcn_logits.append( self.mask_fcn_logits.append(
self.add_sublayer( self.add_sublayer(
name, name,
fluid.dygraph.Conv2D( Conv2D(
num_channels=self.feat_in, in_channels=self.feat_in,
num_filters=self.num_classes, out_channels=self.num_classes,
filter_size=1, kernel_size=1,
param_attr=ParamAttr( weight_attr=ParamAttr(initializer=KaimingNormal(
#name='mask_fcn_logits_w', fan_in=self.num_classes)),
initializer=MSRA(
uniform=False, fan_in=self.num_classes)),
bias_attr=ParamAttr( bias_attr=ParamAttr(
#name='mask_fcn_logits_b', learning_rate=2., regularizer=L2Decay(0.0)))))
learning_rate=2.,
regularizer=L2Decay(0.0)))))
def forward_train(self, def forward_train(self,
body_feats, body_feats,
...@@ -150,14 +153,13 @@ class MaskHead(Layer): ...@@ -150,14 +153,13 @@ class MaskHead(Layer):
for idx, num in enumerate(bbox_num): for idx, num in enumerate(bbox_num):
for n in range(num): for n in range(num):
im_info_expand.append(im_info[idx, -1]) im_info_expand.append(im_info[idx, -1])
im_info_expand = fluid.layers.concat(im_info_expand) im_info_expand = paddle.concat(im_info_expand)
scaled_bbox = fluid.layers.elementwise_mul( scaled_bbox = paddle.multiply(bbox[:, 2:], im_info_expand, axis=0)
bbox[:, 2:], im_info_expand, axis=0)
scaled_bboxes = (scaled_bbox, bbox_num) scaled_bboxes = (scaled_bbox, bbox_num)
mask_feat = self.mask_feat(body_feats, scaled_bboxes, bbox_feat, mask_feat = self.mask_feat(body_feats, scaled_bboxes, bbox_feat,
mask_index, spatial_scale, stage) mask_index, spatial_scale, stage)
mask_logit = self.mask_fcn_logits[stage](mask_feat) mask_logit = self.mask_fcn_logits[stage](mask_feat)
mask_head_out = fluid.layers.sigmoid(mask_logit) mask_head_out = F.sigmoid(mask_logit)
return mask_head_out return mask_head_out
def forward(self, def forward(self,
...@@ -179,12 +181,14 @@ class MaskHead(Layer): ...@@ -179,12 +181,14 @@ class MaskHead(Layer):
return mask_head_out return mask_head_out
def get_loss(self, mask_head_out, mask_target): def get_loss(self, mask_head_out, mask_target):
mask_logits = fluid.layers.flatten(mask_head_out) mask_logits = paddle.flatten(mask_head_out, start_axis=1, stop_axis=-1)
mask_label = fluid.layers.cast(x=mask_target, dtype='float32') mask_label = paddle.cast(x=mask_target, dtype='float32')
mask_label.stop_gradient = True mask_label.stop_gradient = True
loss_mask = ops.sigmoid_cross_entropy_with_logits(
loss_mask = fluid.layers.sigmoid_cross_entropy_with_logits( input=mask_logits,
x=mask_logits, label=mask_label, ignore_index=-1, normalize=True) label=mask_label,
loss_mask = fluid.layers.reduce_sum(loss_mask) ignore_index=-1,
normalize=True)
loss_mask = paddle.sum(loss_mask)
return {'loss_mask': loss_mask} return {'loss_mask': loss_mask}
import paddle.fluid as fluid # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
from paddle.fluid.dygraph import Layer #
from paddle.fluid.param_attr import ParamAttr # Licensed under the Apache License, Version 2.0 (the "License");
from paddle.fluid.initializer import Normal # you may not use this file except in compliance with the License.
from paddle.fluid.regularizer import L2Decay # You may obtain a copy of the License at
from paddle.fluid.dygraph.nn import Conv2D #
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import Normal
from paddle.regularizer import L2Decay
from paddle.nn import Conv2D
from ppdet.core.workspace import register from ppdet.core.workspace import register
from ppdet.modeling import ops
@register @register
class RPNFeat(Layer): class RPNFeat(nn.Layer):
def __init__(self, feat_in=1024, feat_out=1024): def __init__(self, feat_in=1024, feat_out=1024):
super(RPNFeat, self).__init__() super(RPNFeat, self).__init__()
# rpn feat is shared with each level # rpn feat is shared with each level
self.rpn_conv = Conv2D( self.rpn_conv = Conv2D(
num_channels=feat_in, in_channels=feat_in,
num_filters=feat_out, out_channels=feat_out,
filter_size=3, kernel_size=3,
padding=1, padding=1,
act='relu', weight_attr=ParamAttr(initializer=Normal(
param_attr=ParamAttr( mean=0., std=0.01)),
#name="conv_rpn_fpn2_w",
initializer=Normal(
loc=0., scale=0.01)),
bias_attr=ParamAttr( bias_attr=ParamAttr(
#name="conv_rpn_fpn2_b", learning_rate=2., regularizer=L2Decay(0.)))
learning_rate=2.,
regularizer=L2Decay(0.)))
def forward(self, inputs, feats): def forward(self, inputs, feats):
rpn_feats = [] rpn_feats = []
for feat in feats: for feat in feats:
rpn_feats.append(self.rpn_conv(feat)) rpn_feats.append(F.relu(self.rpn_conv(feat)))
return rpn_feats return rpn_feats
@register @register
class RPNHead(Layer): class RPNHead(nn.Layer):
__inject__ = ['rpn_feat'] __inject__ = ['rpn_feat']
def __init__(self, rpn_feat, anchor_per_position=15, rpn_channel=1024): def __init__(self, rpn_feat, anchor_per_position=15, rpn_channel=1024):
...@@ -46,35 +58,25 @@ class RPNHead(Layer): ...@@ -46,35 +58,25 @@ class RPNHead(Layer):
# rpn head is shared with each level # rpn head is shared with each level
# rpn roi classification scores # rpn roi classification scores
self.rpn_rois_score = Conv2D( self.rpn_rois_score = Conv2D(
num_channels=rpn_channel, in_channels=rpn_channel,
num_filters=anchor_per_position, out_channels=anchor_per_position,
filter_size=1, kernel_size=1,
padding=0, padding=0,
act=None, weight_attr=ParamAttr(initializer=Normal(
param_attr=ParamAttr( mean=0., std=0.01)),
#name="rpn_cls_logits_fpn2_w",
initializer=Normal(
loc=0., scale=0.01)),
bias_attr=ParamAttr( bias_attr=ParamAttr(
#name="rpn_cls_logits_fpn2_b", learning_rate=2., regularizer=L2Decay(0.)))
learning_rate=2.,
regularizer=L2Decay(0.)))
# rpn roi bbox regression deltas # rpn roi bbox regression deltas
self.rpn_rois_delta = Conv2D( self.rpn_rois_delta = Conv2D(
num_channels=rpn_channel, in_channels=rpn_channel,
num_filters=4 * anchor_per_position, out_channels=4 * anchor_per_position,
filter_size=1, kernel_size=1,
padding=0, padding=0,
act=None, weight_attr=ParamAttr(initializer=Normal(
param_attr=ParamAttr( mean=0., std=0.01)),
#name="rpn_bbox_pred_fpn2_w",
initializer=Normal(
loc=0., scale=0.01)),
bias_attr=ParamAttr( bias_attr=ParamAttr(
#name="rpn_bbox_pred_fpn2_b", learning_rate=2., regularizer=L2Decay(0.)))
learning_rate=2.,
regularizer=L2Decay(0.)))
def forward(self, inputs, feats): def forward(self, inputs, feats):
rpn_feats = self.rpn_feat(inputs, feats) rpn_feats = self.rpn_feat(inputs, feats)
...@@ -87,28 +89,26 @@ class RPNHead(Layer): ...@@ -87,28 +89,26 @@ class RPNHead(Layer):
def get_loss(self, loss_inputs): def get_loss(self, loss_inputs):
# cls loss # cls loss
score_tgt = fluid.layers.cast( score_tgt = paddle.cast(
x=loss_inputs['rpn_score_target'], dtype='float32') x=loss_inputs['rpn_score_target'], dtype='float32')
score_tgt.stop_gradient = True score_tgt.stop_gradient = True
loss_rpn_cls = fluid.layers.sigmoid_cross_entropy_with_logits( loss_rpn_cls = ops.sigmoid_cross_entropy_with_logits(
x=loss_inputs['rpn_score_pred'], label=score_tgt) input=loss_inputs['rpn_score_pred'], label=score_tgt)
loss_rpn_cls = fluid.layers.reduce_mean( loss_rpn_cls = paddle.mean(loss_rpn_cls, name='loss_rpn_cls')
loss_rpn_cls, name='loss_rpn_cls')
# reg loss # reg loss
loc_tgt = fluid.layers.cast( loc_tgt = paddle.cast(x=loss_inputs['rpn_rois_target'], dtype='float32')
x=loss_inputs['rpn_rois_target'], dtype='float32')
loc_tgt.stop_gradient = True loc_tgt.stop_gradient = True
loss_rpn_reg = fluid.layers.smooth_l1( loss_rpn_reg = ops.smooth_l1(
x=loss_inputs['rpn_rois_pred'], input=loss_inputs['rpn_rois_pred'],
y=loc_tgt, label=loc_tgt,
sigma=3.0,
inside_weight=loss_inputs['rpn_rois_weight'], inside_weight=loss_inputs['rpn_rois_weight'],
outside_weight=loss_inputs['rpn_rois_weight']) outside_weight=loss_inputs['rpn_rois_weight'],
loss_rpn_reg = fluid.layers.reduce_sum(loss_rpn_reg) sigma=3.0, )
score_shape = fluid.layers.shape(score_tgt) loss_rpn_reg = paddle.sum(loss_rpn_reg)
score_shape = fluid.layers.cast(x=score_shape, dtype='float32') score_shape = paddle.shape(score_tgt)
norm = fluid.layers.reduce_prod(score_shape) score_shape = paddle.cast(score_shape, dtype='float32')
norm = paddle.prod(score_shape)
norm.stop_gradient = True norm.stop_gradient = True
loss_rpn_reg = loss_rpn_reg / norm loss_rpn_reg = loss_rpn_reg / norm
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph import Layer import paddle.nn.functional as F
from paddle.fluid.dygraph import Conv2D, Pool2D, BatchNorm from paddle import ParamAttr
from paddle.fluid.param_attr import ParamAttr from paddle.nn import Layer
from paddle.fluid.initializer import Xavier from paddle.nn import Conv2D
from paddle.nn.initializer import XavierUniform
from paddle.fluid.regularizer import L2Decay from paddle.fluid.regularizer import L2Decay
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
...@@ -32,33 +48,27 @@ class FPN(Layer): ...@@ -32,33 +48,27 @@ class FPN(Layer):
lateral = self.add_sublayer( lateral = self.add_sublayer(
lateral_name, lateral_name,
Conv2D( Conv2D(
num_channels=in_c, in_channels=in_c,
num_filters=out_channel, out_channels=out_channel,
filter_size=1, kernel_size=1,
param_attr=ParamAttr( weight_attr=ParamAttr(
#name=lateral_name+'_w', initializer=XavierUniform(fan_out=in_c)),
initializer=Xavier(fan_out=in_c)),
bias_attr=ParamAttr( bias_attr=ParamAttr(
#name=lateral_name+'_b', learning_rate=2., regularizer=L2Decay(0.))))
learning_rate=2.,
regularizer=L2Decay(0.))))
self.lateral_convs.append(lateral) self.lateral_convs.append(lateral)
fpn_name = 'fpn_res{}_sum'.format(i + 2) fpn_name = 'fpn_res{}_sum'.format(i + 2)
fpn_conv = self.add_sublayer( fpn_conv = self.add_sublayer(
fpn_name, fpn_name,
Conv2D( Conv2D(
num_channels=out_channel, in_channels=out_channel,
num_filters=out_channel, out_channels=out_channel,
filter_size=3, kernel_size=3,
padding=1, padding=1,
param_attr=ParamAttr( weight_attr=ParamAttr(
#name=fpn_name+'_w', initializer=XavierUniform(fan_out=fan)),
initializer=Xavier(fan_out=fan)),
bias_attr=ParamAttr( bias_attr=ParamAttr(
#name=fpn_name+'_b', learning_rate=2., regularizer=L2Decay(0.))))
learning_rate=2.,
regularizer=L2Decay(0.))))
self.fpn_convs.append(fpn_conv) self.fpn_convs.append(fpn_conv)
self.min_level = min_level self.min_level = min_level
...@@ -71,14 +81,17 @@ class FPN(Layer): ...@@ -71,14 +81,17 @@ class FPN(Layer):
laterals.append(self.lateral_convs[lvl](body_feats[lvl])) laterals.append(self.lateral_convs[lvl](body_feats[lvl]))
for lvl in range(self.max_level - 1, self.min_level, -1): for lvl in range(self.max_level - 1, self.min_level, -1):
upsample = fluid.layers.resize_nearest(laterals[lvl], scale=2.) upsample = F.interpolate(
laterals[lvl],
scale_factor=2.,
mode='nearest', )
laterals[lvl - 1] = laterals[lvl - 1] + upsample laterals[lvl - 1] = laterals[lvl - 1] + upsample
fpn_output = [] fpn_output = []
for lvl in range(self.min_level, self.max_level): for lvl in range(self.min_level, self.max_level):
fpn_output.append(self.fpn_convs[lvl](laterals[lvl])) fpn_output.append(self.fpn_convs[lvl](laterals[lvl]))
extension = fluid.layers.pool2d(fpn_output[-1], 1, 'max', pool_stride=2) extension = F.max_pool2d(fpn_output[-1], 1, stride=2)
spatial_scale = self.spatial_scale + [self.spatial_scale[-1] * 0.5] spatial_scale = self.spatial_scale + [self.spatial_scale[-1] * 0.5]
fpn_output.append(extension) fpn_output.append(extension)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
import paddle.nn.functional as F
from paddle.fluid.framework import Variable, in_dygraph_mode from paddle.fluid.framework import Variable, in_dygraph_mode
from paddle.fluid import core from paddle.fluid import core
...@@ -1509,3 +1510,30 @@ def generate_proposals(scores, ...@@ -1509,3 +1510,30 @@ def generate_proposals(scores,
return rpn_rois, rpn_roi_probs, rpn_rois_num return rpn_rois, rpn_roi_probs, rpn_rois_num
else: else:
return rpn_rois, rpn_roi_probs return rpn_rois, rpn_roi_probs
def sigmoid_cross_entropy_with_logits(input,
label,
ignore_index=-100,
normalize=False):
output = F.binary_cross_entropy_with_logits(input, label, reduction='none')
mask_tensor = paddle.cast(label != ignore_index, 'float32')
output = paddle.multiply(output, mask_tensor)
output = paddle.reshape(output, shape=[output.shape[0], -1])
if normalize:
sum_valid_mask = paddle.sum(mask_tensor)
output = output / sum_valid_mask
return output
def smooth_l1(input, label, inside_weight=None, outside_weight=None,
sigma=None):
input_new = paddle.multiply(input, inside_weight)
label_new = paddle.multiply(label, inside_weight)
delta = 1 / (sigma * sigma)
out = F.smooth_l1_loss(input_new, label_new, reduction='none', delta=delta)
out = paddle.multiply(out, outside_weight)
out = out / delta
out = paddle.reshape(out, shape=[out.shape[0], -1])
out = paddle.sum(out, axis=1)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册