提交 af9816b2 编写于 作者: D dangqingqing

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into config_parse

......@@ -29,7 +29,7 @@ if [ $USE_VIRTUALENV_FOR_TEST -ne 0 ]; then
fi
export PYTHONPATH=$SCRIPTPATH/../../python/
$PYTHON -m pip install $SCRIPTPATH/../dist/*.whl requests matplotlib ipython==5.3
$PYTHON -m pip install $SCRIPTPATH/../dist/*.whl requests matplotlib opencv-python ipython==5.3
for fn in "$@"
do
......
......@@ -33,11 +33,12 @@ import networks
import py_paddle.swig_paddle as api
import minibatch
import plot
import image
__all__ = [
'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer',
'event', 'data_type', 'attr', 'pooling', 'data_feeder', 'dataset', 'reader',
'topology', 'networks', 'infer', 'plot', 'evaluator'
'topology', 'networks', 'infer', 'plot', 'evaluator', 'image'
]
......
import numpy as np
try:
import cv2
except ImportError:
cv2 = None
from cv2 import resize
__all__ = [
"load_image", "resize_short", "to_chw", "center_crop", "random_crop",
"left_right_flip", "simple_transform", "load_and_transform"
]
"""
This file contains some common interfaces for image preprocess.
Many users are confused about the image layout. We introduce
the image layout as follows.
- CHW Layout
- The abbreviations: C=channel, H=Height, W=Width
- The default layout of image opened by cv2 or PIL is HWC.
PaddlePaddle only supports the CHW layout. And CHW is simply
a transpose of HWC. It must transpose the input image.
- Color format: RGB or BGR
OpenCV use BGR color format. PIL use RGB color format. Both
formats can be used for training. Noted that, the format should
be keep consistent between the training and inference peroid.
"""
def load_image(file, is_color=True):
"""
Load an color or gray image from the file path.
Example usage:
.. code-block:: python
im = load_image('cat.jpg')
:param file: the input image path.
:type file: string
:param is_color: If set is_color True, it will load and
return a color image. Otherwise, it will
load and return a gray image.
"""
# cv2.IMAGE_COLOR for OpenCV3
# cv2.CV_LOAD_IMAGE_COLOR for older OpenCV Version
# cv2.IMAGE_GRAYSCALE for OpenCV3
# cv2.CV_LOAD_IMAGE_GRAYSCALE for older OpenCV Version
# Here, use constant 1 and 0
# 1: COLOR, 0: GRAYSCALE
flag = 1 if is_color else 0
im = cv2.imread(file, flag)
return im
def resize_short(im, size):
"""
Resize an image so that the length of shorter edge is size.
Example usage:
.. code-block:: python
im = load_image('cat.jpg')
im = resize_short(im, 256)
:param im: the input image with HWC layout.
:type im: ndarray
:param size: the shorter edge size of image after resizing.
:type size: int
"""
assert im.shape[-1] == 1 or im.shape[-1] == 3
h, w = im.shape[:2]
h_new, w_new = size, size
if h > w:
h_new = size * h / w
else:
w_new = size * w / h
im = resize(im, (h_new, w_new), interpolation=cv2.INTER_CUBIC)
return im
def to_chw(im, order=(2, 0, 1)):
"""
Transpose the input image order. The image layout is HWC format
opened by cv2 or PIL. Transpose the input image to CHW layout
according the order (2,0,1).
Example usage:
.. code-block:: python
im = load_image('cat.jpg')
im = resize_short(im, 256)
im = to_chw(im)
:param im: the input image with HWC layout.
:type im: ndarray
:param order: the transposed order.
:type order: tuple|list
"""
assert len(im.shape) == len(order)
im = im.transpose(order)
return im
def center_crop(im, size, is_color=True):
"""
Crop the center of image with size.
Example usage:
.. code-block:: python
im = center_crop(im, 224)
:param im: the input image with HWC layout.
:type im: ndarray
:param size: the cropping size.
:type size: int
:param is_color: whether the image is color or not.
:type is_color: bool
"""
h, w = im.shape[:2]
h_start = (h - size) / 2
w_start = (w - size) / 2
h_end, w_end = h_start + size, w_start + size
if is_color:
im = im[h_start:h_end, w_start:w_end, :]
else:
im = im[h_start:h_end, w_start:w_end]
return im
def random_crop(im, size, is_color=True):
"""
Randomly crop input image with size.
Example usage:
.. code-block:: python
im = random_crop(im, 224)
:param im: the input image with HWC layout.
:type im: ndarray
:param size: the cropping size.
:type size: int
:param is_color: whether the image is color or not.
:type is_color: bool
"""
h, w = im.shape[:2]
h_start = np.random.randint(0, h - size + 1)
w_start = np.random.randint(0, w - size + 1)
h_end, w_end = h_start + size, w_start + size
if is_color:
im = im[h_start:h_end, w_start:w_end, :]
else:
im = im[h_start:h_end, w_start:w_end]
return im
def left_right_flip(im):
"""
Flip an image along the horizontal direction.
Return the flipped image.
Example usage:
.. code-block:: python
im = left_right_flip(im)
:paam im: input image with HWC layout
:type im: ndarray
"""
if len(im.shape) == 3:
return im[:, ::-1, :]
else:
return im[:, ::-1, :]
def simple_transform(im, resize_size, crop_size, is_train, is_color=True):
"""
Simply data argumentation for training. These operations include
resizing, croping and flipping.
Example usage:
.. code-block:: python
im = simple_transform(im, 256, 224, True)
:param im: The input image with HWC layout.
:type im: ndarray
:param resize_size: The shorter edge length of the resized image.
:type resize_size: int
:param crop_size: The cropping size.
:type crop_size: int
:param is_train: Whether it is training or not.
:type is_train: bool
"""
im = resize_short(im, resize_size)
if is_train:
im = random_crop(im, crop_size)
if np.random.randint(2) == 0:
im = left_right_flip(im)
else:
im = center_crop(im, crop_size)
im = to_chw(im)
return im
def load_and_transform(filename,
resize_size,
crop_size,
is_train,
is_color=True):
"""
Load image from the input file `filename` and transform image for
data argumentation. Please refer to the `simple_transform` interface
for the transform operations.
Example usage:
.. code-block:: python
im = load_and_transform('cat.jpg', 256, 224, True)
:param filename: The file name of input image.
:type filename: string
:param resize_size: The shorter edge length of the resized image.
:type resize_size: int
:param crop_size: The cropping size.
:type crop_size: int
:param is_train: Whether it is training or not.
:type is_train: bool
"""
im = load_image(filename)
im = simple_transform(im, resize_size, crop_size, is_train, is_color)
return im
add_python_test(test_v2_api test_data_feeder.py test_parameters.py test_layer.py test_rnn_layer.py test_topology.py)
add_python_test(test_v2_api test_data_feeder.py test_parameters.py
test_layer.py test_rnn_layer.py test_topology.py test_image.py)
# Copyright PaddlePaddle contributors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle.v2.image as image
class Image(unittest.TestCase):
def test_resize_flip_chw(self):
# resize
im = image.load_image('cat.jpg')
im = image.resize_short(im, 256)
self.assertEqual(256, min(im.shape[:2]))
self.assertEqual(3, im.shape[2])
# flip
im = image.left_right_flip(im)
im2 = np.flip(im, 1)
self.assertEqual(im.all(), im2.all())
# to_chw
h, w, c = im.shape
im = image.to_chw(im)
self.assertEqual(c, im.shape[0])
self.assertEqual(h, im.shape[1])
self.assertEqual(w, im.shape[2])
if __name__ == '__main__':
unittest.main()
......@@ -18,6 +18,7 @@ setup(name='paddle',
"numpy",
"protobuf==${PROTOBUF_VERSION}",
"matplotlib",
"opencv-python",
],
packages=packages,
package_dir={
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册