diff --git a/model_zoo/official/cv/ssd/eval.py b/model_zoo/official/cv/ssd/eval.py index 9054bf6f244b397963edc48f3d9aa09b7ea005a7..1ce66e38f182a1ae6cc63a1441eda0f60fdebe45 100644 --- a/model_zoo/official/cv/ssd/eval.py +++ b/model_zoo/official/cv/ssd/eval.py @@ -22,7 +22,7 @@ import numpy as np from mindspore import context, Tensor from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.ssd import SSD300, ssd_mobilenet_v2 -from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image +from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord from src.config import config from src.coco_eval import metrics @@ -88,6 +88,13 @@ if __name__ == '__main__': print("Create Mindrecord Done, at {}".format(mindrecord_dir)) else: print("coco_root not exits.") + elif args_opt.dataset == "voc": + if os.path.isdir(config.voc_dir) and os.path.isdir(config.voc_root): + print("Create Mindrecord.") + voc_data_to_mindrecord(mindrecord_dir, False, prefix) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("voc_root or voc_dir not exits.") else: if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path): print("Create Mindrecord.") diff --git a/model_zoo/official/cv/ssd/src/config.py b/model_zoo/official/cv/ssd/src/config.py index 683b8de31fd33c8a0ee3262106318c1b48d0a8b4..ff0cd219634bfe089856c4c857fd249d832bdc8b 100644 --- a/model_zoo/official/cv/ssd/src/config.py +++ b/model_zoo/official/cv/ssd/src/config.py @@ -71,8 +71,11 @@ config = ed({ 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'), "num_classes": 81, - - # if coco used, `image_dir` and `anno_path` are useless. + # The annotation.json position of voc validation dataset. + "voc_root": "", + # voc original dataset. + "voc_dir": "", + # if coco or voc used, `image_dir` and `anno_path` are useless. "image_dir": "", "anno_path": "", }) diff --git a/model_zoo/official/cv/ssd/src/dataset.py b/model_zoo/official/cv/ssd/src/dataset.py index 98097f474b277932c2fffd1ba44ced75737d1668..d842aef709fd549ae2482f0b01a07bcb6756e1b5 100644 --- a/model_zoo/official/cv/ssd/src/dataset.py +++ b/model_zoo/official/cv/ssd/src/dataset.py @@ -18,6 +18,8 @@ from __future__ import division import os +import json +import xml.etree.ElementTree as et import cv2 import numpy as np @@ -32,6 +34,13 @@ def _rand(a=0., b=1.): """Generate random.""" return np.random.rand() * (b - a) + a +def get_imageId_from_fileName(filename): + """Get imageID from fileName""" + try: + filename = os.path.splitext(filename)[0] + return int(filename) + except: + raise NotImplementedError('Filename %s is supposed to be an integer.'%(filename)) def random_sample_crop(image, boxes): """Random Crop the image and boxes""" @@ -144,6 +153,96 @@ def preprocess_fn(img_id, image, box, is_training): return _data_aug(image, box, is_training, image_size=config.img_shape) +def create_voc_label(is_training): + """Get image path and annotation from VOC.""" + voc_dir = config.voc_dir + cls_map = {name: i for i, name in enumerate(config.coco_classes)} + sub_dir = 'train' if is_training else 'eval' + #sub_dir = 'train' + voc_dir = os.path.join(voc_dir, sub_dir) + if not os.path.isdir(voc_dir): + raise ValueError(f'Cannot find {sub_dir} dataset path.') + + image_dir = anno_dir = voc_dir + if os.path.isdir(os.path.join(voc_dir, 'Images')): + image_dir = os.path.join(voc_dir, 'Images') + if os.path.isdir(os.path.join(voc_dir, 'Annotations')): + anno_dir = os.path.join(voc_dir, 'Annotations') + + if not is_training: + data_dir = config.voc_root + json_file = os.path.join(data_dir, config.instances_set.format(sub_dir)) + file_dir = os.path.split(json_file)[0] + if not os.path.isdir(file_dir): + os.makedirs(file_dir) + json_dict = {"images": [], "type": "instances", "annotations": [], + "categories": []} + bnd_id = 1 + + image_files_dict = {} + image_anno_dict = {} + images = [] + for anno_file in os.listdir(anno_dir): + print(anno_file) + if not anno_file.endswith('xml'): + continue + tree = et.parse(os.path.join(anno_dir, anno_file)) + root_node = tree.getroot() + file_name = root_node.find('filename').text + img_id = get_imageId_from_fileName(file_name) + image_path = os.path.join(image_dir, file_name) + print(image_path) + if not os.path.isfile(image_path): + print(f'Cannot find image {file_name} according to annotations.') + continue + + labels = [] + for obj in root_node.iter('object'): + cls_name = obj.find('name').text + if cls_name not in cls_map: + print(f'Label "{cls_name}" not in "{config.coco_classes}"') + continue + bnd_box = obj.find('bndbox') + x_min = int(bnd_box.find('xmin').text) - 1 + y_min = int(bnd_box.find('ymin').text) - 1 + x_max = int(bnd_box.find('xmax').text) - 1 + y_max = int(bnd_box.find('ymax').text) - 1 + labels.append([y_min, x_min, y_max, x_max, cls_map[cls_name]]) + + if not is_training: + o_width = abs(x_max - x_min) + o_height = abs(y_max - y_min) + ann = {'area': o_width * o_height, 'iscrowd': 0, 'image_id': \ + img_id, 'bbox': [x_min, y_min, o_width, o_height], \ + 'category_id': cls_map[cls_name], 'id': bnd_id, \ + 'ignore': 0, \ + 'segmentation': []} + json_dict['annotations'].append(ann) + bnd_id = bnd_id + 1 + + if labels: + images.append(img_id) + image_files_dict[img_id] = image_path + image_anno_dict[img_id] = np.array(labels) + + if not is_training: + size = root_node.find("size") + width = int(size.find('width').text) + height = int(size.find('height').text) + image = {'file_name': file_name, 'height': height, 'width': width, + 'id': img_id} + json_dict['images'].append(image) + + for cls_name, cid in cls_map.items(): + cat = {'supercategory': 'none', 'id': cid, 'name': cls_name} + json_dict['categories'].append(cat) + json_fp = open(json_file, 'w') + json_str = json.dumps(json_dict) + json_fp.write(json_str) + json_fp.close() + + return images, image_files_dict, image_anno_dict + def create_coco_label(is_training): """Get image path and annotation from COCO.""" from pycocotools.coco import COCO @@ -233,6 +332,30 @@ def filter_valid_data(image_dir, anno_path): return images, image_path_dict, image_anno_dict +def voc_data_to_mindrecord(mindrecord_dir, is_training, prefix="ssd.mindrecord", file_num=8): + """Create MindRecord file by image_dir and anno_path.""" + mindrecord_path = os.path.join(mindrecord_dir, prefix) + writer = FileWriter(mindrecord_path, file_num) + images, image_path_dict, image_anno_dict = create_voc_label(is_training) + + ssd_json = { + "img_id": {"type": "int32", "shape": [1]}, + "image": {"type": "bytes"}, + "annotation": {"type": "int32", "shape": [-1, 5]}, + } + writer.add_schema(ssd_json, "ssd_json") + + for img_id in images: + image_path = image_path_dict[img_id] + with open(image_path, 'rb') as f: + img = f.read() + annos = np.array(image_anno_dict[img_id], dtype=np.int32) + img_id = np.array([img_id], dtype=np.int32) + row = {"img_id": img_id, "image": img, "annotation": annos} + writer.write_raw_data([row]) + writer.commit() + + def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd.mindrecord", file_num=8): """Create MindRecord file.""" mindrecord_dir = config.mindrecord_dir diff --git a/model_zoo/official/cv/ssd/src/init_params.py b/model_zoo/official/cv/ssd/src/init_params.py index 335030d2e9fd03767a095410d6099d9fe77e3029..b71ee2c4dc5a47bf0b680347b07e6fb888673280 100644 --- a/model_zoo/official/cv/ssd/src/init_params.py +++ b/model_zoo/official/cv/ssd/src/init_params.py @@ -40,3 +40,9 @@ def load_backbone_params(network, param_dict): param_name = '.'.join(['features', str(int(name_split[1]) + 14)] + name_split[2:]) if param_name in param_dict: param.set_parameter_data(param_dict[param_name].data) + +def filter_checkpoint_parameter(param_dict): + """remove useless parameters""" + for key in list(param_dict.keys()): + if 'multi_loc_layers' in key or 'multi_cls_layers' in key: + del param_dict[key] diff --git a/model_zoo/official/cv/ssd/train.py b/model_zoo/official/cv/ssd/train.py index a0969a06fddddef488d8d383e6bf65decd508a2c..a1739d0e4ed3d034c804eccd76c715c68341f9ff 100644 --- a/model_zoo/official/cv/ssd/train.py +++ b/model_zoo/official/cv/ssd/train.py @@ -25,9 +25,9 @@ from mindspore.train import Model, ParallelMode from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2 from src.config import config -from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image +from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord from src.lr_schedule import get_lr -from src.init_params import init_net_param +from src.init_params import init_net_param, filter_checkpoint_parameter def main(): @@ -79,6 +79,13 @@ def main(): print("Create Mindrecord Done, at {}".format(mindrecord_dir)) else: print("coco_root not exits.") + elif args_opt.dataset == "voc": + if os.path.isdir(config.voc_dir): + print("Create Mindrecord.") + voc_data_to_mindrecord(mindrecord_dir, True, prefix) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("voc_dir not exits.") else: if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path): print("Create Mindrecord.") @@ -110,6 +117,7 @@ def main(): if args_opt.pre_trained_epoch_size <= 0: raise KeyError("pre_trained_epoch_size must be greater than 0.") param_dict = load_checkpoint(args_opt.pre_trained) + filter_checkpoint_parameter(param_dict) load_param_into_net(net, param_dict) lr = Tensor(get_lr(global_step=config.global_step,