From 8590cb53c0e3692e9d2dcf8e98ada5152398b960 Mon Sep 17 00:00:00 2001 From: sunyanfang01 Date: Tue, 12 May 2020 17:04:39 +0800 Subject: [PATCH] add data convert --- paddlex/__init__.py | 1 + paddlex/tools/__init__.py | 24 +++ paddlex/tools/base.py | 43 +++++ paddlex/tools/x2coco.py | 247 +++++++++++++++++++++++++++ paddlex/tools/x2imagenet.py | 50 ++++++ paddlex/tools/x2seg.py | 326 ++++++++++++++++++++++++++++++++++++ paddlex/tools/x2voc.py | 189 +++++++++++++++++++++ 7 files changed, 880 insertions(+) create mode 100644 paddlex/tools/__init__.py create mode 100644 paddlex/tools/base.py create mode 100644 paddlex/tools/x2coco.py create mode 100644 paddlex/tools/x2imagenet.py create mode 100644 paddlex/tools/x2seg.py create mode 100644 paddlex/tools/x2voc.py diff --git a/paddlex/__init__.py b/paddlex/__init__.py index a0b333f..a473b1c 100644 --- a/paddlex/__init__.py +++ b/paddlex/__init__.py @@ -19,6 +19,7 @@ from . import det from . import seg from . import cls from . import slim +from . import tools try: import pycocotools diff --git a/paddlex/tools/__init__.py b/paddlex/tools/__init__.py new file mode 100644 index 0000000..1b749d9 --- /dev/null +++ b/paddlex/tools/__init__.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# coding: utf-8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .x2imagenet import EasyData2ImageNet +from .x2coco import LabelMe2COCO +from .x2coco import EasyData2COCO +from .x2voc import LabelMe2VOC +from .x2voc import EasyData2VOC +from .x2seg import JingLing2Seg +from .x2seg import LabelMe2Seg +from .x2seg import EasyData2Seg diff --git a/paddlex/tools/base.py b/paddlex/tools/base.py new file mode 100644 index 0000000..94f9fa6 --- /dev/null +++ b/paddlex/tools/base.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python +# coding: utf-8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import chardet +import numpy as np + +class MyEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + else: + return super(MyEncoder, self).default(obj) + +def is_pic(img_name): + valid_suffix = ["JPEG", "jpeg", "JPG", "jpg", "BMP", "bmp", "PNG", "png"] + suffix = img_name.split(".")[-1] + if suffix not in valid_suffix: + return False + return True + +def get_encoding(path): + f = open(path, 'rb') + data = f.read() + file_encoding = chardet.detect(data).get('encoding') + return file_encoding \ No newline at end of file diff --git a/paddlex/tools/x2coco.py b/paddlex/tools/x2coco.py new file mode 100644 index 0000000..5bbc901 --- /dev/null +++ b/paddlex/tools/x2coco.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python +# coding: utf-8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import json +import os +import os.path as osp +import shutil +import numpy as np +import PIL.ImageDraw +from .base import MyEncoder, is_pic, get_encoding + + +class X2COCO(object): + def __init__(self): + self.images_list = [] + self.categories_list = [] + self.annotations_list = [] + + def generate_categories_field(self, label, labels_list): + category = {} + category["supercategory"] = "component" + category["id"] = len(labels_list) + 1 + category["name"] = label + return category + + def generate_rectangle_anns_field(self, points, label, image_id, object_id, label_to_num): + annotation = {} + seg_points = np.asarray(points).copy() + seg_points[1, :] = np.asarray(points)[2, :] + seg_points[2, :] = np.asarray(points)[1, :] + annotation["segmentation"] = [list(seg_points.flatten())] + annotation["iscrowd"] = 0 + annotation["image_id"] = image_id + 1 + annotation["bbox"] = list( + map(float, [ + points[0][0], points[0][1], points[1][0] - points[0][0], points[1][ + 1] - points[0][1] + ])) + annotation["area"] = annotation["bbox"][2] * annotation["bbox"][3] + annotation["category_id"] = label_to_num[label] + annotation["id"] = object_id + 1 + return annotation + + def convert(self, image_input_dir, json_input_dir, dataset_save_dir): + assert osp.exists(image_input_dir), "he image folder does not exist!" + assert osp.exists(json_input_dir), "The json folder does not exist!" + assert osp.exists(dataset_save_dir), "The save folder does not exist!" + # Convert the image files. + new_image_dir = osp.join(dataset_save_dir, "JPEGImages") + if osp.exists(new_image_dir): + shutil.rmtree(new_image_dir) + os.makedirs(new_image_dir) + for img_name in os.listdir(image_input_dir): + if is_pic(img_name): + shutil.copyfile( + osp.join(image_input_dir, img_name), + osp.join(new_image_dir, img_name)) + # Convert the json files. + self.analyse_json(new_image_dir, json_input_dir) + coco_data = {} + coco_data["images"] = self.images_list + coco_data["categories"] = self.categories_list + coco_data["annotations"] = self.annotations_list + json_path = osp.join(dataset_save_dir, "annotations.json") + json.dump( + coco_data, + open(json_path, "w"), + indent=4, + cls=MyEncoder) + + +class LabelMe2COCO(X2COCO): + def __init__(self): + super(LabelMe2COCO, self).__init__() + + def generate_images_field(self, json_info, image_id): + image = {} + image["height"] = json_info["imageHeight"] + image["width"] = json_info["imageWidth"] + image["id"] = image_id + 1 + image["file_name"] = json_info["imagePath"].split("/")[-1] + return image + + def generate_polygon_anns_field(self, height, width, + points, label, image_id, + object_id, label_to_num): + annotation = {} + annotation["segmentation"] = [list(np.asarray(points).flatten())] + annotation["iscrowd"] = 0 + annotation["image_id"] = image_id + 1 + annotation["bbox"] = list(map(float, get_bbox(height, width, points))) + annotation["area"] = annotation["bbox"][2] * annotation["bbox"][3] + annotation["category_id"] = label_to_num[label] + annotation["id"] = object_id + 1 + return annotation + + def get_bbox(self, height, width, points): + polygons = points + mask = np.zeros([height, width], dtype=np.uint8) + mask = PIL.Image.fromarray(mask) + xy = list(map(tuple, polygons)) + PIL.ImageDraw.Draw(mask).polygon(xy=xy, outline=1, fill=1) + mask = np.array(mask, dtype=bool) + index = np.argwhere(mask == 1) + rows = index[:, 0] + clos = index[:, 1] + left_top_r = np.min(rows) + left_top_c = np.min(clos) + right_bottom_r = np.max(rows) + right_bottom_c = np.max(clos) + return [ + left_top_c, left_top_r, right_bottom_c - left_top_c, + right_bottom_r - left_top_r + ] + + def analyse_json(self, img_dir, json_dir): + image_id = -1 + object_id = -1 + labels_list = [] + label_to_num = {} + for img_file in os.listdir(img_dir): + img_name_part = osp.splitext(img_file)[0] + json_file = osp.join(json_dir, img_name_part + ".json") + if not osp.exists(json_file): + os.remove(os.remove(osp.join(image_dir, img_file))) + continue + image_id = image_id + 1 + with open(json_file, mode='r', \ + encoding=get_encoding(json_file)) as j: + json_info = json.load(j) + img_info = self.generate_images_field(json_info, image_id) + self.images_list.append(img_info) + for shapes in json_info["shapes"]: + object_id = object_id + 1 + label = shapes["label"] + if label not in labels_list: + self.categories_list.append(\ + self.generate_categories_field(label, labels_list)) + labels_list.append(label) + label_to_num[label] = len(labels_list) + points = shapes["points"] + p_type = shapes["shape_type"] + if p_type == "polygon": + self.annotations_list.append( + self.generate_polygon_anns_field(json_info["imageHeight"], json_info[ + "imageWidth"], points, label, image_id, + object_id, label_to_num)) + if p_type == "rectangle": + points.append([points[0][0], points[1][1]]) + points.append([points[1][0], points[0][1]]) + self.annotations_list.append( + self.generate_rectangle_anns_field(points, label, image_id, + object_id, label_to_num)) + + +class EasyData2COCO(X2COCO): + def __init__(self): + super(EasyData2COCO, self).__init__() + + def generate_images_field(self, img_path, image_id): + image = {} + img = cv2.imread(img_path) + image["height"] = img.shape[0] + image["width"] = img.shape[1] + image["id"] = image_id + 1 + image["file_name"] = osp.split(img_path)[-1] + return image + + def generate_polygon_anns_field(self, points, segmentation, + label, image_id, object_id, + label_to_num): + annotation = {} + annotation["segmentation"] = segmentation + annotation["iscrowd"] = 1 if len(segmentation) > 1 else 0 + annotation["image_id"] = image_id + 1 + annotation["bbox"] = list(map(float, [ + points[0][0], points[0][1], points[1][0] - points[0][0], points[1][ + 1] - points[0][1] + ])) + annotation["area"] = annotation["bbox"][2] * annotation["bbox"][3] + annotation["category_id"] = label_to_num[label] + annotation["id"] = object_id + 1 + return annotation + + def analyse_json(self, img_dir, json_dir): + from pycocotools.mask import decode + image_id = -1 + object_id = -1 + labels_list = [] + label_to_num = {} + for img_file in os.listdir(img_dir): + img_name_part = osp.splitext(img_file)[0] + json_file = osp.join(json_dir, img_name_part + ".json") + if not osp.exists(json_file): + os.remove(os.remove(osp.join(image_dir, img_file))) + continue + image_id = image_id + 1 + with open(json_file, mode='r', \ + encoding=get_encoding(json_file)) as j: + json_info = json.load(j) + img_info = self.generate_images_field(osp.join(img_dir, img_file), image_id) + self.images_list.append(img_info) + for shapes in json_info["labels"]: + object_id = object_id + 1 + label = shapes["name"] + if label not in labels_list: + self.categories_list.append(\ + self.generate_categories_field(label, labels_list)) + labels_list.append(label) + label_to_num[label] = len(labels_list) + points = [[shapes["x1"], shapes["y1"]], + [shapes["x2"], shapes["y2"]]] + if "mask" not in shapes: + points.append([points[0][0], points[1][1]]) + points.append([points[1][0], points[0][1]]) + self.annotations_list.append( + self.generate_rectangle_anns_field(points, label, image_id, + object_id, label_to_num)) + else: + mask_dict = {} + mask_dict['size'] = [img_info["height"], img_info["width"]] + mask_dict['counts'] = shapes['mask'].encode() + mask = decode(mask_dict) + contours, hierarchy = cv2.findContours( + (mask).astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + segmentation = [] + for contour in contours: + contour_list = contour.flatten().tolist() + if len(contour_list) > 4: + segmentation.append(contour_list) + self.annotations_list.append( + self.generate_polygon_anns_field(points, segmentation, label, image_id, object_id, + label_to_num)) \ No newline at end of file diff --git a/paddlex/tools/x2imagenet.py b/paddlex/tools/x2imagenet.py new file mode 100644 index 0000000..aeea7d1 --- /dev/null +++ b/paddlex/tools/x2imagenet.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python +# coding: utf-8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import json +import os +import os.path as osp +import shutil +import numpy as np +from .base import MyEncoder, is_pic, get_encoding + +class EasyData2ImageNet(object): + def __init__(self): + pass + + def convert(self, image_dir, json_dir, dataset_save_dir): + assert osp.exists(image_dir), "The image folder does not exist!" + assert osp.exists(json_dir), "The json folder does not exist!" + assert osp.exists(dataset_save_dir), "The save folder does not exist!" + assert len(os.listdir(dataset_save_dir)) == 0, "The save folder must be empty!" + for img_name in os.listdir(image_dir): + img_name_part = osp.splitext(img_name)[0] + json_file = osp.join(json_dir, img_name_part + ".json") + if not osp.exists(json_file): + continue + with open(json_file, mode="r", \ + encoding=get_encoding(json_file)) as j: + json_info = json.load(j) + for output in json_info['labels']: + cls_name = output['name'] + new_image_dir = osp.join(dataset_save_dir, cls_name) + if not osp.exists(new_image_dir): + os.makedirs(new_image_dir) + if is_pic(img_name): + shutil.copyfile( + osp.join(image_dir, img_name), + osp.join(new_image_dir, img_name)) \ No newline at end of file diff --git a/paddlex/tools/x2seg.py b/paddlex/tools/x2seg.py new file mode 100644 index 0000000..d4b12e7 --- /dev/null +++ b/paddlex/tools/x2seg.py @@ -0,0 +1,326 @@ +#!/usr/bin/env python +# coding: utf-8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import uuid +import json +import os +import os.path as osp +import shutil +import numpy as np +import PIL.Image +from .base import MyEncoder, is_pic, get_encoding + +class X2Seg(object): + def __init__(self): + self.labels2ids = {'_background_': 0} + + def shapes_to_label(self, img_shape, shapes, label_name_to_value): + def shape_to_mask(img_shape, points, shape_type=None, + line_width=10, point_size=5): + mask = np.zeros(img_shape[:2], dtype=np.uint8) + mask = PIL.Image.fromarray(mask) + draw = PIL.ImageDraw.Draw(mask) + xy = [tuple(point) for point in points] + if shape_type == 'circle': + assert len(xy) == 2, 'Shape of shape_type=circle must have 2 points' + (cx, cy), (px, py) = xy + d = math.sqrt((cx - px) ** 2 + (cy - py) ** 2) + draw.ellipse([cx - d, cy - d, cx + d, cy + d], outline=1, fill=1) + elif shape_type == 'rectangle': + assert len(xy) == 2, 'Shape of shape_type=rectangle must have 2 points' + draw.rectangle(xy, outline=1, fill=1) + elif shape_type == 'line': + assert len(xy) == 2, 'Shape of shape_type=line must have 2 points' + draw.line(xy=xy, fill=1, width=line_width) + elif shape_type == 'linestrip': + draw.line(xy=xy, fill=1, width=line_width) + elif shape_type == 'point': + assert len(xy) == 1, 'Shape of shape_type=point must have 1 points' + cx, cy = xy[0] + r = point_size + draw.ellipse([cx - r, cy - r, cx + r, cy + r], outline=1, fill=1) + else: + assert len(xy) > 2, 'Polygon must have points more than 2' + draw.polygon(xy=xy, outline=1, fill=1) + mask = np.array(mask, dtype=bool) + return mask + cls = np.zeros(img_shape[:2], dtype=np.int32) + ins = np.zeros_like(cls) + instances = [] + for shape in shapes: + points = shape['points'] + label = shape['label'] + group_id = shape.get('group_id') + if group_id is None: + group_id = uuid.uuid1() + shape_type = shape.get('shape_type', None) + + cls_name = label + instance = (cls_name, group_id) + + if instance not in instances: + instances.append(instance) + ins_id = instances.index(instance) + 1 + cls_id = label_name_to_value[cls_name] + mask = shape_to_mask(img_shape[:2], points, shape_type) + cls[mask] = cls_id + ins[mask] = ins_id + return cls, ins + + def get_color_map_list(self, num_classes): + """ Returns the color map for visualizing the segmentation mask, + which can support arbitrary number of classes. + Args: + num_classes: Number of classes + Returns: + The color map + """ + color_map = num_classes * [0, 0, 0] + for i in range(0, num_classes): + j = 0 + lab = i + while lab: + color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j)) + color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)) + color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) + j += 1 + lab >>= 3 + return color_map + + def convert(self, image_dir, json_dir, dataset_save_dir): + assert osp.exists(image_dir), "The image folder does not exist!" + assert osp.exists(json_dir), "The json folder does not exist!" + assert osp.exists(dataset_save_dir), "The save folder does not exist!" + # Convert the image files. + new_image_dir = osp.join(dataset_save_dir, "JPEGImages") + if osp.exists(new_image_dir): + shutil.rmtree(new_image_dir) + os.makedirs(new_image_dir) + for img_name in os.listdir(image_dir): + if is_pic(img_name): + shutil.copyfile( + osp.join(image_dir, img_name), + osp.join(new_image_dir, img_name)) + # Convert the json files. + png_dir = osp.join(dataset_save_dir, "Annotations") + if osp.exists(png_dir): + shutil.rmtree(png_dir) + os.makedirs(png_dir) + self.get_labels2ids(new_image_dir, json_dir) + self.json2png(new_image_dir, json_dir, png_dir) + # Generate the labels.txt + ids2labels = {v : k for k, v in self.labels2ids.items()} + with open(osp.join(dataset_save_dir, 'labels.txt'), 'w') as fw: + for i in range(len(ids2labels)): + fw.write(ids2labels[i] + '\n') + + +class JingLing2Seg(X2Seg): + def __init__(self): + super(JingLing2Seg, self).__init__() + + def get_labels2ids(self, image_dir, json_dir): + for img_name in os.listdir(image_dir): + img_name_part = osp.splitext(img_name)[0] + json_file = osp.join(json_dir, img_name_part + ".json") + if not osp.exists(json_file): + os.remove(os.remove(osp.join(image_dir, img_name))) + continue + with open(json_file, mode="r", \ + encoding=get_encoding(json_file)) as j: + json_info = json.load(j) + if 'outputs' in json_info: + for output in json_info['outputs']['object']: + cls_name = output['name'] + if cls_name not in self.labels2ids: + self.labels2ids[cls_name] = len(self.labels2ids) + + def json2png(self, image_dir, json_dir, png_dir): + color_map = self.get_color_map_list(256) + for img_name in os.listdir(image_dir): + img_name_part = osp.splitext(img_name)[0] + json_file = osp.join(json_dir, img_name_part + ".json") + if not osp.exists(json_file): + os.remove(os.remove(osp.join(image_dir, img_name))) + continue + with open(json_file, mode="r", \ + encoding=get_encoding(json_file)) as j: + json_info = json.load(j) + data_shapes = [] + if 'outputs' in json_info: + for output in json_info['outputs']['object']: + if 'polygon' in output.keys(): + polygon = output['polygon'] + name = output['name'] + points = [] + for i in range(1, int(len(polygon) / 2) + 1): + points.append( + [polygon['x' + str(i)], polygon['y' + str(i)]]) + shape = { + 'label': name, + 'points': points, + 'shape_type': 'polygon' + } + data_shapes.append(shape) + if 'size' not in json_info: + continue + img_shape = (json_info['size']['height'], + json_info['size']['width'], + json_info['size']['depth']) + lbl, _ = self.shapes_to_label( + img_shape=img_shape, + shapes=data_shapes, + label_name_to_value=self.labels2ids, + ) + out_png_file = osp.join(png_dir, img_name_part + '.png') + if lbl.min() >= 0 and lbl.max() <= 255: + lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='P') + lbl_pil.putpalette(color_map) + lbl_pil.save(out_png_file) + else: + raise ValueError( + '[%s] Cannot save the pixel-wise class label as PNG. ' + 'Please consider using the .npy format.' % out_png_file) + + +class LabelMe2Seg(X2Seg): + def __init__(self): + super(LabelMe2Seg, self).__init__() + + def get_labels2ids(self, image_dir, json_dir): + for img_name in os.listdir(image_dir): + img_name_part = osp.splitext(img_name)[0] + json_file = osp.join(json_dir, img_name_part + ".json") + if not osp.exists(json_file): + os.remove(os.remove(osp.join(image_dir, img_name))) + continue + with open(json_file, mode="r", \ + encoding=get_encoding(json_file)) as j: + json_info = json.load(j) + for shape in json_info['shapes']: + cls_name = shape['label'] + if cls_name not in self.labels2ids: + self.labels2ids[cls_name] = len(self.labels2ids) + + def json2png(self, image_dir, json_dir, png_dir): + color_map = self.get_color_map_list(256) + for img_name in os.listdir(image_dir): + img_name_part = osp.splitext(img_name)[0] + json_file = osp.join(json_dir, img_name_part + ".json") + if not osp.exists(json_file): + os.remove(os.remove(osp.join(image_dir, img_name))) + continue + img_file = osp.join(image_dir, img_name) + img = np.asarray(PIL.Image.open(img_file)) + with open(json_file, mode="r", \ + encoding=get_encoding(json_file)) as j: + json_info = json.load(j) + lbl, _ = self.shapes_to_label( + img_shape=img.shape, + shapes=json_info['shapes'], + label_name_to_value=self.labels2ids, + ) + out_png_file = osp.join(png_dir, img_name_part + '.png') + if lbl.min() >= 0 and lbl.max() <= 255: + lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='P') + lbl_pil.putpalette(color_map) + lbl_pil.save(out_png_file) + else: + raise ValueError( + '[%s] Cannot save the pixel-wise class label as PNG. ' + 'Please consider using the .npy format.' % out_png_file) + + +class EasyData2Seg(X2Seg): + def __init__(self): + super(EasyData2Seg, self).__init__() + + def get_labels2ids(self, image_dir, json_dir): + for img_name in os.listdir(image_dir): + img_name_part = osp.splitext(img_name)[0] + json_file = osp.join(json_dir, img_name_part + ".json") + if not osp.exists(json_file): + os.remove(os.remove(osp.join(image_dir, img_name))) + continue + with open(json_file, mode="r", \ + encoding=get_encoding(json_file)) as j: + json_info = json.load(j) + for shape in json_info["labels"]: + cls_name = shape['name'] + if cls_name not in self.labels2ids: + self.labels2ids[cls_name] = len(self.labels2ids) + + def mask2polygon(self, mask, label): + contours, hierarchy = cv2.findContours( + (mask).astype(np.uint8), cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE) + segmentation = [] + for contour in contours: + contour_list = contour.flatten().tolist() + if len(contour_list) > 4: + points = [] + for i in range(0, len(contour_list), 2): + points.append( + [contour_list[i], contour_list[i + 1]]) + shape = { + 'label': label, + 'points': points, + 'shape_type': 'polygon' + } + segmentation.append(shape) + return segmentation + + def json2png(self, image_dir, json_dir, png_dir): + from pycocotools.mask import decode + color_map = self.get_color_map_list(256) + for img_name in os.listdir(image_dir): + img_name_part = osp.splitext(img_name)[0] + json_file = osp.join(json_dir, img_name_part + ".json") + if not osp.exists(json_file): + os.remove(os.remove(osp.join(image_dir, img_name))) + continue + img_file = osp.join(image_dir, img_name) + img = np.asarray(PIL.Image.open(img_file)) + img_h = img.shape[0] + img_w = img.shape[1] + with open(json_file, mode="r", \ + encoding=get_encoding(json_file)) as j: + json_info = json.load(j) + data_shapes = [] + for shape in json_info['labels']: + mask_dict = {} + mask_dict['size'] = [img_h, img_w] + mask_dict['counts'] = shape['mask'].encode() + mask = decode(mask_dict) + polygon = self.mask2polygon(mask, shape["name"]) + data_shapes.extend(polygon) + lbl, _ = self.shapes_to_label( + img_shape=img.shape, + shapes=data_shapes, + label_name_to_value=self.labels2ids, + ) + out_png_file = osp.join(png_dir, img_name_part + '.png') + if lbl.min() >= 0 and lbl.max() <= 255: + lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='P') + lbl_pil.putpalette(color_map) + lbl_pil.save(out_png_file) + else: + raise ValueError( + '[%s] Cannot save the pixel-wise class label as PNG. ' + 'Please consider using the .npy format.' % out_png_file) + + + diff --git a/paddlex/tools/x2voc.py b/paddlex/tools/x2voc.py new file mode 100644 index 0000000..b241716 --- /dev/null +++ b/paddlex/tools/x2voc.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python +# coding: utf-8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import json +import os +import os.path as osp +import shutil +import numpy as np +from .base import MyEncoder, is_pic, get_encoding + +class X2VOC(object): + def __init__(self): + pass + + def convert(self, image_dir, json_dir, dataset_save_dir): + assert osp.exists(image_dir), "The image folder does not exist!" + assert osp.exists(json_dir), "The json folder does not exist!" + assert osp.exists(dataset_save_dir), "The save folder does not exist!" + # Convert the image files. + new_image_dir = osp.join(dataset_save_dir, "JPEGImages") + if osp.exists(new_image_dir): + shutil.rmtree(new_image_dir) + os.makedirs(new_image_dir) + for img_name in os.listdir(image_dir): + if is_pic(img_name): + shutil.copyfile( + osp.join(image_dir, img_name), + osp.join(new_image_dir, img_name)) + # Convert the json files. + xml_dir = osp.join(dataset_save_dir, "Annotations") + if osp.exists(xml_dir): + shutil.rmtree(xml_dir) + os.makedirs(xml_dir) + self.json2xml(new_image_dir, json_dir, xml_dir) + + +class LabelMe2VOC(X2VOC): + def __init__(self): + pass + + def json2xml(self, image_dir, json_dir, xml_dir): + import xml.dom.minidom as minidom + for img_name in os.listdir(image_dir): + img_name_part = osp.splitext(img_name)[0] + json_file = osp.join(json_dir, img_name_part + ".json") + if not osp.exists(json_file): + os.remove(os.remove(osp.join(image_dir, img_name))) + continue + xml_doc = minidom.Document() + root = xml_doc.createElement("annotation") + xml_doc.appendChild(root) + node_folder = xml_doc.createElement("folder") + node_folder.appendChild(xml_doc.createTextNode("JPEGImages")) + root.appendChild(node_folder) + node_filename = xml_doc.createElement("filename") + node_filename.appendChild(xml_doc.createTextNode(img_name)) + root.appendChild(node_filename) + with open(json_file, mode="r", \ + encoding=get_encoding(json_file)) as j: + json_info = json.load(j) + h = json_info["imageHeight"] + w = json_info["imageWidth"] + node_size = xml_doc.createElement("size") + node_width = xml_doc.createElement("width") + node_width.appendChild(xml_doc.createTextNode(str(w))) + node_size.appendChild(node_width) + node_height = xml_doc.createElement("height") + node_height.appendChild(xml_doc.createTextNode(str(h))) + node_size.appendChild(node_height) + node_depth = xml_doc.createElement("depth") + node_depth.appendChild(xml_doc.createTextNode(str(3))) + node_size.appendChild(node_depth) + root.appendChild(node_size) + for shape in json_info["shapes"]: + if shape["shape_type"] != "rectangle": + continue + label = shape["label"] + (xmin, ymin), (xmax, ymax) = shape["points"] + xmin, xmax = sorted([xmin, xmax]) + ymin, ymax = sorted([ymin, ymax]) + node_obj = xml_doc.createElement("object") + node_name = xml_doc.createElement("name") + node_name.appendChild(xml_doc.createTextNode(label)) + node_obj.appendChild(node_name) + node_diff = xml_doc.createElement("difficult") + node_diff.appendChild(xml_doc.createTextNode(str(0))) + node_obj.appendChild(node_diff) + node_box = xml_doc.createElement("bndbox") + node_xmin = xml_doc.createElement("xmin") + node_xmin.appendChild(xml_doc.createTextNode(str(xmin))) + node_box.appendChild(node_xmin) + node_ymin = xml_doc.createElement("ymin") + node_ymin.appendChild(xml_doc.createTextNode(str(ymin))) + node_box.appendChild(node_ymin) + node_xmax = xml_doc.createElement("xmax") + node_xmax.appendChild(xml_doc.createTextNode(str(xmax))) + node_box.appendChild(node_xmax) + node_ymax = xml_doc.createElement("ymax") + node_ymax.appendChild(xml_doc.createTextNode(str(ymax))) + node_box.appendChild(node_ymax) + node_obj.appendChild(node_box) + root.appendChild(node_obj) + with open(osp.join(xml_dir, img_name_part + ".xml"), 'w') as fxml: + xml_doc.writexml(fxml, indent='\t', addindent='\t', newl='\n', encoding="utf-8") + + +class EasyData2VOC(X2VOC): + def __init__(self): + pass + + def json2xml(self, image_dir, json_dir, xml_dir): + import xml.dom.minidom as minidom + for img_name in os.listdir(image_dir): + img_name_part = osp.splitext(img_name)[0] + json_file = osp.join(json_dir, img_name_part + ".json") + if not osp.exists(json_file): + os.remove(os.remove(osp.join(image_dir, img_name))) + continue + xml_doc = minidom.Document() + root = xml_doc.createElement("annotation") + xml_doc.appendChild(root) + node_folder = xml_doc.createElement("folder") + node_folder.appendChild(xml_doc.createTextNode("JPEGImages")) + root.appendChild(node_folder) + node_filename = xml_doc.createElement("filename") + node_filename.appendChild(xml_doc.createTextNode(img_name)) + root.appendChild(node_filename) + img = cv2.imread(osp.join(image_dir, img_name)) + h = img.shape[0] + w = img.shape[1] + node_size = xml_doc.createElement("size") + node_width = xml_doc.createElement("width") + node_width.appendChild(xml_doc.createTextNode(str(w))) + node_size.appendChild(node_width) + node_height = xml_doc.createElement("height") + node_height.appendChild(xml_doc.createTextNode(str(h))) + node_size.appendChild(node_height) + node_depth = xml_doc.createElement("depth") + node_depth.appendChild(xml_doc.createTextNode(str(3))) + node_size.appendChild(node_depth) + root.appendChild(node_size) + with open(json_file, mode="r", \ + encoding=get_encoding(json_file)) as j: + json_info = json.load(j) + for shape in json_info["labels"]: + label = shape["name"] + xmin = shape["x1"] + ymin = shape["y1"] + xmax = shape["x2"] + ymax = shape["y2"] + node_obj = xml_doc.createElement("object") + node_name = xml_doc.createElement("name") + node_name.appendChild(xml_doc.createTextNode(label)) + node_obj.appendChild(node_name) + node_diff = xml_doc.createElement("difficult") + node_diff.appendChild(xml_doc.createTextNode(str(0))) + node_obj.appendChild(node_diff) + node_box = xml_doc.createElement("bndbox") + node_xmin = xml_doc.createElement("xmin") + node_xmin.appendChild(xml_doc.createTextNode(str(xmin))) + node_box.appendChild(node_xmin) + node_ymin = xml_doc.createElement("ymin") + node_ymin.appendChild(xml_doc.createTextNode(str(ymin))) + node_box.appendChild(node_ymin) + node_xmax = xml_doc.createElement("xmax") + node_xmax.appendChild(xml_doc.createTextNode(str(xmax))) + node_box.appendChild(node_xmax) + node_ymax = xml_doc.createElement("ymax") + node_ymax.appendChild(xml_doc.createTextNode(str(ymax))) + node_box.appendChild(node_ymax) + node_obj.appendChild(node_box) + root.appendChild(node_obj) + with open(osp.join(xml_dir, img_name_part + ".xml"), 'w') as fxml: + xml_doc.writexml(fxml, indent='\t', addindent='\t', newl='\n', encoding="utf-8") + -- GitLab