提交 8f8c8b15 编写于 作者: A Adam Geitgey

Expose face_distance and prevent errors with empty inputs

上级 b7b8b9f6
......@@ -4,4 +4,4 @@ __author__ = """Adam Geitgey"""
__email__ = 'ageitgey@gmail.com'
__version__ = '0.1.0'
from .api import load_image_file, face_locations, face_landmarks, face_encodings, compare_faces
from .api import load_image_file, face_locations, face_landmarks, face_encodings, compare_faces, face_distance
......@@ -52,16 +52,19 @@ def _trim_css_to_bounds(css, image_shape):
return max(css[0], 0), min(css[1], image_shape[1]), min(css[2], image_shape[0]), max(css[3], 0)
def _face_distance(faces, face_to_compare):
def face_distance(face_encodings, face_to_compare):
"""
Given a list of face encodings, compared them to a known face encoding and get a euclidean distance
for each comparison face.
Given a list of face encodings, compare them to a known face encoding and get a euclidean distance
for each comparison face. The distance tells you how similar the faces are.
:param faces: List of face encodings to compare
:param face_to_compare: A face encoding to compare against
:return: A list with the distance for each face in the same order as the 'faces' array
:return: A numpy ndarray with the distance for each face in the same order as the 'faces' array
"""
return np.linalg.norm(faces - face_to_compare, axis=1)
if len(face_encodings) == 0:
return np.empty((0))
return np.linalg.norm(face_encodings - face_to_compare, axis=1)
def load_image_file(filename, mode='RGB'):
......@@ -154,4 +157,4 @@ def compare_faces(known_face_encodings, face_encoding_to_check, tolerance=0.6):
:param tolerance: How much distance between faces to consider it a match. Lower is more strict. 0.6 is typical best performance.
:return: A list of True/False values indicating which known_face_encodings match the face encoding to check
"""
return list(_face_distance(known_face_encodings, face_encoding_to_check) <= tolerance)
return list(face_distance(known_face_encodings, face_encoding_to_check) <= tolerance)
......@@ -11,6 +11,7 @@ Tests for `face_recognition` module.
import unittest
import os
import numpy as np
from click.testing import CliRunner
from face_recognition import api
......@@ -94,6 +95,50 @@ class Test_face_recognition(unittest.TestCase):
self.assertEqual(len(encodings), 1)
self.assertEqual(len(encodings[0]), 128)
def test_face_distance(self):
img_a1 = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'obama.jpg'))
img_a2 = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'obama2.jpg'))
img_a3 = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'obama3.jpg'))
img_b1 = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'biden.jpg'))
face_encoding_a1 = api.face_encodings(img_a1)[0]
face_encoding_a2 = api.face_encodings(img_a2)[0]
face_encoding_a3 = api.face_encodings(img_a3)[0]
face_encoding_b1 = api.face_encodings(img_b1)[0]
faces_to_compare = [
face_encoding_a2,
face_encoding_a3,
face_encoding_b1]
distance_results = api.face_distance(faces_to_compare, face_encoding_a1)
# 0.6 is the default face distance match threshold. So we'll spot-check that the numbers returned
# are above or below that based on if they should match (since the exact numbers could vary).
self.assertEqual(type(distance_results), np.ndarray)
self.assertLessEqual(distance_results[0], 0.6)
self.assertLessEqual(distance_results[1], 0.6)
self.assertGreater(distance_results[2], 0.6)
def test_face_distance_empty_lists(self):
img = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'biden.jpg'))
face_encoding = api.face_encodings(img)[0]
# empty python list
faces_to_compare = []
distance_results = api.face_distance(faces_to_compare, face_encoding)
self.assertEqual(type(distance_results), np.ndarray)
self.assertEqual(len(distance_results), 0)
# empty numpy list
faces_to_compare = np.array([])
distance_results = api.face_distance(faces_to_compare, face_encoding)
self.assertEqual(type(distance_results), np.ndarray)
self.assertEqual(len(distance_results), 0)
def test_compare_faces(self):
img_a1 = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'obama.jpg'))
img_a2 = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'obama2.jpg'))
......@@ -112,10 +157,30 @@ class Test_face_recognition(unittest.TestCase):
face_encoding_b1]
match_results = api.compare_faces(faces_to_compare, face_encoding_a1)
self.assertEqual(type(match_results), list)
self.assertTrue(match_results[0])
self.assertTrue(match_results[1])
self.assertFalse(match_results[2])
def test_compare_faces_empty_lists(self):
img = api.load_image_file(os.path.join(os.path.dirname(__file__), 'test_images', 'biden.jpg'))
face_encoding = api.face_encodings(img)[0]
# empty python list
faces_to_compare = []
match_results = api.compare_faces(faces_to_compare, face_encoding)
self.assertEqual(type(match_results), list)
self.assertListEqual(match_results, [])
# empty numpy list
faces_to_compare = np.array([])
match_results = api.compare_faces(faces_to_compare, face_encoding)
self.assertEqual(type(match_results), list)
self.assertListEqual(match_results, [])
def test_command_line_interface(self):
target_string = '--help Show this message and exit.'
runner = CliRunner()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册