未验证 提交 02db4c3b 编写于 作者: L lilong12 提交者: GitHub

Remove PyTorch APIs (#2280)

* add transforms.py and rename torchvision_reader.py to reader.py

* add datasets.py

* remove pytorch apis

* fix some small bugs

* remove torch and torchvision from requirements.txt

* modify core.CUDAPlace to fluid.CUDAPlace
上级 6bd90a39
from PIL import Image
import os
import os.path
import sys
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
def has_file_allowed_extension(filename, extensions):
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
extensions (tuple of strings): extensions to consider (lowercase)
Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)
def is_image_file(filename):
"""Checks if a file is an allowed image extension.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
def make_dataset(dir, class_to_idx, extensions):
images = []
dir = os.path.expanduser(dir)
for target in sorted(class_to_idx.keys()):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
if has_file_allowed_extension(fname, extensions):
path = os.path.join(root, fname)
item = (path, class_to_idx[target])
images.append(item)
return images
class DatasetFolder(object):
"""A generic data loader where the samples are arranged in this way: ::
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext
Args:
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (tuple[string]): A list of allowed extensions.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
targets (list): The class_index value for each image in the dataset
"""
def __init__(self,
root,
loader,
extensions,
transform=None,
target_transform=None):
self.root = root
self.transform = transform
self.target_transform = target_transform
classes, class_to_idx = self._find_classes(self.root)
samples = make_dataset(self.root, class_to_idx, extensions)
if len(samples) == 0:
raise (RuntimeError(
"Found 0 files in subfolders of: " + self.root + "\n"
"Supported extensions are: " + ",".join(extensions)))
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
def _find_classes(self, dir):
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
if sys.version_info >= (3, 5):
# Faster and available in Python 3.5 and above
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
else:
classes = [
d for d in os.listdir(dir)
if os.path.isdir(os.path.join(dir, d))
]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def default_loader(path):
return pil_loader(path)
class ImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(self,
root,
transform=None,
target_transform=None,
loader=default_loader):
super(ImageFolder, self).__init__(
root,
loader,
IMG_EXTENSIONS,
transform=transform,
target_transform=target_transform)
self.imgs = self.samples
...@@ -4,31 +4,27 @@ import os ...@@ -4,31 +4,27 @@ import os
import numpy as np import numpy as np
import math import math
import random import random
import torch
import torch.utils.data
from torch.utils.data.distributed import DistributedSampler
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data.sampler import Sampler
import torchvision
import pickle import pickle
from tqdm import tqdm from tqdm import tqdm
import time import time
import multiprocessing import multiprocessing
import transforms
import datasets
FINISH_EVENT = "FINISH_EVENT" FINISH_EVENT = "FINISH_EVENT"
class PaddleDataLoader(object): class PaddleDataLoader(object):
def __init__(self, def __init__(self,
torch_dataset, dataset,
indices=None, indices=None,
concurrent=24, concurrent=24,
queue_size=3072, queue_size=3072,
shuffle=True, shuffle=True,
shuffle_seed=0): shuffle_seed=0):
self.torch_dataset = torch_dataset self.dataset = dataset
self.indices = indices self.indices = indices
self.concurrent = concurrent self.concurrent = concurrent
self.shuffle = shuffle self.shuffle = shuffle
...@@ -39,7 +35,7 @@ class PaddleDataLoader(object): ...@@ -39,7 +35,7 @@ class PaddleDataLoader(object):
cnt = 0 cnt = 0
for idx in worker_indices: for idx in worker_indices:
cnt += 1 cnt += 1
img, label = self.torch_dataset[idx] img, label = self.dataset[idx]
img = np.array(img).astype('uint8').transpose((2, 0, 1)) img = np.array(img).astype('uint8').transpose((2, 0, 1))
queue.put((img, label)) queue.put((img, label))
print("worker: [%d] read [%d] samples. " % (worker_id, cnt)) print("worker: [%d] read [%d] samples. " % (worker_id, cnt))
...@@ -49,7 +45,7 @@ class PaddleDataLoader(object): ...@@ -49,7 +45,7 @@ class PaddleDataLoader(object):
def _reader_creator(): def _reader_creator():
worker_processes = [] worker_processes = []
index_queues = [] index_queues = []
total_img = len(self.torch_dataset) total_img = len(self.dataset)
print("total image: ", total_img) print("total image: ", total_img)
if self.shuffle: if self.shuffle:
self.indices = [i for i in xrange(total_img)] self.indices = [i for i in xrange(total_img)]
...@@ -146,7 +142,7 @@ class CropArTfm(object): ...@@ -146,7 +142,7 @@ class CropArTfm(object):
else: else:
h = int(self.target_size * target_ar) h = int(self.target_size * target_ar)
size = (self.target_size, h // 8 * 8) size = (self.target_size, h // 8 * 8)
return torchvision.transforms.functional.center_crop(img, size) return transforms.center_crop(img, size)
def sort_ar(valdir): def sort_ar(valdir):
......
...@@ -19,8 +19,8 @@ import os ...@@ -19,8 +19,8 @@ import os
import traceback import traceback
import numpy as np import numpy as np
import torchvision_reader import math
import torch import reader
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.profiler as profiler import paddle.fluid.profiler as profiler
...@@ -235,9 +235,19 @@ def refresh_program(args, ...@@ -235,9 +235,19 @@ def refresh_program(args,
if var.name.startswith('conv2d_') if var.name.startswith('conv2d_')
] ]
for var in conv2d_w_vars: for var in conv2d_w_vars:
torch_w = torch.empty(var.shape) #torch_w = torch.empty(var.shape)
kaiming_np = torch.nn.init.kaiming_normal_( #kaiming_np = torch.nn.init.kaiming_normal_(torch_w, mode='fan_out', nonlinearity='relu').numpy()
torch_w, mode='fan_out', nonlinearity='relu').numpy() shape = var.shape
if not shape or len(shape) == 0:
fan_out = 1
elif len(shape) == 1:
fan_out = shape[0]
elif len(shape) == 2:
fan_out = shape[1]
else:
fan_out = shape[0] * np.prod(shape[2:])
std = math.sqrt(2.0 / fan_out)
kaiming_np = np.random.normal(0, std, var.shape)
tensor = fluid.global_scope().find_var(var.name).get_tensor() tensor = fluid.global_scope().find_var(var.name).get_tensor()
if args.fp16: if args.fp16:
tensor.set(np.array( tensor.set(np.array(
...@@ -283,7 +293,7 @@ def refresh_program(args, ...@@ -283,7 +293,7 @@ def refresh_program(args,
def prepare_reader(epoch_id, train_py_reader, train_bs, val_bs, trn_dir, def prepare_reader(epoch_id, train_py_reader, train_bs, val_bs, trn_dir,
img_dim, min_scale, rect_val): img_dim, min_scale, rect_val):
train_reader = torchvision_reader.train( train_reader = reader.train(
traindir="/data/imagenet/%strain" % trn_dir, traindir="/data/imagenet/%strain" % trn_dir,
sz=img_dim, sz=img_dim,
min_scale=min_scale, min_scale=min_scale,
...@@ -292,7 +302,7 @@ def prepare_reader(epoch_id, train_py_reader, train_bs, val_bs, trn_dir, ...@@ -292,7 +302,7 @@ def prepare_reader(epoch_id, train_py_reader, train_bs, val_bs, trn_dir,
paddle.batch( paddle.batch(
train_reader, batch_size=train_bs)) train_reader, batch_size=train_bs))
test_reader = torchvision_reader.test( test_reader = reader.test(
valdir="/data/imagenet/%svalidation" % trn_dir, valdir="/data/imagenet/%svalidation" % trn_dir,
bs=val_bs * DEVICE_NUM, bs=val_bs * DEVICE_NUM,
sz=img_dim, sz=img_dim,
......
from __future__ import division
import math
import random
from PIL import Image
import numpy as np
import warnings
__all__ = [
"Compose", "Resize", "Scale", "RandomHorizontalFlip", "RandomResizedCrop",
"CenterCrop"
]
def _is_pil_image(img):
return isinstance(img, Image.Image)
def crop(img, i, j, h, w):
if not _is_pil_image(img):
raise TypeError('img should be a PIL Image, but be {}'.format(
type(img)))
return img.crop((j, i, j + w, i + h))
def resize(img, size, interpolation=Image.BILINEAR):
if not _is_pil_image(img):
raise TypeError('img should be a PIL Image, but be {}'.format(
type(img)))
if not (isinstance(size, int) or
(isinstance(size, tuple) and len(size) == 2)):
raise TypeError('Wrong size arg: {}'.format(size))
if isinstance(size, int):
w, h = img.size
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
ow = size
oh = int(size * h / w)
return img.resize((ow, oh), interpolation)
else:
oh = size
ow = int(size * w / h)
return img.resize((ow, oh), interpolation)
else:
return img.resize(size[::-1], interpolation)
def center_crop(img, output_size):
if isinstance(output_size, int):
output_size = (output_size, output_size)
w, h = img.size
th, tw = output_size
i = int(round((h - th) / 2.))
j = int(round((w - tw) / 2.))
return crop(img, i, j, th, tw)
class Compose(object):
"""Make some transforms in a chain.
Args:
transforms (list): list of transforms to be in a chain.
"""
def __init__(self, transforms):
self._transforms = transforms
def __call__(self, img):
for t in self._transforms:
img = t(img)
return img
class Resize(object):
"""Resize the input PIL Image.
Args:
size (tuple | int): Output size. If the size is a tuple,
resize image to that size (h, w). If the size is an int,
smaller edge of the image will be resized to this size.
interpolation (int): Interpolation method.
"""
def __init__(self, size, interpolation=Image.BILINEAR):
assert isinstance(size, int) or (isinstance(size, tuple) and
len(size) == 2)
self._size = size
self._interpolation = interpolation
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be resized.
Returns:
PIL Image: Resized image.
"""
return resize(img, self._size, self._interpolation)
class Scale(Resize):
"""
Note: This transform is deprecated in favor of Resize.
"""
def __init__(self, *args, **kwargs):
warnings.warn(
"The use of the transforms.Scale transform is deprecated, " +
"please use transforms.Resize instead.")
super(Scale, self).__init__(*args, **kwargs)
class RandomHorizontalFlip(object):
"""Horizontally flip the given PIL Image randomly with a given probability.
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be flipped.
Returns:
PIL Image: Randomly flipped image.
"""
if random.random() < self.p:
if not _is_pil_image(img):
raise TypeError('img should be a PIL image, but be {}'.format(
type(img)))
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
class RandomResizedCrop(object):
"""Crop the input PIL Image to random size and aspect ratio, and then
resize the PIL Image to target size.
Args:
size: target size
scale: range of ratio of the origin size to be cropped
ratio: range of aspect ratio of the origin aspect ratio to be cropped
interpolation: interpolation method
"""
def __init__(self,
size,
scale=(0.08, 1.0),
ratio=(3. / 4., 4. / 3.),
interpolation=Image.BILINEAR):
if isinstance(size, tuple):
self._size = size
else:
self._size = (size, size)
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
raise ErrorValue("range should be of kind of (min, max)")
self._interpolation = interpolation
self._scale = scale
self._ratio = ratio
@staticmethod
def get_params(img, scale, ratio):
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image): Image to be cropped.
scale (tuple): range of size of the origin size cropped
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
area = img.size[0] * img.size[1]
for attempt in range(10):
target_area = random.uniform(*scale) * area
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= img.size[0] and h <= img.size[1]:
i = random.randint(0, img.size[1] - h)
j = random.randint(0, img.size[0] - w)
return i, j, h, w
# Fallback to central crop
in_ratio = img.size[0] / img.size[1]
if (in_ratio < min(ratio)):
w = img.size[0]
h = w / min(ratio)
elif (in_ratio > max(ratio)):
h = img.size[1]
w = h * max(ratio)
else: # whole image
w = img.size[0]
h = img.size[1]
i = (img.size[1] - h) // 2
j = (img.size[0] - w) // 2
return i, j, h, w
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be cropped and resized.
Returns:
PIL Image: Randomly cropped and resized image.
"""
i, j, h, w = self.get_params(img, self._scale, self._ratio)
assert _is_pil_image(img), 'image should be a PIL Image'
img = crop(img, i, j, h, w)
img = resize(img, self._size, self._interpolation)
return img
class CenterCrop(object):
"""Crops the given PIL Image at the center.
Args:
size (tuple|int): Output size. If size is an int instead of a tuple
like (h, w), a square crop (size, size) is made.
"""
def __init__(self, size):
if isinstance(size, int):
self.size = (size, size)
else:
self.size = size
def __call__(self, img):
return center_crop(img, self.size)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册