提交 a179a85e 编写于 作者: M meixiaowei

upload maskrcnn scripts

上级 b29fab3e
# MaskRcnn Example
## Description
MaskRcnn is a two-stage target detection network,This network uses a region proposal network (RPN), which can share the convolution features of the whole image with the detection network, so that the calculation of region proposal is almost cost free. The whole network further combines RPN and MaskRcnn into a network by sharing the convolution features.
## Requirements
- Install [MindSpore](https://www.mindspore.cn/install/en).
- Download the dataset COCO2017.
- We use coco2017 as training dataset in this example by default, and you can also use your own datasets.
1. If coco dataset is used. **Select dataset to coco when run script.**
Install Cython and pycocotool, and you can also install mmcv to process data.
```
pip install Cython
pip install pycocotools
pip install mmcv
```
And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows:
```
.
└─cocodataset
├─annotations
├─instance_train2017.json
└─instance_val2017.json
├─val2017
└─train2017
```
2. If your own dataset is used. **Select dataset to other when run script.**
Organize the dataset infomation into a TXT file, each row in the file is as follows:
```
train2017/0000001.jpg 0,259,401,459,7 35,28,324,201,2 0,30,59,80,2
```
Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `IMAGE_DIR`(dataset directory) and the relative path in `ANNO_PATH`(the TXT file path), `IMAGE_DIR` and `ANNO_PATH` are setting in `config.py`.
## Example structure
```shell
.
└─MaskRcnn
├─README.md
├─scripts
├─run_download_process_data.sh
├─run_standalone_train.sh
├─run_train.sh
└─run_eval.sh
├─src
├─MaskRcnn
├─__init__.py
├─anchor_generator.py
├─bbox_assign_sample.py
├─bbox_assign_sample_stage2.py
├─mask_rcnn_r50.py
├─fpn_neck.py
├─proposal_generator.py
├─rcnn_cls.py
├─rcnn_mask.py
├─resnet50.py
├─roi_align.py
└─rpn.py
├─config.py
├─dataset.py
├─lr_schedule.py
├─network_define.py
└─util.py
├─eval.py
└─train.py
```
## Running the example
### Train
#### Usage
```
# distributed training
sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [PRETRAINED_MODEL]
# standalone training
sh run_standalone_train.sh [PRETRAINED_MODEL]
```
> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html).
#### Result
Training result will be stored in the example path, whose folder name begins with "train" or "train_parallel". You can find checkpoint file together with result like the followings in loss.log.
```
# distribute training result(8p)
epoch: 1 step: 7393 ,rpn_loss: 0.10626, rcnn_loss: 0.81592, rpn_cls_loss: 0.05862, rpn_reg_loss: 0.04761, rcnn_cls_loss: 0.32642, rcnn_reg_loss: 0.15503, rcnn_mask_loss: 0.33447, total_loss: 0.92218
epoch: 2 step: 7393 ,rpn_loss: 0.00911, rcnn_loss: 0.34082, rpn_cls_loss: 0.00341, rpn_reg_loss: 0.00571, rcnn_cls_loss: 0.07440, rcnn_reg_loss: 0.05872, rcnn_mask_loss: 0.20764, total_loss: 0.34993
epoch: 3 step: 7393 ,rpn_loss: 0.02087, rcnn_loss: 0.98633, rpn_cls_loss: 0.00665, rpn_reg_loss: 0.01422, rcnn_cls_loss: 0.35913, rcnn_reg_loss: 0.21375, rcnn_mask_loss: 0.41382, total_loss: 1.00720
...
epoch: 10 step: 7393 ,rpn_loss: 0.02122, rcnn_loss: 0.55176, rpn_cls_loss: 0.00620, rpn_reg_loss: 0.01503, rcnn_cls_loss: 0.12708, rcnn_reg_loss: 0.10254, rcnn_mask_loss: 0.32227, total_loss: 0.57298
epoch: 11 step: 7393 ,rpn_loss: 0.03772, rcnn_loss: 0.60791, rpn_cls_loss: 0.03058, rpn_reg_loss: 0.00713, rcnn_cls_loss: 0.23987, rcnn_reg_loss: 0.11743, rcnn_mask_loss: 0.25049, total_loss: 0.64563
epoch: 12 step: 7393 ,rpn_loss: 0.06482, rcnn_loss: 0.47681, rpn_cls_loss: 0.04770, rpn_reg_loss: 0.01709, rcnn_cls_loss: 0.16492, rcnn_reg_loss: 0.04990, rcnn_mask_loss: 0.26196, total_loss: 0.54163
```
### Evaluation
#### Usage
```
# infer
sh run_eval.sh [VALIDATION_DATASET_PATH] [CHECKPOINT_PATH]
```
> checkpoint can be produced in training process.
#### Result
Inference result will be stored in the example path, whose folder name is "eval". Under this, you can find result like the followings in log.
```
Evaluate annotation type *bbox*
Accumulating evaluation results...
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.366
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.591
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.393
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.241
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.405
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.454
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.304
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.492
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.521
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.372
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.560
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.637
Evaluate annotation type *segm*
Accumulating evaluation results...
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.318
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.546
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.332
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.165
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.348
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.449
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.272
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.421
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.440
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.292
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.479
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.558
```
\ No newline at end of file
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation for MaskRcnn"""
import os
import argparse
import time
import random
import numpy as np
from pycocotools.coco import COCO
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset.engine as de
from src.MaskRcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
from src.config import config
from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset
from src.util import coco_eval, bbox2result_1image, results2json, get_seg_masks
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
parser = argparse.ArgumentParser(description="MaskRcnn evaluation")
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
parser.add_argument("--ann_file", type=str, default="val.json", help="Ann file, default is val.json.")
parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=args_opt.device_id)
def MaskRcnn_eval(dataset_path, ckpt_path, ann_file):
"""MaskRcnn evaluation."""
ds = create_maskrcnn_dataset(dataset_path, batch_size=config.test_batch_size, is_training=False)
net = Mask_Rcnn_Resnet50(config)
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
net.set_train(False)
eval_iter = 0
total = ds.get_dataset_size()
outputs = []
dataset_coco = COCO(ann_file)
print("\n========================================\n")
print("total images num: ", total)
print("Processing, please wait a moment.")
max_num = 128
for data in ds.create_dict_iterator():
eval_iter = eval_iter + 1
img_data = data['image']
img_metas = data['image_shape']
gt_bboxes = data['box']
gt_labels = data['label']
gt_num = data['valid_num']
gt_mask = data["mask"]
start = time.time()
# run net
output = net(Tensor(img_data), Tensor(img_metas), Tensor(gt_bboxes), Tensor(gt_labels), Tensor(gt_num),
Tensor(gt_mask))
end = time.time()
print("Iter {} cost time {}".format(eval_iter, end - start))
# output
all_bbox = output[0]
all_label = output[1]
all_mask = output[2]
all_mask_fb = output[3]
for j in range(config.test_batch_size):
all_bbox_squee = np.squeeze(all_bbox.asnumpy()[j, :, :])
all_label_squee = np.squeeze(all_label.asnumpy()[j, :, :])
all_mask_squee = np.squeeze(all_mask.asnumpy()[j, :, :])
all_mask_fb_squee = np.squeeze(all_mask_fb.asnumpy()[j, :, :, :])
all_bboxes_tmp_mask = all_bbox_squee[all_mask_squee, :]
all_labels_tmp_mask = all_label_squee[all_mask_squee]
all_mask_fb_tmp_mask = all_mask_fb_squee[all_mask_squee, :, :]
if all_bboxes_tmp_mask.shape[0] > max_num:
inds = np.argsort(-all_bboxes_tmp_mask[:, -1])
inds = inds[:max_num]
all_bboxes_tmp_mask = all_bboxes_tmp_mask[inds]
all_labels_tmp_mask = all_labels_tmp_mask[inds]
all_mask_fb_tmp_mask = all_mask_fb_tmp_mask[inds]
bbox_results = bbox2result_1image(all_bboxes_tmp_mask, all_labels_tmp_mask, config.num_classes)
segm_results = get_seg_masks(all_mask_fb_tmp_mask, all_bboxes_tmp_mask, all_labels_tmp_mask, img_metas[j],
True, config.num_classes)
outputs.append((bbox_results, segm_results))
eval_types = ["bbox", "segm"]
result_files = results2json(dataset_coco, outputs, "./results.pkl")
coco_eval(result_files, eval_types, dataset_coco, single_result=False)
if __name__ == '__main__':
prefix = "MaskRcnn_eval.mindrecord"
mindrecord_dir = config.mindrecord_dir
mindrecord_file = os.path.join(mindrecord_dir, prefix)
if not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
if args_opt.dataset == "coco":
if os.path.isdir(config.coco_root):
print("Create Mindrecord.")
data_to_mindrecord_byte_image("coco", False, prefix, file_num=1)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("coco_root not exits.")
else:
if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH):
print("Create Mindrecord.")
data_to_mindrecord_byte_image("other", False, prefix, file_num=1)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("IMAGE_DIR or ANNO_PATH not exits.")
print("Start Eval!")
MaskRcnn_eval(mindrecord_file, args_opt.checkpoint_path, args_opt.ann_file)
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [PRETRAINED_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
echo $PATH1
echo $PATH2
if [ ! -f $PATH1 ]
then
echo "error: MINDSPORE_HCCL_CONFIG_PATH=$PATH1 is not a file"
exit 1
fi
if [ ! -f $PATH2 ]
then
echo "error: PRETRAINED_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export MINDSPORE_HCCL_CONFIG_PATH=$PATH1
export RANK_TABLE_FILE=$PATH1
echo 3 > /proc/sys/vm/drop_caches
cpus=`cat /proc/cpuinfo| grep "processor"| wc -l`
avg=`expr $cpus \/ $RANK_SIZE`
gap=`expr $avg \- 1`
for((i=0; i<${DEVICE_NUM}; i++))
do
start=`expr $i \* $avg`
end=`expr $start \+ $gap`
cmdopt=$start"-"$end
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
taskset -c $cmdopt python train.py --do_train=True --device_id=$i --rank_id=$i --run_distribute=True --device_num=$DEVICE_NUM \
--pre_trained=$PATH2 &> log &
cd ..
done
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_eval.sh [ANN_FILE] [CHECKPOINT_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
echo $PATH1
echo $PATH2
if [ ! -f $PATH1 ]
then
echo "error: ANN_FILE=$PATH1 is not a file"
exit 1
fi
if [ ! -f $PATH2 ]
then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export RANK_SIZE=$DEVICE_NUM
export DEVICE_ID=0
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start eval for device $DEVICE_ID"
python eval.py --device_id=$DEVICE_ID --ann_file=$PATH1 --checkpoint_path=$PATH2 &> log &
cd ..
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]
then
echo "Usage: sh run_standalone_train.sh [PRETRAINED_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
echo $PATH1
if [ ! -f $PATH1 ]
then
echo "error: PRETRAINED_PATH=$PATH1 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py --do_train=True --device_id=$DEVICE_ID --pre_trained=$PATH1 &> log &
cd ..
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaskRcnn Init."""
from .resnet50 import ResNetFea, ResidualBlockUsing
from .bbox_assign_sample import BboxAssignSample
from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn
from .fpn_neck import FeatPyramidNeck
from .proposal_generator import Proposal
from .rcnn_cls import RcnnCls
from .rcnn_mask import RcnnMask
from .rpn import RPN
from .roi_align import SingleRoIExtractor
from .anchor_generator import AnchorGenerator
__all__ = [
"ResNetFea", "BboxAssignSample", "BboxAssignSampleForRcnn",
"FeatPyramidNeck", "Proposal", "RcnnCls", "RcnnMask",
"RPN", "SingleRoIExtractor", "AnchorGenerator", "ResidualBlockUsing"
]
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaskRcnn anchor generator."""
import numpy as np
class AnchorGenerator():
"""Anchor generator for MasKRcnn."""
def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None):
"""Anchor generator init method."""
self.base_size = base_size
self.scales = np.array(scales)
self.ratios = np.array(ratios)
self.scale_major = scale_major
self.ctr = ctr
self.base_anchors = self.gen_base_anchors()
def gen_base_anchors(self):
"""Generate a single anchor."""
w = self.base_size
h = self.base_size
if self.ctr is None:
x_ctr = 0.5 * (w - 1)
y_ctr = 0.5 * (h - 1)
else:
x_ctr, y_ctr = self.ctr
h_ratios = np.sqrt(self.ratios)
w_ratios = 1 / h_ratios
if self.scale_major:
ws = (w * w_ratios[:, None] * self.scales[None, :]).reshape(-1)
hs = (h * h_ratios[:, None] * self.scales[None, :]).reshape(-1)
else:
ws = (w * self.scales[:, None] * w_ratios[None, :]).reshape(-1)
hs = (h * self.scales[:, None] * h_ratios[None, :]).reshape(-1)
base_anchors = np.stack(
[
x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
],
axis=-1).round()
return base_anchors
def _meshgrid(self, x, y, row_major=True):
"""Generate grid."""
xx = np.repeat(x.reshape(1, len(x)), len(y), axis=0).reshape(-1)
yy = np.repeat(y, len(x))
if row_major:
return xx, yy
return yy, xx
def grid_anchors(self, featmap_size, stride=16):
"""Generate anchor list."""
base_anchors = self.base_anchors
feat_h, feat_w = featmap_size
shift_x = np.arange(0, feat_w) * stride
shift_y = np.arange(0, feat_h) * stride
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
shifts = np.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1)
shifts = shifts.astype(base_anchors.dtype)
# first feat_w elements correspond to the first row of shifts
# add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
# shifted anchors (K, A, 4), reshape to (K*A, 4)
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
all_anchors = all_anchors.reshape(-1, 4)
return all_anchors
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaskRcnn positive and negative sample screening for RPN."""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
class BboxAssignSample(nn.Cell):
"""
Bbox assigner and sampler defination.
Args:
config (dict): Config.
batch_size (int): Batchsize.
num_bboxes (int): The anchor nums.
add_gt_as_proposals (bool): add gt bboxes as proposals flag.
Returns:
Tensor, output tensor.
bbox_targets: bbox location, (batch_size, num_bboxes, 4)
bbox_weights: bbox weights, (batch_size, num_bboxes, 1)
labels: label for every bboxes, (batch_size, num_bboxes, 1)
label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1)
Examples:
BboxAssignSample(config, 2, 1024, True)
"""
def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
super(BboxAssignSample, self).__init__()
cfg = config
self.batch_size = batch_size
self.neg_iou_thr = Tensor(cfg.neg_iou_thr, mstype.float16)
self.pos_iou_thr = Tensor(cfg.pos_iou_thr, mstype.float16)
self.min_pos_iou = Tensor(cfg.min_pos_iou, mstype.float16)
self.zero_thr = Tensor(0.0, mstype.float16)
self.num_bboxes = num_bboxes
self.num_gts = cfg.num_gts
self.num_expected_pos = cfg.num_expected_pos
self.num_expected_neg = cfg.num_expected_neg
self.add_gt_as_proposals = add_gt_as_proposals
if self.add_gt_as_proposals:
self.label_inds = Tensor(np.arange(1, self.num_gts + 1))
self.concat = P.Concat(axis=0)
self.max_gt = P.ArgMaxWithValue(axis=0)
self.max_anchor = P.ArgMaxWithValue(axis=1)
self.sum_inds = P.ReduceSum()
self.iou = P.IOU()
self.greaterequal = P.GreaterEqual()
self.greater = P.Greater()
self.select = P.Select()
self.gatherND = P.GatherNd()
self.squeeze = P.Squeeze()
self.cast = P.Cast()
self.logicaland = P.LogicalAnd()
self.less = P.Less()
self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos)
self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg)
self.reshape = P.Reshape()
self.equal = P.Equal()
self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0))
self.scatterNdUpdate = P.ScatterNdUpdate()
self.scatterNd = P.ScatterNd()
self.logicalnot = P.LogicalNot()
self.tile = P.Tile()
self.zeros_like = P.ZerosLike()
self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32))
self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32))
self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32))
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16))
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16))
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16))
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids):
gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \
(self.num_gts, 1)), (1, 4)), mstype.bool_), gt_bboxes_i, self.check_gt_one)
bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \
(self.num_bboxes, 1)), (1, 4)), mstype.bool_), bboxes, self.check_anchor_two)
overlaps = self.iou(bboxes, gt_bboxes_i)
max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps)
_, max_overlaps_w_ac = self.max_anchor(overlaps)
neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt, self.zero_thr), \
self.less(max_overlaps_w_gt, self.neg_iou_thr))
assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds)
pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.pos_iou_thr)
assigned_gt_inds3 = self.select(pos_sample_iou_mask, \
max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2)
assigned_gt_inds4 = assigned_gt_inds3
for j in range(self.num_gts):
max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1]
overlaps_w_gt_j = self.squeeze(overlaps[j:j+1:1, ::])
pos_mask_j = self.logicaland(self.greaterequal(max_overlaps_w_ac_j, self.min_pos_iou), \
self.equal(overlaps_w_gt_j, max_overlaps_w_ac_j))
assigned_gt_inds4 = self.select(pos_mask_j, self.assigned_gt_ones + j, assigned_gt_inds4)
assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds4, self.assigned_gt_ignores)
pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0))
pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16)
pos_check_valid = self.sum_inds(pos_check_valid, -1)
valid_pos_index = self.less(self.range_pos_size, pos_check_valid)
pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1))
pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones
pos_assigned_gt_index = pos_assigned_gt_index * self.cast(valid_pos_index, mstype.int32)
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, (self.num_expected_pos, 1))
neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0))
num_pos = self.cast(self.logicalnot(valid_pos_index), mstype.float16)
num_pos = self.sum_inds(num_pos, -1)
unvalid_pos_index = self.less(self.range_pos_size, num_pos)
valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index)
pos_bboxes_ = self.gatherND(bboxes, pos_index)
pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index)
pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index)
pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_)
valid_pos_index = self.cast(valid_pos_index, mstype.int32)
valid_neg_index = self.cast(valid_neg_index, mstype.int32)
bbox_targets_total = self.scatterNd(pos_index, pos_bbox_targets_, (self.num_bboxes, 4))
bbox_weights_total = self.scatterNd(pos_index, valid_pos_index, (self.num_bboxes,))
labels_total = self.scatterNd(pos_index, pos_gt_labels, (self.num_bboxes,))
total_index = self.concat((pos_index, neg_index))
total_valid_index = self.concat((valid_pos_index, valid_neg_index))
label_weights_total = self.scatterNd(total_index, total_valid_index, (self.num_bboxes,))
return bbox_targets_total, self.cast(bbox_weights_total, mstype.bool_), \
labels_total, self.cast(label_weights_total, mstype.bool_)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaskRcnn tpositive and negative sample screening for Rcnn."""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
class BboxAssignSampleForRcnn(nn.Cell):
"""
Bbox assigner and sampler defination.
Args:
config (dict): Config.
batch_size (int): Batchsize.
num_bboxes (int): The anchor nums.
add_gt_as_proposals (bool): add gt bboxes as proposals flag.
Returns:
Tensor, multiple output tensors.
Examples:
BboxAssignSampleForRcnn(config, 2, 1024, True)
"""
def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
super(BboxAssignSampleForRcnn, self).__init__()
cfg = config
self.batch_size = batch_size
self.neg_iou_thr = cfg.neg_iou_thr_stage2
self.pos_iou_thr = cfg.pos_iou_thr_stage2
self.min_pos_iou = cfg.min_pos_iou_stage2
self.num_gts = cfg.num_gts
self.num_bboxes = num_bboxes
self.num_expected_pos = cfg.num_expected_pos_stage2
self.num_expected_neg = cfg.num_expected_neg_stage2
self.num_expected_total = cfg.num_expected_total_stage2
self.add_gt_as_proposals = add_gt_as_proposals
self.label_inds = Tensor(np.arange(1, self.num_gts + 1).astype(np.int32))
self.add_gt_as_proposals_valid = Tensor(np.array(self.add_gt_as_proposals * np.ones(self.num_gts),
dtype=np.int32))
self.concat = P.Concat(axis=0)
self.max_gt = P.ArgMaxWithValue(axis=0)
self.max_anchor = P.ArgMaxWithValue(axis=1)
self.sum_inds = P.ReduceSum()
self.iou = P.IOU()
self.greaterequal = P.GreaterEqual()
self.greater = P.Greater()
self.select = P.Select()
self.gatherND = P.GatherNd()
self.squeeze = P.Squeeze()
self.cast = P.Cast()
self.logicaland = P.LogicalAnd()
self.less = P.Less()
self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos)
self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg)
self.reshape = P.Reshape()
self.equal = P.Equal()
self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(10.0, 10.0, 5.0, 5.0))
self.concat_axis1 = P.Concat(axis=1)
self.logicalnot = P.LogicalNot()
self.tile = P.Tile()
# Check
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16))
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16))
# Init tensor
self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32))
self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32))
self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32))
self.gt_ignores = Tensor(np.array(-1 * np.ones(self.num_gts), dtype=np.int32))
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16))
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
self.bboxs_neg_mask = Tensor(np.zeros((self.num_expected_neg, 4), dtype=np.float16))
self.labels_neg_mask = Tensor(np.array(np.zeros(self.num_expected_neg), dtype=np.uint8))
self.reshape_shape_pos = (self.num_expected_pos, 1)
self.reshape_shape_neg = (self.num_expected_neg, 1)
self.scalar_zero = Tensor(0.0, dtype=mstype.float16)
self.scalar_neg_iou_thr = Tensor(self.neg_iou_thr, dtype=mstype.float16)
self.scalar_pos_iou_thr = Tensor(self.pos_iou_thr, dtype=mstype.float16)
self.scalar_min_pos_iou = Tensor(self.min_pos_iou, dtype=mstype.float16)
self.expand_dims = P.ExpandDims()
self.split = P.Split(axis=1, output_num=4)
self.concat_last_axis = P.Concat(axis=-1)
self.round = P.Round()
self.image_h_w = Tensor([cfg.img_height, cfg.img_width, cfg.img_height, cfg.img_width], dtype=mstype.float16)
self.range = nn.Range(start=0, limit=cfg.num_expected_pos_stage2)
self.crop_and_resize = P.CropAndResize()
self.mask_shape = (cfg.mask_shape[0], cfg.mask_shape[1])
self.squeeze_mask_last = P.Squeeze(axis=-1)
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids, gt_masks_i):
gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \
(self.num_gts, 1)), (1, 4)), mstype.bool_), \
gt_bboxes_i, self.check_gt_one)
bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \
(self.num_bboxes, 1)), (1, 4)), mstype.bool_), \
bboxes, self.check_anchor_two)
# 1 dim = gt, 2 dim = bbox
overlaps = self.iou(bboxes, gt_bboxes_i)
max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps)
_, max_overlaps_w_ac = self.max_anchor(overlaps)
neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt,
self.scalar_zero),
self.less(max_overlaps_w_gt,
self.scalar_neg_iou_thr))
assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds)
pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.scalar_pos_iou_thr)
assigned_gt_inds3 = self.select(pos_sample_iou_mask, \
max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2)
for j in range(self.num_gts):
max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1]
overlaps_w_ac_j = overlaps[j:j+1:1, ::]
temp1 = self.greaterequal(max_overlaps_w_ac_j, self.scalar_min_pos_iou)
temp2 = self.squeeze(self.equal(overlaps_w_ac_j, max_overlaps_w_ac_j))
pos_mask_j = self.logicaland(temp1, temp2)
assigned_gt_inds3 = self.select(pos_mask_j, (j+1)*self.assigned_gt_ones, assigned_gt_inds3)
assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds3, self.assigned_gt_ignores)
bboxes = self.concat((gt_bboxes_i, bboxes))
label_inds_valid = self.select(gt_valids, self.label_inds, self.gt_ignores)
label_inds_valid = label_inds_valid * self.add_gt_as_proposals_valid
assigned_gt_inds5 = self.concat((label_inds_valid, assigned_gt_inds5))
# Get pos index
pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0))
pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16)
pos_check_valid = self.sum_inds(pos_check_valid, -1)
valid_pos_index = self.less(self.range_pos_size, pos_check_valid)
pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1))
num_pos = self.sum_inds(self.cast(self.logicalnot(valid_pos_index), mstype.float16), -1)
valid_pos_index = self.cast(valid_pos_index, mstype.int32)
pos_index = self.reshape(pos_index, self.reshape_shape_pos)
valid_pos_index = self.reshape(valid_pos_index, self.reshape_shape_pos)
pos_index = pos_index * valid_pos_index
pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, self.reshape_shape_pos)
pos_assigned_gt_index = pos_assigned_gt_index * valid_pos_index
pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index)
# Get neg index
neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0))
unvalid_pos_index = self.less(self.range_pos_size, num_pos)
valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index)
neg_index = self.reshape(neg_index, self.reshape_shape_neg)
valid_neg_index = self.cast(valid_neg_index, mstype.int32)
valid_neg_index = self.reshape(valid_neg_index, self.reshape_shape_neg)
neg_index = neg_index * valid_neg_index
pos_bboxes_ = self.gatherND(bboxes, pos_index)
neg_bboxes_ = self.gatherND(bboxes, neg_index)
pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, self.reshape_shape_pos)
pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index)
pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_)
# assign positive ROIs to gt masks
# Pick the right front and background mask for each ROI
roi_pos_masks_fb = self.gatherND(gt_masks_i, pos_assigned_gt_index)
pos_masks_fb = self.cast(roi_pos_masks_fb, mstype.float32)
# compute mask targets
x1, y1, x2, y2 = self.split(pos_bboxes_)
boxes = self.concat_last_axis((y1, x1, y2, x2))
# normalized box coordinate
boxes = boxes / self.image_h_w
box_ids = self.range()
pos_masks_fb = self.expand_dims(pos_masks_fb, -1)
boxes = self.cast(boxes, mstype.float32)
pos_masks_fb = self.crop_and_resize(pos_masks_fb, boxes, box_ids, self.mask_shape)
# Remove the extra dimension from masks.
pos_masks_fb = self.squeeze_mask_last(pos_masks_fb)
# convert gt masks targets be 0 or 1 to use with binary cross entropy loss.
pos_masks_fb = self.round(pos_masks_fb)
pos_masks_fb = self.cast(pos_masks_fb, mstype.float16)
total_bboxes = self.concat((pos_bboxes_, neg_bboxes_))
total_deltas = self.concat((pos_bbox_targets_, self.bboxs_neg_mask))
total_labels = self.concat((pos_gt_labels, self.labels_neg_mask))
valid_pos_index = self.reshape(valid_pos_index, self.reshape_shape_pos)
valid_neg_index = self.reshape(valid_neg_index, self.reshape_shape_neg)
total_mask = self.concat((valid_pos_index, valid_neg_index))
return total_bboxes, total_deltas, total_labels, total_mask, pos_bboxes_, pos_masks_fb, \
pos_gt_labels, valid_pos_index
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaskRcnn feature pyramid network."""
import numpy as np
import mindspore.nn as nn
from mindspore import context
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
def bias_init_zeros(shape):
"""Bias init method."""
return Tensor(np.array(np.zeros(shape).astype(np.float32)).astype(np.float16))
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
"""Conv2D wrapper."""
shape = (out_channels, in_channels, kernel_size, kernel_size)
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16).to_tensor()
shape_bias = (out_channels,)
biass = bias_init_zeros(shape_bias)
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=True, bias_init=biass)
class FeatPyramidNeck(nn.Cell):
"""
Feature pyramid network cell, usually uses as network neck.
Applies the convolution on multiple, input feature maps
and output feature map with same channel size. if required num of
output larger then num of inputs, add extra maxpooling for further
downsampling;
Args:
in_channels (tuple) - Channel size of input feature maps.
out_channels (int) - Channel size output.
num_outs (int) - Num of output features.
Returns:
Tuple, with tensors of same channel size.
Examples:
neck = FeatPyramidNeck([100,200,300], 50, 4)
input_data = (normal(0,0.1,(1,c,1280//(4*2**i), 768//(4*2**i)),
dtype=np.float32) \
for i, c in enumerate(config.fpn_in_channels))
x = neck(input_data)
"""
def __init__(self,
in_channels,
out_channels,
num_outs):
super(FeatPyramidNeck, self).__init__()
self.num_outs = num_outs
self.in_channels = in_channels
self.fpn_layer = len(self.in_channels)
assert not self.num_outs < len(in_channels)
self.lateral_convs_list_ = []
self.fpn_convs_ = []
for _, channel in enumerate(in_channels):
l_conv = _conv(channel, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='valid')
fpn_conv = _conv(out_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='same')
self.lateral_convs_list_.append(l_conv)
self.fpn_convs_.append(fpn_conv)
self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_)
self.fpn_convs_list = nn.layer.CellList(self.fpn_convs_)
self.interpolate1 = P.ResizeNearestNeighbor((48, 80))
self.interpolate2 = P.ResizeNearestNeighbor((96, 160))
self.interpolate3 = P.ResizeNearestNeighbor((192, 320))
self.maxpool = P.MaxPool(ksize=1, strides=2, padding="same")
def construct(self, inputs):
x = ()
for i in range(self.fpn_layer):
x += (self.lateral_convs_list[i](inputs[i]),)
y = (x[3],)
y = y + (x[2] + self.interpolate1(y[self.fpn_layer - 4]),)
y = y + (x[1] + self.interpolate2(y[self.fpn_layer - 3]),)
y = y + (x[0] + self.interpolate3(y[self.fpn_layer - 2]),)
z = ()
for i in range(self.fpn_layer - 1, -1, -1):
z = z + (y[i],)
outs = ()
for i in range(self.fpn_layer):
outs = outs + (self.fpn_convs_list[i](z[i]),)
for i in range(self.num_outs - self.fpn_layer):
outs = outs + (self.maxpool(outs[3]),)
return outs
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaskRcnn proposal generator."""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
class Proposal(nn.Cell):
"""
Proposal subnet.
Args:
config (dict): Config.
batch_size (int): Batchsize.
num_classes (int) - Class number.
use_sigmoid_cls (bool) - Select sigmoid or softmax function.
target_means (tuple) - Means for encode function. Default: (.0, .0, .0, .0).
target_stds (tuple) - Stds for encode function. Default: (1.0, 1.0, 1.0, 1.0).
Returns:
Tuple, tuple of output tensor,(proposal, mask).
Examples:
Proposal(config = config, batch_size = 1, num_classes = 81, use_sigmoid_cls = True, \
target_means=(.0, .0, .0, .0), target_stds=(1.0, 1.0, 1.0, 1.0))
"""
def __init__(self,
config,
batch_size,
num_classes,
use_sigmoid_cls,
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0)
):
super(Proposal, self).__init__()
cfg = config
self.batch_size = batch_size
self.num_classes = num_classes
self.target_means = target_means
self.target_stds = target_stds
self.use_sigmoid_cls = use_sigmoid_cls
if self.use_sigmoid_cls:
self.cls_out_channels = num_classes - 1
self.activation = P.Sigmoid()
self.reshape_shape = (-1, 1)
else:
self.cls_out_channels = num_classes
self.activation = P.Softmax(axis=1)
self.reshape_shape = (-1, 2)
if self.cls_out_channels <= 0:
raise ValueError('num_classes={} is too small'.format(num_classes))
self.num_pre = cfg.rpn_proposal_nms_pre
self.min_box_size = cfg.rpn_proposal_min_bbox_size
self.nms_thr = cfg.rpn_proposal_nms_thr
self.nms_post = cfg.rpn_proposal_nms_post
self.nms_across_levels = cfg.rpn_proposal_nms_across_levels
self.max_num = cfg.rpn_proposal_max_num
self.num_levels = cfg.fpn_num_outs
# Op Define
self.squeeze = P.Squeeze()
self.reshape = P.Reshape()
self.cast = P.Cast()
self.feature_shapes = cfg.feature_shapes
self.transpose_shape = (1, 2, 0)
self.decode = P.BoundingBoxDecode(max_shape=(cfg.img_height, cfg.img_width), \
means=self.target_means, \
stds=self.target_stds)
self.nms = P.NMSWithMask(self.nms_thr)
self.concat_axis0 = P.Concat(axis=0)
self.concat_axis1 = P.Concat(axis=1)
self.split = P.Split(axis=1, output_num=5)
self.min = P.Minimum()
self.gatherND = P.GatherNd()
self.slice = P.Slice()
self.select = P.Select()
self.greater = P.Greater()
self.transpose = P.Transpose()
self.tile = P.Tile()
self.set_train_local(config, training=True)
self.multi_10 = Tensor(10.0, mstype.float16)
def set_train_local(self, config, training=True):
"""Set training flag."""
self.training_local = training
cfg = config
self.topK_stage1 = ()
self.topK_shape = ()
total_max_topk_input = 0
if not self.training_local:
self.num_pre = cfg.rpn_nms_pre
self.min_box_size = cfg.rpn_min_bbox_min_size
self.nms_thr = cfg.rpn_nms_thr
self.nms_post = cfg.rpn_nms_post
self.nms_across_levels = cfg.rpn_nms_across_levels
self.max_num = cfg.rpn_max_num
for shp in self.feature_shapes:
k_num = min(self.num_pre, (shp[0] * shp[1] * 3))
total_max_topk_input += k_num
self.topK_stage1 += (k_num,)
self.topK_shape += ((k_num, 1),)
self.topKv2 = P.TopK(sorted=True)
self.topK_shape_stage2 = (self.max_num, 1)
self.min_float_num = -65536.0
self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float16))
def construct(self, rpn_cls_score_total, rpn_bbox_pred_total, anchor_list):
proposals_tuple = ()
masks_tuple = ()
for img_id in range(self.batch_size):
cls_score_list = ()
bbox_pred_list = ()
for i in range(self.num_levels):
rpn_cls_score_i = self.squeeze(rpn_cls_score_total[i][img_id:img_id+1:1, ::, ::, ::])
rpn_bbox_pred_i = self.squeeze(rpn_bbox_pred_total[i][img_id:img_id+1:1, ::, ::, ::])
cls_score_list = cls_score_list + (rpn_cls_score_i,)
bbox_pred_list = bbox_pred_list + (rpn_bbox_pred_i,)
proposals, masks = self.get_bboxes_single(cls_score_list, bbox_pred_list, anchor_list)
proposals_tuple += (proposals,)
masks_tuple += (masks,)
return proposals_tuple, masks_tuple
def get_bboxes_single(self, cls_scores, bbox_preds, mlvl_anchors):
"""Get proposal boundingbox."""
mlvl_proposals = ()
mlvl_mask = ()
for idx in range(self.num_levels):
rpn_cls_score = self.transpose(cls_scores[idx], self.transpose_shape)
rpn_bbox_pred = self.transpose(bbox_preds[idx], self.transpose_shape)
anchors = mlvl_anchors[idx]
rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape)
rpn_cls_score = self.activation(rpn_cls_score)
rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score[::, 0::]), mstype.float16)
rpn_bbox_pred_process = self.cast(self.reshape(rpn_bbox_pred, (-1, 4)), mstype.float16)
scores_sorted, topk_inds = self.topKv2(rpn_cls_score_process, self.topK_stage1[idx])
topk_inds = self.reshape(topk_inds, self.topK_shape[idx])
bboxes_sorted = self.gatherND(rpn_bbox_pred_process, topk_inds)
anchors_sorted = self.cast(self.gatherND(anchors, topk_inds), mstype.float16)
proposals_decode = self.decode(anchors_sorted, bboxes_sorted)
proposals_decode = self.concat_axis1((proposals_decode, self.reshape(scores_sorted, self.topK_shape[idx])))
proposals, _, mask_valid = self.nms(proposals_decode)
mlvl_proposals = mlvl_proposals + (proposals,)
mlvl_mask = mlvl_mask + (mask_valid,)
proposals = self.concat_axis0(mlvl_proposals)
masks = self.concat_axis0(mlvl_mask)
_, _, _, _, scores = self.split(proposals)
scores = self.squeeze(scores)
topk_mask = self.cast(self.topK_mask, mstype.float16)
scores_using = self.select(masks, scores, topk_mask)
_, topk_inds = self.topKv2(scores_using, self.max_num)
topk_inds = self.reshape(topk_inds, self.topK_shape_stage2)
proposals = self.gatherND(proposals, topk_inds)
masks = self.gatherND(masks, topk_inds)
return proposals, masks
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaskRcnn Rcnn classification and box regression network."""
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
class DenseNoTranpose(nn.Cell):
"""Dense method"""
def __init__(self, input_channels, output_channels, weight_init):
super(DenseNoTranpose, self).__init__()
self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float16),
name="weight")
self.bias = Parameter(initializer("zeros", [output_channels], mstype.float16).to_tensor(), name="bias")
self.matmul = P.MatMul(transpose_b=False)
self.bias_add = P.BiasAdd()
def construct(self, x):
output = self.bias_add(self.matmul(x, self.weight), self.bias)
return output
class FpnCls(nn.Cell):
"""dense layer of classification and box head"""
def __init__(self, input_channels, output_channels, num_classes, pool_size):
super(FpnCls, self).__init__()
representation_size = input_channels * pool_size * pool_size
shape_0 = (output_channels, representation_size)
weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=mstype.float16).to_tensor()
shape_1 = (output_channels, output_channels)
weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=mstype.float16).to_tensor()
self.shared_fc_0 = DenseNoTranpose(representation_size, output_channels, weights_0)
self.shared_fc_1 = DenseNoTranpose(output_channels, output_channels, weights_1)
cls_weight = initializer('Normal', shape=[num_classes, output_channels][::-1],
dtype=mstype.float16).to_tensor()
reg_weight = initializer('Normal', shape=[num_classes * 4, output_channels][::-1],
dtype=mstype.float16).to_tensor()
self.cls_scores = DenseNoTranpose(output_channels, num_classes, cls_weight)
self.reg_scores = DenseNoTranpose(output_channels, num_classes * 4, reg_weight)
self.relu = P.ReLU()
self.flatten = P.Flatten()
def construct(self, x):
# two share fc layer
x = self.flatten(x)
x = self.relu(self.shared_fc_0(x))
x = self.relu(self.shared_fc_1(x))
# classifier head
cls_scores = self.cls_scores(x)
# bbox head
reg_scores = self.reg_scores(x)
return cls_scores, reg_scores
class RcnnCls(nn.Cell):
"""
Rcnn for classification and box regression subnet.
Args:
config (dict) - Config.
batch_size (int) - Batchsize.
num_classes (int) - Class number.
target_means (list) - Means for encode function. Default: (.0, .0, .0, .0]).
target_stds (list) - Stds for encode function. Default: (0.1, 0.1, 0.2, 0.2).
Returns:
Tuple, tuple of output tensor.
Examples:
RcnnCls(config=config, representation_size = 1024, batch_size=2, num_classes = 81, \
target_means=(0., 0., 0., 0.), target_stds=(0.1, 0.1, 0.2, 0.2))
"""
def __init__(self,
config,
batch_size,
num_classes,
target_means=(0., 0., 0., 0.),
target_stds=(0.1, 0.1, 0.2, 0.2)
):
super(RcnnCls, self).__init__()
cfg = config
self.rcnn_loss_cls_weight = Tensor(np.array(cfg.rcnn_loss_cls_weight).astype(np.float16))
self.rcnn_loss_reg_weight = Tensor(np.array(cfg.rcnn_loss_reg_weight).astype(np.float16))
self.rcnn_fc_out_channels = cfg.rcnn_fc_out_channels
self.target_means = target_means
self.target_stds = target_stds
self.num_classes = num_classes
self.in_channels = cfg.rcnn_in_channels
self.train_batch_size = batch_size
self.test_batch_size = cfg.test_batch_size
self.fpn_cls = FpnCls(self.in_channels, self.rcnn_fc_out_channels, self.num_classes, cfg.roi_layer["out_size"])
self.relu = P.ReLU()
self.logicaland = P.LogicalAnd()
self.loss_cls = P.SoftmaxCrossEntropyWithLogits()
self.loss_bbox = P.SmoothL1Loss(sigma=1.0)
self.loss_mask = P.SigmoidCrossEntropyWithLogits()
self.reshape = P.Reshape()
self.onehot = P.OneHot()
self.greater = P.Greater()
self.cast = P.Cast()
self.sum_loss = P.ReduceSum()
self.tile = P.Tile()
self.expandims = P.ExpandDims()
self.gather = P.GatherNd()
self.argmax = P.ArgMaxWithValue(axis=1)
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.value = Tensor(1.0, mstype.float16)
self.num_bboxes = (cfg.num_expected_pos_stage2 + cfg.num_expected_neg_stage2) * batch_size
rmv_first = np.ones((self.num_bboxes, self.num_classes))
rmv_first[:, 0] = np.zeros((self.num_bboxes,))
self.rmv_first_tensor = Tensor(rmv_first.astype(np.float16))
self.num_bboxes_test = cfg.rpn_max_num * cfg.test_batch_size
def construct(self, featuremap, bbox_targets, labels, mask):
x_cls, x_reg = self.fpn_cls(featuremap)
if self.training:
bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels
labels = self.cast(self.onehot(labels, self.num_classes, self.on_value, self.off_value), mstype.float16)
bbox_targets = self.tile(self.expandims(bbox_targets, 1), (1, self.num_classes, 1))
loss_cls, loss_reg = self.loss(x_cls, x_reg,
bbox_targets, bbox_weights,
labels,
mask)
out = (loss_cls, loss_reg)
else:
out = (x_cls, x_reg)
return out
def loss(self, cls_score, bbox_pred, bbox_targets, bbox_weights, labels, weights):
"""Loss method."""
# loss_cls
loss_cls, _ = self.loss_cls(cls_score, labels)
weights = self.cast(weights, mstype.float16)
loss_cls = loss_cls * weights
loss_cls = self.sum_loss(loss_cls, (0,)) / self.sum_loss(weights, (0,))
# loss_reg
bbox_weights = self.cast(self.onehot(bbox_weights, self.num_classes, self.on_value, self.off_value),
mstype.float16)
bbox_weights = bbox_weights * self.rmv_first_tensor # * self.rmv_first_tensor exclude background
pos_bbox_pred = self.reshape(bbox_pred, (self.num_bboxes, -1, 4))
loss_reg = self.loss_bbox(pos_bbox_pred, bbox_targets)
loss_reg = self.sum_loss(loss_reg, (2,))
loss_reg = loss_reg * bbox_weights
loss_reg = loss_reg / self.sum_loss(weights, (0,))
loss_reg = self.sum_loss(loss_reg, (0, 1))
return loss_cls, loss_reg
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaskRcnn Rcnn for mask network."""
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import initializer
def _conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='pad'):
"""Conv2D wrapper."""
shape = (out_channels, in_channels, kernel_size, kernel_size)
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16).to_tensor()
shape_bias = (out_channels,)
bias = Tensor(np.array(np.zeros(shape_bias)).astype(np.float16))
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=True, bias_init=bias)
def _convTanspose(in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='pad'):
"""ConvTranspose wrapper."""
shape = (out_channels, in_channels, kernel_size, kernel_size)
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16).to_tensor()
shape_bias = (out_channels,)
bias = Tensor(np.array(np.zeros(shape_bias)).astype(np.float16))
return nn.Conv2dTranspose(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=True, bias_init=bias)
class FpnMask(nn.Cell):
"""conv layers of mask head"""
def __init__(self, input_channels, output_channels, num_classes):
super(FpnMask, self).__init__()
self.mask_conv1 = _conv(input_channels, output_channels, kernel_size=3, pad_mode="same")
self.mask_relu1 = P.ReLU()
self.mask_conv2 = _conv(output_channels, output_channels, kernel_size=3, pad_mode="same")
self.mask_relu2 = P.ReLU()
self.mask_conv3 = _conv(output_channels, output_channels, kernel_size=3, pad_mode="same")
self.mask_relu3 = P.ReLU()
self.mask_conv4 = _conv(output_channels, output_channels, kernel_size=3, pad_mode="same")
self.mask_relu4 = P.ReLU()
self.mask_deconv5 = _convTanspose(output_channels, output_channels, kernel_size=2, stride=2, pad_mode="valid")
self.mask_relu5 = P.ReLU()
self.mask_conv6 = _conv(output_channels, num_classes, kernel_size=1, stride=1, pad_mode="valid")
def construct(self, x):
x = self.mask_conv1(x)
x = self.mask_relu1(x)
x = self.mask_conv2(x)
x = self.mask_relu2(x)
x = self.mask_conv3(x)
x = self.mask_relu3(x)
x = self.mask_conv4(x)
x = self.mask_relu4(x)
x = self.mask_deconv5(x)
x = self.mask_relu5(x)
x = self.mask_conv6(x)
return x
class RcnnMask(nn.Cell):
"""
Rcnn for mask subnet.
Args:
config (dict) - Config.
batch_size (int) - Batchsize.
num_classes (int) - Class number.
target_means (list) - Means for encode function. Default: (.0, .0, .0, .0]).
target_stds (list) - Stds for encode function. Default: (0.1, 0.1, 0.2, 0.2).
Returns:
Tuple, tuple of output tensor.
Examples:
RcnnMask(config=config, representation_size = 1024, batch_size=2, num_classes = 81, \
target_means=(0., 0., 0., 0.), target_stds=(0.1, 0.1, 0.2, 0.2))
"""
def __init__(self,
config,
batch_size,
num_classes,
target_means=(0., 0., 0., 0.),
target_stds=(0.1, 0.1, 0.2, 0.2)
):
super(RcnnMask, self).__init__()
cfg = config
self.rcnn_loss_mask_fb_weight = Tensor(np.array(cfg.rcnn_loss_mask_fb_weight).astype(np.float16))
self.rcnn_mask_out_channels = cfg.rcnn_mask_out_channels
self.target_means = target_means
self.target_stds = target_stds
self.num_classes = num_classes
self.in_channels = cfg.rcnn_in_channels
self.fpn_mask = FpnMask(self.in_channels, self.rcnn_mask_out_channels, self.num_classes)
self.logicaland = P.LogicalAnd()
self.loss_mask = P.SigmoidCrossEntropyWithLogits()
self.onehot = P.OneHot()
self.greater = P.Greater()
self.cast = P.Cast()
self.sum_loss = P.ReduceSum()
self.tile = P.Tile()
self.expandims = P.ExpandDims()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.num_bboxes = cfg.num_expected_pos_stage2 * batch_size
rmv_first = np.ones((self.num_bboxes, self.num_classes))
rmv_first[:, 0] = np.zeros((self.num_bboxes,))
self.rmv_first_tensor = Tensor(rmv_first.astype(np.float16))
self.mean_loss = P.ReduceMean()
def construct(self, mask_featuremap, labels=None, mask=None, mask_fb_targets=None):
x_mask_fb = self.fpn_mask(mask_featuremap)
if self.training:
bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels
mask_fb_targets = self.tile(self.expandims(mask_fb_targets, 1), (1, self.num_classes, 1, 1))
loss_mask_fb = self.loss(x_mask_fb, bbox_weights, mask, mask_fb_targets)
out = loss_mask_fb
else:
out = x_mask_fb
return out
def loss(self, masks_fb_pred, bbox_weights, weights, masks_fb_targets):
"""Loss method."""
weights = self.cast(weights, mstype.float16)
bbox_weights = self.cast(self.onehot(bbox_weights, self.num_classes, self.on_value, self.off_value),
mstype.float16)
bbox_weights = bbox_weights * self.rmv_first_tensor # * self.rmv_first_tensor exclude background
# loss_mask_fb
masks_fb_targets = self.cast(masks_fb_targets, mstype.float16)
loss_mask_fb = self.loss_mask(masks_fb_pred, masks_fb_targets)
loss_mask_fb = self.mean_loss(loss_mask_fb, (2, 3))
loss_mask_fb = loss_mask_fb * bbox_weights
loss_mask_fb = loss_mask_fb / self.sum_loss(weights, (0,))
loss_mask_fb = self.sum_loss(loss_mask_fb, (0, 1))
return loss_mask_fb
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Resnet50 backbone."""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.ops import functional as F
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
def weight_init_ones(shape):
"""Weight init."""
return Tensor(np.array(np.ones(shape).astype(np.float32) * 0.01).astype(np.float16))
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
"""Conv2D wrapper."""
shape = (out_channels, in_channels, kernel_size, kernel_size)
weights = weight_init_ones(shape)
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=False)
def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=True):
"""Batchnorm2D wrapper."""
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float16))
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float16))
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float16))
moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float16))
return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init,
beta_init=beta_init, moving_mean_init=moving_mean_init,
moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics)
class ResNetFea(nn.Cell):
"""
ResNet architecture.
Args:
block (Cell): Block for network.
layer_nums (list): Numbers of block in different layers.
in_channels (list): Input channel in each layer.
out_channels (list): Output channel in each layer.
weights_update (bool): Weight update flag.
Returns:
Tensor, output tensor.
Examples:
>>> ResNet(ResidualBlock,
>>> [3, 4, 6, 3],
>>> [64, 256, 512, 1024],
>>> [256, 512, 1024, 2048],
>>> False)
"""
def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
weights_update=False):
super(ResNetFea, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of "
"layer_num, inchannel, outchannel list must be 4!")
bn_training = False
self.conv1 = _conv(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad')
self.bn1 = _BatchNorm2dInit(64, affine=bn_training, use_batch_statistics=bn_training)
self.relu = P.ReLU()
self.maxpool = P.MaxPool(ksize=3, strides=2, padding="SAME")
self.weights_update = weights_update
if not self.weights_update:
self.conv1.weight.requires_grad = False
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=1,
training=bn_training,
weights_update=self.weights_update)
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=2,
training=bn_training,
weights_update=True)
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=2,
training=bn_training,
weights_update=True)
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=2,
training=bn_training,
weights_update=True)
def _make_layer(self, block, layer_num, in_channel, out_channel, stride, training=False, weights_update=False):
"""Make block layer."""
layers = []
down_sample = False
if stride != 1 or in_channel != out_channel:
down_sample = True
resblk = block(in_channel,
out_channel,
stride=stride,
down_sample=down_sample,
training=training,
weights_update=weights_update)
layers.append(resblk)
for _ in range(1, layer_num):
resblk = block(out_channel, out_channel, stride=1, training=training, weights_update=weights_update)
layers.append(resblk)
return nn.SequentialCell(layers)
def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
c1 = self.maxpool(x)
c2 = self.layer1(c1)
identity = c2
if not self.weights_update:
identity = F.stop_gradient(c2)
c3 = self.layer2(identity)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
return identity, c3, c4, c5
class ResidualBlockUsing(nn.Cell):
"""
ResNet V1 residual block definition.
Args:
in_channels (int) - Input channel.
out_channels (int) - Output channel.
stride (int) - Stride size for the initial convolutional layer. Default: 1.
down_sample (bool) - If to do the downsample in block. Default: False.
momentum (float) - Momentum for batchnorm layer. Default: 0.1.
training (bool) - Training flag. Default: False.
weights_updata (bool) - Weights update flag. Default: False.
Returns:
Tensor, output tensor.
Examples:
ResidualBlock(3,256,stride=2,down_sample=True)
"""
expansion = 4
def __init__(self,
in_channels,
out_channels,
stride=1,
down_sample=False,
momentum=0.1,
training=False,
weights_update=False):
super(ResidualBlockUsing, self).__init__()
self.affine = weights_update
out_chls = out_channels // self.expansion
self.conv1 = _conv(in_channels, out_chls, kernel_size=1, stride=1, padding=0)
self.bn1 = _BatchNorm2dInit(out_chls, momentum=momentum, affine=self.affine, use_batch_statistics=training)
self.conv2 = _conv(out_chls, out_chls, kernel_size=3, stride=stride, padding=1)
self.bn2 = _BatchNorm2dInit(out_chls, momentum=momentum, affine=self.affine, use_batch_statistics=training)
self.conv3 = _conv(out_chls, out_channels, kernel_size=1, stride=1, padding=0)
self.bn3 = _BatchNorm2dInit(out_channels, momentum=momentum, affine=self.affine, use_batch_statistics=training)
if training:
self.bn1 = self.bn1.set_train()
self.bn2 = self.bn2.set_train()
self.bn3 = self.bn3.set_train()
if not weights_update:
self.conv1.weight.requires_grad = False
self.conv2.weight.requires_grad = False
self.conv3.weight.requires_grad = False
self.relu = P.ReLU()
self.downsample = down_sample
if self.downsample:
self.conv_down_sample = _conv(in_channels, out_channels, kernel_size=1, stride=stride, padding=0)
self.bn_down_sample = _BatchNorm2dInit(out_channels, momentum=momentum, affine=self.affine,
use_batch_statistics=training)
if training:
self.bn_down_sample = self.bn_down_sample.set_train()
if not weights_update:
self.conv_down_sample.weight.requires_grad = False
self.add = P.TensorAdd()
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample:
identity = self.conv_down_sample(identity)
identity = self.bn_down_sample(identity)
out = self.add(out, identity)
out = self.relu(out)
return out
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaskRcnn ROIAlign module."""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.nn import layer as L
from mindspore.common.tensor import Tensor
class ROIAlign(nn.Cell):
"""
Extract RoI features from mulitple feature map.
Args:
out_size_h (int) - RoI height.
out_size_w (int) - RoI width.
spatial_scale (int) - RoI spatial scale.
sample_num (int) - RoI sample number.
roi_align_mode (int)- RoI align mode
"""
def __init__(self,
out_size_h,
out_size_w,
spatial_scale,
sample_num=0,
roi_align_mode=1):
super(ROIAlign, self).__init__()
self.out_size = (out_size_h, out_size_w)
self.spatial_scale = float(spatial_scale)
self.sample_num = int(sample_num)
self.align_op = P.ROIAlign(self.out_size[0], self.out_size[1],
self.spatial_scale, self.sample_num, roi_align_mode)
def construct(self, features, rois):
return self.align_op(features, rois)
def __repr__(self):
format_str = self.__class__.__name__
format_str += '(out_size={}, spatial_scale={}, sample_num={}'.format(
self.out_size, self.spatial_scale, self.sample_num)
return format_str
class SingleRoIExtractor(nn.Cell):
"""
Extract RoI features from a single level feature map.
If there are mulitple input feature levels, each RoI is mapped to a level
according to its scale.
Args:
config (dict): Config
roi_layer (dict): Specify RoI layer type and arguments.
out_channels (int): Output channels of RoI layers.
featmap_strides (int): Strides of input feature maps.
batch_size (int): Batchsize.
finest_scale (int): Scale threshold of mapping to level 0.
mask (bool): Specify ROIAlign for cls or mask branch
"""
def __init__(self,
config,
roi_layer,
out_channels,
featmap_strides,
batch_size=1,
finest_scale=56,
mask=False):
super(SingleRoIExtractor, self).__init__()
cfg = config
self.train_batch_size = batch_size
self.out_channels = out_channels
self.featmap_strides = featmap_strides
self.num_levels = len(self.featmap_strides)
self.out_size = roi_layer['mask_out_size'] if mask else roi_layer['out_size']
self.mask = mask
self.sample_num = roi_layer['sample_num']
self.roi_layers = self.build_roi_layers(self.featmap_strides)
self.roi_layers = L.CellList(self.roi_layers)
self.sqrt = P.Sqrt()
self.log = P.Log()
self.finest_scale_ = finest_scale
self.clamp = C.clip_by_value
self.cast = P.Cast()
self.equal = P.Equal()
self.select = P.Select()
_mode_16 = False
self.dtype = np.float16 if _mode_16 else np.float32
self.ms_dtype = mstype.float16 if _mode_16 else mstype.float32
self.set_train_local(cfg, training=True)
def set_train_local(self, config, training=True):
"""Set training flag."""
self.training_local = training
cfg = config
# Init tensor
roi_sample_num = cfg.num_expected_pos_stage2 if self.mask else cfg.roi_sample_num
self.batch_size = roi_sample_num if self.training_local else cfg.rpn_max_num
self.batch_size = self.train_batch_size*self.batch_size \
if self.training_local else cfg.test_batch_size*self.batch_size
self.ones = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype))
finest_scale = np.array(np.ones((self.batch_size, 1)), dtype=self.dtype) * self.finest_scale_
self.finest_scale = Tensor(finest_scale)
self.epslion = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype)*self.dtype(1e-6))
self.zeros = Tensor(np.array(np.zeros((self.batch_size, 1)), dtype=np.int32))
self.max_levels = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=np.int32)*(self.num_levels-1))
self.twos = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype) * 2)
self.res_ = Tensor(np.array(np.zeros((self.batch_size, self.out_channels,
self.out_size, self.out_size)), dtype=self.dtype))
def num_inputs(self):
return len(self.featmap_strides)
def init_weights(self):
pass
def log2(self, value):
return self.log(value) / self.log(self.twos)
def build_roi_layers(self, featmap_strides):
roi_layers = []
for s in featmap_strides:
layer_cls = ROIAlign(self.out_size, self.out_size,
spatial_scale=1 / s,
sample_num=self.sample_num,
roi_align_mode=0)
roi_layers.append(layer_cls)
return roi_layers
def _c_map_roi_levels(self, rois):
"""Map rois to corresponding feature levels by scales.
- scale < finest_scale * 2: level 0
- finest_scale * 2 <= scale < finest_scale * 4: level 1
- finest_scale * 4 <= scale < finest_scale * 8: level 2
- scale >= finest_scale * 8: level 3
Args:
rois (Tensor): Input RoIs, shape (k, 5).
num_levels (int): Total level number.
Returns:
Tensor: Level index (0-based) of each RoI, shape (k, )
"""
scale = self.sqrt(rois[::, 3:4:1] - rois[::, 1:2:1] + self.ones) * \
self.sqrt(rois[::, 4:5:1] - rois[::, 2:3:1] + self.ones)
target_lvls = self.log2(scale / self.finest_scale + self.epslion)
target_lvls = P.Floor()(target_lvls)
target_lvls = self.cast(target_lvls, mstype.int32)
target_lvls = self.clamp(target_lvls, self.zeros, self.max_levels)
return target_lvls
def construct(self, rois, feat1, feat2, feat3, feat4):
feats = (feat1, feat2, feat3, feat4)
res = self.res_
target_lvls = self._c_map_roi_levels(rois)
for i in range(self.num_levels):
mask = self.equal(target_lvls, P.ScalarToArray()(i))
mask = P.Reshape()(mask, (-1, 1, 1, 1))
roi_feats_t = self.roi_layers[i](feats[i], rois)
mask = self.cast(P.Tile()(self.cast(mask, mstype.int32), (1, 256, self.out_size, self.out_size)),
mstype.bool_)
res = self.select(mask, roi_feats_t, res)
return res
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""RPN for MaskRCNN"""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore.ops import functional as F
from mindspore.common.initializer import initializer
from .bbox_assign_sample import BboxAssignSample
class RpnRegClsBlock(nn.Cell):
"""
Rpn reg cls block for rpn layer
Args:
in_channels (int) - Input channels of shared convolution.
feat_channels (int) - Output channels of shared convolution.
num_anchors (int) - The anchor number.
cls_out_channels (int) - Output channels of classification convolution.
weight_conv (Tensor) - weight init for rpn conv.
bias_conv (Tensor) - bias init for rpn conv.
weight_cls (Tensor) - weight init for rpn cls conv.
bias_cls (Tensor) - bias init for rpn cls conv.
weight_reg (Tensor) - weight init for rpn reg conv.
bias_reg (Tensor) - bias init for rpn reg conv.
Returns:
Tensor, output tensor.
"""
def __init__(self,
in_channels,
feat_channels,
num_anchors,
cls_out_channels,
weight_conv,
bias_conv,
weight_cls,
bias_cls,
weight_reg,
bias_reg):
super(RpnRegClsBlock, self).__init__()
self.rpn_conv = nn.Conv2d(in_channels, feat_channels, kernel_size=3, stride=1, pad_mode='same',
has_bias=True, weight_init=weight_conv, bias_init=bias_conv)
self.relu = nn.ReLU()
self.rpn_cls = nn.Conv2d(feat_channels, num_anchors * cls_out_channels, kernel_size=1, pad_mode='valid',
has_bias=True, weight_init=weight_cls, bias_init=bias_cls)
self.rpn_reg = nn.Conv2d(feat_channels, num_anchors * 4, kernel_size=1, pad_mode='valid',
has_bias=True, weight_init=weight_reg, bias_init=bias_reg)
def construct(self, x):
x = self.relu(self.rpn_conv(x))
x1 = self.rpn_cls(x)
x2 = self.rpn_reg(x)
return x1, x2
class RPN(nn.Cell):
"""
ROI proposal network..
Args:
config (dict) - Config.
batch_size (int) - Batchsize.
in_channels (int) - Input channels of shared convolution.
feat_channels (int) - Output channels of shared convolution.
num_anchors (int) - The anchor number.
cls_out_channels (int) - Output channels of classification convolution.
Returns:
Tuple, tuple of output tensor.
Examples:
RPN(config=config, batch_size=2, in_channels=256, feat_channels=1024,
num_anchors=3, cls_out_channels=512)
"""
def __init__(self,
config,
batch_size,
in_channels,
feat_channels,
num_anchors,
cls_out_channels):
super(RPN, self).__init__()
cfg_rpn = config
self.num_bboxes = cfg_rpn.num_bboxes
self.slice_index = ()
self.feature_anchor_shape = ()
self.slice_index += (0,)
index = 0
for shape in cfg_rpn.feature_shapes:
self.slice_index += (self.slice_index[index] + shape[0] * shape[1] * num_anchors,)
self.feature_anchor_shape += (shape[0] * shape[1] * num_anchors * batch_size,)
index += 1
self.num_anchors = num_anchors
self.batch_size = batch_size
self.test_batch_size = cfg_rpn.test_batch_size
self.num_layers = 5
self.real_ratio = Tensor(np.ones((1, 1)).astype(np.float16))
self.rpn_convs_list = nn.layer.CellList(self._make_rpn_layer(self.num_layers, in_channels, feat_channels,
num_anchors, cls_out_channels))
self.transpose = P.Transpose()
self.reshape = P.Reshape()
self.concat = P.Concat(axis=0)
self.fill = P.Fill()
self.placeh1 = Tensor(np.ones((1,)).astype(np.float16))
self.trans_shape = (0, 2, 3, 1)
self.reshape_shape_reg = (-1, 4)
self.reshape_shape_cls = (-1,)
self.rpn_loss_reg_weight = Tensor(np.array(cfg_rpn.rpn_loss_reg_weight).astype(np.float16))
self.rpn_loss_cls_weight = Tensor(np.array(cfg_rpn.rpn_loss_cls_weight).astype(np.float16))
self.num_expected_total = Tensor(np.array(cfg_rpn.num_expected_neg * self.batch_size).astype(np.float16))
self.num_bboxes = cfg_rpn.num_bboxes
self.get_targets = BboxAssignSample(cfg_rpn, self.batch_size, self.num_bboxes, False)
self.CheckValid = P.CheckValid()
self.sum_loss = P.ReduceSum()
self.loss_cls = P.SigmoidCrossEntropyWithLogits()
self.loss_bbox = P.SmoothL1Loss(sigma=1.0/9.0)
self.squeeze = P.Squeeze()
self.cast = P.Cast()
self.tile = P.Tile()
self.zeros_like = P.ZerosLike()
self.loss = Tensor(np.zeros((1,)).astype(np.float16))
self.clsloss = Tensor(np.zeros((1,)).astype(np.float16))
self.regloss = Tensor(np.zeros((1,)).astype(np.float16))
def _make_rpn_layer(self, num_layers, in_channels, feat_channels, num_anchors, cls_out_channels):
"""
make rpn layer for rpn proposal network
Args:
num_layers (int) - layer num.
in_channels (int) - Input channels of shared convolution.
feat_channels (int) - Output channels of shared convolution.
num_anchors (int) - The anchor number.
cls_out_channels (int) - Output channels of classification convolution.
Returns:
List, list of RpnRegClsBlock cells.
"""
rpn_layer = []
shp_weight_conv = (feat_channels, in_channels, 3, 3)
shp_bias_conv = (feat_channels,)
weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=mstype.float16).to_tensor()
bias_conv = initializer(0, shape=shp_bias_conv, dtype=mstype.float16).to_tensor()
shp_weight_cls = (num_anchors * cls_out_channels, feat_channels, 1, 1)
shp_bias_cls = (num_anchors * cls_out_channels,)
weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=mstype.float16).to_tensor()
bias_cls = initializer(0, shape=shp_bias_cls, dtype=mstype.float16).to_tensor()
shp_weight_reg = (num_anchors * 4, feat_channels, 1, 1)
shp_bias_reg = (num_anchors * 4,)
weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=mstype.float16).to_tensor()
bias_reg = initializer(0, shape=shp_bias_reg, dtype=mstype.float16).to_tensor()
for i in range(num_layers):
rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \
weight_conv, bias_conv, weight_cls, \
bias_cls, weight_reg, bias_reg))
for i in range(1, num_layers):
rpn_layer[i].rpn_conv.weight = rpn_layer[0].rpn_conv.weight
rpn_layer[i].rpn_cls.weight = rpn_layer[0].rpn_cls.weight
rpn_layer[i].rpn_reg.weight = rpn_layer[0].rpn_reg.weight
rpn_layer[i].rpn_conv.bias = rpn_layer[0].rpn_conv.bias
rpn_layer[i].rpn_cls.bias = rpn_layer[0].rpn_cls.bias
rpn_layer[i].rpn_reg.bias = rpn_layer[0].rpn_reg.bias
return rpn_layer
def construct(self, inputs, img_metas, anchor_list, gt_bboxes, gt_labels, gt_valids):
loss_print = ()
rpn_cls_score = ()
rpn_bbox_pred = ()
rpn_cls_score_total = ()
rpn_bbox_pred_total = ()
for i in range(self.num_layers):
x1, x2 = self.rpn_convs_list[i](inputs[i])
rpn_cls_score_total = rpn_cls_score_total + (x1,)
rpn_bbox_pred_total = rpn_bbox_pred_total + (x2,)
x1 = self.transpose(x1, self.trans_shape)
x1 = self.reshape(x1, self.reshape_shape_cls)
x2 = self.transpose(x2, self.trans_shape)
x2 = self.reshape(x2, self.reshape_shape_reg)
rpn_cls_score = rpn_cls_score + (x1,)
rpn_bbox_pred = rpn_bbox_pred + (x2,)
loss = self.loss
clsloss = self.clsloss
regloss = self.regloss
bbox_targets = ()
bbox_weights = ()
labels = ()
label_weights = ()
output = ()
if self.training:
for i in range(self.batch_size):
multi_level_flags = ()
anchor_list_tuple = ()
for j in range(self.num_layers):
res = self.cast(self.CheckValid(anchor_list[j], self.squeeze(img_metas[i:i + 1:1, ::])),
mstype.int32)
multi_level_flags = multi_level_flags + (res,)
anchor_list_tuple = anchor_list_tuple + (anchor_list[j],)
valid_flag_list = self.concat(multi_level_flags)
anchor_using_list = self.concat(anchor_list_tuple)
gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::])
gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::])
gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::])
bbox_target, bbox_weight, label, label_weight = self.get_targets(gt_bboxes_i,
gt_labels_i,
self.cast(valid_flag_list,
mstype.bool_),
anchor_using_list, gt_valids_i)
bbox_weight = self.cast(bbox_weight, mstype.float16)
label = self.cast(label, mstype.float16)
label_weight = self.cast(label_weight, mstype.float16)
for j in range(self.num_layers):
begin = self.slice_index[j]
end = self.slice_index[j + 1]
stride = 1
bbox_targets += (bbox_target[begin:end:stride, ::],)
bbox_weights += (bbox_weight[begin:end:stride],)
labels += (label[begin:end:stride],)
label_weights += (label_weight[begin:end:stride],)
for i in range(self.num_layers):
bbox_target_using = ()
bbox_weight_using = ()
label_using = ()
label_weight_using = ()
for j in range(self.batch_size):
bbox_target_using += (bbox_targets[i + (self.num_layers * j)],)
bbox_weight_using += (bbox_weights[i + (self.num_layers * j)],)
label_using += (labels[i + (self.num_layers * j)],)
label_weight_using += (label_weights[i + (self.num_layers * j)],)
bbox_target_with_batchsize = self.concat(bbox_target_using)
bbox_weight_with_batchsize = self.concat(bbox_weight_using)
label_with_batchsize = self.concat(label_using)
label_weight_with_batchsize = self.concat(label_weight_using)
# stop
bbox_target_ = F.stop_gradient(bbox_target_with_batchsize)
bbox_weight_ = F.stop_gradient(bbox_weight_with_batchsize)
label_ = F.stop_gradient(label_with_batchsize)
label_weight_ = F.stop_gradient(label_weight_with_batchsize)
cls_score_i = rpn_cls_score[i]
reg_score_i = rpn_bbox_pred[i]
loss_cls = self.loss_cls(cls_score_i, label_)
loss_cls_item = loss_cls * label_weight_
loss_cls_item = self.sum_loss(loss_cls_item, (0,)) / self.num_expected_total
loss_reg = self.loss_bbox(reg_score_i, bbox_target_)
bbox_weight_ = self.tile(self.reshape(bbox_weight_, (self.feature_anchor_shape[i], 1)), (1, 4))
loss_reg = loss_reg * bbox_weight_
loss_reg_item = self.sum_loss(loss_reg, (1,))
loss_reg_item = self.sum_loss(loss_reg_item, (0,)) / self.num_expected_total
loss_total = self.rpn_loss_cls_weight * loss_cls_item + self.rpn_loss_reg_weight * loss_reg_item
loss += loss_total
loss_print += (loss_total, loss_cls_item, loss_reg_item)
clsloss += loss_cls_item
regloss += loss_reg_item
output = (loss, rpn_cls_score_total, rpn_bbox_pred_total, clsloss, regloss, loss_print)
else:
output = (self.placeh1, rpn_cls_score_total, rpn_bbox_pred_total, self.placeh1, self.placeh1, self.placeh1)
return output
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#" :===========================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
config = ed({
"img_width": 1280,
"img_height": 768,
"keep_ratio": False,
"flip_ratio": 0.5,
"photo_ratio": 0.5,
"expand_ratio": 1.0,
"max_instance_count": 128,
"mask_shape": (28, 28),
# anchor
"feature_shapes": [(192, 320), (96, 160), (48, 80), (24, 40), (12, 20)],
"anchor_scales": [8],
"anchor_ratios": [0.5, 1.0, 2.0],
"anchor_strides": [4, 8, 16, 32, 64],
"num_anchors": 3,
# resnet
"resnet_block": [3, 4, 6, 3],
"resnet_in_channels": [64, 256, 512, 1024],
"resnet_out_channels": [256, 512, 1024, 2048],
# fpn
"fpn_in_channels": [256, 512, 1024, 2048],
"fpn_out_channels": 256,
"fpn_num_outs": 5,
# rpn
"rpn_in_channels": 256,
"rpn_feat_channels": 256,
"rpn_loss_cls_weight": 1.0,
"rpn_loss_reg_weight": 1.0,
"rpn_cls_out_channels": 1,
"rpn_target_means": [0., 0., 0., 0.],
"rpn_target_stds": [1.0, 1.0, 1.0, 1.0],
# bbox_assign_sampler
"neg_iou_thr": 0.3,
"pos_iou_thr": 0.7,
"min_pos_iou": 0.3,
"num_bboxes": 245520,
"num_gts": 128,
"num_expected_neg": 256,
"num_expected_pos": 128,
# proposal
"activate_num_classes": 2,
"use_sigmoid_cls": True,
# roi_align
"roi_layer": dict(type='RoIAlign', out_size=7, mask_out_size=14, sample_num=2),
"roi_align_out_channels": 256,
"roi_align_featmap_strides": [4, 8, 16, 32],
"roi_align_finest_scale": 56,
"roi_sample_num": 640,
# bbox_assign_sampler_stage2
"neg_iou_thr_stage2": 0.5,
"pos_iou_thr_stage2": 0.5,
"min_pos_iou_stage2": 0.5,
"num_bboxes_stage2": 2000,
"num_expected_pos_stage2": 128,
"num_expected_neg_stage2": 512,
"num_expected_total_stage2": 512,
# rcnn
"rcnn_num_layers": 2,
"rcnn_in_channels": 256,
"rcnn_fc_out_channels": 1024,
"rcnn_mask_out_channels": 256,
"rcnn_loss_cls_weight": 1,
"rcnn_loss_reg_weight": 1,
"rcnn_loss_mask_fb_weight": 1,
"rcnn_target_means": [0., 0., 0., 0.],
"rcnn_target_stds": [0.1, 0.1, 0.2, 0.2],
# train proposal
"rpn_proposal_nms_across_levels": False,
"rpn_proposal_nms_pre": 2000,
"rpn_proposal_nms_post": 2000,
"rpn_proposal_max_num": 2000,
"rpn_proposal_nms_thr": 0.7,
"rpn_proposal_min_bbox_size": 0,
# test proposal
"rpn_nms_across_levels": False,
"rpn_nms_pre": 1000,
"rpn_nms_post": 1000,
"rpn_max_num": 1000,
"rpn_nms_thr": 0.7,
"rpn_min_bbox_min_size": 0,
"test_score_thr": 0.05,
"test_iou_thr": 0.5,
"test_max_per_img": 100,
"test_batch_size": 2,
"rpn_head_loss_type": "CrossEntropyLoss",
"rpn_head_use_sigmoid": True,
"rpn_head_weight": 1.0,
"mask_thr_binary": 0.5,
# LR
"base_lr": 0.02,
"base_step": 58633,
"total_epoch": 13,
"warmup_step": 500,
"warmup_mode": "linear",
"warmup_ratio": 1/3.0,
"sgd_step": [8, 11],
"sgd_momentum": 0.9,
# train
"batch_size": 2,
"loss_scale": 1,
"momentum": 0.91,
"weight_decay": 1e-4,
"epoch_size": 12,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./checkpoint",
"mindrecord_dir": "/home/mxw/mask_rcnn/scripts/MindRecord_COCO2017_Train",
"coco_root": "/home/mxw/coco2017/",
"train_data_type": "train2017",
"val_data_type": "val2017",
"instance_set": "annotations/instances_{}.json",
"coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush'),
"num_classes": 81
})
此差异已折叠。
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lr generator for maskrcnn"""
import math
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
learning_rate = float(init_lr) + lr_inc * current_step
return learning_rate
def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
base = float(current_step - warmup_steps) / float(decay_steps)
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
return learning_rate
def dynamic_lr(config, rank_size=1):
"""dynamic learning rate generator"""
base_lr = config.base_lr
base_step = (config.base_step // rank_size) + rank_size
total_steps = int(base_step * config.total_epoch)
warmup_steps = int(config.warmup_step)
lr = []
for i in range(total_steps):
if i < warmup_steps:
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio))
else:
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
return lr
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaskRcnn training network wrapper."""
import time
import numpy as np
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore import ParameterTuple
from mindspore.train.callback import Callback
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
time_stamp_init = False
time_stamp_first = 0
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note:
If per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, per_print_times=1):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.count = 0
self.rpn_loss_sum = 0
self.rcnn_loss_sum = 0
self.rpn_cls_loss_sum = 0
self.rpn_reg_loss_sum = 0
self.rcnn_cls_loss_sum = 0
self.rcnn_reg_loss_sum = 0
self.rcnn_mask_loss_sum = 0
global time_stamp_init, time_stamp_first
if not time_stamp_init:
time_stamp_first = time.time()
time_stamp_init = True
def step_end(self, run_context):
cb_params = run_context.original_args()
rpn_loss = cb_params.net_outputs[0].asnumpy()
rcnn_loss = cb_params.net_outputs[1].asnumpy()
rpn_cls_loss = cb_params.net_outputs[2].asnumpy()
rpn_reg_loss = cb_params.net_outputs[3].asnumpy()
rcnn_cls_loss = cb_params.net_outputs[4].asnumpy()
rcnn_reg_loss = cb_params.net_outputs[5].asnumpy()
rcnn_mask_loss = cb_params.net_outputs[6].asnumpy()
self.count += 1
self.rpn_loss_sum += float(rpn_loss)
self.rcnn_loss_sum += float(rcnn_loss)
self.rpn_cls_loss_sum += float(rpn_cls_loss)
self.rpn_reg_loss_sum += float(rpn_reg_loss)
self.rcnn_cls_loss_sum += float(rcnn_cls_loss)
self.rcnn_reg_loss_sum += float(rcnn_reg_loss)
self.rcnn_mask_loss_sum += float(rcnn_mask_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
if self.count >= 1:
global time_stamp_first
time_stamp_current = time.time()
rpn_loss = self.rpn_loss_sum/self.count
rcnn_loss = self.rcnn_loss_sum/self.count
rpn_cls_loss = self.rpn_cls_loss_sum/self.count
rpn_reg_loss = self.rpn_reg_loss_sum/self.count
rcnn_cls_loss = self.rcnn_cls_loss_sum/self.count
rcnn_reg_loss = self.rcnn_reg_loss_sum/self.count
rcnn_mask_loss = self.rcnn_mask_loss_sum/self.count
total_loss = rpn_loss + rcnn_loss
loss_file = open("./loss.log", "a+")
loss_file.write("%lu epoch: %s step: %s ,rpn_loss: %.5f, rcnn_loss: %.5f, rpn_cls_loss: %.5f, "
"rpn_reg_loss: %.5f, rcnn_cls_loss: %.5f, rcnn_reg_loss: %.5f, rcnn_mask_loss: %.5f, "
"total_loss: %.5f" %
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss,
rcnn_cls_loss, rcnn_reg_loss, rcnn_mask_loss, total_loss))
loss_file.write("\n")
loss_file.close()
self.count = 0
self.rpn_loss_sum = 0
self.rcnn_loss_sum = 0
self.rpn_cls_loss_sum = 0
self.rpn_reg_loss_sum = 0
self.rcnn_cls_loss_sum = 0
self.rcnn_reg_loss_sum = 0
self.rcnn_mask_loss_sum = 0
class LossNet(nn.Cell):
"""MaskRcnn loss method"""
def __init__(self):
super(LossNet, self).__init__()
def construct(self, x1, x2, x3, x4, x5, x6, x7):
return x1 + x2
class WithLossCell(nn.Cell):
"""
Wrap the network with loss function to compute loss.
Args:
backbone (Cell): The target network to wrap.
loss_fn (Cell): The loss function used to compute loss.
"""
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num, gt_mask):
loss1, loss2, loss3, loss4, loss5, loss6, loss7 = self._backbone(x, img_shape, gt_bboxe, gt_label,
gt_num, gt_mask)
return self._loss_fn(loss1, loss2, loss3, loss4, loss5, loss6, loss7)
@property
def backbone_network(self):
"""
Get the backbone network.
Returns:
Cell, return backbone network.
"""
return self._backbone
class TrainOneStepCell(nn.Cell):
"""
Network training package class.
Append an optimizer to the training network after that the construct function
can be called to create the backward graph.
Args:
network (Cell): The training network.
network_backbone (Cell): The forward network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default value is 1.0.
reduce_flag (bool): The reduce flag. Default value is False.
mean (bool): Allreduce method. Default value is False.
degree (int): Device number. Default value is None.
"""
def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.backbone = network_backbone
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation('grad',
get_by_list=True,
sens_param=True)
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float16))
self.reduce_flag = reduce_flag
self.hyper_map = C.HyperMap()
if reduce_flag:
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num, gt_mask):
weights = self.weights
loss1, loss2, loss3, loss4, loss5, loss6, loss7 = self.backbone(x, img_shape, gt_bboxe, gt_label,
gt_num, gt_mask)
grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, gt_mask, self.sens)
if self.reduce_flag:
grads = self.grad_reducer(grads)
return F.depend(loss1, self.optimizer(grads)), loss2, loss3, loss4, loss5, loss6, loss7
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""coco eval for maskrcnn"""
import json
import numpy as np
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from pycocotools import mask as maskUtils
import mmcv
from src.config import config
_init_value = np.array(0.0)
summary_init = {
'Precision/mAP': _init_value,
'Precision/mAP@.50IOU': _init_value,
'Precision/mAP@.75IOU': _init_value,
'Precision/mAP (small)': _init_value,
'Precision/mAP (medium)': _init_value,
'Precision/mAP (large)': _init_value,
'Recall/AR@1': _init_value,
'Recall/AR@10': _init_value,
'Recall/AR@100': _init_value,
'Recall/AR@100 (small)': _init_value,
'Recall/AR@100 (medium)': _init_value,
'Recall/AR@100 (large)': _init_value,
}
def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000), single_result=False):
"""coco eval for maskrcnn"""
anns = json.load(open(result_files['bbox']))
if not anns:
return summary_init
if mmcv.is_str(coco):
coco = COCO(coco)
assert isinstance(coco, COCO)
for res_type in result_types:
result_file = result_files[res_type]
assert result_file.endswith('.json')
coco_dets = coco.loadRes(result_file)
gt_img_ids = coco.getImgIds()
det_img_ids = coco_dets.getImgIds()
iou_type = 'bbox' if res_type == 'proposal' else res_type
cocoEval = COCOeval(coco, coco_dets, iou_type)
if res_type == 'proposal':
cocoEval.params.useCats = 0
cocoEval.params.maxDets = list(max_dets)
tgt_ids = gt_img_ids if not single_result else det_img_ids
if single_result:
res_dict = dict()
for id_i in tgt_ids:
cocoEval = COCOeval(coco, coco_dets, iou_type)
if res_type == 'proposal':
cocoEval.params.useCats = 0
cocoEval.params.maxDets = list(max_dets)
cocoEval.params.imgIds = [id_i]
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
res_dict.update({coco.imgs[id_i]['file_name']: cocoEval.stats[1]})
cocoEval = COCOeval(coco, coco_dets, iou_type)
if res_type == 'proposal':
cocoEval.params.useCats = 0
cocoEval.params.maxDets = list(max_dets)
cocoEval.params.imgIds = tgt_ids
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
summary_metrics = {
'Precision/mAP': cocoEval.stats[0],
'Precision/mAP@.50IOU': cocoEval.stats[1],
'Precision/mAP@.75IOU': cocoEval.stats[2],
'Precision/mAP (small)': cocoEval.stats[3],
'Precision/mAP (medium)': cocoEval.stats[4],
'Precision/mAP (large)': cocoEval.stats[5],
'Recall/AR@1': cocoEval.stats[6],
'Recall/AR@10': cocoEval.stats[7],
'Recall/AR@100': cocoEval.stats[8],
'Recall/AR@100 (small)': cocoEval.stats[9],
'Recall/AR@100 (medium)': cocoEval.stats[10],
'Recall/AR@100 (large)': cocoEval.stats[11],
}
return summary_metrics
def xyxy2xywh(bbox):
_bbox = bbox.tolist()
return [
_bbox[0],
_bbox[1],
_bbox[2] - _bbox[0] + 1,
_bbox[3] - _bbox[1] + 1,
]
def bbox2result_1image(bboxes, labels, num_classes):
"""Convert detection results to a list of numpy arrays.
Args:
bboxes (Tensor): shape (n, 5)
labels (Tensor): shape (n, )
num_classes (int): class number, including background class
Returns:
list(ndarray): bbox results of each class
"""
if bboxes.shape[0] == 0:
result = [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes - 1)]
else:
result = [bboxes[labels == i, :] for i in range(num_classes - 1)]
return result
def proposal2json(dataset, results):
"""convert proposal to json mode"""
img_ids = dataset.getImgIds()
json_results = []
dataset_len = dataset.get_dataset_size()*2
for idx in range(dataset_len):
img_id = img_ids[idx]
bboxes = results[idx]
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['bbox'] = xyxy2xywh(bboxes[i])
data['score'] = float(bboxes[i][4])
data['category_id'] = 1
json_results.append(data)
return json_results
def det2json(dataset, results):
"""convert det to json mode"""
cat_ids = dataset.getCatIds()
img_ids = dataset.getImgIds()
json_results = []
dataset_len = len(img_ids)
for idx in range(dataset_len):
img_id = img_ids[idx]
if idx == len(results): break
result = results[idx]
for label, result_label in enumerate(result):
bboxes = result_label
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['bbox'] = xyxy2xywh(bboxes[i])
data['score'] = float(bboxes[i][4])
data['category_id'] = cat_ids[label]
json_results.append(data)
return json_results
def segm2json(dataset, results):
"""convert segm to json mode"""
cat_ids = dataset.getCatIds()
img_ids = dataset.getImgIds()
bbox_json_results = []
segm_json_results = []
dataset_len = len(img_ids)
assert dataset_len == len(results)
for idx in range(dataset_len):
img_id = img_ids[idx]
if idx == len(results): break
det, seg = results[idx]
for label, det_label in enumerate(det):
bboxes = det_label
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['bbox'] = xyxy2xywh(bboxes[i])
data['score'] = float(bboxes[i][4])
data['category_id'] = cat_ids[label]
bbox_json_results.append(data)
if len(seg) == 2:
segms = seg[0][label]
mask_score = seg[1][label]
else:
segms = seg[label]
mask_score = [bbox[4] for bbox in bboxes]
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['score'] = float(mask_score[i])
data['category_id'] = cat_ids[label]
segms[i]['counts'] = segms[i]['counts'].decode()
data['segmentation'] = segms[i]
segm_json_results.append(data)
return bbox_json_results, segm_json_results
def results2json(dataset, results, out_file):
"""convert result convert to json mode"""
result_files = dict()
if isinstance(results[0], list):
json_results = det2json(dataset, results)
result_files['bbox'] = '{}.{}.json'.format(out_file, 'bbox')
result_files['proposal'] = '{}.{}.json'.format(out_file, 'bbox')
mmcv.dump(json_results, result_files['bbox'])
elif isinstance(results[0], tuple):
json_results = segm2json(dataset, results)
result_files['bbox'] = '{}.{}.json'.format(out_file, 'bbox')
result_files['segm'] = '{}.{}.json'.format(out_file, 'segm')
mmcv.dump(json_results[0], result_files['bbox'])
mmcv.dump(json_results[1], result_files['segm'])
elif isinstance(results[0], np.ndarray):
json_results = proposal2json(dataset, results)
result_files['proposal'] = '{}.{}.json'.format(out_file, 'proposal')
mmcv.dump(json_results, result_files['proposal'])
else:
raise TypeError('invalid type of results')
return result_files
def get_seg_masks(mask_pred, det_bboxes, det_labels, img_meta, rescale, num_classes):
"""Get segmentation masks from mask_pred and bboxes"""
mask_pred = mask_pred.astype(np.float32)
cls_segms = [[] for _ in range(num_classes - 1)]
bboxes = det_bboxes[:, :4]
labels = det_labels + 1
ori_shape = img_meta[:2].astype(np.int32)
scale_factor = img_meta[2:].astype(np.int32)
if rescale:
img_h, img_w = ori_shape[:2]
else:
img_h = np.round(ori_shape[0] * scale_factor[0]).astype(np.int32)
img_w = np.round(ori_shape[1] * scale_factor[1]).astype(np.int32)
scale_factor = 1.0
for i in range(bboxes.shape[0]):
bbox = (bboxes[i, :] / 1.0).astype(np.int32)
label = labels[i]
w = max(bbox[2] - bbox[0] + 1, 1)
h = max(bbox[3] - bbox[1] + 1, 1)
w = min(w, img_w - bbox[0])
h = min(h, img_h - bbox[1])
mask_pred_ = mask_pred[i, :, :]
im_mask = np.zeros((img_h, img_w), dtype=np.uint8)
bbox_mask = mmcv.imresize(mask_pred_, (w, h))
bbox_mask = (bbox_mask > config.mask_thr_binary).astype(np.uint8)
im_mask[bbox[1]:bbox[1] + h, bbox[0]:bbox[0] + w] = bbox_mask
rle = maskUtils.encode(
np.array(im_mask[:, :, np.newaxis], order='F'))[0]
cls_segms[label - 1].append(rle)
return cls_segms
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train MaskRcnn and get checkpoint files."""
import os
import argparse
import random
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import context, Tensor
from mindspore.communication.management import init
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore.train import Model, ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn import SGD
import mindspore.dataset.engine as de
from src.MaskRcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet
from src.config import config
from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset
from src.lr_schedule import dynamic_lr
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
parser = argparse.ArgumentParser(description="MaskRcnn training")
parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create "
"Mindrecord, default is false.")
parser.add_argument("--run_distribute", type=bool, default=False, help="Run distribute, default is false.")
parser.add_argument("--do_train", type=bool, default=True, help="Do train or not, default is true.")
parser.add_argument("--do_eval", type=bool, default=False, help="Do eval or not, default is false.")
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
parser.add_argument("--pre_trained", type=str, default="", help="Pretrain file path.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default is 0.")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=args_opt.device_id)
if __name__ == '__main__':
print("Start train for maskrcnn!")
if not args_opt.do_eval and args_opt.run_distribute:
rank = args_opt.rank_id
device_num = args_opt.device_num
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True, parameter_broadcast=True)
init()
else:
rank = 0
device_num = 1
print("Start create dataset!")
# It will generate mindrecord file in args_opt.mindrecord_dir,
# and the file name is MaskRcnn.mindrecord0, 1, ... file_num.
prefix = "MaskRcnn.mindrecord"
mindrecord_dir = config.mindrecord_dir
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
if not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
if args_opt.dataset == "coco":
if os.path.isdir(config.coco_root):
print("Create Mindrecord.")
data_to_mindrecord_byte_image("coco", True, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("coco_root not exits.")
else:
if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH):
print("Create Mindrecord.")
data_to_mindrecord_byte_image("other", True, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("IMAGE_DIR or ANNO_PATH not exits.")
if not args_opt.only_create_dataset:
loss_scale = float(config.loss_scale)
# When create MindDataset, using the fitst mindrecord file, such as MaskRcnn.mindrecord0.
dataset = create_maskrcnn_dataset(mindrecord_file, batch_size=config.batch_size,
device_num=device_num, rank_id=rank)
dataset_size = dataset.get_dataset_size()
print("total images num: ", dataset_size)
print("Create dataset done!")
net = Mask_Rcnn_Resnet50(config=config)
net = net.set_train()
load_path = args_opt.pre_trained
if load_path != "":
param_dict = load_checkpoint(load_path)
for item in list(param_dict.keys()):
if not (item.startswith('backbone') or item.startswith('rcnn_mask')):
param_dict.pop(item)
load_param_into_net(net, param_dict)
loss = LossNet()
lr = Tensor(dynamic_lr(config, rank_size=device_num), mstype.float32)
opt = SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,
weight_decay=config.weight_decay, loss_scale=config.loss_scale)
net_with_loss = WithLossCell(net, loss)
if args_opt.run_distribute:
net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale, reduce_flag=True,
mean=True, degree=device_num)
else:
net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale)
time_cb = TimeMonitor(data_size=dataset_size)
loss_cb = LossCallBack()
cb = [time_cb, loss_cb]
if config.save_checkpoint:
ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * dataset_size,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix='mask_rcnn', directory=config.save_checkpoint_path, config=ckptconfig)
cb += [ckpoint_cb]
model = Model(net)
model.train(config.epoch_size, dataset, callbacks=cb)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册