提交 e2094c97 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!891 Add SSD network in model zoo

Merge pull request !891 from zhaoting/SSD
# SSD Example
## Description
SSD network based on MobileNetV2, with support for training and evaluation.
## Requirements
- Install [MindSpore](https://www.mindspore.cn/install/en).
- Dataset
We use coco2017 as training dataset in this example by default, and you can also use your own datasets.
1. If coco dataset is used. **Select dataset to coco when run script.**
Download coco2017: [train2017](http://images.cocodataset.org/zips/train2017.zip), [val2017](http://images.cocodataset.org/zips/val2017.zip), [test2017](http://images.cocodataset.org/zips/test2017.zip), [annotations](http://images.cocodataset.org/annotations/annotations_trainval2017.zip). Install pycocotool.
```
pip install Cython
pip install pycocotools
```
And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows:
```
└─coco2017
├── annotations # annotation jsons
├── train2017 # train dataset
└── val2017 # infer dataset
```
2. If your own dataset is used. **Select dataset to other when run script.**
Organize the dataset infomation into a TXT file, each row in the file is as follows:
```
train2017/0000001.jpg 0,259,401,459,7 35,28,324,201,2 0,30,59,80,2
```
Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `IMAGE_DIR`(dataset directory) and the relative path in `ANNO_PATH`(the TXT file path), `IMAGE_DIR` and `ANNO_PATH` are setting in `config.py`.
## Running the example
### Training
To train the model, run `train.py`. If the `MINDRECORD_DIR` is empty, it will generate [mindrecord](https://www.mindspore.cn/tutorial/en/master/use/data_preparation/converting_datasets.html) files by `COCO_ROOT`(coco dataset) or `IMAGE_DIR` and `ANNO_PATH`(own dataset). **Note if MINDRECORD_DIR isn't empty, it will use MINDRECORD_DIR instead of raw images.**
- Stand alone mode
```
python train.py --dataset coco
```
You can run ```python train.py -h``` to get more information.
- Distribute mode
```
sh run_distribute_train.sh 8 150 coco /data/hccl.json
```
The input parameters are device numbers, epoch size, dataset mode and [hccl json configuration file](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). **It is better to use absolute path.**
You will get the loss value of each step as following:
```
epoch: 1 step: 455, loss is 5.8653416
epoch: 2 step: 455, loss is 5.4292373
epoch: 3 step: 455, loss is 5.458992
...
epoch: 148 step: 455, loss is 1.8340507
epoch: 149 step: 455, loss is 2.0876894
epoch: 150 step: 455, loss is 2.239692
```
### Evaluation
for evaluation , run `eval.py` with `ckpt_path`. `ckpt_path` is the path of [checkpoint](https://www.mindspore.cn/tutorial/en/master/use/saving_and_loading_model_parameters.html) file.
```
python eval.py --ckpt_path ssd.ckpt --dataset coco
```
You can run ```python eval.py -h``` to get more information.
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Config parameters for SSD models."""
class ConfigSSD:
"""
Config parameters for SSD.
Examples:
ConfigSSD().
"""
IMG_SHAPE = [300, 300]
NUM_SSD_BOXES = 1917
NEG_PRE_POSITIVE = 3
MATCH_THRESHOLD = 0.5
NUM_DEFAULT = [3, 6, 6, 6, 6, 6]
EXTRAS_IN_CHANNELS = [256, 576, 1280, 512, 256, 256]
EXTRAS_OUT_CHANNELS = [576, 1280, 512, 256, 256, 128]
EXTRAS_STRIDES = [1, 1, 2, 2, 2, 2]
EXTRAS_RATIO = [0.2, 0.2, 0.2, 0.25, 0.5, 0.25]
FEATURE_SIZE = [19, 10, 5, 3, 2, 1]
SCALES = [21, 45, 99, 153, 207, 261, 315]
ASPECT_RATIOS = [(1,), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)]
STEPS = (16, 32, 64, 100, 150, 300)
PRIOR_SCALING = (0.1, 0.2)
# `MINDRECORD_DIR` and `COCO_ROOT` are better to use absolute path.
MINDRECORD_DIR = "MindRecord_COCO"
COCO_ROOT = "coco2017"
TRAIN_DATA_TYPE = "train2017"
VAL_DATA_TYPE = "val2017"
INSTANCES_SET = "annotations/instances_{}.json"
COCO_CLASSES = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush')
NUM_CLASSES = len(COCO_CLASSES)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SSD dataset"""
from __future__ import division
import os
import math
import itertools as it
import numpy as np
import cv2
import mindspore.dataset as de
import mindspore.dataset.transforms.vision.c_transforms as C
from mindspore.mindrecord import FileWriter
from config import ConfigSSD
config = ConfigSSD()
class GeneratDefaultBoxes():
"""
Generate Default boxes for SSD, follows the order of (W, H, archor_sizes).
`self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [x, y, w, h].
`self.default_boxes_ltrb` has a shape as `self.default_boxes`, the last dimension is [x1, y1, x2, y2].
"""
def __init__(self):
fk = config.IMG_SHAPE[0] / np.array(config.STEPS)
self.default_boxes = []
for idex, feature_size in enumerate(config.FEATURE_SIZE):
sk1 = config.SCALES[idex] / config.IMG_SHAPE[0]
sk2 = config.SCALES[idex + 1] / config.IMG_SHAPE[0]
sk3 = math.sqrt(sk1 * sk2)
if config.NUM_DEFAULT[idex] == 3:
all_sizes = [(0.5, 1.0), (1.0, 1.0), (1.0, 0.5)]
else:
all_sizes = [(sk1, sk1), (sk3, sk3)]
for aspect_ratio in config.ASPECT_RATIOS[idex]:
w, h = sk1 * math.sqrt(aspect_ratio), sk1 / math.sqrt(aspect_ratio)
all_sizes.append((w, h))
all_sizes.append((h, w))
assert len(all_sizes) == config.NUM_DEFAULT[idex]
for i, j in it.product(range(feature_size), repeat=2):
for w, h in all_sizes:
cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex]
box = [np.clip(k, 0, 1) for k in (cx, cy, w, h)]
self.default_boxes.append(box)
def to_ltrb(cx, cy, w, h):
return cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
# For IoU calculation
self.default_boxes_ltrb = np.array(tuple(to_ltrb(*i) for i in self.default_boxes), dtype='float32')
self.default_boxes = np.array(self.default_boxes, dtype='float32')
default_boxes_ltrb = GeneratDefaultBoxes().default_boxes_ltrb
default_boxes = GeneratDefaultBoxes().default_boxes
x1, y1, x2, y2 = np.split(default_boxes_ltrb[:, :4], 4, axis=-1)
vol_anchors = (x2 - x1) * (y2 - y1)
matching_threshold = config.MATCH_THRESHOLD
def ssd_bboxes_encode(boxes):
"""
Labels anchors with ground truth inputs.
Args:
boxex: ground truth with shape [N, 5], for each row, it stores [x, y, w, h, cls].
Returns:
gt_loc: location ground truth with shape [num_anchors, 4].
gt_label: class ground truth with shape [num_anchors, 1].
num_matched_boxes: number of positives in an image.
"""
def jaccard_with_anchors(bbox):
"""Compute jaccard score a box and the anchors."""
# Intersection bbox and volume.
xmin = np.maximum(x1, bbox[0])
ymin = np.maximum(y1, bbox[1])
xmax = np.minimum(x2, bbox[2])
ymax = np.minimum(y2, bbox[3])
w = np.maximum(xmax - xmin, 0.)
h = np.maximum(ymax - ymin, 0.)
# Volumes.
inter_vol = h * w
union_vol = vol_anchors + (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) - inter_vol
jaccard = inter_vol / union_vol
return np.squeeze(jaccard)
pre_scores = np.zeros((config.NUM_SSD_BOXES), dtype=np.float32)
t_boxes = np.zeros((config.NUM_SSD_BOXES, 4), dtype=np.float32)
t_label = np.zeros((config.NUM_SSD_BOXES), dtype=np.int64)
for bbox in boxes:
label = int(bbox[4])
scores = jaccard_with_anchors(bbox)
mask = (scores > matching_threshold)
if not np.any(mask):
mask[np.argmax(scores)] = True
mask = mask & (scores > pre_scores)
pre_scores = np.maximum(pre_scores, scores)
t_label = mask * label + (1 - mask) * t_label
for i in range(4):
t_boxes[:, i] = mask * bbox[i] + (1 - mask) * t_boxes[:, i]
index = np.nonzero(t_label)
# Transform to ltrb.
bboxes = np.zeros((config.NUM_SSD_BOXES, 4), dtype=np.float32)
bboxes[:, [0, 1]] = (t_boxes[:, [0, 1]] + t_boxes[:, [2, 3]]) / 2
bboxes[:, [2, 3]] = t_boxes[:, [2, 3]] - t_boxes[:, [0, 1]]
# Encode features.
bboxes_t = bboxes[index]
default_boxes_t = default_boxes[index]
bboxes_t[:, :2] = (bboxes_t[:, :2] - default_boxes_t[:, :2]) / (default_boxes_t[:, 2:] * config.PRIOR_SCALING[0])
bboxes_t[:, 2:4] = np.log(bboxes_t[:, 2:4] / default_boxes_t[:, 2:4]) / config.PRIOR_SCALING[1]
bboxes[index] = bboxes_t
num_match_num = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32)
return bboxes, t_label.astype(np.int32), num_match_num
def ssd_bboxes_decode(boxes, index, image_shape):
"""Decode predict boxes to [x, y, w, h]"""
boxes_t = boxes[index]
default_boxes_t = default_boxes[index]
boxes_t[:, :2] = boxes_t[:, :2] * config.PRIOR_SCALING[0] * default_boxes_t[:, 2:] + default_boxes_t[:, :2]
boxes_t[:, 2:4] = np.exp(boxes_t[:, 2:4] * config.PRIOR_SCALING[1]) * default_boxes_t[:, 2:4]
bboxes = np.zeros((len(boxes_t), 4), dtype=np.float32)
bboxes[:, [0, 1]] = boxes_t[:, [0, 1]] - boxes_t[:, [2, 3]] / 2
bboxes[:, [2, 3]] = boxes_t[:, [0, 1]] + boxes_t[:, [2, 3]] / 2
return bboxes
def preprocess_fn(image, box, is_training):
"""Preprocess function for dataset."""
def _rand(a=0., b=1.):
"""Generate random."""
return np.random.rand() * (b - a) + a
def _infer_data(image, input_shape, box):
img_h, img_w, _ = image.shape
input_h, input_w = input_shape
scale = min(float(input_w) / float(img_w), float(input_h) / float(img_h))
nw = int(img_w * scale)
nh = int(img_h * scale)
image = cv2.resize(image, (nw, nh))
new_image = np.zeros((input_h, input_w, 3), np.float32)
dh = (input_h - nh) // 2
dw = (input_w - nw) // 2
new_image[dh: (nh + dh), dw: (nw + dw), :] = image
image = new_image
#When the channels of image is 1
if len(image.shape) == 2:
image = np.expand_dims(image, axis=-1)
image = np.concatenate([image, image, image], axis=-1)
box = box.astype(np.float32)
box[:, [0, 2]] = (box[:, [0, 2]] * scale + dw) / input_w
box[:, [1, 3]] = (box[:, [1, 3]] * scale + dh) / input_h
return image, np.array((img_h, img_w), np.float32), box
def _data_aug(image, box, is_training, image_size=(300, 300)):
"""Data augmentation function."""
ih, iw, _ = image.shape
w, h = image_size
if not is_training:
return _infer_data(image, image_size, box)
# Random settings
scale_w = _rand(0.75, 1.25)
scale_h = _rand(0.75, 1.25)
flip = _rand() < .5
nw = iw * scale_w
nh = ih * scale_h
scale = min(w / nw, h / nh)
nw = int(scale * nw)
nh = int(scale * nh)
# Resize image
image = cv2.resize(image, (nw, nh))
# place image
new_image = np.zeros((h, w, 3), dtype=np.float32)
dw = (w - nw) // 2
dh = (h - nh) // 2
new_image[dh:dh + nh, dw:dw + nw, :] = image
image = new_image
# Flip image or not
if flip:
image = cv2.flip(image, 1, dst=None)
# Convert image to gray or not
gray = _rand() < .25
if gray:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# When the channels of image is 1
if len(image.shape) == 2:
image = np.expand_dims(image, axis=-1)
image = np.concatenate([image, image, image], axis=-1)
box = box.astype(np.float32)
# Transform box with shape[x1, y1, x2, y2].
box[:, [0, 2]] = (box[:, [0, 2]] * scale * scale_w + dw) / w
box[:, [1, 3]] = (box[:, [1, 3]] * scale * scale_h + dh) / h
if flip:
box[:, [0, 2]] = 1 - box[:, [2, 0]]
box, label, num_match_num = ssd_bboxes_encode(box)
return image, box, label, num_match_num
return _data_aug(image, box, is_training, image_size=config.IMG_SHAPE)
def create_coco_label(is_training):
"""Get image path and annotation from COCO."""
from pycocotools.coco import COCO
coco_root = config.COCO_ROOT
data_type = config.VAL_DATA_TYPE
if is_training:
data_type = config.TRAIN_DATA_TYPE
#Classes need to train or test.
train_cls = config.COCO_CLASSES
train_cls_dict = {}
for i, cls in enumerate(train_cls):
train_cls_dict[cls] = i
anno_json = os.path.join(coco_root, config.INSTANCES_SET.format(data_type))
coco = COCO(anno_json)
classs_dict = {}
cat_ids = coco.loadCats(coco.getCatIds())
for cat in cat_ids:
classs_dict[cat["id"]] = cat["name"]
image_ids = coco.getImgIds()
image_files = []
image_anno_dict = {}
for img_id in image_ids:
image_info = coco.loadImgs(img_id)
file_name = image_info[0]["file_name"]
anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = coco.loadAnns(anno_ids)
image_path = os.path.join(coco_root, data_type, file_name)
annos = []
for label in anno:
bbox = label["bbox"]
class_name = classs_dict[label["category_id"]]
if class_name in train_cls:
x_min, x_max = bbox[0], bbox[0] + bbox[2]
y_min, y_max = bbox[1], bbox[1] + bbox[3]
annos.append(list(map(round, [x_min, y_min, x_max, y_max])) + [train_cls_dict[class_name]])
if len(annos) >= 1:
image_files.append(image_path)
image_anno_dict[image_path] = np.array(annos)
return image_files, image_anno_dict
def anno_parser(annos_str):
"""Parse annotation from string to list."""
annos = []
for anno_str in annos_str:
anno = list(map(int, anno_str.strip().split(',')))
annos.append(anno)
return annos
def filter_valid_data(image_dir, anno_path):
"""Filter valid image file, which both in image_dir and anno_path."""
image_files = []
image_anno_dict = {}
if not os.path.isdir(image_dir):
raise RuntimeError("Path given is not valid.")
if not os.path.isfile(anno_path):
raise RuntimeError("Annotation file is not valid.")
with open(anno_path, "rb") as f:
lines = f.readlines()
for line in lines:
line_str = line.decode("utf-8").strip()
line_split = str(line_str).split(' ')
file_name = line_split[0]
image_path = os.path.join(image_dir, file_name)
if os.path.isfile(image_path):
image_anno_dict[image_path] = anno_parser(line_split[1:])
image_files.append(image_path)
return image_files, image_anno_dict
def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd.mindrecord", file_num=8):
"""Create MindRecord file."""
mindrecord_dir = config.MINDRECORD_DIR
mindrecord_path = os.path.join(mindrecord_dir, prefix)
writer = FileWriter(mindrecord_path, file_num)
if dataset == "coco":
image_files, image_anno_dict = create_coco_label(is_training)
else:
image_files, image_anno_dict = filter_valid_data(config.IMAGE_DIR, config.ANNO_PATH)
ssd_json = {
"image": {"type": "bytes"},
"annotation": {"type": "int32", "shape": [-1, 5]},
}
writer.add_schema(ssd_json, "ssd_json")
for image_name in image_files:
with open(image_name, 'rb') as f:
img = f.read()
annos = np.array(image_anno_dict[image_name], dtype=np.int32)
row = {"image": img, "annotation": annos}
writer.write_raw_data([row])
writer.commit()
def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num=1, rank=0,
is_training=True, num_parallel_workers=4):
"""Creatr SSD dataset with MindDataset."""
ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank,
num_parallel_workers=num_parallel_workers, shuffle=is_training)
decode = C.Decode()
ds = ds.map(input_columns=["image"], operations=decode)
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
if is_training:
hwc_to_chw = C.HWC2CHW()
ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "box", "label", "num_match_num"],
columns_order=["image", "box", "label", "num_match_num"],
operations=compose_map_func, python_multiprocessing=True, num_parallel_workers=num_parallel_workers)
ds = ds.map(input_columns=["image"], operations=hwc_to_chw, python_multiprocessing=True,
num_parallel_workers=num_parallel_workers)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_num)
else:
hwc_to_chw = C.HWC2CHW()
ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "annotation"],
columns_order=["image", "image_shape", "annotation"],
operations=compose_map_func)
ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_num)
return ds
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation for SSD"""
import os
import argparse
import time
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.model_zoo.ssd import SSD300, ssd_mobilenet_v2
from dataset import create_ssd_dataset, data_to_mindrecord_byte_image
from config import ConfigSSD
from util import metrics
def ssd_eval(dataset_path, ckpt_path):
"""SSD evaluation."""
ds = create_ssd_dataset(dataset_path, batch_size=1, repeat_num=1, is_training=False)
net = SSD300(ssd_mobilenet_v2(), ConfigSSD(), is_training=False)
print("Load Checkpoint!")
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
net.set_train(False)
i = 1.
total = ds.get_dataset_size()
start = time.time()
pred_data = []
print("\n========================================\n")
print("total images num: ", total)
print("Processing, please wait a moment.")
for data in ds.create_dict_iterator():
img_np = data['image']
image_shape = data['image_shape']
annotation = data['annotation']
output = net(Tensor(img_np))
for batch_idx in range(img_np.shape[0]):
pred_data.append({"boxes": output[0].asnumpy()[batch_idx],
"box_scores": output[1].asnumpy()[batch_idx],
"annotation": annotation,
"image_shape": image_shape})
percent = round(i / total * 100, 2)
print(f' {str(percent)} [{i}/{total}]', end='\r')
i += 1
cost_time = int((time.time() - start) * 1000)
print(f' 100% [{total}/{total}] cost {cost_time} ms')
mAP = metrics(pred_data)
print("\n========================================\n")
print(f"mAP: {mAP}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='SSD evaluation')
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True)
config = ConfigSSD()
prefix = "ssd_eval.mindrecord"
mindrecord_dir = config.MINDRECORD_DIR
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
if not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
if args_opt.dataset == "coco":
if os.path.isdir(config.COCO_ROOT):
print("Create Mindrecord.")
data_to_mindrecord_byte_image("coco", False, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("COCO_ROOT not exits.")
else:
if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH):
print("Create Mindrecord.")
data_to_mindrecord_byte_image("other", False, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("IMAGE_DIR or ANNO_PATH not exits.")
print("Start Eval!")
ssd_eval(mindrecord_file, args_opt.checkpoint_path)
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE MINDSPORE_HCCL_CONFIG_PATH"
echo "for example: sh run_distribute_train.sh 8 150 coco /data/hccl.json"
echo "It is better to use absolute path."
echo "The learning rate is 0.4 as default, if you want other lr, please change the value in this script."
echo "=============================================================================================================="
# Before start distribute train, first create mindrecord files.
python train.py --only_create_dataset=1
echo "After running the scipt, the network runs in the background. The log will be generated in LOGx/log.txt"
export RANK_SIZE=$1
EPOCH_SIZE=$2
DATASET=$3
export MINDSPORE_HCCL_CONFIG_PATH=$4
for((i=0;i<RANK_SIZE;i++))
do
export DEVICE_ID=$i
rm -rf LOG$i
mkdir ./LOG$i
cp *.py ./LOG$i
cd ./LOG$i || exit
export RANK_ID=$i
echo "start training for rank $i, device $DEVICE_ID"
env > env.log
python ../train.py \
--distribute=1 \
--lr=0.4 \
--dataset=$DATASET \
--device_num=$RANK_SIZE \
--device_id=$DEVICE_ID \
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
cd ../
done
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train SSD and get checkpoint files."""
import os
import math
import argparse
import numpy as np
import mindspore.nn as nn
from mindspore import context, Tensor
from mindspore.communication.management import init
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
from mindspore.train import Model, ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common.initializer import initializer
from mindspore.model_zoo.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2
from config import ConfigSSD
from dataset import create_ssd_dataset, data_to_mindrecord_byte_image
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
"""
generate learning rate array
Args:
global_step(int): total steps of the training
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_end + (lr_max - lr_end) * \
(1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2.
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
return learning_rate
def init_net_param(network, initialize_mode='XavierUniform'):
"""Init the parameters in net."""
params = network.trainable_params()
for p in params:
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
p.set_parameter_data(initializer(initialize_mode, p.data.shape(), p.data.dtype()))
def main():
parser = argparse.ArgumentParser(description="SSD training")
parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create "
"Mindrecord, default is false.")
parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--lr", type=float, default=0.25, help="Learning rate, default is 0.25.")
parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.")
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, defalut is coco.")
parser.add_argument("--epoch_size", type=int, default=70, help="Epoch size, default is 70.")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.")
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path.")
parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.")
parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True)
if args_opt.distribute:
device_num = args_opt.device_num
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
device_num=device_num)
init()
rank = args_opt.device_id % device_num
else:
rank = 0
device_num = 1
print("Start create dataset!")
# It will generate mindrecord file in args_opt.mindrecord_dir,
# and the file name is ssd.mindrecord0, 1, ... file_num.
config = ConfigSSD()
prefix = "ssd.mindrecord"
mindrecord_dir = config.MINDRECORD_DIR
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
if not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
if args_opt.dataset == "coco":
if os.path.isdir(config.COCO_ROOT):
print("Create Mindrecord.")
data_to_mindrecord_byte_image("coco", True, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("COCO_ROOT not exits.")
else:
if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH):
print("Create Mindrecord.")
data_to_mindrecord_byte_image("other", True, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("IMAGE_DIR or ANNO_PATH not exits.")
if not args_opt.only_create_dataset:
loss_scale = float(args_opt.loss_scale)
# When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0.
dataset = create_ssd_dataset(mindrecord_file, repeat_num=args_opt.epoch_size,
batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
dataset_size = dataset.get_dataset_size()
print("Create dataset done!")
ssd = SSD300(backbone=ssd_mobilenet_v2(), config=config)
net = SSDWithLossCell(ssd, config)
init_net_param(net)
# checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs)
ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=None, config=ckpt_config)
lr = Tensor(get_lr(global_step=0, lr_init=0, lr_end=0, lr_max=args_opt.lr,
warmup_epochs=max(args_opt.epoch_size // 20, 1),
total_epochs=args_opt.epoch_size,
steps_per_epoch=dataset_size))
opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 0.0001, loss_scale)
net = TrainingWrapper(net, opt, loss_scale)
if args_opt.checkpoint_path != "":
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb]
model = Model(net)
dataset_sink_mode = False
if args_opt.mode == "sink":
print("In sink mode, one epoch return a loss.")
dataset_sink_mode = True
print("Start train SSD, the first epoch will be slower because of the graph compilation.")
model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode)
if __name__ == '__main__':
main()
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""metrics utils"""
import numpy as np
from config import ConfigSSD
from dataset import ssd_bboxes_decode
def calc_iou(bbox_pred, bbox_ground):
"""Calculate iou of predicted bbox and ground truth."""
bbox_pred = np.expand_dims(bbox_pred, axis=0)
pred_w = bbox_pred[:, 2] - bbox_pred[:, 0]
pred_h = bbox_pred[:, 3] - bbox_pred[:, 1]
pred_area = pred_w * pred_h
gt_w = bbox_ground[:, 2] - bbox_ground[:, 0]
gt_h = bbox_ground[:, 3] - bbox_ground[:, 1]
gt_area = gt_w * gt_h
iw = np.minimum(bbox_pred[:, 2], bbox_ground[:, 2]) - np.maximum(bbox_pred[:, 0], bbox_ground[:, 0])
ih = np.minimum(bbox_pred[:, 3], bbox_ground[:, 3]) - np.maximum(bbox_pred[:, 1], bbox_ground[:, 1])
iw = np.maximum(iw, 0)
ih = np.maximum(ih, 0)
intersection_area = iw * ih
union_area = pred_area + gt_area - intersection_area
union_area = np.maximum(union_area, np.finfo(float).eps)
iou = intersection_area * 1. / union_area
return iou
def apply_nms(all_boxes, all_scores, thres, max_boxes):
"""Apply NMS to bboxes."""
x1 = all_boxes[:, 0]
y1 = all_boxes[:, 1]
x2 = all_boxes[:, 2]
y2 = all_boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = all_scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
if len(keep) >= max_boxes:
break
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thres)[0]
order = order[inds + 1]
return keep
def calc_ap(recall, precision):
"""Calculate AP."""
correct_recall = np.concatenate(([0.], recall, [1.]))
correct_precision = np.concatenate(([0.], precision, [0.]))
for i in range(correct_recall.size - 1, 0, -1):
correct_precision[i - 1] = np.maximum(correct_precision[i - 1], correct_precision[i])
i = np.where(correct_recall[1:] != correct_recall[:-1])[0]
ap = np.sum((correct_recall[i + 1] - correct_recall[i]) * correct_precision[i + 1])
return ap
def metrics(pred_data):
"""Calculate mAP of predicted bboxes."""
config = ConfigSSD()
num_classes = config.NUM_CLASSES
all_detections = [None for i in range(num_classes)]
all_pred_scores = [None for i in range(num_classes)]
all_annotations = [None for i in range(num_classes)]
average_precisions = {}
num = [0 for i in range(num_classes)]
accurate_num = [0 for i in range(num_classes)]
for sample in pred_data:
pred_boxes = sample['boxes']
boxes_scores = sample['box_scores']
annotation = sample['annotation']
image_shape = sample['image_shape']
annotation = np.squeeze(annotation, axis=0)
image_shape = np.squeeze(image_shape, axis=0)
pred_labels = np.argmax(boxes_scores, axis=-1)
index = np.nonzero(pred_labels)
pred_boxes = ssd_bboxes_decode(pred_boxes, index, image_shape)
pred_boxes = pred_boxes.clip(0, 1)
boxes_scores = np.max(boxes_scores, axis=-1)
boxes_scores = boxes_scores[index]
pred_labels = pred_labels[index]
top_k = 50
for c in range(1, num_classes):
if len(pred_labels) >= 1:
class_box_scores = boxes_scores[pred_labels == c]
class_boxes = pred_boxes[pred_labels == c]
nms_index = apply_nms(class_boxes, class_box_scores, config.MATCH_THRESHOLD, top_k)
class_boxes = class_boxes[nms_index]
class_box_scores = class_box_scores[nms_index]
cmask = class_box_scores > 0.5
class_boxes = class_boxes[cmask]
class_box_scores = class_box_scores[cmask]
all_detections[c] = class_boxes
all_pred_scores[c] = class_box_scores
for c in range(1, num_classes):
if len(annotation) >= 1:
all_annotations[c] = annotation[annotation[:, 4] == c, :4]
for c in range(1, num_classes):
false_positives = np.zeros((0,))
true_positives = np.zeros((0,))
scores = np.zeros((0,))
num_annotations = 0.0
annotations = all_annotations[c]
num_annotations += annotations.shape[0]
detections = all_detections[c]
pred_scores = all_pred_scores[c]
for index, detection in enumerate(detections):
scores = np.append(scores, pred_scores[index])
if len(annotations) >= 1:
IoUs = calc_iou(detection, annotations)
assigned_anno = np.argmax(IoUs)
max_overlap = IoUs[assigned_anno]
if max_overlap >= 0.5:
false_positives = np.append(false_positives, 0)
true_positives = np.append(true_positives, 1)
else:
false_positives = np.append(false_positives, 1)
true_positives = np.append(true_positives, 0)
else:
false_positives = np.append(false_positives, 1)
true_positives = np.append(true_positives, 0)
if num_annotations == 0:
if c not in average_precisions.keys():
average_precisions[c] = 0
continue
accurate_num[c] = 1
indices = np.argsort(-scores)
false_positives = false_positives[indices]
true_positives = true_positives[indices]
false_positives = np.cumsum(false_positives)
true_positives = np.cumsum(true_positives)
recall = true_positives * 1. / num_annotations
precision = true_positives * 1. / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps)
average_precision = calc_ap(recall, precision)
if c not in average_precisions.keys():
average_precisions[c] = average_precision
else:
average_precisions[c] += average_precision
num[c] += 1
count = 0
for key in average_precisions:
if num[key] != 0:
count += (average_precisions[key] / num[key])
mAP = count * 1. / accurate_num.count(1)
return mAP
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SSD net based MobilenetV2."""
import mindspore.common.dtype as mstype
import mindspore as ms
import mindspore.nn as nn
from mindspore import context
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.common.initializer import initializer
from .mobilenet import InvertedResidual, ConvBNReLU
def _conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same'):
weight_shape = (out_channel, in_channel, kernel_size, kernel_size)
weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32)
return nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride,
padding=0, pad_mode=pad_mod, weight_init=weight)
def _make_divisible(v, divisor, min_value=None):
"""nsures that all layers have a channel number that is divisible by 8."""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class FlattenConcat(nn.Cell):
"""
Concatenate predictions into a single tensor.
Args:
config (Class): The default config of SSD.
Returns:
Tensor, flatten predictions.
"""
def __init__(self, config):
super(FlattenConcat, self).__init__()
self.sizes = config.FEATURE_SIZE
self.length = len(self.sizes)
self.num_default = config.NUM_DEFAULT
self.concat = P.Concat(axis=-1)
self.transpose = P.Transpose()
def construct(self, x):
output = ()
for i in range(self.length):
shape = F.shape(x[i])
mid_shape = (shape[0], -1, self.num_default[i], self.sizes[i], self.sizes[i])
final_shape = (shape[0], -1, self.num_default[i] * self.sizes[i] * self.sizes[i])
output += (F.reshape(F.reshape(x[i], mid_shape), final_shape),)
res = self.concat(output)
return self.transpose(res, (0, 2, 1))
class MultiBox(nn.Cell):
"""
Multibox conv layers. Each multibox layer contains class conf scores and localization predictions.
Args:
config (Class): The default config of SSD.
Returns:
Tensor, localization predictions.
Tensor, class conf scores.
"""
def __init__(self, config):
super(MultiBox, self).__init__()
num_classes = config.NUM_CLASSES
out_channels = config.EXTRAS_OUT_CHANNELS
num_default = config.NUM_DEFAULT
loc_layers = []
cls_layers = []
for k, out_channel in enumerate(out_channels):
loc_layers += [_conv2d(out_channel, 4 * num_default[k],
kernel_size=3, stride=1, pad_mod='same')]
cls_layers += [_conv2d(out_channel, num_classes * num_default[k],
kernel_size=3, stride=1, pad_mod='same')]
self.multi_loc_layers = nn.layer.CellList(loc_layers)
self.multi_cls_layers = nn.layer.CellList(cls_layers)
self.flatten_concat = FlattenConcat(config)
def construct(self, inputs):
loc_outputs = ()
cls_outputs = ()
for i in range(len(self.multi_loc_layers)):
loc_outputs += (self.multi_loc_layers[i](inputs[i]),)
cls_outputs += (self.multi_cls_layers[i](inputs[i]),)
return self.flatten_concat(loc_outputs), self.flatten_concat(cls_outputs)
class SSD300(nn.Cell):
"""
SSD300 Network. Default backbone is resnet34.
Args:
backbone (Cell): Backbone Network.
config (Class): The default config of SSD.
Returns:
Tensor, localization predictions.
Tensor, class conf scores.
Examples:backbone
SSD300(backbone=resnet34(num_classes=None),
config=ConfigSSDResNet34()).
"""
def __init__(self, backbone, config, is_training=True):
super(SSD300, self).__init__()
self.backbone = backbone
in_channels = config.EXTRAS_IN_CHANNELS
out_channels = config.EXTRAS_OUT_CHANNELS
ratios = config.EXTRAS_RATIO
strides = config.EXTRAS_STRIDES
residual_list = []
for i in range(2, len(in_channels)):
residual = InvertedResidual(in_channels[i], out_channels[i], stride=strides[i], expand_ratio=ratios[i])
residual_list.append(residual)
self.multi_residual = nn.layer.CellList(residual_list)
self.multi_box = MultiBox(config)
self.is_training = is_training
if not is_training:
self.softmax = P.Softmax()
def construct(self, x):
layer_out_13, output = self.backbone(x)
multi_feature = (layer_out_13, output)
feature = output
for residual in self.multi_residual:
feature = residual(feature)
multi_feature += (feature,)
pred_loc, pred_label = self.multi_box(multi_feature)
if not self.is_training:
pred_label = self.softmax(pred_label)
return pred_loc, pred_label
class LocalizationLoss(nn.Cell):
""""
Computes the localization loss with SmoothL1Loss.
Returns:
Tensor, box regression loss.
"""
def __init__(self):
super(LocalizationLoss, self).__init__()
self.reduce_sum = P.ReduceSum()
self.reduce_mean = P.ReduceMean()
self.loss = nn.SmoothL1Loss()
self.expand_dims = P.ExpandDims()
self.less = P.Less()
def construct(self, pred_loc, gt_loc, gt_label, num_matched_boxes):
mask = F.cast(self.less(0, gt_label), mstype.float32)
mask = self.expand_dims(mask, -1)
smooth_l1 = self.loss(gt_loc, pred_loc) * mask
box_loss = self.reduce_sum(smooth_l1, 1)
return self.reduce_mean(box_loss / F.cast(num_matched_boxes, mstype.float32), (0, 1))
class ClassificationLoss(nn.Cell):
""""
Computes the classification loss with hard example mining.
Args:
config (Class): The default config of SSD.
Returns:
Tensor, classification loss.
"""
def __init__(self, config):
super(ClassificationLoss, self).__init__()
self.num_classes = config.NUM_CLASSES
self.num_boxes = config.NUM_SSD_BOXES
self.neg_pre_positive = config.NEG_PRE_POSITIVE
self.minimum = P.Minimum()
self.less = P.Less()
self.sort = P.TopK()
self.tile = P.Tile()
self.reduce_sum = P.ReduceSum()
self.reduce_mean = P.ReduceMean()
self.expand_dims = P.ExpandDims()
self.sort_descend = P.TopK(True)
self.cross_entropy = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
def construct(self, pred_label, gt_label, num_matched_boxes):
gt_label = F.cast(gt_label, mstype.int32)
mask = F.cast(self.less(0, gt_label), mstype.float32)
gt_label_shape = F.shape(gt_label)
pred_label = F.reshape(pred_label, (-1, self.num_classes))
gt_label = F.reshape(gt_label, (-1,))
cross_entropy = self.cross_entropy(pred_label, gt_label)
cross_entropy = F.reshape(cross_entropy, gt_label_shape)
# Hard example mining
num_matched_boxes = F.reshape(num_matched_boxes, (-1,))
neg_masked_cross_entropy = F.cast(cross_entropy * (1- mask), mstype.float16)
_, loss_idx = self.sort_descend(neg_masked_cross_entropy, self.num_boxes)
_, relative_position = self.sort(F.cast(loss_idx, mstype.float16), self.num_boxes)
num_neg_boxes = self.minimum(num_matched_boxes * self.neg_pre_positive, self.num_boxes)
tile_num_neg_boxes = self.tile(self.expand_dims(num_neg_boxes, -1), (1, self.num_boxes))
top_k_neg_mask = F.cast(self.less(relative_position, tile_num_neg_boxes), mstype.float32)
class_loss = self.reduce_sum(cross_entropy * (mask + top_k_neg_mask), 1)
return self.reduce_mean(class_loss / F.cast(num_matched_boxes, mstype.float32), 0)
class SSDWithLossCell(nn.Cell):
""""
Provide SSD training loss through network.
Args:
network (Cell): The training network.
config (Class): SSD config.
Returns:
Tensor, the loss of the network.
"""
def __init__(self, network, config):
super(SSDWithLossCell, self).__init__()
self.network = network
self.class_loss = ClassificationLoss(config)
self.box_loss = LocalizationLoss()
def construct(self, x, gt_loc, gt_label, num_matched_boxes):
pred_loc, pred_label = self.network(x)
loss_cls = self.class_loss(pred_label, gt_label, num_matched_boxes)
loss_loc = self.box_loss(pred_loc, gt_loc, gt_label, num_matched_boxes)
return loss_cls + loss_loc
class TrainingWrapper(nn.Cell):
"""
Encapsulation class of SSD network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default: 1.0.
"""
def __init__(self, network, optimizer, sens=1.0):
super(TrainingWrapper, self).__init__(auto_prefix=False)
self.network = network
self.weights = ms.ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")
if auto_parallel_context().get_device_num_is_set():
degree = context.get_auto_parallel_context("device_num")
else:
degree = get_group_size()
self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, *args):
weights = self.weights
loss = self.network(*args)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*args, sens)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))
class SSDWithMobileNetV2(nn.Cell):
"""
MobileNetV2 architecture for SSD backbone.
Args:
width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1.
inverted_residual_setting (list): Inverted residual settings. Default is None
round_nearest (list): Channel round to. Default is 8
Returns:
Tensor, the 13th feature after ConvBNReLU in MobileNetV2.
Tensor, the last feature in MobileNetV2.
Examples:
>>> SSDWithMobileNetV2()
"""
def __init__(self, width_mult=1.0, inverted_residual_setting=None, round_nearest=8):
super(SSDWithMobileNetV2, self).__init__()
block = InvertedResidual
input_channel = 32
last_channel = 1280
if inverted_residual_setting is None:
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
if len(inverted_residual_setting[0]) != 4:
raise ValueError("inverted_residual_setting should be non-empty "
"or a 4-element list, got {}".format(inverted_residual_setting))
#building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features = [ConvBNReLU(3, input_channel, stride=2)]
# building inverted residual blocks
layer_index = 0
for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest)
for i in range(n):
if layer_index == 13:
hidden_dim = int(round(input_channel * t))
self.expand_layer_conv_13 = ConvBNReLU(input_channel, hidden_dim, kernel_size=1)
stride = s if i == 0 else 1
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
input_channel = output_channel
layer_index += 1
# building last several layers
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
self.features_1 = nn.SequentialCell(features[:14])
self.features_2 = nn.SequentialCell(features[14:])
def construct(self, x):
out = self.features_1(x)
expand_layer_conv_13 = self.expand_layer_conv_13(out)
out = self.features_2(out)
return expand_layer_conv_13, out
def get_out_channels(self):
return self.last_channel
def ssd_mobilenet_v2(**kwargs):
return SSDWithMobileNetV2(**kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册