提交 b22752e2 编写于 作者: H HydrogenSulfate

refactor DALI

上级 49519f63
# 基于PaddleClas的DALI开发实践
- [1. 简介](#1-简介)
- [2. 环境准备](#2-环境准备)
- [2.1 安装DALI](#21-安装dali)
- [3. 基本概念介绍](#3-基本概念介绍)
- [3.1 Operator](#31-operator)
- [3.2 Device](#32-device)
- [3.3 DataNode](#33-datanode)
- [3.4 Pipeline](#34-pipeline)
- [4. 开发实践](#4-开发实践)
- [4.1 开发与接入流程](#41-开发与接入流程)
- [4.2 RandomFlip](#42-randomflip)
- [4.2.1 继承DALI已有类](#421-继承dali已有类)
- [4.2.2 重载 \_\_init\_\_ 方法](#422-重载-__init__-方法)
- [4.2.3 重载 \_\_call\_\_ 方法](#423-重载-__call__-方法)
- [4.3 RandomRotation](#43-randomrotation)
- [4.3.1 继承DALI已有类](#431-继承dali已有类)
- [4.3.2 重载 \_\_init\_\_ 方法](#432-重载-__init__-方法)
- [4.3.3 重载 \_\_call\_\_ 方法](#433-重载-__call__-方法)
- [5. FAQ](#5-faq)
## 1. 简介
NVIDIA **Da**ta Loading **Li**brary (DALI) 是由 NVIDIA 开发的一套高性能数据预处理开源代码库,其提供了许多优化后的预处理算子,能很大程度上减少数据预处理耗时,非常适合在深度学习任务中使用。具体地,DALI 通过将大部分的数据预处理转移到 GPU 来解决 CPU 瓶颈问题。此外,DALI 编写了配套的高效执行引擎,最大限度地提高输入管道的吞吐量。
实际上 DALI 提供了不同粒度的图像、音频处理算子与随机数算子,这一特点基本上满足了大部分用户的需求,即用户只需在python侧进行开发,而不需要接触更为底层更复杂的C++代码。
本文档作为DALI的开发入门实践教程,在 PaddleClas 的文件结构与代码逻辑基础上,来介绍如何利用已有的DALI算子,根据自己的需求进行python侧的二次开发,以减少初学者的学习成本,提升模型训练效率。
## 2. 环境准备
### 2.1 安装DALI
进入DALI的官方教程 **[DALI-installtion](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html#)**
首先运行 `nvcc -V` 查看运行环境中的CUDA版本
```log
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Wed_Oct_23_19:24:38_PDT_2019
Cuda compilation tools, release 10.2, V10.2.89
```
可以看到具体版本号是 `release 10.2, V10.2.89`,因此接下来需安装 CUDA10.2 的DALI包
```shell
# for CUDA10.2
python3.7 -m pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-cuda102
```
其余版本的CUDA请将上述命令末尾的 `cuda102` 改成对应的CUDA版本即可,如 CUDA11.0就改成 `cuda110`。DALI的具体支持设备可查看 **[DALI-support_matrix](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/support_matrix.html)**
## 3. 基本概念介绍
### 3.1 Operator
DALI 预处理过程的基本单位是 Operator(算子),PaddleClas 的 `operators.py` 设计逻辑与之类似,是一种较为通用的设计方式。DALI 提供了多种算子供用户根据具体需求使用,如 `nvidia.dali.ops.decoders.Image`(图像解码算子), `nvidia.dali.ops.Flip`(水平、垂直翻转算子),以及稍复杂的融合算子 `nvidia.dali.ops.decoders.ImageRandomCrop`(图像解码+随机裁剪的融合算子)。同时 DALI 也提供了一些随机数算子以在图像增强中加入随机性,如 `nvidia.dali.ops.random.CoinFlip`(伯努利分布随机数算子),`nvidia.dali.ops.random.Uniform`(均匀分布随机数算子)。
详细的算子库结构可以查看 **[DALI-operators](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/supported_ops_legacy.html#modules)**
### 3.2 Device
DALI 可以选择将数据预处理放到GPU上进行,因此绝大部分算子自身具有 `device` 这一参数,以在不同的设备上运行。
而 DALI 将具体情况分为了三种:
1. `cpu` - 接受在CPU上的输入,且输出在CPU上。
2. `mixed` - 接受在CPU上的输入,但输出在GPU上。
3. `gpu` - 接受在GPU上的输入,且输出在GPU上。
因此可以指定每个算子的处理时的设备,加快并行效率,减少阻塞耗时。
### 3.3 DataNode
与常见的深度学习框架中静态图的设计思路(如 tensorflow)相似,DALI 的 Operator 输入和输出一般是一个或多个在CPU/GPU上的数据,被称为 **DataNode**,这些 DataNode 在多个 Operator 中被有顺序地处理、传递,直到成为最后一个 Operator 的输出,然后才被用户获取并输入到网络模型中去。
### 3.4 Pipeline
从用户读取、解析给定的图片路径文件(如`.txt`格式文件)开始,到解码出图片,再到使用一个或多个Operator对图片进行预处理,最后返回处理完毕的图像(一般为Tensor格式)。这一整个过程称之为 **Pipeline**,当准备好需要的 Operator(s) 之后,就需要开始编写这一部分的代码,将 数据读取、预处理Operator(s) 组装成一个 Pipeline。如果将 Pipeline 当作是一个计算图,那么 Operator 和 DataNode 都是图中的结点,如下图所示。
![DALI-pipeline](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/_images/two_readers.svg)
## 4. 开发实践
本章节希望通过一个简单的例子和一个稍复杂的例子,介绍如何基于 DALI 提供的算子,在python侧进行二次开发,以满足用户实际需要。
### 4.1 开发与接入流程
1.`ppcls/data/preprocess/ops/dali_operators.py` 中开发python侧DALI算子的代码。
2.`ppcls/data/preprocess/ops/dali.py` 开头处 import 导入开发好的算子类,并在 `convert_cfg_to_dali` 函数内参照其它算子配置转换逻辑,为添加的算子也加入对应的配置转换逻辑。
3. (可选)如果开发的算子属于 fused operator,则还需在 `ppcls/data/preprocess/ops/dali.py``build_dali_transforms` 函数内,参照已有融合算子逻辑,添加新算子对应的融合逻辑。
4. (可选)如果开发的是 External Source 类的 sampler 算子,可参照已有的 `ExternalSource_RandomIdentity` 代码进行开发,并在添加对应调用逻辑。实际上 External Source 类可视作对原有的Dataset和Sampler代码进行合并。
### 4.2 RandomFlip
以 PaddleClas 已有的 [RandFlipImage](../../../../ppcls/data/preprocess/ops/operators.py#L499) 算子为例,我们希望在使用DALI训练时,将其转换为对应的 DALI 算子,且同样具备 **按指定的 `prob` 概率进行 指定的水平 or 垂直翻转**
#### 4.2.1 继承DALI已有类
DALI 已经提供了简单的翻转算子 [`nvidia.dali.ops.Flip`](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/supported_ops_legacy.html#nvidia.dali.ops.Flip),其通过 `horizontal``vertical` 参数来分别控制是否对图像进行水平、垂直翻转。但是其缺少随机性,无法直接按照一定概率进行翻转或不翻转,因此我们需要继承这个翻转类,并重载其 `__init__` 方法和 `__call__` 方法。继承代码如下所示:
```python
import nvidia.dali.ops as ops
class RandFlipImage(ops.Flip):
def __init__(self, *kargs, device="cpu", **kwargs):
super(RandFlipImage, self).__init__(*kargs, device=device, **kwargs)
...
def __call__(self, data, **kwargs):
...
```
#### 4.2.2 重载 \_\_init\_\_ 方法
我们需要在构造算子时加入随机参数来控制是否翻转,因此仿照普通 `RandFlipImage`算子的逻辑,在继承类的初始化方法中加入参数 `prob`,同理再加入 `flip_code` 用于控制水平、垂直翻转。
由于每一次执行我们都需要生成一个随机数(此处用0或1表示),代表是否在翻转轴上进行翻转,因此我们实例化一个 `ops.random.CoinFlip` 来作为随机数生成器(实例化对象为下方代码中的 `self.rng`),同理我们也需要记录翻转轴参数 `flip_code`,以供 `__call__` 方法中调用。
修改后代码如下所示:
```python
class RandFlipImage(ops.Flip):
def __init__(self, *kargs, device="cpu", prob=0.5, flip_code=1, **kwargs):
super(RandFlipImage, self).__init__(*kargs, device=device, **kwargs)
self.flip_code = flip_code
self.rng = ops.random.CoinFlip(probability=prob)
def __call__(self, data, **kwargs):
...
```
#### 4.2.3 重载 \_\_call\_\_ 方法
有了 `self.rng``self.flip_code`,我们就能在每次调用的 `__call__` 方法内部,加入随机性并控制方向。首先调用 `self.rng()``__call__` 方法,生成一个0或1的随机整数,0代表不进行翻转,1代表进行翻转;然后根据 `self.flip_code` ,将这个随机整数作为父类 `__call__` 方法的 `horizontal``vertical` 参数,调用父类的 `__call__` 方法完成翻转。这样就完成了一个简单的自定义DALI RandomFlip 算子的编写。完整代码如下所示:
```python
class RandFlipImage(ops.Flip):
def __init__(self, *kargs, device="cpu", prob=0.5, flip_code=1, **kwargs):
super(RandFlipImage, self).__init__(*kargs, device=device, **kwargs)
self.flip_code = flip_code
self.rng = ops.random.CoinFlip(probability=prob)
def __call__(self, data, **kwargs):
do_flip = self.rng()
if self.flip_code == 1:
return super(RandFlipImage, self).__call__(
data, horizontal=do_flip, vertical=0, **kwargs)
elif self.flip_code == 1:
return super(RandFlipImage, self).__call__(
data, horizontal=0, vertical=do_flip, **kwargs)
else:
return super(RandFlipImage, self).__call__(
data, horizontal=do_flip, vertical=do_flip, **kwargs)
```
### 4.3 RandomRotation
以 PaddleClas 已有的 [RandomRotation](../../../../ppcls/data/preprocess/ops/operators.py#L684) 算子为例,我们希望在使用DALI训练时,将其转换为对应的 DALI 算子,且同样具备 **按指定的参数与角度进行随机旋转**
#### 4.3.1 继承DALI已有类
DALI 已经提供了简单的翻转算子 [`nvidia.dali.ops.Rotate`](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/supported_ops_legacy.html#nvidia.dali.ops.Rotate),其通过 `angle``fill_value``interp_type` 等参数控制旋转的角度、填充值以及插值方式。但是其缺少一定的随机性,此我们需要继承这个旋转类,并重载其 `__init__` 方法和 `__call__` 方法。继承代码如下所示:
```python
import nvidia.dali.ops as ops
class RandomRotation(ops.Rotate):
def __init__(self, *kargs, device="cpu", **kwargs):
super(RandomRotation, self).__init__(*kargs, device=device, **kwargs)
...
def __call__(self, data, **kwargs):
...
```
#### 4.3.2 重载 \_\_init\_\_ 方法
我们需要在构造算子时加入随机参数来控制是否翻转,因此仿照普通 `RandomRotation` 算子的逻辑,在继承类的初始化方法中加入参数 `prob`,同理再加入 `angle` 用于控制旋转角度。
由于每一次执行我们都需要生成一个随机数(此处用0或1表示),代表是否进行随机旋转,因此我们实例化一个 `ops.random.CoinFlip` 来作为随机数生成器(实例化对象为下方代码中的 `self.rng`)。除此之外我们还需要实例化一个随机数生成器来作为实际旋转时的角度(实例化对象为下方代码中的 `self.rng_angle`),由于角度是一个均匀分布而不是伯努利分布,因此需要使用 `random.Uniform` 这个类。
修改后代码如下所示:
```python
class RandomRotation(ops.Rotate):
def __init__(self, *kargs, device="cpu", prob=0.5, angle=0, **kwargs):
super(RandomRotation, self).__init__(*kargs, device=device, **kwargs)
self.rng = ops.random.CoinFlip(probability=prob)
self.rng_angle = ops.random.Uniform(range=(-angle, angle))
def __call__(self, data, **kwargs):
...
```
#### 4.3.3 重载 \_\_call\_\_ 方法
有了以上的一些变量,根据 `operators.py``RandomRotation` 的逻辑,仿照 [RandomFlip-重载__call__方法](#413-重载-__call__-方法) 的写法进行代码编写,就能得到完整代码,如下所示:
```python
class RandomRotation(ops.Rotate):
def __init__(self, *kargs, device="cpu", prob=0.5, angle=0, **kwargs):
super(RandomRotation, self).__init__(*kargs, device=device, **kwargs)
self.rng = ops.random.CoinFlip(probability=prob)
discrete_angle = list(range(-angle, angle + 1))
self.rng_angle = ops.random.Uniform(values=discrete_angle)
def __call__(self, data, **kwargs):
do_rotate = self.rng()
angle = self.rng_angle()
flip_data = super(RandomRotation, self).__call__(
data,
angle=fn.cast(
do_rotate, dtype=types.FLOAT) * angle,
keep_size=True,
fill_value=0,
**kwargs)
return flip_data
```
具体每个参数的含义,如`angle``keep_size``fill_value`等,可以查看DALI对应算子的文档
## 5. FAQ
- **Q**:是否所有算子都能以继承-重载的方式改写成DALI算子?
**A**:具体视算子本身的执行逻辑而定,如 `RandomErasing` 算子实际上比较难在python侧转换成DALI算子,尽管DALI有一个对应的 [random_erasing Demo](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/general/erase.html?highlight=erase),但其实际执行中的随机逻辑与 `RandomErasing` 存在一定差异,无法保证等价转换。可以尝试使用 [python_function](https://docs.nvidia.com/deeplearning/dali/main-user-guide/docs/operations/nvidia.dali.fn.python_function.html?highlight=python_function#nvidia.dali.fn.python_function) 来接入python实现的数据增强
- **Q**:使用DALI训练模型的最终精度与不使用DALI不同?
**A**:由于DALI底层实现是NVIDIA官方编写的代码,而operators.py中调用的是cv2、Pillow库,可能存在无法避免的细微差异,如同样的插值方法,实现存在不同。因此只能尽量从执行逻辑、参数、随机数分布上进行等价转换,而无法做到完全一致。如果出现较大diff,可以检查转换来的DALI算子代码执行逻辑、参数、随机数分布是否存在问题,也可以将读取结果可视化检查。另外需要注意的是如果使用DALI的数据预处理接口进行训练,那么为了获得最佳的精度,也应该用DALI的数据预处理接口进行测试,否则可能会造成精度下降。
- **Q**:如果模型使用比较复杂的Sampler如PKsampler该如何改写呢?
**A**:从开发成本考虑,目前比较推荐的方法([#issue 4407](https://github.com/NVIDIA/DALI/issues/4407#issuecomment-1298132180))是使用DALI官方提供的 [`External Source Operator`](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/general/data_loading/external_input.html) 完成自定义Sampler的编写,实际上 [dali.py](../../../../ppcls/data/dataloader/dali.py) 也提供了基于 `External Source Operator``PKSampler` 的实现 `ExternalSource_RandomIdentity`
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
* [2. 安装 DALI](#2) * [2. 安装 DALI](#2)
* [3. 使用 DALI](#3) * [3. 使用 DALI](#3)
* [4. 使用 FP16 训练](#4) * [4. 使用 FP16 训练](#4)
* [5. 新增 DALI 算子](#5)
<a name='1'></a> <a name='1'></a>
...@@ -70,3 +71,6 @@ python -m paddle.distributed.launch \ ...@@ -70,3 +71,6 @@ python -m paddle.distributed.launch \
ppcls/train.py \ ppcls/train.py \
-c ./ppcls/configs/ImageNet/ResNet/ResNet50_fp16_dygraph.yaml -c ./ppcls/configs/ImageNet/ResNet/ResNet50_fp16_dygraph.yaml
``` ```
## 5. 新增 DALI 算子
PaddleClas提供了基于DALI已有API的自定义python算子教程,可以参考 [develop_with_DALI.md](../advanced/develop_with_DALI.md)
...@@ -98,7 +98,8 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): ...@@ -98,7 +98,8 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
mode, mode,
paddle.device.get_device(), paddle.device.get_device(),
num_threads=config[mode]['loader']["num_workers"], num_threads=config[mode]['loader']["num_workers"],
seed=seed) seed=seed,
enable_fuse=True)
class_num = config.get("class_num", None) class_num = config.get("class_num", None)
epochs = config.get("epochs", None) epochs = config.get("epochs", None)
......
...@@ -14,131 +14,640 @@ ...@@ -14,131 +14,640 @@
from __future__ import division from __future__ import division
import copy
import os import os
from collections import defaultdict
from typing import Any, Callable, Dict, List, Tuple, Union, Optional
import numpy as np
import nvidia.dali.fn as fn
import nvidia.dali.ops as ops import nvidia.dali.ops as ops
import nvidia.dali.pipeline as pipeline
import nvidia.dali.types as types import nvidia.dali.types as types
import paddle import paddle
from typing import List
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.paddle import DALIGenericIterator from nvidia.dali.plugin.paddle import DALIGenericIterator
from nvidia.dali.plugin.base_iterator import LastBatchPolicy
from ppcls.data.preprocess.ops.dali_operators import ColorJitter
from ppcls.data.preprocess.ops.dali_operators import CropImage
from ppcls.data.preprocess.ops.dali_operators import CropMirrorNormalize
from ppcls.data.preprocess.ops.dali_operators import DecodeImage
from ppcls.data.preprocess.ops.dali_operators import DecodeRandomResizedCrop
from ppcls.data.preprocess.ops.dali_operators import NormalizeImage
from ppcls.data.preprocess.ops.dali_operators import Pad
from ppcls.data.preprocess.ops.dali_operators import RandCropImage
from ppcls.data.preprocess.ops.dali_operators import RandCropImageV2
from ppcls.data.preprocess.ops.dali_operators import RandFlipImage
from ppcls.data.preprocess.ops.dali_operators import RandomCropImage
from ppcls.data.preprocess.ops.dali_operators import RandomRot90
from ppcls.data.preprocess.ops.dali_operators import RandomRotation
from ppcls.data.preprocess.ops.dali_operators import ResizeImage
from ppcls.data.preprocess.ops.dali_operators import ToCHWImage
from ppcls.engine.train.utils import type_name
from ppcls.utils import logger
INTERP_MAP = {
"nearest": types.DALIInterpType.INTERP_NN, # cv2.INTER_NEAREST
"bilinear": types.DALIInterpType.INTERP_LINEAR, # cv2.INTER_LINEAR
"bicubic": types.DALIInterpType.INTERP_CUBIC, # cv2.INTER_CUBIC
"lanczos": types.DALIInterpType.INTERP_LANCZOS3, # cv2.INTER_LANCZOS4
}
def make_pair(x: Union[Any, Tuple[Any], List[Any]]) -> Tuple[Any]:
"""repeat input x to be an tuple if x is an single element, else return x directly
Args:
x (Union[Any, Tuple[Any], List[Any]]): input x
Returns:
Tuple[Any]: tupled input
"""
return x if isinstance(x, (tuple, list)) else (x, x)
def parse_value_with_key(content: Union[Dict, List[Dict]],
key: str) -> Union[None, Any]:
"""parse value according to given key recursively, return None if not found
Args:
content (Union[Dict, List[Dict]]): content to be parsed
key (str): given key
Returns:
Union[None, Any]: result
"""
if isinstance(content, dict):
if key in content:
return content[key]
for content_ in content.values():
value = parse_value_with_key(content_, key)
if value is not None:
return value
elif isinstance(content, (tuple, list)):
for content_ in content:
value = parse_value_with_key(content_, key)
if value is not None:
return value
return None
def convert_cfg_to_dali(op_name: str, device: str, **op_cfg) -> Dict[str, Any]:
"""convert original preprocess op params into DALI-based op params
Args:
op_name (str): name of operator
device (str): device which operator applied on
Returns:
Dict[str, Any]: converted arguments for DALI initialization
"""
assert device in ["cpu", "gpu"
], f"device({device}) must in [\"cpu\", \"gpu\"]"
dali_op_cfg = {}
if op_name == "DecodeImage":
device = "cpu" if device == "cpu" else "mixed"
to_rgb = op_cfg.get("to_rgb", True)
channel_first = op_cfg.get("channel_first", False)
assert channel_first is False, \
f"`channel_first` must set to False when using DALI, but got {channel_first}"
dali_op_cfg.update({"device": device})
dali_op_cfg.update({
"output_type": types.DALIImageType.RGB
if to_rgb else types.DALIImageType.BGR
})
dali_op_cfg.update({
"device_memory_padding":
op_cfg.get("device_memory_padding", 211025920)
})
dali_op_cfg.update({
"host_memory_padding": op_cfg.get("host_memory_padding", 140544512)
})
elif op_name == "ResizeImage":
size = op_cfg.get("size", None)
resize_short = op_cfg.get("resize_short", None)
interpolation = op_cfg.get("interpolation", None)
if size is not None:
size = make_pair(size)
dali_op_cfg.update({"resize_y": size[0], "resize_x": size[1]})
if resize_short is not None:
dali_op_cfg.update({"resize_shorter": resize_short})
if interpolation is not None:
dali_op_cfg.update({"interp_type": INTERP_MAP[interpolation]})
elif op_name == "CropImage":
size = op_cfg.get("size", 224)
size = make_pair(size)
dali_op_cfg.update({"crop_h": size[1], "crop_w": size[0]})
dali_op_cfg.update({"crop_pos_x": 0.5, "crop_pos_y": 0.5})
elif op_name == "RandomCropImage":
size = op_cfg.get("size", 224)
if size is not None:
size = make_pair(size)
dali_op_cfg.update({"crop_h": size[1], "crop_w": size[0]})
elif op_name == "RandCropImage":
size = op_cfg.get("size", 224)
size = make_pair(size)
scale = op_cfg.get("scale", [0.08, 1.0])
ratio = op_cfg.get("ratio", [3.0 / 4, 4.0 / 3])
interpolation = op_cfg.get("interpolation", "bilinear")
dali_op_cfg.update({"size": size})
if scale is not None:
dali_op_cfg.update({"random_area": scale})
if ratio is not None:
dali_op_cfg.update({"random_aspect_ratio": ratio})
if interpolation is not None:
dali_op_cfg.update({"interp_type": INTERP_MAP[interpolation]})
elif op_name == "RandCropImageV2":
size = op_cfg.get("size", 224)
size = make_pair(size)
dali_op_cfg.update({"crop_h": size[1], "crop_w": size[0]})
elif op_name == "RandFlipImage":
prob = op_cfg.get("prob", 0.5)
flip_code = op_cfg.get("flip_code", 1)
dali_op_cfg.update({"prob": prob})
dali_op_cfg.update({"flip_code": flip_code})
elif op_name == "NormalizeImage":
# scale * (in - mean) / stddev + shift
scale = op_cfg.get("scale", 1.0 / 255.0)
if isinstance(scale, str):
scale = eval(scale)
mean = op_cfg.get("mean", [0.485, 0.456, 0.406])
std = op_cfg.get("std", [0.229, 0.224, 0.225])
mean = [v / scale for v in mean]
std = [v / scale for v in std]
order = op_cfg.get("order", "chw")
channel_num = op_cfg.get("channel_num", 3)
output_fp16 = op_cfg.get("output_fp16", False)
dali_op_cfg.update({
"mean": np.reshape(
np.array(
mean, dtype="float32"), [channel_num, 1, 1]
if order == "chw" else [1, 1, channel_num])
})
dali_op_cfg.update({
"stddev": np.reshape(
np.array(
std, dtype="float32"), [channel_num, 1, 1]
if order == "chw" else [1, 1, channel_num])
})
if output_fp16:
dali_op_cfg.update({"dtype": types.FLOAT16})
elif op_name == "ToCHWImage":
dali_op_cfg.update({"perm": [2, 0, 1]})
elif op_name == "ColorJitter":
prob = op_cfg.get("prob", 1.0)
brightness = op_cfg.get("brightness", 0.0)
contrast = op_cfg.get("contrast", 0.0)
saturation = op_cfg.get("saturation", 0.0)
hue = op_cfg.get("hue", 0.0)
dali_op_cfg.update({"prob": prob})
dali_op_cfg.update({"brightness_factor": brightness})
dali_op_cfg.update({"contrast_factor": contrast})
dali_op_cfg.update({"saturation_factor": saturation})
dali_op_cfg.update({"hue_factor": hue})
elif op_name == "RandomRotation":
prob = op_cfg.get("prob", 0.5)
degrees = op_cfg.get("degrees", 90)
interpolation = op_cfg.get("interpolation", "bilinear")
dali_op_cfg.update({"prob": prob})
dali_op_cfg.update({"angle": degrees})
dali_op_cfg.update({"interp_type": INTERP_MAP[interpolation]})
elif op_name == "Pad":
size = op_cfg.get("size", 224)
size = make_pair(size)
padding = op_cfg.get("padding", 0)
fill = op_cfg.get("fill", 0)
dali_op_cfg.update({
"crop_h": padding + size[1] + padding,
"crop_w": padding + size[0] + padding
})
dali_op_cfg.update({"fill_values": fill})
dali_op_cfg.update({"out_of_bounds_policy": "pad"})
elif op_name == "RandomRot90":
interpolation = op_cfg.get("interpolation", "nearest")
elif op_name == "DecodeRandomResizedCrop":
device = "cpu" if device == "cpu" else "mixed"
output_type = op_cfg.get("output_type", types.DALIImageType.RGB)
device_memory_padding = op_cfg.get("device_memory_padding", 211025920)
host_memory_padding = op_cfg.get("host_memory_padding", 140544512)
scale = op_cfg.get("scale", [0.08, 1.0])
ratio = op_cfg.get("ratio", [3.0 / 4, 4.0 / 3])
num_attempts = op_cfg.get("num_attempts", 100)
size = op_cfg.get("size", 224)
dali_op_cfg.update({"device": device})
if output_type is not None:
dali_op_cfg.update({"output_type": output_type})
if device_memory_padding is not None:
dali_op_cfg.update({
"device_memory_padding": device_memory_padding
})
if host_memory_padding is not None:
dali_op_cfg.update({"host_memory_padding": host_memory_padding})
if scale is not None:
dali_op_cfg.update({"random_area": scale})
if ratio is not None:
dali_op_cfg.update({"random_aspect_ratio": ratio})
if num_attempts is not None:
dali_op_cfg.update({"num_attempts": num_attempts})
if size is not None:
dali_op_cfg.update({"resize_x": size, "resize_y": size})
elif op_name == "CropMirrorNormalize":
dtype = types.FLOAT16 if op_cfg.get("output_fp16",
False) else types.FLOAT
output_layout = op_cfg.get("output_layout", "CHW")
size = op_cfg.get("size", None)
scale = op_cfg.get("scale", 1 / 255.0)
if isinstance(scale, str):
scale = eval(scale)
mean = op_cfg.get("mean", [0.485, 0.456, 0.406])
mean = [v / scale for v in mean]
std = op_cfg.get("std", [0.229, 0.224, 0.225])
std = [v / scale for v in std]
pad_output = op_cfg.get("channel_num", 3) == 4
prob = op_cfg.get("prob", 0.5)
dali_op_cfg.update({"dtype": dtype})
if output_layout is not None:
dali_op_cfg.update({"output_layout": output_layout})
if size is not None:
dali_op_cfg.update({"crop": (size, size)})
if mean is not None:
dali_op_cfg.update({"mean": mean})
if std is not None:
dali_op_cfg.update({"std": std})
if pad_output is not None:
dali_op_cfg.update({"pad_output": pad_output})
if prob is not None:
dali_op_cfg.update({"prob": prob})
else:
raise ValueError(
f"DALI operator \"{op_name}\" in PaddleClas is not implemented now. please refer to docs/zh_CN/training/config_description/develop_with_DALI.md"
)
if "device" not in dali_op_cfg:
dali_op_cfg.update({"device": device})
return dali_op_cfg
def build_dali_transforms(op_cfg_list: List[Dict[str, Any]],
mode: str,
device: str="gpu",
enable_fuse: bool=True) -> List[Callable]:
"""create dali operators based on the config
Args:
op_cfg_list (List[Dict[str, Any]]): a dict list, used to create some operators, such as config below
--------------------------------
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ""
--------------------------------
mode (str): mode.
device (str): device which dali operator(s) applied in. Defaults to "gpu".
enable_fuse (bool): whether to use fused dali operators instead of single operators, such as DecodeRandomResizedCrop. Defaults to True.
Returns:
List[Callable]: Callable DALI operators in list.
"""
assert isinstance(op_cfg_list, list), "operator config should be a list"
# build dali transforms list
dali_op_list = []
idx = 0
num_cfg_node = len(op_cfg_list)
while idx < num_cfg_node:
op_cfg = op_cfg_list[idx]
op_name = list(op_cfg)[0]
op_param = {} if op_cfg[op_name] is None else copy.deepcopy(op_cfg[
op_name])
fused_success = False
if enable_fuse:
# fuse operators if enabled
if idx + 1 < num_cfg_node:
op_name_nxt = list(op_cfg_list[idx + 1])[0]
if (op_name == "DecodeImage" and
op_name_nxt == "RandCropImage"):
fused_op_name = "DecodeRandomResizedCrop"
fused_op_param = convert_cfg_to_dali(
fused_op_name, device, **{
** op_param, ** (op_cfg_list[idx + 1][op_name_nxt])
})
fused_dali_op = eval(fused_op_name)(**fused_op_param)
idx += 2
dali_op_list.append(fused_dali_op)
fused_success = True
logger.info(
f"DALI fused Operator conversion({mode}): [DecodeImage, RandCropImage] -> {type_name(dali_op_list[-1])}: {fused_op_param}"
)
if not fused_success and 0 < idx and idx + 1 < num_cfg_node:
op_name_pre = list(op_cfg_list[idx - 1])[0]
op_name_nxt = list(op_cfg_list[idx + 1])[0]
if (op_name_pre == "RandCropImage" and
op_name == "RandFlipImage" and
op_name_nxt == "NormalizeImage"):
fused_op_name = "CropMirrorNormalize"
fused_op_param = convert_cfg_to_dali(
fused_op_name, device, **{
** op_param, **
(op_cfg_list[idx - 1][op_name_pre]), **
(op_cfg_list[idx + 1][op_name_nxt])
})
fused_dali_op = eval(fused_op_name)(**fused_op_param)
idx += 2
dali_op_list.append(fused_dali_op)
fused_success = True
logger.info(
f"DALI fused Operator conversion({mode}): [RandCropImage, RandFlipImage, NormalizeImage] -> {type_name(dali_op_list[-1])}: {fused_op_param}"
)
if not fused_success and idx + 1 < num_cfg_node:
op_name_nxt = list(op_cfg_list[idx + 1])[0]
if (op_name == "CropImage" and
op_name_nxt == "NormalizeImage"):
fused_op_name = "CropMirrorNormalize"
fused_op_param = convert_cfg_to_dali(
fused_op_name, device, **{
**
op_param,
**
(op_cfg_list[idx + 1][op_name_nxt]),
"prob": 0.0
})
fused_dali_op = eval(fused_op_name)(**fused_op_param)
idx += 2
dali_op_list.append(fused_dali_op)
fused_success = True
logger.info(
f"DALI fused Operator conversion({mode}): [CropImage, NormalizeImage] -> {type_name(dali_op_list[-1])}: {fused_op_param}"
)
if not enable_fuse or not fused_success:
assert isinstance(op_cfg,
dict) and len(op_cfg) == 1, "yaml format error"
if op_name == "Pad":
# NOTE: Argument `size` must be provided for DALI operator
op_param.update({
"size": parse_value_with_key(op_cfg_list[:idx], "size")
})
dali_param = convert_cfg_to_dali(op_name, device, **op_param)
dali_op = eval(op_name)(**dali_param)
dali_op_list.append(dali_op)
idx += 1
logger.info(
f"DALI Operator conversion({mode}): {op_name} -> {type_name(dali_op_list[-1])}: {dali_param}"
)
return dali_op_list
class ExternalSource_RandomIdentity(object):
"""PKsampler implemented with ExternalSource
Args:
batch_size (int): batch size
sample_per_id (int): number of instance(s) within an class
device_id (int): device id
shard_id (int): shard id
num_gpus (int): number of gpus
image_root (str): image root directory
cls_label_path (str): path to annotation file, such as `train_list.txt` or `val_list.txt`
delimiter (Optional[str], optional): delimiter. Defaults to None.
relabel (bool, optional): whether do relabel when original label do not starts from 0 or are discontinuous. Defaults to False.
sample_method (str, optional): sample method when generating prob_list. Defaults to "sample_avg_prob".
id_list (List[int], optional): list of (start_id, end_id, start_id, end_id) for set of ids to duplicated. Defaults to None.
ratio (List[Union[int, float]], optional): list of (ratio1, ratio2..) the duplication number for ids in id_list. Defaults to None.
shuffle (bool): whether to shuffle label list. Defaults to True.
"""
class HybridTrainPipe(Pipeline):
def __init__(self, def __init__(self,
file_root, batch_size: int,
file_list, sample_per_id: int,
batch_size, device_id: int,
resize_shorter, shard_id: int,
crop, num_gpus: int,
min_area, image_root: str,
lower, cls_label_path: str,
upper, delimiter: Optional[str]=None,
interp, relabel: bool=False,
mean, sample_method: str="sample_avg_prob",
std, id_list: List[int]=None,
device_id, ratio: List[Union[int, float]]=None,
shard_id=0, shuffle: bool=True):
num_shards=1, self.batch_size = batch_size
random_shuffle=True, self.sample_per_id = sample_per_id
num_threads=4, self.label_per_batch = self.batch_size // self.sample_per_id
seed=42, self.device_id = device_id
pad_output=False, self.shard_id = shard_id
output_dtype=types.FLOAT, self.num_gpus = num_gpus
dataset='Train'): self._img_root = image_root
super(HybridTrainPipe, self).__init__( self._cls_path = cls_label_path
batch_size, num_threads, device_id, seed=seed) self.delimiter = delimiter if delimiter is not None else " "
self.input = ops.readers.File( self.relabel = relabel
file_root=file_root, self.sample_method = sample_method
file_list=file_list, self.image_paths = []
shard_id=shard_id, self.labels = []
num_shards=num_shards, self.epoch = 0
random_shuffle=random_shuffle)
# set internal nvJPEG buffers size to handle full-sized ImageNet images
# without additional reallocations
device_memory_padding = 211025920
host_memory_padding = 140544512
self.decode = ops.decoders.ImageRandomCrop(
device='mixed',
output_type=types.DALIImageType.RGB,
device_memory_padding=device_memory_padding,
host_memory_padding=host_memory_padding,
random_aspect_ratio=[lower, upper],
random_area=[min_area, 1.0],
num_attempts=100)
self.res = ops.Resize(
device='gpu', resize_x=crop, resize_y=crop, interp_type=interp)
self.cmnp = ops.CropMirrorNormalize(
device="gpu",
dtype=output_dtype,
output_layout='CHW',
crop=(crop, crop),
mean=mean,
std=std,
pad_output=pad_output)
self.coin = ops.random.CoinFlip(probability=0.5)
self.to_int64 = ops.Cast(dtype=types.DALIDataType.INT64, device="gpu")
def define_graph(self): # NOTE: code from ImageNetDataset below
rng = self.coin() with open(self._cls_path, "r") as fd:
jpegs, labels = self.input(name="Reader") lines = fd.readlines()
images = self.decode(jpegs) if self.relabel:
images = self.res(images) label_set = set()
output = self.cmnp(images.gpu(), mirror=rng) for line in lines:
return [output, self.to_int64(labels.gpu())] line = line.strip().split(self.delimiter)
label_set.add(np.int64(line[1]))
label_map = {
oldlabel: newlabel
for newlabel, oldlabel in enumerate(label_set)
}
for line in lines:
line = line.strip().split(self.delimiter)
self.image_paths.append(os.path.join(self._img_root, line[0]))
if self.relabel:
self.labels.append(label_map[np.int64(line[1])])
else:
self.labels.append(np.int64(line[1]))
assert os.path.exists(self.image_paths[
-1]), f"path {self.image_paths[-1]} does not exist."
# NOTE: code from PKSampler below
# group sample indexes into their label bucket
self.label_dict = defaultdict(list)
for idx, label in enumerate(self.labels):
self.label_dict[label].append(idx)
# get all label
self.label_list = list(self.label_dict)
assert len(self.label_list) * self.sample_per_id >= self.batch_size, \
f"batch size({self.batch_size}) should not be bigger than than #classes({len(self.label_list)})*sample_per_id({self.sample_per_id})"
if self.sample_method == "id_avg_prob":
self.prob_list = np.array([1 / len(self.label_list)] *
len(self.label_list))
elif self.sample_method == "sample_avg_prob":
counter = []
for label_i in self.label_list:
counter.append(len(self.label_dict[label_i]))
self.prob_list = np.array(counter) / sum(counter)
# reweight prob_list according to id_list and ratio if provided
if id_list and ratio:
assert len(id_list) % 2 == 0 and len(id_list) == len(ratio) * 2
for i in range(len(self.prob_list)):
for j in range(len(ratio)):
if i >= id_list[j * 2] and i <= id_list[j * 2 + 1]:
self.prob_list[i] = self.prob_list[i] * ratio[j]
break
self.prob_list = self.prob_list / sum(self.prob_list)
assert os.path.exists(
self._cls_path), f"path {self._cls_path} does not exist."
assert os.path.exists(
self._img_root), f"path {self._img_root} does not exist."
diff = np.abs(sum(self.prob_list) - 1)
if diff > 0.00000001:
self.prob_list[-1] = 1 - sum(self.prob_list[:-1])
if self.prob_list[-1] > 1 or self.prob_list[-1] < 0:
logger.error("PKSampler prob list error")
else:
logger.info(
"sum of prob list not equal to 1, diff is {}, change the last prob".
format(diff))
# whole dataset size
self.data_set_len = len(self.image_paths)
# get sharded size
self.sharded_data_set_len = self.data_set_len // self.num_gpus
# iteration log
self.shuffle = shuffle
self.total_iter = self.sharded_data_set_len // batch_size
self.iter_count = 0
def __iter__(self):
if self.shuffle:
seed = self.shard_id * 12345 + self.epoch
np.random.RandomState(seed).shuffle(self.label_list)
self.epoch += 1
return self
def __next__(self):
if self.iter_count >= self.total_iter:
self.__iter__()
self.iter_count = 0
batch_indexes = []
for _ in range(self.sharded_data_set_len):
batch_label_list = np.random.choice(
self.label_list,
size=self.label_per_batch,
replace=False,
p=self.prob_list)
for label_i in batch_label_list:
label_i_indexes = self.label_dict[label_i]
if self.sample_per_id <= len(label_i_indexes):
batch_indexes.extend(
np.random.choice(
label_i_indexes,
size=self.sample_per_id,
replace=False))
else:
batch_indexes.extend(
np.random.choice(
label_i_indexes,
size=self.sample_per_id,
replace=True))
if len(batch_indexes) == self.batch_size:
break
batch_indexes = []
batch_raw_images = []
batch_labels = []
for index in batch_indexes:
batch_raw_images.append(
np.fromfile(
self.image_paths[index], dtype="uint8"))
batch_labels.append(self.labels[index])
self.iter_count += 1
return (batch_raw_images, np.array(batch_labels, dtype="int64"))
def __len__(self): def __len__(self):
return self.epoch_size("Reader") return self.sharded_data_set_len
class HybridValPipe(Pipeline): class HybridPipeline(pipeline.Pipeline):
"""Hybrid Pipeline
Args:
device (str): device
batch_size (int): batch size
py_num_workers (int): number of python worker(s)
num_threads (int): number of thread(s)
device_id (int): device id
seed (int): random seed
file_root (str): file root path
file_list (str): path to annotation file, such as `train_list.txt` or `val_list.txt`
transform_list (List[Callable]): List of DALI transform operator(s)
shard_id (int, optional): shard id. Defaults to 0.
num_shards (int, optional): number of shard(s). Defaults to 1.
random_shuffle (bool, optional): whether shuffle data during training. Defaults to True.
ext_src (optional): custom external source. Defaults to None.
"""
def __init__(self, def __init__(self,
file_root, device: str,
file_list, batch_size: int,
batch_size, py_num_workers: int,
resize_shorter, num_threads: int,
crop, device_id: int,
interp, seed: int,
mean, file_root: str,
std, file_list: str,
device_id, transform_list: List[Callable],
shard_id=0, shard_id: int=0,
num_shards=1, num_shards: int=1,
random_shuffle=False, random_shuffle: bool=True,
num_threads=4, ext_src=None):
seed=42, super(HybridPipeline, self).__init__(
pad_output=False, batch_size=batch_size,
output_dtype=types.FLOAT): device_id=device_id,
super(HybridValPipe, self).__init__( seed=seed,
batch_size, num_threads, device_id, seed=seed) py_start_method="fork" if ext_src is None else "spawn",
self.input = ops.readers.File( py_num_workers=py_num_workers,
num_threads=num_threads)
self.device = device
self.ext_src = ext_src
if ext_src is None:
self.reader = ops.readers.File(
file_root=file_root, file_root=file_root,
file_list=file_list, file_list=file_list,
shard_id=shard_id, shard_id=shard_id,
num_shards=num_shards, num_shards=num_shards,
random_shuffle=random_shuffle) random_shuffle=random_shuffle)
self.decode = ops.decoders.Image(device="mixed") self.transforms = ops.Compose(transform_list)
self.res = ops.Resize( self.cast = ops.Cast(dtype=types.DALIDataType.INT64, device=device)
device="gpu", resize_shorter=resize_shorter, interp_type=interp)
self.cmnp = ops.CropMirrorNormalize(
device="gpu",
dtype=output_dtype,
output_layout='CHW',
crop=(crop, crop),
mean=mean,
std=std,
pad_output=pad_output)
self.to_int64 = ops.Cast(dtype=types.DALIDataType.INT64, device="gpu")
def define_graph(self): def define_graph(self):
jpegs, labels = self.input(name="Reader") if self.ext_src:
images = self.decode(jpegs) raw_images, labels = fn.external_source(
images = self.res(images) source=self.ext_src,
output = self.cmnp(images) num_outputs=2,
return [output, self.to_int64(labels.gpu())] dtype=[types.DALIDataType.UINT8, types.DALIDataType.INT64],
batch=True,
parallel=True)
else:
raw_images, labels = self.reader(name="Reader")
images = self.transforms(raw_images)
return [
images, self.cast(labels.gpu() if self.device == "gpu" else labels)
]
def __len__(self): def __len__(self):
return self.epoch_size("Reader") if self.ext_src is not None:
return len(self.ext_src)
return self.epoch_size(name="Reader")
class DALIImageNetIterator(DALIGenericIterator): class DALIImageNetIterator(DALIGenericIterator):
...@@ -158,180 +667,128 @@ class DALIImageNetIterator(DALIGenericIterator): ...@@ -158,180 +667,128 @@ class DALIImageNetIterator(DALIGenericIterator):
return data_batch return data_batch
def dali_dataloader(config, mode, device, num_threads=4, seed=None): def dali_dataloader(config: Dict[str, Any],
assert "gpu" in device, "gpu training is required for DALI" mode: str,
device_id = int(device.split(':')[1]) device: str,
config_dataloader = config[mode] py_num_workers: int=1,
seed = 42 if seed is None else seed num_threads: int=4,
ops = [ seed: Optional[int]=None,
list(x.keys())[0] enable_fuse: bool=True) -> DALIImageNetIterator:
for x in config_dataloader["dataset"]["transform_ops"] """build and return HybridPipeline
]
support_ops_train = [
"DecodeImage", "NormalizeImage", "RandFlipImage", "RandCropImage"
]
support_ops_eval = [
"DecodeImage", "ResizeImage", "CropImage", "NormalizeImage"
]
if mode.lower() == 'train': Args:
assert set(ops) == set( config (Dict[str, Any]): train/eval dataloader configuration
support_ops_train mode (str): mode
), "The supported trasform_ops for train_dataset in dali is : {}".format( device (str): device string
",".join(support_ops_train)) py_num_workers (int, optional): number of python worker(s). Defaults to 1.
else: num_threads (int, optional): number of thread(s). Defaults to 4.
assert set(ops) == set( seed (Optional[int], optional): random seed. Defaults to None.
support_ops_eval enable_fuse (bool, optional): enable fused operator(s). Defaults to True.
), "The supported trasform_ops for eval_dataset in dali is : {}".format(
",".join(support_ops_eval))
normalize_ops = [
op for op in config_dataloader["dataset"]["transform_ops"]
if "NormalizeImage" in op
][0]["NormalizeImage"]
channel_num = normalize_ops.get("channel_num", 3)
output_dtype = types.FLOAT16 if normalize_ops.get("output_fp16",
False) else types.FLOAT
Returns:
DALIImageNetIterator: Iterable DALI dataloader
"""
assert "gpu" in device, f"device must be \"gpu\" when running with DALI, but got {device}"
config_dataloader = config[mode]
device_id = int(device.split(":")[1])
device = "gpu"
seed = 42 if seed is None else seed
env = os.environ env = os.environ
# assert float(env.get('FLAGS_fraction_of_gpu_memory_to_use', 0.92)) < 0.9, \ num_gpus = paddle.distributed.get_world_size()
# "Please leave enough GPU memory for DALI workspace, e.g., by setting" \
# " `export FLAGS_fraction_of_gpu_memory_to_use=0.8`"
gpu_num = paddle.distributed.get_world_size()
batch_size = config_dataloader["sampler"]["batch_size"] batch_size = config_dataloader["sampler"]["batch_size"]
file_root = config_dataloader["dataset"]["image_root"] file_root = config_dataloader["dataset"]["image_root"]
file_list = config_dataloader["dataset"]["cls_label_path"] file_list = config_dataloader["dataset"]["cls_label_path"]
interp = 1 # settings.interpolation or 1 # default to linear
interp_map = {
0: types.DALIInterpType.INTERP_NN, # cv2.INTER_NEAREST
1: types.DALIInterpType.INTERP_LINEAR, # cv2.INTER_LINEAR
2: types.DALIInterpType.INTERP_CUBIC, # cv2.INTER_CUBIC
3: types.DALIInterpType.
INTERP_LANCZOS3, # XXX use LANCZOS3 for cv2.INTER_LANCZOS4
}
assert interp in interp_map, "interpolation method not supported by DALI"
interp = interp_map[interp]
pad_output = channel_num == 4
transforms = {
k: v
for d in config_dataloader["dataset"]["transform_ops"]
for k, v in d.items()
}
scale = transforms["NormalizeImage"].get("scale", 1.0 / 255)
scale = eval(scale) if isinstance(scale, str) else scale
mean = transforms["NormalizeImage"].get("mean", [0.485, 0.456, 0.406])
std = transforms["NormalizeImage"].get("std", [0.229, 0.224, 0.225])
mean = [v / scale for v in mean]
std = [v / scale for v in std]
sampler_name = config_dataloader["sampler"].get("name", sampler_name = config_dataloader["sampler"].get("name",
"DistributedBatchSampler") "DistributedBatchSampler")
assert sampler_name in ["DistributedBatchSampler", "BatchSampler"] transform_ops_cfg = config_dataloader["dataset"]["transform_ops"]
random_shuffle = config_dataloader["sampler"].get("shuffle", None)
if mode.lower() == "train": dali_transforms = build_dali_transforms(
resize_shorter = 256 transform_ops_cfg, mode, device, enable_fuse=enable_fuse)
crop = transforms["RandCropImage"]["size"] if "ToCHWImage" not in [type_name(op) for op in dali_transforms] and (
scale = transforms["RandCropImage"].get("scale", [0.08, 1.]) "CropMirrorNormalize" not in
ratio = transforms["RandCropImage"].get("ratio", [3.0 / 4, 4.0 / 3]) [type_name(op) for op in dali_transforms]):
min_area = scale[0] dali_transforms.append(ToCHWImage(perm=[2, 0, 1], device=device))
lower = ratio[0] logger.info(
upper = ratio[1] "Append DALI operator \"ToCHWImage\" at the end of dali_transforms for getting output in \"CHW\" shape"
)
if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env and 'FLAGS_selected_gpus' in env:
shard_id = int(env['PADDLE_TRAINER_ID']) if mode.lower() in ["train"]:
num_shards = int(env['PADDLE_TRAINERS_NUM']) if "PADDLE_TRAINER_ID" in env and "PADDLE_TRAINERS_NUM" in env and "FLAGS_selected_gpus" in env:
device_id = int(env['FLAGS_selected_gpus']) shard_id = int(env["PADDLE_TRAINER_ID"])
pipe = HybridTrainPipe( num_shards = int(env["PADDLE_TRAINERS_NUM"])
file_root, device_id = int(env["FLAGS_selected_gpus"])
file_list,
batch_size,
resize_shorter,
crop,
min_area,
lower,
upper,
interp,
mean,
std,
device_id,
shard_id,
num_shards,
num_threads=num_threads,
seed=seed + shard_id,
pad_output=pad_output,
output_dtype=output_dtype)
pipe.build()
pipelines = [pipe]
# sample_per_shard = len(pipe) // num_shards
else: else:
pipe = HybridTrainPipe( shard_id = 0
file_root, num_shards = 1
file_list, logger.info(
batch_size, f"Building DALI {mode} pipeline with num_shards: {num_shards}, num_gpus: {num_gpus}"
resize_shorter, )
crop,
min_area, random_shuffle = random_shuffle if random_shuffle is not None else True
lower, if sampler_name in ["PKSampler", "DistributedRandomIdentitySampler"]:
upper, ext_src = ExternalSource_RandomIdentity(
interp, batch_size=batch_size,
mean, sample_per_id=config_dataloader["sampler"][
std, "sample_per_id"
if sampler_name == "PKSampler" else "num_instances"],
device_id=device_id, device_id=device_id,
shard_id=0, shard_id=shard_id,
num_shards=1, num_gpus=num_gpus,
num_threads=num_threads, image_root=file_root,
seed=seed, cls_label_path=file_list,
pad_output=pad_output, delimiter=None,
output_dtype=output_dtype) relabel=config_dataloader["dataset"].get("relabel", False),
sample_method=config_dataloader["sampler"].get(
"sample_method", "sample_avg_prob"),
id_list=config_dataloader["sampler"].get("id_list", None),
ratio=config_dataloader["sampler"].get("ratio", None),
shuffle=random_shuffle)
logger.info(
f"Building DALI {mode} pipeline with ext_src({type_name(ext_src)})"
)
else:
ext_src = None
pipe = HybridPipeline(device, batch_size, py_num_workers, num_threads,
device_id, seed + shard_id, file_root, file_list,
dali_transforms, shard_id, num_shards,
random_shuffle, ext_src)
pipe.build() pipe.build()
pipelines = [pipe] pipelines = [pipe]
# sample_per_shard = len(pipelines[0]) if ext_src is None:
return DALIImageNetIterator( return DALIImageNetIterator(
pipelines, ['data', 'label'], reader_name='Reader') pipelines, ["data", "label"], reader_name="Reader")
else: else:
resize_shorter = transforms["ResizeImage"].get("resize_short", 256) return DALIImageNetIterator(
crop = transforms["CropImage"]["size"] pipelines,
if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env and 'FLAGS_selected_gpus' in env and sampler_name == "DistributedBatchSampler": ["data", "label"],
shard_id = int(env['PADDLE_TRAINER_ID']) size=len(ext_src),
num_shards = int(env['PADDLE_TRAINERS_NUM']) last_batch_policy=LastBatchPolicy.
device_id = int(env['FLAGS_selected_gpus']) DROP # make reset() successfully
)
pipe = HybridValPipe( elif mode.lower() in ["eval", "gallery", "query"]:
file_root, assert sampler_name in ["DistributedBatchSampler"], \
file_list, f"sampler_name({sampler_name}) must in [\"DistributedBatchSampler\"]"
batch_size, if "PADDLE_TRAINER_ID" in env and "PADDLE_TRAINERS_NUM" in env and "FLAGS_selected_gpus" in env:
resize_shorter, shard_id = int(env["PADDLE_TRAINER_ID"])
crop, num_shards = int(env["PADDLE_TRAINERS_NUM"])
interp, device_id = int(env["FLAGS_selected_gpus"])
mean,
std,
device_id=device_id,
shard_id=shard_id,
num_shards=num_shards,
num_threads=num_threads,
pad_output=pad_output,
output_dtype=output_dtype)
else: else:
pipe = HybridValPipe( shard_id = 0
file_root, num_shards = 1
file_list, logger.info(
batch_size, f"Building DALI {mode} pipeline with num_shards: {num_shards}, num_gpus: {num_gpus}..."
resize_shorter, )
crop,
interp, random_shuffle = random_shuffle if random_shuffle is not None else False
mean, pipe = HybridPipeline(device, batch_size, py_num_workers, num_threads,
std, device_id, seed + shard_id, file_root, file_list,
device_id=device_id, dali_transforms, shard_id, num_shards,
num_threads=num_threads, random_shuffle)
pad_output=pad_output,
output_dtype=output_dtype)
pipe.build() pipe.build()
pipelines = [pipe]
return DALIImageNetIterator( return DALIImageNetIterator(
[pipe], ['data', 'label'], reader_name="Reader") pipelines, ["data", "label"], reader_name="Reader")
else:
raise ValueError(f"Invalid mode({mode}) when building DALI pipeline")
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import nvidia.dali.fn as fn
import nvidia.dali.ops as ops
import nvidia.dali.types as types
class DecodeImage(ops.decoders.Image):
def __init__(self, *kargs, device="cpu", **kwargs):
super(DecodeImage, self).__init__(*kargs, device=device, **kwargs)
def __call__(self, data, **kwargs):
return super(DecodeImage, self).__call__(data, **kwargs)
class ToCHWImage(ops.Transpose):
def __init__(self, *kargs, device="cpu", **kwargs):
super(ToCHWImage, self).__init__(*kargs, device=device, **kwargs)
def __call__(self, data, **kwargs):
return super(ToCHWImage, self).__call__(data, **kwargs)
class ColorJitter(ops.ColorTwist):
def __init__(self,
*kargs,
device="cpu",
prob=1.0,
brightness_factor=0.0,
contrast_factor=0.0,
saturation_factor=0.0,
hue_factor=0.0,
**kwargs):
super(ColorJitter, self).__init__(*kargs, device=device, **kwargs)
self.brightness_factor = brightness_factor
self.contrast_factor = contrast_factor
self.saturation_factor = saturation_factor
self.hue_factor = hue_factor
self.rng = ops.random.CoinFlip(probability=prob)
def __call__(self, data, **kwargs):
do_jitter = self.rng()
brightness = fn.random.uniform(
range=(max(0, 1 - self.brightness_factor),
1 + self.brightness_factor)) * do_jitter
contrast = fn.random.uniform(
range=(max(0, 1 - self.contrast_factor),
1 + self.contrast_factor)) * do_jitter
saturation = fn.random.uniform(
range=(max(0, 1 - self.saturation_factor),
1 + self.saturation_factor)) * do_jitter
hue = fn.random.uniform(range=(-self.hue_factor,
self.hue_factor)) * do_jitter
return super(ColorJitter, self).__call__(
data,
brightness=brightness,
contrast=contrast,
saturation=saturation,
hue=hue,
**kwargs)
class DecodeRandomResizedCrop(ops.decoders.ImageRandomCrop):
def __init__(self,
*kargs,
device="cpu",
resize_x=224,
resize_y=224,
resize_short=None,
interp_type=types.DALIInterpType.INTERP_LINEAR,
**kwargs):
super(DecodeRandomResizedCrop, self).__init__(
*kargs, device=device, **kwargs)
if resize_short is None:
self.resize = ops.Resize(
device="gpu" if device == "mixed" else "cpu",
resize_x=resize_x,
resize_y=resize_y,
interp_type=interp_type)
else:
self.resize = ops.Resize(
device="gpu" if device == "mixed" else "cpu",
resize_short=resize_short,
interp_type=interp_type)
def __call__(self, data, **kwargs):
data = super(DecodeRandomResizedCrop, self).__call__(data, **kwargs)
data = self.resize(data)
return data
class CropMirrorNormalize(ops.CropMirrorNormalize):
def __init__(self, *kargs, device="cpu", prob=0.5, **kwargs):
super(CropMirrorNormalize, self).__init__(
*kargs, device=device, **kwargs)
self.rng = ops.random.CoinFlip(probability=prob)
def __call__(self, data, **kwargs):
do_mirror = self.rng()
return super(CropMirrorNormalize, self).__call__(
data, mirror=do_mirror, **kwargs)
class RandCropImage(ops.RandomResizedCrop):
def __init__(self, *kargs, device="cpu", **kwargs):
super(RandCropImage, self).__init__(*kargs, device=device, **kwargs)
def __call__(self, data, **kwargs):
return super(RandCropImage, self).__call__(data, **kwargs)
class CropImage(ops.Crop):
def __init__(self, *kargs, device="cpu", **kwargs):
super(CropImage, self).__init__(*kargs, device=device, **kwargs)
def __call__(self, data, **kwargs):
return super(CropImage, self).__call__(data, **kwargs)
class ResizeImage(ops.Resize):
def __init__(self, *kargs, device="cpu", **kwargs):
super(ResizeImage, self).__init__(*kargs, device=device, **kwargs)
def __call__(self, data, **kwargs):
return super(ResizeImage, self).__call__(data, **kwargs)
class RandFlipImage(ops.Flip):
def __init__(self, *kargs, device="cpu", prob=0.5, flip_code=1, **kwargs):
super(RandFlipImage, self).__init__(*kargs, device=device, **kwargs)
self.flip_code = flip_code
self.rng = ops.random.CoinFlip(probability=prob)
def __call__(self, data, **kwargs):
do_flip = self.rng()
if self.flip_code == 1:
return super(RandFlipImage, self).__call__(
data, horizontal=do_flip, vertical=0, **kwargs)
elif self.flip_code == 0:
return super(RandFlipImage, self).__call__(
data, horizontal=0, vertical=do_flip, **kwargs)
else:
return super(RandFlipImage, self).__call__(
data, horizontal=do_flip, vertical=do_flip, **kwargs)
class Pad(ops.Crop):
"""
use ops.Crop to implement Pad operator, for ops.Pad alwayls only pad in right and bottom.
"""
def __init__(self, *kargs, device="cpu", **kwargs):
super(Pad, self).__init__(*kargs, device=device, **kwargs)
def __call__(self, data, **kwargs):
return super(Pad, self).__call__(data, **kwargs)
class RandCropImageV2(ops.Crop):
def __init__(self, *kargs, device="cpu", **kwargs):
super(RandCropImageV2, self).__init__(*kargs, device=device, **kwargs)
self.rng_x = ops.random.Uniform(range=(0.0, 1.0))
self.rng_y = ops.random.Uniform(range=(0.0, 1.0))
def __call__(self, data, **kwargs):
pos_x = self.rng_x()
pos_y = self.rng_y()
return super(RandCropImageV2, self).__call__(
data, crop_pos_x=pos_x, crop_pos_y=pos_y, **kwargs)
class RandomCropImage(ops.Crop):
def __init__(self, *kargs, device="cpu", **kwargs):
super(RandomCropImage, self).__init__(*kargs, device=device, **kwargs)
self.rng_x = ops.random.Uniform(range=(0.0, 1.0))
self.rng_y = ops.random.Uniform(range=(0.0, 1.0))
def __call__(self, data, **kwargs):
pos_x = self.rng_x()
pos_y = self.rng_y()
return super(RandomCropImage, self).__call__(
data, crop_pos_x=pos_x, crop_pos_y=pos_y, **kwargs)
class RandomRotation(ops.Rotate):
def __init__(self, *kargs, device="cpu", prob=0.5, angle=0, **kwargs):
super(RandomRotation, self).__init__(*kargs, device=device, **kwargs)
self.rng = ops.random.CoinFlip(probability=prob)
discrete_angle = list(range(-angle, angle + 1))
self.rng_angle = ops.random.Uniform(values=discrete_angle)
def __call__(self, data, **kwargs):
do_rotate = self.rng()
angle = self.rng_angle()
flip_data = super(RandomRotation, self).__call__(
data,
angle=do_rotate * angle,
keep_size=True,
fill_value=0,
**kwargs)
return flip_data
class RandomRot90(ops.Rotate):
def __init__(self, *kargs, device="cpu", **kwargs):
super(RandomRot90, self).__init__(*kargs, device=device, **kwargs)
self.rng_angle = ops.random.Uniform(values=[0.0, 1.0, 2.0, 3.0])
def __call__(self, data, **kwargs):
angle = self.rng_angle() * 90.0
flip_data = super(RandomRot90, self).__call__(
data, angle=angle, keep_size=True, fill_value=0, **kwargs)
return flip_data
class NormalizeImage(ops.Normalize):
def __init__(self, *kargs, device="cpu", **kwargs):
super(NormalizeImage, self).__init__(*kargs, device=device, **kwargs)
def __call__(self, data, **kwargs):
return super(NormalizeImage, self).__call__(data, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册