未验证 提交 ee911f6b 编写于 作者: G Guanghua Yu 提交者: GitHub

[Dygraph]update mask_rcnn_fpn model (#1718)

* update mask_rcnn_fpn model

* update smooth_l1_loss

* update smooth_l1 and sigmoid_cross_entropy_with_logits

* fix relu name
上级 31628f60
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
import paddle
from ppdet.core.workspace import register
from .meta_arch import BaseArch
......@@ -114,7 +128,7 @@ class MaskRCNN(BaseArch):
loss_mask = self.mask_head.get_loss(self.mask_head_out, mask_targets)
loss.update(loss_mask)
total_loss = fluid.layers.sums(list(loss.values()))
total_loss = paddle.add_n(list(loss.values()))
loss.update({'loss': total_loss})
return loss
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph import Layer, Sequential
from paddle.fluid.dygraph import Conv2D, Pool2D, BatchNorm
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Constant
from paddle import ParamAttr
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2D, BatchNorm
from paddle.nn import MaxPool2D
from ppdet.core.workspace import register, serializable
from paddle.fluid.regularizer import L2Decay
from paddle.regularizer import L2Decay
from .name_adapter import NameAdapter
from numbers import Integral
class ConvNormLayer(Layer):
class ConvNormLayer(nn.Layer):
def __init__(self,
ch_in,
ch_out,
......@@ -24,19 +41,18 @@ class ConvNormLayer(Layer):
lr=1.0,
name=None):
super(ConvNormLayer, self).__init__()
assert norm_type in ['bn', 'affine_channel']
assert norm_type in ['bn', 'sync_bn']
self.norm_type = norm_type
self.act = act
self.conv = Conv2D(
num_channels=ch_in,
num_filters=ch_out,
filter_size=filter_size,
in_channels=ch_in,
out_channels=ch_out,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=1,
act=None,
param_attr=ParamAttr(
weight_attr=ParamAttr(
learning_rate=lr, name=name + "_weights"),
bias_attr=False)
......@@ -53,30 +69,16 @@ class ConvNormLayer(Layer):
name=bn_name + "_offset",
trainable=False if freeze_norm else True)
if norm_type in ['bn', 'sync_bn']:
global_stats = True if freeze_norm else False
self.norm = BatchNorm(
num_channels=ch_out,
act=act,
param_attr=param_attr,
bias_attr=bias_attr,
use_global_stats=global_stats,
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
norm_params = self.norm.parameters()
elif norm_type == 'affine_channel':
self.scale = fluid.layers.create_parameter(
shape=[ch_out],
dtype='float32',
attr=param_attr,
default_initializer=Constant(1.))
self.offset = fluid.layers.create_parameter(
shape=[ch_out],
dtype='float32',
attr=bias_attr,
default_initializer=Constant(0.))
norm_params = [self.scale, self.offset]
global_stats = True if freeze_norm else False
self.norm = BatchNorm(
ch_out,
act=act,
param_attr=param_attr,
bias_attr=bias_attr,
use_global_stats=global_stats,
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
norm_params = self.norm.parameters()
if freeze_norm:
for param in norm_params:
......@@ -86,13 +88,10 @@ class ConvNormLayer(Layer):
out = self.conv(inputs)
if self.norm_type == 'bn':
out = self.norm(out)
elif self.norm_type == 'affine_channel':
out = fluid.layers.affine_channel(
out, scale=self.scale, bias=self.offset, act=self.act)
return out
class BottleNeck(Layer):
class BottleNeck(nn.Layer):
def __init__(self,
ch_in,
ch_out,
......@@ -176,12 +175,13 @@ class BottleNeck(Layer):
out = self.branch2b(out)
out = self.branch2c(out)
out = fluid.layers.elementwise_add(x=short, y=out, act='relu')
out = paddle.add(x=short, y=out)
out = F.relu(out)
return out
class Blocks(Layer):
class Blocks(nn.Layer):
def __init__(self,
ch_in,
ch_out,
......@@ -226,7 +226,7 @@ ResNet_cfg = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}
@register
@serializable
class ResNet(Layer):
class ResNet(nn.Layer):
def __init__(self,
depth=50,
variant='b',
......@@ -265,7 +265,7 @@ class ResNet(Layer):
]
else:
conv_def = [[3, 64, 7, 2, conv1_name]]
self.conv1 = Sequential()
self.conv1 = nn.Sequential()
for (c_in, c_out, k, s, _name) in conv_def:
self.conv1.add_sublayer(
_name,
......@@ -282,8 +282,7 @@ class ResNet(Layer):
lr=lr_mult,
name=_name))
self.pool = Pool2D(
pool_type='max', pool_size=3, pool_stride=2, pool_padding=1)
self.pool = MaxPool2D(kernel_size=3, stride=2, padding=1)
ch_in_list = [64, 256, 512, 1024]
ch_out_list = [64, 128, 256, 512]
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import Layer
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Normal, Xavier
from paddle.fluid.regularizer import L2Decay
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from ppdet.core.workspace import register
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import ReLU
from paddle.nn.initializer import Normal, XavierUniform
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register
from ppdet.modeling import ops
@register
class TwoFCHead(Layer):
class TwoFCHead(nn.Layer):
__shared__ = ['num_stages']
......@@ -21,48 +35,47 @@ class TwoFCHead(Layer):
self.num_stages = num_stages
fan = in_dim * resolution * resolution
self.fc6_list = []
self.fc6_relu_list = []
self.fc7_list = []
self.fc7_relu_list = []
for stage in range(num_stages):
fc6_name = 'fc6_{}'.format(stage)
fc7_name = 'fc7_{}'.format(stage)
fc6 = self.add_sublayer(
fc6_name,
Linear(
nn.Linear(
in_dim * resolution * resolution,
mlp_dim,
act='relu',
param_attr=ParamAttr(
#name='fc6_w',
initializer=Xavier(fan_out=fan)),
weight_attr=ParamAttr(
initializer=XavierUniform(fan_out=fan)),
bias_attr=ParamAttr(
#name='fc6_b',
learning_rate=2.,
regularizer=L2Decay(0.))))
learning_rate=2., regularizer=L2Decay(0.))))
fc6_relu = self.add_sublayer(fc6_name + 'act', ReLU())
fc7 = self.add_sublayer(
fc7_name,
Linear(
nn.Linear(
mlp_dim,
mlp_dim,
act='relu',
param_attr=ParamAttr(
#name='fc7_w',
initializer=Xavier()),
weight_attr=ParamAttr(initializer=XavierUniform()),
bias_attr=ParamAttr(
#name='fc7_b',
learning_rate=2.,
regularizer=L2Decay(0.))))
learning_rate=2., regularizer=L2Decay(0.))))
fc7_relu = self.add_sublayer(fc7_name + 'act', ReLU())
self.fc6_list.append(fc6)
self.fc6_relu_list.append(fc6_relu)
self.fc7_list.append(fc7)
self.fc7_relu_list.append(fc7_relu)
def forward(self, rois_feat, stage=0):
rois_feat = fluid.layers.flatten(rois_feat)
rois_feat = paddle.flatten(rois_feat, start_axis=1, stop_axis=-1)
fc6 = self.fc6_list[stage](rois_feat)
fc7 = self.fc7_list[stage](fc6)
return fc7
fc6_relu = self.fc6_relu_list[stage](fc6)
fc7 = self.fc7_list[stage](fc6_relu)
fc7_relu = self.fc7_relu_list[stage](fc7)
return fc7_relu
@register
class BBoxFeat(Layer):
class BBoxFeat(nn.Layer):
__inject__ = ['roi_extractor', 'head_feat']
def __init__(self, roi_extractor, head_feat):
......@@ -77,7 +90,7 @@ class BBoxFeat(Layer):
@register
class BBoxHead(Layer):
class BBoxHead(nn.Layer):
__shared__ = ['num_classes', 'num_stages']
__inject__ = ['bbox_feat']
......@@ -105,40 +118,30 @@ class BBoxHead(Layer):
delta_name = 'bbox_delta_{}'.format(stage)
bbox_score = self.add_sublayer(
score_name,
fluid.dygraph.Linear(
input_dim=in_feat,
output_dim=1 * self.num_classes,
act=None,
param_attr=ParamAttr(
#name='cls_score_w',
initializer=Normal(
loc=0.0, scale=0.01)),
nn.Linear(
in_feat,
1 * self.num_classes,
weight_attr=ParamAttr(initializer=Normal(
mean=0.0, std=0.01)),
bias_attr=ParamAttr(
#name='cls_score_b',
learning_rate=2.,
regularizer=L2Decay(0.))))
learning_rate=2., regularizer=L2Decay(0.))))
bbox_delta = self.add_sublayer(
delta_name,
fluid.dygraph.Linear(
input_dim=in_feat,
output_dim=4 * self.delta_dim,
act=None,
param_attr=ParamAttr(
#name='bbox_pred_w',
initializer=Normal(
loc=0.0, scale=0.001)),
nn.Linear(
in_feat,
4 * self.delta_dim,
weight_attr=ParamAttr(initializer=Normal(
mean=0.0, std=0.001)),
bias_attr=ParamAttr(
#name='bbox_pred_b',
learning_rate=2.,
regularizer=L2Decay(0.))))
learning_rate=2., regularizer=L2Decay(0.))))
self.bbox_score_list.append(bbox_score)
self.bbox_delta_list.append(bbox_delta)
def forward(self, body_feats, rois, spatial_scale, stage=0):
bbox_feat = self.bbox_feat(body_feats, rois, spatial_scale, stage)
if self.with_pool:
bbox_feat = fluid.layers.pool2d(
bbox_feat = F.pool2d(
bbox_feat, pool_type='avg', global_pooling=True)
bbox_head_out = []
scores = self.bbox_score_list[stage](bbox_feat)
......@@ -148,20 +151,19 @@ class BBoxHead(Layer):
def _get_head_loss(self, score, delta, target):
# bbox cls
labels_int64 = fluid.layers.cast(
x=target['labels_int32'], dtype='int64')
labels_int64 = paddle.cast(x=target['labels_int32'], dtype='int64')
labels_int64.stop_gradient = True
loss_bbox_cls = fluid.layers.softmax_with_cross_entropy(
loss_bbox_cls = F.softmax_with_cross_entropy(
logits=score, label=labels_int64)
loss_bbox_cls = fluid.layers.reduce_mean(loss_bbox_cls)
loss_bbox_cls = paddle.mean(loss_bbox_cls)
# bbox reg
loss_bbox_reg = fluid.layers.smooth_l1(
x=delta,
y=target['bbox_targets'],
loss_bbox_reg = ops.smooth_l1(
input=delta,
label=target['bbox_targets'],
inside_weight=target['bbox_inside_weights'],
outside_weight=target['bbox_outside_weights'],
sigma=1.0)
loss_bbox_reg = fluid.layers.reduce_mean(loss_bbox_reg)
loss_bbox_reg = paddle.mean(loss_bbox_reg)
return loss_bbox_cls, loss_bbox_reg
def get_loss(self, bbox_head_out, targets):
......
import paddle.fluid as fluid
from paddle.fluid.dygraph import Layer, Sequential
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import MSRA
from paddle.fluid.regularizer import L2Decay
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Conv2DTranspose
import paddle
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn import Layer, Sequential
from paddle.nn import Conv2D, Conv2DTranspose, ReLU
from paddle.nn.initializer import KaimingNormal
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register
from ppdet.modeling import ops
@register
......@@ -37,35 +52,27 @@ class MaskFeat(Layer):
mask_conv.add_sublayer(
conv_name,
Conv2D(
num_channels=feat_in if j == 0 else feat_out,
num_filters=feat_out,
filter_size=3,
act='relu',
in_channels=feat_in if j == 0 else feat_out,
out_channels=feat_out,
kernel_size=3,
padding=1,
param_attr=ParamAttr(
#name=conv_name+'_w',
initializer=MSRA(
uniform=False, fan_in=fan_conv)),
weight_attr=ParamAttr(
initializer=KaimingNormal(fan_in=fan_conv)),
bias_attr=ParamAttr(
#name=conv_name+'_b',
learning_rate=2.,
regularizer=L2Decay(0.))))
learning_rate=2., regularizer=L2Decay(0.))))
mask_conv.add_sublayer(conv_name + 'act', ReLU())
mask_conv.add_sublayer(
'conv5_mask',
Conv2DTranspose(
num_channels=self.feat_in,
num_filters=self.feat_out,
filter_size=2,
in_channels=self.feat_in,
out_channels=self.feat_out,
kernel_size=2,
stride=2,
act='relu',
param_attr=ParamAttr(
#name='conv5_mask_w',
initializer=MSRA(
uniform=False, fan_in=fan_deconv)),
weight_attr=ParamAttr(
initializer=KaimingNormal(fan_in=fan_deconv)),
bias_attr=ParamAttr(
#name='conv5_mask_b',
learning_rate=2.,
regularizer=L2Decay(0.))))
learning_rate=2., regularizer=L2Decay(0.))))
mask_conv.add_sublayer('conv5_mask' + 'act', ReLU())
upsample = self.add_sublayer(name, mask_conv)
self.upsample_module.append(upsample)
......@@ -77,7 +84,7 @@ class MaskFeat(Layer):
spatial_scale,
stage=0):
if self.share_bbox_feat:
rois_feat = fluid.layers.gather(bbox_feat, mask_index)
rois_feat = paddle.gather(bbox_feat, mask_index)
else:
rois_feat = self.mask_roi_extractor(body_feats, bboxes,
spatial_scale)
......@@ -107,18 +114,14 @@ class MaskHead(Layer):
self.mask_fcn_logits.append(
self.add_sublayer(
name,
fluid.dygraph.Conv2D(
num_channels=self.feat_in,
num_filters=self.num_classes,
filter_size=1,
param_attr=ParamAttr(
#name='mask_fcn_logits_w',
initializer=MSRA(
uniform=False, fan_in=self.num_classes)),
Conv2D(
in_channels=self.feat_in,
out_channels=self.num_classes,
kernel_size=1,
weight_attr=ParamAttr(initializer=KaimingNormal(
fan_in=self.num_classes)),
bias_attr=ParamAttr(
#name='mask_fcn_logits_b',
learning_rate=2.,
regularizer=L2Decay(0.0)))))
learning_rate=2., regularizer=L2Decay(0.0)))))
def forward_train(self,
body_feats,
......@@ -150,14 +153,13 @@ class MaskHead(Layer):
for idx, num in enumerate(bbox_num):
for n in range(num):
im_info_expand.append(im_info[idx, -1])
im_info_expand = fluid.layers.concat(im_info_expand)
scaled_bbox = fluid.layers.elementwise_mul(
bbox[:, 2:], im_info_expand, axis=0)
im_info_expand = paddle.concat(im_info_expand)
scaled_bbox = paddle.multiply(bbox[:, 2:], im_info_expand, axis=0)
scaled_bboxes = (scaled_bbox, bbox_num)
mask_feat = self.mask_feat(body_feats, scaled_bboxes, bbox_feat,
mask_index, spatial_scale, stage)
mask_logit = self.mask_fcn_logits[stage](mask_feat)
mask_head_out = fluid.layers.sigmoid(mask_logit)
mask_head_out = F.sigmoid(mask_logit)
return mask_head_out
def forward(self,
......@@ -179,12 +181,14 @@ class MaskHead(Layer):
return mask_head_out
def get_loss(self, mask_head_out, mask_target):
mask_logits = fluid.layers.flatten(mask_head_out)
mask_label = fluid.layers.cast(x=mask_target, dtype='float32')
mask_logits = paddle.flatten(mask_head_out, start_axis=1, stop_axis=-1)
mask_label = paddle.cast(x=mask_target, dtype='float32')
mask_label.stop_gradient = True
loss_mask = fluid.layers.sigmoid_cross_entropy_with_logits(
x=mask_logits, label=mask_label, ignore_index=-1, normalize=True)
loss_mask = fluid.layers.reduce_sum(loss_mask)
loss_mask = ops.sigmoid_cross_entropy_with_logits(
input=mask_logits,
label=mask_label,
ignore_index=-1,
normalize=True)
loss_mask = paddle.sum(loss_mask)
return {'loss_mask': loss_mask}
import paddle.fluid as fluid
from paddle.fluid.dygraph import Layer
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Normal
from paddle.fluid.regularizer import L2Decay
from paddle.fluid.dygraph.nn import Conv2D
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import Normal
from paddle.regularizer import L2Decay
from paddle.nn import Conv2D
from ppdet.core.workspace import register
from ppdet.modeling import ops
@register
class RPNFeat(Layer):
class RPNFeat(nn.Layer):
def __init__(self, feat_in=1024, feat_out=1024):
super(RPNFeat, self).__init__()
# rpn feat is shared with each level
self.rpn_conv = Conv2D(
num_channels=feat_in,
num_filters=feat_out,
filter_size=3,
in_channels=feat_in,
out_channels=feat_out,
kernel_size=3,
padding=1,
act='relu',
param_attr=ParamAttr(
#name="conv_rpn_fpn2_w",
initializer=Normal(
loc=0., scale=0.01)),
weight_attr=ParamAttr(initializer=Normal(
mean=0., std=0.01)),
bias_attr=ParamAttr(
#name="conv_rpn_fpn2_b",
learning_rate=2.,
regularizer=L2Decay(0.)))
learning_rate=2., regularizer=L2Decay(0.)))
def forward(self, inputs, feats):
rpn_feats = []
for feat in feats:
rpn_feats.append(self.rpn_conv(feat))
rpn_feats.append(F.relu(self.rpn_conv(feat)))
return rpn_feats
@register
class RPNHead(Layer):
class RPNHead(nn.Layer):
__inject__ = ['rpn_feat']
def __init__(self, rpn_feat, anchor_per_position=15, rpn_channel=1024):
......@@ -46,35 +58,25 @@ class RPNHead(Layer):
# rpn head is shared with each level
# rpn roi classification scores
self.rpn_rois_score = Conv2D(
num_channels=rpn_channel,
num_filters=anchor_per_position,
filter_size=1,
in_channels=rpn_channel,
out_channels=anchor_per_position,
kernel_size=1,
padding=0,
act=None,
param_attr=ParamAttr(
#name="rpn_cls_logits_fpn2_w",
initializer=Normal(
loc=0., scale=0.01)),
weight_attr=ParamAttr(initializer=Normal(
mean=0., std=0.01)),
bias_attr=ParamAttr(
#name="rpn_cls_logits_fpn2_b",
learning_rate=2.,
regularizer=L2Decay(0.)))
learning_rate=2., regularizer=L2Decay(0.)))
# rpn roi bbox regression deltas
self.rpn_rois_delta = Conv2D(
num_channels=rpn_channel,
num_filters=4 * anchor_per_position,
filter_size=1,
in_channels=rpn_channel,
out_channels=4 * anchor_per_position,
kernel_size=1,
padding=0,
act=None,
param_attr=ParamAttr(
#name="rpn_bbox_pred_fpn2_w",
initializer=Normal(
loc=0., scale=0.01)),
weight_attr=ParamAttr(initializer=Normal(
mean=0., std=0.01)),
bias_attr=ParamAttr(
#name="rpn_bbox_pred_fpn2_b",
learning_rate=2.,
regularizer=L2Decay(0.)))
learning_rate=2., regularizer=L2Decay(0.)))
def forward(self, inputs, feats):
rpn_feats = self.rpn_feat(inputs, feats)
......@@ -87,28 +89,26 @@ class RPNHead(Layer):
def get_loss(self, loss_inputs):
# cls loss
score_tgt = fluid.layers.cast(
score_tgt = paddle.cast(
x=loss_inputs['rpn_score_target'], dtype='float32')
score_tgt.stop_gradient = True
loss_rpn_cls = fluid.layers.sigmoid_cross_entropy_with_logits(
x=loss_inputs['rpn_score_pred'], label=score_tgt)
loss_rpn_cls = fluid.layers.reduce_mean(
loss_rpn_cls, name='loss_rpn_cls')
loss_rpn_cls = ops.sigmoid_cross_entropy_with_logits(
input=loss_inputs['rpn_score_pred'], label=score_tgt)
loss_rpn_cls = paddle.mean(loss_rpn_cls, name='loss_rpn_cls')
# reg loss
loc_tgt = fluid.layers.cast(
x=loss_inputs['rpn_rois_target'], dtype='float32')
loc_tgt = paddle.cast(x=loss_inputs['rpn_rois_target'], dtype='float32')
loc_tgt.stop_gradient = True
loss_rpn_reg = fluid.layers.smooth_l1(
x=loss_inputs['rpn_rois_pred'],
y=loc_tgt,
sigma=3.0,
loss_rpn_reg = ops.smooth_l1(
input=loss_inputs['rpn_rois_pred'],
label=loc_tgt,
inside_weight=loss_inputs['rpn_rois_weight'],
outside_weight=loss_inputs['rpn_rois_weight'])
loss_rpn_reg = fluid.layers.reduce_sum(loss_rpn_reg)
score_shape = fluid.layers.shape(score_tgt)
score_shape = fluid.layers.cast(x=score_shape, dtype='float32')
norm = fluid.layers.reduce_prod(score_shape)
outside_weight=loss_inputs['rpn_rois_weight'],
sigma=3.0, )
loss_rpn_reg = paddle.sum(loss_rpn_reg)
score_shape = paddle.shape(score_tgt)
score_shape = paddle.cast(score_shape, dtype='float32')
norm = paddle.prod(score_shape)
norm.stop_gradient = True
loss_rpn_reg = loss_rpn_reg / norm
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import Layer
from paddle.fluid.dygraph import Conv2D, Pool2D, BatchNorm
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Xavier
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn import Layer
from paddle.nn import Conv2D
from paddle.nn.initializer import XavierUniform
from paddle.fluid.regularizer import L2Decay
from ppdet.core.workspace import register, serializable
......@@ -32,33 +48,27 @@ class FPN(Layer):
lateral = self.add_sublayer(
lateral_name,
Conv2D(
num_channels=in_c,
num_filters=out_channel,
filter_size=1,
param_attr=ParamAttr(
#name=lateral_name+'_w',
initializer=Xavier(fan_out=in_c)),
in_channels=in_c,
out_channels=out_channel,
kernel_size=1,
weight_attr=ParamAttr(
initializer=XavierUniform(fan_out=in_c)),
bias_attr=ParamAttr(
#name=lateral_name+'_b',
learning_rate=2.,
regularizer=L2Decay(0.))))
learning_rate=2., regularizer=L2Decay(0.))))
self.lateral_convs.append(lateral)
fpn_name = 'fpn_res{}_sum'.format(i + 2)
fpn_conv = self.add_sublayer(
fpn_name,
Conv2D(
num_channels=out_channel,
num_filters=out_channel,
filter_size=3,
in_channels=out_channel,
out_channels=out_channel,
kernel_size=3,
padding=1,
param_attr=ParamAttr(
#name=fpn_name+'_w',
initializer=Xavier(fan_out=fan)),
weight_attr=ParamAttr(
initializer=XavierUniform(fan_out=fan)),
bias_attr=ParamAttr(
#name=fpn_name+'_b',
learning_rate=2.,
regularizer=L2Decay(0.))))
learning_rate=2., regularizer=L2Decay(0.))))
self.fpn_convs.append(fpn_conv)
self.min_level = min_level
......@@ -71,14 +81,17 @@ class FPN(Layer):
laterals.append(self.lateral_convs[lvl](body_feats[lvl]))
for lvl in range(self.max_level - 1, self.min_level, -1):
upsample = fluid.layers.resize_nearest(laterals[lvl], scale=2.)
upsample = F.interpolate(
laterals[lvl],
scale_factor=2.,
mode='nearest', )
laterals[lvl - 1] = laterals[lvl - 1] + upsample
fpn_output = []
for lvl in range(self.min_level, self.max_level):
fpn_output.append(self.fpn_convs[lvl](laterals[lvl]))
extension = fluid.layers.pool2d(fpn_output[-1], 1, 'max', pool_stride=2)
extension = F.max_pool2d(fpn_output[-1], 1, stride=2)
spatial_scale = self.spatial_scale + [self.spatial_scale[-1] * 0.5]
fpn_output.append(extension)
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import paddle
import paddle.nn.functional as F
from paddle.fluid.framework import Variable, in_dygraph_mode
from paddle.fluid import core
......@@ -1509,3 +1510,30 @@ def generate_proposals(scores,
return rpn_rois, rpn_roi_probs, rpn_rois_num
else:
return rpn_rois, rpn_roi_probs
def sigmoid_cross_entropy_with_logits(input,
label,
ignore_index=-100,
normalize=False):
output = F.binary_cross_entropy_with_logits(input, label, reduction='none')
mask_tensor = paddle.cast(label != ignore_index, 'float32')
output = paddle.multiply(output, mask_tensor)
output = paddle.reshape(output, shape=[output.shape[0], -1])
if normalize:
sum_valid_mask = paddle.sum(mask_tensor)
output = output / sum_valid_mask
return output
def smooth_l1(input, label, inside_weight=None, outside_weight=None,
sigma=None):
input_new = paddle.multiply(input, inside_weight)
label_new = paddle.multiply(label, inside_weight)
delta = 1 / (sigma * sigma)
out = F.smooth_l1_loss(input_new, label_new, reduction='none', delta=delta)
out = paddle.multiply(out, outside_weight)
out = out / delta
out = paddle.reshape(out, shape=[out.shape[0], -1])
out = paddle.sum(out, axis=1)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册