未验证 提交 2cabaccf 编写于 作者: G Guanghua Yu 提交者: GitHub

update image_folder loader in yolo act (#1393)

上级 8c46869c
......@@ -62,25 +62,47 @@ pip install paddleslim==2.3.3
#### 3.2 准备数据集
本示例默认以COCO数据进行自动压缩实验,可以从[MS COCO官网](https://cocodataset.org)下载[Train](http://images.cocodataset.org/zips/train2017.zip)[Val](http://images.cocodataset.org/zips/val2017.zip)[annotation](http://images.cocodataset.org/annotations/annotations_trainval2017.zip)
**选择(1)或(2)中一种方法准备数据即可。**
目录格式如下:
```
dataset/coco/
├── annotations
│ ├── instances_train2017.json
│ ├── instances_val2017.json
│ | ...
├── train2017
│ ├── 000000000009.jpg
│ ├── 000000580008.jpg
│ | ...
├── val2017
│ ├── 000000000139.jpg
│ ├── 000000000285.jpg
```
- (1)支持无标注图片,直接传入图片文件夹,但不支持评估模型mAP
修改[config](./configs)`image_path`路径为真实预测场景下的图片文件夹,图片数量依据数据集大小来定,尽量覆盖所有部署场景。
```yaml
Global:
image_path: dataset/coco/val2017
```
- (2)支持加载COCO格式数据集,**可支持实时评估模型mAP**
可以从[MS COCO官网](https://cocodataset.org)下载[Train](http://images.cocodataset.org/zips/train2017.zip)[Val](http://images.cocodataset.org/zips/val2017.zip)[annotation](http://images.cocodataset.org/annotations/annotations_trainval2017.zip)
如果是自定义数据集,请按照如上COCO数据格式准备数据。
目录格式如下:
```
dataset/coco/
├── annotations
│ ├── instances_train2017.json
│ ├── instances_val2017.json
│ | ...
├── train2017
│ ├── 000000000009.jpg
│ ├── 000000580008.jpg
│ | ...
├── val2017
│ ├── 000000000139.jpg
│ ├── 000000000285.jpg
```
如果是自定义数据集,请按照如上COCO数据格式准备数据。
准备好数据集后,修改[config](./configs)`coco_dataset_dir`路径。
```yaml
Global:
coco_dataset_dir: dataset/coco/
coco_train_image_dir: train2017
coco_train_anno_path: annotations/instances_train2017.json
coco_val_image_dir: val2017
coco_val_anno_path: annotations/instances_val2017.json
```
#### 3.3 准备预测模型
......
Global:
model_dir: ./yolov5s.onnx
dataset_dir: dataset/coco/
train_image_dir: train2017
val_image_dir: val2017
train_anno_path: annotations/instances_train2017.json
val_anno_path: annotations/instances_val2017.json
Evaluation: True
image_path: None # If image_path is set, it will be trained directly based on unlabeled images, no need to set the COCO dataset path.
coco_dataset_dir: dataset/coco/
coco_train_image_dir: train2017
coco_train_anno_path: annotations/instances_train2017.json
coco_val_image_dir: val2017
coco_val_anno_path: annotations/instances_val2017.json
arch: YOLOv5
Distillation:
......
Global:
model_dir: ./yolov6s.onnx
dataset_dir: dataset/coco/
train_image_dir: train2017
val_image_dir: val2017
train_anno_path: annotations/instances_train2017.json
val_anno_path: annotations/instances_val2017.json
Evaluation: True
image_path: None # If image_path is set, it will be trained directly based on unlabeled images, no need to set the COCO dataset path.
coco_dataset_dir: dataset/coco/
coco_train_image_dir: train2017
coco_train_anno_path: annotations/instances_train2017.json
coco_val_image_dir: val2017
coco_val_anno_path: annotations/instances_val2017.json
arch: YOLOv6
Distillation:
......
Global:
model_dir: ./yolov7.onnx
dataset_dir: dataset/coco/
train_image_dir: train2017
val_image_dir: val2017
train_anno_path: annotations/instances_train2017.json
val_anno_path: annotations/instances_val2017.json
Evaluation: True
image_path: None # If image_path is set, it will be trained directly based on unlabeled images, no need to set the COCO dataset path.
coco_dataset_dir: dataset/coco/
coco_train_image_dir: train2017
coco_train_anno_path: annotations/instances_train2017.json
coco_val_image_dir: val2017
coco_val_anno_path: annotations/instances_val2017.json
arch: YOLOv7
Distillation:
......
Global:
model_dir: ./yolov7-tiny.onnx
dataset_dir: dataset/coco/
train_image_dir: train2017
val_image_dir: val2017
train_anno_path: annotations/instances_train2017.json
val_anno_path: annotations/instances_val2017.json
Evaluation: True
image_path: None # If image_path is set, it will be trained directly based on unlabeled images, no need to set the COCO dataset path.
coco_dataset_dir: dataset/coco/
coco_train_image_dir: train2017
coco_train_anno_path: annotations/instances_train2017.json
coco_val_image_dir: val2017
coco_val_anno_path: annotations/instances_val2017.json
arch: YOLOv7
Distillation:
......
from pycocotools.coco import COCO
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import os
import numpy as np
......@@ -12,6 +25,7 @@ class COCOValDataset(paddle.io.Dataset):
anno_path=None,
img_size=[640, 640],
input_name='x2paddle_images'):
from pycocotools.coco import COCO
self.dataset_dir = dataset_dir
self.image_dir = image_dir
self.img_size = img_size
......@@ -113,3 +127,39 @@ class COCOTrainDataset(COCOValDataset):
img = self._get_img_data_from_img_id(img_id)
img, scale_factor = self.image_preprocess(img, self.img_size)
return {self.input_name: img}
def _generate_scale(im, target_shape):
origin_shape = im.shape[:2]
im_size_min = np.min(origin_shape)
im_size_max = np.max(origin_shape)
target_size_min = np.min(target_shape)
target_size_max = np.max(target_shape)
im_scale = float(target_size_min) / float(im_size_min)
if np.round(im_scale * im_size_max) > target_size_max:
im_scale = float(target_size_max) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
return im_scale_y, im_scale_x
def yolo_image_preprocess(img, target_shape=[640, 640]):
# Resize image
im_scale_y, im_scale_x = _generate_scale(img, target_shape)
img = cv2.resize(
img,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=cv2.INTER_LINEAR)
# Pad
im_h, im_w = img.shape[:2]
h, w = target_shape[:]
if h != im_h or w != im_w:
canvas = np.ones((h, w, 3), dtype=np.float32)
canvas *= np.array([114.0, 114.0, 114.0], dtype=np.float32)
canvas[0:im_h, 0:im_w, :] = img.astype(np.float32)
img = canvas
img = np.transpose(img / 255, [2, 0, 1])
return img.astype(np.float32)
......@@ -80,9 +80,9 @@ def main():
global val_loader
dataset = COCOValDataset(
dataset_dir=global_config['dataset_dir'],
image_dir=global_config['val_image_dir'],
anno_path=global_config['val_anno_path'])
dataset_dir=global_config['coco_dataset_dir'],
image_dir=global_config['coco_val_image_dir'],
anno_path=global_config['coco_val_anno_path'])
global anno_file
anno_file = dataset.ann_file
val_loader = paddle.io.DataLoader(
......
......@@ -20,7 +20,7 @@ from tqdm import tqdm
import paddle
from paddleslim.common import load_config
from paddleslim.auto_compression import AutoCompression
from dataset import COCOValDataset, COCOTrainDataset
from dataset import COCOValDataset, COCOTrainDataset, yolo_image_preprocess
from post_process import YOLOPostProcess, coco_metric
......@@ -42,12 +42,18 @@ def argsparser():
type=str,
default='gpu',
help="which device used to compress.")
parser.add_argument(
'--eval', type=bool, default=False, help="whether to run evaluation.")
return parser
def reader_wrapper(reader, input_name='x2paddle_images'):
def gen():
for data in reader:
yield {input_name: data[0]}
return gen
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
bboxes_list, bbox_nums_list, image_id_list = [], [], []
with tqdm(
......@@ -79,32 +85,45 @@ def main():
global_config = all_config["Global"]
input_name = 'x2paddle_image_arrays' if global_config[
'arch'] == 'YOLOv6' else 'x2paddle_images'
dataset = COCOTrainDataset(
dataset_dir=global_config['dataset_dir'],
image_dir=global_config['train_image_dir'],
anno_path=global_config['train_anno_path'],
input_name=input_name)
train_loader = paddle.io.DataLoader(
dataset, batch_size=1, shuffle=True, drop_last=True, num_workers=0)
if 'Evaluation' in global_config.keys() and global_config[
'Evaluation'] and paddle.distributed.get_rank() == 0:
eval_func = eval_function
global val_loader
dataset = COCOValDataset(
dataset_dir=global_config['dataset_dir'],
image_dir=global_config['val_image_dir'],
anno_path=global_config['val_anno_path'])
global anno_file
anno_file = dataset.ann_file
val_loader = paddle.io.DataLoader(
dataset,
if global_config['image_path'] != 'None':
assert os.path.exists(global_config['image_path'])
paddle.vision.image.set_image_backend('cv2')
train_dataset = paddle.vision.datasets.ImageFolder(
global_config['image_path'], transform=yolo_image_preprocess)
train_loader = paddle.io.DataLoader(
train_dataset,
batch_size=1,
shuffle=False,
drop_last=False,
shuffle=True,
drop_last=True,
num_workers=0)
else:
train_loader = reader_wrapper(train_loader, input_name=input_name)
eval_func = None
else:
dataset = COCOTrainDataset(
dataset_dir=global_config['coco_dataset_dir'],
image_dir=global_config['coco_train_image_dir'],
anno_path=global_config['coco_train_anno_path'],
input_name=input_name)
train_loader = paddle.io.DataLoader(
dataset, batch_size=1, shuffle=True, drop_last=True, num_workers=0)
if paddle.distributed.get_rank() == 0:
eval_func = eval_function
global val_loader
dataset = COCOValDataset(
dataset_dir=global_config['coco_dataset_dir'],
image_dir=global_config['coco_val_image_dir'],
anno_path=global_config['coco_val_anno_path'])
global anno_file
anno_file = dataset.ann_file
val_loader = paddle.io.DataLoader(
dataset,
batch_size=1,
shuffle=False,
drop_last=False,
num_workers=0)
else:
eval_func = None
ac = AutoCompression(
model_dir=global_config["model_dir"],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册