# coding: utf8 # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function import argparse import glob import math import json import os import os.path as osp import numpy as np import PIL.Image import PIL.ImageDraw import cv2 from gray2pseudo_color import get_color_map_list def parse_args(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('input_dir', help='input annotated directory') return parser.parse_args() def main(args): output_dir = osp.join(args.input_dir, 'annotations') if not osp.exists(output_dir): os.makedirs(output_dir) print('Creating annotations directory:', output_dir) # get the all class names for the given dataset class_names = ['_background_'] for label_file in glob.glob(osp.join(args.input_dir, '*.json')): with open(label_file) as f: data = json.load(f) for shape in data['shapes']: label = shape['label'] cls_name = label if not cls_name in class_names: class_names.append(cls_name) class_name_to_id = {} for i, class_name in enumerate(class_names): class_id = i # starts with 0 class_name_to_id[class_name] = class_id if class_id == 0: assert class_name == '_background_' class_names = tuple(class_names) print('class_names:', class_names) out_class_names_file = osp.join(args.input_dir, 'class_names.txt') with open(out_class_names_file, 'w') as f: f.writelines('\n'.join(class_names)) print('Saved class_names:', out_class_names_file) color_map = get_color_map_list(256) for label_file in glob.glob(osp.join(args.input_dir, '*.json')): print('Generating dataset from:', label_file) with open(label_file) as f: base = osp.splitext(osp.basename(label_file))[0] out_png_file = osp.join(output_dir, base + '.png') data = json.load(f) img_file = osp.join(osp.dirname(label_file), data['imagePath']) img = np.asarray(cv2.imread(img_file)) lbl = shape2label( img_size=img.shape, shapes=data['shapes'], class_name_mapping=class_name_to_id, ) if osp.splitext(out_png_file)[1] != '.png': out_png_file += '.png' # Assume label ranges [0, 255] for uint8, if lbl.min() >= 0 and lbl.max() <= 255: lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='P') lbl_pil.putpalette(color_map) lbl_pil.save(out_png_file) else: raise ValueError( '[%s] Cannot save the pixel-wise class label as PNG. ' 'Please consider using the .npy format.' % out_png_file) def shape2mask(img_size, points): label_mask = PIL.Image.fromarray(np.zeros(img_size[:2], dtype=np.uint8)) image_draw = PIL.ImageDraw.Draw(label_mask) points_list = [tuple(point) for point in points] assert len(points_list) > 2, 'Polygon must have points more than 2' image_draw.polygon(xy=points_list, outline=1, fill=1) return np.array(label_mask, dtype=bool) def shape2label(img_size, shapes, class_name_mapping): label = np.zeros(img_size[:2], dtype=np.int32) for shape in shapes: points = shape['points'] class_name = shape['label'] shape_type = shape.get('shape_type', None) class_id = class_name_mapping[class_name] label_mask = shape2mask(img_size[:2], points) label[label_mask] = class_id return label if __name__ == '__main__': args = parse_args() main(args)