提交 49784f25 编写于 作者: D dengkaipeng

update paddle.io

上级 d90f7cc6
......@@ -22,7 +22,7 @@ import sys
sys.path.append('../')
from distributed import DistributedBatchSampler
from paddle.fluid.io import Dataset, DataLoader
from paddle.io import Dataset, DataLoader
logger = logging.getLogger(__name__)
......
......@@ -30,7 +30,7 @@ IMAGES_ROOT = "./data/" + DATASET + "/"
import paddle.fluid as fluid
class Cityscapes(fluid.io.Dataset):
class Cityscapes(paddle.io.Dataset):
def __init__(self, root_path, file_path, mode='train', return_name=False):
self.root_path = root_path
self.file_path = file_path
......
......@@ -86,13 +86,13 @@ def main():
if FLAGS.resume:
g.load(FLAGS.resume)
loader_A = fluid.io.DataLoader(
loader_A = paddle.io.DataLoader(
data.DataA(),
places=place,
shuffle=True,
return_list=True,
batch_size=FLAGS.batch_size)
loader_B = fluid.io.DataLoader(
loader_B = paddle.io.DataLoader(
data.DataB(),
places=place,
shuffle=True,
......
......@@ -23,7 +23,7 @@ import numpy as np
from paddle import fluid
from paddle.fluid.layers import collective
from paddle.fluid.dygraph.parallel import ParallelEnv, ParallelStrategy
from paddle.fluid.io import BatchSampler
from paddle.io import BatchSampler
_parallel_context_initialized = False
......@@ -39,7 +39,7 @@ class DistributedBatchSampler(BatchSampler):
Dataset is assumed to be of constant size.
Args:
data_source: this could be a `fluid.io.Dataset` implement
data_source: this could be a `paddle.io.Dataset` implement
or other python object which implemented
`__len__` for BatchSampler to get sample
number of data source.
......
......@@ -32,7 +32,7 @@ from imagenet_dataset import ImageNetDataset
from distributed import DistributedBatchSampler
from paddle.fluid.dygraph.parallel import ParallelEnv
from metrics import Accuracy
from paddle.fluid.io import BatchSampler, DataLoader
from paddle.io import BatchSampler, DataLoader
def make_optimizer(step_per_epoch, parameter_list=None):
......
......@@ -24,7 +24,7 @@ import numpy as np
from paddle import fluid
from paddle.fluid.optimizer import Momentum
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from paddle.fluid.io import MNIST as MnistDataset
from .vision.datasets import MNIST as MnistDataset
from model import Model, CrossEntropy, Input, set_device
from metrics import Accuracy
......
......@@ -32,7 +32,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.layers.utils import flatten
from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy
from paddle.fluid.incubate.fleet.base import role_maker
from paddle.fluid.io import DataLoader, Dataset
from paddle.io import DataLoader, Dataset
from distributed import DistributedBatchSampler, _all_gather, prepare_distributed_context, _parallel_context_initialized
from metrics import Metric
......@@ -913,11 +913,11 @@ class Model(fluid.dygraph.Layer):
FIXME: add more comments and usage
Args:
train_data (Dataset|DataLoader): An iterable data loader is used for
train. An instance of paddle.fluid.io.Dataset or
paddle.fluid.io.Dataloader is recomended.
train. An instance of paddle paddle.io.Dataset or
paddle.io.Dataloader is recomended.
eval_data (Dataset|DataLoader): An iterable data loader is used for
evaluation at the end of epoch. If None, will not do evaluation.
An instance of paddle.fluid.io.Dataset or paddle.fluid.io.Dataloader
An instance of paddle.io.Dataset or paddle.io.Dataloader
is recomended.
batch_size (int): Integer number. The batch size of train_data and eval_data.
When train_data and eval_data are both the instance of Dataloader, this
......@@ -1041,8 +1041,8 @@ class Model(fluid.dygraph.Layer):
FIXME: add more comments and usage
Args:
eval_data (Dataset|DataLoader): An iterable data loader is used for
evaluation. An instance of paddle.fluid.io.Dataset or
paddle.fluid.io.Dataloader is recomended.
evaluation. An instance of paddle.io.Dataset or
paddle.io.Dataloader is recomended.
batch_size (int): Integer number. The batch size of train_data and eval_data.
When train_data and eval_data are both the instance of Dataloader, this
parameter will be ignored.
......@@ -1116,7 +1116,7 @@ class Model(fluid.dygraph.Layer):
FIXME: add more comments and usage
Args:
test_data (Dataset|DataLoader): An iterable data loader is used for
predict. An instance of paddle.fluid.io.Dataset or paddle.fluid.io.Dataloader
predict. An instance of paddle.io.Dataset or paddle.io.Dataloader
is recomended.
batch_size (int): Integer number. The batch size of train_data and eval_data.
When train_data and eval_data are both the instance of Dataloader, this
......@@ -1177,8 +1177,8 @@ class Model(fluid.dygraph.Layer):
"""
Args:
eval_data (Dataset|DataLoader|None): An iterable data loader is used for
eval. An instance of paddle.fluid.io.Dataset or
paddle.fluid.io.Dataloader is recomended.
eval. An instance of paddle.io.Dataset or
paddle.io.Dataloader is recomended.
"""
assert isinstance(
eval_data,
......
......@@ -21,7 +21,7 @@ from paddle.fluid.dygraph.nn import Conv2D, BatchNorm
from model import Model
from .download import get_weights_path
__all__ = ['DarkNet53', 'ConvBNLayer', 'darknet53']
__all__ = ['DarkNet', 'ConvBNLayer', 'darknet53']
# {num_layers: (url, md5)}
pretrain_infos = {
......@@ -136,9 +136,17 @@ class LayerWarp(fluid.dygraph.Layer):
DarkNet_cfg = {53: ([1, 2, 8, 8, 4])}
class DarkNet53(Model):
class DarkNet(Model):
"""DarkNet model from
`"YOLOv3: An Incremental Improvement" <https://arxiv.org/abs/1804.02767>`_
Args:
num_layers (int): layer number of DarkNet, only 53 supported currently, default: 53.
ch_in (int): channel number of input data, default 3.
"""
def __init__(self, num_layers=53, ch_in=3):
super(DarkNet53, self).__init__()
super(DarkNet, self).__init__()
assert num_layers in DarkNet_cfg.keys(), \
"only support num_layers in {} currently" \
.format(DarkNet_cfg.keys())
......@@ -188,7 +196,7 @@ class DarkNet53(Model):
def _darknet(num_layers=53, input_channels=3, pretrained=True):
model = DarkNet53(num_layers, input_channels)
model = DarkNet(num_layers, input_channels)
if pretrained:
assert num_layers in pretrain_infos.keys(), \
"DarkNet{} do not have pretrained weights now, " \
......@@ -201,4 +209,11 @@ def _darknet(num_layers=53, input_channels=3, pretrained=True):
def darknet53(input_channels=3, pretrained=True):
"""DarkNet 53-layer model
Args:
input_channels (bool): channel number of input data, default 3.
pretrained (bool): If True, returns a model pre-trained on ImageNet,
default True.
"""
return _darknet(53, input_channels, pretrained)
......@@ -201,4 +201,12 @@ def _tsm_resnet(num_layers, seg_num=8, num_classes=400, pretrained=True):
def tsm_resnet50(seg_num=8, num_classes=400, pretrained=True):
"""TSM model with 50-layer ResNet as backbone
Args:
seg_num (int): segment number of each video sample. Default 8.
num_classes (int): video class number. Default 400.
pretrained (bool): If True, returns a model with pre-trained model
on COCO, default True
"""
return _tsm_resnet(50, seg_num, num_classes, pretrained)
......@@ -88,6 +88,20 @@ class YoloDetectionBlock(fluid.dygraph.Layer):
class YOLOv3(Model):
"""YOLOv3 model from
`"YOLOv3: An Incremental Improvement" <https://arxiv.org/abs/1804.02767>`_
Args:
num_classes (int): class number, default 80.
model_mode (str): 'train', 'eval', 'test' mode, network structure
will be diffrent in the output layer and data, in 'train' mode,
no output layer append, in 'eval' and 'test', output feature
map will be decode to predictions by 'fluid.layers.yolo_box',
in 'eval' mode, return feature maps and predictions, in 'test'
mode, only return predictions. Default 'train'.
"""
def __init__(self, num_classes=80, model_mode='train'):
super(YOLOv3, self).__init__()
self.num_classes = num_classes
......@@ -245,4 +259,17 @@ def _yolov3_darknet(num_layers=53, num_classes=80,
def yolov3_darknet53(num_classes=80, model_mode='train', pretrained=True):
"""YOLOv3 model with 53-layer DarkNet as backbone
Args:
num_classes (int): class number, default 80.
model_mode (str): 'train', 'eval', 'test' mode, network structure
will be diffrent in the output layer and data, in 'train' mode,
no output layer append, in 'eval' and 'test', output feature
map will be decode to predictions by 'fluid.layers.yolo_box',
in 'eval' mode, return feature maps and predictions, in 'test'
mode, only return predictions. Default 'train'.
pretrained (bool): If True, returns a model with pre-trained model
on COCO, default True
"""
return _yolov3_darknet(53, num_classes, model_mode, pretrained)
......@@ -168,13 +168,13 @@ def create_lexnet_data_generator(args, reader, file_name, place, mode="train"):
def create_dataloader(generator, place, feed_list=None):
if not feed_list:
data_loader = fluid.io.DataLoader.from_generator(
data_loader = paddle.io.DataLoader.from_generator(
capacity=50,
use_double_buffer=True,
iterable=True,
return_list=True)
else:
data_loader = fluid.io.DataLoader.from_generator(
data_loader = paddle.io.DataLoader.from_generator(
feed_list=feed_list,
capacity=50,
use_double_buffer=True,
......
......@@ -22,7 +22,7 @@ from functools import partial
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.io import DataLoader
from paddle.io import DataLoader
from paddle.fluid.layers.utils import flatten
from utils.configure import PDConfig
......
......@@ -22,7 +22,7 @@ from functools import partial
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import BatchSampler, DataLoader, Dataset
from paddle.io import BatchSampler, DataLoader, Dataset
def create_data_loader(args, device):
......
......@@ -21,7 +21,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.io import DataLoader
from paddle.io import DataLoader
from utils.configure import PDConfig
from utils.check import check_gpu, check_version
......
......@@ -26,7 +26,7 @@ except ImportError:
import pickle
from io import BytesIO
from paddle.fluid.io import Dataset
from paddle.io import Dataset
import logging
logger = logging.getLogger(__name__)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .folder import *
from .mnist import *
from .flowers import *
from .coco import *
......@@ -20,7 +20,7 @@ import cv2
import numpy as np
from pycocotools.coco import COCO
from paddle.fluid.io import Dataset
from paddle.io import Dataset
import logging
logger = logging.getLogger(__name__)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import io
import tarfile
import numpy as np
import scipy.io as scio
from PIL import Image
from paddle.io import Dataset
from .utils import _check_exists_and_download
__all__ = ["Flowers"]
DATA_URL = 'http://paddlemodels.bj.bcebos.com/flowers/102flowers.tgz'
LABEL_URL = 'http://paddlemodels.bj.bcebos.com/flowers/imagelabels.mat'
SETID_URL = 'http://paddlemodels.bj.bcebos.com/flowers/setid.mat'
DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
# In official 'readme', tstid is the flag of test data
# and trnid is the flag of train data. But test data is more than train data.
# So we exchange the train data and test data.
MODE_FLAG_MAP = {'train': 'tstid', 'test': 'trnid', 'valid': "valid"}
class Flowers(Dataset):
"""
Implement of flowers dataset
Args:
data_file(str): path to data file, can be set None if
:attr:`download` is True. Default None
label_file(str): path to label file, can be set None if
:attr:`download` is True. Default None
setid_file(str): path to subset index file, can be set
None if :attr:`download` is True. Default None
mode(str): 'train', 'valid' or 'test' mode. Default 'train'.
download(bool): whether auto download mnist dataset if
:attr:`image_path`/:attr:`label_path` unset. Default
True
Examples:
.. code-block:: python
from hapi.vision.datasets import Flowers
flowers = Flowers(mode='test')
for i in range(len(flowers)):
sample = flowers[i]
print(sample[0].shape, sample[1])
"""
def __init__(self,
data_file=None,
label_file=None,
setid_file=None,
mode='train',
transform=None,
target_transform=None,
download=True):
assert mode.lower() in ['train', 'valid', 'test'], \
"mode should be 'train', 'valid' or 'test', but got {}".format(mode)
self.flag = MODE_FLAG_MAP[mode.lower()]
self.data_file = data_file
if self.data_file is None:
assert download, "data_file not set and auto download disabled"
self.data_file = _check_exists_and_download(
data_file, DATA_URL, DATA_MD5, 'flowers', download)
self.label_file = label_file
if self.label_file is None:
assert download, "label_file not set and auto download disabled"
self.label_file = _check_exists_and_download(
label_file, LABEL_URL, LABEL_MD5, 'flowers', download)
self.setid_file = setid_file
if self.setid_file is None:
assert download, "setid_file not set and auto download disabled"
self.setid_file = _check_exists_and_download(
setid_file, SETID_URL, SETID_MD5, 'flowers', download)
self.transform = transform
self.target_transform = target_transform
# read dataset into memory
self._load_anno()
def _load_anno(self):
self.name2mem = {}
self.data_tar = tarfile.open(self.data_file)
for ele in self.data_tar.getmembers():
self.name2mem[ele.name] = ele
self.labels = scio.loadmat(self.label_file)['labels'][0]
self.indexes = scio.loadmat(self.setid_file)[self.flag][0]
def __getitem__(self, idx):
index = self.indexes[idx]
label = np.array([self.labels[index - 1]])
img_name = "jpg/image_%05d.jpg" % index
img_ele = self.name2mem[img_name]
image = self.data_tar.extractfile(img_ele).read()
image = np.array(Image.open(io.BytesIO(image)))
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
label = self.target_transform(label)
return image, label
def __len__(self):
return len(self.indexes)
......@@ -16,7 +16,9 @@ import os
import sys
import cv2
from paddle.fluid.io import Dataset
from paddle.io import Dataset
__all__ = ["DatasetFolder"]
def has_valid_extension(filename, extensions):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import gzip
import struct
import numpy as np
import paddle.dataset.common
from paddle.io import Dataset
from .utils import _check_exists_and_download
__all__ = ["MNIST"]
URL_PREFIX = 'https://dataset.bj.bcebos.com/mnist/'
TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
TEST_IMAGE_MD5 = '9fb629c4189551a2d022fa330f9573f3'
TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz'
TEST_LABEL_MD5 = 'ec29112dd5afa0611ce80d1b7f02629c'
TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz'
TRAIN_IMAGE_MD5 = 'f68b3c2dcbeaaa9fbdd348bbdeb94873'
TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz'
TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432'
class MNIST(Dataset):
"""
Implement of MNIST dataset
Args:
image_path(str): path to image file, can be set None if
:attr:`download` is True. Default None
label_path(str): path to label file, can be set None if
:attr:`download` is True. Default None
mode(str): 'train' or 'test' mode. Default 'train'.
download(bool): whether auto download mnist dataset if
:attr:`image_path`/:attr:`label_path` unset. Default
True
Returns:
Dataset: MNIST Dataset.
Examples:
.. code-block:: python
from hapi.vision.datasets import MNIST
mnist = MNIST(mode='test')
for i in range(len(mnist)):
sample = mnist[i]
print(sample[0].shape, sample[1])
"""
def __init__(self,
image_path=None,
label_path=None,
mode='train',
transform=None,
target_transform=None,
download=True):
assert mode.lower() in ['train', 'test'], \
"mode should be 'train' or 'test', but got {}".format(mode)
self.mode = mode.lower()
self.image_path = image_path
if self.image_path is None:
assert download, "image_path not set and auto download disabled"
image_url = TRAIN_IMAGE_URL if mode == 'train' else TEST_IMAGE_URL
image_md5 = TRAIN_IMAGE_MD5 if mode == 'train' else TEST_IMAGE_MD5
self.image_path = _check_exists_and_download(
image_path, image_url, image_md5, 'mnist', download)
self.label_path = label_path
if self.label_path is None:
assert download, "label_path not set and auto download disabled"
label_url = TRAIN_LABEL_URL if mode == 'train' else TEST_LABEL_URL
label_md5 = TRAIN_LABEL_MD5 if mode == 'train' else TEST_LABEL_MD5
self.label_path = _check_exists_and_download(
label_path, label_url, label_md5, 'mnist', download)
self.transform = transform
self.target_transform = target_transform
# read dataset into memory
self._parse_dataset()
def _parse_dataset(self, buffer_size=100):
self.images = []
self.labels = []
with gzip.GzipFile(self.image_path, 'rb') as image_file:
img_buf = image_file.read()
with gzip.GzipFile(self.label_path, 'rb') as label_file:
lab_buf = label_file.read()
step_label = 0
offset_img = 0
# read from Big-endian
# get file info from magic byte
# image file : 16B
magic_byte_img = '>IIII'
magic_img, image_num, rows, cols = struct.unpack_from(
magic_byte_img, img_buf, offset_img)
offset_img += struct.calcsize(magic_byte_img)
offset_lab = 0
# label file : 8B
magic_byte_lab = '>II'
magic_lab, label_num = struct.unpack_from(magic_byte_lab,
lab_buf, offset_lab)
offset_lab += struct.calcsize(magic_byte_lab)
while True:
if step_label >= label_num:
break
fmt_label = '>' + str(buffer_size) + 'B'
labels = struct.unpack_from(fmt_label, lab_buf, offset_lab)
offset_lab += struct.calcsize(fmt_label)
step_label += buffer_size
fmt_images = '>' + str(buffer_size * rows * cols) + 'B'
images_temp = struct.unpack_from(fmt_images, img_buf,
offset_img)
images = np.reshape(images_temp, (buffer_size, rows *
cols)).astype('float32')
offset_img += struct.calcsize(fmt_images)
images = images / 255.0
images = images * 2.0
images = images - 1.0
for i in range(buffer_size):
self.images.append(images[i, :])
self.labels.append(np.array([labels[i]]))
def __getitem__(self, idx):
image, label = self.images[idx], self.labels[idx]
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
label = self.target_transform(label)
return image, label
def __len__(self):
return len(self.labels)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import paddle.dataset.common
def _check_exists_and_download(path, url, md5, module_name, download=True):
if path and os.path.exists(path):
return path
if download:
return paddle.dataset.common.download(url, module_name, md5)
else:
raise FileNotFoundError(
'{} not exists and auto download disabled'.format(path))
......@@ -13,3 +13,5 @@
# limitations under the License.
from .transforms import *
from .functional import *
from .detection_transforms import *
......@@ -19,48 +19,18 @@ import cv2
import traceback
import numpy as np
import logging
logger = logging.getLogger(__name__)
__all__ = ['ColorDistort', 'RandomExpand', 'RandomCrop', 'RandomFlip',
'NormalizeBox', 'PadBox', 'RandomShape', 'NormalizeImage',
'BboxXYXY2XYWH', 'ResizeImage', 'Compose', 'BatchCompose']
class Compose(object):
def __init__(self, transforms=[]):
self.transforms = transforms
def __call__(self, *data):
for f in self.transforms:
try:
data = f(*data)
except Exception as e:
stack_info = traceback.format_exc()
logger.info("fail to perform transform [{}] with error: "
"{} and stack:\n{}".format(f, e, str(stack_info)))
raise e
return data
class BatchCompose(object):
def __init__(self, transforms=[]):
self.transforms = transforms
def __call__(self, data):
for f in self.transforms:
try:
data = f(data)
except Exception as e:
stack_info = traceback.format_exc()
logger.info("fail to perform batch transform [{}] with error: "
"{} and stack:\n{}".format(f, e, str(stack_info)))
raise e
# sample list to batch data
batch = list(zip(*data))
return batch
__all__ = [
'ColorDistort',
'RandomExpand',
'RandomCrop',
'RandomFlip',
'NormalizeBox',
'PadBox',
'RandomShape',
'NormalizeImage',
'BboxXYXY2XYWH',
'ResizeImage',
]
class ColorDistort(object):
......
......@@ -24,6 +24,7 @@ import numbers
import types
import collections
import warnings
import traceback
from . import functional as F
......@@ -34,6 +35,7 @@ else:
__all__ = [
"Compose",
"BatchCompose",
"Resize",
"RandomResizedCrop",
"CenterCropResize",
......@@ -62,10 +64,16 @@ class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
def __call__(self, *data):
for f in self.transforms:
try:
data = f(*data)
except Exception as e:
stack_info = traceback.format_exc()
print("fail to perform transform [{}] with error: "
"{} and stack:\n{}".format(f, e, str(stack_info)))
raise e
return data
def __repr__(self):
format_string = self.__class__.__name__ + '('
......@@ -76,6 +84,33 @@ class Compose(object):
return format_string
class BatchCompose(object):
"""Composes several batch transforms together
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
these transforms perform on batch data.
"""
def __init__(self, transforms=[]):
self.transforms = transforms
def __call__(self, data):
for f in self.transforms:
try:
data = f(data)
except Exception as e:
stack_info = traceback.format_exc()
print("fail to perform batch transform [{}] with error: "
"{} and stack:\n{}".format(f, e, str(stack_info)))
raise e
# sample list to batch data
batch = list(zip(*data))
return batch
class Resize(object):
"""Resize the input PIL Image to the given size.
......
......@@ -22,7 +22,7 @@ from PIL import Image
from paddle import fluid
from paddle.fluid.optimizer import Momentum
from paddle.fluid.io import DataLoader
from paddle.io import DataLoader
from model import Model, Input, set_device
from models import yolov3_darknet53, YoloLoss
......
......@@ -23,15 +23,15 @@ import numpy as np
from paddle import fluid
from paddle.fluid.optimizer import Momentum
from paddle.fluid.io import DataLoader
from paddle.io import DataLoader
from model import Model, Input, set_device
from distributed import DistributedBatchSampler
from models import yolov3_darknet53, YoloLoss
from coco_metric import COCOMetric
from coco import COCODataset
from transforms import *
from vision.datasets import COCODataset
from vision.transforms import *
NUM_MAX_BOXES = 50
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册