提交 76ee482e 编写于 作者: M minqiyang

Fix cv2 issues

上级 6dc07e7f
...@@ -33,6 +33,11 @@ import numpy as np ...@@ -33,6 +33,11 @@ import numpy as np
try: try:
import cv2 import cv2
except ImportError: except ImportError:
import sys
sys.stderr.write(
'''Warning with paddle image module: opencv-python should be imported,
or paddle image module could NOT work; please install opencv-python first.'''
)
cv2 = None cv2 = None
import os import os
import tarfile import tarfile
...@@ -126,6 +131,8 @@ def load_image_bytes(bytes, is_color=True): ...@@ -126,6 +131,8 @@ def load_image_bytes(bytes, is_color=True):
load and return a gray image. load and return a gray image.
:type is_color: bool :type is_color: bool
""" """
assert cv2 is not None
flag = 1 if is_color else 0 flag = 1 if is_color else 0
file_bytes = np.asarray(bytearray(bytes), dtype=np.uint8) file_bytes = np.asarray(bytearray(bytes), dtype=np.uint8)
img = cv2.imdecode(file_bytes, flag) img = cv2.imdecode(file_bytes, flag)
...@@ -149,6 +156,8 @@ def load_image(file, is_color=True): ...@@ -149,6 +156,8 @@ def load_image(file, is_color=True):
load and return a gray image. load and return a gray image.
:type is_color: bool :type is_color: bool
""" """
assert cv2 is not None
# cv2.IMAGE_COLOR for OpenCV3 # cv2.IMAGE_COLOR for OpenCV3
# cv2.CV_LOAD_IMAGE_COLOR for older OpenCV Version # cv2.CV_LOAD_IMAGE_COLOR for older OpenCV Version
# cv2.IMAGE_GRAYSCALE for OpenCV3 # cv2.IMAGE_GRAYSCALE for OpenCV3
...@@ -176,12 +185,14 @@ def resize_short(im, size): ...@@ -176,12 +185,14 @@ def resize_short(im, size):
:param size: the shorter edge size of image after resizing. :param size: the shorter edge size of image after resizing.
:type size: int :type size: int
""" """
assert cv2 is not None
h, w = im.shape[:2] h, w = im.shape[:2]
h_new, w_new = size, size h_new, w_new = size, size
if h > w: if h > w:
h_new = size * h / w h_new = size * h // w
else: else:
w_new = size * w / h w_new = size * w // h
im = cv2.resize(im, (h_new, w_new), interpolation=cv2.INTER_CUBIC) im = cv2.resize(im, (h_new, w_new), interpolation=cv2.INTER_CUBIC)
return im return im
...@@ -228,8 +239,8 @@ def center_crop(im, size, is_color=True): ...@@ -228,8 +239,8 @@ def center_crop(im, size, is_color=True):
:type is_color: bool :type is_color: bool
""" """
h, w = im.shape[:2] h, w = im.shape[:2]
h_start = (h - size) / 2 h_start = (h - size) // 2
w_start = (w - size) / 2 w_start = (w - size) // 2
h_end, w_end = h_start + size, w_start + size h_end, w_end = h_start + size, w_start + size
if is_color: if is_color:
im = im[h_start:h_end, w_start:w_end, :] im = im[h_start:h_end, w_start:w_end, :]
......
...@@ -362,9 +362,14 @@ class OpTest(unittest.TestCase): ...@@ -362,9 +362,14 @@ class OpTest(unittest.TestCase):
def check_output_customized(self, checker): def check_output_customized(self, checker):
places = self._get_places() places = self._get_places()
import sys
print('places', places)
for place in places: for place in places:
outs = self.calc_output(place) outs = self.calc_output(place)
outs = [np.array(out) for out in outs] outs = [np.array(out) for out in outs]
import sys
print('outs', outs)
sys.stdout.flush()
checker(outs) checker(outs)
def __assert_is_close(self, numeric_grads, analytic_grads, names, def __assert_is_close(self, numeric_grads, analytic_grads, names,
......
...@@ -27,6 +27,7 @@ from six.moves import zip ...@@ -27,6 +27,7 @@ from six.moves import zip
import itertools import itertools
import random import random
import zlib import zlib
import paddle.fluid.compat as cpt
def map_readers(func, *readers): def map_readers(func, *readers):
...@@ -390,9 +391,9 @@ class PipeReader: ...@@ -390,9 +391,9 @@ class PipeReader:
buff = self.process.stdout.read(self.bufsize) buff = self.process.stdout.read(self.bufsize)
if buff: if buff:
if self.file_type == "gzip": if self.file_type == "gzip":
decomp_buff = self.dec.decompress(buff) decomp_buff = cpt.to_literal_str(self.dec.decompress(buff))
elif self.file_type == "plain": elif self.file_type == "plain":
decomp_buff = buff decomp_buff = cpt.to_literal_str(buff)
else: else:
raise TypeError("file_type %s is not allowed" % raise TypeError("file_type %s is not allowed" %
self.file_type) self.file_type)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册