From 209cfdaa0e74b09a8208e6c97253eebe774ad749 Mon Sep 17 00:00:00 2001 From: SunAhong1993 <48579383+SunAhong1993@users.noreply.github.com> Date: Thu, 27 Jun 2019 10:21:21 +0800 Subject: [PATCH] modify the crop operator (#2507) * modify the crop operator * modify the judgment * modify the judgment v2 * modify the data_feed.py * modify the readme and add labelme2coco code * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README_cn.md * Update README_cn.md * Update README_cn.md * Update README_cn.md * Update README_cn.md * Update README_cn.md * Update README.md * Update README_cn.md * Update README.md * Update README_cn.md * Update labelme2coco.py * Update labelme2coco.py * Update labelme2coco.py * Update data_feed.py --- .../object_detection/ppdet/data/README.md | 202 ++++++++++++++ .../object_detection/ppdet/data/README_cn.md | 201 ++++++++++++++ .../ppdet/data/tools/labelme2coco.py | 259 ++++++++++++++++++ .../ppdet/data/transform/operators.py | 11 +- 4 files changed, 670 insertions(+), 3 deletions(-) create mode 100644 PaddleCV/object_detection/ppdet/data/README.md create mode 100644 PaddleCV/object_detection/ppdet/data/README_cn.md create mode 100644 PaddleCV/object_detection/ppdet/data/tools/labelme2coco.py diff --git a/PaddleCV/object_detection/ppdet/data/README.md b/PaddleCV/object_detection/ppdet/data/README.md new file mode 100644 index 00000000..34e647b9 --- /dev/null +++ b/PaddleCV/object_detection/ppdet/data/README.md @@ -0,0 +1,202 @@ +## Introduction +This is a Python module used to load and convert data into formats for detection model training, evaluation and inference. The converted sample schema is a tuple of np.ndarrays. For example, the schema of Faster R-CNN training data is: `[(im, im_info, im_id, gt_bbox, gt_class, is_crowd), (...)]`. + +### Implementation +This module is consists of four sub-systems: data parsing, image pre-processing, data conversion and data feeding apis. + +We use `dataset.Dataset` to abstract a set of data samples. For example, `COCO` data contains 3 sets of data for training, validation, and testing respectively. Original data stored in files could be loaded into memory using `dataset.source`; Then make use of `dataset.transform` to process the data; Finally, the batch data could be fetched by the api of `dataset.Reader`. + +Sub-systems introduction: +1. Data prasing +By data parsing, we can get a `dataset.Dataset` instance, whose implementation is located in `dataset.source`. This sub-system is used to parse different data formats, which is easy to add new data format supports. Currently, only following data sources are included: + +- COCO data source +This kind of source is used to load `COCO` data directly, eg: `COCO2017`. It's composed of json files for labeling info and image files. And it's directory structure is as follows: + + ``` + data/coco/ + ├── annotations + │ ├── instances_train2017.json + │ ├── instances_val2017.json + | ... + ├── train2017 + │ ├── 000000000009.jpg + │ ├── 000000580008.jpg + | ... + ├── val2017 + │ ├── 000000000139.jpg + │ ├── 000000000285.jpg + | ... + ``` + +- Pascal VOC data source +This kind of source is used to load `VOC` data directly, eg: `VOC2007`. It's composed of xml files for labeling info and image files. And it's directory structure is as follows: + + + ``` + data/pascalvoc/ + ├──Annotations + │ ├── i000050.jpg + │ ├── 003876.xml + | ... + ├── ImageSets + │ ├──Main + └── train.txt + └── val.txt + └── test.txt + └── dog_train.txt + └── dog_trainval.txt + └── dog_val.txt + └── dog_test.txt + └── ... + │ ├──Layout + └──... + │ ├── Segmentation + └──... + ├── JPEGImages + │ ├── 000050.jpg + │ ├── 003876.jpg + | ... + ``` + + + +- Roidb data source +This kind of source is a normalized data format which only contains a pickle file. The pickle file only has a dictionary which only has a list named 'records' (maybe there is a mapping file for label name to label id named 'canme2id'). You can convert `COCO` or `VOC` data into this format. The pickle file's content is as follows: +```python +(records, catname2clsid) +'records' is list of dict whose structure is: +{ + 'im_file': im_fname, # image file name + 'im_id': im_id, # image id + 'h': im_h, # height of image + 'w': im_w, # width + 'is_crowd': is_crowd, + 'gt_class': gt_class, + 'gt_bbox': gt_bbox, + 'gt_poly': gt_poly, +} +'cname2id' is a dict to map category name to class id + +``` +We also provide the tool to generate the roidb data source in `./tools/`. You can use the follow command to implement. +```python +# --type: the type of original data (xml or json) +# --annotation: the path of file, which contains the name of annotation files +# --save-dir: the save path +# --samples: the number of samples (default is -1, which mean all datas in dataset) +python ./tools/generate_data_for_training.py + --type=json \ + --annotation=./annotations/instances_val2017.json \ + --save-dir=./roidb \ + --samples=-1 +``` + + 2. Image preprocessing + Image preprocessing subsystem includes operations such as image decoding, expanding, cropping, etc. We use `dataset.transform.operator` to unify the implementation, which is convenient for extension. In addition, multiple operators can be combined to form a complex processing pipeline, and used by data transformers in `dataset.transformer`, such as multi-threading to acclerate a complex image data processing. + + 3. Data transformer + The function of the data transformer is used to convert a `dataset.Dataset` to a new `dataset.Dataset`, for example: convert a jpeg image dataset into a decoded and resized dataset. We use the decorator pattern to implement different transformers which are all subclass of `dataset.Dataset`. For example, the `dataset.transform.paralle_map` transformer is for multi-process preprocessing, more transformers can be found in `dataset.transform.transformer`. + + 4. Data feeding apis +To facilitate data pipeline building and data feeding for training, we combine multiple `dataset.Dataset` to form a `dataset.Reader` which can provide data for training, validation and testing respectively. The user only needs to call `Reader.[train|eval|infer]` to get the corresponding data stream. `Reader` supports yaml file to configure data address, preprocessing oprators, acceleration mode, and so on. + + + +The main APIs are as follows: + + + +1. Data parsing + + - `source/coco_loader.py`: Use to parse the COCO dataset. [detail code](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/object_detection/ppdet/data/source/coco_loader.py) + - `source/voc_loader.py`: Use to parse the Pascal VOC dataset. [detail code](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/object_detection/ppdet/data/source/voc_loader.py) + [Note] When using VOC datasets, if you do not use the default label list, you need to generate `label_list.txt` using `tools/generate_data_for_training.py` (the usage method is same as generating the roidb data source) or provide `label_list.txt` in `data/pascalvoc/ImageSets/Main` firstly. Also set the parameter `use_default_label` to `false` in the configuration file. + - `source/loader.py`: Use to parse the Roidb dataset. [detail code](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/object_detection/ppdet/data/source/loader.py) + +2. Operator + `transform/operators.py`: Contains a variety of data enhancement methods, including: + +``` python +RandomFlipImage: Horizontal flip. +RandomDistort: Distort brightness, contrast, saturation, and hue. +ResizeImage: Adjust the image size according to the specific interpolation method. +RandomInterpImage: Use a random interpolation method to resize the image. +CropImage: Crop image with respect to different scale, aspect ratio, and overlap. +ExpandImage: Put the original image into a larger expanded image which is initialized using image mean. +DecodeImage: Read images in RGB format. +Permute: Arrange the channels of the image and converted to the BGR format. +NormalizeImage: Normalize image pixel values. +NormalizeBox: Normalize the bounding box. +MixupImage: Mixup two images in proportion. +``` +[Note] The mixup operation can refer to[paper](https://arxiv.org/pdf/1710.09412.pdf)。 + +`transform/arrange_sample.py`: Sort the data which need to input the network. +3. Transformer +`transform/post_map.py`: A pre-processing operation for completing batch data, which mainly includes: + +``` python +Randomly adjust the image size of the batch data +Multi-scale adjustment of image size +Padding operation +``` +`transform/transformer.py`: Used to filter useless data and return batch data. +`transform/parallel_map.py`: Used to achieve acceleration. +4. Reader +`reader.py`: Used to combine source and transformer operations, and return batch data according to `max_iter`. +`data_feed.py`: Configure default parameters for `reader.py`. + + + + + +### Usage + +#### Ordinary usage +The function of this module is completed by combining the configuration information in the yaml file. The use of yaml files can be found in the configuration file section. + + - Read data for training + +``` python +ccfg = load_cfg('./config.yml') +coco = Reader(ccfg.DATA, ccfg.TRANSFORM, maxiter=-1) +``` +#### How to use customized dataset? +- Option 1: Convert the dataset to the VOC format or COCO format. +```python + # In ./tools/, the code named labelme2coco.py is provided to convert + # the dataset which is annotatedby Labelme to a COCO dataset. + python ./tools/labelme2coco.py --json_input_dir ./labelme_annos/ + --image_input_dir ./labelme_imgs/ + --output_dir ./cocome/ + --train_proportion 0.8 + --val_proportion 0.2 + --test_proportion 0.0 + # --json_input_dir:The path of json files which are annotated by Labelme. + # --image_input_dir:The path of images. + # --output_dir:The path of coverted COCO dataset. + # --train_proportion:The train proportion of annatation data. + # --val_proportion:The validation proportion of annatation data. + # --test_proportion: The inference proportion of annatation data. +``` +- Option 2: + +1. Following the `./source/coco_loader.py` and `./source/voc_loader.py`, add `./source/XX_loader.py` and implement the `load` function. +2. Add the entry for `./source/XX_loader.py` in the `load` function of `./source/loader.py`. +3. Modify `./source/__init__.py`: + + +```python +if data_cf['type'] in ['VOCSource', 'COCOSource', 'RoiDbSource']: + source_type = 'RoiDbSource' +# Replace the above code with the following code: +if data_cf['type'] in ['VOCSource', 'COCOSource', 'RoiDbSource', 'XXSource']: + source_type = 'RoiDbSource' +``` + +4. In the configure file, define the `type` of `dataset` as `XXSource`。 + +#### How to add data pre-processing? +- If you want to add the enhanced preprocessing of a single image, you can refer to the code of each class in `transform/operators.py`, and create a new class to implement new data enhancement. Also add the name of this preprocessing to the configuration file. +- If you want to add image preprocessing for a single batch, you can refer to the code for each function in `build_post_map` of `transform/post_map.py`, and create a new internal function to implement new batch data preprocessing. Also add the name of this preprocessing to the configuration file. diff --git a/PaddleCV/object_detection/ppdet/data/README_cn.md b/PaddleCV/object_detection/ppdet/data/README_cn.md new file mode 100644 index 00000000..0dfce342 --- /dev/null +++ b/PaddleCV/object_detection/ppdet/data/README_cn.md @@ -0,0 +1,201 @@ +## 介绍 +本模块是一个Python模块,用于加载数据并将其转换成适用于检测模型的训练、验证、测试所需要的格式——由多个np.ndarray组成的tuple数组,例如用于Faster R-CNN模型的训练数据格式为:`[(im, im_info, im_id, gt_bbox, gt_class, is_crowd), (...)]`。 + +### 实现 +该模块内部可分为4个子功能:数据解析、图片预处理、数据转换和数据获取接口。 + +我们采用`dataset.Dataset`表示一份数据,比如`COCO`数据包含3份数据,分别用于训练、验证和测试。原始数据存储与文件中,通过`dataset.source`加载到内存,然后使用`dataset.transform`对数据进行处理转换,最终通过`dataset.Reader`的接口可以获得用于训练、验证和测试的batch数据。 + +子功能介绍: + +1. 数据解析 + 数据解析得到的是`dataset.Dataset`,实现逻辑位于`dataset.source`中。通过它可以实现解析不同格式的数据集,已支持的数据源包括: +- COCO数据源 + 该数据集目前分为COCO2012和COCO2017,主要由json文件和image文件组成,其组织结构如下所示: + + ``` + data/coco/ + ├── annotations + │ ├── instances_train2014.json + │ ├── instances_train2017.json + │ ├── instances_val2014.json + │ ├── instances_val2017.json + | ... + ├── train2017 + │ ├── 000000000009.jpg + │ ├── 000000580008.jpg + | ... + ├── val2017 + │ ├── 000000000139.jpg + │ ├── 000000000285.jpg + | ... + ``` + + +- Pascal VOC数据源 + 该数据集目前分为VOC2007和VOC2012,主要由xml文件和image文件组成,其组织结构如下所示: + + + ``` + data/pascalvoc/ + ├──Annotations + │ ├── i000050.jpg + │ ├── 003876.xml + | ... + ├── ImageSets + │ ├──Main + └── train.txt + └── val.txt + └── test.txt + └── dog_train.txt + └── dog_trainval.txt + └── dog_val.txt + └── dog_test.txt + └── ... + │ ├──Layout + └──... + │ ├── Segmentation + └──... + ├── JPEGImages + │ ├── 000050.jpg + │ ├── 003876.jpg + | ... + ``` + + + +- Roidb数据源 + 该数据集主要由COCO数据集和Pascal VOC数据集转换而成的pickle文件,包含一个dict,而dict中只包含一个命名为‘records’的list(可能还有一个命名为‘cname2cid’的字典),其内容如下所示: +```python +(records, catname2clsid) +'records'是一个list并且它的结构如下: +{ + 'im_file': im_fname, # 图像文件名 + 'im_id': im_id, # 图像id + 'h': im_h, # 图像高度 + 'w': im_w, # 图像宽度 + 'is_crowd': is_crowd, # 是否重叠 + 'gt_class': gt_class, # 真实框类别 + 'gt_bbox': gt_bbox, # 真实框坐标 + 'gt_poly': gt_poly, # 多边形坐标 +} +'cname2id'是一个dict,保存了类别名到id的映射 + +``` +我们在`./tools/`中提供了一个生成roidb数据集的代码,可以通过下面命令实现该功能。 +```python +# --type: 原始数据集的类别(只能是xml或者json) +# --annotation: 一个包含所需标注文件名的文件的路径 +# --save-dir: 保存路径 +# --samples: sample的个数(默认是-1,代表使用所有sample) +python ./tools/generate_data_for_training.py + --type=json \ + --annotation=./annotations/instances_val2017.json \ + --save-dir=./roidb \ + --samples=-1 +``` + 2. 图片预处理 + 图片预处理通过包括图片解码、缩放、裁剪等操作,我们采用`dataset.transform.operator`算子的方式来统一实现,这样能方便扩展。此外,多个算子还可以组合形成复杂的处理流程, 并被`dataset.transformer`中的转换器使用,比如多线程完成一个复杂的预处理流程。 + + 3. 数据转换器 + 数据转换器的功能是完成对某个`dataset.Dataset`进行转换处理,从而得到一个新的`dataset.Dataset`。我们采用装饰器模式实现各种不同的`dataset.transform.transformer`。比如用于多进程预处理的`dataset.transform.paralle_map`转换器。 + + 4. 数据获取接口 + 为方便训练时的数据获取,我们将多个`dataset.Dataset`组合在一起构成一个`dataset.Reader`为用户提供数据,用户只需要调用`Reader.[train|eval|infer]`即可获得对应的数据流。`Reader`支持yaml文件配置数据地址、预处理过程、加速方式等。 + +主要的APIs如下: + + + + +1. 数据解析 + + - `source/coco_loader.py`:用于解析COCO数据集。[详见代码](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/object_detection/ppdet/data/source/coco_loader.py) + - `source/voc_loader.py`:用于解析Pascal VOC数据集。[详见代码](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/object_detection/ppdet/data/source/voc_loader.py) + [注意]在使用VOC数据集时,若不使用默认的label列表,则需要先使用`tools/generate_data_for_training.py`生成`label_list.txt`(使用方式与数据解析中的roidb数据集获取过程一致),或提供`label_list.txt`放置于`data/pascalvoc/ImageSets/Main`中;同时在配置文件中设置参数`use_default_label`为`true`。 + - `source/loader.py`:用于解析Roidb数据集。[详见代码](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/object_detection/ppdet/data/source/loader.py) + +2. 算子 + `transform/operators.py`:包含多种数据增强方式,主要包括: + +``` python +RandomFlipImage:水平翻转。 +RandomDistort:随机扰动图片亮度、对比度、饱和度和色相。 +ResizeImage:根据特定的插值方式调整图像大小。 +RandomInterpImage:使用随机的插值方式调整图像大小。 +CropImage:根据缩放比例、长宽比例两个参数生成若干候选框,再依据这些候选框和标注框的面积交并比(IoU)挑选出符合要求的裁剪结果。 +ExpandImage:将原始图片放进一张使用像素均值填充(随后会在减均值操作中减掉)的扩张图中,再对此图进行裁剪、缩放和翻转。 +DecodeImage:以RGB格式读取图像。 +Permute:对图像的通道进行排列并转为BGR格式。 +NormalizeImage:对图像像素值进行归一化。 +NormalizeBox:对bounding box进行归一化。 +MixupImage:按比例叠加两张图像。 +``` +[注意]:Mixup的操作可参考[论文](https://arxiv.org/pdf/1710.09412.pdf)。 + +`transform/arrange_sample.py`:实现对输入网络数据的排序。 +3. 转换 +`transform/post_map.py`:用于完成批数据的预处理操作,其主要包括: + +``` python +随机调整批数据的图像大小 +多尺度调整图像大小 +padding操作 +``` +`transform/transformer.py`:用于过滤无用的数据,并返回批数据。 +`transform/parallel_map.py`:用于实现加速。 +4. 读取 +`reader.py`:用于组合source和transformer操作,根据`max_iter`返回batch数据。 +`data_feed.py`: 用于配置 `reader.py`中所需的默认参数. + + + + +### 使用 +#### 常规使用 +结合yaml文件中的配置信息,完成本模块的功能。yaml文件的使用可以参见配置文件部分。 + + - 读取用于训练的数据 + +``` python +ccfg = load_cfg('./config.yml') +coco = Reader(ccfg.DATA, ccfg.TRANSFORM, maxiter=-1) +``` +#### 如何使用自定义数据集? + +- 选择1:将数据集转换为VOC格式或者COCO格式。 +```python + # 在./tools/中提供了labelme2coco.py用于将labelme标注的数据集转换为COCO数据集 + python ./tools/labelme2coco.py --json_input_dir ./labelme_annos/ + --image_input_dir ./labelme_imgs/ + --output_dir ./cocome/ + --train_proportion 0.8 + --val_proportion 0.2 + --test_proportion 0.0 + # --json_input_dir:使用labelme标注的json文件所在文件夹 + # --image_input_dir:图像文件所在文件夹 + # --output_dir:转换后的COCO格式数据集存放位置 + # --train_proportion:标注数据中用于train的比例 + # --val_proportion:标注数据中用于validation的比例 + # --test_proportion: 标注数据中用于infer的比例 +``` +- 选择2: + +1. 仿照`./source/coco_loader.py`和`./source/voc_loader.py`,添加`./source/XX_loader.py`并实现`load`函数。 +2. 在`./source/loader.py`的`load`函数中添加使用`./source/XX_loader.py`的入口。 +3. 修改`./source/__init__.py`: + + +```python +if data_cf['type'] in ['VOCSource', 'COCOSource', 'RoiDbSource']: + source_type = 'RoiDbSource' +# 将上述代码替换为如下代码: +if data_cf['type'] in ['VOCSource', 'COCOSource', 'RoiDbSource', 'XXSource']: + source_type = 'RoiDbSource' +``` + +4. 在配置文件中修改`dataset`下的`type`为`XXSource`。 + +#### 如何增加数据预处理? +- 若增加单张图像的增强预处理,可在`transform/operators.py`中参考每个类的代码,新建一个类来实现新的数据增强;同时在配置文件中增加该预处理。 +- 若增加单个batch的图像预处理,可在`transform/post_map.py`中参考`build_post_map`中每个函数的代码,新建一个内部函数来实现新的批数据预处理;同时在配置文件中增加该预处理。 diff --git a/PaddleCV/object_detection/ppdet/data/tools/labelme2coco.py b/PaddleCV/object_detection/ppdet/data/tools/labelme2coco.py new file mode 100644 index 00000000..bf5dc32b --- /dev/null +++ b/PaddleCV/object_detection/ppdet/data/tools/labelme2coco.py @@ -0,0 +1,259 @@ +#!/usr/bin/env python +# coding: utf-8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import glob +import json +import os +import os.path as osp +import sys +import shutil + +import numpy as np +import PIL.ImageDraw + + +class MyEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + else: + return super(MyEncoder, self).default(obj) + + +def getbbox(self, points): + polygons = points + mask = self.polygons_to_mask([self.height, self.width], polygons) + return self.mask2box(mask) + + +def images(data, num): + image = {} + image['height'] = data['imageHeight'] + image['width'] = data['imageWidth'] + image['id'] = num + 1 + image['file_name'] = data['imagePath'].split('/')[-1] + return image + + +def categories(label, labels_list): + category = {} + category['supercategory'] = 'component' + category['id'] = len(labels_list) + 1 + category['name'] = label + return category + + +def annotations_rectangle(points, label, num, label_to_num): + annotation = {} + seg_points = np.asarray(points).copy() + seg_points[1, :] = np.asarray(points)[2, :] + seg_points[2, :] = np.asarray(points)[1, :] + annotation['segmentation'] = [list(seg_points.flatten())] + annotation['iscrowd'] = 0 + annotation['image_id'] = num + 1 + annotation['bbox'] = list( + map(float, [ + points[0][0], points[0][1], points[1][0] - points[0][0], points[1][ + 1] - points[0][1] + ])) + annotation['area'] = annotation['bbox'][2] * annotation['bbox'][3] + annotation['category_id'] = label_to_num[label] + annotation['id'] = num + 1 + return annotation + + +def annotations_polygon(height, width, points, label, num, label_to_num): + annotation = {} + annotation['segmentation'] = [list(np.asarray(points).flatten())] + annotation['iscrowd'] = 0 + annotation['image_id'] = num + 1 + annotation['bbox'] = list(map(float, get_bbox(height, width, points))) + annotation['area'] = annotation['bbox'][2] * annotation['bbox'][3] + annotation['category_id'] = label_to_num[label] + annotation['id'] = num + 1 + return annotation + + +def get_bbox(height, width, points): + polygons = points + mask = np.zeros([height, width], dtype=np.uint8) + mask = PIL.Image.fromarray(mask) + xy = list(map(tuple, polygons)) + PIL.ImageDraw.Draw(mask).polygon(xy=xy, outline=1, fill=1) + mask = np.array(mask, dtype=bool) + index = np.argwhere(mask == 1) + rows = index[:, 0] + clos = index[:, 1] + left_top_r = np.min(rows) + left_top_c = np.min(clos) + right_bottom_r = np.max(rows) + right_bottom_c = np.max(clos) + return [ + left_top_c, left_top_r, right_bottom_c - left_top_c, + right_bottom_r - left_top_r + ] + + +def deal_json(img_path, json_path): + data_coco = {} + label_to_num = {} + images_list = [] + categories_list = [] + annotations_list = [] + labels_list = [] + num = -1 + for img_file in os.listdir(img_path): + img_label = img_file.split('.')[0] + label_file = osp.join(json_path, img_label + '.json') + print('Generating dataset from:', label_file) + num = num + 1 + with open(label_file) as f: + data = json.load(f) + images_list.append(images(data, num)) + for shapes in data['shapes']: + label = shapes['label'] + if label not in labels_list: + categories_list.append(categories(label, labels_list)) + labels_list.append(label) + label_to_num[label] = len(labels_list) + points = shapes['points'] + p_type = shapes['shape_type'] + if p_type == 'polygon': + annotations_list.append( + annotations_polygon(data['imageHeight'], data[ + 'imageWidth'], points, label, num, label_to_num)) + + if p_type == 'rectangle': + points.append([points[0][0], points[1][1]]) + points.append([points[1][0], points[0][1]]) + annotations_list.append( + annotations_rectangle(points, label, num, label_to_num)) + data_coco['images'] = images_list + data_coco['categories'] = categories_list + data_coco['annotations'] = annotations_list + return data_coco + + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--json_input_dir', help='input annotated directory') + parser.add_argument('--image_input_dir', help='image directory') + parser.add_argument( + '--output_dir', help='output dataset directory', default='../../../') + parser.add_argument( + '--train_proportion', + help='the proportion of train dataset', + type=float, + default=1.0) + parser.add_argument( + '--val_proportion', + help='the proportion of validation dataset', + type=float, + default=0.0) + parser.add_argument( + '--test_proportion', + help='the proportion of test dataset', + type=float, + default=0.0) + args = parser.parse_args() + try: + assert os.path.exists(args.json_input_dir) + except AssertionError as e: + print('The json folder does not exist!') + os._exit(0) + try: + assert os.path.exists(args.image_input_dir) + except AssertionError as e: + print('The image folder does not exist!') + os._exit(0) + try: + assert args.train_proportion + args.val_proportion + args.test_proportion == 1.0 + except AssertionError as e: + print( + 'The sum of pqoportion of training, validation and test datase must be 1!' + ) + os._exit(0) + + # Allocate the dataset. + total_num = len(glob.glob(osp.join(args.json_input_dir, '*.json'))) + if args.train_proportion != 0: + train_num = int(total_num * args.train_proportion) + os.makedirs(args.output_dir + '/train') + else: + train_num = 0 + if args.val_proportion == 0.0: + val_num = 0 + test_num = total_num - train_num + if args.test_proportion != 0.0: + os.makedirs(args.output_dir + '/test') + else: + val_num = int(total_num * args.val_proportion) + test_num = total_num - train_num - val_num + os.makedirs(args.output_dir + '/val') + if args.test_proportion != 0.0: + os.makedirs(args.output_dir + '/test') + count = 1 + for img_name in os.listdir(args.image_input_dir): + if count <= train_num: + shutil.copyfile( + osp.join(args.image_input_dir, img_name), + osp.join(args.output_dir + '/train/', img_name)) + else: + if count <= train_num + val_num: + shutil.copyfile( + osp.join(args.image_input_dir, img_name), + osp.join(args.output_dir + '/val/', img_name)) + else: + shutil.copyfile( + osp.join(args.image_input_dir, img_name), + osp.join(args.output_dir + '/test/', img_name)) + count = count + 1 + + # Deal with the json files. + if not os.path.exists(args.output_dir + '/annotations'): + os.makedirs(args.output_dir + '/annotations') + if args.train_proportion != 0: + train_data_coco = deal_json(args.output_dir + '/train', + args.json_input_dir) + train_json_path = osp.join(args.output_dir + '/annotations', + 'instance_train.json') + json.dump( + train_data_coco, + open(train_json_path, 'w'), + indent=4, + cls=MyEncoder) + if args.val_proportion != 0: + val_data_coco = deal_json(args.output_dir + '/val', args.json_input_dir) + val_json_path = osp.join(args.output_dir + '/annotations', + 'instance_val.json') + json.dump( + val_data_coco, open(val_json_path, 'w'), indent=4, cls=MyEncoder) + if args.test_proportion != 0: + test_data_coco = deal_json(args.output_dir + '/test', + args.json_input_dir) + test_json_path = osp.join(args.output_dir + '/annotations', + 'instance_test.json') + json.dump( + test_data_coco, open(test_json_path, 'w'), indent=4, cls=MyEncoder) + +if __name__ == '__main__': + main() diff --git a/PaddleCV/object_detection/ppdet/data/transform/operators.py b/PaddleCV/object_detection/ppdet/data/transform/operators.py index 76a02b64..7d86cf8c 100644 --- a/PaddleCV/object_detection/ppdet/data/transform/operators.py +++ b/PaddleCV/object_detection/ppdet/data/transform/operators.py @@ -499,11 +499,14 @@ class ExpandImage(BaseOperator): @register_op class CropImage(BaseOperator): - def __init__(self, batch_sampler, satisfy_all=False): + def __init__(self, batch_sampler, satisfy_all=False, avoid_no_bbox=True): """ Args: batch_sampler (list): Multiple sets of different parameters for cropping. + satisfy_all (bool): whether all boxes must satisfy. + avoid_no_bbox (bool): whether to to avoid the + situation where the box does not appear. e.g.[[1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0], [1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 1.0], [1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 1.0], @@ -518,6 +521,7 @@ class CropImage(BaseOperator): super(CropImage, self).__init__() self.batch_sampler = batch_sampler self.satisfy_all = satisfy_all + self.avoid_no_bbox = avoid_no_bbox def __call__(self, sample, context): """ @@ -556,8 +560,9 @@ class CropImage(BaseOperator): sample_bbox = clip_bbox(sample_bbox) crop_bbox, crop_class, crop_score = \ filter_and_process(sample_bbox, gt_bbox, gt_class, gt_score) - if len(crop_bbox) <= 1: - continue + if self.avoid_no_bbox: + if len(crop_bbox) < 1: + continue xmin = int(sample_bbox[0] * im_width) xmax = int(sample_bbox[2] * im_width) ymin = int(sample_bbox[1] * im_height) -- GitLab