提交 d9520c24 编写于 作者: D dengkaipeng

refine compose

上级 c88f4570
...@@ -27,7 +27,7 @@ from paddle.io import DataLoader ...@@ -27,7 +27,7 @@ from paddle.io import DataLoader
from hapi.model import Model, Input, set_device from hapi.model import Model, Input, set_device
from hapi.distributed import DistributedBatchSampler from hapi.distributed import DistributedBatchSampler
from hapi.vision.transforms import BatchCompose from hapi.vision.transforms import Compose, BatchCompose
from modeling import yolov3_darknet53, YoloLoss from modeling import yolov3_darknet53, YoloLoss
from coco import COCODataset from coco import COCODataset
......
...@@ -20,7 +20,6 @@ import traceback ...@@ -20,7 +20,6 @@ import traceback
import numpy as np import numpy as np
__all__ = [ __all__ = [
"Compose",
'ColorDistort', 'ColorDistort',
'RandomExpand', 'RandomExpand',
'RandomCrop', 'RandomCrop',
...@@ -34,37 +33,6 @@ __all__ = [ ...@@ -34,37 +33,6 @@ __all__ = [
] ]
class Compose(object):
"""Composes several transforms together.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, *data):
for f in self.transforms:
try:
data = f(*data)
except Exception as e:
stack_info = traceback.format_exc()
print("fail to perform transform [{}] with error: "
"{} and stack:\n{}".format(f, e, str(stack_info)))
raise e
return data
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class ColorDistort(object): class ColorDistort(object):
"""Random color distortion. """Random color distortion.
......
...@@ -29,8 +29,10 @@ import traceback ...@@ -29,8 +29,10 @@ import traceback
from . import functional as F from . import functional as F
if sys.version_info < (3, 3): if sys.version_info < (3, 3):
Sequence = collections.Sequence
Iterable = collections.Iterable Iterable = collections.Iterable
else: else:
Sequence = collections.abc.Sequence
Iterable = collections.abc.Iterable Iterable = collections.abc.Iterable
__all__ = [ __all__ = [
...@@ -64,10 +66,13 @@ class Compose(object): ...@@ -64,10 +66,13 @@ class Compose(object):
def __init__(self, transforms): def __init__(self, transforms):
self.transforms = transforms self.transforms = transforms
def __call__(self, data): def __call__(self, *data):
for f in self.transforms: for f in self.transforms:
try: try:
data = f(data) if isinstance(data, Sequence):
data = f(*data)
else:
data = f(data)
except Exception as e: except Exception as e:
stack_info = traceback.format_exc() stack_info = traceback.format_exc()
print("fail to perform transform [{}] with error: " print("fail to perform transform [{}] with error: "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册