未验证 提交 2de737eb 编写于 作者: Z zhiboniu 提交者: GitHub

update 2.0 public api in vision (#33308)

* update 2.0 public api in vision

* fix some flake8 errors
上级 6760d737
...@@ -324,7 +324,7 @@ class ProgBarLogger(Callback): ...@@ -324,7 +324,7 @@ class ProgBarLogger(Callback):
]) ])
train_dataset = MNIST(mode='train', transform=transform) train_dataset = MNIST(mode='train', transform=transform)
lenet = paddle.vision.LeNet() lenet = paddle.vision.models.LeNet()
model = paddle.Model(lenet, model = paddle.Model(lenet,
inputs, labels) inputs, labels)
...@@ -558,7 +558,7 @@ class ModelCheckpoint(Callback): ...@@ -558,7 +558,7 @@ class ModelCheckpoint(Callback):
]) ])
train_dataset = MNIST(mode='train', transform=transform) train_dataset = MNIST(mode='train', transform=transform)
lenet = paddle.vision.LeNet() lenet = paddle.vision.models.LeNet()
model = paddle.Model(lenet, model = paddle.Model(lenet,
inputs, labels) inputs, labels)
...@@ -618,7 +618,7 @@ class LRScheduler(Callback): ...@@ -618,7 +618,7 @@ class LRScheduler(Callback):
]) ])
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform) train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
lenet = paddle.vision.LeNet() lenet = paddle.vision.models.LeNet()
model = paddle.Model(lenet, model = paddle.Model(lenet,
inputs, labels) inputs, labels)
...@@ -634,7 +634,7 @@ class LRScheduler(Callback): ...@@ -634,7 +634,7 @@ class LRScheduler(Callback):
boundaries=boundaries, values=values) boundaries=boundaries, values=values)
learning_rate = paddle.optimizer.lr.LinearWarmup( learning_rate = paddle.optimizer.lr.LinearWarmup(
learning_rate=learning_rate, learning_rate=learning_rate,
warmup_steps=wamup_epochs, warmup_steps=wamup_steps,
start_lr=base_lr / 5., start_lr=base_lr / 5.,
end_lr=base_lr, end_lr=base_lr,
verbose=True) verbose=True)
...@@ -860,7 +860,7 @@ class VisualDL(Callback): ...@@ -860,7 +860,7 @@ class VisualDL(Callback):
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform) train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
eval_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform) eval_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
net = paddle.vision.LeNet() net = paddle.vision.models.LeNet()
model = paddle.Model(net, inputs, labels) model = paddle.Model(net, inputs, labels)
optim = paddle.optimizer.Adam(0.001, parameters=net.parameters()) optim = paddle.optimizer.Adam(0.001, parameters=net.parameters())
......
...@@ -30,20 +30,28 @@ from collections import Iterable ...@@ -30,20 +30,28 @@ from collections import Iterable
import paddle import paddle
from paddle import fluid from paddle import fluid
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.framework import in_dygraph_mode, Variable, ParamBase, _current_expected_place from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import in_dygraph_mode, Variable, _get_paddle_place from paddle.fluid.framework import Variable
from paddle.fluid.framework import ParamBase
from paddle.fluid.framework import _current_expected_place
from paddle.fluid.framework import _get_paddle_place
from paddle.fluid.framework import _current_expected_place as _get_device from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.executor import global_scope from paddle.fluid.executor import global_scope
from paddle.fluid.io import is_belong_to_optimizer from paddle.fluid.io import is_belong_to_optimizer
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator, FunctionSpec from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX from paddle.fluid.dygraph.dygraph_to_static.program_translator import FunctionSpec
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX
from paddle.fluid.dygraph.io import INFER_PARAMS_SUFFIX
from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers import collective from paddle.fluid.layers import collective
from paddle.io import DataLoader, Dataset, DistributedBatchSampler from paddle.io import DataLoader
from paddle.fluid.executor import scope_guard, Executor from paddle.io import Dataset
from paddle.io import DistributedBatchSampler
from paddle.fluid.executor import scope_guard
from paddle.fluid.executor import Executor
from paddle.fluid.dygraph.layers import Layer from paddle.fluid.dygraph.layers import Layer
from paddle.metric import Metric from paddle.metric import Metric
from paddle.static import InputSpec as Input from paddle.static import InputSpec as Input
...@@ -166,7 +174,6 @@ def init_communicator(program, rank, nranks, wait_port, current_endpoint, ...@@ -166,7 +174,6 @@ def init_communicator(program, rank, nranks, wait_port, current_endpoint,
name=fluid.unique_name.generate('hccl_id'), name=fluid.unique_name.generate('hccl_id'),
persistable=True, persistable=True,
type=core.VarDesc.VarType.RAW) type=core.VarDesc.VarType.RAW)
endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)}
block.append_op( block.append_op(
type='c_gen_hccl_id', type='c_gen_hccl_id',
inputs={}, inputs={},
...@@ -1363,8 +1370,9 @@ class Model(object): ...@@ -1363,8 +1370,9 @@ class Model(object):
# pure float16 training has some restricts now # pure float16 training has some restricts now
if self._adapter._amp_level == "O2": if self._adapter._amp_level == "O2":
if in_dygraph_mode(): if in_dygraph_mode():
warnings.warn("Pure float16 training is not supported in dygraph mode now, "\ warnings.warn(
"and it will be supported in future version.") "Pure float16 training is not supported in dygraph mode now, and it will be supported in future version."
)
else: else:
# grad clip is not supported in pure fp16 training now # grad clip is not supported in pure fp16 training now
assert self._optimizer._grad_clip is None, \ assert self._optimizer._grad_clip is None, \
...@@ -1398,8 +1406,7 @@ class Model(object): ...@@ -1398,8 +1406,7 @@ class Model(object):
if 'use_pure_fp16' in amp_configs: if 'use_pure_fp16' in amp_configs:
raise ValueError( raise ValueError(
"''use_pure_fp16' is an invalid parameter, " "'use_pure_fp16' is an invalid parameter, the level of mixed precision training only depends on 'O1' or 'O2'."
"the level of mixed precision training only depends on 'O1' or 'O2'."
) )
_check_pure_fp16_configs() _check_pure_fp16_configs()
...@@ -1427,9 +1434,8 @@ class Model(object): ...@@ -1427,9 +1434,8 @@ class Model(object):
} }
if amp_config_key_set - accepted_param_set: if amp_config_key_set - accepted_param_set:
raise ValueError( raise ValueError(
"Except for 'level', the keys of 'amp_configs' must be accepted by mixed precision APIs, " "Except for 'level', the keys of 'amp_configs' must be accepted by mixed precision APIs, but {} could not be recognized.".
"but {} could not be recognized.".format( format(tuple(amp_config_key_set - accepted_param_set)))
tuple(amp_config_key_set - accepted_param_set)))
if 'use_fp16_guard' in amp_config_key_set: if 'use_fp16_guard' in amp_config_key_set:
if in_dygraph_mode(): if in_dygraph_mode():
...@@ -1501,8 +1507,9 @@ class Model(object): ...@@ -1501,8 +1507,9 @@ class Model(object):
self._optimizer = optimizer self._optimizer = optimizer
if loss is not None: if loss is not None:
if not isinstance(loss, paddle.nn.Layer) and not callable(loss): if not isinstance(loss, paddle.nn.Layer) and not callable(loss):
raise TypeError("'loss' must be sub classes of " \ raise TypeError(
"`paddle.nn.Layer` or any callable function.") "'loss' must be sub classes of `paddle.nn.Layer` or any callable function."
)
self._loss = loss self._loss = loss
metrics = metrics or [] metrics = metrics or []
...@@ -2084,7 +2091,7 @@ class Model(object): ...@@ -2084,7 +2091,7 @@ class Model(object):
input = InputSpec([None, 1, 28, 28], 'float32', 'image') input = InputSpec([None, 1, 28, 28], 'float32', 'image')
label = InputSpec([None, 1], 'int64', 'label') label = InputSpec([None, 1], 'int64', 'label')
model = paddle.Model(paddle.vision.LeNet(), model = paddle.Model(paddle.vision.models.LeNet(),
input, label) input, label)
optim = paddle.optimizer.Adam( optim = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters()) learning_rate=0.001, parameters=model.parameters())
...@@ -2126,9 +2133,11 @@ class Model(object): ...@@ -2126,9 +2133,11 @@ class Model(object):
else: else:
out_specs = to_list(specs) out_specs = to_list(specs)
elif isinstance(specs, dict): elif isinstance(specs, dict):
assert is_input == False assert is_input is False
out_specs = [specs[n] \ out_specs = [
for n in extract_args(self.network.forward) if n != 'self'] specs[n] for n in extract_args(self.network.forward)
if n != 'self'
]
else: else:
out_specs = to_list(specs) out_specs = to_list(specs)
# Note: checks each element has specificed `name`. # Note: checks each element has specificed `name`.
......
...@@ -222,7 +222,7 @@ class Accuracy(Metric): ...@@ -222,7 +222,7 @@ class Accuracy(Metric):
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = MNIST(mode='train', transform=transform) train_dataset = MNIST(mode='train', transform=transform)
model = paddle.Model(paddle.vision.LeNet(), input, label) model = paddle.Model(paddle.vision.models.LeNet(), input, label)
optim = paddle.optimizer.Adam( optim = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters()) learning_rate=0.001, parameters=model.parameters())
model.prepare( model.prepare(
......
...@@ -55,7 +55,7 @@ class TestCallbacks(unittest.TestCase): ...@@ -55,7 +55,7 @@ class TestCallbacks(unittest.TestCase):
train_dataset = MnistDataset(mode='train', transform=transform) train_dataset = MnistDataset(mode='train', transform=transform)
eval_dataset = MnistDataset(mode='test', transform=transform) eval_dataset = MnistDataset(mode='test', transform=transform)
net = paddle.vision.LeNet() net = paddle.vision.models.LeNet()
model = paddle.Model(net, inputs, labels) model = paddle.Model(net, inputs, labels)
optim = paddle.optimizer.Adam(0.001, parameters=net.parameters()) optim = paddle.optimizer.Adam(0.001, parameters=net.parameters())
......
...@@ -11,22 +11,59 @@ ...@@ -11,22 +11,59 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle
import paddle.nn as nn
from . import models # noqa: F401
from . import transforms # noqa: F401
from . import datasets # noqa: F401
from . import ops # noqa: F401
from .image import set_image_backend # noqa: F401
from .image import get_image_backend # noqa: F401
from .image import image_load # noqa: F401
from .models import LeNet as models_LeNet
import paddle.utils.deprecated as deprecated
from . import models __all__ = [ #noqa
from .models import * 'set_image_backend', 'get_image_backend', 'image_load'
]
from . import transforms
from .transforms import *
from . import datasets class LeNet(models_LeNet):
from .datasets import * """LeNet model from
`"LeCun Y, Bottou L, Bengio Y, et al. Gradient-based learning applied to document recognition[J]. Proceedings of the IEEE, 1998, 86(11): 2278-2324.`_
from . import image Args:
from .image import * num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 10.
from . import ops Examples:
.. code-block:: python
__all__ = models.__all__ \ from paddle.vision.models import LeNet
+ transforms.__all__ \
+ datasets.__all__ \ model = LeNet()
+ image.__all__ """
@deprecated(
since="2.0.0",
update_to="paddle.vision.models.LeNet",
level=1,
reason="Please use new API in models, paddle.vision.LeNet will be removed in future"
)
def __init__(self, num_classes=10):
super(LeNet, self).__init__(num_classes=10)
self.num_classes = num_classes
self.features = nn.Sequential(
nn.Conv2D(
1, 6, 3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2D(2, 2),
nn.Conv2D(
6, 16, 5, stride=1, padding=0),
nn.ReLU(),
nn.MaxPool2D(2, 2))
if num_classes > 0:
self.fc = nn.Sequential(
nn.Linear(400, 120),
nn.Linear(120, 84), nn.Linear(84, num_classes))
...@@ -12,20 +12,22 @@ ...@@ -12,20 +12,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import folder from .folder import DatasetFolder # noqa: F401
from . import mnist from .folder import ImageFolder # noqa: F401
from . import flowers from .mnist import MNIST # noqa: F401
from . import cifar from .mnist import FashionMNIST # noqa: F401
from . import voc2012 from .flowers import Flowers # noqa: F401
from .cifar import Cifar10 # noqa: F401
from .cifar import Cifar100 # noqa: F401
from .voc2012 import VOC2012 # noqa: F401
from .folder import * __all__ = [ #noqa
from .mnist import * 'DatasetFolder'
from .flowers import * 'ImageFolder',
from .cifar import * 'MNIST',
from .voc2012 import * 'FashionMNIST',
'Flowers',
__all__ = folder.__all__ \ 'Cifar10',
+ mnist.__all__ \ 'Cifar100',
+ flowers.__all__ \ 'VOC2012'
+ cifar.__all__ \ ]
+ voc2012.__all__
...@@ -24,7 +24,7 @@ import paddle ...@@ -24,7 +24,7 @@ import paddle
from paddle.io import Dataset from paddle.io import Dataset
from paddle.dataset.common import _check_exists_and_download from paddle.dataset.common import _check_exists_and_download
__all__ = ['Cifar10', 'Cifar100'] __all__ = []
URL_PREFIX = 'https://dataset.bj.bcebos.com/cifar/' URL_PREFIX = 'https://dataset.bj.bcebos.com/cifar/'
CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz' CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz'
......
...@@ -25,7 +25,7 @@ from paddle.io import Dataset ...@@ -25,7 +25,7 @@ from paddle.io import Dataset
from paddle.utils import try_import from paddle.utils import try_import
from paddle.dataset.common import _check_exists_and_download from paddle.dataset.common import _check_exists_and_download
__all__ = ["Flowers"] __all__ = []
DATA_URL = 'http://paddlemodels.bj.bcebos.com/flowers/102flowers.tgz' DATA_URL = 'http://paddlemodels.bj.bcebos.com/flowers/102flowers.tgz'
LABEL_URL = 'http://paddlemodels.bj.bcebos.com/flowers/imagelabels.mat' LABEL_URL = 'http://paddlemodels.bj.bcebos.com/flowers/imagelabels.mat'
......
...@@ -20,7 +20,7 @@ import paddle ...@@ -20,7 +20,7 @@ import paddle
from paddle.io import Dataset from paddle.io import Dataset
from paddle.utils import try_import from paddle.utils import try_import
__all__ = ["DatasetFolder", "ImageFolder"] __all__ = []
def has_valid_extension(filename, extensions): def has_valid_extension(filename, extensions):
......
...@@ -24,7 +24,7 @@ import paddle ...@@ -24,7 +24,7 @@ import paddle
from paddle.io import Dataset from paddle.io import Dataset
from paddle.dataset.common import _check_exists_and_download from paddle.dataset.common import _check_exists_and_download
__all__ = ["MNIST", "FashionMNIST"] __all__ = []
class MNIST(Dataset): class MNIST(Dataset):
......
...@@ -23,7 +23,7 @@ import paddle ...@@ -23,7 +23,7 @@ import paddle
from paddle.io import Dataset from paddle.io import Dataset
from paddle.dataset.common import _check_exists_and_download from paddle.dataset.common import _check_exists_and_download
__all__ = ["VOC2012"] __all__ = []
VOC_URL = 'https://dataset.bj.bcebos.com/voc/VOCtrainval_11-May-2012.tar' VOC_URL = 'https://dataset.bj.bcebos.com/voc/VOCtrainval_11-May-2012.tar'
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from PIL import Image from PIL import Image
from paddle.utils import try_import from paddle.utils import try_import
__all__ = ['set_image_backend', 'get_image_backend', 'image_load'] __all__ = []
_image_backend = 'pil' _image_backend = 'pil'
......
...@@ -12,20 +12,38 @@ ...@@ -12,20 +12,38 @@
#See the License for the specific language governing permissions and #See the License for the specific language governing permissions and
#limitations under the License. #limitations under the License.
from . import resnet from .resnet import ResNet # noqa: F401
from . import vgg from .resnet import resnet18 # noqa: F401
from . import mobilenetv1 from .resnet import resnet34 # noqa: F401
from . import mobilenetv2 from .resnet import resnet50 # noqa: F401
from . import lenet from .resnet import resnet101 # noqa: F401
from .resnet import resnet152 # noqa: F401
from .mobilenetv1 import MobileNetV1 # noqa: F401
from .mobilenetv1 import mobilenet_v1 # noqa: F401
from .mobilenetv2 import MobileNetV2 # noqa: F401
from .mobilenetv2 import mobilenet_v2 # noqa: F401
from .vgg import VGG # noqa: F401
from .vgg import vgg11 # noqa: F401
from .vgg import vgg13 # noqa: F401
from .vgg import vgg16 # noqa: F401
from .vgg import vgg19 # noqa: F401
from .lenet import LeNet # noqa: F401
from .resnet import * __all__ = [ #noqa
from .mobilenetv1 import * 'ResNet',
from .mobilenetv2 import * 'resnet18',
from .vgg import * 'resnet34',
from .lenet import * 'resnet50',
'resnet101',
__all__ = resnet.__all__ \ 'resnet152',
+ vgg.__all__ \ 'VGG',
+ mobilenetv1.__all__ \ 'vgg11',
+ mobilenetv2.__all__ \ 'vgg13',
+ lenet.__all__ 'vgg16',
'vgg19',
'MobileNetV1',
'mobilenet_v1',
'MobileNetV2',
'mobilenet_v2',
'LeNet'
]
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
__all__ = ['LeNet'] __all__ = []
class LeNet(nn.Layer): class LeNet(nn.Layer):
......
...@@ -17,7 +17,7 @@ import paddle.nn as nn ...@@ -17,7 +17,7 @@ import paddle.nn as nn
from paddle.utils.download import get_weights_path_from_url from paddle.utils.download import get_weights_path_from_url
__all__ = ['MobileNetV1', 'mobilenet_v1'] __all__ = []
model_urls = { model_urls = {
'mobilenetv1_1.0': 'mobilenetv1_1.0':
......
...@@ -20,7 +20,7 @@ import paddle.nn.functional as F ...@@ -20,7 +20,7 @@ import paddle.nn.functional as F
from paddle.utils.download import get_weights_path_from_url from paddle.utils.download import get_weights_path_from_url
__all__ = ['MobileNetV2', 'mobilenet_v2'] __all__ = []
model_urls = { model_urls = {
'mobilenetv2_1.0': 'mobilenetv2_1.0':
......
...@@ -20,9 +20,7 @@ import paddle.nn as nn ...@@ -20,9 +20,7 @@ import paddle.nn as nn
from paddle.utils.download import get_weights_path_from_url from paddle.utils.download import get_weights_path_from_url
__all__ = [ __all__ = []
'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'
]
model_urls = { model_urls = {
'resnet18': ('https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams', 'resnet18': ('https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams',
......
...@@ -17,13 +17,7 @@ import paddle.nn as nn ...@@ -17,13 +17,7 @@ import paddle.nn as nn
from paddle.utils.download import get_weights_path_from_url from paddle.utils.download import get_weights_path_from_url
__all__ = [ __all__ = []
'VGG',
'vgg11',
'vgg13',
'vgg16',
'vgg19',
]
model_urls = { model_urls = {
'vgg16': ('https://paddle-hapi.bj.bcebos.com/models/vgg16.pdparams', 'vgg16': ('https://paddle-hapi.bj.bcebos.com/models/vgg16.pdparams',
......
...@@ -22,8 +22,12 @@ from ..fluid.initializer import Normal ...@@ -22,8 +22,12 @@ from ..fluid.initializer import Normal
from paddle.common_ops_import import * from paddle.common_ops_import import *
__all__ = [ __all__ = [ #noqa
'yolo_loss', 'yolo_box', 'deform_conv2d', 'DeformConv2D', 'read_file', 'yolo_loss',
'yolo_box',
'deform_conv2d',
'DeformConv2D',
'read_file',
'decode_jpeg' 'decode_jpeg'
] ]
......
...@@ -12,11 +12,70 @@ ...@@ -12,11 +12,70 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import transforms from .transforms import BaseTransform # noqa: F401
from . import functional from .transforms import Compose # noqa: F401
from .transforms import Resize # noqa: F401
from .transforms import RandomResizedCrop # noqa: F401
from .transforms import CenterCrop # noqa: F401
from .transforms import RandomHorizontalFlip # noqa: F401
from .transforms import RandomVerticalFlip # noqa: F401
from .transforms import Transpose # noqa: F401
from .transforms import Normalize # noqa: F401
from .transforms import BrightnessTransform # noqa: F401
from .transforms import SaturationTransform # noqa: F401
from .transforms import ContrastTransform # noqa: F401
from .transforms import HueTransform # noqa: F401
from .transforms import ColorJitter # noqa: F401
from .transforms import RandomCrop # noqa: F401
from .transforms import Pad # noqa: F401
from .transforms import RandomRotation # noqa: F401
from .transforms import Grayscale # noqa: F401
from .transforms import ToTensor # noqa: F401
from .functional import to_tensor # noqa: F401
from .functional import hflip # noqa: F401
from .functional import vflip # noqa: F401
from .functional import resize # noqa: F401
from .functional import pad # noqa: F401
from .functional import rotate # noqa: F401
from .functional import to_grayscale # noqa: F401
from .functional import crop # noqa: F401
from .functional import center_crop # noqa: F401
from .functional import adjust_brightness # noqa: F401
from .functional import adjust_contrast # noqa: F401
from .functional import adjust_hue # noqa: F401
from .functional import normalize # noqa: F401
from .transforms import * __all__ = [ #noqa
from .functional import * 'BaseTransform',
'Compose',
__all__ = transforms.__all__ \ 'Resize',
+ functional.__all__ 'RandomResizedCrop',
'CenterCrop',
'RandomHorizontalFlip',
'RandomVerticalFlip',
'Transpose',
'Normalize',
'BrightnessTransform',
'SaturationTransform',
'ContrastTransform',
'HueTransform',
'ColorJitter',
'RandomCrop',
'Pad',
'RandomRotation',
'Grayscale',
'ToTensor',
'to_tensor',
'hflip',
'vflip',
'resize',
'pad',
'rotate',
'to_grayscale',
'crop',
'center_crop',
'adjust_brightness',
'adjust_contrast',
'adjust_hue',
'normalize'
]
...@@ -29,11 +29,7 @@ from . import functional_pil as F_pil ...@@ -29,11 +29,7 @@ from . import functional_pil as F_pil
from . import functional_cv2 as F_cv2 from . import functional_cv2 as F_cv2
from . import functional_tensor as F_t from . import functional_tensor as F_t
__all__ = [ __all__ = []
'to_tensor', 'hflip', 'vflip', 'resize', 'pad', 'rotate', 'to_grayscale',
'crop', 'center_crop', 'adjust_brightness', 'adjust_contrast', 'adjust_hue',
'normalize'
]
def _is_pil_image(img): def _is_pil_image(img):
......
...@@ -33,6 +33,8 @@ else: ...@@ -33,6 +33,8 @@ else:
Sequence = collections.abc.Sequence Sequence = collections.abc.Sequence
Iterable = collections.abc.Iterable Iterable = collections.abc.Iterable
__all__ = []
def to_tensor(pic, data_format='CHW'): def to_tensor(pic, data_format='CHW'):
"""Converts a ``numpy.ndarray`` to paddle.Tensor. """Converts a ``numpy.ndarray`` to paddle.Tensor.
...@@ -49,7 +51,7 @@ def to_tensor(pic, data_format='CHW'): ...@@ -49,7 +51,7 @@ def to_tensor(pic, data_format='CHW'):
""" """
if not data_format in ['CHW', 'HWC']: if data_format not in ['CHW', 'HWC']:
raise ValueError('data_format should be CHW or HWC. Got {}'.format( raise ValueError('data_format should be CHW or HWC. Got {}'.format(
data_format)) data_format))
......
...@@ -41,6 +41,8 @@ _pil_interp_from_str = { ...@@ -41,6 +41,8 @@ _pil_interp_from_str = {
'hamming': Image.HAMMING 'hamming': Image.HAMMING
} }
__all__ = []
def to_tensor(pic, data_format='CHW'): def to_tensor(pic, data_format='CHW'):
"""Converts a ``PIL.Image`` to paddle.Tensor. """Converts a ``PIL.Image`` to paddle.Tensor.
...@@ -57,7 +59,7 @@ def to_tensor(pic, data_format='CHW'): ...@@ -57,7 +59,7 @@ def to_tensor(pic, data_format='CHW'):
""" """
if not data_format in ['CHW', 'HWC']: if data_format not in ['CHW', 'HWC']:
raise ValueError('data_format should be CHW or HWC. Got {}'.format( raise ValueError('data_format should be CHW or HWC. Got {}'.format(
data_format)) data_format))
......
...@@ -23,6 +23,8 @@ import paddle.nn.functional as F ...@@ -23,6 +23,8 @@ import paddle.nn.functional as F
import sys import sys
import collections import collections
__all__ = []
def _assert_image_tensor(img, data_format): def _assert_image_tensor(img, data_format):
if not isinstance( if not isinstance(
......
...@@ -35,13 +35,7 @@ else: ...@@ -35,13 +35,7 @@ else:
Sequence = collections.abc.Sequence Sequence = collections.abc.Sequence
Iterable = collections.abc.Iterable Iterable = collections.abc.Iterable
__all__ = [ __all__ = []
"BaseTransform", "Compose", "Resize", "RandomResizedCrop", "CenterCrop",
"RandomHorizontalFlip", "RandomVerticalFlip", "Transpose", "Normalize",
"BrightnessTransform", "SaturationTransform", "ContrastTransform",
"HueTransform", "ColorJitter", "RandomCrop", "Pad", "RandomRotation",
"Grayscale", "ToTensor"
]
def _get_image_size(img): def _get_image_size(img):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册