未验证 提交 1a0d26a4 编写于 作者: L LielinJiang 提交者: GitHub

Add vision api for hapi (#24404)

* add vision

* fix predict, test=develop

* add unittest for vision apis, test=develop

* fix typos

* add hapi models api, test=develop

* fix code format, test=develop

* fix typos, test=develop

* fix sample code import, test=develop

* fix sample codes, test=develop

* add decompress, test=develop

* rm darknet, test=develop

* rm debug code, test=develop
上级 046b7ebc
......@@ -22,6 +22,8 @@ import os.path as osp
import shutil
import requests
import hashlib
import tarfile
import zipfile
import time
from collections import OrderedDict
from paddle.fluid.dygraph.parallel import ParallelEnv
......@@ -166,6 +168,11 @@ def get_path_from_url(url, root_dir, md5sum=None, check_exist=True):
else:
while not os.path.exists(fullpath):
time.sleep(1)
if ParallelEnv().local_rank == 0:
if tarfile.is_tarfile(fullpath) or zipfile.is_zipfile(fullpath):
fullpath = _decompress(fullpath)
return fullpath
......@@ -233,3 +240,101 @@ def _md5check(fullname, md5sum=None):
"{}(base)".format(fullname, calc_md5sum, md5sum))
return False
return True
def _decompress(fname):
"""
Decompress for zip and tar file
"""
logger.info("Decompressing {}...".format(fname))
# For protecting decompressing interupted,
# decompress to fpath_tmp directory firstly, if decompress
# successed, move decompress files to fpath and delete
# fpath_tmp and remove download compress file.
if tarfile.is_tarfile(fname):
uncompressed_path = _uncompress_file_tar(fname)
elif zipfile.is_zipfile(fname):
uncompressed_path = _uncompress_file_zip(fname)
else:
raise TypeError("Unsupport compress file type {}".format(fname))
return uncompressed_path
def _uncompress_file_zip(filepath):
files = zipfile.ZipFile(filepath, 'r')
file_list = files.namelist()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
for item in file_list:
files.extract(item, os.path.join(file_dir, rootpath))
files.close()
return uncompressed_path
def _uncompress_file_tar(filepath, mode="r:*"):
files = tarfile.open(filepath, mode)
file_list = files.getnames()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
for item in file_list:
files.extract(item, os.path.join(file_dir, rootpath))
files.close()
return uncompressed_path
def _is_a_single_file(file_list):
if len(file_list) == 1 and file_list[0].find(os.sep) < -1:
return True
return False
def _is_a_single_dir(file_list):
file_name = file_list[0].split(os.sep)[0]
for i in range(1, len(file_list)):
if file_name != file_list[i].split(os.sep)[0]:
return False
return True
......@@ -42,7 +42,7 @@ class Metric(object):
m.accumulate()
Advanced usage for :code:`add_metric_op`
Metric calculating con be accelerate by calucateing metric states
Metric calculation can be accelerated by calculating metric states
from model outputs and labels by Paddle OPs in :code:`add_metric_op`,
metric states will be fetch as numpy array and call :code:`update`
with states in numpy format.
......
......@@ -1633,7 +1633,12 @@ class Model(fluid.dygraph.Layer):
for k, v in zip(self._metrics_name(), metrics):
logs[k] = v
else:
outs = getattr(self, mode + '_batch')(data)
if self._inputs is not None:
outs = getattr(self,
mode + '_batch')(data[:len(self._inputs)])
else:
outs = getattr(self, mode + '_batch')(data)
outputs.append(outs)
logs['step'] = step
......
......@@ -45,6 +45,18 @@ class TestDownload(unittest.TestCase):
url = 'https://paddle-hapi.bj.bcebos.com/models/mobilenet_v2_x1.0t.pdparams'
self.download(url, None)
def test_download_and_uncompress(self):
urls = [
"https://paddle-hapi.bj.bcebos.com/unittest/files.tar",
"https://paddle-hapi.bj.bcebos.com/unittest/files.zip",
"https://paddle-hapi.bj.bcebos.com/unittest/single_dir.tar",
"https://paddle-hapi.bj.bcebos.com/unittest/single_dir.zip",
"https://paddle-hapi.bj.bcebos.com/unittest/single_file.tar",
"https://paddle-hapi.bj.bcebos.com/unittest/single_file.zip",
]
for url in urls:
self.download(url, None)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# when test, you should add hapi root path to the PYTHONPATH,
# export PYTHONPATH=PATH_TO_HAPI:$PYTHONPATH
import unittest
import os
import tempfile
import cv2
import shutil
import numpy as np
from paddle.incubate.hapi.datasets import DatasetFolder
from paddle.incubate.hapi.vision.transforms import transforms
class TestTransforms(unittest.TestCase):
def setUp(self):
self.data_dir = tempfile.mkdtemp()
for i in range(2):
sub_dir = os.path.join(self.data_dir, 'class_' + str(i))
if not os.path.exists(sub_dir):
os.makedirs(sub_dir)
for j in range(2):
if j == 0:
fake_img = (np.random.random(
(280, 350, 3)) * 255).astype('uint8')
else:
fake_img = (np.random.random(
(400, 300, 3)) * 255).astype('uint8')
cv2.imwrite(os.path.join(sub_dir, str(j) + '.jpg'), fake_img)
def tearDown(self):
shutil.rmtree(self.data_dir)
def do_transform(self, trans):
dataset_folder = DatasetFolder(self.data_dir, transform=trans)
for _ in dataset_folder:
pass
def test_trans_all(self):
normalize = transforms.Normalize(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375])
trans = transforms.Compose([
transforms.RandomResizedCrop(224), transforms.GaussianNoise(),
transforms.ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.4,
hue=0.4), transforms.RandomHorizontalFlip(),
transforms.Permute(mode='CHW'), normalize
])
self.do_transform(trans)
def test_trans_resize(self):
trans = transforms.Compose([
transforms.Resize(300, [0, 1]),
transforms.RandomResizedCrop((280, 280)),
transforms.Resize(280, [0, 1]),
transforms.Resize((256, 200)),
transforms.Resize((180, 160)),
transforms.CenterCrop(128),
transforms.CenterCrop((128, 128)),
])
self.do_transform(trans)
def test_trans_centerCrop(self):
trans = transforms.Compose([
transforms.CenterCropResize(224),
transforms.CenterCropResize(128, 160),
])
self.do_transform(trans)
def test_flip(self):
trans = transforms.Compose([
transforms.RandomHorizontalFlip(1.0),
transforms.RandomHorizontalFlip(0.0),
transforms.RandomVerticalFlip(0.0),
transforms.RandomVerticalFlip(1.0),
])
self.do_transform(trans)
def test_color_jitter(self):
trans = transforms.BatchCompose([
transforms.BrightnessTransform(0.0),
transforms.HueTransform(0.0),
transforms.SaturationTransform(0.0),
transforms.ContrastTransform(0.0),
])
self.do_transform(trans)
def test_exception(self):
trans = transforms.Compose([transforms.Resize(-1)])
trans_batch = transforms.BatchCompose([transforms.Resize(-1)])
with self.assertRaises(Exception):
self.do_transform(trans)
with self.assertRaises(Exception):
self.do_transform(trans_batch)
with self.assertRaises(ValueError):
transforms.ContrastTransform(-1.0)
with self.assertRaises(ValueError):
transforms.SaturationTransform(-1.0),
with self.assertRaises(ValueError):
transforms.HueTransform(-1.0)
with self.assertRaises(ValueError):
transforms.BrightnessTransform(-1.0)
def test_info(self):
str(transforms.Compose([transforms.Resize((224, 224))]))
str(transforms.BatchCompose([transforms.Resize((224, 224))]))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle.incubate.hapi.vision.models as models
from paddle.incubate.hapi.model import Input
class TestVisonModels(unittest.TestCase):
def models_infer(self, arch, pretrained=False, batch_norm=False):
x = np.array(np.random.random((2, 3, 224, 224)), dtype=np.float32)
if batch_norm:
model = models.__dict__[arch](pretrained=pretrained,
batch_norm=True)
else:
model = models.__dict__[arch](pretrained=pretrained)
inputs = [Input([None, 3, 224, 224], 'float32', name='image')]
model.prepare(inputs=inputs)
model.test_batch(x)
def test_mobilenetv2_pretrained(self):
self.models_infer('mobilenet_v2', pretrained=True)
def test_mobilenetv1(self):
self.models_infer('mobilenet_v1')
def test_vgg11(self):
self.models_infer('vgg11')
def test_vgg13(self):
self.models_infer('vgg13')
def test_vgg16(self):
self.models_infer('vgg16')
def test_vgg16_bn(self):
self.models_infer('vgg16', batch_norm=True)
def test_vgg19(self):
self.models_infer('vgg19')
def test_resnet18(self):
self.models_infer('resnet18')
def test_resnet34(self):
self.models_infer('resnet34')
def test_resnet50(self):
self.models_infer('resnet50')
def test_resnet101(self):
self.models_infer('resnet101')
def test_resnet152(self):
self.models_infer('resnet152')
def test_lenet(self):
lenet = models.__dict__['LeNet']()
inputs = [Input([None, 1, 28, 28], 'float32', name='x')]
lenet.prepare(inputs=inputs)
x = np.array(np.random.random((2, 1, 28, 28)), dtype=np.float32)
lenet.test_batch(x)
if __name__ == '__main__':
unittest.main()
......@@ -13,6 +13,9 @@
# limitations under the License.
from . import models
from . import transforms
from .models import *
from .transforms import *
__all__ = models.__all__
__all__ = models.__all__ \
+ transforms.__all__
......@@ -12,7 +12,20 @@
#See the License for the specific language governing permissions and
#limitations under the License.
from . import resnet
from . import vgg
from . import mobilenetv1
from . import mobilenetv2
from . import lenet
from .resnet import *
from .mobilenetv1 import *
from .mobilenetv2 import *
from .vgg import *
from .lenet import *
__all__ = lenet.__all__
__all__ = resnet.__all__ \
+ vgg.__all__ \
+ mobilenetv1.__all__ \
+ mobilenetv2.__all__ \
+ lenet.__all__
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from ...model import Model
from ...download import get_weights_path_from_url
__all__ = ['MobileNetV1', 'mobilenet_v1']
model_urls = {
'mobilenetv1_1.0':
('https://paddle-hapi.bj.bcebos.com/models/mobilenet_v1_x1.0.pdparams',
'bf0d25cb0bed1114d9dac9384ce2b4a6')
}
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
act='relu',
use_cudnn=True,
name=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(
initializer=MSRA(), name=self.full_name() + "_weights"),
bias_attr=False)
self._batch_norm = BatchNorm(
num_filters,
act=act,
param_attr=ParamAttr(name=self.full_name() + "_bn" + "_scale"),
bias_attr=ParamAttr(name=self.full_name() + "_bn" + "_offset"),
moving_mean_name=self.full_name() + "_bn" + '_mean',
moving_variance_name=self.full_name() + "_bn" + '_variance')
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class DepthwiseSeparable(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters1,
num_filters2,
num_groups,
stride,
scale,
name=None):
super(DepthwiseSeparable, self).__init__()
self._depthwise_conv = ConvBNLayer(
num_channels=num_channels,
num_filters=int(num_filters1 * scale),
filter_size=3,
stride=stride,
padding=1,
num_groups=int(num_groups * scale),
use_cudnn=False)
self._pointwise_conv = ConvBNLayer(
num_channels=int(num_filters1 * scale),
filter_size=1,
num_filters=int(num_filters2 * scale),
stride=1,
padding=0)
def forward(self, inputs):
y = self._depthwise_conv(inputs)
y = self._pointwise_conv(y)
return y
class MobileNetV1(Model):
"""MobileNetV1 model from
`"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" <https://arxiv.org/abs/1704.04861>`_.
Args:
scale (float): scale of channels in each layer. Default: 1.0.
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000.
with_pool (bool): use pool before the last fc layer or not. Default: True.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
Examples:
.. code-block:: python
from paddle.incubate.hapi.vision.models import MobileNetV1
model = MobileNetV1()
"""
def __init__(self,
scale=1.0,
num_classes=1000,
with_pool=True,
classifier_activation='softmax'):
super(MobileNetV1, self).__init__()
self.scale = scale
self.dwsl = []
self.num_classes = num_classes
self.with_pool = with_pool
self.conv1 = ConvBNLayer(
num_channels=3,
filter_size=3,
channels=3,
num_filters=int(32 * scale),
stride=2,
padding=1)
dws21 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(32 * scale),
num_filters1=32,
num_filters2=64,
num_groups=32,
stride=1,
scale=scale),
name="conv2_1")
self.dwsl.append(dws21)
dws22 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(64 * scale),
num_filters1=64,
num_filters2=128,
num_groups=64,
stride=2,
scale=scale),
name="conv2_2")
self.dwsl.append(dws22)
dws31 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(128 * scale),
num_filters1=128,
num_filters2=128,
num_groups=128,
stride=1,
scale=scale),
name="conv3_1")
self.dwsl.append(dws31)
dws32 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(128 * scale),
num_filters1=128,
num_filters2=256,
num_groups=128,
stride=2,
scale=scale),
name="conv3_2")
self.dwsl.append(dws32)
dws41 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(256 * scale),
num_filters1=256,
num_filters2=256,
num_groups=256,
stride=1,
scale=scale),
name="conv4_1")
self.dwsl.append(dws41)
dws42 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(256 * scale),
num_filters1=256,
num_filters2=512,
num_groups=256,
stride=2,
scale=scale),
name="conv4_2")
self.dwsl.append(dws42)
for i in range(5):
tmp = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(512 * scale),
num_filters1=512,
num_filters2=512,
num_groups=512,
stride=1,
scale=scale),
name="conv5_" + str(i + 1))
self.dwsl.append(tmp)
dws56 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(512 * scale),
num_filters1=512,
num_filters2=1024,
num_groups=512,
stride=2,
scale=scale),
name="conv5_6")
self.dwsl.append(dws56)
dws6 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(1024 * scale),
num_filters1=1024,
num_filters2=1024,
num_groups=1024,
stride=1,
scale=scale),
name="conv6")
self.dwsl.append(dws6)
if with_pool:
self.pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
if num_classes > -1:
self.out = Linear(
int(1024 * scale),
num_classes,
act=classifier_activation,
param_attr=ParamAttr(
initializer=MSRA(), name=self.full_name() + "fc7_weights"),
bias_attr=ParamAttr(name="fc7_offset"))
def forward(self, inputs):
y = self.conv1(inputs)
for dws in self.dwsl:
y = dws(y)
if self.with_pool:
y = self.pool2d_avg(y)
if self.num_classes > 0:
y = fluid.layers.reshape(y, shape=[-1, 1024])
y = self.out(y)
return y
def _mobilenet(arch, pretrained=False, **kwargs):
model = MobileNetV1(**kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams"
model.load(weight_path)
return model
def mobilenet_v1(pretrained=False, scale=1.0, **kwargs):
"""MobileNetV1
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
scale: (float): scale of channels in each layer. Default: 1.0.
Examples:
.. code-block:: python
from paddle.incubate.hapi.vision.models import mobilenet_v1
# build model
model = mobilenet_v1()
# build model and load imagenet pretrained weight
# model = mobilenet_v1(pretrained=True)
# build mobilenet v1 with scale=0.5
model = mobilenet_v1(scale=0.5)
"""
model = _mobilenet(
'mobilenetv1_' + str(scale), pretrained, scale=scale, **kwargs)
return model
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from ...model import Model
from ...download import get_weights_path_from_url
__all__ = ['MobileNetV2', 'mobilenet_v2']
model_urls = {
'mobilenetv2_1.0':
('https://paddle-hapi.bj.bcebos.com/models/mobilenet_v2_x1.0.pdparams',
'8ff74f291f72533f2a7956a4efff9d88')
}
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
use_cudnn=True):
super(ConvBNLayer, self).__init__()
tmp_param = ParamAttr(name=self.full_name() + "_weights")
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=tmp_param,
bias_attr=False)
self._batch_norm = BatchNorm(
num_filters,
param_attr=ParamAttr(name=self.full_name() + "_bn" + "_scale"),
bias_attr=ParamAttr(name=self.full_name() + "_bn" + "_offset"),
moving_mean_name=self.full_name() + "_bn" + '_mean',
moving_variance_name=self.full_name() + "_bn" + '_variance')
def forward(self, inputs, if_act=True):
y = self._conv(inputs)
y = self._batch_norm(y)
if if_act:
y = fluid.layers.relu6(y)
return y
class InvertedResidualUnit(fluid.dygraph.Layer):
def __init__(
self,
num_channels,
num_in_filter,
num_filters,
stride,
filter_size,
padding,
expansion_factor, ):
super(InvertedResidualUnit, self).__init__()
num_expfilter = int(round(num_in_filter * expansion_factor))
self._expand_conv = ConvBNLayer(
num_channels=num_channels,
num_filters=num_expfilter,
filter_size=1,
stride=1,
padding=0,
num_groups=1)
self._bottleneck_conv = ConvBNLayer(
num_channels=num_expfilter,
num_filters=num_expfilter,
filter_size=filter_size,
stride=stride,
padding=padding,
num_groups=num_expfilter,
use_cudnn=False)
self._linear_conv = ConvBNLayer(
num_channels=num_expfilter,
num_filters=num_filters,
filter_size=1,
stride=1,
padding=0,
num_groups=1)
def forward(self, inputs, ifshortcut):
y = self._expand_conv(inputs, if_act=True)
y = self._bottleneck_conv(y, if_act=True)
y = self._linear_conv(y, if_act=False)
if ifshortcut:
y = fluid.layers.elementwise_add(inputs, y)
return y
class InvresiBlocks(fluid.dygraph.Layer):
def __init__(self, in_c, t, c, n, s):
super(InvresiBlocks, self).__init__()
self._first_block = InvertedResidualUnit(
num_channels=in_c,
num_in_filter=in_c,
num_filters=c,
stride=s,
filter_size=3,
padding=1,
expansion_factor=t)
self._inv_blocks = []
for i in range(1, n):
tmp = self.add_sublayer(
sublayer=InvertedResidualUnit(
num_channels=c,
num_in_filter=c,
num_filters=c,
stride=1,
filter_size=3,
padding=1,
expansion_factor=t),
name=self.full_name() + "_" + str(i + 1))
self._inv_blocks.append(tmp)
def forward(self, inputs):
y = self._first_block(inputs, ifshortcut=False)
for inv_block in self._inv_blocks:
y = inv_block(y, ifshortcut=True)
return y
class MobileNetV2(Model):
"""MobileNetV2 model from
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
Args:
scale (float): scale of channels in each layer. Default: 1.0.
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000.
with_pool (bool): use pool before the last fc layer or not. Default: True.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
Examples:
.. code-block:: python
from paddle.incubate.hapi.vision.models import MobileNetV2
model = MobileNetV2()
"""
def __init__(self,
scale=1.0,
num_classes=1000,
with_pool=True,
classifier_activation='softmax'):
super(MobileNetV2, self).__init__()
self.scale = scale
self.num_classes = num_classes
self.with_pool = with_pool
bottleneck_params_list = [
(1, 16, 1, 1),
(6, 24, 2, 2),
(6, 32, 3, 2),
(6, 64, 4, 2),
(6, 96, 3, 1),
(6, 160, 3, 2),
(6, 320, 1, 1),
]
self._conv1 = ConvBNLayer(
num_channels=3,
num_filters=int(32 * scale),
filter_size=3,
stride=2,
padding=1)
self._invl = []
i = 1
in_c = int(32 * scale)
for layer_setting in bottleneck_params_list:
t, c, n, s = layer_setting
i += 1
tmp = self.add_sublayer(
sublayer=InvresiBlocks(
in_c=in_c, t=t, c=int(c * scale), n=n, s=s),
name='conv' + str(i))
self._invl.append(tmp)
in_c = int(c * scale)
self._out_c = int(1280 * scale) if scale > 1.0 else 1280
self._conv9 = ConvBNLayer(
num_channels=in_c,
num_filters=self._out_c,
filter_size=1,
stride=1,
padding=0)
if with_pool:
self._pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
if num_classes > 0:
tmp_param = ParamAttr(name=self.full_name() + "fc10_weights")
self._fc = Linear(
self._out_c,
num_classes,
act=classifier_activation,
param_attr=tmp_param,
bias_attr=ParamAttr(name="fc10_offset"))
def forward(self, inputs):
y = self._conv1(inputs, if_act=True)
for inv in self._invl:
y = inv(y)
y = self._conv9(y, if_act=True)
if self.with_pool:
y = self._pool2d_avg(y)
if self.num_classes > 0:
y = fluid.layers.reshape(y, shape=[-1, self._out_c])
y = self._fc(y)
return y
def _mobilenet(arch, pretrained=False, **kwargs):
model = MobileNetV2(**kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams"
model.load(weight_path)
return model
def mobilenet_v2(pretrained=False, scale=1.0, **kwargs):
"""MobileNetV2
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
scale: (float): scale of channels in each layer. Default: 1.0.
Examples:
.. code-block:: python
from paddle.incubate.hapi.vision.models import mobilenet_v2
# build model
model = mobilenet_v2()
# build model and load imagenet pretrained weight
# model = mobilenet_v2(pretrained=True)
# build mobilenet v2 with scale=0.5
model = mobilenet_v2(scale=0.5)
"""
model = _mobilenet(
'mobilenetv2_' + str(scale), pretrained, scale=scale, **kwargs)
return model
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import math
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.container import Sequential
from ...model import Model
from ...download import get_weights_path_from_url
__all__ = [
'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'
]
model_urls = {
'resnet18': ('https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams',
'0ba53eea9bc970962d0ef96f7b94057e'),
'resnet34': ('https://paddle-hapi.bj.bcebos.com/models/resnet34.pdparams',
'46bc9f7c3dd2e55b7866285bee91eff3'),
'resnet50': ('https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams',
'5ce890a9ad386df17cf7fe2313dca0a1'),
'resnet101': ('https://paddle-hapi.bj.bcebos.com/models/resnet101.pdparams',
'fb07a451df331e4b0bb861ed97c3a9b9'),
'resnet152': ('https://paddle-hapi.bj.bcebos.com/models/resnet152.pdparams',
'f9c700f26d3644bb76ad2226ed5f5713'),
}
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
bias_attr=False)
self._batch_norm = BatchNorm(num_filters, act=act)
def forward(self, inputs):
x = self._conv(inputs)
x = self._batch_norm(x)
return x
class BasicBlock(fluid.dygraph.Layer):
"""residual block of resnet18 and resnet34
"""
expansion = 1
def __init__(self, num_channels, num_filters, stride, shortcut=True):
super(BasicBlock, self).__init__()
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=3,
act='relu')
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu')
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
stride=stride)
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = short + conv1
return fluid.layers.relu(y)
class BottleneckBlock(fluid.dygraph.Layer):
"""residual block of resnet50, resnet101 amd resnet152
"""
expansion = 4
def __init__(self, num_channels, num_filters, stride, shortcut=True):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act='relu')
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu')
self.conv2 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters * self.expansion,
filter_size=1,
act=None)
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters * self.expansion,
filter_size=1,
stride=stride)
self.shortcut = shortcut
self._num_channels_out = num_filters * self.expansion
def forward(self, inputs):
x = self.conv0(inputs)
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
x = fluid.layers.elementwise_add(x=short, y=conv2)
return fluid.layers.relu(x)
class ResNet(Model):
"""ResNet model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
Block (BasicBlock|BottleneckBlock): block module of model.
depth (int): layers of resnet, default: 50.
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000.
with_pool (bool): use pool before the last fc layer or not. Default: True.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
Examples:
.. code-block:: python
from paddle.incubate.hapi.vision.models import ResNet
from paddle.incubate.hapi.vision.models.resnet import BottleneckBlock, BasicBlock
resnet50 = ResNet(BottleneckBlock, 50)
resnet18 = ResNet(BasicBlock, 18)
"""
def __init__(self,
Block,
depth=50,
num_classes=1000,
with_pool=True,
classifier_activation='softmax'):
super(ResNet, self).__init__()
self.num_classes = num_classes
self.with_pool = with_pool
layer_config = {
18: [2, 2, 2, 2],
34: [3, 4, 6, 3],
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3],
}
assert depth in layer_config.keys(), \
"supported depth are {} but input layer is {}".format(
layer_config.keys(), depth)
layers = layer_config[depth]
in_channels = 64
out_channels = [64, 128, 256, 512]
self.conv = ConvBNLayer(
num_channels=3, num_filters=64, filter_size=7, stride=2, act='relu')
self.pool = Pool2D(
pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
self.layers = []
for idx, num_blocks in enumerate(layers):
blocks = []
shortcut = False
for b in range(num_blocks):
if b == 1:
in_channels = out_channels[idx] * Block.expansion
block = Block(
num_channels=in_channels,
num_filters=out_channels[idx],
stride=2 if b == 0 and idx != 0 else 1,
shortcut=shortcut)
blocks.append(block)
shortcut = True
layer = self.add_sublayer("layer_{}".format(idx),
Sequential(*blocks))
self.layers.append(layer)
if with_pool:
self.global_pool = Pool2D(
pool_size=7, pool_type='avg', global_pooling=True)
if num_classes > 0:
stdv = 1.0 / math.sqrt(out_channels[-1] * Block.expansion * 1.0)
self.fc_input_dim = out_channels[-1] * Block.expansion * 1 * 1
self.fc = Linear(
self.fc_input_dim,
num_classes,
act=classifier_activation,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
def forward(self, inputs):
x = self.conv(inputs)
x = self.pool(x)
for layer in self.layers:
x = layer(x)
if self.with_pool:
x = self.global_pool(x)
if self.num_classes > -1:
x = fluid.layers.reshape(x, shape=[-1, self.fc_input_dim])
x = self.fc(x)
return x
def _resnet(arch, Block, depth, pretrained, **kwargs):
model = ResNet(Block, depth, **kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams"
model.load(weight_path)
return model
def resnet18(pretrained=False, **kwargs):
"""ResNet 18-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.incubate.hapi.vision.models import resnet18
# build model
model = resnet18()
# build model and load imagenet pretrained weight
# model = resnet18(pretrained=True)
"""
return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs)
def resnet34(pretrained=False, **kwargs):
"""ResNet 34-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.incubate.hapi.vision.models import resnet34
# build model
model = resnet34()
# build model and load imagenet pretrained weight
# model = resnet34(pretrained=True)
"""
return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs)
def resnet50(pretrained=False, **kwargs):
"""ResNet 50-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.incubate.hapi.vision.models import resnet50
# build model
model = resnet50()
# build model and load imagenet pretrained weight
# model = resnet50(pretrained=True)
"""
return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs)
def resnet101(pretrained=False, **kwargs):
"""ResNet 101-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.incubate.hapi.vision.models import resnet101
# build model
model = resnet101()
# build model and load imagenet pretrained weight
# model = resnet101(pretrained=True)
"""
return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs)
def resnet152(pretrained=False, **kwargs):
"""ResNet 152-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.incubate.hapi.vision.models import resnet152
# build model
model = resnet152()
# build model and load imagenet pretrained weight
# model = resnet152(pretrained=True)
"""
return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.container import Sequential
from ...model import Model
from ...download import get_weights_path_from_url
__all__ = [
'VGG',
'vgg11',
'vgg13',
'vgg16',
'vgg19',
]
model_urls = {
'vgg16': ('https://paddle-hapi.bj.bcebos.com/models/vgg16.pdparams',
'c788f453a3b999063e8da043456281ee')
}
class Classifier(fluid.dygraph.Layer):
def __init__(self, num_classes, classifier_activation='softmax'):
super(Classifier, self).__init__()
self.linear1 = Linear(512 * 7 * 7, 4096)
self.linear2 = Linear(4096, 4096)
self.linear3 = Linear(4096, num_classes, act=classifier_activation)
def forward(self, x):
x = self.linear1(x)
x = fluid.layers.relu(x)
x = fluid.layers.dropout(x, 0.5)
x = self.linear2(x)
x = fluid.layers.relu(x)
x = fluid.layers.dropout(x, 0.5)
out = self.linear3(x)
return out
class VGG(Model):
"""VGG model from
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
Args:
features (fluid.dygraph.Layer): vgg features create by function make_layers.
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
Examples:
.. code-block:: python
from paddle.incubate.hapi.vision.models import VGG
from paddle.incubate.hapi.vision.models.vgg import make_layers
vgg11_cfg = [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']
features = make_layers(vgg11_cfg)
vgg11 = VGG(features)
"""
def __init__(self,
features,
num_classes=1000,
classifier_activation='softmax'):
super(VGG, self).__init__()
self.features = features
self.num_classes = num_classes
if num_classes > 0:
classifier = Classifier(num_classes, classifier_activation)
self.classifier = self.add_sublayer("classifier",
Sequential(classifier))
def forward(self, x):
x = self.features(x)
if self.num_classes > 0:
x = fluid.layers.flatten(x, 1)
x = self.classifier(x)
return x
def make_layers(cfg, batch_norm=False):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [Pool2D(pool_size=2, pool_stride=2)]
else:
if batch_norm:
conv2d = Conv2D(in_channels, v, filter_size=3, padding=1)
layers += [conv2d, BatchNorm(v, act='relu')]
else:
conv2d = Conv2D(
in_channels, v, filter_size=3, padding=1, act='relu')
layers += [conv2d]
in_channels = v
return Sequential(*layers)
cfgs = {
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'B':
[64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'D': [
64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512,
512, 512, 'M'
],
'E': [
64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512,
'M', 512, 512, 512, 512, 'M'
],
}
def _vgg(arch, cfg, batch_norm, pretrained, **kwargs):
model = VGG(make_layers(
cfgs[cfg], batch_norm=batch_norm),
num_classes=1000,
**kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams"
model.load(weight_path)
return model
def vgg11(pretrained=False, batch_norm=False, **kwargs):
"""VGG 11-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
Examples:
.. code-block:: python
from paddle.incubate.hapi.vision.models import vgg11
# build model
model = vgg11()
# build vgg11 model with batch_norm
model = vgg11(batch_norm=True)
"""
model_name = 'vgg11'
if batch_norm:
model_name += ('_bn')
return _vgg(model_name, 'A', batch_norm, pretrained, **kwargs)
def vgg13(pretrained=False, batch_norm=False, **kwargs):
"""VGG 13-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
Examples:
.. code-block:: python
from paddle.incubate.hapi.vision.models import vgg13
# build model
model = vgg13()
# build vgg13 model with batch_norm
model = vgg13(batch_norm=True)
"""
model_name = 'vgg13'
if batch_norm:
model_name += ('_bn')
return _vgg(model_name, 'B', batch_norm, pretrained, **kwargs)
def vgg16(pretrained=False, batch_norm=False, **kwargs):
"""VGG 16-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
Examples:
.. code-block:: python
from paddle.incubate.hapi.vision.models import vgg16
# build model
model = vgg16()
# build vgg16 model with batch_norm
model = vgg16(batch_norm=True)
"""
model_name = 'vgg16'
if batch_norm:
model_name += ('_bn')
return _vgg(model_name, 'D', batch_norm, pretrained, **kwargs)
def vgg19(pretrained=False, batch_norm=False, **kwargs):
"""VGG 19-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
Examples:
.. code-block:: python
from paddle.incubate.hapi.vision.models import vgg19
# build model
model = vgg19()
# build vgg19 model with batch_norm
model = vgg19(batch_norm=True)
"""
model_name = 'vgg19'
if batch_norm:
model_name += ('_bn')
return _vgg(model_name, 'E', batch_norm, pretrained, **kwargs)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import transforms
from . import functional
from .transforms import *
from .functional import *
__all__ = transforms.__all__ \
+ functional.__all__
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import collections
import random
import cv2
import numpy as np
if sys.version_info < (3, 3):
Sequence = collections.Sequence
Iterable = collections.Iterable
else:
Sequence = collections.abc.Sequence
Iterable = collections.abc.Iterable
__all__ = ['flip', 'resize']
def flip(image, code):
"""
Accordding to the code (the type of flip), flip the input image
Args:
image: Input image, with (H, W, C) shape
code: Code that indicates the type of flip.
-1 : Flip horizontally and vertically
0 : Flip vertically
1 : Flip horizontally
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms import functional as F
fake_img = np.random.rand(224, 224, 3)
# flip horizontally and vertically
F.flip(fake_img, -1)
# flip vertically
F.flip(fake_img, 0)
# flip horizontally
F.flip(fake_img, 1)
"""
return cv2.flip(image, flipCode=code)
def resize(img, size, interpolation=cv2.INTER_LINEAR):
"""
resize the input data to given size
Args:
input: Input data, could be image or masks, with (H, W, C) shape
size: Target size of input data, with (height, width) shape.
interpolation: Interpolation method.
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms import functional as F
fake_img = np.random.rand(256, 256, 3)
F.resize(fake_img, 224)
F.resize(fake_img, (200, 150))
"""
if isinstance(interpolation, Sequence):
interpolation = random.choice(interpolation)
if isinstance(size, int):
h, w = img.shape[:2]
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
ow = size
oh = int(size * h / w)
return cv2.resize(img, (ow, oh), interpolation=interpolation)
else:
oh = size
ow = int(size * w / h)
return cv2.resize(img, (ow, oh), interpolation=interpolation)
else:
return cv2.resize(img, size[::-1], interpolation=interpolation)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import math
import sys
import random
import cv2
import numpy as np
import numbers
import types
import collections
import warnings
import traceback
from . import functional as F
if sys.version_info < (3, 3):
Sequence = collections.Sequence
Iterable = collections.Iterable
else:
Sequence = collections.abc.Sequence
Iterable = collections.abc.Iterable
__all__ = [
"Compose",
"BatchCompose",
"Resize",
"RandomResizedCrop",
"CenterCropResize",
"CenterCrop",
"RandomHorizontalFlip",
"RandomVerticalFlip",
"Permute",
"Normalize",
"GaussianNoise",
"BrightnessTransform",
"SaturationTransform",
"ContrastTransform",
"HueTransform",
"ColorJitter",
]
class Compose(object):
"""
Composes several transforms together use for composing list of transforms
together for a dataset transform.
Args:
transforms (list): List of transforms to compose.
Returns:
A compose object which is callable, __call__ for this Compose
object will call each given :attr:`transforms` sequencely.
Examples:
.. code-block:: python
from paddle.incubate.hapi.datasets import Flowers
from paddle.incubate.hapi.vision.transforms import Compose, ColorJitter, Resize
transform = Compose([ColorJitter(), Resize(size=608)])
flowers = Flowers(mode='test', transform=transform)
for i in range(10):
sample = flowers[i]
print(sample[0].shape, sample[1])
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, *data):
for f in self.transforms:
try:
# multi-fileds in a sample
if isinstance(data, Sequence):
data = f(*data)
# single field in a sample, call transform directly
else:
data = f(data)
except Exception as e:
stack_info = traceback.format_exc()
print("fail to perform transform [{}] with error: "
"{} and stack:\n{}".format(f, e, str(stack_info)))
raise e
return data
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class BatchCompose(object):
"""Composes several batch transforms together
Args:
transforms (list): List of transforms to compose.
these transforms perform on batch data.
Examples:
.. code-block:: python
import numpy as np
from paddle.io import DataLoader
from paddle.incubate.hapi.model import set_device
from paddle.incubate.hapi.datasets import Flowers
from paddle.incubate.hapi.vision.transforms import Compose, BatchCompose, Resize
class NormalizeBatch(object):
def __init__(self,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
scale=True,
channel_first=True):
self.mean = mean
self.std = std
self.scale = scale
self.channel_first = channel_first
if not (isinstance(self.mean, list) and isinstance(self.std, list) and
isinstance(self.scale, bool)):
raise TypeError("{}: input type is invalid.".format(self))
from functools import reduce
if reduce(lambda x, y: x * y, self.std) == 0:
raise ValueError('{}: std is invalid!'.format(self))
def __call__(self, samples):
for i in range(len(samples)):
samples[i] = list(samples[i])
im = samples[i][0]
im = im.astype(np.float32, copy=False)
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
if self.scale:
im = im / 255.0
im -= mean
im /= std
if self.channel_first:
im = im.transpose((2, 0, 1))
samples[i][0] = im
return samples
transform = Compose([Resize((500, 500))])
flowers_dataset = Flowers(mode='test', transform=transform)
device = set_device('cpu')
collate_fn = BatchCompose([NormalizeBatch()])
loader = DataLoader(
flowers_dataset,
batch_size=4,
places=device,
return_list=True,
collate_fn=collate_fn)
for data in loader:
# do something
break
"""
def __init__(self, transforms=[]):
self.transforms = transforms
def __call__(self, data):
for f in self.transforms:
try:
data = f(data)
except Exception as e:
stack_info = traceback.format_exc()
print("fail to perform batch transform [{}] with error: "
"{} and stack:\n{}".format(f, e, str(stack_info)))
raise e
# sample list to batch data
batch = list(zip(*data))
return batch
class Resize(object):
"""Resize the input Image to the given size.
Args:
size (int|list|tuple): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int): Interpolation mode of resize. Default: cv2.INTER_LINEAR.
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms import Resize
transform = Resize(size=224)
fake_img = np.random.rand(500, 500, 3).astype('float32')
fake_img = transform(fake_img)
print(fake_img.shape)
"""
def __init__(self, size, interpolation=cv2.INTER_LINEAR):
assert isinstance(size, int) or (isinstance(size, Iterable) and
len(size) == 2)
self.size = size
self.interpolation = interpolation
def __call__(self, img):
return F.resize(img, self.size, self.interpolation)
class RandomResizedCrop(object):
"""Crop the input data to random size and aspect ratio.
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
aspect ratio (default: of 3/4 to 1.33) of the original aspect ratio is made.
After applying crop transfrom, the input data will be resized to given size.
Args:
output_size (int|list|tuple): Target size of output image, with (height, width) shape.
scale (list|tuple): Range of size of the origin size cropped. Default: (0.08, 1.0)
ratio (list|tuple): Range of aspect ratio of the origin aspect ratio cropped. Default: (0.75, 1.33)
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms import RandomResizedCrop
transform = RandomResizedCrop(224)
fake_img = np.random.rand(500, 500, 3).astype('float32')
fake_img = transform(fake_img)
print(fake_img.shape)
"""
def __init__(self,
output_size,
scale=(0.08, 1.0),
ratio=(3. / 4, 4. / 3),
interpolation=cv2.INTER_LINEAR):
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
self.output_size = output_size
assert (scale[0] <= scale[1]), "scale should be of kind (min, max)"
assert (ratio[0] <= ratio[1]), "ratio should be of kind (min, max)"
self.scale = scale
self.ratio = ratio
self.interpolation = interpolation
def _get_params(self, image, attempts=10):
height, width, _ = image.shape
area = height * width
for _ in range(attempts):
target_area = np.random.uniform(*self.scale) * area
log_ratio = tuple(math.log(x) for x in self.ratio)
aspect_ratio = math.exp(np.random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
x = np.random.randint(0, width - w + 1)
y = np.random.randint(0, height - h + 1)
return x, y, w, h
# Fallback to central crop
in_ratio = float(width) / float(height)
if in_ratio < min(self.ratio):
w = width
h = int(round(w / min(self.ratio)))
elif in_ratio > max(self.ratio):
h = height
w = int(round(h * max(self.ratio)))
else: # whole image
w = width
h = height
x = (width - w) // 2
y = (height - h) // 2
return x, y, w, h
def __call__(self, img):
x, y, w, h = self._get_params(img)
cropped_img = img[y:y + h, x:x + w]
return F.resize(cropped_img, self.output_size, self.interpolation)
class CenterCropResize(object):
"""Crops to center of image with padding then scales size.
Args:
size (int|list|tuple): Target size of output image, with (height, width) shape.
crop_padding (int): Center crop with the padding. Default: 32.
interpolation (int): Interpolation mode of resize. Default: cv2.INTER_LINEAR.
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms import CenterCropResize
transform = CenterCropResize(224)
fake_img = np.random.rand(500, 500, 3).astype('float32')
fake_img = transform(fake_img)
print(fake_img.shape)
"""
def __init__(self, size, crop_padding=32, interpolation=cv2.INTER_LINEAR):
if isinstance(size, int):
self.size = (size, size)
else:
self.size = size
self.crop_padding = crop_padding
self.interpolation = interpolation
def _get_params(self, img):
h, w = img.shape[:2]
size = min(self.size)
c = int(size / (size + self.crop_padding) * min((h, w)))
x = (h + 1 - c) // 2
y = (w + 1 - c) // 2
return c, x, y
def __call__(self, img):
c, x, y = self._get_params(img)
cropped_img = img[x:x + c, y:y + c, :]
return F.resize(cropped_img, self.size, self.interpolation)
class CenterCrop(object):
"""Crops the given the input data at the center.
Args:
output_size: Target size of output image, with (height, width) shape.
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms import CenterCrop
transform = CenterCrop(224)
fake_img = np.random.rand(500, 500, 3).astype('float32')
fake_img = transform(fake_img)
print(fake_img.shape)
"""
def __init__(self, output_size):
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
self.output_size = output_size
def _get_params(self, img):
th, tw = self.output_size
h, w, _ = img.shape
assert th <= h and tw <= w, "output size is bigger than image size"
x = int(round((w - tw) / 2.0))
y = int(round((h - th) / 2.0))
return x, y
def __call__(self, img):
x, y = self._get_params(img)
th, tw = self.output_size
return img[y:y + th, x:x + tw]
class RandomHorizontalFlip(object):
"""Horizontally flip the input data randomly with a given probability.
Args:
prob (float): Probability of the input data being flipped. Default: 0.5
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms import RandomHorizontalFlip
transform = RandomHorizontalFlip(224)
fake_img = np.random.rand(500, 500, 3).astype('float32')
fake_img = transform(fake_img)
print(fake_img.shape)
"""
def __init__(self, prob=0.5):
self.prob = prob
def __call__(self, img):
if np.random.random() < self.prob:
return F.flip(img, code=1)
return img
class RandomVerticalFlip(object):
"""Vertically flip the input data randomly with a given probability.
Args:
prob (float): Probability of the input data being flipped. Default: 0.5
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms import RandomVerticalFlip
transform = RandomVerticalFlip(224)
fake_img = np.random.rand(500, 500, 3).astype('float32')
fake_img = transform(fake_img)
print(fake_img.shape)
"""
def __init__(self, prob=0.5):
self.prob = prob
def __call__(self, img):
if np.random.random() < self.prob:
return F.flip(img, code=0)
return img
class Normalize(object):
"""Normalize the input data with mean and standard deviation.
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels,
this transform will normalize each channel of the input data.
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
Args:
mean (int|float|list): Sequence of means for each channel.
std (int|float|list): Sequence of standard deviations for each channel.
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms import Normalize
normalize = Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
fake_img = np.random.rand(3, 500, 500).astype('float32')
fake_img = normalize(fake_img)
print(fake_img.shape)
"""
def __init__(self, mean=0.0, std=1.0):
if isinstance(mean, numbers.Number):
mean = [mean, mean, mean]
if isinstance(std, numbers.Number):
mean = [std, std, std]
self.mean = np.array(mean, dtype=np.float32).reshape(len(mean), 1, 1)
self.std = np.array(std, dtype=np.float32).reshape(len(std), 1, 1)
def __call__(self, img):
return (img - self.mean) / self.std
class Permute(object):
"""Change input data to a target mode.
For example, most transforms use HWC mode image,
while the Neural Network might use CHW mode input tensor.
Input image should be HWC mode and an instance of numpy.ndarray.
Args:
mode (str): Output mode of input. Default: "CHW".
to_rgb (bool): Convert 'bgr' image to 'rgb'. Default: True.
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms import Permute
transform = Permute()
fake_img = np.random.rand(500, 500, 3).astype('float32')
fake_img = transform(fake_img)
print(fake_img.shape)
"""
def __init__(self, mode="CHW", to_rgb=True):
assert mode in [
"CHW"
], "Only support 'CHW' mode, but received mode: {}".format(mode)
self.mode = mode
self.to_rgb = to_rgb
def __call__(self, img):
if self.to_rgb:
img = img[..., ::-1]
if self.mode == "CHW":
return img.transpose((2, 0, 1))
return img
class GaussianNoise(object):
"""Add random gaussian noise to the input data.
Gaussian noise is generated with given mean and std.
Args:
mean (float): Gaussian mean used to generate noise.
std (float): Gaussian standard deviation used to generate noise.
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms import GaussianNoise
transform = GaussianNoise()
fake_img = np.random.rand(500, 500, 3).astype('float32')
fake_img = transform(fake_img)
print(fake_img.shape)
"""
def __init__(self, mean=0.0, std=1.0):
self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32)
def __call__(self, img):
dtype = img.dtype
noise = np.random.normal(self.mean, self.std, img.shape) * 255
img = img + noise.astype(np.float32)
return np.clip(img, 0, 255).astype(dtype)
class BrightnessTransform(object):
"""Adjust brightness of the image.
Args:
value (float): How much to adjust the brightness. Can be any
non negative number. 0 gives the original image
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms import BrightnessTransform
transform = BrightnessTransform(0.4)
fake_img = np.random.rand(500, 500, 3).astype('float32')
fake_img = transform(fake_img)
print(fake_img.shape)
"""
def __init__(self, value):
if value < 0:
raise ValueError("brightness value should be non-negative")
self.value = value
def __call__(self, img):
if self.value == 0:
return img
dtype = img.dtype
img = img.astype(np.float32)
alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
img = img * alpha
return img.clip(0, 255).astype(dtype)
class ContrastTransform(object):
"""Adjust contrast of the image.
Args:
value (float): How much to adjust the contrast. Can be any
non negative number. 0 gives the original image
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms import ContrastTransform
transform = ContrastTransform(0.4)
fake_img = np.random.rand(500, 500, 3).astype('float32')
fake_img = transform(fake_img)
print(fake_img.shape)
"""
def __init__(self, value):
if value < 0:
raise ValueError("contrast value should be non-negative")
self.value = value
def __call__(self, img):
if self.value == 0:
return img
dtype = img.dtype
img = img.astype(np.float32)
alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
img = img * alpha + cv2.cvtColor(img, cv2.COLOR_BGR2GRAY).mean() * (
1 - alpha)
return img.clip(0, 255).astype(dtype)
class SaturationTransform(object):
"""Adjust saturation of the image.
Args:
value (float): How much to adjust the saturation. Can be any
non negative number. 0 gives the original image
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms import SaturationTransform
transform = SaturationTransform(0.4)
fake_img = np.random.rand(500, 500, 3).astype('float32')
fake_img = transform(fake_img)
print(fake_img.shape)
"""
def __init__(self, value):
if value < 0:
raise ValueError("saturation value should be non-negative")
self.value = value
def __call__(self, img):
if self.value == 0:
return img
dtype = img.dtype
img = img.astype(np.float32)
alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray_img = gray_img[..., np.newaxis]
img = img * alpha + gray_img * (1 - alpha)
return img.clip(0, 255).astype(dtype)
class HueTransform(object):
"""Adjust hue of the image.
Args:
value (float): How much to adjust the hue. Can be any number
between 0 and 0.5, 0 gives the original image
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms import HueTransform
transform = HueTransform(0.4)
fake_img = np.random.rand(500, 500, 3).astype('float32')
fake_img = transform(fake_img)
print(fake_img.shape)
"""
def __init__(self, value):
if value < 0 or value > 0.5:
raise ValueError("hue value should be in [0.0, 0.5]")
self.value = value
def __call__(self, img):
if self.value == 0:
return img
dtype = img.dtype
img = img.astype(np.uint8)
hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV_FULL)
h, s, v = cv2.split(hsv_img)
alpha = np.random.uniform(-self.value, self.value)
h = h.astype(np.uint8)
# uint8 addition take cares of rotation across boundaries
with np.errstate(over="ignore"):
h += np.uint8(alpha * 255)
hsv_img = cv2.merge([h, s, v])
return cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR_FULL).astype(dtype)
class ColorJitter(object):
"""Randomly change the brightness, contrast, saturation and hue of an image.
Args:
brightness: How much to jitter brightness.
Chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
or the given [min, max]. Should be non negative numbers.
contrast: How much to jitter contrast.
Chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
or the given [min, max]. Should be non negative numbers.
saturation: How much to jitter saturation.
Chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
or the given [min, max]. Should be non negative numbers.
hue: How much to jitter hue.
Chosen uniformly from [-hue, hue] or the given [min, max].
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms import ColorJitter
transform = ColorJitter(0.4)
fake_img = np.random.rand(500, 500, 3).astype('float32')
fake_img = transform(fake_img)
print(fake_img.shape)
"""
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
transforms = []
if brightness != 0:
transforms.append(BrightnessTransform(brightness))
if contrast != 0:
transforms.append(ContrastTransform(contrast))
if saturation != 0:
transforms.append(SaturationTransform(saturation))
if hue != 0:
transforms.append(HueTransform(hue))
random.shuffle(transforms)
self.transforms = Compose(transforms)
def __call__(self, img):
return self.transforms(img)
......@@ -182,6 +182,7 @@ packages=['paddle',
'paddle.incubate.hapi.datasets',
'paddle.incubate.hapi.vision',
'paddle.incubate.hapi.vision.models',
'paddle.incubate.hapi.vision.transforms',
'paddle.io',
'paddle.nn',
'paddle.nn.functional',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册