From 43facfd375b9e1018bc1e9c3c862e9f59d10653c Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Wed, 1 Jul 2020 21:09:28 +0800 Subject: [PATCH] [Cherry-pick]Add DistributedBatchSampler and Colerjitter (#25242) * add DistributedSampler and ColorJitter, test=develop --- python/CMakeLists.txt | 1 + python/paddle/__init__.py | 3 + python/paddle/incubate/__init__.py | 18 ++ python/paddle/incubate/hapi/__init__.py | 18 ++ python/paddle/incubate/hapi/distributed.py | 134 +++++++++++ .../paddle/incubate/hapi/tests/CMakeLists.txt | 6 + .../hapi/tests/test_distributed_sampler.py | 69 ++++++ .../incubate/hapi/tests/test_transforms.py | 52 ++++ .../paddle/incubate/hapi/vision/__init__.py | 18 ++ .../hapi/vision/transforms/__init__.py | 19 ++ .../hapi/vision/transforms/transforms.py | 222 ++++++++++++++++++ python/setup.py.in | 4 + 12 files changed, 564 insertions(+) create mode 100644 python/paddle/incubate/__init__.py create mode 100644 python/paddle/incubate/hapi/__init__.py create mode 100644 python/paddle/incubate/hapi/distributed.py create mode 100644 python/paddle/incubate/hapi/tests/CMakeLists.txt create mode 100644 python/paddle/incubate/hapi/tests/test_distributed_sampler.py create mode 100644 python/paddle/incubate/hapi/tests/test_transforms.py create mode 100644 python/paddle/incubate/hapi/vision/__init__.py create mode 100644 python/paddle/incubate/hapi/vision/transforms/__init__.py create mode 100644 python/paddle/incubate/hapi/vision/transforms/transforms.py diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 77a75280049..59dfc5c9d03 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -96,6 +96,7 @@ if (WITH_TESTING) add_subdirectory(paddle/fluid/tests) add_subdirectory(paddle/fluid/contrib/tests) add_subdirectory(paddle/fluid/contrib/slim/tests) + add_subdirectory(paddle/incubate/hapi/tests) endif() install(DIRECTORY ${PADDLE_PYTHON_PACKAGE_DIR} DESTINATION opt/paddle/share/wheels diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index b67429dcfdc..ce08b1ea114 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -35,3 +35,6 @@ import paddle.distributed batch = batch.batch import paddle.sysconfig import paddle.complex + +from . import incubate +from .incubate import hapi diff --git a/python/paddle/incubate/__init__.py b/python/paddle/incubate/__init__.py new file mode 100644 index 00000000000..e6888ebc8f4 --- /dev/null +++ b/python/paddle/incubate/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import hapi + +__all__ = [] +__all__ += hapi.__all__ diff --git a/python/paddle/incubate/hapi/__init__.py b/python/paddle/incubate/hapi/__init__.py new file mode 100644 index 00000000000..5680164737c --- /dev/null +++ b/python/paddle/incubate/hapi/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import distributed +from . import vision + +__all__ = ['distributed', 'vision'] diff --git a/python/paddle/incubate/hapi/distributed.py b/python/paddle/incubate/hapi/distributed.py new file mode 100644 index 00000000000..eefd0895a95 --- /dev/null +++ b/python/paddle/incubate/hapi/distributed.py @@ -0,0 +1,134 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np + +from paddle.fluid.dygraph.parallel import ParallelEnv +from paddle.fluid.io import BatchSampler + +__all__ = ['DistributedBatchSampler'] + + +class DistributedBatchSampler(BatchSampler): + """Sampler that restricts data loading to a subset of the dataset. + In such case, each process can pass a DistributedBatchSampler instance + as a DataLoader sampler, and load a subset of the original dataset that + is exclusive to it. + .. note:: + Dataset is assumed to be of constant size. + + Args: + dataset(paddle.io.Dataset): this could be a `paddle.io.Dataset` implement + or other python object which implemented + `__len__` for BatchSampler to get sample + number of data source. + batch_size(int): sample indice number in a mini-batch indices. + shuffle(bool): whther to shuffle indices order before genrating + batch indices. Default False. + drop_last(bool): whether drop the last incomplete batch dataset size + is not divisible by the batch size. Default False + Examples: + .. code-block:: python + + from paddle.incubate.hapi.distributed import DistributedBatchSampler + class FakeDataset(): + def __init__(self): + pass + + def __getitem__(self, idx): + return idx, + + def __len__(self): + return 10 + + train_dataset = FakeDataset() + dist_train_dataloader = DistributedBatchSampler(train_dataset, batch_size=4) + for data in dist_train_dataloader: + # do something + break + """ + + def __init__(self, dataset, batch_size, shuffle=False, drop_last=False): + self.dataset = dataset + + assert isinstance(batch_size, int) and batch_size > 0, \ + "batch_size should be a positive integer" + self.batch_size = batch_size + assert isinstance(shuffle, bool), \ + "shuffle should be a boolean value" + self.shuffle = shuffle + assert isinstance(drop_last, bool), \ + "drop_last should be a boolean number" + + self.drop_last = drop_last + self.nranks = ParallelEnv().nranks + self.local_rank = ParallelEnv().local_rank + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks)) + self.total_size = self.num_samples * self.nranks + + def __iter__(self): + num_samples = len(self.dataset) + indices = np.arange(num_samples).tolist() + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + if self.shuffle: + np.random.RandomState(self.epoch).shuffle(indices) + self.epoch += 1 + + # subsample + def _get_indices_by_batch_size(indices): + subsampled_indices = [] + last_batch_size = self.total_size % (self.batch_size * self.nranks) + assert last_batch_size % self.nranks == 0 + last_local_batch_size = last_batch_size // self.nranks + + for i in range(self.local_rank * self.batch_size, + len(indices) - last_batch_size, + self.batch_size * self.nranks): + subsampled_indices.extend(indices[i:i + self.batch_size]) + + indices = indices[len(indices) - last_batch_size:] + subsampled_indices.extend(indices[ + self.local_rank * last_local_batch_size:( + self.local_rank + 1) * last_local_batch_size]) + return subsampled_indices + + if self.nranks > 1: + indices = _get_indices_by_batch_size(indices) + + assert len(indices) == self.num_samples + _sample_iter = iter(indices) + + batch_indices = [] + for idx in _sample_iter: + batch_indices.append(idx) + if len(batch_indices) == self.batch_size: + yield batch_indices + batch_indices = [] + if not self.drop_last and len(batch_indices) > 0: + yield batch_indices + + def __len__(self): + num_samples = self.num_samples + num_samples += int(not self.drop_last) * (self.batch_size - 1) + return num_samples // self.batch_size + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/python/paddle/incubate/hapi/tests/CMakeLists.txt b/python/paddle/incubate/hapi/tests/CMakeLists.txt new file mode 100644 index 00000000000..79bec8c4ad3 --- /dev/null +++ b/python/paddle/incubate/hapi/tests/CMakeLists.txt @@ -0,0 +1,6 @@ +file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + +foreach(src ${TEST_OPS}) + py_test(${src} SRCS ${src}.py) +endforeach() diff --git a/python/paddle/incubate/hapi/tests/test_distributed_sampler.py b/python/paddle/incubate/hapi/tests/test_distributed_sampler.py new file mode 100644 index 00000000000..39f1bf10827 --- /dev/null +++ b/python/paddle/incubate/hapi/tests/test_distributed_sampler.py @@ -0,0 +1,69 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import unittest + +from paddle.incubate.hapi.distributed import DistributedBatchSampler + + +class FakeDataset(): + def __init__(self): + pass + + def __getitem__(self, index): + return index + + def __len__(self): + return 10 + + +class TestDistributedBatchSampler(unittest.TestCase): + def test_sampler(self): + dataset = FakeDataset() + sampler = DistributedBatchSampler(dataset, batch_size=1, shuffle=True) + for batch_idx in sampler: + batch_idx + pass + + def test_multiple_gpus_sampler(self): + dataset = FakeDataset() + sampler1 = DistributedBatchSampler( + dataset, batch_size=4, shuffle=True, drop_last=True) + sampler2 = DistributedBatchSampler( + dataset, batch_size=4, shuffle=True, drop_last=True) + + sampler1.nranks = 2 + sampler1.local_rank = 0 + sampler1.num_samples = int( + math.ceil(len(dataset) * 1.0 / sampler1.nranks)) + sampler1.total_size = sampler1.num_samples * sampler1.nranks + + sampler2.nranks = 2 + sampler2.local_rank = 1 + sampler2.num_samples = int( + math.ceil(len(dataset) * 1.0 / sampler2.nranks)) + sampler2.total_size = sampler2.num_samples * sampler2.nranks + + for batch_idx in sampler1: + batch_idx + pass + + for batch_idx in sampler2: + batch_idx + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/incubate/hapi/tests/test_transforms.py b/python/paddle/incubate/hapi/tests/test_transforms.py new file mode 100644 index 00000000000..1304941bd35 --- /dev/null +++ b/python/paddle/incubate/hapi/tests/test_transforms.py @@ -0,0 +1,52 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +from paddle.incubate.hapi.vision.transforms import transforms + + +class TestTransforms(unittest.TestCase): + def do_transform(self, trans): + fake_img = (np.random.random((400, 300, 3)) * 255).astype('uint8') + for t in trans: + fake_img = t(fake_img) + + def test_color_jitter(self): + trans = [ + transforms.BrightnessTransform(0.0), transforms.HueTransform(0.0), + transforms.SaturationTransform(0.0), + transforms.ContrastTransform(0.0), + transforms.ColorJitter(0.2, 0.2, 0.2, 0.2) + ] + self.do_transform(trans) + + def test_exception(self): + + with self.assertRaises(ValueError): + transforms.ContrastTransform(-1.0) + + with self.assertRaises(ValueError): + transforms.SaturationTransform(-1.0), + + with self.assertRaises(ValueError): + transforms.HueTransform(-1.0) + + with self.assertRaises(ValueError): + transforms.BrightnessTransform(-1.0) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/incubate/hapi/vision/__init__.py b/python/paddle/incubate/hapi/vision/__init__.py new file mode 100644 index 00000000000..e783592e09e --- /dev/null +++ b/python/paddle/incubate/hapi/vision/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import transforms +from .transforms import * + +__all__ = transforms.__all__ diff --git a/python/paddle/incubate/hapi/vision/transforms/__init__.py b/python/paddle/incubate/hapi/vision/transforms/__init__.py new file mode 100644 index 00000000000..3248639f3b2 --- /dev/null +++ b/python/paddle/incubate/hapi/vision/transforms/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import transforms + +from .transforms import * + +__all__ = transforms.__all__ diff --git a/python/paddle/incubate/hapi/vision/transforms/transforms.py b/python/paddle/incubate/hapi/vision/transforms/transforms.py new file mode 100644 index 00000000000..23250922c19 --- /dev/null +++ b/python/paddle/incubate/hapi/vision/transforms/transforms.py @@ -0,0 +1,222 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import sys +import cv2 +import random + +import numpy as np +import collections + +if sys.version_info < (3, 3): + Sequence = collections.Sequence + Iterable = collections.Iterable +else: + Sequence = collections.abc.Sequence + Iterable = collections.abc.Iterable + +__all__ = [ + "BrightnessTransform", + "SaturationTransform", + "ContrastTransform", + "HueTransform", + "ColorJitter", +] + + +class BrightnessTransform(object): + """Adjust brightness of the image. + Args: + value (float): How much to adjust the brightness. Can be any + non negative number. 0 gives the original image + Examples: + + .. code-block:: python + import numpy as np + from paddle.incubate.hapi.vision.transforms import BrightnessTransform + transform = BrightnessTransform(0.4) + fake_img = np.random.rand(500, 500, 3).astype('float32') + fake_img = transform(fake_img) + print(fake_img.shape) + """ + + def __init__(self, value): + if value < 0: + raise ValueError("brightness value should be non-negative") + self.value = value + + def __call__(self, img): + if self.value == 0: + return img + + dtype = img.dtype + img = img.astype(np.float32) + alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value) + img = img * alpha + return img.clip(0, 255).astype(dtype) + + +class ContrastTransform(object): + """Adjust contrast of the image. + Args: + value (float): How much to adjust the contrast. Can be any + non negative number. 0 gives the original image + Examples: + + .. code-block:: python + import numpy as np + from paddle.incubate.hapi.vision.transforms import ContrastTransform + transform = ContrastTransform(0.4) + fake_img = np.random.rand(500, 500, 3).astype('float32') + fake_img = transform(fake_img) + print(fake_img.shape) + """ + + def __init__(self, value): + if value < 0: + raise ValueError("contrast value should be non-negative") + self.value = value + + def __call__(self, img): + if self.value == 0: + return img + + dtype = img.dtype + img = img.astype(np.float32) + alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value) + img = img * alpha + cv2.cvtColor(img, cv2.COLOR_BGR2GRAY).mean() * ( + 1 - alpha) + return img.clip(0, 255).astype(dtype) + + +class SaturationTransform(object): + """Adjust saturation of the image. + Args: + value (float): How much to adjust the saturation. Can be any + non negative number. 0 gives the original image + Examples: + + .. code-block:: python + import numpy as np + from paddle.incubate.hapi.vision.transforms import SaturationTransform + transform = SaturationTransform(0.4) + fake_img = np.random.rand(500, 500, 3).astype('float32') + + fake_img = transform(fake_img) + print(fake_img.shape) + """ + + def __init__(self, value): + if value < 0: + raise ValueError("saturation value should be non-negative") + self.value = value + + def __call__(self, img): + if self.value == 0: + return img + + dtype = img.dtype + img = img.astype(np.float32) + alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value) + gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + gray_img = gray_img[..., np.newaxis] + img = img * alpha + gray_img * (1 - alpha) + return img.clip(0, 255).astype(dtype) + + +class HueTransform(object): + """Adjust hue of the image. + Args: + value (float): How much to adjust the hue. Can be any number + between 0 and 0.5, 0 gives the original image + Examples: + + .. code-block:: python + import numpy as np + from paddle.incubate.hapi.vision.transforms import HueTransform + transform = HueTransform(0.4) + fake_img = np.random.rand(500, 500, 3).astype('float32') + fake_img = transform(fake_img) + print(fake_img.shape) + """ + + def __init__(self, value): + if value < 0 or value > 0.5: + raise ValueError("hue value should be in [0.0, 0.5]") + self.value = value + + def __call__(self, img): + if self.value == 0: + return img + + dtype = img.dtype + img = img.astype(np.uint8) + hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV_FULL) + h, s, v = cv2.split(hsv_img) + + alpha = np.random.uniform(-self.value, self.value) + h = h.astype(np.uint8) + # uint8 addition take cares of rotation across boundaries + with np.errstate(over="ignore"): + h += np.uint8(alpha * 255) + hsv_img = cv2.merge([h, s, v]) + return cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR_FULL).astype(dtype) + + +class ColorJitter(object): + """Randomly change the brightness, contrast, saturation and hue of an image. + Args: + brightness: How much to jitter brightness. + Chosen uniformly from [max(0, 1 - brightness), 1 + brightness] + or the given [min, max]. Should be non negative numbers. + contrast: How much to jitter contrast. + Chosen uniformly from [max(0, 1 - contrast), 1 + contrast] + or the given [min, max]. Should be non negative numbers. + saturation: How much to jitter saturation. + Chosen uniformly from [max(0, 1 - saturation), 1 + saturation] + or the given [min, max]. Should be non negative numbers. + hue: How much to jitter hue. + Chosen uniformly from [-hue, hue] or the given [min, max]. + Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. + Examples: + + .. code-block:: python + import numpy as np + from paddle.incubate.hapi.vision.transforms import ColorJitter + transform = ColorJitter(0.4) + fake_img = np.random.rand(500, 500, 3).astype('float32') + fake_img = transform(fake_img) + print(fake_img.shape) + """ + + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + transforms = [] + if brightness != 0: + transforms.append(BrightnessTransform(brightness)) + if contrast != 0: + transforms.append(ContrastTransform(contrast)) + if saturation != 0: + transforms.append(SaturationTransform(saturation)) + if hue != 0: + transforms.append(HueTransform(hue)) + + random.shuffle(transforms) + self.transforms = transforms + + def __call__(self, img): + for t in self.transforms: + img = t(img) + return img diff --git a/python/setup.py.in b/python/setup.py.in index 1e698cf5112..7370c38ecfb 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -177,6 +177,10 @@ packages=['paddle', 'paddle.fluid.incubate.fleet.parameter_server.pslib', 'paddle.fluid.incubate.fleet.collective', 'paddle.fluid.incubate.fleet.utils', + 'paddle.incubate', + 'paddle.incubate.hapi', + 'paddle.incubate.hapi.vision', + 'paddle.incubate.hapi.vision.transforms', ] with open('@PADDLE_SOURCE_DIR@/python/requirements.txt') as f: -- GitLab