# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """RPN for MaskRCNN""" import numpy as np import mindspore.nn as nn import mindspore.common.dtype as mstype from mindspore.ops import operations as P from mindspore import Tensor from mindspore.ops import functional as F from mindspore.common.initializer import initializer from .bbox_assign_sample import BboxAssignSample class RpnRegClsBlock(nn.Cell): """ Rpn reg cls block for rpn layer Args: in_channels (int) - Input channels of shared convolution. feat_channels (int) - Output channels of shared convolution. num_anchors (int) - The anchor number. cls_out_channels (int) - Output channels of classification convolution. weight_conv (Tensor) - weight init for rpn conv. bias_conv (Tensor) - bias init for rpn conv. weight_cls (Tensor) - weight init for rpn cls conv. bias_cls (Tensor) - bias init for rpn cls conv. weight_reg (Tensor) - weight init for rpn reg conv. bias_reg (Tensor) - bias init for rpn reg conv. Returns: Tensor, output tensor. """ def __init__(self, in_channels, feat_channels, num_anchors, cls_out_channels, weight_conv, bias_conv, weight_cls, bias_cls, weight_reg, bias_reg): super(RpnRegClsBlock, self).__init__() self.rpn_conv = nn.Conv2d(in_channels, feat_channels, kernel_size=3, stride=1, pad_mode='same', has_bias=True, weight_init=weight_conv, bias_init=bias_conv) self.relu = nn.ReLU() self.rpn_cls = nn.Conv2d(feat_channels, num_anchors * cls_out_channels, kernel_size=1, pad_mode='valid', has_bias=True, weight_init=weight_cls, bias_init=bias_cls) self.rpn_reg = nn.Conv2d(feat_channels, num_anchors * 4, kernel_size=1, pad_mode='valid', has_bias=True, weight_init=weight_reg, bias_init=bias_reg) def construct(self, x): x = self.relu(self.rpn_conv(x)) x1 = self.rpn_cls(x) x2 = self.rpn_reg(x) return x1, x2 class RPN(nn.Cell): """ ROI proposal network.. Args: config (dict) - Config. batch_size (int) - Batchsize. in_channels (int) - Input channels of shared convolution. feat_channels (int) - Output channels of shared convolution. num_anchors (int) - The anchor number. cls_out_channels (int) - Output channels of classification convolution. Returns: Tuple, tuple of output tensor. Examples: RPN(config=config, batch_size=2, in_channels=256, feat_channels=1024, num_anchors=3, cls_out_channels=512) """ def __init__(self, config, batch_size, in_channels, feat_channels, num_anchors, cls_out_channels): super(RPN, self).__init__() cfg_rpn = config self.num_bboxes = cfg_rpn.num_bboxes self.slice_index = () self.feature_anchor_shape = () self.slice_index += (0,) index = 0 for shape in cfg_rpn.feature_shapes: self.slice_index += (self.slice_index[index] + shape[0] * shape[1] * num_anchors,) self.feature_anchor_shape += (shape[0] * shape[1] * num_anchors * batch_size,) index += 1 self.num_anchors = num_anchors self.batch_size = batch_size self.test_batch_size = cfg_rpn.test_batch_size self.num_layers = 5 self.real_ratio = Tensor(np.ones((1, 1)).astype(np.float16)) self.rpn_convs_list = nn.layer.CellList(self._make_rpn_layer(self.num_layers, in_channels, feat_channels, num_anchors, cls_out_channels)) self.transpose = P.Transpose() self.reshape = P.Reshape() self.concat = P.Concat(axis=0) self.fill = P.Fill() self.placeh1 = Tensor(np.ones((1,)).astype(np.float16)) self.trans_shape = (0, 2, 3, 1) self.reshape_shape_reg = (-1, 4) self.reshape_shape_cls = (-1,) self.rpn_loss_reg_weight = Tensor(np.array(cfg_rpn.rpn_loss_reg_weight).astype(np.float16)) self.rpn_loss_cls_weight = Tensor(np.array(cfg_rpn.rpn_loss_cls_weight).astype(np.float16)) self.num_expected_total = Tensor(np.array(cfg_rpn.num_expected_neg * self.batch_size).astype(np.float16)) self.num_bboxes = cfg_rpn.num_bboxes self.get_targets = BboxAssignSample(cfg_rpn, self.batch_size, self.num_bboxes, False) self.CheckValid = P.CheckValid() self.sum_loss = P.ReduceSum() self.loss_cls = P.SigmoidCrossEntropyWithLogits() self.loss_bbox = P.SmoothL1Loss(beta=1.0/9.0) self.squeeze = P.Squeeze() self.cast = P.Cast() self.tile = P.Tile() self.zeros_like = P.ZerosLike() self.loss = Tensor(np.zeros((1,)).astype(np.float16)) self.clsloss = Tensor(np.zeros((1,)).astype(np.float16)) self.regloss = Tensor(np.zeros((1,)).astype(np.float16)) def _make_rpn_layer(self, num_layers, in_channels, feat_channels, num_anchors, cls_out_channels): """ make rpn layer for rpn proposal network Args: num_layers (int) - layer num. in_channels (int) - Input channels of shared convolution. feat_channels (int) - Output channels of shared convolution. num_anchors (int) - The anchor number. cls_out_channels (int) - Output channels of classification convolution. Returns: List, list of RpnRegClsBlock cells. """ rpn_layer = [] shp_weight_conv = (feat_channels, in_channels, 3, 3) shp_bias_conv = (feat_channels,) weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=mstype.float16).to_tensor() bias_conv = initializer(0, shape=shp_bias_conv, dtype=mstype.float16).to_tensor() shp_weight_cls = (num_anchors * cls_out_channels, feat_channels, 1, 1) shp_bias_cls = (num_anchors * cls_out_channels,) weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=mstype.float16).to_tensor() bias_cls = initializer(0, shape=shp_bias_cls, dtype=mstype.float16).to_tensor() shp_weight_reg = (num_anchors * 4, feat_channels, 1, 1) shp_bias_reg = (num_anchors * 4,) weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=mstype.float16).to_tensor() bias_reg = initializer(0, shape=shp_bias_reg, dtype=mstype.float16).to_tensor() for i in range(num_layers): rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \ weight_conv, bias_conv, weight_cls, \ bias_cls, weight_reg, bias_reg)) for i in range(1, num_layers): rpn_layer[i].rpn_conv.weight = rpn_layer[0].rpn_conv.weight rpn_layer[i].rpn_cls.weight = rpn_layer[0].rpn_cls.weight rpn_layer[i].rpn_reg.weight = rpn_layer[0].rpn_reg.weight rpn_layer[i].rpn_conv.bias = rpn_layer[0].rpn_conv.bias rpn_layer[i].rpn_cls.bias = rpn_layer[0].rpn_cls.bias rpn_layer[i].rpn_reg.bias = rpn_layer[0].rpn_reg.bias return rpn_layer def construct(self, inputs, img_metas, anchor_list, gt_bboxes, gt_labels, gt_valids): loss_print = () rpn_cls_score = () rpn_bbox_pred = () rpn_cls_score_total = () rpn_bbox_pred_total = () for i in range(self.num_layers): x1, x2 = self.rpn_convs_list[i](inputs[i]) rpn_cls_score_total = rpn_cls_score_total + (x1,) rpn_bbox_pred_total = rpn_bbox_pred_total + (x2,) x1 = self.transpose(x1, self.trans_shape) x1 = self.reshape(x1, self.reshape_shape_cls) x2 = self.transpose(x2, self.trans_shape) x2 = self.reshape(x2, self.reshape_shape_reg) rpn_cls_score = rpn_cls_score + (x1,) rpn_bbox_pred = rpn_bbox_pred + (x2,) loss = self.loss clsloss = self.clsloss regloss = self.regloss bbox_targets = () bbox_weights = () labels = () label_weights = () output = () if self.training: for i in range(self.batch_size): multi_level_flags = () anchor_list_tuple = () for j in range(self.num_layers): res = self.cast(self.CheckValid(anchor_list[j], self.squeeze(img_metas[i:i + 1:1, ::])), mstype.int32) multi_level_flags = multi_level_flags + (res,) anchor_list_tuple = anchor_list_tuple + (anchor_list[j],) valid_flag_list = self.concat(multi_level_flags) anchor_using_list = self.concat(anchor_list_tuple) gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::]) gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::]) gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::]) bbox_target, bbox_weight, label, label_weight = self.get_targets(gt_bboxes_i, gt_labels_i, self.cast(valid_flag_list, mstype.bool_), anchor_using_list, gt_valids_i) bbox_weight = self.cast(bbox_weight, mstype.float16) label = self.cast(label, mstype.float16) label_weight = self.cast(label_weight, mstype.float16) for j in range(self.num_layers): begin = self.slice_index[j] end = self.slice_index[j + 1] stride = 1 bbox_targets += (bbox_target[begin:end:stride, ::],) bbox_weights += (bbox_weight[begin:end:stride],) labels += (label[begin:end:stride],) label_weights += (label_weight[begin:end:stride],) for i in range(self.num_layers): bbox_target_using = () bbox_weight_using = () label_using = () label_weight_using = () for j in range(self.batch_size): bbox_target_using += (bbox_targets[i + (self.num_layers * j)],) bbox_weight_using += (bbox_weights[i + (self.num_layers * j)],) label_using += (labels[i + (self.num_layers * j)],) label_weight_using += (label_weights[i + (self.num_layers * j)],) bbox_target_with_batchsize = self.concat(bbox_target_using) bbox_weight_with_batchsize = self.concat(bbox_weight_using) label_with_batchsize = self.concat(label_using) label_weight_with_batchsize = self.concat(label_weight_using) # stop bbox_target_ = F.stop_gradient(bbox_target_with_batchsize) bbox_weight_ = F.stop_gradient(bbox_weight_with_batchsize) label_ = F.stop_gradient(label_with_batchsize) label_weight_ = F.stop_gradient(label_weight_with_batchsize) cls_score_i = rpn_cls_score[i] reg_score_i = rpn_bbox_pred[i] loss_cls = self.loss_cls(cls_score_i, label_) loss_cls_item = loss_cls * label_weight_ loss_cls_item = self.sum_loss(loss_cls_item, (0,)) / self.num_expected_total loss_reg = self.loss_bbox(reg_score_i, bbox_target_) bbox_weight_ = self.tile(self.reshape(bbox_weight_, (self.feature_anchor_shape[i], 1)), (1, 4)) loss_reg = loss_reg * bbox_weight_ loss_reg_item = self.sum_loss(loss_reg, (1,)) loss_reg_item = self.sum_loss(loss_reg_item, (0,)) / self.num_expected_total loss_total = self.rpn_loss_cls_weight * loss_cls_item + self.rpn_loss_reg_weight * loss_reg_item loss += loss_total loss_print += (loss_total, loss_cls_item, loss_reg_item) clsloss += loss_cls_item regloss += loss_reg_item output = (loss, rpn_cls_score_total, rpn_bbox_pred_total, clsloss, regloss, loss_print) else: output = (self.placeh1, rpn_cls_score_total, rpn_bbox_pred_total, self.placeh1, self.placeh1, self.placeh1) return output