提交 b10c1190 编写于 作者: A Adam Geitgey 提交者: GitHub

Make sure detected face locations don't cross outside image bounds (#15)

上级 3976fae6
......@@ -12,6 +12,8 @@ except:
print("pip install git+https://github.com/ageitgey/face_recognition_models")
quit()
face_detector = dlib.get_frontal_face_detector()
predictor_model = face_recognition_models.pose_predictor_model_location()
pose_predictor = dlib.shape_predictor(predictor_model)
......@@ -39,6 +41,17 @@ def _css_to_rect(css):
return dlib.rectangle(css[3], css[0], css[1], css[2])
def _trim_css_to_bounds(css, image_shape):
"""
Make sure a tuple in (top, right, bottom, left) order is within the bounds of the image.
:param css: plain tuple representation of the rect in (top, right, bottom, left) order
:param image_shape: numpy shape of the image array
:return: a trimmed plain tuple representation of the rect in (top, right, bottom, left) order
"""
return max(css[0], 0), min(css[1], image_shape[1]), min(css[2], image_shape[0]), max(css[3], 0)
def _face_distance(faces, face_to_compare):
"""
Given a list of face encodings, compared them to a known face encoding and get a euclidean distance
......@@ -70,7 +83,6 @@ def _raw_face_locations(img, number_of_times_to_upsample=1):
:param number_of_times_to_upsample: How many times to upsample the image looking for faces. Higher numbers find smaller faces.
:return: A list of dlib 'rect' objects of found face locations
"""
face_detector = dlib.get_frontal_face_detector()
return face_detector(img, number_of_times_to_upsample)
......@@ -82,7 +94,7 @@ def face_locations(img, number_of_times_to_upsample=1):
:param number_of_times_to_upsample: How many times to upsample the image looking for faces. Higher numbers find smaller faces.
:return: A list of tuples of found face locations in css (top, right, bottom, left) order
"""
return [_rect_to_css(face) for face in _raw_face_locations(img, number_of_times_to_upsample)]
return [_trim_css_to_bounds(_rect_to_css(face), img.shape) for face in _raw_face_locations(img, number_of_times_to_upsample)]
def _raw_face_landmarks(face_image, face_locations=None):
......
......@@ -50,6 +50,19 @@ class Test_face_recognition(unittest.TestCase):
self.assertEqual(len(detected_faces), 1)
self.assertEqual(detected_faces[0], (142, 617, 409, 349))
def test_partial_face_locations(self):
img = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'obama_partial_face.jpg'))
detected_faces = api.face_locations(img)
self.assertEqual(len(detected_faces), 1)
self.assertEqual(detected_faces[0], (142, 191, 365, 0))
img = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'obama_partial_face2.jpg'))
detected_faces = api.face_locations(img)
self.assertEqual(len(detected_faces), 1)
self.assertEqual(detected_faces[0], (142, 551, 409, 349))
def test_raw_face_landmarks(self):
img = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'obama.jpg'))
face_landmarks = api._raw_face_landmarks(img)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册