diff --git a/RELEASE.md b/RELEASE.md index f919bd7a2fab2d5494ed201b90a51a449145ab1c..686f4ccac10e5febc1a9b19f5cc2a800e072abcf 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -7,6 +7,7 @@ * DeepFM: a factorization-machine based neural network for CTR prediction on Criteo dataset. * DeepLabV3: significantly improves over our previous DeepLab versions without DenseCRF post-processing and attains comparable performance with other state-of-art models on the PASCAL VOC 2007 semantic image segmentation benchmark. * Faster-RCNN: towards real-time object detection with region proposal networks on COCO 2017 dataset. + * SSD: a single stage object detection methods on COCO 2017 dataset. * GoogLeNet: a deep convolutional neural network architecture codenamed Inception V1 for classification and detection on CIFAR-10 dataset. * Wide&Deep: jointly trained wide linear models and deep neural networks for recommender systems on Criteo dataset. * Frontend and User Interface diff --git a/example/ssd_coco2017/README.md b/example/ssd_coco2017/README.md index bd43344b8b87a8044a45e1dac34ab1e62446e57a..0e0db68944bbf1671f82a75a570f62310774f97b 100644 --- a/example/ssd_coco2017/README.md +++ b/example/ssd_coco2017/README.md @@ -60,10 +60,10 @@ To train the model, run `train.py`. If the `MINDRECORD_DIR` is empty, it will ge - Distribute mode ``` - sh run_distribute_train.sh 8 150 coco /data/hccl.json + sh run_distribute_train.sh 8 500 0.2 coco /data/hccl.json ``` - The input parameters are device numbers, epoch size, dataset mode and [hccl json configuration file](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). **It is better to use absolute path.** + The input parameters are device numbers, epoch size, learning rate, dataset mode and [hccl json configuration file](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). **It is better to use absolute path.** You will get the loss value of each step as following: @@ -75,14 +75,15 @@ epoch: 3 step: 455, loss is 5.458992 epoch: 148 step: 455, loss is 1.8340507 epoch: 149 step: 455, loss is 2.0876894 epoch: 150 step: 455, loss is 2.239692 +... ``` ### Evaluation -for evaluation , run `eval.py` with `ckpt_path`. `ckpt_path` is the path of [checkpoint](https://www.mindspore.cn/tutorial/en/master/use/saving_and_loading_model_parameters.html) file. +for evaluation , run `eval.py` with `checkpoint_path`. `checkpoint_path` is the path of [checkpoint](https://www.mindspore.cn/tutorial/en/master/use/saving_and_loading_model_parameters.html) file. ``` -python eval.py --ckpt_path ssd.ckpt --dataset coco +python eval.py --checkpoint_path ssd.ckpt --dataset coco ``` You can run ```python eval.py -h``` to get more information. diff --git a/example/ssd_coco2017/config.py b/example/ssd_coco2017/config.py index 452aaf970081b6c59046c305c199702407d0a2c0..62df38b762bdd65b92fb902813abea9652c79e9e 100644 --- a/example/ssd_coco2017/config.py +++ b/example/ssd_coco2017/config.py @@ -27,6 +27,9 @@ class ConfigSSD: NUM_SSD_BOXES = 1917 NEG_PRE_POSITIVE = 3 MATCH_THRESHOLD = 0.5 + NMS_THRESHOLD = 0.6 + MIN_SCORE = 0.05 + TOP_K = 100 NUM_DEFAULT = [3, 6, 6, 6, 6, 6] EXTRAS_IN_CHANNELS = [256, 576, 1280, 512, 256, 256] @@ -34,20 +37,21 @@ class ConfigSSD: EXTRAS_STRIDES = [1, 1, 2, 2, 2, 2] EXTRAS_RATIO = [0.2, 0.2, 0.2, 0.25, 0.5, 0.25] FEATURE_SIZE = [19, 10, 5, 3, 2, 1] - SCALES = [21, 45, 99, 153, 207, 261, 315] - ASPECT_RATIOS = [(1,), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)] + MIN_SCALE = 0.2 + MAX_SCALE = 0.95 + ASPECT_RATIOS = [(2,), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)] STEPS = (16, 32, 64, 100, 150, 300) PRIOR_SCALING = (0.1, 0.2) # `MINDRECORD_DIR` and `COCO_ROOT` are better to use absolute path. - MINDRECORD_DIR = "MindRecord_COCO" - COCO_ROOT = "coco2017" + MINDRECORD_DIR = "/data/MindRecord_COCO" + COCO_ROOT = "/data/coco2017" TRAIN_DATA_TYPE = "train2017" VAL_DATA_TYPE = "val2017" INSTANCES_SET = "annotations/instances_{}.json" COCO_CLASSES = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', - 'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', @@ -58,7 +62,7 @@ class ConfigSSD: 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', - 'keyboard', 'cell phone', 'microwave oven', 'toaster', 'sink', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush') NUM_CLASSES = len(COCO_CLASSES) diff --git a/example/ssd_coco2017/dataset.py b/example/ssd_coco2017/dataset.py index b88b22c8626d520d142a9570b345f359a4ef3468..e228ce5a1500ab078a702c462215d2d2fc749108 100644 --- a/example/ssd_coco2017/dataset.py +++ b/example/ssd_coco2017/dataset.py @@ -32,36 +32,38 @@ config = ConfigSSD() class GeneratDefaultBoxes(): """ Generate Default boxes for SSD, follows the order of (W, H, archor_sizes). - `self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [x, y, w, h]. - `self.default_boxes_ltrb` has a shape as `self.default_boxes`, the last dimension is [x1, y1, x2, y2]. + `self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [y, x, h, w]. + `self.default_boxes_ltrb` has a shape as `self.default_boxes`, the last dimension is [y1, x1, y2, x2]. """ def __init__(self): fk = config.IMG_SHAPE[0] / np.array(config.STEPS) + scale_rate = (config.MAX_SCALE - config.MIN_SCALE) / (len(config.NUM_DEFAULT) - 1) + scales = [config.MIN_SCALE + scale_rate * i for i in range(len(config.NUM_DEFAULT))] + [1.0] self.default_boxes = [] for idex, feature_size in enumerate(config.FEATURE_SIZE): - sk1 = config.SCALES[idex] / config.IMG_SHAPE[0] - sk2 = config.SCALES[idex + 1] / config.IMG_SHAPE[0] + sk1 = scales[idex] + sk2 = scales[idex + 1] sk3 = math.sqrt(sk1 * sk2) - - if config.NUM_DEFAULT[idex] == 3: - all_sizes = [(0.5, 1.0), (1.0, 1.0), (1.0, 0.5)] + if idex == 0: + w, h = sk1 * math.sqrt(2), sk1 / math.sqrt(2) + all_sizes = [(0.1, 0.1), (w, h), (h, w)] else: - all_sizes = [(sk1, sk1), (sk3, sk3)] + all_sizes = [(sk1, sk1)] for aspect_ratio in config.ASPECT_RATIOS[idex]: w, h = sk1 * math.sqrt(aspect_ratio), sk1 / math.sqrt(aspect_ratio) all_sizes.append((w, h)) all_sizes.append((h, w)) + all_sizes.append((sk3, sk3)) assert len(all_sizes) == config.NUM_DEFAULT[idex] for i, j in it.product(range(feature_size), repeat=2): for w, h in all_sizes: cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex] - box = [np.clip(k, 0, 1) for k in (cx, cy, w, h)] - self.default_boxes.append(box) + self.default_boxes.append([cy, cx, h, w]) - def to_ltrb(cx, cy, w, h): - return cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2 + def to_ltrb(cy, cx, h, w): + return cy - h / 2, cx - w / 2, cy + h / 2, cx + w / 2 # For IoU calculation self.default_boxes_ltrb = np.array(tuple(to_ltrb(*i) for i in self.default_boxes), dtype='float32') @@ -70,17 +72,22 @@ class GeneratDefaultBoxes(): default_boxes_ltrb = GeneratDefaultBoxes().default_boxes_ltrb default_boxes = GeneratDefaultBoxes().default_boxes -x1, y1, x2, y2 = np.split(default_boxes_ltrb[:, :4], 4, axis=-1) +y1, x1, y2, x2 = np.split(default_boxes_ltrb[:, :4], 4, axis=-1) vol_anchors = (x2 - x1) * (y2 - y1) matching_threshold = config.MATCH_THRESHOLD +def _rand(a=0., b=1.): + """Generate random.""" + return np.random.rand() * (b - a) + a + + def ssd_bboxes_encode(boxes): """ Labels anchors with ground truth inputs. Args: - boxex: ground truth with shape [N, 5], for each row, it stores [x, y, w, h, cls]. + boxex: ground truth with shape [N, 5], for each row, it stores [y, x, h, w, cls]. Returns: gt_loc: location ground truth with shape [num_anchors, 4]. @@ -91,10 +98,10 @@ def ssd_bboxes_encode(boxes): def jaccard_with_anchors(bbox): """Compute jaccard score a box and the anchors.""" # Intersection bbox and volume. - xmin = np.maximum(x1, bbox[0]) - ymin = np.maximum(y1, bbox[1]) - xmax = np.minimum(x2, bbox[2]) - ymax = np.minimum(y2, bbox[3]) + ymin = np.maximum(y1, bbox[0]) + xmin = np.maximum(x1, bbox[1]) + ymax = np.minimum(y2, bbox[2]) + xmax = np.minimum(x2, bbox[3]) w = np.maximum(xmax - xmin, 0.) h = np.maximum(ymax - ymin, 0.) @@ -110,12 +117,11 @@ def ssd_bboxes_encode(boxes): for bbox in boxes: label = int(bbox[4]) scores = jaccard_with_anchors(bbox) + idx = np.argmax(scores) + scores[idx] = 2.0 mask = (scores > matching_threshold) - if not np.any(mask): - mask[np.argmax(scores)] = True - mask = mask & (scores > pre_scores) - pre_scores = np.maximum(pre_scores, scores) + pre_scores = np.maximum(pre_scores, scores * mask) t_label = mask * label + (1 - mask) * t_label for i in range(4): t_boxes[:, i] = mask * bbox[i] + (1 - mask) * t_boxes[:, i] @@ -134,13 +140,13 @@ def ssd_bboxes_encode(boxes): bboxes_t[:, 2:4] = np.log(bboxes_t[:, 2:4] / default_boxes_t[:, 2:4]) / config.PRIOR_SCALING[1] bboxes[index] = bboxes_t - num_match_num = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32) - return bboxes, t_label.astype(np.int32), num_match_num + num_match = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32) + return bboxes, t_label.astype(np.int32), num_match -def ssd_bboxes_decode(boxes, index): - """Decode predict boxes to [x, y, w, h]""" - boxes_t = boxes[index] - default_boxes_t = default_boxes[index] +def ssd_bboxes_decode(boxes): + """Decode predict boxes to [y, x, h, w]""" + boxes_t = boxes.copy() + default_boxes_t = default_boxes.copy() boxes_t[:, :2] = boxes_t[:, :2] * config.PRIOR_SCALING[0] * default_boxes_t[:, 2:] + default_boxes_t[:, :2] boxes_t[:, 2:4] = np.exp(boxes_t[:, 2:4] * config.PRIOR_SCALING[1]) * default_boxes_t[:, 2:4] @@ -149,41 +155,101 @@ def ssd_bboxes_decode(boxes, index): bboxes[:, [0, 1]] = boxes_t[:, [0, 1]] - boxes_t[:, [2, 3]] / 2 bboxes[:, [2, 3]] = boxes_t[:, [0, 1]] + boxes_t[:, [2, 3]] / 2 - return bboxes + return np.clip(bboxes, 0, 1) -def preprocess_fn(image, box, is_training): - """Preprocess function for dataset.""" - def _rand(a=0., b=1.): - """Generate random.""" - return np.random.rand() * (b - a) + a +def intersect(box_a, box_b): + """Compute the intersect of two sets of boxes.""" + max_yx = np.minimum(box_a[:, 2:4], box_b[2:4]) + min_yx = np.maximum(box_a[:, :2], box_b[:2]) + inter = np.clip((max_yx - min_yx), a_min=0, a_max=np.inf) + return inter[:, 0] * inter[:, 1] - def _infer_data(image, input_shape, box): - img_h, img_w, _ = image.shape - input_h, input_w = input_shape - scale = min(float(input_w) / float(img_w), float(input_h) / float(img_h)) - nw = int(img_w * scale) - nh = int(img_h * scale) +def jaccard_numpy(box_a, box_b): + """Compute the jaccard overlap of two sets of boxes.""" + inter = intersect(box_a, box_b) + area_a = ((box_a[:, 2] - box_a[:, 0]) * + (box_a[:, 3] - box_a[:, 1])) + area_b = ((box_b[2] - box_b[0]) * + (box_b[3] - box_b[1])) + union = area_a + area_b - inter + return inter / union + + +def random_sample_crop(image, boxes): + """Random Crop the image and boxes""" + height, width, _ = image.shape + min_iou = np.random.choice([None, 0.1, 0.3, 0.5, 0.7, 0.9]) + + if min_iou is None: + return image, boxes + + # max trails (50) + for _ in range(50): + image_t = image + + w = _rand(0.3, 1.0) * width + h = _rand(0.3, 1.0) * height + + # aspect ratio constraint b/t .5 & 2 + if h / w < 0.5 or h / w > 2: + continue + + left = _rand() * (width - w) + top = _rand() * (height - h) + + rect = np.array([int(top), int(left), int(top+h), int(left+w)]) + overlap = jaccard_numpy(boxes, rect) + + # dropout some boxes + drop_mask = overlap > 0 + if not drop_mask.any(): + continue + + if overlap[drop_mask].min() < min_iou: + continue + + image_t = image_t[rect[0]:rect[2], rect[1]:rect[3], :] + + centers = (boxes[:, :2] + boxes[:, 2:4]) / 2.0 + + m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) + m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) + + # mask in that both m1 and m2 are true + mask = m1 * m2 * drop_mask - image = cv2.resize(image, (nw, nh)) + # have any valid boxes? try again if not + if not mask.any(): + continue - new_image = np.zeros((input_h, input_w, 3), np.float32) - dh = (input_h - nh) // 2 - dw = (input_w - nw) // 2 - new_image[dh: (nh + dh), dw: (nw + dw), :] = image - image = new_image + # take only matching gt boxes + boxes_t = boxes[mask, :].copy() + + boxes_t[:, :2] = np.maximum(boxes_t[:, :2], rect[:2]) + boxes_t[:, :2] -= rect[:2] + boxes_t[:, 2:4] = np.minimum(boxes_t[:, 2:4], rect[2:4]) + boxes_t[:, 2:4] -= rect[:2] + + return image_t, boxes_t + return image, boxes + + +def preprocess_fn(img_id, image, box, is_training): + """Preprocess function for dataset.""" + def _infer_data(image, input_shape): + img_h, img_w, _ = image.shape + input_h, input_w = input_shape + + image = cv2.resize(image, (input_w, input_h)) #When the channels of image is 1 if len(image.shape) == 2: image = np.expand_dims(image, axis=-1) image = np.concatenate([image, image, image], axis=-1) - box = box.astype(np.float32) - - box[:, [0, 2]] = (box[:, [0, 2]] * scale + dw) / input_w - box[:, [1, 3]] = (box[:, [1, 3]] * scale + dh) / input_h - return image, np.array((img_h, img_w), np.float32), box + return img_id, image, np.array((img_h, img_w), np.float32) def _data_aug(image, box, is_training, image_size=(300, 300)): """Data augmentation function.""" @@ -191,53 +257,34 @@ def preprocess_fn(image, box, is_training): w, h = image_size if not is_training: - return _infer_data(image, image_size, box) - # Random settings - scale_w = _rand(0.75, 1.25) - scale_h = _rand(0.75, 1.25) + return _infer_data(image, image_size) - flip = _rand() < .5 - nw = iw * scale_w - nh = ih * scale_h - scale = min(w / nw, h / nh) - nw = int(scale * nw) - nh = int(scale * nh) + # Random crop + box = box.astype(np.float32) + image, box = random_sample_crop(image, box) + ih, iw, _ = image.shape # Resize image - image = cv2.resize(image, (nw, nh)) - - # place image - new_image = np.zeros((h, w, 3), dtype=np.float32) - dw = (w - nw) // 2 - dh = (h - nh) // 2 - new_image[dh:dh + nh, dw:dw + nw, :] = image - image = new_image + image = cv2.resize(image, (w, h)) # Flip image or not + flip = _rand() < .5 if flip: image = cv2.flip(image, 1, dst=None) - # Convert image to gray or not - gray = _rand() < .25 - if gray: - image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) - # When the channels of image is 1 if len(image.shape) == 2: image = np.expand_dims(image, axis=-1) image = np.concatenate([image, image, image], axis=-1) - box = box.astype(np.float32) - - # Transform box with shape[x1, y1, x2, y2]. - box[:, [0, 2]] = (box[:, [0, 2]] * scale * scale_w + dw) / w - box[:, [1, 3]] = (box[:, [1, 3]] * scale * scale_h + dh) / h + box[:, [0, 2]] = box[:, [0, 2]] / ih + box[:, [1, 3]] = box[:, [1, 3]] / iw if flip: - box[:, [0, 2]] = 1 - box[:, [2, 0]] + box[:, [1, 3]] = 1 - box[:, [3, 1]] - box, label, num_match_num = ssd_bboxes_encode(box) - return image, box, label, num_match_num + box, label, num_match = ssd_bboxes_encode(box) + return image, box, label, num_match return _data_aug(image, box, is_training, image_size=config.IMG_SHAPE) @@ -265,7 +312,8 @@ def create_coco_label(is_training): classs_dict[cat["id"]] = cat["name"] image_ids = coco.getImgIds() - image_files = [] + images = [] + image_path_dict = {} image_anno_dict = {} for img_id in image_ids: @@ -275,17 +323,23 @@ def create_coco_label(is_training): anno = coco.loadAnns(anno_ids) image_path = os.path.join(coco_root, data_type, file_name) annos = [] + iscrowd = False for label in anno: bbox = label["bbox"] class_name = classs_dict[label["category_id"]] + iscrowd = iscrowd or label["iscrowd"] if class_name in train_cls: x_min, x_max = bbox[0], bbox[0] + bbox[2] y_min, y_max = bbox[1], bbox[1] + bbox[3] - annos.append(list(map(round, [x_min, y_min, x_max, y_max])) + [train_cls_dict[class_name]]) + annos.append(list(map(round, [y_min, x_min, y_max, x_max])) + [train_cls_dict[class_name]]) + if not is_training and iscrowd: + continue if len(annos) >= 1: - image_files.append(image_path) - image_anno_dict[image_path] = np.array(annos) - return image_files, image_anno_dict + images.append(img_id) + image_path_dict[img_id] = image_path + image_anno_dict[img_id] = np.array(annos) + + return images, image_path_dict, image_anno_dict def anno_parser(annos_str): @@ -299,7 +353,8 @@ def anno_parser(annos_str): def filter_valid_data(image_dir, anno_path): """Filter valid image file, which both in image_dir and anno_path.""" - image_files = [] + images = [] + image_path_dict = {} image_anno_dict = {} if not os.path.isdir(image_dir): raise RuntimeError("Path given is not valid.") @@ -308,15 +363,17 @@ def filter_valid_data(image_dir, anno_path): with open(anno_path, "rb") as f: lines = f.readlines() - for line in lines: + for img_id, line in enumerate(lines): line_str = line.decode("utf-8").strip() line_split = str(line_str).split(' ') file_name = line_split[0] image_path = os.path.join(image_dir, file_name) if os.path.isfile(image_path): - image_anno_dict[image_path] = anno_parser(line_split[1:]) - image_files.append(image_path) - return image_files, image_anno_dict + images.append(img_id) + image_path_dict[img_id] = image_path + image_anno_dict[img_id] = anno_parser(line_split[1:]) + + return images, image_path_dict, image_anno_dict def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd.mindrecord", file_num=8): @@ -325,21 +382,24 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd. mindrecord_path = os.path.join(mindrecord_dir, prefix) writer = FileWriter(mindrecord_path, file_num) if dataset == "coco": - image_files, image_anno_dict = create_coco_label(is_training) + images, image_path_dict, image_anno_dict = create_coco_label(is_training) else: - image_files, image_anno_dict = filter_valid_data(config.IMAGE_DIR, config.ANNO_PATH) + images, image_path_dict, image_anno_dict = filter_valid_data(config.IMAGE_DIR, config.ANNO_PATH) ssd_json = { + "img_id": {"type": "int32", "shape": [1]}, "image": {"type": "bytes"}, "annotation": {"type": "int32", "shape": [-1, 5]}, } writer.add_schema(ssd_json, "ssd_json") - for image_name in image_files: - with open(image_name, 'rb') as f: + for img_id in images: + image_path = image_path_dict[img_id] + with open(image_path, 'rb') as f: img = f.read() - annos = np.array(image_anno_dict[image_name], dtype=np.int32) - row = {"image": img, "annotation": annos} + annos = np.array(image_anno_dict[img_id], dtype=np.int32) + img_id = np.array([img_id], dtype=np.int32) + row = {"img_id": img_id, "image": img, "annotation": annos} writer.write_raw_data([row]) writer.commit() @@ -347,29 +407,26 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd. def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num=1, rank=0, is_training=True, num_parallel_workers=4): """Creatr SSD dataset with MindDataset.""" - ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank, - num_parallel_workers=num_parallel_workers, shuffle=is_training) + ds = de.MindDataset(mindrecord_file, columns_list=["img_id", "image", "annotation"], num_shards=device_num, + shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training) decode = C.Decode() ds = ds.map(input_columns=["image"], operations=decode) - compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training)) - + change_swap_op = C.HWC2CHW() + normalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255]) + color_adjust_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) + compose_map_func = (lambda img_id, image, annotation: preprocess_fn(img_id, image, annotation, is_training)) if is_training: - hwc_to_chw = C.HWC2CHW() - ds = ds.map(input_columns=["image", "annotation"], - output_columns=["image", "box", "label", "num_match_num"], - columns_order=["image", "box", "label", "num_match_num"], - operations=compose_map_func, python_multiprocessing=True, num_parallel_workers=num_parallel_workers) - ds = ds.map(input_columns=["image"], operations=hwc_to_chw, python_multiprocessing=True, - num_parallel_workers=num_parallel_workers) - ds = ds.batch(batch_size, drop_remainder=True) - ds = ds.repeat(repeat_num) + output_columns = ["image", "box", "label", "num_match"] + trans = [color_adjust_op, normalize_op, change_swap_op] else: - hwc_to_chw = C.HWC2CHW() - ds = ds.map(input_columns=["image", "annotation"], - output_columns=["image", "image_shape", "annotation"], - columns_order=["image", "image_shape", "annotation"], - operations=compose_map_func) - ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers) - ds = ds.batch(batch_size, drop_remainder=True) - ds = ds.repeat(repeat_num) + output_columns = ["img_id", "image", "image_shape"] + trans = [normalize_op, change_swap_op] + ds = ds.map(input_columns=["img_id", "image", "annotation"], + output_columns=output_columns, columns_order=output_columns, + operations=compose_map_func, python_multiprocessing=is_training, + num_parallel_workers=num_parallel_workers) + ds = ds.map(input_columns=["image"], operations=trans, python_multiprocessing=is_training, + num_parallel_workers=num_parallel_workers) + ds = ds.batch(batch_size, drop_remainder=True) + ds = ds.repeat(repeat_num) return ds diff --git a/example/ssd_coco2017/eval.py b/example/ssd_coco2017/eval.py index d5e0d86b67a6dfc8089970c2d1686404ad8c161a..98410930f2b3daf9c1938c0b55c83c8d14dd5e8b 100644 --- a/example/ssd_coco2017/eval.py +++ b/example/ssd_coco2017/eval.py @@ -17,6 +17,7 @@ import os import argparse import time +import numpy as np from mindspore import context, Tensor from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.model_zoo.ssd import SSD300, ssd_mobilenet_v2 @@ -26,8 +27,8 @@ from util import metrics def ssd_eval(dataset_path, ckpt_path): """SSD evaluation.""" - - ds = create_ssd_dataset(dataset_path, batch_size=1, repeat_num=1, is_training=False) + batch_size = 32 + ds = create_ssd_dataset(dataset_path, batch_size=batch_size, repeat_num=1, is_training=False) net = SSD300(ssd_mobilenet_v2(), ConfigSSD(), is_training=False) print("Load Checkpoint!") param_dict = load_checkpoint(ckpt_path) @@ -35,28 +36,28 @@ def ssd_eval(dataset_path, ckpt_path): load_param_into_net(net, param_dict) net.set_train(False) - i = 1. - total = ds.get_dataset_size() + i = batch_size + total = ds.get_dataset_size() * batch_size start = time.time() pred_data = [] print("\n========================================\n") print("total images num: ", total) print("Processing, please wait a moment.") for data in ds.create_dict_iterator(): + img_id = data['img_id'] img_np = data['image'] image_shape = data['image_shape'] - annotation = data['annotation'] output = net(Tensor(img_np)) for batch_idx in range(img_np.shape[0]): pred_data.append({"boxes": output[0].asnumpy()[batch_idx], "box_scores": output[1].asnumpy()[batch_idx], - "annotation": annotation, - "image_shape": image_shape}) - percent = round(i / total * 100, 2) + "img_id": int(np.squeeze(img_id[batch_idx])), + "image_shape": image_shape[batch_idx]}) + percent = round(i / total * 100., 2) print(f' {str(percent)} [{i}/{total}]', end='\r') - i += 1 + i += batch_size cost_time = int((time.time() - start) * 1000) print(f' 100% [{total}/{total}] cost {cost_time} ms') mAP = metrics(pred_data) diff --git a/example/ssd_coco2017/run_distribute_train.sh b/example/ssd_coco2017/run_distribute_train.sh index 4c1049fccc06114397ac49483b149ec3389a7e34..8afadd7ad1129c1a0990ed9d4f700c582c06cb06 100644 --- a/example/ssd_coco2017/run_distribute_train.sh +++ b/example/ssd_coco2017/run_distribute_train.sh @@ -16,11 +16,17 @@ echo "==============================================================================================================" echo "Please run the scipt as: " -echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE MINDSPORE_HCCL_CONFIG_PATH" -echo "for example: sh run_distribute_train.sh 8 150 coco /data/hccl.json" +echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR DATASET MINDSPORE_HCCL_CONFIG_PATH PRE_TRAINED PRE_TRAINED_EPOCH_SIZE" +echo "for example: sh run_distribute_train.sh 8 500 0.2 coco /data/hccl.json /opt/ssd-300.ckpt(optional) 200(optional)" echo "It is better to use absolute path." -echo "The learning rate is 0.4 as default, if you want other lr, please change the value in this script." -echo "==============================================================================================================" +echo "=================================================================================================================" + +if [ $# != 5 ] && [ $# != 7 ] +then + echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] [DATASET] \ +[MINDSPORE_HCCL_CONFIG_PATH] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)" + exit 1 +fi # Before start distribute train, first create mindrecord files. python train.py --only_create_dataset=1 @@ -29,9 +35,11 @@ echo "After running the scipt, the network runs in the background. The log will export RANK_SIZE=$1 EPOCH_SIZE=$2 -DATASET=$3 -export MINDSPORE_HCCL_CONFIG_PATH=$4 - +LR=$3 +DATASET=$4 +PRE_TRAINED=$6 +PRE_TRAINED_EPOCH_SIZE=$7 +export MINDSPORE_HCCL_CONFIG_PATH=$5 for((i=0;i env.log - python ../train.py \ - --distribute=1 \ - --lr=0.4 \ - --dataset=$DATASET \ - --device_num=$RANK_SIZE \ - --device_id=$DEVICE_ID \ - --epoch_size=$EPOCH_SIZE > log.txt 2>&1 & + if [ $# == 5 ] + then + python ../train.py \ + --distribute=1 \ + --lr=$LR \ + --dataset=$DATASET \ + --device_num=$RANK_SIZE \ + --device_id=$DEVICE_ID \ + --epoch_size=$EPOCH_SIZE > log.txt 2>&1 & + fi + + if [ $# == 7 ] + then + python ../train.py \ + --distribute=1 \ + --lr=$LR \ + --dataset=$DATASET \ + --device_num=$RANK_SIZE \ + --device_id=$DEVICE_ID \ + --pre_trained=$PRE_TRAINED \ + --pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE \ + --epoch_size=$EPOCH_SIZE > log.txt 2>&1 & + fi + cd ../ done diff --git a/example/ssd_coco2017/train.py b/example/ssd_coco2017/train.py index 75f9a6d31f7e237d3a3ef7e0b49248d2214f0e2d..b28bcc1523133f34aea3a33e157c7503ad4f0b93 100644 --- a/example/ssd_coco2017/train.py +++ b/example/ssd_coco2017/train.py @@ -16,79 +16,34 @@ """train SSD and get checkpoint files.""" import os -import math import argparse -import numpy as np import mindspore.nn as nn from mindspore import context, Tensor from mindspore.communication.management import init from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor from mindspore.train import Model, ParallelMode from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.common.initializer import initializer - from mindspore.model_zoo.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2 from config import ConfigSSD from dataset import create_ssd_dataset, data_to_mindrecord_byte_image +from util import get_lr, init_net_param -def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch): - """ - generate learning rate array - - Args: - global_step(int): total steps of the training - lr_init(float): init learning rate - lr_end(float): end learning rate - lr_max(float): max learning rate - warmup_epochs(int): number of warmup epochs - total_epochs(int): total epoch of training - steps_per_epoch(int): steps of one epoch - - Returns: - np.array, learning rate array - """ - lr_each_step = [] - total_steps = steps_per_epoch * total_epochs - warmup_steps = steps_per_epoch * warmup_epochs - for i in range(total_steps): - if i < warmup_steps: - lr = lr_init + (lr_max - lr_init) * i / warmup_steps - else: - lr = lr_end + (lr_max - lr_end) * \ - (1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2. - if lr < 0.0: - lr = 0.0 - lr_each_step.append(lr) - - current_step = global_step - lr_each_step = np.array(lr_each_step).astype(np.float32) - learning_rate = lr_each_step[current_step:] - - return learning_rate - - -def init_net_param(network, initialize_mode='XavierUniform'): - """Init the parameters in net.""" - params = network.trainable_params() - for p in params: - if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: - p.set_parameter_data(initializer(initialize_mode, p.data.shape(), p.data.dtype())) - def main(): parser = argparse.ArgumentParser(description="SSD training") parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create " - "Mindrecord, default is false.") - parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.") + "Mindrecord, default is False.") + parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is False.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") - parser.add_argument("--lr", type=float, default=0.25, help="Learning rate, default is 0.25.") + parser.add_argument("--lr", type=float, default=0.1, help="Learning rate, default is 0.1.") parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.") parser.add_argument("--dataset", type=str, default="coco", help="Dataset, defalut is coco.") - parser.add_argument("--epoch_size", type=int, default=70, help="Epoch size, default is 70.") + parser.add_argument("--epoch_size", type=int, default=250, help="Epoch size, default is 250.") parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.") - parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.") + parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size.") + parser.add_argument("--save_checkpoint_epochs", type=int, default=10, help="Save checkpoint epochs, default is 5.") parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") args_opt = parser.parse_args() @@ -142,7 +97,8 @@ def main(): dataset_size = dataset.get_dataset_size() print("Create dataset done!") - ssd = SSD300(backbone=ssd_mobilenet_v2(), config=config) + backbone = ssd_mobilenet_v2() + ssd = SSD300(backbone=backbone, config=config) net = SSDWithLossCell(ssd, config) init_net_param(net) @@ -150,17 +106,19 @@ def main(): ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs) ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=None, config=ckpt_config) - lr = Tensor(get_lr(global_step=0, lr_init=0, lr_end=0, lr_max=args_opt.lr, - warmup_epochs=max(args_opt.epoch_size // 20, 1), - total_epochs=args_opt.epoch_size, - steps_per_epoch=dataset_size)) - opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 0.0001, loss_scale) - net = TrainingWrapper(net, opt, loss_scale) - if args_opt.pre_trained: + if args_opt.pre_trained_epoch_size <= 0: + raise KeyError("pre_trained_epoch_size must be greater than 0.") param_dict = load_checkpoint(args_opt.pre_trained) load_param_into_net(net, param_dict) + lr = Tensor(get_lr(global_step=0, lr_init=0.001, lr_end=0.001 * args_opt.lr, lr_max=args_opt.lr, + warmup_epochs=2, + total_epochs=args_opt.epoch_size, + steps_per_epoch=dataset_size)) + opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 1e-4, loss_scale) + net = TrainingWrapper(net, opt, loss_scale) + callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] model = Model(net) diff --git a/example/ssd_coco2017/util.py b/example/ssd_coco2017/util.py index 6e1028537576df231b6d820096eeb148858c3bea..3c547dbbeeee123f00dae13bbeabf975a2fbbfc6 100644 --- a/example/ssd_coco2017/util.py +++ b/example/ssd_coco2017/util.py @@ -14,43 +14,83 @@ # ============================================================================ """metrics utils""" +import os +import json +import math import numpy as np +from mindspore import Tensor +from mindspore.common.initializer import initializer, TruncatedNormal from config import ConfigSSD from dataset import ssd_bboxes_decode -def calc_iou(bbox_pred, bbox_ground): - """Calculate iou of predicted bbox and ground truth.""" - bbox_pred = np.expand_dims(bbox_pred, axis=0) - - pred_w = bbox_pred[:, 2] - bbox_pred[:, 0] - pred_h = bbox_pred[:, 3] - bbox_pred[:, 1] - pred_area = pred_w * pred_h - - gt_w = bbox_ground[:, 2] - bbox_ground[:, 0] - gt_h = bbox_ground[:, 3] - bbox_ground[:, 1] - gt_area = gt_w * gt_h - - iw = np.minimum(bbox_pred[:, 2], bbox_ground[:, 2]) - np.maximum(bbox_pred[:, 0], bbox_ground[:, 0]) - ih = np.minimum(bbox_pred[:, 3], bbox_ground[:, 3]) - np.maximum(bbox_pred[:, 1], bbox_ground[:, 1]) - - iw = np.maximum(iw, 0) - ih = np.maximum(ih, 0) - intersection_area = iw * ih +def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch): + """ + generate learning rate array + + Args: + global_step(int): total steps of the training + lr_init(float): init learning rate + lr_end(float): end learning rate + lr_max(float): max learning rate + warmup_epochs(int): number of warmup epochs + total_epochs(int): total epoch of training + steps_per_epoch(int): steps of one epoch + + Returns: + np.array, learning rate array + """ + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + warmup_steps = steps_per_epoch * warmup_epochs + for i in range(total_steps): + if i < warmup_steps: + lr = lr_init + (lr_max - lr_init) * i / warmup_steps + else: + lr = lr_end + \ + (lr_max - lr_end) * \ + (1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2. + if lr < 0.0: + lr = 0.0 + lr_each_step.append(lr) + + current_step = global_step + lr_each_step = np.array(lr_each_step).astype(np.float32) + learning_rate = lr_each_step[current_step:] + + return learning_rate + + +def init_net_param(network, initialize_mode='TruncatedNormal'): + """Init the parameters in net.""" + params = network.trainable_params() + for p in params: + if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: + if initialize_mode == 'TruncatedNormal': + p.set_parameter_data(initializer(TruncatedNormal(0.03), p.data.shape(), p.data.dtype())) + else: + p.set_parameter_data(initialize_mode, p.data.shape(), p.data.dtype()) - union_area = pred_area + gt_area - intersection_area - union_area = np.maximum(union_area, np.finfo(float).eps) - iou = intersection_area * 1. / union_area - return iou +def load_backbone_params(network, param_dict): + """Init the parameters from pre-train model, default is mobilenetv2.""" + for _, param in net.parameters_and_names(): + param_name = param.name.replace('network.backbone.', '') + name_split = param_name.split('.') + if 'features_1' in param_name: + param_name = param_name.replace('features_1', 'features') + if 'features_2' in param_name: + param_name = '.'.join(['features', str(int(name_split[1]) + 14)] + name_split[2:]) + if param_name in param_dict: + param.set_parameter_data(param_dict[param_name].data) def apply_nms(all_boxes, all_scores, thres, max_boxes): """Apply NMS to bboxes.""" - x1 = all_boxes[:, 0] - y1 = all_boxes[:, 1] - x2 = all_boxes[:, 2] - y2 = all_boxes[:, 3] + y1 = all_boxes[:, 0] + x1 = all_boxes[:, 1] + y2 = all_boxes[:, 2] + x2 = all_boxes[:, 3] areas = (x2 - x1 + 1) * (y2 - y1 + 1) order = all_scores.argsort()[::-1] @@ -80,127 +120,73 @@ def apply_nms(all_boxes, all_scores, thres, max_boxes): return keep -def calc_ap(recall, precision): - """Calculate AP.""" - correct_recall = np.concatenate(([0.], recall, [1.])) - correct_precision = np.concatenate(([0.], precision, [0.])) - - for i in range(correct_recall.size - 1, 0, -1): - correct_precision[i - 1] = np.maximum(correct_precision[i - 1], correct_precision[i]) - - i = np.where(correct_recall[1:] != correct_recall[:-1])[0] - - ap = np.sum((correct_recall[i + 1] - correct_recall[i]) * correct_precision[i + 1]) - - return ap - def metrics(pred_data): """Calculate mAP of predicted bboxes.""" + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval config = ConfigSSD() num_classes = config.NUM_CLASSES - all_detections = [None for i in range(num_classes)] - all_pred_scores = [None for i in range(num_classes)] - all_annotations = [None for i in range(num_classes)] - average_precisions = {} - num = [0 for i in range(num_classes)] - accurate_num = [0 for i in range(num_classes)] + coco_root = config.COCO_ROOT + data_type = config.VAL_DATA_TYPE - for sample in pred_data: - pred_boxes = sample['boxes'] - boxes_scores = sample['box_scores'] - annotation = sample['annotation'] + #Classes need to train or test. + val_cls = config.COCO_CLASSES + val_cls_dict = {} + for i, cls in enumerate(val_cls): + val_cls_dict[i] = cls - annotation = np.squeeze(annotation, axis=0) + anno_json = os.path.join(coco_root, config.INSTANCES_SET.format(data_type)) + coco_gt = COCO(anno_json) + classs_dict = {} + cat_ids = coco_gt.loadCats(coco_gt.getCatIds()) + for cat in cat_ids: + classs_dict[cat["name"]] = cat["id"] - pred_labels = np.argmax(boxes_scores, axis=-1) - index = np.nonzero(pred_labels) - pred_boxes = ssd_bboxes_decode(pred_boxes, index) + predictions = [] + img_ids = [] - pred_boxes = pred_boxes.clip(0, 1) - boxes_scores = np.max(boxes_scores, axis=-1) - boxes_scores = boxes_scores[index] - pred_labels = pred_labels[index] + for sample in pred_data: + pred_boxes = sample['boxes'] + box_scores = sample['box_scores'] + img_id = sample['img_id'] + h, w = sample['image_shape'] - top_k = 50 + pred_boxes = ssd_bboxes_decode(pred_boxes) + final_boxes = [] + final_label = [] + final_score = [] + img_ids.append(img_id) for c in range(1, num_classes): - if len(pred_labels) >= 1: - class_box_scores = boxes_scores[pred_labels == c] - class_boxes = pred_boxes[pred_labels == c] - - nms_index = apply_nms(class_boxes, class_box_scores, config.MATCH_THRESHOLD, top_k) + class_box_scores = box_scores[:, c] + score_mask = class_box_scores > config.MIN_SCORE + class_box_scores = class_box_scores[score_mask] + class_boxes = pred_boxes[score_mask] * [h, w, h, w] + if score_mask.any(): + nms_index = apply_nms(class_boxes, class_box_scores, config.NMS_THRESHOLD, config.TOP_K) class_boxes = class_boxes[nms_index] class_box_scores = class_box_scores[nms_index] - cmask = class_box_scores > 0.5 - class_boxes = class_boxes[cmask] - class_box_scores = class_box_scores[cmask] - - all_detections[c] = class_boxes - all_pred_scores[c] = class_box_scores - - for c in range(1, num_classes): - if len(annotation) >= 1: - all_annotations[c] = annotation[annotation[:, 4] == c, :4] - - for c in range(1, num_classes): - false_positives = np.zeros((0,)) - true_positives = np.zeros((0,)) - scores = np.zeros((0,)) - num_annotations = 0.0 - - annotations = all_annotations[c] - num_annotations += annotations.shape[0] - detections = all_detections[c] - pred_scores = all_pred_scores[c] - - for index, detection in enumerate(detections): - scores = np.append(scores, pred_scores[index]) - if len(annotations) >= 1: - IoUs = calc_iou(detection, annotations) - assigned_anno = np.argmax(IoUs) - max_overlap = IoUs[assigned_anno] - - if max_overlap >= 0.5: - false_positives = np.append(false_positives, 0) - true_positives = np.append(true_positives, 1) - else: - false_positives = np.append(false_positives, 1) - true_positives = np.append(true_positives, 0) - else: - false_positives = np.append(false_positives, 1) - true_positives = np.append(true_positives, 0) - - if num_annotations == 0: - if c not in average_precisions.keys(): - average_precisions[c] = 0 - continue - accurate_num[c] = 1 - indices = np.argsort(-scores) - false_positives = false_positives[indices] - true_positives = true_positives[indices] - - false_positives = np.cumsum(false_positives) - true_positives = np.cumsum(true_positives) - - recall = true_positives * 1. / num_annotations - precision = true_positives * 1. / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps) - - average_precision = calc_ap(recall, precision) - - if c not in average_precisions.keys(): - average_precisions[c] = average_precision - else: - average_precisions[c] += average_precision - - num[c] += 1 - - count = 0 - for key in average_precisions: - if num[key] != 0: - count += (average_precisions[key] / num[key]) - - mAP = count * 1. / accurate_num.count(1) - return mAP + final_boxes += class_boxes.tolist() + final_score += class_box_scores.tolist() + final_label += [classs_dict[val_cls_dict[c]]] * len(class_box_scores) + + for loc, label, score in zip(final_boxes, final_label, final_score): + res = {} + res['image_id'] = img_id + res['bbox'] = [loc[1], loc[0], loc[3] - loc[1], loc[2] - loc[0]] + res['score'] = score + res['category_id'] = label + predictions.append(res) + with open('predictions.json', 'w') as f: + json.dump(predictions, f) + + coco_dt = coco_gt.loadRes('predictions.json') + E = COCOeval(coco_gt, coco_dt, iouType='bbox') + E.params.imgIds = img_ids + E.evaluate() + E.accumulate() + E.summarize() + return E.stats[0] diff --git a/mindspore/model_zoo/ssd.py b/mindspore/model_zoo/ssd.py index b69942cd5c1bab1ebb61057d6598c46cf5d5e119..32c2689fc6b2beefd6112c900d943e484dd9bba9 100644 --- a/mindspore/model_zoo/ssd.py +++ b/mindspore/model_zoo/ssd.py @@ -17,22 +17,13 @@ import mindspore.common.dtype as mstype import mindspore as ms import mindspore.nn as nn -from mindspore import context +from mindspore import Parameter, context, Tensor from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.communication.management import get_group_size from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.ops import composite as C from mindspore.common.initializer import initializer -from mindspore.ops.operations import TensorAdd -from mindspore import Parameter - - -def _conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same'): - weight_shape = (out_channel, in_channel, kernel_size, kernel_size) - weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor() - return nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, - padding=0, pad_mode=pad_mod, weight_init=weight) def _make_divisible(v, divisor, min_value=None): @@ -46,6 +37,55 @@ def _make_divisible(v, divisor, min_value=None): return new_v +def _conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same'): + return nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, + padding=0, pad_mode=pad_mod, has_bias=True) + + +def _bn(channel): + return nn.BatchNorm2d(channel, eps=1e-3, momentum=0.97, + gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +def _last_conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same', pad=0): + depthwise_conv = DepthwiseConv(in_channel, kernel_size, stride, pad_mode='same', pad=pad) + conv = _conv2d(in_channel, out_channel, kernel_size=1) + return nn.SequentialCell([depthwise_conv, _bn(in_channel), nn.ReLU6(), conv]) + + +class ConvBNReLU(nn.Cell): + """ + Convolution/Depthwise fused with Batchnorm and ReLU block definition. + + Args: + in_planes (int): Input channel. + out_planes (int): Output channel. + kernel_size (int): Input kernel size. + stride (int): Stride size for the first convolutional layer. Default: 1. + groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) + """ + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + super(ConvBNReLU, self).__init__() + padding = 0 + if groups == 1: + conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='same', + padding=padding) + else: + conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='same', pad=padding) + layers = [conv, _bn(out_planes), nn.ReLU6()] + self.features = nn.SequentialCell(layers) + + def construct(self, x): + output = self.features(x) + return output + + class DepthwiseConv(nn.Cell): """ Depthwise Convolution warpper definition. @@ -64,6 +104,7 @@ class DepthwiseConv(nn.Cell): Examples: >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1) """ + def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): super(DepthwiseConv, self).__init__() self.has_bias = has_bias @@ -91,42 +132,9 @@ class DepthwiseConv(nn.Cell): return output -class ConvBNReLU(nn.Cell): - """ - Convolution/Depthwise fused with Batchnorm and ReLU block definition. - - Args: - in_planes (int): Input channel. - out_planes (int): Output channel. - kernel_size (int): Input kernel size. - stride (int): Stride size for the first convolutional layer. Default: 1. - groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. - - Returns: - Tensor, output tensor. - - Examples: - >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) - """ - def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): - super(ConvBNReLU, self).__init__() - padding = (kernel_size - 1) // 2 - if groups == 1: - conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', - padding=padding) - else: - conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding) - layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()] - self.features = nn.SequentialCell(layers) - - def construct(self, x): - output = self.features(x) - return output - - class InvertedResidual(nn.Cell): """ - Mobilenetv2 residual block definition. + Residual block definition. Args: inp (int): Input channel. @@ -140,7 +148,7 @@ class InvertedResidual(nn.Cell): Examples: >>> ResidualBlock(3, 256, 1, 1) """ - def __init__(self, inp, oup, stride, expand_ratio): + def __init__(self, inp, oup, stride, expand_ratio, last_relu=False): super(InvertedResidual, self).__init__() assert stride in [1, 2] @@ -155,17 +163,21 @@ class InvertedResidual(nn.Cell): ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), # pw-linear nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, has_bias=False), - nn.BatchNorm2d(oup), + _bn(oup), ]) self.conv = nn.SequentialCell(layers) - self.add = TensorAdd() + self.add = P.TensorAdd() self.cast = P.Cast() + self.last_relu = last_relu + self.relu = nn.ReLU6() def construct(self, x): identity = x x = self.conv(x) if self.use_res_connect: - return self.add(identity, x) + x = self.add(identity, x) + if self.last_relu: + x = self.relu(x) return x @@ -214,10 +226,10 @@ class MultiBox(nn.Cell): loc_layers = [] cls_layers = [] for k, out_channel in enumerate(out_channels): - loc_layers += [_conv2d(out_channel, 4 * num_default[k], - kernel_size=3, stride=1, pad_mod='same')] - cls_layers += [_conv2d(out_channel, num_classes * num_default[k], - kernel_size=3, stride=1, pad_mod='same')] + loc_layers += [_last_conv2d(out_channel, 4 * num_default[k], + kernel_size=3, stride=1, pad_mod='same', pad=0)] + cls_layers += [_last_conv2d(out_channel, num_classes * num_default[k], + kernel_size=3, stride=1, pad_mod='same', pad=0)] self.multi_loc_layers = nn.layer.CellList(loc_layers) self.multi_cls_layers = nn.layer.CellList(cls_layers) @@ -258,13 +270,14 @@ class SSD300(nn.Cell): strides = config.EXTRAS_STRIDES residual_list = [] for i in range(2, len(in_channels)): - residual = InvertedResidual(in_channels[i], out_channels[i], stride=strides[i], expand_ratio=ratios[i]) + residual = InvertedResidual(in_channels[i], out_channels[i], stride=strides[i], + expand_ratio=ratios[i], last_relu=True) residual_list.append(residual) self.multi_residual = nn.layer.CellList(residual_list) self.multi_box = MultiBox(config) self.is_training = is_training if not is_training: - self.softmax = P.Softmax() + self.activation = P.Sigmoid() def construct(self, x): layer_out_13, output = self.backbone(x) @@ -275,77 +288,42 @@ class SSD300(nn.Cell): multi_feature += (feature,) pred_loc, pred_label = self.multi_box(multi_feature) if not self.is_training: - pred_label = self.softmax(pred_label) + pred_label = self.activation(pred_label) return pred_loc, pred_label -class LocalizationLoss(nn.Cell): +class SigmoidFocalClassificationLoss(nn.Cell): """" - Computes the localization loss with SmoothL1Loss. - - Returns: - Tensor, box regression loss. - """ - def __init__(self): - super(LocalizationLoss, self).__init__() - self.reduce_sum = P.ReduceSum() - self.reduce_mean = P.ReduceMean() - self.loss = nn.SmoothL1Loss() - self.expand_dims = P.ExpandDims() - self.less = P.Less() - - def construct(self, pred_loc, gt_loc, gt_label, num_matched_boxes): - mask = F.cast(self.less(0, gt_label), mstype.float32) - mask = self.expand_dims(mask, -1) - smooth_l1 = self.loss(gt_loc, pred_loc) * mask - box_loss = self.reduce_sum(smooth_l1, 1) - return self.reduce_mean(box_loss / F.cast(num_matched_boxes, mstype.float32), (0, 1)) - - -class ClassificationLoss(nn.Cell): - """" - Computes the classification loss with hard example mining. + Sigmoid focal-loss for classification. Args: - config (Class): The default config of SSD. + gamma (float): Hyper-parameter to balance the easy and hard examples. Default: 2.0 + alpha (float): Hyper-parameter to balance the positive and negative example. Default: 0.25 Returns: - Tensor, classification loss. + Tensor, the focal loss. """ - def __init__(self, config): - super(ClassificationLoss, self).__init__() - self.num_classes = config.NUM_CLASSES - self.num_boxes = config.NUM_SSD_BOXES - self.neg_pre_positive = config.NEG_PRE_POSITIVE - self.minimum = P.Minimum() - self.less = P.Less() - self.sort = P.TopK() - self.tile = P.Tile() - self.reduce_sum = P.ReduceSum() - self.reduce_mean = P.ReduceMean() - self.expand_dims = P.ExpandDims() - self.sort_descend = P.TopK(True) - self.cross_entropy = nn.SoftmaxCrossEntropyWithLogits(sparse=True) - - def construct(self, pred_label, gt_label, num_matched_boxes): - gt_label = F.cast(gt_label, mstype.int32) - mask = F.cast(self.less(0, gt_label), mstype.float32) - gt_label_shape = F.shape(gt_label) - pred_label = F.reshape(pred_label, (-1, self.num_classes)) - gt_label = F.reshape(gt_label, (-1,)) - cross_entropy = self.cross_entropy(pred_label, gt_label) - cross_entropy = F.reshape(cross_entropy, gt_label_shape) - - # Hard example mining - num_matched_boxes = F.reshape(num_matched_boxes, (-1,)) - neg_masked_cross_entropy = F.cast(cross_entropy * (1- mask), mstype.float16) - _, loss_idx = self.sort_descend(neg_masked_cross_entropy, self.num_boxes) - _, relative_position = self.sort(F.cast(loss_idx, mstype.float16), self.num_boxes) - num_neg_boxes = self.minimum(num_matched_boxes * self.neg_pre_positive, self.num_boxes) - tile_num_neg_boxes = self.tile(self.expand_dims(num_neg_boxes, -1), (1, self.num_boxes)) - top_k_neg_mask = F.cast(self.less(relative_position, tile_num_neg_boxes), mstype.float32) - class_loss = self.reduce_sum(cross_entropy * (mask + top_k_neg_mask), 1) - return self.reduce_mean(class_loss / F.cast(num_matched_boxes, mstype.float32), 0) + def __init__(self, gamma=2.0, alpha=0.75): + super(SigmoidFocalClassificationLoss, self).__init__() + self.sigmiod_cross_entropy = P.SigmoidCrossEntropyWithLogits() + self.sigmoid = P.Sigmoid() + self.pow = P.Pow() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.gamma = gamma + self.alpha = alpha + + def construct(self, logits, label): + label = self.onehot(label, F.shape(logits)[-1], self.on_value, self.off_value) + sigmiod_cross_entropy = self.sigmiod_cross_entropy(logits, label) + sigmoid = self.sigmoid(logits) + label = F.cast(label, mstype.float32) + p_t = label * sigmoid + (1 - label) * (1 - sigmoid) + modulating_factor = self.pow(1 - p_t, self.gamma) + alpha_weight_factor = label * self.alpha + (1 - label) * (1 - self.alpha) + focal_loss = modulating_factor * alpha_weight_factor * sigmiod_cross_entropy + return focal_loss class SSDWithLossCell(nn.Cell): @@ -362,14 +340,29 @@ class SSDWithLossCell(nn.Cell): def __init__(self, network, config): super(SSDWithLossCell, self).__init__() self.network = network - self.class_loss = ClassificationLoss(config) - self.box_loss = LocalizationLoss() + self.less = P.Less() + self.tile = P.Tile() + self.reduce_sum = P.ReduceSum() + self.reduce_mean = P.ReduceMean() + self.expand_dims = P.ExpandDims() + self.class_loss = SigmoidFocalClassificationLoss() + self.loc_loss = nn.SmoothL1Loss() def construct(self, x, gt_loc, gt_label, num_matched_boxes): pred_loc, pred_label = self.network(x) - loss_cls = self.class_loss(pred_label, gt_label, num_matched_boxes) - loss_loc = self.box_loss(pred_loc, gt_loc, gt_label, num_matched_boxes) - return loss_cls + loss_loc + mask = F.cast(self.less(0, gt_label), mstype.float32) + num_matched_boxes = self.reduce_sum(F.cast(num_matched_boxes, mstype.float32)) + + # Localization Loss + mask_loc = self.tile(self.expand_dims(mask, -1), (1, 1, 4)) + smooth_l1 = self.loc_loss(pred_loc, gt_loc) * mask_loc + loss_loc = self.reduce_sum(self.reduce_mean(smooth_l1, -1), -1) + + # Classification Loss + loss_cls = self.class_loss(pred_label, gt_label) + loss_cls = self.reduce_sum(loss_cls, (1, 2)) + + return self.reduce_sum((loss_cls + loss_loc) / num_matched_boxes) class TrainingWrapper(nn.Cell): @@ -415,7 +408,6 @@ class TrainingWrapper(nn.Cell): return F.depend(loss, self.optimizer(grads)) - class SSDWithMobileNetV2(nn.Cell): """ MobileNetV2 architecture for SSD backbone.