提交 d9520c24 编写于 作者: D dengkaipeng

refine compose

上级 c88f4570
......@@ -27,7 +27,7 @@ from paddle.io import DataLoader
from hapi.model import Model, Input, set_device
from hapi.distributed import DistributedBatchSampler
from hapi.vision.transforms import BatchCompose
from hapi.vision.transforms import Compose, BatchCompose
from modeling import yolov3_darknet53, YoloLoss
from coco import COCODataset
......
......@@ -20,7 +20,6 @@ import traceback
import numpy as np
__all__ = [
"Compose",
'ColorDistort',
'RandomExpand',
'RandomCrop',
......@@ -34,37 +33,6 @@ __all__ = [
]
class Compose(object):
"""Composes several transforms together.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, *data):
for f in self.transforms:
try:
data = f(*data)
except Exception as e:
stack_info = traceback.format_exc()
print("fail to perform transform [{}] with error: "
"{} and stack:\n{}".format(f, e, str(stack_info)))
raise e
return data
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class ColorDistort(object):
"""Random color distortion.
......
......@@ -29,8 +29,10 @@ import traceback
from . import functional as F
if sys.version_info < (3, 3):
Sequence = collections.Sequence
Iterable = collections.Iterable
else:
Sequence = collections.abc.Sequence
Iterable = collections.abc.Iterable
__all__ = [
......@@ -64,10 +66,13 @@ class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, data):
def __call__(self, *data):
for f in self.transforms:
try:
data = f(data)
if isinstance(data, Sequence):
data = f(*data)
else:
data = f(data)
except Exception as e:
stack_info = traceback.format_exc()
print("fail to perform transform [{}] with error: "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册