torchvision-datasets.md 4.7 KB
Newer Older
W
wizardforcel 已提交
1 2 3 4
# torchvision.datasets
`torchvision.datasets`中包含了以下数据集

- MNIST
片刻小哥哥's avatar
片刻小哥哥 已提交
5
- COCO(用于图像标注和目标检测)(Captioning and Detection)
W
wizardforcel 已提交
6 7 8 9 10 11 12 13 14 15 16
- LSUN Classification
- ImageFolder
- Imagenet-12
- CIFAR10 and CIFAR100
- STL10

`Datasets` 拥有以下`API`:

`__getitem__`
`__len__`

片刻小哥哥's avatar
片刻小哥哥 已提交
17
由于以上`Datasets`都是 `torch.utils.data.Dataset`的子类,所以,他们也可以通过`torch.utils.data.DataLoader`使用多线程(python的多进程)。
W
wizardforcel 已提交
18 19 20 21 22 23

举例说明:
`torch.utils.data.DataLoader(coco_cap, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads)`

在构造函数中,不同的数据集直接的构造函数会有些许不同,但是他们共同拥有 `keyword` 参数。
In the constructor, each dataset has a slightly different API as needed, but they all take the keyword args:
片刻小哥哥's avatar
片刻小哥哥 已提交
24
- `transform`: 一个函数,原始图片作为输入,返回一个转换后的图片。(详情请看下面关于`torchvision-tranform`的部分)
W
wizardforcel 已提交
25 26 27 28 29 30 31 32 33

- `target_transform` - 一个函数,输入为`target`,输出对其的转换。例子,输入的是图片标注的`string`,输出为`word`的索引。
## MNIST
```python
dset.MNIST(root, train=True, transform=None, target_transform=None, download=False)
```
参数说明:
- root : `processed/training.pt``processed/test.pt` 的主目录
- train : `True` = 训练集, `False` = 测试集
片刻小哥哥's avatar
片刻小哥哥 已提交
34
- download : `True` = 从互联网上下载数据集,并把数据集放在`root`目录下. 如果数据集之前下载过,将处理过的数据(minist.py中有相关函数)放在`processed`文件夹下。
W
wizardforcel 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76

## COCO
需要安装[COCO API](https://github.com/pdollar/coco/tree/master/PythonAPI)

### 图像标注:
```python
dset.CocoCaptions(root="dir where images are", annFile="json annotation file", [transform, target_transform])
```
例子:
```python
import torchvision.datasets as dset
import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = 'dir where images are',
                        annFile = 'json annotation file',
                        transform=transforms.ToTensor())

print('Number of samples: ', len(cap))
img, target = cap[3] # load 4th sample

print("Image Size: ", img.size())
print(target)
```
输出:
```
Number of samples: 82783
Image Size: (3L, 427L, 640L)
[u'A plane emitting smoke stream flying over a mountain.',
u'A plane darts across a bright blue sky behind a mountain covered in snow',
u'A plane leaves a contrail above the snowy mountain top.',
u'A mountain that has a plane flying overheard in the distance.',
u'A mountain view with a plume of smoke in the background']
```
### 检测:
```
dset.CocoDetection(root="dir where images are", annFile="json annotation file", [transform, target_transform])
```
## LSUN
```python
dset.LSUN(db_path, classes='train', [transform, target_transform])
```
参数说明:
- db_path = 数据集文件的根目录
片刻小哥哥's avatar
片刻小哥哥 已提交
77 78
- classes = 'train' (所有类别, 训练集), 'val' (所有类别, 验证集), 'test' (所有类别, 测试集)
['bedroom\_train', 'church\_train', …] : a list of categories to load
W
wizardforcel 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
## ImageFolder
一个通用的数据加载器,数据集中的数据以以下方式组织
```
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
```
```python
dset.ImageFolder(root="root folder path", [transform, target_transform])
```
他有以下成员变量:

- self.classes - 用一个list保存 类名
- self.class_to_idx - 类名对应的 索引
- self.imgs - 保存(img-path, class) tuple的list

## Imagenet-12
This is simply implemented with an ImageFolder dataset.

The data is preprocessed [as described here](https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md#download-the-imagenet-dataset)

[Here is an example](https://github.com/pytorch/examples/blob/27e2a46c1d1505324032b1d94fc6ce24d5b67e97/imagenet/main.py#L48-L62)

## CIFAR
```python
dset.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)

dset.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)
```
参数说明:
- root : `cifar-10-batches-py` 的根目录
- train : `True` = 训练集, `False` = 测试集
- download : `True` = 从互联上下载数据,并将其放在`root`目录下。如果数据集已经下载,什么都不干。
## STL10
```python
dset.STL10(root, split='train', transform=None, target_transform=None, download=False)
```
参数说明:
- root : `stl10_binary`的根目录
- split : 'train' = 训练集, 'test' = 测试集, 'unlabeled' = 无标签数据集, 'train+unlabeled' = 训练 + 无标签数据集 (没有标签的标记为-1)
- download : `True` = 从互联上下载数据,并将其放在`root`目录下。如果数据集已经下载,什么都不干。