dali.py 11.9 KB
Newer Older
T
Tingquan Gao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import division

import os

import numpy as np
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from nvidia.dali.plugin.paddle import DALIGenericIterator

import paddle
from paddle import fluid


class HybridTrainPipe(Pipeline):
    def __init__(self,
                 file_root,
                 file_list,
                 batch_size,
                 resize_shorter,
                 crop,
                 min_area,
                 lower,
                 upper,
                 interp,
                 mean,
                 std,
                 device_id,
                 shard_id=0,
                 num_shards=1,
                 random_shuffle=True,
                 num_threads=4,
H
huangxu96 已提交
47 48 49
                 seed=42,
                 pad_output=False,
                 output_dtype=types.FLOAT):
T
Tingquan Gao 已提交
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
        super(HybridTrainPipe, self).__init__(
            batch_size, num_threads, device_id, seed=seed)
        self.input = ops.FileReader(
            file_root=file_root,
            file_list=file_list,
            shard_id=shard_id,
            num_shards=num_shards,
            random_shuffle=random_shuffle)
        # set internal nvJPEG buffers size to handle full-sized ImageNet images
        # without additional reallocations
        device_memory_padding = 211025920
        host_memory_padding = 140544512
        self.decode = ops.ImageDecoderRandomCrop(
            device='mixed',
            output_type=types.RGB,
            device_memory_padding=device_memory_padding,
            host_memory_padding=host_memory_padding,
            random_aspect_ratio=[lower, upper],
            random_area=[min_area, 1.0],
            num_attempts=100)
        self.res = ops.Resize(
            device='gpu', resize_x=crop, resize_y=crop, interp_type=interp)
        self.cmnp = ops.CropMirrorNormalize(
            device="gpu",
H
huangxu96 已提交
74
            output_dtype=output_dtype,
T
Tingquan Gao 已提交
75 76 77 78
            output_layout=types.NCHW,
            crop=(crop, crop),
            image_type=types.RGB,
            mean=mean,
H
huangxu96 已提交
79 80
            std=std,
            pad_output=pad_output)
T
Tingquan Gao 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
        self.coin = ops.CoinFlip(probability=0.5)
        self.to_int64 = ops.Cast(dtype=types.INT64, device="gpu")

    def define_graph(self):
        rng = self.coin()
        jpegs, labels = self.input(name="Reader")
        images = self.decode(jpegs)
        images = self.res(images)
        output = self.cmnp(images.gpu(), mirror=rng)
        return [output, self.to_int64(labels.gpu())]

    def __len__(self):
        return self.epoch_size("Reader")


class HybridValPipe(Pipeline):
    def __init__(self,
                 file_root,
                 file_list,
                 batch_size,
                 resize_shorter,
                 crop,
                 interp,
                 mean,
                 std,
                 device_id,
                 shard_id=0,
                 num_shards=1,
                 random_shuffle=False,
                 num_threads=4,
H
huangxu96 已提交
111 112 113
                 seed=42,
                 pad_output=False,
                 output_dtype=types.FLOAT):
T
Tingquan Gao 已提交
114 115 116 117 118 119 120 121 122 123 124 125 126
        super(HybridValPipe, self).__init__(
            batch_size, num_threads, device_id, seed=seed)
        self.input = ops.FileReader(
            file_root=file_root,
            file_list=file_list,
            shard_id=shard_id,
            num_shards=num_shards,
            random_shuffle=random_shuffle)
        self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
        self.res = ops.Resize(
            device="gpu", resize_shorter=resize_shorter, interp_type=interp)
        self.cmnp = ops.CropMirrorNormalize(
            device="gpu",
H
huangxu96 已提交
127
            output_dtype=output_dtype,
T
Tingquan Gao 已提交
128 129 130 131
            output_layout=types.NCHW,
            crop=(crop, crop),
            image_type=types.RGB,
            mean=mean,
H
huangxu96 已提交
132 133
            std=std,
            pad_output=pad_output)
T
Tingquan Gao 已提交
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
        self.to_int64 = ops.Cast(dtype=types.INT64, device="gpu")

    def define_graph(self):
        jpegs, labels = self.input(name="Reader")
        images = self.decode(jpegs)
        images = self.res(images)
        output = self.cmnp(images)
        return [output, self.to_int64(labels.gpu())]

    def __len__(self):
        return self.epoch_size("Reader")


def build(config, mode='train'):
    env = os.environ
    assert config.get('use_gpu',
                      True) == True, "gpu training is required for DALI"
    assert not config.get(
        'use_aa'), "auto augment is not supported by DALI reader"
    assert float(env.get('FLAGS_fraction_of_gpu_memory_to_use', 0.92)) < 0.9, \
        "Please leave enough GPU memory for DALI workspace, e.g., by setting" \
        " `export FLAGS_fraction_of_gpu_memory_to_use=0.8`"

    dataset_config = config[mode.upper()]

    gpu_num = paddle.fluid.core.get_cuda_device_count() if (
        'PADDLE_TRAINERS_NUM') and (
            'PADDLE_TRAINER_ID'
        ) not in env else int(env.get('PADDLE_TRAINERS_NUM', 0))

    batch_size = dataset_config.batch_size
    assert batch_size % gpu_num == 0, \
        "batch size must be multiple of number of devices"
    batch_size = batch_size // gpu_num

    file_root = dataset_config.data_dir
    file_list = dataset_config.file_list

    interp = 1  # settings.interpolation or 1  # default to linear
    interp_map = {
        0: types.INTERP_NN,  # cv2.INTER_NEAREST
        1: types.INTERP_LINEAR,  # cv2.INTER_LINEAR
        2: types.INTERP_CUBIC,  # cv2.INTER_CUBIC
        4: types.INTERP_LANCZOS3,  # XXX use LANCZOS3 for cv2.INTER_LANCZOS4
    }
H
huangxu96 已提交
179 180 181 182 183

    output_dtype = (types.FLOAT16 if 'AMP' in config and
                    config.AMP.get("use_pure_fp16", False) 
                    else types.FLOAT)
    
T
Tingquan Gao 已提交
184 185
    assert interp in interp_map, "interpolation method not supported by DALI"
    interp = interp_map[interp]
H
huangxu96 已提交
186 187 188 189
    pad_output = False
    image_shape = config.get("image_shape", None)
    if image_shape and image_shape[0] == 4:
        pad_output = True
T
Tingquan Gao 已提交
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231

    transforms = {
        k: v
        for d in dataset_config["transforms"] for k, v in d.items()
    }

    scale = transforms["NormalizeImage"].get("scale", 1.0 / 255)
    if isinstance(scale, str):
        scale = eval(scale)
    mean = transforms["NormalizeImage"].get("mean", [0.485, 0.456, 0.406])
    std = transforms["NormalizeImage"].get("std", [0.229, 0.224, 0.225])
    mean = [v / scale for v in mean]
    std = [v / scale for v in std]

    if mode == "train":
        resize_shorter = 256
        crop = transforms["RandCropImage"]["size"]
        scale = transforms["RandCropImage"].get("scale", [0.08, 1.])
        ratio = transforms["RandCropImage"].get("ratio", [3.0 / 4, 4.0 / 3])
        min_area = scale[0]
        lower = ratio[0]
        upper = ratio[1]

        if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env:
            shard_id = int(env['PADDLE_TRAINER_ID'])
            num_shards = int(env['PADDLE_TRAINERS_NUM'])
            device_id = int(env['FLAGS_selected_gpus'])
            pipe = HybridTrainPipe(
                file_root,
                file_list,
                batch_size,
                resize_shorter,
                crop,
                min_area,
                lower,
                upper,
                interp,
                mean,
                std,
                device_id,
                shard_id,
                num_shards,
H
huangxu96 已提交
232 233 234
                seed=42 + shard_id,
                pad_output=pad_output,
                output_dtype=output_dtype)
T
Tingquan Gao 已提交
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
            pipe.build()
            pipelines = [pipe]
            sample_per_shard = len(pipe) // num_shards
        else:
            pipelines = []
            places = fluid.framework.cuda_places()
            num_shards = len(places)
            for idx, p in enumerate(places):
                place = fluid.core.Place()
                place.set_place(p)
                device_id = place.gpu_device_id()
                pipe = HybridTrainPipe(
                    file_root,
                    file_list,
                    batch_size,
                    resize_shorter,
                    crop,
                    min_area,
                    lower,
                    upper,
                    interp,
                    mean,
                    std,
                    device_id,
                    idx,
                    num_shards,
H
huangxu96 已提交
261 262 263
                    seed=42 + idx,
                pad_output=pad_output,
                output_dtype=output_dtype)
T
Tingquan Gao 已提交
264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
                pipe.build()
                pipelines.append(pipe)
            sample_per_shard = len(pipelines[0])
        return DALIGenericIterator(
            pipelines, ['feed_image', 'feed_label'], size=sample_per_shard)
    else:
        resize_shorter = transforms["ResizeImage"].get("resize_short", 256)
        crop = transforms["CropImage"]["size"]

        p = fluid.framework.cuda_places()[0]
        place = fluid.core.Place()
        place.set_place(p)
        device_id = place.gpu_device_id()
        pipe = HybridValPipe(
            file_root,
            file_list,
            batch_size,
            resize_shorter,
            crop,
            interp,
            mean,
            std,
H
huangxu96 已提交
286 287 288
            device_id=device_id,
            pad_output=pad_output,
            output_dtype=output_dtype)
T
Tingquan Gao 已提交
289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
        pipe.build()
        return DALIGenericIterator(
            pipe, ['feed_image', 'feed_label'],
            size=len(pipe),
            dynamic_shape=True,
            fill_last_batch=True,
            last_batch_padded=True)


def train(config):
    return build(config, 'train')


def val(config):
    return build(config, 'valid')


def _to_Tensor(lod_tensor, dtype):
    data_tensor = fluid.layers.create_tensor(dtype=dtype)
    data = np.array(lod_tensor).astype(dtype)
    fluid.layers.assign(data, data_tensor)
    return data_tensor


def normalize(feeds, config):
    image, label = feeds['image'], feeds['label']
    img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
    img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
    image = fluid.layers.cast(image, 'float32')
    costant = fluid.layers.fill_constant(
        shape=[1], value=255.0, dtype='float32')
    image = fluid.layers.elementwise_div(image, costant)

    mean = fluid.layers.create_tensor(dtype="float32")
    fluid.layers.assign(input=img_mean.astype("float32"), output=mean)
    std = fluid.layers.create_tensor(dtype="float32")
    fluid.layers.assign(input=img_std.astype("float32"), output=std)

    image = fluid.layers.elementwise_sub(image, mean)
    image = fluid.layers.elementwise_div(image, std)

    image.stop_gradient = True
    feeds['image'] = image

    return feeds


def mix(feeds, config, is_train=True):
    env = os.environ
    gpu_num = paddle.fluid.core.get_cuda_device_count() if (
        'PADDLE_TRAINERS_NUM') and (
            'PADDLE_TRAINER_ID'
        ) not in env else int(env.get('PADDLE_TRAINERS_NUM', 0))

    batch_size = config.TRAIN.batch_size // gpu_num

    images = feeds['image']
    label = feeds['label']
    # TODO: hard code here, should be fixed!
    alpha = 0.2
    idx = _to_Tensor(np.random.permutation(batch_size), 'int32')
    lam = np.random.beta(alpha, alpha)

    images = lam * images + (1 - lam) * paddle.fluid.layers.gather(images, idx)

    feed = {
        'image': images,
        'feed_y_a': label,
        'feed_y_b': paddle.fluid.layers.gather(label, idx),
        'feed_lam': _to_Tensor([lam] * batch_size, 'float32')
    }

    return feed if is_train else feeds