未验证 提交 00e77dde 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #58 from heavengate/refine_compose

refine compose
......@@ -20,6 +20,7 @@ import argparse
import numpy as np
from hapi.model import Input, set_device
from hapi.vision.transforms import Compose
from check import check_gpu, check_version
from modeling import tsm_resnet50
......
......@@ -24,6 +24,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from hapi.model import Model, CrossEntropy, Input, set_device
from hapi.metrics import Accuracy
from hapi.vision.transforms import Compose
from modeling import tsm_resnet50
from check import check_gpu, check_version
......
......@@ -21,24 +21,7 @@ import logging
logger = logging.getLogger(__name__)
__all__ = ['GroupScale', 'GroupMultiScaleCrop', 'GroupRandomCrop',
'GroupRandomFlip', 'GroupCenterCrop', 'NormalizeImage',
'Compose']
class Compose(object):
def __init__(self, transforms=[]):
self.transforms = transforms
def __call__(self, *data):
for f in self.transforms:
try:
data = f(*data)
except Exception as e:
stack_info = traceback.format_exc()
logger.info("fail to perform transform [{}] with error: "
"{} and stack:\n{}".format(f, e, str(stack_info)))
raise e
return data
'GroupRandomFlip', 'GroupCenterCrop', 'NormalizeImage']
class GroupScale(object):
......
......@@ -27,7 +27,7 @@ from paddle.io import DataLoader
from hapi.model import Model, Input, set_device
from hapi.distributed import DistributedBatchSampler
from hapi.vision.transforms import BatchCompose
from hapi.vision.transforms import Compose, BatchCompose
from modeling import yolov3_darknet53, YoloLoss
from coco import COCODataset
......
......@@ -20,7 +20,6 @@ import traceback
import numpy as np
__all__ = [
"Compose",
'ColorDistort',
'RandomExpand',
'RandomCrop',
......@@ -34,37 +33,6 @@ __all__ = [
]
class Compose(object):
"""Composes several transforms together.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, *data):
for f in self.transforms:
try:
data = f(*data)
except Exception as e:
stack_info = traceback.format_exc()
print("fail to perform transform [{}] with error: "
"{} and stack:\n{}".format(f, e, str(stack_info)))
raise e
return data
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class ColorDistort(object):
"""Random color distortion.
......
......@@ -121,7 +121,7 @@ class Flowers(Dataset):
image = np.array(Image.open(io.BytesIO(image)))
if self.transform is not None:
image, label = self.transform(image, label)
image = self.transform(image)
return image, label
......
......@@ -149,7 +149,7 @@ class MNIST(Dataset):
def __getitem__(self, idx):
image, label = self.images[idx], self.labels[idx]
if self.transform is not None:
image, label = self.transform(image, label)
image = self.transform(image)
return image, label
def __len__(self):
......
......@@ -29,8 +29,10 @@ import traceback
from . import functional as F
if sys.version_info < (3, 3):
Sequence = collections.Sequence
Iterable = collections.Iterable
else:
Sequence = collections.abc.Sequence
Iterable = collections.abc.Iterable
__all__ = [
......@@ -54,20 +56,45 @@ __all__ = [
class Compose(object):
"""Composes several transforms together.
"""
Composes several transforms together use for composing list of transforms
together for a dataset transform.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Returns:
A compose object which is callable, __call__ for this Compose
object will call each given :attr:`transforms` sequencely.
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, ColorJitter, Resize
transform = Compose([ColorJitter(), Resize(size=608)])
flowers = Flowers(mode='test', transform=transform)
for i in range(10):
sample = flowers[i]
print(sample[0].shape, sample[1])
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, data):
def __call__(self, *data):
for f in self.transforms:
try:
data = f(data)
# multi-fileds in a sample
if isinstance(data, Sequence):
data = f(*data)
# single field in a sample, call transform directly
else:
data = f(data)
except Exception as e:
stack_info = traceback.format_exc()
print("fail to perform transform [{}] with error: "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册