diff --git a/paddlex/command.py b/paddlex/command.py index a08118c0460f02453439f7358cd504f2d5e4c0e4..74079b9b708073c264f1baaacbd4f4fe075bbfe1 100644 --- a/paddlex/command.py +++ b/paddlex/command.py @@ -200,10 +200,8 @@ def main(): logging.error("The value of split is not correct.") if not osp.exists(save_dir): logging.error("The path of saved split information doesn't exist.") - print(11111111111111) pdx.tools.split.dataset_split(dataset_dir, dataset_form, val_value, test_value, save_dir) - print(222222222) if __name__ == "__main__": diff --git a/paddlex/tools/dataset_split/__init__.py b/paddlex/tools/dataset_split/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/paddlex/tools/dataset_split/coco_split.py b/paddlex/tools/dataset_split/coco_split.py new file mode 100644 index 0000000000000000000000000000000000000000..37356da25c4fb7d5d99f925fbc9a70487b27c43e --- /dev/null +++ b/paddlex/tools/dataset_split/coco_split.py @@ -0,0 +1,69 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path as osp +import random +import json +from .utils import MyEncoder + + +def split_coco_dataset(dataset_dir, val_percent, test_percent, save_dir): + if not osp.exists(osp.join(dataset_dir, "annotations.json")): + raise ValueError("\'annotations.json\' is not found in {}!".format( + dataset_dir)) + try: + from pycocotools.coco import COCO + except: + print( + "pycococotools is not installed, follow this doc install pycocotools: https://paddlex.readthedocs.io/zh_CN/develop/install.html#pycocotools" + ) + return + + annotation_file = osp.join(dataset_dir, "annotations.json") + coco = COCO(annotation_file) + img_ids = coco.getImgIds() + cat_ids = coco.getCatIds() + anno_ids = coco.getAnnIds() + + val_num = int(len(img_ids) * val_percent) + test_num = int(len(img_ids) * test_percent) + train_num = len(img_ids) - val_num - test_num + + random.shuffle(img_ids) + train_files_ids = img_ids[:train_num] + val_files_ids = img_ids[train_num:train_num + val_num] + test_files_ids = img_ids[train_num + val_num:] + + for img_id_list in [train_files_ids, val_files_ids, test_files_ids]: + img_anno_ids = coco.getAnnIds(imgIds=img_id_list, iscrowd=0) + imgs = coco.loadImgs(img_id_list) + instances = coco.loadAnns(img_anno_ids) + categories = coco.loadCats(cat_ids) + img_dict = { + "annotations": instances, + "images": imgs, + "categories": categories + } + + if img_id_list == train_files_ids: + json_file = open(osp.join(save_dir, 'train.json'), 'w+') + json.dump(img_dict, json_file, cls=MyEncoder) + elif img_id_list == val_files_ids: + json_file = open(osp.join(save_dir, 'val.json'), 'w+') + json.dump(img_dict, json_file, cls=MyEncoder) + elif img_id_list == test_files_ids and len(test_files_ids): + json_file = open(osp.join(save_dir, 'test.json'), 'w+') + json.dump(img_dict, json_file, cls=MyEncoder) + + return train_num, val_num, test_num diff --git a/paddlex/tools/dataset_split/imagenet_split.py b/paddlex/tools/dataset_split/imagenet_split.py new file mode 100644 index 0000000000000000000000000000000000000000..164a812b43012c533d81f9e9f166e3b2bf066bff --- /dev/null +++ b/paddlex/tools/dataset_split/imagenet_split.py @@ -0,0 +1,74 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path as osp +import random +from .utils import list_files, is_pic + + +def split_imagenet_dataset(dataset_dir, val_percent, test_percent, save_dir): + all_files = list_files(dataset_dir) + label_list = list() + train_image_anno_list = list() + val_image_anno_list = list() + test_image_anno_list = list() + for file in all_files: + if not is_pic(file): + continue + label, image_name = osp.split(file) + if label not in label_list: + label_list.append(label) + label_list = sorted(label_list) + + for i in range(len(label_list)): + image_list = list_files(osp.join(dataset_dir, label_list[i])) + image_anno_list = list() + for img in image_list: + image_anno_list.append([osp.join(label_list[i], img), i]) + random.shuffle(image_anno_list) + image_num = len(image_anno_list) + val_num = int(image_num * val_percent) + test_num = int(image_num * test_percent) + train_num = image_num - val_num - test_num + + train_image_anno_list += image_anno_list[:train_num] + val_image_anno_list += image_anno_list[train_num:train_num + val_num] + test_image_anno_list += image_anno_list[train_num + val_num:] + + with open( + osp.join(save_dir, 'train_list.txt'), mode='w', + encoding='utf-8') as f: + for x in train_image_anno_list: + file, label = x + f.write('{} {}\n'.format(file, label)) + with open( + osp.join(save_dir, 'val_list.txt'), mode='w', + encoding='utf-8') as f: + for x in val_image_anno_list: + file, label = x + f.write('{} {}\n'.format(file, label)) + if len(test_image_anno_list): + with open( + osp.join(save_dir, 'test_list.txt'), mode='w', + encoding='utf-8') as f: + for x in test_image_anno_list: + file, label = x + f.write('{} {}\n'.format(file, label)) + with open( + osp.join(save_dir, 'labels.txt'), mode='w', encoding='utf-8') as f: + for l in sorted(label_list): + f.write('{}\n'.format(l)) + + return len(train_image_anno_list), len(val_image_anno_list), len( + test_image_anno_list) diff --git a/paddlex/tools/dataset_split/seg_split.py b/paddlex/tools/dataset_split/seg_split.py new file mode 100644 index 0000000000000000000000000000000000000000..0cfd0d2b37c17989c60ee9564b039e4929a47511 --- /dev/null +++ b/paddlex/tools/dataset_split/seg_split.py @@ -0,0 +1,93 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path as osp +import random +from .utils import list_files, is_pic, replace_ext, read_seg_ann + + +def split_seg_dataset(dataset_dir, val_percent, test_percent, save_dir): + if not osp.exists(osp.join(dataset_dir, "JPEGImages")): + raise ValueError("\'JPEGImages\' is not found in {}!".format( + dataset_dir)) + if not osp.exists(osp.join(dataset_dir, "Annotations")): + raise ValueError("\'Annotations\' is not found in {}!".format( + dataset_dir)) + + all_image_files = list_files(osp.join(dataset_dir, "JPEGImages")) + + image_anno_list = list() + label_list = list() + for image_file in all_image_files: + if not is_pic(image_file): + continue + anno_name = replace_ext(image_file, "png") + if osp.exists(osp.join(dataset_dir, "Annotations", anno_name)): + image_anno_list.append([image_file, anno_name]) + else: + anno_name = replace_ext(image_file, "PNG") + if osp.exists(osp.join(dataset_dir, "Annotations", anno_name)): + image_anno_list.append([image_file, anno_name]) + + if not osp.exists(osp.join(dataset_dir, "labels.txt")): + for image_anno in image_anno_list: + labels = read_seg_ann( + osp.join(dataset_dir, "Annotations", anno_name)) + for i in labels: + if i not in label_list: + label_list.append(i) + # 如果类标签的最大值大于类别数,添加对应缺失的标签 + if len(label_list) != max(label_list) + 1: + label_list = [i for i in range(max(label_list) + 1)] + + random.shuffle(image_anno_list) + image_num = len(image_anno_list) + val_num = int(image_num * val_percent) + test_num = int(image_num * test_percent) + train_num = image_num - val_num - test_num + + train_image_anno_list = image_anno_list[:train_num] + val_image_anno_list = image_anno_list[train_num:train_num + val_num] + test_image_anno_list = image_anno_list[train_num + val_num:] + + with open( + osp.join(save_dir, 'train_list.txt'), mode='w', + encoding='utf-8') as f: + for x in train_image_anno_list: + file = osp.join("JPEGImages", x[0]) + label = osp.join("Annotations", x[1]) + f.write('{} {}\n'.format(file, label)) + with open( + osp.join(save_dir, 'val_list.txt'), mode='w', + encoding='utf-8') as f: + for x in val_image_anno_list: + file = osp.join("JPEGImages", x[0]) + label = osp.join("Annotations", x[1]) + f.write('{} {}\n'.format(file, label)) + if len(test_image_anno_list): + with open( + osp.join(save_dir, 'test_list.txt'), mode='w', + encoding='utf-8') as f: + for x in test_image_anno_list: + file = osp.join("JPEGImages", x[0]) + label = osp.join("Annotations", x[1]) + f.write('{} {}\n'.format(file, label)) + if len(label_list): + with open( + osp.join(save_dir, 'labels.txt'), mode='w', + encoding='utf-8') as f: + for l in sorted(label_list): + f.write('{}\n'.format(l)) + + return train_num, val_num, test_num diff --git a/paddlex/tools/dataset_split/utils.py b/paddlex/tools/dataset_split/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9106b479000be497464df9d77835268b16c604a4 --- /dev/null +++ b/paddlex/tools/dataset_split/utils.py @@ -0,0 +1,102 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import os.path as osp +from PIL import Image +import numpy as np +import json + + +class MyEncoder(json.JSONEncoder): + # 调整json文件存储形式 + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + else: + return super(MyEncoder, self).default(obj) + + +def list_files(dirname): + """ 列出目录下所有文件(包括所属的一级子目录下文件) + + Args: + dirname: 目录路径 + """ + + def filter_file(f): + if f.startswith('.'): + return True + return False + + all_files = list() + dirs = list() + for f in os.listdir(dirname): + if filter_file(f): + continue + if osp.isdir(osp.join(dirname, f)): + dirs.append(f) + else: + all_files.append(f) + for d in dirs: + for f in os.listdir(osp.join(dirname, d)): + if filter_file(f): + continue + if osp.isdir(osp.join(dirname, d, f)): + continue + all_files.append(osp.join(d, f)) + return all_files + + +def is_pic(filename): + """ 判断文件是否为图片格式 + + Args: + filename: 文件路径 + """ + suffixes = {'JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png'} + suffix = filename.strip().split('.')[-1] + if suffix not in suffixes: + return False + return True + + +def replace_ext(filename, new_ext): + """ 替换文件后缀 + + Args: + filename: 文件路径 + new_ext: 需要替换的新的后缀 + """ + items = filename.split(".") + items[-1] = new_ext + new_filename = ".".join(items) + return new_filename + + +def read_seg_ann(pngfile): + """ 解析语义分割的标注png图片 + + Args: + pngfile: 包含标注信息的png图片路径 + """ + grt = np.asarray(Image.open(pngfile)) + labels = list(np.unique(grt)) + if 255 in labels: + labels.remove(255) + return labels diff --git a/paddlex/tools/dataset_split/voc_split.py b/paddlex/tools/dataset_split/voc_split.py new file mode 100644 index 0000000000000000000000000000000000000000..1822e2ace2786a059737976608019b21fd7ff961 --- /dev/null +++ b/paddlex/tools/dataset_split/voc_split.py @@ -0,0 +1,88 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path as osp +import random +import xml.etree.ElementTree as ET +from .utils import list_files, is_pic, replace_ext + + +def split_voc_dataset(dataset_dir, val_percent, test_percent, save_dir): + if not osp.exists(osp.join(dataset_dir, "JPEGImages")): + raise ValueError("\'JPEGImages\' is not found in {}!".format( + dataset_dir)) + if not osp.exists(osp.join(dataset_dir, "Annotations")): + raise ValueError("\'Annotations\' is not found in {}!".format( + dataset_dir)) + + all_image_files = list_files(osp.join(dataset_dir, "JPEGImages")) + + image_anno_list = list() + label_list = list() + for image_file in all_image_files: + if not is_pic(image_file): + continue + anno_name = replace_ext(image_file, "xml") + if osp.exists(osp.join(dataset_dir, "Annotations", anno_name)): + image_anno_list.append([image_file, anno_name]) + try: + tree = ET.parse( + osp.join(dataset_dir, "Annotations", anno_name)) + except: + raise Exception("文件{}不是一个良构的xml文件,请检查标注文件".format( + osp.join(dataset_dir, "Annotations", anno_name))) + objs = tree.findall("object") + for i, obj in enumerate(objs): + cname = obj.find('name').text + if not cname in label_list: + label_list.append(cname) + + random.shuffle(image_anno_list) + image_num = len(image_anno_list) + val_num = int(image_num * val_percent) + test_num = int(image_num * test_percent) + train_num = image_num - val_num - test_num + + train_image_anno_list = image_anno_list[:train_num] + val_image_anno_list = image_anno_list[train_num:train_num + val_num] + test_image_anno_list = image_anno_list[train_num + val_num:] + + with open( + osp.join(save_dir, 'train_list.txt'), mode='w', + encoding='utf-8') as f: + for x in train_image_anno_list: + file = osp.join("JPEGImages", x[0]) + label = osp.join("Annotations", x[1]) + f.write('{} {}\n'.format(file, label)) + with open( + osp.join(save_dir, 'val_list.txt'), mode='w', + encoding='utf-8') as f: + for x in val_image_anno_list: + file = osp.join("JPEGImages", x[0]) + label = osp.join("Annotations", x[1]) + f.write('{} {}\n'.format(file, label)) + if len(test_image_anno_list): + with open( + osp.join(save_dir, 'test_list.txt'), mode='w', + encoding='utf-8') as f: + for x in test_image_anno_list: + file = osp.join("JPEGImages", x[0]) + label = osp.join("Annotations", x[1]) + f.write('{} {}\n'.format(file, label)) + with open( + osp.join(save_dir, 'labels.txt'), mode='w', encoding='utf-8') as f: + for l in sorted(label_list): + f.write('{}\n'.format(l)) + + return train_num, val_num, test_num diff --git a/paddlex/tools/split.py b/paddlex/tools/split.py index 8e892c236ed44ed2c7cf7a95b0b57a25e16bd743..ac6b7f01ed32140a1a1a57f2f2fe525b34a56ed8 100644 --- a/paddlex/tools/split.py +++ b/paddlex/tools/split.py @@ -14,7 +14,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .dataset_split.coco_split import split_coco_dataset +from .dataset_split.voc_split import split_voc_dataset +from .dataset_split.imagenet_split import split_imagenet_dataset +from .dataset_split.seg_split import split_seg_dataset + def dataset_split(dataset_dir, dataset_form, val_value, test_value, save_dir): print(dataset_dir, dataset_form, val_value, test_value, save_dir) - print(12345) + if dataset_form == "coco": + train_num, val_num, test_num = split_coco_dataset( + dataset_dir, val_value, test_value, save_dir) + elif dataset_form == "voc": + train_num, val_num, test_num = split_voc_dataset( + dataset_dir, val_value, test_value, save_dir) + elif dataset_form == "seg": + train_num, val_num, test_num = split_seg_dataset( + dataset_dir, val_value, test_value, save_dir) + elif dataset_form == "imagenet": + train_num, val_num, test_num = split_imagenet_dataset( + dataset_dir, val_value, test_value, save_dir) + print("Dataset Split Done.") + print("Train samples: {}".format(train_num)) + print("Eval samples: {}".format(val_num)) + print("Test samples: {}".format(test_num)) + print("Split file saved in {}".format(save_dir))