提交 2dd4aa3a 编写于 作者: L LielinJiang

add comments and examples

上级 db5f3697
...@@ -48,6 +48,35 @@ class DistributedBatchSampler(BatchSampler): ...@@ -48,6 +48,35 @@ class DistributedBatchSampler(BatchSampler):
batch indices. Default False. batch indices. Default False.
drop_last(bool): whether drop the last incomplete batch dataset size drop_last(bool): whether drop the last incomplete batch dataset size
is not divisible by the batch size. Default False is not divisible by the batch size. Default False
Examples:
.. code-block:: python
import numpy as np
from hapi.datasets import MNIST
from hapi.distributed import DistributedBatchSampler
class MnistDataset(MNIST):
def __init__(self, mode, return_label=True):
super(MnistDataset, self).__init__(mode=mode)
self.return_label = return_label
def __getitem__(self, idx):
img = np.reshape(self.images[idx], [1, 28, 28])
if self.return_label:
return img, np.array(self.labels[idx]).astype('int64')
return img,
def __len__(self):
return len(self.images)
train_dataset = MnistDataset(mode='train')
dist_train_dataloader = DistributedBatchSampler(train_dataset, batch_size=64)
for data in dist_train_dataloader:
# do something
break
""" """
def __init__(self, dataset, batch_size, shuffle=False, drop_last=False): def __init__(self, dataset, batch_size, shuffle=False, drop_last=False):
......
...@@ -55,6 +55,15 @@ def get_weights_path_from_url(url, md5sum=None): ...@@ -55,6 +55,15 @@ def get_weights_path_from_url(url, md5sum=None):
Returns: Returns:
str: a local path to save downloaded weights. str: a local path to save downloaded weights.
Examples:
.. code-block:: python
from hapi.download import get_weights_path_from_url
resnet18_pretrained_weight_url = 'https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams'
local_weight_path = get_weights_path_from_url(resnet18_pretrained_weight_url)
""" """
path = get_path_from_url(url, WEIGHTS_HOME, md5sum) path = get_path_from_url(url, WEIGHTS_HOME, md5sum)
return path return path
......
...@@ -38,7 +38,7 @@ def setup_logger(output=None, name="hapi", log_level=logging.INFO): ...@@ -38,7 +38,7 @@ def setup_logger(output=None, name="hapi", log_level=logging.INFO):
# stdout logging: only local rank==0 # stdout logging: only local rank==0
local_rank = ParallelEnv().local_rank local_rank = ParallelEnv().local_rank
if local_rank == 0: if local_rank == 0 and not logger.hasHandlers():
ch = logging.StreamHandler(stream=sys.stdout) ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(log_level) ch.setLevel(log_level)
......
...@@ -30,8 +30,8 @@ class Loss(object): ...@@ -30,8 +30,8 @@ class Loss(object):
Base class for loss, encapsulates loss logic and APIs Base class for loss, encapsulates loss logic and APIs
Usage: Usage:
custom_loss = CustomLoss() custom_loss = CustomLoss()
loss = custom_loss(inputs, labels) loss = custom_loss(inputs, labels)
""" """
def __init__(self, average=True): def __init__(self, average=True):
...@@ -63,10 +63,25 @@ class CrossEntropy(Loss): ...@@ -63,10 +63,25 @@ class CrossEntropy(Loss):
average (bool, optional): Indicate whether to average the loss, Default: True. average (bool, optional): Indicate whether to average the loss, Default: True.
Returns: Returns:
list[Variable]: The tensor variable storing the cross_entropy_loss of inputs and labels. list[Variable]: The tensor variable storing the cross_entropy_loss of inputs and labels.
Examples:
.. code-block:: python
from hapi.model import Input
from hapi.vision.models import LeNet
from hapi.loss import CrossEntropy
inputs = [Input([-1, 1, 28, 28], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
model = LeNet()
loss = CrossEntropy()
model.prepare(loss_function=loss, inputs=inputs, labels=labels)
""" """
def __init__(self, average=True): def __init__(self, average=True):
super(CrossEntropy, self).__init__() super(CrossEntropy, self).__init__(average)
def forward(self, outputs, labels): def forward(self, outputs, labels):
return [ return [
...@@ -85,10 +100,24 @@ class SoftmaxWithCrossEntropy(Loss): ...@@ -85,10 +100,24 @@ class SoftmaxWithCrossEntropy(Loss):
average (bool, optional): Indicate whether to average the loss, Default: True. average (bool, optional): Indicate whether to average the loss, Default: True.
Returns: Returns:
list[Variable]: The tensor variable storing the cross_entropy_loss of inputs and labels. list[Variable]: The tensor variable storing the cross_entropy_loss of inputs and labels.
Examples:
.. code-block:: python
from hapi.model import Input
from hapi.vision.models import LeNet
from hapi.loss import SoftmaxWithCrossEntropy
inputs = [Input([-1, 1, 28, 28], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
model = LeNet(classifier_activation=None)
loss = SoftmaxWithCrossEntropy()
model.prepare(loss_function=loss, inputs=inputs, labels=labels)
""" """
def __init__(self, average=True): def __init__(self, average=True):
super(SoftmaxWithCrossEntropy, self).__init__() super(SoftmaxWithCrossEntropy, self).__init__(average)
def forward(self, outputs, labels): def forward(self, outputs, labels):
return [ return [
......
...@@ -144,6 +144,13 @@ class DarkNet(Model): ...@@ -144,6 +144,13 @@ class DarkNet(Model):
will not be defined. Default: 1000. will not be defined. Default: 1000.
with_pool (bool): use pool before the last fc layer or not. Default: True. with_pool (bool): use pool before the last fc layer or not. Default: True.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'. classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
Examples:
.. code-block:: python
from hapi.vision.models import DarkNet
model = DarkNet()
""" """
def __init__(self, def __init__(self,
...@@ -233,5 +240,17 @@ def darknet53(pretrained=False, **kwargs): ...@@ -233,5 +240,17 @@ def darknet53(pretrained=False, **kwargs):
input_channels (bool): channel number of input data, default 3. input_channels (bool): channel number of input data, default 3.
pretrained (bool): If True, returns a model pre-trained on ImageNet, pretrained (bool): If True, returns a model pre-trained on ImageNet,
default True. default True.
Examples:
.. code-block:: python
from hapi.vision.models import darknet53
# build model
model = darknet53()
#build model and load imagenet pretrained weight
model = darknet53(pretrained=True)
""" """
return _darknet('darknet53', 53, pretrained, **kwargs) return _darknet('darknet53', 53, pretrained, **kwargs)
...@@ -29,6 +29,13 @@ class LeNet(Model): ...@@ -29,6 +29,13 @@ class LeNet(Model):
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 10. will not be defined. Default: 10.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'. classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
Examples:
.. code-block:: python
from hapi.vision.models import LeNet
model = LeNet()
""" """
def __init__(self, num_classes=10, classifier_activation='softmax'): def __init__(self, num_classes=10, classifier_activation='softmax'):
......
...@@ -115,6 +115,13 @@ class MobileNetV1(Model): ...@@ -115,6 +115,13 @@ class MobileNetV1(Model):
will not be defined. Default: 1000. will not be defined. Default: 1000.
with_pool (bool): use pool before the last fc layer or not. Default: True. with_pool (bool): use pool before the last fc layer or not. Default: True.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'. classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
Examples:
.. code-block:: python
from hapi.vision.models import MobileNetV1
model = MobileNetV1()
""" """
def __init__(self, def __init__(self,
...@@ -282,6 +289,20 @@ def mobilenet_v1(pretrained=False, scale=1.0, **kwargs): ...@@ -282,6 +289,20 @@ def mobilenet_v1(pretrained=False, scale=1.0, **kwargs):
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False. pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
scale: (float): scale of channels in each layer. Default: 1.0. scale: (float): scale of channels in each layer. Default: 1.0.
Examples:
.. code-block:: python
from hapi.vision.models import mobilenet_v1
# build model
model = mobilenet_v1()
#build model and load imagenet pretrained weight
model = mobilenet_v1(pretrained=True)
#build mobilenet v1 with scale=0.5
model = mobilenet_v1(scale=0.5)
""" """
model = _mobilenet( model = _mobilenet(
'mobilenetv1_' + str(scale), pretrained, scale=scale, **kwargs) 'mobilenetv1_' + str(scale), pretrained, scale=scale, **kwargs)
......
...@@ -160,6 +160,13 @@ class MobileNetV2(Model): ...@@ -160,6 +160,13 @@ class MobileNetV2(Model):
will not be defined. Default: 1000. will not be defined. Default: 1000.
with_pool (bool): use pool before the last fc layer or not. Default: True. with_pool (bool): use pool before the last fc layer or not. Default: True.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'. classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
Examples:
.. code-block:: python
from hapi.vision.models import MobileNetV2
model = MobileNetV2()
""" """
def __init__(self, def __init__(self,
...@@ -256,6 +263,20 @@ def mobilenet_v2(pretrained=False, scale=1.0, **kwargs): ...@@ -256,6 +263,20 @@ def mobilenet_v2(pretrained=False, scale=1.0, **kwargs):
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False. pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
scale: (float): scale of channels in each layer. Default: 1.0. scale: (float): scale of channels in each layer. Default: 1.0.
Examples:
.. code-block:: python
from hapi.vision.models import mobilenet_v2
# build model
model = mobilenet_v2()
#build model and load imagenet pretrained weight
model = mobilenet_v2(pretrained=True)
#build mobilenet v2 with scale=0.5
model = mobilenet_v2(scale=0.5)
""" """
model = _mobilenet( model = _mobilenet(
'mobilenetv2_' + str(scale), pretrained, scale=scale, **kwargs) 'mobilenetv2_' + str(scale), pretrained, scale=scale, **kwargs)
......
...@@ -26,7 +26,8 @@ from hapi.model import Model ...@@ -26,7 +26,8 @@ from hapi.model import Model
from hapi.download import get_weights_path_from_url from hapi.download import get_weights_path_from_url
__all__ = [ __all__ = [
'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152' 'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
'BottleneckBlock', 'BasicBlock'
] ]
model_urls = { model_urls = {
...@@ -75,7 +76,8 @@ class ConvBNLayer(fluid.dygraph.Layer): ...@@ -75,7 +76,8 @@ class ConvBNLayer(fluid.dygraph.Layer):
class BasicBlock(fluid.dygraph.Layer): class BasicBlock(fluid.dygraph.Layer):
"""residual block of resnet18 and resnet34
"""
expansion = 1 expansion = 1
def __init__(self, num_channels, num_filters, stride, shortcut=True): def __init__(self, num_channels, num_filters, stride, shortcut=True):
...@@ -117,6 +119,8 @@ class BasicBlock(fluid.dygraph.Layer): ...@@ -117,6 +119,8 @@ class BasicBlock(fluid.dygraph.Layer):
class BottleneckBlock(fluid.dygraph.Layer): class BottleneckBlock(fluid.dygraph.Layer):
"""residual block of resnet50, resnet101 amd resnet152
"""
expansion = 4 expansion = 4
...@@ -177,6 +181,16 @@ class ResNet(Model): ...@@ -177,6 +181,16 @@ class ResNet(Model):
will not be defined. Default: 1000. will not be defined. Default: 1000.
with_pool (bool): use pool before the last fc layer or not. Default: True. with_pool (bool): use pool before the last fc layer or not. Default: True.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'. classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
Examples:
.. code-block:: python
from hapi.vision.models import ResNet, BottleneckBlock, BasicBlock
resnet50 = ResNet(BottleneckBlock, 50)
resnet18 = ResNet(BasicBlock, 18)
""" """
def __init__(self, def __init__(self,
...@@ -280,6 +294,17 @@ def resnet18(pretrained=False, **kwargs): ...@@ -280,6 +294,17 @@ def resnet18(pretrained=False, **kwargs):
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from hapi.vision.models import resnet18
# build model
model = resnet18()
#build model and load imagenet pretrained weight
model = resnet18(pretrained=True)
""" """
return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs) return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs)
...@@ -289,6 +314,17 @@ def resnet34(pretrained=False, **kwargs): ...@@ -289,6 +314,17 @@ def resnet34(pretrained=False, **kwargs):
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from hapi.vision.models import resnet34
# build model
model = resnet34()
#build model and load imagenet pretrained weight
model = resnet34(pretrained=True)
""" """
return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs) return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs)
...@@ -298,6 +334,17 @@ def resnet50(pretrained=False, **kwargs): ...@@ -298,6 +334,17 @@ def resnet50(pretrained=False, **kwargs):
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from hapi.vision.models import resnet50
# build model
model = resnet50()
#build model and load imagenet pretrained weight
model = resnet50(pretrained=True)
""" """
return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs) return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs)
...@@ -307,6 +354,17 @@ def resnet101(pretrained=False, **kwargs): ...@@ -307,6 +354,17 @@ def resnet101(pretrained=False, **kwargs):
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from hapi.vision.models import resnet101
# build model
model = resnet101()
#build model and load imagenet pretrained weight
model = resnet101(pretrained=True)
""" """
return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs) return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs)
...@@ -316,5 +374,16 @@ def resnet152(pretrained=False, **kwargs): ...@@ -316,5 +374,16 @@ def resnet152(pretrained=False, **kwargs):
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from hapi.vision.models import resnet152
# build model
model = resnet152()
#build model and load imagenet pretrained weight
model = resnet152(pretrained=True)
""" """
return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs) return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs)
...@@ -143,6 +143,17 @@ def vgg11(pretrained=False, batch_norm=False, **kwargs): ...@@ -143,6 +143,17 @@ def vgg11(pretrained=False, batch_norm=False, **kwargs):
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False. pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False. batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
Examples:
.. code-block:: python
from hapi.vision.models import vgg11
# build model
model = vgg11()
#build vgg11 model with batch_norm
model = vgg11(batch_norm=True)
""" """
model_name = 'vgg11' model_name = 'vgg11'
if batch_norm: if batch_norm:
...@@ -156,6 +167,17 @@ def vgg13(pretrained=False, batch_norm=False, **kwargs): ...@@ -156,6 +167,17 @@ def vgg13(pretrained=False, batch_norm=False, **kwargs):
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False. pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False. batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
Examples:
.. code-block:: python
from hapi.vision.models import vgg13
# build model
model = vgg13()
#build vgg13 model with batch_norm
model = vgg13(batch_norm=True)
""" """
model_name = 'vgg13' model_name = 'vgg13'
if batch_norm: if batch_norm:
...@@ -169,6 +191,17 @@ def vgg16(pretrained=False, batch_norm=False, **kwargs): ...@@ -169,6 +191,17 @@ def vgg16(pretrained=False, batch_norm=False, **kwargs):
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False. pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False. batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
Examples:
.. code-block:: python
from hapi.vision.models import vgg16
# build model
model = vgg16()
#build vgg16 model with batch_norm
model = vgg16(batch_norm=True)
""" """
model_name = 'vgg16' model_name = 'vgg16'
if batch_norm: if batch_norm:
...@@ -182,6 +215,17 @@ def vgg19(pretrained=False, batch_norm=False, **kwargs): ...@@ -182,6 +215,17 @@ def vgg19(pretrained=False, batch_norm=False, **kwargs):
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False. pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False. batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
Examples:
.. code-block:: python
from hapi.vision.models import vgg19
# build model
model = vgg19()
#build vgg19 model with batch_norm
model = vgg19(batch_norm=True)
""" """
model_name = 'vgg19' model_name = 'vgg19'
if batch_norm: if batch_norm:
......
...@@ -39,6 +39,23 @@ def flip(image, code): ...@@ -39,6 +39,23 @@ def flip(image, code):
-1 : Flip horizontally and vertically -1 : Flip horizontally and vertically
0 : Flip vertically 0 : Flip vertically
1 : Flip horizontally 1 : Flip horizontally
Examples:
.. code-block:: python
import numpy as np
from hapi.vision.transforms import functional as F
fake_img = np.random.rand(224, 224, 3)
# flip horizontally and vertically
F.flip(fake_img, -1)
# flip vertically
F.flip(fake_img, 0)
# flip horizontally
F.flip(fake_img, 1)
""" """
return cv2.flip(image, flipCode=code) return cv2.flip(image, flipCode=code)
...@@ -51,6 +68,18 @@ def resize(img, size, interpolation=cv2.INTER_LINEAR): ...@@ -51,6 +68,18 @@ def resize(img, size, interpolation=cv2.INTER_LINEAR):
input: Input data, could be image or masks, with (H, W, C) shape input: Input data, could be image or masks, with (H, W, C) shape
size: Target size of input data, with (height, width) shape. size: Target size of input data, with (height, width) shape.
interpolation: Interpolation method. interpolation: Interpolation method.
Examples:
.. code-block:: python
import numpy as np
from hapi.vision.transforms import functional as F
fake_img = np.random.rand(256, 256, 3)
F.resize(fake_img, 224)
F.resize(fake_img, (200, 150))
""" """
if isinstance(interpolation, Sequence): if isinstance(interpolation, Sequence):
......
...@@ -118,6 +118,67 @@ class BatchCompose(object): ...@@ -118,6 +118,67 @@ class BatchCompose(object):
transforms (list of ``Transform`` objects): list of transforms to compose. transforms (list of ``Transform`` objects): list of transforms to compose.
these transforms perform on batch data. these transforms perform on batch data.
Examples:
.. code-block:: python
import numpy as np
from paddle.io import DataLoader
from hapi.model import set_device
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, BatchCompose, Resize
class NormalizeBatch(object):
def __init__(self,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
scale=True,
channel_first=True):
self.mean = mean
self.std = std
self.scale = scale
self.channel_first = channel_first
if not (isinstance(self.mean, list) and isinstance(self.std, list) and
isinstance(self.scale, bool)):
raise TypeError("{}: input type is invalid.".format(self))
from functools import reduce
if reduce(lambda x, y: x * y, self.std) == 0:
raise ValueError('{}: std is invalid!'.format(self))
def __call__(self, samples):
for i in range(len(samples)):
samples[i] = list(samples[i])
im = samples[i][0]
im = im.astype(np.float32, copy=False)
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
if self.scale:
im = im / 255.0
im -= mean
im /= std
if self.channel_first:
im = im.transpose((2, 0, 1))
samples[i][0] = im
return samples
transform = Compose([Resize((500, 500))])
flowers_dataset = Flowers(mode='test', transform=transform)
device = set_device('cpu')
collate_fn = BatchCompose([NormalizeBatch()])
loader = DataLoader(
flowers_dataset,
batch_size=4,
places=device,
return_list=True,
collate_fn=collate_fn)
for data in loader:
# do something
break
""" """
def __init__(self, transforms=[]): def __init__(self, transforms=[]):
...@@ -149,6 +210,20 @@ class Resize(object): ...@@ -149,6 +210,20 @@ class Resize(object):
i.e, if height > width, then image will be rescaled to i.e, if height > width, then image will be rescaled to
(size * height / width, size) (size * height / width, size)
interpolation (int): interpolation mode of resize. Default: cv2.INTER_LINEAR. interpolation (int): interpolation mode of resize. Default: cv2.INTER_LINEAR.
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, Resize
transform = Compose([Resize(size=224)])
flowers = Flowers(mode='test', transform=transform)
for i in range(10):
sample = flowers[i]
print(sample[0].shape, sample[1])
""" """
def __init__(self, size, interpolation=cv2.INTER_LINEAR): def __init__(self, size, interpolation=cv2.INTER_LINEAR):
...@@ -171,6 +246,20 @@ class RandomResizedCrop(object): ...@@ -171,6 +246,20 @@ class RandomResizedCrop(object):
output_size (int|list|tuple): Target size of output image, with (height, width) shape. output_size (int|list|tuple): Target size of output image, with (height, width) shape.
scale (list|tuple): Range of size of the origin size cropped. Default: (0.08, 1.0) scale (list|tuple): Range of size of the origin size cropped. Default: (0.08, 1.0)
ratio (list|tuple): Range of aspect ratio of the origin aspect ratio cropped. Default: (0.75, 1.33) ratio (list|tuple): Range of aspect ratio of the origin aspect ratio cropped. Default: (0.75, 1.33)
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, Resize, RandomResizedCrop
transform = Compose([Resize(500), RandomResizedCrop(224)])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
""" """
def __init__(self, def __init__(self,
...@@ -233,6 +322,20 @@ class CenterCropResize(object): ...@@ -233,6 +322,20 @@ class CenterCropResize(object):
size (int|list|tuple): Target size of output image, with (height, width) shape. size (int|list|tuple): Target size of output image, with (height, width) shape.
crop_padding (int): center crop with the padding. Default: 32. crop_padding (int): center crop with the padding. Default: 32.
interpolation (int): interpolation mode of resize. Default: cv2.INTER_LINEAR. interpolation (int): interpolation mode of resize. Default: cv2.INTER_LINEAR.
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, Resize, CenterCropResize
transform = Compose([Resize(500), CenterCropResize(224)])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
""" """
def __init__(self, size, crop_padding=32, interpolation=cv2.INTER_LINEAR): def __init__(self, size, crop_padding=32, interpolation=cv2.INTER_LINEAR):
...@@ -262,6 +365,20 @@ class CenterCrop(object): ...@@ -262,6 +365,20 @@ class CenterCrop(object):
Args: Args:
output_size: Target size of output image, with (height, width) shape. output_size: Target size of output image, with (height, width) shape.
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, Resize, CenterCrop
transform = Compose([Resize(500), CenterCrop(224)])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
""" """
def __init__(self, output_size): def __init__(self, output_size):
...@@ -289,6 +406,20 @@ class RandomHorizontalFlip(object): ...@@ -289,6 +406,20 @@ class RandomHorizontalFlip(object):
Args: Args:
prob (float): probability of the input data being flipped. Default: 0.5 prob (float): probability of the input data being flipped. Default: 0.5
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, RandomHorizontalFlip
transform = Compose([RandomHorizontalFlip()])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
""" """
def __init__(self, prob=0.5): def __init__(self, prob=0.5):
...@@ -305,6 +436,20 @@ class RandomVerticalFlip(object): ...@@ -305,6 +436,20 @@ class RandomVerticalFlip(object):
Args: Args:
prob (float): probability of the input data being flipped. Default: 0.5 prob (float): probability of the input data being flipped. Default: 0.5
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, RandomVerticalFlip
transform = Compose([RandomVerticalFlip()])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
""" """
def __init__(self, prob=0.5): def __init__(self, prob=0.5):
...@@ -325,6 +470,23 @@ class Normalize(object): ...@@ -325,6 +470,23 @@ class Normalize(object):
Args: Args:
mean (int|float|list): Sequence of means for each channel. mean (int|float|list): Sequence of means for each channel.
std (int|float|list): Sequence of standard deviations for each channel. std (int|float|list): Sequence of standard deviations for each channel.
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, Normalize, Permute
normalize = Normalize(mean=[123.675, 116.28, 103.53],
std=[58.395, 57.120, 57.375])
transform = Compose([Permute(), normalize])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
""" """
...@@ -351,6 +513,20 @@ class Permute(object): ...@@ -351,6 +513,20 @@ class Permute(object):
Args: Args:
mode: Output mode of input. Default: "CHW". mode: Output mode of input. Default: "CHW".
to_rgb: convert 'bgr' image to 'rgb'. Default: True. to_rgb: convert 'bgr' image to 'rgb'. Default: True.
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, Permute
transform = Compose([Permute()])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
""" """
def __init__(self, mode="CHW", to_rgb=True): def __init__(self, mode="CHW", to_rgb=True):
...@@ -375,6 +551,20 @@ class GaussianNoise(object): ...@@ -375,6 +551,20 @@ class GaussianNoise(object):
Args: Args:
mean: Gaussian mean used to generate noise. mean: Gaussian mean used to generate noise.
std: Gaussian standard deviation used to generate noise. std: Gaussian standard deviation used to generate noise.
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, GaussianNoise
transform = Compose([GaussianNoise()])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
""" """
def __init__(self, mean=0.0, std=1.0): def __init__(self, mean=0.0, std=1.0):
...@@ -394,6 +584,20 @@ class BrightnessTransform(object): ...@@ -394,6 +584,20 @@ class BrightnessTransform(object):
Args: Args:
value: How much to adjust the brightness. Can be any value: How much to adjust the brightness. Can be any
non negative number. 0 gives the original image non negative number. 0 gives the original image
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, BrightnessTransform
transform = Compose([BrightnessTransform(0.4)])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
""" """
def __init__(self, value): def __init__(self, value):
...@@ -418,6 +622,20 @@ class ContrastTransform(object): ...@@ -418,6 +622,20 @@ class ContrastTransform(object):
Args: Args:
value: How much to adjust the contrast. Can be any value: How much to adjust the contrast. Can be any
non negative number. 0 gives the original image non negative number. 0 gives the original image
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, ContrastTransform
transform = Compose([ContrastTransform(0.4)])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
""" """
def __init__(self, value): def __init__(self, value):
...@@ -443,6 +661,20 @@ class SaturationTransform(object): ...@@ -443,6 +661,20 @@ class SaturationTransform(object):
Args: Args:
value: How much to adjust the saturation. Can be any value: How much to adjust the saturation. Can be any
non negative number. 0 gives the original image non negative number. 0 gives the original image
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, SaturationTransform
transform = Compose([SaturationTransform(0.4)])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
""" """
def __init__(self, value): def __init__(self, value):
...@@ -469,6 +701,20 @@ class HueTransform(object): ...@@ -469,6 +701,20 @@ class HueTransform(object):
Args: Args:
value: How much to adjust the hue. Can be any number value: How much to adjust the hue. Can be any number
between 0 and 0.5, 0 gives the original image between 0 and 0.5, 0 gives the original image
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, HueTransform
transform = Compose([HueTransform(0.4)])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
""" """
def __init__(self, value): def __init__(self, value):
...@@ -510,6 +756,20 @@ class ColorJitter(object): ...@@ -510,6 +756,20 @@ class ColorJitter(object):
hue: How much to jitter hue. hue: How much to jitter hue.
Chosen uniformly from [-hue, hue] or the given [min, max]. Chosen uniformly from [-hue, hue] or the given [min, max].
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, ColorJitter
transform = Compose([ColorJitter(0.4)])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
""" """
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册