提交 3078fd41 编写于 作者: H hanjun996

add voc_dataset to mindrecord

上级 8b3e6496
......@@ -22,7 +22,7 @@ import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.ssd import SSD300, ssd_mobilenet_v2
from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image
from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord
from src.config import config
from src.coco_eval import metrics
......@@ -88,6 +88,13 @@ if __name__ == '__main__':
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("coco_root not exits.")
elif args_opt.dataset == "voc":
if os.path.isdir(config.voc_dir) and os.path.isdir(config.voc_root):
print("Create Mindrecord.")
voc_data_to_mindrecord(mindrecord_dir, False, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("voc_root or voc_dir not exits.")
else:
if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path):
print("Create Mindrecord.")
......
......@@ -71,8 +71,11 @@ config = ed({
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush'),
"num_classes": 81,
# if coco used, `image_dir` and `anno_path` are useless.
# The annotation.json position of voc validation dataset.
"voc_root": "",
# voc original dataset.
"voc_dir": "",
# if coco or voc used, `image_dir` and `anno_path` are useless.
"image_dir": "",
"anno_path": "",
})
......@@ -18,6 +18,8 @@
from __future__ import division
import os
import json
import xml.etree.ElementTree as et
import cv2
import numpy as np
......@@ -32,6 +34,13 @@ def _rand(a=0., b=1.):
"""Generate random."""
return np.random.rand() * (b - a) + a
def get_imageId_from_fileName(filename):
"""Get imageID from fileName"""
try:
filename = os.path.splitext(filename)[0]
return int(filename)
except:
raise NotImplementedError('Filename %s is supposed to be an integer.'%(filename))
def random_sample_crop(image, boxes):
"""Random Crop the image and boxes"""
......@@ -144,6 +153,96 @@ def preprocess_fn(img_id, image, box, is_training):
return _data_aug(image, box, is_training, image_size=config.img_shape)
def create_voc_label(is_training):
"""Get image path and annotation from VOC."""
voc_dir = config.voc_dir
cls_map = {name: i for i, name in enumerate(config.coco_classes)}
sub_dir = 'train' if is_training else 'eval'
#sub_dir = 'train'
voc_dir = os.path.join(voc_dir, sub_dir)
if not os.path.isdir(voc_dir):
raise ValueError(f'Cannot find {sub_dir} dataset path.')
image_dir = anno_dir = voc_dir
if os.path.isdir(os.path.join(voc_dir, 'Images')):
image_dir = os.path.join(voc_dir, 'Images')
if os.path.isdir(os.path.join(voc_dir, 'Annotations')):
anno_dir = os.path.join(voc_dir, 'Annotations')
if not is_training:
data_dir = config.voc_root
json_file = os.path.join(data_dir, config.instances_set.format(sub_dir))
file_dir = os.path.split(json_file)[0]
if not os.path.isdir(file_dir):
os.makedirs(file_dir)
json_dict = {"images": [], "type": "instances", "annotations": [],
"categories": []}
bnd_id = 1
image_files_dict = {}
image_anno_dict = {}
images = []
for anno_file in os.listdir(anno_dir):
print(anno_file)
if not anno_file.endswith('xml'):
continue
tree = et.parse(os.path.join(anno_dir, anno_file))
root_node = tree.getroot()
file_name = root_node.find('filename').text
img_id = get_imageId_from_fileName(file_name)
image_path = os.path.join(image_dir, file_name)
print(image_path)
if not os.path.isfile(image_path):
print(f'Cannot find image {file_name} according to annotations.')
continue
labels = []
for obj in root_node.iter('object'):
cls_name = obj.find('name').text
if cls_name not in cls_map:
print(f'Label "{cls_name}" not in "{config.coco_classes}"')
continue
bnd_box = obj.find('bndbox')
x_min = int(bnd_box.find('xmin').text) - 1
y_min = int(bnd_box.find('ymin').text) - 1
x_max = int(bnd_box.find('xmax').text) - 1
y_max = int(bnd_box.find('ymax').text) - 1
labels.append([y_min, x_min, y_max, x_max, cls_map[cls_name]])
if not is_training:
o_width = abs(x_max - x_min)
o_height = abs(y_max - y_min)
ann = {'area': o_width * o_height, 'iscrowd': 0, 'image_id': \
img_id, 'bbox': [x_min, y_min, o_width, o_height], \
'category_id': cls_map[cls_name], 'id': bnd_id, \
'ignore': 0, \
'segmentation': []}
json_dict['annotations'].append(ann)
bnd_id = bnd_id + 1
if labels:
images.append(img_id)
image_files_dict[img_id] = image_path
image_anno_dict[img_id] = np.array(labels)
if not is_training:
size = root_node.find("size")
width = int(size.find('width').text)
height = int(size.find('height').text)
image = {'file_name': file_name, 'height': height, 'width': width,
'id': img_id}
json_dict['images'].append(image)
for cls_name, cid in cls_map.items():
cat = {'supercategory': 'none', 'id': cid, 'name': cls_name}
json_dict['categories'].append(cat)
json_fp = open(json_file, 'w')
json_str = json.dumps(json_dict)
json_fp.write(json_str)
json_fp.close()
return images, image_files_dict, image_anno_dict
def create_coco_label(is_training):
"""Get image path and annotation from COCO."""
from pycocotools.coco import COCO
......@@ -233,6 +332,30 @@ def filter_valid_data(image_dir, anno_path):
return images, image_path_dict, image_anno_dict
def voc_data_to_mindrecord(mindrecord_dir, is_training, prefix="ssd.mindrecord", file_num=8):
"""Create MindRecord file by image_dir and anno_path."""
mindrecord_path = os.path.join(mindrecord_dir, prefix)
writer = FileWriter(mindrecord_path, file_num)
images, image_path_dict, image_anno_dict = create_voc_label(is_training)
ssd_json = {
"img_id": {"type": "int32", "shape": [1]},
"image": {"type": "bytes"},
"annotation": {"type": "int32", "shape": [-1, 5]},
}
writer.add_schema(ssd_json, "ssd_json")
for img_id in images:
image_path = image_path_dict[img_id]
with open(image_path, 'rb') as f:
img = f.read()
annos = np.array(image_anno_dict[img_id], dtype=np.int32)
img_id = np.array([img_id], dtype=np.int32)
row = {"img_id": img_id, "image": img, "annotation": annos}
writer.write_raw_data([row])
writer.commit()
def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd.mindrecord", file_num=8):
"""Create MindRecord file."""
mindrecord_dir = config.mindrecord_dir
......
......@@ -40,3 +40,9 @@ def load_backbone_params(network, param_dict):
param_name = '.'.join(['features', str(int(name_split[1]) + 14)] + name_split[2:])
if param_name in param_dict:
param.set_parameter_data(param_dict[param_name].data)
def filter_checkpoint_parameter(param_dict):
"""remove useless parameters"""
for key in list(param_dict.keys()):
if 'multi_loc_layers' in key or 'multi_cls_layers' in key:
del param_dict[key]
......@@ -25,9 +25,9 @@ from mindspore.train import Model, ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2
from src.config import config
from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image
from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord
from src.lr_schedule import get_lr
from src.init_params import init_net_param
from src.init_params import init_net_param, filter_checkpoint_parameter
def main():
......@@ -79,6 +79,13 @@ def main():
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("coco_root not exits.")
elif args_opt.dataset == "voc":
if os.path.isdir(config.voc_dir):
print("Create Mindrecord.")
voc_data_to_mindrecord(mindrecord_dir, True, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("voc_dir not exits.")
else:
if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path):
print("Create Mindrecord.")
......@@ -110,6 +117,7 @@ def main():
if args_opt.pre_trained_epoch_size <= 0:
raise KeyError("pre_trained_epoch_size must be greater than 0.")
param_dict = load_checkpoint(args_opt.pre_trained)
filter_checkpoint_parameter(param_dict)
load_param_into_net(net, param_dict)
lr = Tensor(get_lr(global_step=config.global_step,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册