From 02db4c3b473c504eb6f4f6fa013cfd3659566195 Mon Sep 17 00:00:00 2001 From: lilong12 Date: Mon, 10 Jun 2019 13:43:24 +0800 Subject: [PATCH] Remove PyTorch APIs (#2280) * add transforms.py and rename torchvision_reader.py to reader.py * add datasets.py * remove pytorch apis * fix some small bugs * remove torch and torchvision from requirements.txt * modify core.CUDAPlace to fluid.CUDAPlace --- .../fast_imagenet/datasets.py | 198 +++++++++++++++ .../{torchvision_reader.py => reader.py} | 22 +- .../fast_imagenet/requirements.txt | 2 - .../fast_imagenet/train.py | 24 +- .../fast_imagenet/transforms.py | 239 ++++++++++++++++++ 5 files changed, 463 insertions(+), 22 deletions(-) create mode 100644 PaddleCV/image_classification/fast_imagenet/datasets.py rename PaddleCV/image_classification/fast_imagenet/{torchvision_reader.py => reader.py} (92%) create mode 100644 PaddleCV/image_classification/fast_imagenet/transforms.py diff --git a/PaddleCV/image_classification/fast_imagenet/datasets.py b/PaddleCV/image_classification/fast_imagenet/datasets.py new file mode 100644 index 00000000..866e58db --- /dev/null +++ b/PaddleCV/image_classification/fast_imagenet/datasets.py @@ -0,0 +1,198 @@ +from PIL import Image + +import os +import os.path +import sys + +IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff') + + +def has_file_allowed_extension(filename, extensions): + """Checks if a file is an allowed extension. + + Args: + filename (string): path to a file + extensions (tuple of strings): extensions to consider (lowercase) + + Returns: + bool: True if the filename ends with one of given extensions + """ + return filename.lower().endswith(extensions) + + +def is_image_file(filename): + """Checks if a file is an allowed image extension. + + Args: + filename (string): path to a file + + Returns: + bool: True if the filename ends with a known image extension + """ + return has_file_allowed_extension(filename, IMG_EXTENSIONS) + + +def make_dataset(dir, class_to_idx, extensions): + images = [] + dir = os.path.expanduser(dir) + for target in sorted(class_to_idx.keys()): + d = os.path.join(dir, target) + if not os.path.isdir(d): + continue + + for root, _, fnames in sorted(os.walk(d)): + for fname in sorted(fnames): + if has_file_allowed_extension(fname, extensions): + path = os.path.join(root, fname) + item = (path, class_to_idx[target]) + images.append(item) + + return images + + +class DatasetFolder(object): + """A generic data loader where the samples are arranged in this way: :: + + root/class_x/xxx.ext + root/class_x/xxy.ext + root/class_x/xxz.ext + + root/class_y/123.ext + root/class_y/nsdf3.ext + root/class_y/asd932_.ext + + Args: + root (string): Root directory path. + loader (callable): A function to load a sample given its path. + extensions (tuple[string]): A list of allowed extensions. + transform (callable, optional): A function/transform that takes in + a sample and returns a transformed version. + E.g, ``transforms.RandomCrop`` for images. + target_transform (callable, optional): A function/transform that takes + in the target and transforms it. + + Attributes: + classes (list): List of the class names. + class_to_idx (dict): Dict with items (class_name, class_index). + samples (list): List of (sample path, class_index) tuples + targets (list): The class_index value for each image in the dataset + """ + + def __init__(self, + root, + loader, + extensions, + transform=None, + target_transform=None): + self.root = root + self.transform = transform + self.target_transform = target_transform + classes, class_to_idx = self._find_classes(self.root) + samples = make_dataset(self.root, class_to_idx, extensions) + if len(samples) == 0: + raise (RuntimeError( + "Found 0 files in subfolders of: " + self.root + "\n" + "Supported extensions are: " + ",".join(extensions))) + + self.loader = loader + self.extensions = extensions + + self.classes = classes + self.class_to_idx = class_to_idx + self.samples = samples + self.targets = [s[1] for s in samples] + + def _find_classes(self, dir): + """ + Finds the class folders in a dataset. + + Args: + dir (string): Root directory path. + + Returns: + tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. + + Ensures: + No class is a subdirectory of another. + """ + if sys.version_info >= (3, 5): + # Faster and available in Python 3.5 and above + classes = [d.name for d in os.scandir(dir) if d.is_dir()] + else: + classes = [ + d for d in os.listdir(dir) + if os.path.isdir(os.path.join(dir, d)) + ] + classes.sort() + class_to_idx = {classes[i]: i for i in range(len(classes))} + return classes, class_to_idx + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (sample, target) where target is class_index of the target class. + """ + path, target = self.samples[index] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + + return sample, target + + def __len__(self): + return len(self.samples) + + +def pil_loader(path): + # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) + with open(path, 'rb') as f: + img = Image.open(f) + return img.convert('RGB') + + +def default_loader(path): + return pil_loader(path) + + +class ImageFolder(DatasetFolder): + """A generic data loader where the images are arranged in this way: :: + + root/dog/xxx.png + root/dog/xxy.png + root/dog/xxz.png + + root/cat/123.png + root/cat/nsdf3.png + root/cat/asd932_.png + + Args: + root (string): Root directory path. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + loader (callable, optional): A function to load an image given its path. + + Attributes: + classes (list): List of the class names. + class_to_idx (dict): Dict with items (class_name, class_index). + imgs (list): List of (image path, class_index) tuples + """ + + def __init__(self, + root, + transform=None, + target_transform=None, + loader=default_loader): + super(ImageFolder, self).__init__( + root, + loader, + IMG_EXTENSIONS, + transform=transform, + target_transform=target_transform) + self.imgs = self.samples diff --git a/PaddleCV/image_classification/fast_imagenet/torchvision_reader.py b/PaddleCV/image_classification/fast_imagenet/reader.py similarity index 92% rename from PaddleCV/image_classification/fast_imagenet/torchvision_reader.py rename to PaddleCV/image_classification/fast_imagenet/reader.py index 20500816..9d277322 100644 --- a/PaddleCV/image_classification/fast_imagenet/torchvision_reader.py +++ b/PaddleCV/image_classification/fast_imagenet/reader.py @@ -4,31 +4,27 @@ import os import numpy as np import math import random -import torch -import torch.utils.data -from torch.utils.data.distributed import DistributedSampler -import torchvision.transforms as transforms -import torchvision.datasets as datasets - -from torch.utils.data.sampler import Sampler -import torchvision + import pickle from tqdm import tqdm import time import multiprocessing +import transforms +import datasets + FINISH_EVENT = "FINISH_EVENT" class PaddleDataLoader(object): def __init__(self, - torch_dataset, + dataset, indices=None, concurrent=24, queue_size=3072, shuffle=True, shuffle_seed=0): - self.torch_dataset = torch_dataset + self.dataset = dataset self.indices = indices self.concurrent = concurrent self.shuffle = shuffle @@ -39,7 +35,7 @@ class PaddleDataLoader(object): cnt = 0 for idx in worker_indices: cnt += 1 - img, label = self.torch_dataset[idx] + img, label = self.dataset[idx] img = np.array(img).astype('uint8').transpose((2, 0, 1)) queue.put((img, label)) print("worker: [%d] read [%d] samples. " % (worker_id, cnt)) @@ -49,7 +45,7 @@ class PaddleDataLoader(object): def _reader_creator(): worker_processes = [] index_queues = [] - total_img = len(self.torch_dataset) + total_img = len(self.dataset) print("total image: ", total_img) if self.shuffle: self.indices = [i for i in xrange(total_img)] @@ -146,7 +142,7 @@ class CropArTfm(object): else: h = int(self.target_size * target_ar) size = (self.target_size, h // 8 * 8) - return torchvision.transforms.functional.center_crop(img, size) + return transforms.center_crop(img, size) def sort_ar(valdir): diff --git a/PaddleCV/image_classification/fast_imagenet/requirements.txt b/PaddleCV/image_classification/fast_imagenet/requirements.txt index 5e13381c..78620c47 100644 --- a/PaddleCV/image_classification/fast_imagenet/requirements.txt +++ b/PaddleCV/image_classification/fast_imagenet/requirements.txt @@ -1,3 +1 @@ -torch==0.4.1 -torchvision tqdm diff --git a/PaddleCV/image_classification/fast_imagenet/train.py b/PaddleCV/image_classification/fast_imagenet/train.py index 33300f10..95c8782d 100644 --- a/PaddleCV/image_classification/fast_imagenet/train.py +++ b/PaddleCV/image_classification/fast_imagenet/train.py @@ -19,8 +19,8 @@ import os import traceback import numpy as np -import torchvision_reader -import torch +import math +import reader import paddle import paddle.fluid as fluid import paddle.fluid.profiler as profiler @@ -235,9 +235,19 @@ def refresh_program(args, if var.name.startswith('conv2d_') ] for var in conv2d_w_vars: - torch_w = torch.empty(var.shape) - kaiming_np = torch.nn.init.kaiming_normal_( - torch_w, mode='fan_out', nonlinearity='relu').numpy() + #torch_w = torch.empty(var.shape) + #kaiming_np = torch.nn.init.kaiming_normal_(torch_w, mode='fan_out', nonlinearity='relu').numpy() + shape = var.shape + if not shape or len(shape) == 0: + fan_out = 1 + elif len(shape) == 1: + fan_out = shape[0] + elif len(shape) == 2: + fan_out = shape[1] + else: + fan_out = shape[0] * np.prod(shape[2:]) + std = math.sqrt(2.0 / fan_out) + kaiming_np = np.random.normal(0, std, var.shape) tensor = fluid.global_scope().find_var(var.name).get_tensor() if args.fp16: tensor.set(np.array( @@ -283,7 +293,7 @@ def refresh_program(args, def prepare_reader(epoch_id, train_py_reader, train_bs, val_bs, trn_dir, img_dim, min_scale, rect_val): - train_reader = torchvision_reader.train( + train_reader = reader.train( traindir="/data/imagenet/%strain" % trn_dir, sz=img_dim, min_scale=min_scale, @@ -292,7 +302,7 @@ def prepare_reader(epoch_id, train_py_reader, train_bs, val_bs, trn_dir, paddle.batch( train_reader, batch_size=train_bs)) - test_reader = torchvision_reader.test( + test_reader = reader.test( valdir="/data/imagenet/%svalidation" % trn_dir, bs=val_bs * DEVICE_NUM, sz=img_dim, diff --git a/PaddleCV/image_classification/fast_imagenet/transforms.py b/PaddleCV/image_classification/fast_imagenet/transforms.py new file mode 100644 index 00000000..ad9707ba --- /dev/null +++ b/PaddleCV/image_classification/fast_imagenet/transforms.py @@ -0,0 +1,239 @@ +from __future__ import division +import math +import random +from PIL import Image +import numpy as np +import warnings + +__all__ = [ + "Compose", "Resize", "Scale", "RandomHorizontalFlip", "RandomResizedCrop", + "CenterCrop" +] + + +def _is_pil_image(img): + return isinstance(img, Image.Image) + + +def crop(img, i, j, h, w): + if not _is_pil_image(img): + raise TypeError('img should be a PIL Image, but be {}'.format( + type(img))) + return img.crop((j, i, j + w, i + h)) + + +def resize(img, size, interpolation=Image.BILINEAR): + if not _is_pil_image(img): + raise TypeError('img should be a PIL Image, but be {}'.format( + type(img))) + if not (isinstance(size, int) or + (isinstance(size, tuple) and len(size) == 2)): + raise TypeError('Wrong size arg: {}'.format(size)) + if isinstance(size, int): + w, h = img.size + if (w <= h and w == size) or (h <= w and h == size): + return img + if w < h: + ow = size + oh = int(size * h / w) + return img.resize((ow, oh), interpolation) + else: + oh = size + ow = int(size * w / h) + return img.resize((ow, oh), interpolation) + else: + return img.resize(size[::-1], interpolation) + + +def center_crop(img, output_size): + if isinstance(output_size, int): + output_size = (output_size, output_size) + w, h = img.size + th, tw = output_size + i = int(round((h - th) / 2.)) + j = int(round((w - tw) / 2.)) + return crop(img, i, j, th, tw) + + +class Compose(object): + """Make some transforms in a chain. + + Args: + transforms (list): list of transforms to be in a chain. + """ + + def __init__(self, transforms): + self._transforms = transforms + + def __call__(self, img): + for t in self._transforms: + img = t(img) + return img + + +class Resize(object): + """Resize the input PIL Image. + + Args: + size (tuple | int): Output size. If the size is a tuple, + resize image to that size (h, w). If the size is an int, + smaller edge of the image will be resized to this size. + interpolation (int): Interpolation method. + """ + + def __init__(self, size, interpolation=Image.BILINEAR): + assert isinstance(size, int) or (isinstance(size, tuple) and + len(size) == 2) + self._size = size + self._interpolation = interpolation + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be resized. + + Returns: + PIL Image: Resized image. + """ + return resize(img, self._size, self._interpolation) + + +class Scale(Resize): + """ + Note: This transform is deprecated in favor of Resize. + """ + + def __init__(self, *args, **kwargs): + warnings.warn( + "The use of the transforms.Scale transform is deprecated, " + + "please use transforms.Resize instead.") + super(Scale, self).__init__(*args, **kwargs) + + +class RandomHorizontalFlip(object): + """Horizontally flip the given PIL Image randomly with a given probability. + + Args: + p (float): probability of the image being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be flipped. + + Returns: + PIL Image: Randomly flipped image. + """ + if random.random() < self.p: + if not _is_pil_image(img): + raise TypeError('img should be a PIL image, but be {}'.format( + type(img))) + return img.transpose(Image.FLIP_LEFT_RIGHT) + return img + + +class RandomResizedCrop(object): + """Crop the input PIL Image to random size and aspect ratio, and then + resize the PIL Image to target size. + + Args: + size: target size + scale: range of ratio of the origin size to be cropped + ratio: range of aspect ratio of the origin aspect ratio to be cropped + interpolation: interpolation method + """ + + def __init__(self, + size, + scale=(0.08, 1.0), + ratio=(3. / 4., 4. / 3.), + interpolation=Image.BILINEAR): + if isinstance(size, tuple): + self._size = size + else: + self._size = (size, size) + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + raise ErrorValue("range should be of kind of (min, max)") + + self._interpolation = interpolation + self._scale = scale + self._ratio = ratio + + @staticmethod + def get_params(img, scale, ratio): + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (PIL Image): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect ratio cropped + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + area = img.size[0] * img.size[1] + + for attempt in range(10): + target_area = random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if w <= img.size[0] and h <= img.size[1]: + i = random.randint(0, img.size[1] - h) + j = random.randint(0, img.size[0] - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = img.size[0] / img.size[1] + if (in_ratio < min(ratio)): + w = img.size[0] + h = w / min(ratio) + elif (in_ratio > max(ratio)): + h = img.size[1] + w = h * max(ratio) + else: # whole image + w = img.size[0] + h = img.size[1] + i = (img.size[1] - h) // 2 + j = (img.size[0] - w) // 2 + return i, j, h, w + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped and resized. + + Returns: + PIL Image: Randomly cropped and resized image. + """ + i, j, h, w = self.get_params(img, self._scale, self._ratio) + assert _is_pil_image(img), 'image should be a PIL Image' + img = crop(img, i, j, h, w) + img = resize(img, self._size, self._interpolation) + return img + + +class CenterCrop(object): + """Crops the given PIL Image at the center. + + Args: + size (tuple|int): Output size. If size is an int instead of a tuple + like (h, w), a square crop (size, size) is made. + """ + + def __init__(self, size): + if isinstance(size, int): + self.size = (size, size) + else: + self.size = size + + def __call__(self, img): + return center_crop(img, self.size) -- GitLab