提交 8f8c8b15 编写于 作者: A Adam Geitgey

Expose face_distance and prevent errors with empty inputs

上级 b7b8b9f6
...@@ -4,4 +4,4 @@ __author__ = """Adam Geitgey""" ...@@ -4,4 +4,4 @@ __author__ = """Adam Geitgey"""
__email__ = 'ageitgey@gmail.com' __email__ = 'ageitgey@gmail.com'
__version__ = '0.1.0' __version__ = '0.1.0'
from .api import load_image_file, face_locations, face_landmarks, face_encodings, compare_faces from .api import load_image_file, face_locations, face_landmarks, face_encodings, compare_faces, face_distance
...@@ -52,16 +52,19 @@ def _trim_css_to_bounds(css, image_shape): ...@@ -52,16 +52,19 @@ def _trim_css_to_bounds(css, image_shape):
return max(css[0], 0), min(css[1], image_shape[1]), min(css[2], image_shape[0]), max(css[3], 0) return max(css[0], 0), min(css[1], image_shape[1]), min(css[2], image_shape[0]), max(css[3], 0)
def _face_distance(faces, face_to_compare): def face_distance(face_encodings, face_to_compare):
""" """
Given a list of face encodings, compared them to a known face encoding and get a euclidean distance Given a list of face encodings, compare them to a known face encoding and get a euclidean distance
for each comparison face. for each comparison face. The distance tells you how similar the faces are.
:param faces: List of face encodings to compare :param faces: List of face encodings to compare
:param face_to_compare: A face encoding to compare against :param face_to_compare: A face encoding to compare against
:return: A list with the distance for each face in the same order as the 'faces' array :return: A numpy ndarray with the distance for each face in the same order as the 'faces' array
""" """
return np.linalg.norm(faces - face_to_compare, axis=1) if len(face_encodings) == 0:
return np.empty((0))
return np.linalg.norm(face_encodings - face_to_compare, axis=1)
def load_image_file(filename, mode='RGB'): def load_image_file(filename, mode='RGB'):
...@@ -154,4 +157,4 @@ def compare_faces(known_face_encodings, face_encoding_to_check, tolerance=0.6): ...@@ -154,4 +157,4 @@ def compare_faces(known_face_encodings, face_encoding_to_check, tolerance=0.6):
:param tolerance: How much distance between faces to consider it a match. Lower is more strict. 0.6 is typical best performance. :param tolerance: How much distance between faces to consider it a match. Lower is more strict. 0.6 is typical best performance.
:return: A list of True/False values indicating which known_face_encodings match the face encoding to check :return: A list of True/False values indicating which known_face_encodings match the face encoding to check
""" """
return list(_face_distance(known_face_encodings, face_encoding_to_check) <= tolerance) return list(face_distance(known_face_encodings, face_encoding_to_check) <= tolerance)
...@@ -11,6 +11,7 @@ Tests for `face_recognition` module. ...@@ -11,6 +11,7 @@ Tests for `face_recognition` module.
import unittest import unittest
import os import os
import numpy as np
from click.testing import CliRunner from click.testing import CliRunner
from face_recognition import api from face_recognition import api
...@@ -94,6 +95,50 @@ class Test_face_recognition(unittest.TestCase): ...@@ -94,6 +95,50 @@ class Test_face_recognition(unittest.TestCase):
self.assertEqual(len(encodings), 1) self.assertEqual(len(encodings), 1)
self.assertEqual(len(encodings[0]), 128) self.assertEqual(len(encodings[0]), 128)
def test_face_distance(self):
img_a1 = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'obama.jpg'))
img_a2 = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'obama2.jpg'))
img_a3 = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'obama3.jpg'))
img_b1 = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'biden.jpg'))
face_encoding_a1 = api.face_encodings(img_a1)[0]
face_encoding_a2 = api.face_encodings(img_a2)[0]
face_encoding_a3 = api.face_encodings(img_a3)[0]
face_encoding_b1 = api.face_encodings(img_b1)[0]
faces_to_compare = [
face_encoding_a2,
face_encoding_a3,
face_encoding_b1]
distance_results = api.face_distance(faces_to_compare, face_encoding_a1)
# 0.6 is the default face distance match threshold. So we'll spot-check that the numbers returned
# are above or below that based on if they should match (since the exact numbers could vary).
self.assertEqual(type(distance_results), np.ndarray)
self.assertLessEqual(distance_results[0], 0.6)
self.assertLessEqual(distance_results[1], 0.6)
self.assertGreater(distance_results[2], 0.6)
def test_face_distance_empty_lists(self):
img = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'biden.jpg'))
face_encoding = api.face_encodings(img)[0]
# empty python list
faces_to_compare = []
distance_results = api.face_distance(faces_to_compare, face_encoding)
self.assertEqual(type(distance_results), np.ndarray)
self.assertEqual(len(distance_results), 0)
# empty numpy list
faces_to_compare = np.array([])
distance_results = api.face_distance(faces_to_compare, face_encoding)
self.assertEqual(type(distance_results), np.ndarray)
self.assertEqual(len(distance_results), 0)
def test_compare_faces(self): def test_compare_faces(self):
img_a1 = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'obama.jpg')) img_a1 = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'obama.jpg'))
img_a2 = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'obama2.jpg')) img_a2 = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'obama2.jpg'))
...@@ -112,10 +157,30 @@ class Test_face_recognition(unittest.TestCase): ...@@ -112,10 +157,30 @@ class Test_face_recognition(unittest.TestCase):
face_encoding_b1] face_encoding_b1]
match_results = api.compare_faces(faces_to_compare, face_encoding_a1) match_results = api.compare_faces(faces_to_compare, face_encoding_a1)
self.assertEqual(type(match_results), list)
self.assertTrue(match_results[0]) self.assertTrue(match_results[0])
self.assertTrue(match_results[1]) self.assertTrue(match_results[1])
self.assertFalse(match_results[2]) self.assertFalse(match_results[2])
def test_compare_faces_empty_lists(self):
img = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'biden.jpg'))
face_encoding = api.face_encodings(img)[0]
# empty python list
faces_to_compare = []
match_results = api.compare_faces(faces_to_compare, face_encoding)
self.assertEqual(type(match_results), list)
self.assertListEqual(match_results, [])
# empty numpy list
faces_to_compare = np.array([])
match_results = api.compare_faces(faces_to_compare, face_encoding)
self.assertEqual(type(match_results), list)
self.assertListEqual(match_results, [])
def test_command_line_interface(self): def test_command_line_interface(self):
target_string = '--help Show this message and exit.' target_string = '--help Show this message and exit.'
runner = CliRunner() runner = CliRunner()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册