Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
2dd4aa3a
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2dd4aa3a
编写于
4月 29, 2020
作者:
L
LielinJiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add comments and examples
上级
db5f3697
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
544 addition
and
7 deletion
+544
-7
hapi/distributed.py
hapi/distributed.py
+29
-0
hapi/download.py
hapi/download.py
+9
-0
hapi/logger.py
hapi/logger.py
+1
-1
hapi/loss.py
hapi/loss.py
+33
-4
hapi/vision/models/darknet.py
hapi/vision/models/darknet.py
+19
-0
hapi/vision/models/lenet.py
hapi/vision/models/lenet.py
+7
-0
hapi/vision/models/mobilenetv1.py
hapi/vision/models/mobilenetv1.py
+21
-0
hapi/vision/models/mobilenetv2.py
hapi/vision/models/mobilenetv2.py
+21
-0
hapi/vision/models/resnet.py
hapi/vision/models/resnet.py
+71
-2
hapi/vision/models/vgg.py
hapi/vision/models/vgg.py
+44
-0
hapi/vision/transforms/functional.py
hapi/vision/transforms/functional.py
+29
-0
hapi/vision/transforms/transforms.py
hapi/vision/transforms/transforms.py
+260
-0
未找到文件。
hapi/distributed.py
浏览文件 @
2dd4aa3a
...
@@ -48,6 +48,35 @@ class DistributedBatchSampler(BatchSampler):
...
@@ -48,6 +48,35 @@ class DistributedBatchSampler(BatchSampler):
batch indices. Default False.
batch indices. Default False.
drop_last(bool): whether drop the last incomplete batch dataset size
drop_last(bool): whether drop the last incomplete batch dataset size
is not divisible by the batch size. Default False
is not divisible by the batch size. Default False
Examples:
.. code-block:: python
import numpy as np
from hapi.datasets import MNIST
from hapi.distributed import DistributedBatchSampler
class MnistDataset(MNIST):
def __init__(self, mode, return_label=True):
super(MnistDataset, self).__init__(mode=mode)
self.return_label = return_label
def __getitem__(self, idx):
img = np.reshape(self.images[idx], [1, 28, 28])
if self.return_label:
return img, np.array(self.labels[idx]).astype('int64')
return img,
def __len__(self):
return len(self.images)
train_dataset = MnistDataset(mode='train')
dist_train_dataloader = DistributedBatchSampler(train_dataset, batch_size=64)
for data in dist_train_dataloader:
# do something
break
"""
"""
def
__init__
(
self
,
dataset
,
batch_size
,
shuffle
=
False
,
drop_last
=
False
):
def
__init__
(
self
,
dataset
,
batch_size
,
shuffle
=
False
,
drop_last
=
False
):
...
...
hapi/download.py
浏览文件 @
2dd4aa3a
...
@@ -55,6 +55,15 @@ def get_weights_path_from_url(url, md5sum=None):
...
@@ -55,6 +55,15 @@ def get_weights_path_from_url(url, md5sum=None):
Returns:
Returns:
str: a local path to save downloaded weights.
str: a local path to save downloaded weights.
Examples:
.. code-block:: python
from hapi.download import get_weights_path_from_url
resnet18_pretrained_weight_url = 'https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams'
local_weight_path = get_weights_path_from_url(resnet18_pretrained_weight_url)
"""
"""
path
=
get_path_from_url
(
url
,
WEIGHTS_HOME
,
md5sum
)
path
=
get_path_from_url
(
url
,
WEIGHTS_HOME
,
md5sum
)
return
path
return
path
...
...
hapi/logger.py
浏览文件 @
2dd4aa3a
...
@@ -38,7 +38,7 @@ def setup_logger(output=None, name="hapi", log_level=logging.INFO):
...
@@ -38,7 +38,7 @@ def setup_logger(output=None, name="hapi", log_level=logging.INFO):
# stdout logging: only local rank==0
# stdout logging: only local rank==0
local_rank
=
ParallelEnv
().
local_rank
local_rank
=
ParallelEnv
().
local_rank
if
local_rank
==
0
:
if
local_rank
==
0
and
not
logger
.
hasHandlers
()
:
ch
=
logging
.
StreamHandler
(
stream
=
sys
.
stdout
)
ch
=
logging
.
StreamHandler
(
stream
=
sys
.
stdout
)
ch
.
setLevel
(
log_level
)
ch
.
setLevel
(
log_level
)
...
...
hapi/loss.py
浏览文件 @
2dd4aa3a
...
@@ -30,8 +30,8 @@ class Loss(object):
...
@@ -30,8 +30,8 @@ class Loss(object):
Base class for loss, encapsulates loss logic and APIs
Base class for loss, encapsulates loss logic and APIs
Usage:
Usage:
custom_loss = CustomLoss()
custom_loss = CustomLoss()
loss = custom_loss(inputs, labels)
loss = custom_loss(inputs, labels)
"""
"""
def
__init__
(
self
,
average
=
True
):
def
__init__
(
self
,
average
=
True
):
...
@@ -63,10 +63,25 @@ class CrossEntropy(Loss):
...
@@ -63,10 +63,25 @@ class CrossEntropy(Loss):
average (bool, optional): Indicate whether to average the loss, Default: True.
average (bool, optional): Indicate whether to average the loss, Default: True.
Returns:
Returns:
list[Variable]: The tensor variable storing the cross_entropy_loss of inputs and labels.
list[Variable]: The tensor variable storing the cross_entropy_loss of inputs and labels.
Examples:
.. code-block:: python
from hapi.model import Input
from hapi.vision.models import LeNet
from hapi.loss import CrossEntropy
inputs = [Input([-1, 1, 28, 28], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
model = LeNet()
loss = CrossEntropy()
model.prepare(loss_function=loss, inputs=inputs, labels=labels)
"""
"""
def
__init__
(
self
,
average
=
True
):
def
__init__
(
self
,
average
=
True
):
super
(
CrossEntropy
,
self
).
__init__
()
super
(
CrossEntropy
,
self
).
__init__
(
average
)
def
forward
(
self
,
outputs
,
labels
):
def
forward
(
self
,
outputs
,
labels
):
return
[
return
[
...
@@ -85,10 +100,24 @@ class SoftmaxWithCrossEntropy(Loss):
...
@@ -85,10 +100,24 @@ class SoftmaxWithCrossEntropy(Loss):
average (bool, optional): Indicate whether to average the loss, Default: True.
average (bool, optional): Indicate whether to average the loss, Default: True.
Returns:
Returns:
list[Variable]: The tensor variable storing the cross_entropy_loss of inputs and labels.
list[Variable]: The tensor variable storing the cross_entropy_loss of inputs and labels.
Examples:
.. code-block:: python
from hapi.model import Input
from hapi.vision.models import LeNet
from hapi.loss import SoftmaxWithCrossEntropy
inputs = [Input([-1, 1, 28, 28], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
model = LeNet(classifier_activation=None)
loss = SoftmaxWithCrossEntropy()
model.prepare(loss_function=loss, inputs=inputs, labels=labels)
"""
"""
def
__init__
(
self
,
average
=
True
):
def
__init__
(
self
,
average
=
True
):
super
(
SoftmaxWithCrossEntropy
,
self
).
__init__
()
super
(
SoftmaxWithCrossEntropy
,
self
).
__init__
(
average
)
def
forward
(
self
,
outputs
,
labels
):
def
forward
(
self
,
outputs
,
labels
):
return
[
return
[
...
...
hapi/vision/models/darknet.py
浏览文件 @
2dd4aa3a
...
@@ -144,6 +144,13 @@ class DarkNet(Model):
...
@@ -144,6 +144,13 @@ class DarkNet(Model):
will not be defined. Default: 1000.
will not be defined. Default: 1000.
with_pool (bool): use pool before the last fc layer or not. Default: True.
with_pool (bool): use pool before the last fc layer or not. Default: True.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
Examples:
.. code-block:: python
from hapi.vision.models import DarkNet
model = DarkNet()
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -233,5 +240,17 @@ def darknet53(pretrained=False, **kwargs):
...
@@ -233,5 +240,17 @@ def darknet53(pretrained=False, **kwargs):
input_channels (bool): channel number of input data, default 3.
input_channels (bool): channel number of input data, default 3.
pretrained (bool): If True, returns a model pre-trained on ImageNet,
pretrained (bool): If True, returns a model pre-trained on ImageNet,
default True.
default True.
Examples:
.. code-block:: python
from hapi.vision.models import darknet53
# build model
model = darknet53()
#build model and load imagenet pretrained weight
model = darknet53(pretrained=True)
"""
"""
return
_darknet
(
'darknet53'
,
53
,
pretrained
,
**
kwargs
)
return
_darknet
(
'darknet53'
,
53
,
pretrained
,
**
kwargs
)
hapi/vision/models/lenet.py
浏览文件 @
2dd4aa3a
...
@@ -29,6 +29,13 @@ class LeNet(Model):
...
@@ -29,6 +29,13 @@ class LeNet(Model):
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 10.
will not be defined. Default: 10.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
Examples:
.. code-block:: python
from hapi.vision.models import LeNet
model = LeNet()
"""
"""
def
__init__
(
self
,
num_classes
=
10
,
classifier_activation
=
'softmax'
):
def
__init__
(
self
,
num_classes
=
10
,
classifier_activation
=
'softmax'
):
...
...
hapi/vision/models/mobilenetv1.py
浏览文件 @
2dd4aa3a
...
@@ -115,6 +115,13 @@ class MobileNetV1(Model):
...
@@ -115,6 +115,13 @@ class MobileNetV1(Model):
will not be defined. Default: 1000.
will not be defined. Default: 1000.
with_pool (bool): use pool before the last fc layer or not. Default: True.
with_pool (bool): use pool before the last fc layer or not. Default: True.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
Examples:
.. code-block:: python
from hapi.vision.models import MobileNetV1
model = MobileNetV1()
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -282,6 +289,20 @@ def mobilenet_v1(pretrained=False, scale=1.0, **kwargs):
...
@@ -282,6 +289,20 @@ def mobilenet_v1(pretrained=False, scale=1.0, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
scale: (float): scale of channels in each layer. Default: 1.0.
scale: (float): scale of channels in each layer. Default: 1.0.
Examples:
.. code-block:: python
from hapi.vision.models import mobilenet_v1
# build model
model = mobilenet_v1()
#build model and load imagenet pretrained weight
model = mobilenet_v1(pretrained=True)
#build mobilenet v1 with scale=0.5
model = mobilenet_v1(scale=0.5)
"""
"""
model
=
_mobilenet
(
model
=
_mobilenet
(
'mobilenetv1_'
+
str
(
scale
),
pretrained
,
scale
=
scale
,
**
kwargs
)
'mobilenetv1_'
+
str
(
scale
),
pretrained
,
scale
=
scale
,
**
kwargs
)
...
...
hapi/vision/models/mobilenetv2.py
浏览文件 @
2dd4aa3a
...
@@ -160,6 +160,13 @@ class MobileNetV2(Model):
...
@@ -160,6 +160,13 @@ class MobileNetV2(Model):
will not be defined. Default: 1000.
will not be defined. Default: 1000.
with_pool (bool): use pool before the last fc layer or not. Default: True.
with_pool (bool): use pool before the last fc layer or not. Default: True.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
Examples:
.. code-block:: python
from hapi.vision.models import MobileNetV2
model = MobileNetV2()
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -256,6 +263,20 @@ def mobilenet_v2(pretrained=False, scale=1.0, **kwargs):
...
@@ -256,6 +263,20 @@ def mobilenet_v2(pretrained=False, scale=1.0, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
scale: (float): scale of channels in each layer. Default: 1.0.
scale: (float): scale of channels in each layer. Default: 1.0.
Examples:
.. code-block:: python
from hapi.vision.models import mobilenet_v2
# build model
model = mobilenet_v2()
#build model and load imagenet pretrained weight
model = mobilenet_v2(pretrained=True)
#build mobilenet v2 with scale=0.5
model = mobilenet_v2(scale=0.5)
"""
"""
model
=
_mobilenet
(
model
=
_mobilenet
(
'mobilenetv2_'
+
str
(
scale
),
pretrained
,
scale
=
scale
,
**
kwargs
)
'mobilenetv2_'
+
str
(
scale
),
pretrained
,
scale
=
scale
,
**
kwargs
)
...
...
hapi/vision/models/resnet.py
浏览文件 @
2dd4aa3a
...
@@ -26,7 +26,8 @@ from hapi.model import Model
...
@@ -26,7 +26,8 @@ from hapi.model import Model
from
hapi.download
import
get_weights_path_from_url
from
hapi.download
import
get_weights_path_from_url
__all__
=
[
__all__
=
[
'ResNet'
,
'resnet18'
,
'resnet34'
,
'resnet50'
,
'resnet101'
,
'resnet152'
'ResNet'
,
'resnet18'
,
'resnet34'
,
'resnet50'
,
'resnet101'
,
'resnet152'
,
'BottleneckBlock'
,
'BasicBlock'
]
]
model_urls
=
{
model_urls
=
{
...
@@ -75,7 +76,8 @@ class ConvBNLayer(fluid.dygraph.Layer):
...
@@ -75,7 +76,8 @@ class ConvBNLayer(fluid.dygraph.Layer):
class
BasicBlock
(
fluid
.
dygraph
.
Layer
):
class
BasicBlock
(
fluid
.
dygraph
.
Layer
):
"""residual block of resnet18 and resnet34
"""
expansion
=
1
expansion
=
1
def
__init__
(
self
,
num_channels
,
num_filters
,
stride
,
shortcut
=
True
):
def
__init__
(
self
,
num_channels
,
num_filters
,
stride
,
shortcut
=
True
):
...
@@ -117,6 +119,8 @@ class BasicBlock(fluid.dygraph.Layer):
...
@@ -117,6 +119,8 @@ class BasicBlock(fluid.dygraph.Layer):
class
BottleneckBlock
(
fluid
.
dygraph
.
Layer
):
class
BottleneckBlock
(
fluid
.
dygraph
.
Layer
):
"""residual block of resnet50, resnet101 amd resnet152
"""
expansion
=
4
expansion
=
4
...
@@ -177,6 +181,16 @@ class ResNet(Model):
...
@@ -177,6 +181,16 @@ class ResNet(Model):
will not be defined. Default: 1000.
will not be defined. Default: 1000.
with_pool (bool): use pool before the last fc layer or not. Default: True.
with_pool (bool): use pool before the last fc layer or not. Default: True.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
Examples:
.. code-block:: python
from hapi.vision.models import ResNet, BottleneckBlock, BasicBlock
resnet50 = ResNet(BottleneckBlock, 50)
resnet18 = ResNet(BasicBlock, 18)
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -280,6 +294,17 @@ def resnet18(pretrained=False, **kwargs):
...
@@ -280,6 +294,17 @@ def resnet18(pretrained=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from hapi.vision.models import resnet18
# build model
model = resnet18()
#build model and load imagenet pretrained weight
model = resnet18(pretrained=True)
"""
"""
return
_resnet
(
'resnet18'
,
BasicBlock
,
18
,
pretrained
,
**
kwargs
)
return
_resnet
(
'resnet18'
,
BasicBlock
,
18
,
pretrained
,
**
kwargs
)
...
@@ -289,6 +314,17 @@ def resnet34(pretrained=False, **kwargs):
...
@@ -289,6 +314,17 @@ def resnet34(pretrained=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from hapi.vision.models import resnet34
# build model
model = resnet34()
#build model and load imagenet pretrained weight
model = resnet34(pretrained=True)
"""
"""
return
_resnet
(
'resnet34'
,
BasicBlock
,
34
,
pretrained
,
**
kwargs
)
return
_resnet
(
'resnet34'
,
BasicBlock
,
34
,
pretrained
,
**
kwargs
)
...
@@ -298,6 +334,17 @@ def resnet50(pretrained=False, **kwargs):
...
@@ -298,6 +334,17 @@ def resnet50(pretrained=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from hapi.vision.models import resnet50
# build model
model = resnet50()
#build model and load imagenet pretrained weight
model = resnet50(pretrained=True)
"""
"""
return
_resnet
(
'resnet50'
,
BottleneckBlock
,
50
,
pretrained
,
**
kwargs
)
return
_resnet
(
'resnet50'
,
BottleneckBlock
,
50
,
pretrained
,
**
kwargs
)
...
@@ -307,6 +354,17 @@ def resnet101(pretrained=False, **kwargs):
...
@@ -307,6 +354,17 @@ def resnet101(pretrained=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from hapi.vision.models import resnet101
# build model
model = resnet101()
#build model and load imagenet pretrained weight
model = resnet101(pretrained=True)
"""
"""
return
_resnet
(
'resnet101'
,
BottleneckBlock
,
101
,
pretrained
,
**
kwargs
)
return
_resnet
(
'resnet101'
,
BottleneckBlock
,
101
,
pretrained
,
**
kwargs
)
...
@@ -316,5 +374,16 @@ def resnet152(pretrained=False, **kwargs):
...
@@ -316,5 +374,16 @@ def resnet152(pretrained=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from hapi.vision.models import resnet152
# build model
model = resnet152()
#build model and load imagenet pretrained weight
model = resnet152(pretrained=True)
"""
"""
return
_resnet
(
'resnet152'
,
BottleneckBlock
,
152
,
pretrained
,
**
kwargs
)
return
_resnet
(
'resnet152'
,
BottleneckBlock
,
152
,
pretrained
,
**
kwargs
)
hapi/vision/models/vgg.py
浏览文件 @
2dd4aa3a
...
@@ -143,6 +143,17 @@ def vgg11(pretrained=False, batch_norm=False, **kwargs):
...
@@ -143,6 +143,17 @@ def vgg11(pretrained=False, batch_norm=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
Examples:
.. code-block:: python
from hapi.vision.models import vgg11
# build model
model = vgg11()
#build vgg11 model with batch_norm
model = vgg11(batch_norm=True)
"""
"""
model_name
=
'vgg11'
model_name
=
'vgg11'
if
batch_norm
:
if
batch_norm
:
...
@@ -156,6 +167,17 @@ def vgg13(pretrained=False, batch_norm=False, **kwargs):
...
@@ -156,6 +167,17 @@ def vgg13(pretrained=False, batch_norm=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
Examples:
.. code-block:: python
from hapi.vision.models import vgg13
# build model
model = vgg13()
#build vgg13 model with batch_norm
model = vgg13(batch_norm=True)
"""
"""
model_name
=
'vgg13'
model_name
=
'vgg13'
if
batch_norm
:
if
batch_norm
:
...
@@ -169,6 +191,17 @@ def vgg16(pretrained=False, batch_norm=False, **kwargs):
...
@@ -169,6 +191,17 @@ def vgg16(pretrained=False, batch_norm=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
Examples:
.. code-block:: python
from hapi.vision.models import vgg16
# build model
model = vgg16()
#build vgg16 model with batch_norm
model = vgg16(batch_norm=True)
"""
"""
model_name
=
'vgg16'
model_name
=
'vgg16'
if
batch_norm
:
if
batch_norm
:
...
@@ -182,6 +215,17 @@ def vgg19(pretrained=False, batch_norm=False, **kwargs):
...
@@ -182,6 +215,17 @@ def vgg19(pretrained=False, batch_norm=False, **kwargs):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
Examples:
.. code-block:: python
from hapi.vision.models import vgg19
# build model
model = vgg19()
#build vgg19 model with batch_norm
model = vgg19(batch_norm=True)
"""
"""
model_name
=
'vgg19'
model_name
=
'vgg19'
if
batch_norm
:
if
batch_norm
:
...
...
hapi/vision/transforms/functional.py
浏览文件 @
2dd4aa3a
...
@@ -39,6 +39,23 @@ def flip(image, code):
...
@@ -39,6 +39,23 @@ def flip(image, code):
-1 : Flip horizontally and vertically
-1 : Flip horizontally and vertically
0 : Flip vertically
0 : Flip vertically
1 : Flip horizontally
1 : Flip horizontally
Examples:
.. code-block:: python
import numpy as np
from hapi.vision.transforms import functional as F
fake_img = np.random.rand(224, 224, 3)
# flip horizontally and vertically
F.flip(fake_img, -1)
# flip vertically
F.flip(fake_img, 0)
# flip horizontally
F.flip(fake_img, 1)
"""
"""
return
cv2
.
flip
(
image
,
flipCode
=
code
)
return
cv2
.
flip
(
image
,
flipCode
=
code
)
...
@@ -51,6 +68,18 @@ def resize(img, size, interpolation=cv2.INTER_LINEAR):
...
@@ -51,6 +68,18 @@ def resize(img, size, interpolation=cv2.INTER_LINEAR):
input: Input data, could be image or masks, with (H, W, C) shape
input: Input data, could be image or masks, with (H, W, C) shape
size: Target size of input data, with (height, width) shape.
size: Target size of input data, with (height, width) shape.
interpolation: Interpolation method.
interpolation: Interpolation method.
Examples:
.. code-block:: python
import numpy as np
from hapi.vision.transforms import functional as F
fake_img = np.random.rand(256, 256, 3)
F.resize(fake_img, 224)
F.resize(fake_img, (200, 150))
"""
"""
if
isinstance
(
interpolation
,
Sequence
):
if
isinstance
(
interpolation
,
Sequence
):
...
...
hapi/vision/transforms/transforms.py
浏览文件 @
2dd4aa3a
...
@@ -118,6 +118,67 @@ class BatchCompose(object):
...
@@ -118,6 +118,67 @@ class BatchCompose(object):
transforms (list of ``Transform`` objects): list of transforms to compose.
transforms (list of ``Transform`` objects): list of transforms to compose.
these transforms perform on batch data.
these transforms perform on batch data.
Examples:
.. code-block:: python
import numpy as np
from paddle.io import DataLoader
from hapi.model import set_device
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, BatchCompose, Resize
class NormalizeBatch(object):
def __init__(self,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
scale=True,
channel_first=True):
self.mean = mean
self.std = std
self.scale = scale
self.channel_first = channel_first
if not (isinstance(self.mean, list) and isinstance(self.std, list) and
isinstance(self.scale, bool)):
raise TypeError("{}: input type is invalid.".format(self))
from functools import reduce
if reduce(lambda x, y: x * y, self.std) == 0:
raise ValueError('{}: std is invalid!'.format(self))
def __call__(self, samples):
for i in range(len(samples)):
samples[i] = list(samples[i])
im = samples[i][0]
im = im.astype(np.float32, copy=False)
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
if self.scale:
im = im / 255.0
im -= mean
im /= std
if self.channel_first:
im = im.transpose((2, 0, 1))
samples[i][0] = im
return samples
transform = Compose([Resize((500, 500))])
flowers_dataset = Flowers(mode='test', transform=transform)
device = set_device('cpu')
collate_fn = BatchCompose([NormalizeBatch()])
loader = DataLoader(
flowers_dataset,
batch_size=4,
places=device,
return_list=True,
collate_fn=collate_fn)
for data in loader:
# do something
break
"""
"""
def
__init__
(
self
,
transforms
=
[]):
def
__init__
(
self
,
transforms
=
[]):
...
@@ -149,6 +210,20 @@ class Resize(object):
...
@@ -149,6 +210,20 @@ class Resize(object):
i.e, if height > width, then image will be rescaled to
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
(size * height / width, size)
interpolation (int): interpolation mode of resize. Default: cv2.INTER_LINEAR.
interpolation (int): interpolation mode of resize. Default: cv2.INTER_LINEAR.
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, Resize
transform = Compose([Resize(size=224)])
flowers = Flowers(mode='test', transform=transform)
for i in range(10):
sample = flowers[i]
print(sample[0].shape, sample[1])
"""
"""
def
__init__
(
self
,
size
,
interpolation
=
cv2
.
INTER_LINEAR
):
def
__init__
(
self
,
size
,
interpolation
=
cv2
.
INTER_LINEAR
):
...
@@ -171,6 +246,20 @@ class RandomResizedCrop(object):
...
@@ -171,6 +246,20 @@ class RandomResizedCrop(object):
output_size (int|list|tuple): Target size of output image, with (height, width) shape.
output_size (int|list|tuple): Target size of output image, with (height, width) shape.
scale (list|tuple): Range of size of the origin size cropped. Default: (0.08, 1.0)
scale (list|tuple): Range of size of the origin size cropped. Default: (0.08, 1.0)
ratio (list|tuple): Range of aspect ratio of the origin aspect ratio cropped. Default: (0.75, 1.33)
ratio (list|tuple): Range of aspect ratio of the origin aspect ratio cropped. Default: (0.75, 1.33)
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, Resize, RandomResizedCrop
transform = Compose([Resize(500), RandomResizedCrop(224)])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -233,6 +322,20 @@ class CenterCropResize(object):
...
@@ -233,6 +322,20 @@ class CenterCropResize(object):
size (int|list|tuple): Target size of output image, with (height, width) shape.
size (int|list|tuple): Target size of output image, with (height, width) shape.
crop_padding (int): center crop with the padding. Default: 32.
crop_padding (int): center crop with the padding. Default: 32.
interpolation (int): interpolation mode of resize. Default: cv2.INTER_LINEAR.
interpolation (int): interpolation mode of resize. Default: cv2.INTER_LINEAR.
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, Resize, CenterCropResize
transform = Compose([Resize(500), CenterCropResize(224)])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
"""
"""
def
__init__
(
self
,
size
,
crop_padding
=
32
,
interpolation
=
cv2
.
INTER_LINEAR
):
def
__init__
(
self
,
size
,
crop_padding
=
32
,
interpolation
=
cv2
.
INTER_LINEAR
):
...
@@ -262,6 +365,20 @@ class CenterCrop(object):
...
@@ -262,6 +365,20 @@ class CenterCrop(object):
Args:
Args:
output_size: Target size of output image, with (height, width) shape.
output_size: Target size of output image, with (height, width) shape.
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, Resize, CenterCrop
transform = Compose([Resize(500), CenterCrop(224)])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
"""
"""
def
__init__
(
self
,
output_size
):
def
__init__
(
self
,
output_size
):
...
@@ -289,6 +406,20 @@ class RandomHorizontalFlip(object):
...
@@ -289,6 +406,20 @@ class RandomHorizontalFlip(object):
Args:
Args:
prob (float): probability of the input data being flipped. Default: 0.5
prob (float): probability of the input data being flipped. Default: 0.5
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, RandomHorizontalFlip
transform = Compose([RandomHorizontalFlip()])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
"""
"""
def
__init__
(
self
,
prob
=
0.5
):
def
__init__
(
self
,
prob
=
0.5
):
...
@@ -305,6 +436,20 @@ class RandomVerticalFlip(object):
...
@@ -305,6 +436,20 @@ class RandomVerticalFlip(object):
Args:
Args:
prob (float): probability of the input data being flipped. Default: 0.5
prob (float): probability of the input data being flipped. Default: 0.5
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, RandomVerticalFlip
transform = Compose([RandomVerticalFlip()])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
"""
"""
def
__init__
(
self
,
prob
=
0.5
):
def
__init__
(
self
,
prob
=
0.5
):
...
@@ -325,6 +470,23 @@ class Normalize(object):
...
@@ -325,6 +470,23 @@ class Normalize(object):
Args:
Args:
mean (int|float|list): Sequence of means for each channel.
mean (int|float|list): Sequence of means for each channel.
std (int|float|list): Sequence of standard deviations for each channel.
std (int|float|list): Sequence of standard deviations for each channel.
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, Normalize, Permute
normalize = Normalize(mean=[123.675, 116.28, 103.53],
std=[58.395, 57.120, 57.375])
transform = Compose([Permute(), normalize])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
"""
"""
...
@@ -351,6 +513,20 @@ class Permute(object):
...
@@ -351,6 +513,20 @@ class Permute(object):
Args:
Args:
mode: Output mode of input. Default: "CHW".
mode: Output mode of input. Default: "CHW".
to_rgb: convert 'bgr' image to 'rgb'. Default: True.
to_rgb: convert 'bgr' image to 'rgb'. Default: True.
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, Permute
transform = Compose([Permute()])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
"""
"""
def
__init__
(
self
,
mode
=
"CHW"
,
to_rgb
=
True
):
def
__init__
(
self
,
mode
=
"CHW"
,
to_rgb
=
True
):
...
@@ -375,6 +551,20 @@ class GaussianNoise(object):
...
@@ -375,6 +551,20 @@ class GaussianNoise(object):
Args:
Args:
mean: Gaussian mean used to generate noise.
mean: Gaussian mean used to generate noise.
std: Gaussian standard deviation used to generate noise.
std: Gaussian standard deviation used to generate noise.
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, GaussianNoise
transform = Compose([GaussianNoise()])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
"""
"""
def
__init__
(
self
,
mean
=
0.0
,
std
=
1.0
):
def
__init__
(
self
,
mean
=
0.0
,
std
=
1.0
):
...
@@ -394,6 +584,20 @@ class BrightnessTransform(object):
...
@@ -394,6 +584,20 @@ class BrightnessTransform(object):
Args:
Args:
value: How much to adjust the brightness. Can be any
value: How much to adjust the brightness. Can be any
non negative number. 0 gives the original image
non negative number. 0 gives the original image
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, BrightnessTransform
transform = Compose([BrightnessTransform(0.4)])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
"""
"""
def
__init__
(
self
,
value
):
def
__init__
(
self
,
value
):
...
@@ -418,6 +622,20 @@ class ContrastTransform(object):
...
@@ -418,6 +622,20 @@ class ContrastTransform(object):
Args:
Args:
value: How much to adjust the contrast. Can be any
value: How much to adjust the contrast. Can be any
non negative number. 0 gives the original image
non negative number. 0 gives the original image
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, ContrastTransform
transform = Compose([ContrastTransform(0.4)])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
"""
"""
def
__init__
(
self
,
value
):
def
__init__
(
self
,
value
):
...
@@ -443,6 +661,20 @@ class SaturationTransform(object):
...
@@ -443,6 +661,20 @@ class SaturationTransform(object):
Args:
Args:
value: How much to adjust the saturation. Can be any
value: How much to adjust the saturation. Can be any
non negative number. 0 gives the original image
non negative number. 0 gives the original image
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, SaturationTransform
transform = Compose([SaturationTransform(0.4)])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
"""
"""
def
__init__
(
self
,
value
):
def
__init__
(
self
,
value
):
...
@@ -469,6 +701,20 @@ class HueTransform(object):
...
@@ -469,6 +701,20 @@ class HueTransform(object):
Args:
Args:
value: How much to adjust the hue. Can be any number
value: How much to adjust the hue. Can be any number
between 0 and 0.5, 0 gives the original image
between 0 and 0.5, 0 gives the original image
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, HueTransform
transform = Compose([HueTransform(0.4)])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
"""
"""
def
__init__
(
self
,
value
):
def
__init__
(
self
,
value
):
...
@@ -510,6 +756,20 @@ class ColorJitter(object):
...
@@ -510,6 +756,20 @@ class ColorJitter(object):
hue: How much to jitter hue.
hue: How much to jitter hue.
Chosen uniformly from [-hue, hue] or the given [min, max].
Chosen uniformly from [-hue, hue] or the given [min, max].
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
Examples:
.. code-block:: python
from hapi.datasets import Flowers
from hapi.vision.transforms import Compose, ColorJitter
transform = Compose([ColorJitter(0.4)])
flowers = Flowers(mode='test', transform=transform)
for i in range(2):
sample = flowers[i]
print(sample[0].shape, sample[1])
"""
"""
def
__init__
(
self
,
brightness
=
0
,
contrast
=
0
,
saturation
=
0
,
hue
=
0
):
def
__init__
(
self
,
brightness
=
0
,
contrast
=
0
,
saturation
=
0
,
hue
=
0
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录