{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 加载数据集\n", "\n", "## 概述\n", "\n", "MindSpore可以帮助你加载常见的数据集、特定数据格式的数据集或自定义的数据集。加载数据集时,需先导入所需要依赖的库`mindspore.dataset`。\n", "\n", "接下来,以加载数常用数据集(CIFAR-10数据集)、特定格式数据集以及自定义数据集为例来体验MindSpore加载数据集操作。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 整体流程\n", "\n", "1. 准备环节。下载本次体验流程所需的数据集。\n", "2. 加载常用数据集并输出结果,以CIFAR-10二进制数据集为例。\n", "3. 加载特定格式数据集并输出结果,以MindRecord格式数据集为例。\n", "4. 加载自定义数据集并输出结果。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 准备环节\n", "\n", "### 导入`mindspore.dataset`模块" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import mindspore.dataset as ds" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 下载所需数据集\n", "\n", "1. 在当前`notebook`工作目录创建`./datasets/cifar-10`目录,用于存放数据集。\n", "2. 在当前`notebook`工作目录创建`./datasets/mindrecord`目录,用于后续存放转换后的MindRecord格式数据集文件。\n", "3. 下载[CIFAR-10二进制格式数据集](https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz),并将数据集文件解压到`./datasets/cifar-10/cifar-10-batches-bin`目录下。\n", "4. 下载数据集[CIFAR-10 Python文件格式数据集](http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz),并将数据集文件解压到`./datasets/cifar-10/cifar-10-batches-py`目录下。\n", "\n", " 此时当前`notebook`工作目录下`datasets`目录结构为:\n", "\n", " ```shell\n", " $ tree datasets\n", " datasets\n", " ├── cifar-10\n", " │   ├── cifar-10-batches-bin\n", " │   │   ├── batches.meta.txt\n", " │   │   ├── data_batch_1.bin\n", " │   │   ├── data_batch_2.bin\n", " │   │   ├── data_batch_3.bin\n", " │   │   ├── data_batch_4.bin\n", " │   │   ├── data_batch_5.bin\n", " │   │   ├── readme.html\n", " │   │   └── test_batch.bin\n", " │   └── cifar-10-batches-py\n", " │   ├── batches.meta\n", " │   ├── data_batch_1\n", " │   ├── data_batch_2\n", " │   ├── data_batch_3\n", " │   ├── data_batch_4\n", " │   ├── data_batch_5\n", " │   ├── readme.html\n", " │   └── test_batch\n", " └── mindrecord\n", " ```\n", "\n", " 其中:\n", " - `cifar-10-batches-bin`目录为CIFAR-10二进制格式数据集目录。\n", " - `cifar-10-batches-py`目录为CIFAR-10 Python文件格式数据集目录。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 加载常见的数据集\n", "\n", "MindSpore可以加载常见的标准数据集。支持的数据集如下表:\n", "\n", "| 数据集: | 简要说明 |\n", "| :---------: | :-------------:|\n", "| ImageNet | ImageNet是根据WordNet层次结构组织的图像数据库,其中层次结构的每个节点都由成百上千个图像表示。 |\n", "| MNIST | 是一个手写数字图像的大型数据库,通常用于训练各种图像处理系统。 |\n", "| CIFAR-10 | 常用于训练图像的采集机器学习和计算机视觉算法。CIFAR-10数据集包含10种不同类别的60,000张32x32彩色图像。 |\n", "| CIFAR-100 | 该数据集类似于CIFAR-10,不同之处在于它有100个类别,每个类别包含600张图像:500张训练图像和100张测试图像。|\n", "| PASCAL-VOC | 数据内容多样,可用于训练计算机视觉模型(分类、定位、检测、分割、动作识别等)。|\n", "| CelebA | CelebA人脸数据集包含上万个名人身份的人脸图片,每张图片有40个特征标记,常用于人脸相关的训练任务。 |\n", "\n", "加载常见数据集的详细步骤如下,以创建`CIFAR-10`对象为例,用于加载支持的数据集。\n", "\n", "1. 使用二进制格式的数据集(CIFAR-10 binary version),配置数据集目录,定义需要加载的数据集实例。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "DATA_DIR = \"./datasets/cifar-10/cifar-10-batches-bin\"\n", "cifar10_dataset = ds.Cifar10Dataset(DATA_DIR)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "2. 创建迭代器,通过迭代器读取数据。此处读取前2个图像及其标签。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The data of image 1 is below:\n", "[[[179 147 140]\n", " [173 148 138]\n", " [131 108 98]\n", " ...\n", " [129 90 77]\n", " [167 140 124]\n", " [188 172 154]]\n", "\n", " [[177 156 131]\n", " [182 167 142]\n", " [120 108 85]\n", " ...\n", " [156 142 130]\n", " [199 171 159]\n", " [174 126 106]]\n", "\n", " [[145 129 103]\n", " [128 107 81]\n", " [166 144 118]\n", " ...\n", " [145 129 115]\n", " [138 94 72]\n", " [179 108 84]]\n", "\n", " ...\n", "\n", " [[123 135 91]\n", " [134 146 101]\n", " [113 123 86]\n", " ...\n", " [117 106 79]\n", " [ 87 81 67]\n", " [ 80 80 56]]\n", "\n", " [[148 159 114]\n", " [135 146 103]\n", " [125 135 97]\n", " ...\n", " [150 137 93]\n", " [123 116 88]\n", " [124 120 93]]\n", "\n", " [[150 162 102]\n", " [160 171 115]\n", " [132 141 97]\n", " ...\n", " [139 126 79]\n", " [113 100 84]\n", " [ 98 83 72]]]\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "The label of image 1 is : 6\n", "The data of image 2 is below:\n", "[[[ 91 93 133]\n", " [ 94 97 127]\n", " [ 75 86 127]\n", " ...\n", " [ 86 89 117]\n", " [ 84 86 113]\n", " [ 80 80 110]]\n", "\n", " [[ 96 104 130]\n", " [ 98 106 124]\n", " [ 83 99 124]\n", " ...\n", " [102 102 111]\n", " [ 99 101 110]\n", " [ 75 88 106]]\n", "\n", " [[ 76 92 126]\n", " [ 91 101 126]\n", " [ 89 104 132]\n", " ...\n", " [100 104 114]\n", " [102 106 115]\n", " [ 88 95 116]]\n", "\n", " ...\n", "\n", " [[127 113 132]\n", " [139 125 138]\n", " [147 131 142]\n", " ...\n", " [159 127 111]\n", " [133 127 137]\n", " [133 124 139]]\n", "\n", " [[132 120 135]\n", " [140 129 136]\n", " [142 130 138]\n", " ...\n", " [166 133 115]\n", " [139 130 136]\n", " [141 133 142]]\n", "\n", " [[118 115 143]\n", " [126 121 143]\n", " [115 111 134]\n", " ...\n", " [148 130 146]\n", " [139 130 156]\n", " [129 121 146]]]\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "The label of image 2 is : 5\n" ] } ], "source": [ "from matplotlib import pyplot as plt\n", "\n", "count = 0\n", "for data in cifar10_dataset.create_dict_iterator():\n", "# In CIFAR-10 dataset, each dictionary of data has keys \"image\" and \"label\".\n", " image = data[\"image\"]\n", " print(f\"The data of image {count+1} is below:\")\n", " print(image)\n", " plt.figure(count)\n", " plt.imshow(image)\n", " plt.title(f\"image{count+1}\")\n", " plt.axis('off')\n", " plt.show()\n", " print(f\"\\nThe label of image {count+1} is :\", data[\"label\"])\n", " count += 1\n", " if count == 2:\n", " break\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 加载特定数据格式的数据集\n", "\n", "\n", "### MindSpore数据格式\n", "\n", "MindSpore天然支持读取MindSpore数据格式——`MindRecord`存储的数据集,在性能和特性上有更好的支持。 \n", "\n", "> 阅读[将数据集转换为MindSpore数据格式](https://www.mindspore.cn/tutorial/zh-CN/r0.7/use/data_preparation/converting_datasets.html),了解如何将数据集转换为MindSpore数据格式。\n", "\n", "可以通过`MindDataset`对象对数据集进行读取。详细方法如下所示:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1. 将CIFAR-10数据集转换为`MindRecord`数据格式。此处使用的数据集为CIFAR-10 Python文件格式数据集(`cifar-10-batches-py`)。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MSRStatus.SUCCESS" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from mindspore.mindrecord import Cifar10ToMR\n", "\n", "\n", "CIFAR10_DIR = \"./datasets/cifar-10/cifar-10-batches-py\"\n", "MINDRECORD_FILE = \"./datasets/mindrecord/cifar10.mindrecord\"\n", "cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE)\n", "cifar10_transformer.transform(['label'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "2. 使用`MindDataset`类创建数据集`data_set`,用于读取数据。其中`dataset_file`为指定MindRecord的文件或文件列表。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "\n", "data_set = ds.MindDataset(dataset_file=\"./datasets/mindrecord/cifar10.mindrecord\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "3. 创建字典迭代器,通过迭代器读取数据记录。此处读取前5个数据的标签数据。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2\n", "6\n", "0\n", "6\n", "9\n" ] } ], "source": [ "num_iter = 0\n", "for data in data_set.create_dict_iterator():\n", " print(data[\"label\"])\n", " num_iter += 1\n", " if num_iter == 5:\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 加载自定义数据集\n", "\n", "现实场景中,数据集的种类多种多样,对于自定义数据集或者目前不支持直接加载的数据集,有两种方法可以处理。\n", "一种方法是将数据集转成MindRecord格式(请参考[将数据集转换为MindSpore数据格式](https://www.mindspore.cn/tutorial/zh-CN/r0.7/use/data_preparation/converting_datasets.html)章节),另一种方法是通过`GeneratorDataset`对象加载,以下将展示如何使用`GeneratorDataset`。\n", "\n", "1. 定义一个可迭代的对象,用于生成数据集。以下展示了两种示例,一种是含有`yield`返回值的自定义函数,另一种是含有`__getitem__`的自定义类。两种示例都将产生一个含有从0到9数字的数据集。\n", " \n", "> 自定义的可迭代对象,每次返回`numpy array`的元组,作为一行数据。 " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "    以下一段代码创建含有`yield`返回值的自定义函数`generator_func`:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "import numpy as np # Import numpy lib.\n", "\n", "\n", "def generator_func(num):\n", " for i in range(num):\n", " yield (np.array([i]),) # Notice, tuple of only one element needs following a comma at the end.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "    创建含有`__getitem__`的自定义类:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "import numpy as np # Import numpy lib.\n", "\n", "\n", "class Generator():\n", "\n", " def __init__(self, num):\n", " self.num = num\n", "\n", " def __getitem__(self, item):\n", " return (np.array([item]),) # Notice, tuple of only one element needs following a comma at the end.\n", "\n", " def __len__(self):\n", " return self.num\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "2. 使用`GeneratorDataset`创建数据集,并通过给数据创建迭代器的方式,获取相应的数据。\n", "\n", " - 将`generator_func`传入`GeneratorDataset`创建数据集`dataset1`,并设定`column`名为“data” 。\n", " - 将定义的`Generator`对象传入`GeneratorDataset`创建数据集`dataset2`,并设定`column`名为“data” 。\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "    以下一段代码分别对`dataset1`和`dataset2`创建返回值为序列类型的迭代器,并打印输出数据。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dataset1:\n", "[array([0], dtype=int32)]\n", "[array([1], dtype=int32)]\n", "[array([2], dtype=int32)]\n", "[array([3], dtype=int32)]\n", "[array([4], dtype=int32)]\n", "[array([5], dtype=int32)]\n", "[array([6], dtype=int32)]\n", "[array([7], dtype=int32)]\n", "[array([8], dtype=int32)]\n", "[array([9], dtype=int32)]\n", "dataset2:\n", "[array([0], dtype=int64)]\n", "[array([1], dtype=int64)]\n", "[array([2], dtype=int64)]\n", "[array([3], dtype=int64)]\n", "[array([4], dtype=int64)]\n", "[array([5], dtype=int64)]\n", "[array([6], dtype=int64)]\n", "[array([7], dtype=int64)]\n", "[array([8], dtype=int64)]\n", "[array([9], dtype=int64)]\n" ] } ], "source": [ "dataset1 = ds.GeneratorDataset(source=generator_func(10), column_names=[\"data\"], shuffle=False)\n", "dataset2 = ds.GeneratorDataset(source=Generator(10), column_names=[\"data\"], shuffle=False)\n", "\n", "print(\"dataset1:\") \n", "for data in dataset1.create_tuple_iterator(): # each data is a sequence\n", " print(data)\n", "\n", "print(\"dataset2:\")\n", "for data in dataset2.create_tuple_iterator(): # each data is a sequence\n", " print(data)\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "    以下一段代码分别对`dataset1`和`dataset2`创建迭代器,并打印输出数据果。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dataset1:\n", "[0]\n", "[1]\n", "[2]\n", "[3]\n", "[4]\n", "[5]\n", "[6]\n", "[7]\n", "[8]\n", "[9]\n", "dataset2:\n", "[0]\n", "[1]\n", "[2]\n", "[3]\n", "[4]\n", "[5]\n", "[6]\n", "[7]\n", "[8]\n", "[9]\n" ] } ], "source": [ "dataset1 = ds.GeneratorDataset(source=generator_func(10), column_names=[\"data\"], shuffle=False)\n", "dataset2 = ds.GeneratorDataset(source=Generator(10), column_names=[\"data\"], shuffle=False)\n", "\n", "\n", "print(\"dataset1:\")\n", "for data in dataset1.create_dict_iterator(): # each data is a dictionary\n", " print(data[\"data\"])\n", "\n", "print(\"dataset2:\")\n", "for data in dataset2.create_dict_iterator(): # each data is a dictionary\n", " print(data[\"data\"])\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 总结\n", "\n", "以上便完成了MindSpore加载数据集的体验,我们通过本次体验全面了解了MindSpore加载数据集的几种方式和支持的数据集类型、如何创建自定义数据集,以及输出展示加载后的数据集结果。" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 4 }