提交 7420b0f4 编写于 作者: M mamingjie-China

update

上级 a237df88
......@@ -200,10 +200,8 @@ def main():
logging.error("The value of split is not correct.")
if not osp.exists(save_dir):
logging.error("The path of saved split information doesn't exist.")
print(11111111111111)
pdx.tools.split.dataset_split(dataset_dir, dataset_form, val_value,
test_value, save_dir)
print(222222222)
if __name__ == "__main__":
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
import random
import json
from .utils import MyEncoder
def split_coco_dataset(dataset_dir, val_percent, test_percent, save_dir):
if not osp.exists(osp.join(dataset_dir, "annotations.json")):
raise ValueError("\'annotations.json\' is not found in {}!".format(
dataset_dir))
try:
from pycocotools.coco import COCO
except:
print(
"pycococotools is not installed, follow this doc install pycocotools: https://paddlex.readthedocs.io/zh_CN/develop/install.html#pycocotools"
)
return
annotation_file = osp.join(dataset_dir, "annotations.json")
coco = COCO(annotation_file)
img_ids = coco.getImgIds()
cat_ids = coco.getCatIds()
anno_ids = coco.getAnnIds()
val_num = int(len(img_ids) * val_percent)
test_num = int(len(img_ids) * test_percent)
train_num = len(img_ids) - val_num - test_num
random.shuffle(img_ids)
train_files_ids = img_ids[:train_num]
val_files_ids = img_ids[train_num:train_num + val_num]
test_files_ids = img_ids[train_num + val_num:]
for img_id_list in [train_files_ids, val_files_ids, test_files_ids]:
img_anno_ids = coco.getAnnIds(imgIds=img_id_list, iscrowd=0)
imgs = coco.loadImgs(img_id_list)
instances = coco.loadAnns(img_anno_ids)
categories = coco.loadCats(cat_ids)
img_dict = {
"annotations": instances,
"images": imgs,
"categories": categories
}
if img_id_list == train_files_ids:
json_file = open(osp.join(save_dir, 'train.json'), 'w+')
json.dump(img_dict, json_file, cls=MyEncoder)
elif img_id_list == val_files_ids:
json_file = open(osp.join(save_dir, 'val.json'), 'w+')
json.dump(img_dict, json_file, cls=MyEncoder)
elif img_id_list == test_files_ids and len(test_files_ids):
json_file = open(osp.join(save_dir, 'test.json'), 'w+')
json.dump(img_dict, json_file, cls=MyEncoder)
return train_num, val_num, test_num
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
import random
from .utils import list_files, is_pic
def split_imagenet_dataset(dataset_dir, val_percent, test_percent, save_dir):
all_files = list_files(dataset_dir)
label_list = list()
train_image_anno_list = list()
val_image_anno_list = list()
test_image_anno_list = list()
for file in all_files:
if not is_pic(file):
continue
label, image_name = osp.split(file)
if label not in label_list:
label_list.append(label)
label_list = sorted(label_list)
for i in range(len(label_list)):
image_list = list_files(osp.join(dataset_dir, label_list[i]))
image_anno_list = list()
for img in image_list:
image_anno_list.append([osp.join(label_list[i], img), i])
random.shuffle(image_anno_list)
image_num = len(image_anno_list)
val_num = int(image_num * val_percent)
test_num = int(image_num * test_percent)
train_num = image_num - val_num - test_num
train_image_anno_list += image_anno_list[:train_num]
val_image_anno_list += image_anno_list[train_num:train_num + val_num]
test_image_anno_list += image_anno_list[train_num + val_num:]
with open(
osp.join(save_dir, 'train_list.txt'), mode='w',
encoding='utf-8') as f:
for x in train_image_anno_list:
file, label = x
f.write('{} {}\n'.format(file, label))
with open(
osp.join(save_dir, 'val_list.txt'), mode='w',
encoding='utf-8') as f:
for x in val_image_anno_list:
file, label = x
f.write('{} {}\n'.format(file, label))
if len(test_image_anno_list):
with open(
osp.join(save_dir, 'test_list.txt'), mode='w',
encoding='utf-8') as f:
for x in test_image_anno_list:
file, label = x
f.write('{} {}\n'.format(file, label))
with open(
osp.join(save_dir, 'labels.txt'), mode='w', encoding='utf-8') as f:
for l in sorted(label_list):
f.write('{}\n'.format(l))
return len(train_image_anno_list), len(val_image_anno_list), len(
test_image_anno_list)
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
import random
from .utils import list_files, is_pic, replace_ext, read_seg_ann
def split_seg_dataset(dataset_dir, val_percent, test_percent, save_dir):
if not osp.exists(osp.join(dataset_dir, "JPEGImages")):
raise ValueError("\'JPEGImages\' is not found in {}!".format(
dataset_dir))
if not osp.exists(osp.join(dataset_dir, "Annotations")):
raise ValueError("\'Annotations\' is not found in {}!".format(
dataset_dir))
all_image_files = list_files(osp.join(dataset_dir, "JPEGImages"))
image_anno_list = list()
label_list = list()
for image_file in all_image_files:
if not is_pic(image_file):
continue
anno_name = replace_ext(image_file, "png")
if osp.exists(osp.join(dataset_dir, "Annotations", anno_name)):
image_anno_list.append([image_file, anno_name])
else:
anno_name = replace_ext(image_file, "PNG")
if osp.exists(osp.join(dataset_dir, "Annotations", anno_name)):
image_anno_list.append([image_file, anno_name])
if not osp.exists(osp.join(dataset_dir, "labels.txt")):
for image_anno in image_anno_list:
labels = read_seg_ann(
osp.join(dataset_dir, "Annotations", anno_name))
for i in labels:
if i not in label_list:
label_list.append(i)
# 如果类标签的最大值大于类别数,添加对应缺失的标签
if len(label_list) != max(label_list) + 1:
label_list = [i for i in range(max(label_list) + 1)]
random.shuffle(image_anno_list)
image_num = len(image_anno_list)
val_num = int(image_num * val_percent)
test_num = int(image_num * test_percent)
train_num = image_num - val_num - test_num
train_image_anno_list = image_anno_list[:train_num]
val_image_anno_list = image_anno_list[train_num:train_num + val_num]
test_image_anno_list = image_anno_list[train_num + val_num:]
with open(
osp.join(save_dir, 'train_list.txt'), mode='w',
encoding='utf-8') as f:
for x in train_image_anno_list:
file = osp.join("JPEGImages", x[0])
label = osp.join("Annotations", x[1])
f.write('{} {}\n'.format(file, label))
with open(
osp.join(save_dir, 'val_list.txt'), mode='w',
encoding='utf-8') as f:
for x in val_image_anno_list:
file = osp.join("JPEGImages", x[0])
label = osp.join("Annotations", x[1])
f.write('{} {}\n'.format(file, label))
if len(test_image_anno_list):
with open(
osp.join(save_dir, 'test_list.txt'), mode='w',
encoding='utf-8') as f:
for x in test_image_anno_list:
file = osp.join("JPEGImages", x[0])
label = osp.join("Annotations", x[1])
f.write('{} {}\n'.format(file, label))
if len(label_list):
with open(
osp.join(save_dir, 'labels.txt'), mode='w',
encoding='utf-8') as f:
for l in sorted(label_list):
f.write('{}\n'.format(l))
return train_num, val_num, test_num
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os.path as osp
from PIL import Image
import numpy as np
import json
class MyEncoder(json.JSONEncoder):
# 调整json文件存储形式
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return super(MyEncoder, self).default(obj)
def list_files(dirname):
""" 列出目录下所有文件(包括所属的一级子目录下文件)
Args:
dirname: 目录路径
"""
def filter_file(f):
if f.startswith('.'):
return True
return False
all_files = list()
dirs = list()
for f in os.listdir(dirname):
if filter_file(f):
continue
if osp.isdir(osp.join(dirname, f)):
dirs.append(f)
else:
all_files.append(f)
for d in dirs:
for f in os.listdir(osp.join(dirname, d)):
if filter_file(f):
continue
if osp.isdir(osp.join(dirname, d, f)):
continue
all_files.append(osp.join(d, f))
return all_files
def is_pic(filename):
""" 判断文件是否为图片格式
Args:
filename: 文件路径
"""
suffixes = {'JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png'}
suffix = filename.strip().split('.')[-1]
if suffix not in suffixes:
return False
return True
def replace_ext(filename, new_ext):
""" 替换文件后缀
Args:
filename: 文件路径
new_ext: 需要替换的新的后缀
"""
items = filename.split(".")
items[-1] = new_ext
new_filename = ".".join(items)
return new_filename
def read_seg_ann(pngfile):
""" 解析语义分割的标注png图片
Args:
pngfile: 包含标注信息的png图片路径
"""
grt = np.asarray(Image.open(pngfile))
labels = list(np.unique(grt))
if 255 in labels:
labels.remove(255)
return labels
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
import random
import xml.etree.ElementTree as ET
from .utils import list_files, is_pic, replace_ext
def split_voc_dataset(dataset_dir, val_percent, test_percent, save_dir):
if not osp.exists(osp.join(dataset_dir, "JPEGImages")):
raise ValueError("\'JPEGImages\' is not found in {}!".format(
dataset_dir))
if not osp.exists(osp.join(dataset_dir, "Annotations")):
raise ValueError("\'Annotations\' is not found in {}!".format(
dataset_dir))
all_image_files = list_files(osp.join(dataset_dir, "JPEGImages"))
image_anno_list = list()
label_list = list()
for image_file in all_image_files:
if not is_pic(image_file):
continue
anno_name = replace_ext(image_file, "xml")
if osp.exists(osp.join(dataset_dir, "Annotations", anno_name)):
image_anno_list.append([image_file, anno_name])
try:
tree = ET.parse(
osp.join(dataset_dir, "Annotations", anno_name))
except:
raise Exception("文件{}不是一个良构的xml文件,请检查标注文件".format(
osp.join(dataset_dir, "Annotations", anno_name)))
objs = tree.findall("object")
for i, obj in enumerate(objs):
cname = obj.find('name').text
if not cname in label_list:
label_list.append(cname)
random.shuffle(image_anno_list)
image_num = len(image_anno_list)
val_num = int(image_num * val_percent)
test_num = int(image_num * test_percent)
train_num = image_num - val_num - test_num
train_image_anno_list = image_anno_list[:train_num]
val_image_anno_list = image_anno_list[train_num:train_num + val_num]
test_image_anno_list = image_anno_list[train_num + val_num:]
with open(
osp.join(save_dir, 'train_list.txt'), mode='w',
encoding='utf-8') as f:
for x in train_image_anno_list:
file = osp.join("JPEGImages", x[0])
label = osp.join("Annotations", x[1])
f.write('{} {}\n'.format(file, label))
with open(
osp.join(save_dir, 'val_list.txt'), mode='w',
encoding='utf-8') as f:
for x in val_image_anno_list:
file = osp.join("JPEGImages", x[0])
label = osp.join("Annotations", x[1])
f.write('{} {}\n'.format(file, label))
if len(test_image_anno_list):
with open(
osp.join(save_dir, 'test_list.txt'), mode='w',
encoding='utf-8') as f:
for x in test_image_anno_list:
file = osp.join("JPEGImages", x[0])
label = osp.join("Annotations", x[1])
f.write('{} {}\n'.format(file, label))
with open(
osp.join(save_dir, 'labels.txt'), mode='w', encoding='utf-8') as f:
for l in sorted(label_list):
f.write('{}\n'.format(l))
return train_num, val_num, test_num
......@@ -14,7 +14,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .dataset_split.coco_split import split_coco_dataset
from .dataset_split.voc_split import split_voc_dataset
from .dataset_split.imagenet_split import split_imagenet_dataset
from .dataset_split.seg_split import split_seg_dataset
def dataset_split(dataset_dir, dataset_form, val_value, test_value, save_dir):
print(dataset_dir, dataset_form, val_value, test_value, save_dir)
print(12345)
if dataset_form == "coco":
train_num, val_num, test_num = split_coco_dataset(
dataset_dir, val_value, test_value, save_dir)
elif dataset_form == "voc":
train_num, val_num, test_num = split_voc_dataset(
dataset_dir, val_value, test_value, save_dir)
elif dataset_form == "seg":
train_num, val_num, test_num = split_seg_dataset(
dataset_dir, val_value, test_value, save_dir)
elif dataset_form == "imagenet":
train_num, val_num, test_num = split_imagenet_dataset(
dataset_dir, val_value, test_value, save_dir)
print("Dataset Split Done.")
print("Train samples: {}".format(train_num))
print("Eval samples: {}".format(val_num))
print("Test samples: {}".format(test_num))
print("Split file saved in {}".format(save_dir))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册