未验证 提交 89fb3c8b 编写于 作者: L lijianshe02 提交者: GitHub

add blazeface detector (#229)

* add blazeface detector
上级 273ba198
......@@ -57,6 +57,12 @@ parser.add_argument("--ratio",
type=float,
default=0.4,
help="margin ratio")
parser.add_argument(
"--face_detector",
dest="face_detector",
type=str,
default='sfd',
help="face detector to be used, can choose s3fd or blazeface")
parser.set_defaults(relative=False)
parser.set_defaults(adapt_scale=False)
......@@ -75,5 +81,6 @@ if __name__ == "__main__":
adapt_scale=args.adapt_scale,
find_best_frame=args.find_best_frame,
best_frame=args.best_frame,
ratio=args.ratio)
ratio=args.ratio,
face_detector=args.face_detector)
predictor.run(args.source_image, args.driving_video)
......@@ -97,6 +97,12 @@ parser.add_argument(
action='store_true',
help='Prevent smoothing face detections over a short temporal window')
parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
parser.add_argument(
"--face_detector",
dest="face_detector",
type=str,
default='sfd',
help="face detector to be used, can choose s3fd or blazeface")
if __name__ == "__main__":
args = parser.parse_args()
......
......@@ -46,7 +46,8 @@ class FirstOrderPredictor(BasePredictor):
find_best_frame=False,
best_frame=None,
ratio=1.0,
filename='result.mp4'):
filename='result.mp4',
face_detector='sfd'):
if config is not None and isinstance(config, str):
self.cfg = yaml.load(config, Loader=yaml.SafeLoader)
elif isinstance(config, dict):
......@@ -95,6 +96,7 @@ class FirstOrderPredictor(BasePredictor):
self.find_best_frame = find_best_frame
self.best_frame = best_frame
self.ratio = ratio
self.face_detector = face_detector
self.generator, self.kp_detector = self.load_checkpoints(
self.cfg, self.weight_path)
......@@ -261,7 +263,9 @@ class FirstOrderPredictor(BasePredictor):
def extract_bbox(self, image):
detector = face_detection.FaceAlignment(
face_detection.LandmarksType._2D, flip_input=False)
face_detection.LandmarksType._2D,
flip_input=False,
face_detector=self.face_detector)
frame = [image]
predictions = detector.get_detections_for_image(np.array(frame))
......
......@@ -36,7 +36,9 @@ class Wav2LipPredictor(BasePredictor):
def face_detect(self, images):
detector = face_detection.FaceAlignment(
face_detection.LandmarksType._2D, flip_input=False)
face_detection.LandmarksType._2D,
flip_input=False,
face_detector=self.args.face_detector)
batch_size = self.args.face_det_batch_size
......
......@@ -80,7 +80,7 @@ class FaceAlignment:
d = d[0]
d = np.clip(d, 0, None)
x1, y1, x2, y2 = map(int, d[:-1])
x1, y1, x2, y2 = map(int, d[:4])
results.append((x1, y1, x2, y2))
return results
......
from .blazeface_detector import BlazeFaceDetector as FaceDetector
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
from paddle.utils.download import get_weights_path_from_url
from ..core import FaceDetector
from .net_blazeface import BlazeFace
from .detect import *
blazeface_weights = 'https://paddlegan.bj.bcebos.com/models/blazeface.pdparams'
blazeface_anchors = 'https://paddlegan.bj.bcebos.com/models/anchors.npy'
class BlazeFaceDetector(FaceDetector):
def __init__(self,
path_to_detector=None,
path_to_anchor=None,
verbose=False,
min_score_thresh=0.5,
min_suppression_threshold=0.3):
super(BlazeFaceDetector, self).__init__(verbose)
# Initialise the face detector
if path_to_detector is None:
model_weights_path = get_weights_path_from_url(blazeface_weights)
model_weights = paddle.load(model_weights_path)
model_anchors = np.load(
get_weights_path_from_url(blazeface_anchors))
else:
model_weights = paddle.load(path_to_detector)
model_anchors = np.load(path_to_anchor)
self.face_detector = BlazeFace()
self.face_detector.load_dict(model_weights)
self.face_detector.load_anchors_from_npy(model_anchors)
self.face_detector.min_score_thresh = min_score_thresh
self.face_detector.min_suppression_threshold = min_suppression_threshold
self.face_detector.eval()
def detect_from_image(self, tensor_or_path):
image = self.tensor_or_path_to_ndarray(tensor_or_path)
bboxlist = detect(self.face_detector, image)[0]
return bboxlist
def detect_from_batch(self, tensor):
bboxlists = batch_detect(self.face_detector, tensor)
return bboxlists
@property
def reference_scale(self):
return 195
@property
def reference_x_shift(self):
return 0
@property
def reference_y_shift(self):
return 0
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn.functional as F
import cv2
import numpy as np
from .utils import *
def detect(net, img, device):
H, W, C = img.shape
orig_size = min(H, W)
img, (xshift, yshift) = resize_and_crop_image(img, 128)
preds = net.predict_on_image(img.astype('float32')).numpy()
if 0 == len(preds):
return [[]]
shift = np.array([xshift, yshift] * 2)
scores = preds[:, -1:]
locs = np.concatenate(
(preds[:, 1:2], preds[:, 0:1], preds[:, 3:4], preds[:, 2:3]), axis=1)
return [np.concatenate((locs * orig_size + shift, scores), axis=1)]
def batch_detect(net, img_batch):
"""
Inputs:
- img_batch: a numpy array or tensor of shape (Batch size, Channels, Height, Width)
Outputs:
- list of 2-dim numpy arrays with shape (faces_on_this_image, 5): x1, y1, x2, y2, confidence
(x1, y1) - top left corner, (x2, y2) - bottom right corner
"""
B, H, W, C = img_batch.shape
orig_size = min(H, W)
if isinstance(img_batch, paddle.Tensor):
img_batch = img_batch.numpy()
imgs, (xshift, yshift) = resize_and_crop_batch(img_batch, 128)
preds = net.predict_on_batch(imgs.astype('float32'))
bboxlists = []
for pred in preds:
pred = pred.numpy()
shift = np.array([xshift, yshift] * 2)
scores = pred[:, -1:]
xmin = pred[:, 1:2]
ymin = pred[:, 0:1]
xmax = pred[:, 3:4]
ymax = pred[:, 2:3]
locs = np.concatenate((xmin, ymin, xmax, ymax), axis=1)
bboxlists.append(
np.concatenate((locs * orig_size + shift, scores), axis=1))
return bboxlists
def flip_detect(net, img):
img = cv2.flip(img, 1)
b = detect(net, img)
bboxlist = np.zeros(b.shape)
bboxlist[:, 0] = img.shape[1] - b[:, 2]
bboxlist[:, 1] = b[:, 1]
bboxlist[:, 2] = img.shape[1] - b[:, 0]
bboxlist[:, 3] = b[:, 3]
bboxlist[:, 4] = b[:, 4]
return bboxlist
def pts_to_bb(pts):
min_x, min_y = np.min(pts, axis=0)
max_x, max_y = np.max(pts, axis=0)
return np.array([min_x, min_y, max_x, max_y])
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class BlazeBlock(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
super(BlazeBlock, self).__init__()
self.stride = stride
self.channel_pad = out_channels - in_channels
if stride == 2:
self.max_pool = nn.MaxPool2D(kernel_size=stride, stride=stride)
padding = 0
else:
padding = (kernel_size - 1) // 2
self.convs = nn.Sequential(
nn.Conv2D(in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=in_channels),
nn.Conv2D(in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0),
)
self.act = nn.ReLU()
def forward(self, x):
if self.stride == 2:
h = F.pad(x, [0, 2, 0, 2], "constant", 0)
x = self.max_pool(x)
else:
h = x
if self.channel_pad > 0:
x = F.pad(x, [0, 0, 0, self.channel_pad, 0, 0, 0, 0], "constant", 0)
return self.act(self.convs(h) + x)
class BlazeFace(nn.Layer):
"""The BlazeFace face detection model.
"""
def __init__(self):
super(BlazeFace, self).__init__()
self.num_classes = 1
self.num_anchors = 896
self.num_coords = 16
self.score_clipping_thresh = 100.0
self.x_scale = 128.0
self.y_scale = 128.0
self.h_scale = 128.0
self.w_scale = 128.0
self.min_score_thresh = 0.75
self.min_suppression_threshold = 0.3
self._define_layers()
def _define_layers(self):
self.backbone1 = nn.Sequential(
nn.Conv2D(in_channels=3,
out_channels=24,
kernel_size=5,
stride=2,
padding=0),
nn.ReLU(),
BlazeBlock(24, 24),
BlazeBlock(24, 28),
BlazeBlock(28, 32, stride=2),
BlazeBlock(32, 36),
BlazeBlock(36, 42),
BlazeBlock(42, 48, stride=2),
BlazeBlock(48, 56),
BlazeBlock(56, 64),
BlazeBlock(64, 72),
BlazeBlock(72, 80),
BlazeBlock(80, 88),
)
self.backbone2 = nn.Sequential(
BlazeBlock(88, 96, stride=2),
BlazeBlock(96, 96),
BlazeBlock(96, 96),
BlazeBlock(96, 96),
BlazeBlock(96, 96),
)
self.classifier_8 = nn.Conv2D(88, 2, 1)
self.classifier_16 = nn.Conv2D(96, 6, 1)
self.regressor_8 = nn.Conv2D(88, 32, 1)
self.regressor_16 = nn.Conv2D(96, 96, 1)
def forward(self, x):
x = F.pad(x, [1, 2, 1, 2], "constant", 0)
b = x.shape[0]
x = self.backbone1(x) # (b, 88, 16, 16)
h = self.backbone2(x) # (b, 96, 8, 8)
c1 = self.classifier_8(x) # (b, 2, 16, 16)
c1 = c1.transpose([0, 2, 3, 1]) # (b, 16, 16, 2)
c1 = c1.reshape([b, -1, 1]) # (b, 512, 1)
c2 = self.classifier_16(h) # (b, 6, 8, 8)
c2 = c2.transpose([0, 2, 3, 1]) # (b, 8, 8, 6)
c2 = c2.reshape([b, -1, 1]) # (b, 384, 1)
c = paddle.concat((c1, c2), axis=1) # (b, 896, 1)
r1 = self.regressor_8(x) # (b, 32, 16, 16)
r1 = r1.transpose([0, 2, 3, 1]) # (b, 16, 16, 32)
r1 = r1.reshape([b, -1, 16]) # (b, 512, 16)
r2 = self.regressor_16(h) # (b, 96, 8, 8)
r2 = r2.transpose([0, 2, 3, 1]) # (b, 8, 8, 96)
r2 = r2.reshape([b, -1, 16]) # (b, 384, 16)
r = paddle.concat((r1, r2), axis=1) # (b, 896, 16)
return [r, c]
def load_weights(self, path):
paddle.load_dict(paddle.load(path))
self.eval()
def load_anchors(self, path):
self.anchors = paddle.to_tensor(np.load(path), dtype='float32')
assert (self.anchors.shape == 2)
assert (self.anchors.shape[0] == self.num_anchors)
assert (self.anchors.shape[1] == 4)
def load_anchors_from_npy(self, arr):
self.anchors = paddle.to_tensor(arr, dtype='float32')
assert (len(self.anchors.shape) == 2)
assert (self.anchors.shape[0] == self.num_anchors)
assert (self.anchors.shape[1] == 4)
def _preprocess(self, x):
"""Converts the image pixels to the range [-1, 1]."""
return x.astype('float32') / 127.5 - 1.0
def predict_on_image(self, img):
"""Makes a prediction on a single image.
Arguments:
img: a NumPy array of shape (H, W, 3) or a Paddle tensor of
shape (3, H, W). The image's height and width should be
128 pixels.
Returns:
A tensor with face detections.
"""
if isinstance(img, np.ndarray):
img = paddle.to_tensor(img).transpose((2, 0, 1))
return self.predict_on_batch(img.unsqueeze(0))[0]
def predict_on_batch(self, x):
"""Makes a prediction on a batch of images.
Arguments:
x: a NumPy array of shape (b, H, W, 3) or a Paddle tensor of
shape (b, 3, H, W). The height and width should be 128 pixels.
Returns:
A list containing a tensor of face detections for each image in
the batch. If no faces are found for an image, returns a tensor
of shape (0, 17).
Each face detection is a Paddle tensor consisting of 17 numbers:
- ymin, xmin, ymax, xmax
- x,y-coordinates for the 6 keypoints
- confidence score
"""
if isinstance(x, np.ndarray):
x = paddle.to_tensor(x).transpose((0, 3, 1, 2))
assert x.shape[1] == 3
assert x.shape[2] == 128
assert x.shape[3] == 128
x = self._preprocess(x)
with paddle.no_grad():
out = self.__call__(x)
detections = self._tensors_to_detections(out[0], out[1], self.anchors)
filtered_detections = []
for i in range(len(detections)):
faces = self._weighted_non_max_suppression(detections[i])
faces = paddle.stack(faces) if len(faces) > 0 else paddle.zeros(
(0, 17))
filtered_detections.append(faces)
return filtered_detections
def _tensors_to_detections(self, raw_box_tensor, raw_score_tensor, anchors):
"""The output of the neural network is a tensor of shape (b, 896, 16)
containing the bounding box regressor predictions, as well as a tensor
of shape (b, 896, 1) with the classification confidences.
Returns a list of (num_detections, 17) tensors, one for each image in
the batch.
"""
assert len(raw_box_tensor.shape) == 3
assert raw_box_tensor.shape[1] == self.num_anchors
assert raw_box_tensor.shape[2] == self.num_coords
assert len(raw_score_tensor.shape) == 3
assert raw_score_tensor.shape[1] == self.num_anchors
assert raw_score_tensor.shape[2] == self.num_classes
assert raw_box_tensor.shape[0] == raw_score_tensor.shape[0]
detection_boxes = self._decode_boxes(raw_box_tensor, anchors)
thresh = self.score_clipping_thresh
raw_score_tensor = raw_score_tensor.clip(-thresh, thresh)
detection_scores = F.sigmoid(raw_score_tensor).squeeze(axis=-1)
mask = detection_scores >= self.min_score_thresh
mask = mask.numpy()
detection_boxes = detection_boxes.numpy()
detection_scores = detection_scores.numpy()
output_detections = []
for i in range(raw_box_tensor.shape[0]):
boxes = paddle.to_tensor(detection_boxes[i, mask[i]])
scores = paddle.to_tensor(
detection_scores[i, mask[i]]).unsqueeze(axis=-1)
output_detections.append(paddle.concat((boxes, scores), axis=-1))
return output_detections
def _decode_boxes(self, raw_boxes, anchors):
"""Converts the predictions into actual coordinates using
the anchor boxes. Processes the entire batch at once.
"""
boxes = paddle.zeros_like(raw_boxes)
x_center = raw_boxes[:,:, 0] / self.x_scale * \
anchors[:, 2] + anchors[:, 0]
y_center = raw_boxes[:,:, 1] / self.y_scale * \
anchors[:, 3] + anchors[:, 1]
w = raw_boxes[:, :, 2] / self.w_scale * anchors[:, 2]
h = raw_boxes[:, :, 3] / self.h_scale * anchors[:, 3]
boxes[:, :, 0] = y_center - h / 2. # ymin
boxes[:, :, 1] = x_center - w / 2. # xmin
boxes[:, :, 2] = y_center + h / 2. # ymax
boxes[:, :, 3] = x_center + w / 2. # xmax
for k in range(6):
offset = 4 + k * 2
keypoint_x = raw_boxes[:,:, offset] / \
self.x_scale * anchors[:, 2] + anchors[:, 0]
keypoint_y = raw_boxes[:,:, offset + 1] / \
self.y_scale * anchors[:, 3] + anchors[:, 1]
boxes[:, :, offset] = keypoint_x
boxes[:, :, offset + 1] = keypoint_y
return boxes
def _weighted_non_max_suppression(self, detections):
"""The alternative NMS method as mentioned in the BlazeFace paper:
The input detections should be a Tensor of shape (count, 17).
Returns a list of Paddle tensors, one for each detected face.
"""
if len(detections) == 0:
return []
output_detections = []
# Sort the detections from highest to lowest score.
remaining = paddle.argsort(detections[:, 16], descending=True).numpy()
detections = detections.numpy()
while len(remaining) > 0:
detection = detections[remaining[0]]
first_box = detection[:4]
other_boxes = detections[remaining, :4]
ious = overlap_similarity(paddle.to_tensor(first_box),
paddle.to_tensor(other_boxes))
mask = ious > self.min_suppression_threshold
mask = mask.numpy()
overlapping = remaining[mask]
remaining = remaining[~mask]
weighted_detection = detection.copy()
if len(overlapping) > 1:
coordinates = detections[overlapping, :16]
scores = detections[overlapping, 16:17]
total_score = scores.sum()
weighted = (coordinates * scores).sum(axis=0) / total_score
weighted_detection[:16] = weighted
weighted_detection[16] = total_score / len(overlapping)
output_detections.append(paddle.to_tensor(weighted_detection))
return output_detections
def intersect(box_a, box_b):
"""Compute the area of intersect between box_a and box_b.
Args:
box_a: (tensor) bounding boxes, Shape: [A,4].
box_b: (tensor) bounding boxes, Shape: [B,4].
Return:
(tensor) intersection area, Shape: [A,B].
"""
A = box_a.shape[0]
B = box_b.shape[0]
max_xy = paddle.minimum(box_a[:, 2:].unsqueeze(1).expand((A, B, 2)),
box_b[:, 2:].unsqueeze(0).expand((A, B, 2)))
min_xy = paddle.maximum(box_a[:, :2].unsqueeze(1).expand((A, B, 2)),
box_b[:, :2].unsqueeze(0).expand((A, B, 2)))
inter = paddle.clip((max_xy - min_xy), min=0)
return inter[:, :, 0] * inter[:, :, 1]
def jaccard(box_a, box_b):
"""Compute the jaccard overlap of two sets of boxes.
Args:
box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
Return:
jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
"""
inter = intersect(box_a, box_b)
area_a = ((box_a[:, 2] - box_a[:, 0]) *
(box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter)
area_b = ((box_b[:, 2] - box_b[:, 0]) *
(box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter)
union = area_a + area_b - inter
return inter / union
def overlap_similarity(box, other_boxes):
"""Computes the IOU between a bounding box and set of other boxes."""
return jaccard(box.unsqueeze(0), other_boxes).squeeze(0)
def init_model():
net = BlazeFace()
net.load_weights("blazeface.pdparams")
net.load_anchors("anchors.npy")
net.min_score_thresh = 0.75
net.min_suppression_threshold = 0.3
return net
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import numpy as np
def image_resize(image, width=None, height=None, inter=cv2.INTER_AREA):
dim = None
(h, w) = image.shape[:2]
if width is None and height is None:
return image
if width is None:
r = height / float(h)
dim = (int(w * r), height)
else:
r = width / float(w)
dim = (width, int(h * r))
resized = cv2.resize(image, dim, interpolation=inter)
return resized
def resize_and_crop_image(image, dim):
if image.shape[0] > image.shape[1]:
img = image_resize(image, width=dim)
yshift, xshift = (image.shape[0] - image.shape[1]) // 2, 0
y_start = (img.shape[0] - img.shape[1]) // 2
y_end = y_start + dim
return img[y_start:y_end, :, :], (xshift, yshift)
else:
img = image_resize(image, height=dim)
yshift, xshift = 0, (image.shape[1] - image.shape[0]) // 2
x_start = (img.shape[1] - img.shape[0]) // 2
x_end = x_start + dim
return img[:, x_start:x_end, :], (xshift, yshift)
def resize_and_crop_batch(frames, dim):
smframes = []
xshift, yshift = 0, 0
for i in range(len(frames)):
smframe, (xshift, yshift) = resize_and_crop_image(frames[i], dim)
smframes.append(smframe)
smframes = np.stack(smframes)
return smframes, (xshift, yshift)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册