From 89fb3c8b49e6cfef92a7f5f894add74ba154104b Mon Sep 17 00:00:00 2001 From: lijianshe02 <48898730+lijianshe02@users.noreply.github.com> Date: Wed, 24 Mar 2021 11:48:57 +0800 Subject: [PATCH] add blazeface detector (#229) * add blazeface detector --- applications/tools/first-order-demo.py | 9 +- applications/tools/wav2lip.py | 6 + ppgan/apps/first_order_predictor.py | 8 +- ppgan/apps/wav2lip_predictor.py | 4 +- ppgan/faceutils/face_detection/api.py | 2 +- .../detection/blazeface/__init__.py | 1 + .../detection/blazeface/blazeface_detector.py | 77 ++++ .../detection/blazeface/detect.py | 89 ++++ .../detection/blazeface/net_blazeface.py | 380 ++++++++++++++++++ .../detection/blazeface/utils.py | 60 +++ 10 files changed, 631 insertions(+), 5 deletions(-) create mode 100644 ppgan/faceutils/face_detection/detection/blazeface/__init__.py create mode 100644 ppgan/faceutils/face_detection/detection/blazeface/blazeface_detector.py create mode 100644 ppgan/faceutils/face_detection/detection/blazeface/detect.py create mode 100644 ppgan/faceutils/face_detection/detection/blazeface/net_blazeface.py create mode 100644 ppgan/faceutils/face_detection/detection/blazeface/utils.py diff --git a/applications/tools/first-order-demo.py b/applications/tools/first-order-demo.py index e08ac73..472e4a7 100644 --- a/applications/tools/first-order-demo.py +++ b/applications/tools/first-order-demo.py @@ -57,6 +57,12 @@ parser.add_argument("--ratio", type=float, default=0.4, help="margin ratio") +parser.add_argument( + "--face_detector", + dest="face_detector", + type=str, + default='sfd', + help="face detector to be used, can choose s3fd or blazeface") parser.set_defaults(relative=False) parser.set_defaults(adapt_scale=False) @@ -75,5 +81,6 @@ if __name__ == "__main__": adapt_scale=args.adapt_scale, find_best_frame=args.find_best_frame, best_frame=args.best_frame, - ratio=args.ratio) + ratio=args.ratio, + face_detector=args.face_detector) predictor.run(args.source_image, args.driving_video) diff --git a/applications/tools/wav2lip.py b/applications/tools/wav2lip.py index 8a708fc..74f40f6 100644 --- a/applications/tools/wav2lip.py +++ b/applications/tools/wav2lip.py @@ -97,6 +97,12 @@ parser.add_argument( action='store_true', help='Prevent smoothing face detections over a short temporal window') parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.") +parser.add_argument( + "--face_detector", + dest="face_detector", + type=str, + default='sfd', + help="face detector to be used, can choose s3fd or blazeface") if __name__ == "__main__": args = parser.parse_args() diff --git a/ppgan/apps/first_order_predictor.py b/ppgan/apps/first_order_predictor.py index 204615a..4ed4791 100644 --- a/ppgan/apps/first_order_predictor.py +++ b/ppgan/apps/first_order_predictor.py @@ -46,7 +46,8 @@ class FirstOrderPredictor(BasePredictor): find_best_frame=False, best_frame=None, ratio=1.0, - filename='result.mp4'): + filename='result.mp4', + face_detector='sfd'): if config is not None and isinstance(config, str): self.cfg = yaml.load(config, Loader=yaml.SafeLoader) elif isinstance(config, dict): @@ -95,6 +96,7 @@ class FirstOrderPredictor(BasePredictor): self.find_best_frame = find_best_frame self.best_frame = best_frame self.ratio = ratio + self.face_detector = face_detector self.generator, self.kp_detector = self.load_checkpoints( self.cfg, self.weight_path) @@ -261,7 +263,9 @@ class FirstOrderPredictor(BasePredictor): def extract_bbox(self, image): detector = face_detection.FaceAlignment( - face_detection.LandmarksType._2D, flip_input=False) + face_detection.LandmarksType._2D, + flip_input=False, + face_detector=self.face_detector) frame = [image] predictions = detector.get_detections_for_image(np.array(frame)) diff --git a/ppgan/apps/wav2lip_predictor.py b/ppgan/apps/wav2lip_predictor.py index d29c76e..a8014bb 100644 --- a/ppgan/apps/wav2lip_predictor.py +++ b/ppgan/apps/wav2lip_predictor.py @@ -36,7 +36,9 @@ class Wav2LipPredictor(BasePredictor): def face_detect(self, images): detector = face_detection.FaceAlignment( - face_detection.LandmarksType._2D, flip_input=False) + face_detection.LandmarksType._2D, + flip_input=False, + face_detector=self.args.face_detector) batch_size = self.args.face_det_batch_size diff --git a/ppgan/faceutils/face_detection/api.py b/ppgan/faceutils/face_detection/api.py index 7511880..608ad5b 100644 --- a/ppgan/faceutils/face_detection/api.py +++ b/ppgan/faceutils/face_detection/api.py @@ -80,7 +80,7 @@ class FaceAlignment: d = d[0] d = np.clip(d, 0, None) - x1, y1, x2, y2 = map(int, d[:-1]) + x1, y1, x2, y2 = map(int, d[:4]) results.append((x1, y1, x2, y2)) return results diff --git a/ppgan/faceutils/face_detection/detection/blazeface/__init__.py b/ppgan/faceutils/face_detection/detection/blazeface/__init__.py new file mode 100644 index 0000000..e2a4708 --- /dev/null +++ b/ppgan/faceutils/face_detection/detection/blazeface/__init__.py @@ -0,0 +1 @@ +from .blazeface_detector import BlazeFaceDetector as FaceDetector diff --git a/ppgan/faceutils/face_detection/detection/blazeface/blazeface_detector.py b/ppgan/faceutils/face_detection/detection/blazeface/blazeface_detector.py new file mode 100644 index 0000000..1bc58b5 --- /dev/null +++ b/ppgan/faceutils/face_detection/detection/blazeface/blazeface_detector.py @@ -0,0 +1,77 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +from paddle.utils.download import get_weights_path_from_url + +from ..core import FaceDetector + +from .net_blazeface import BlazeFace +from .detect import * + +blazeface_weights = 'https://paddlegan.bj.bcebos.com/models/blazeface.pdparams' +blazeface_anchors = 'https://paddlegan.bj.bcebos.com/models/anchors.npy' + + +class BlazeFaceDetector(FaceDetector): + def __init__(self, + path_to_detector=None, + path_to_anchor=None, + verbose=False, + min_score_thresh=0.5, + min_suppression_threshold=0.3): + super(BlazeFaceDetector, self).__init__(verbose) + + # Initialise the face detector + if path_to_detector is None: + model_weights_path = get_weights_path_from_url(blazeface_weights) + model_weights = paddle.load(model_weights_path) + model_anchors = np.load( + get_weights_path_from_url(blazeface_anchors)) + else: + model_weights = paddle.load(path_to_detector) + model_anchors = np.load(path_to_anchor) + + self.face_detector = BlazeFace() + self.face_detector.load_dict(model_weights) + self.face_detector.load_anchors_from_npy(model_anchors) + + self.face_detector.min_score_thresh = min_score_thresh + self.face_detector.min_suppression_threshold = min_suppression_threshold + + self.face_detector.eval() + + def detect_from_image(self, tensor_or_path): + image = self.tensor_or_path_to_ndarray(tensor_or_path) + + bboxlist = detect(self.face_detector, image)[0] + + return bboxlist + + def detect_from_batch(self, tensor): + bboxlists = batch_detect(self.face_detector, tensor) + return bboxlists + + @property + def reference_scale(self): + return 195 + + @property + def reference_x_shift(self): + return 0 + + @property + def reference_y_shift(self): + return 0 diff --git a/ppgan/faceutils/face_detection/detection/blazeface/detect.py b/ppgan/faceutils/face_detection/detection/blazeface/detect.py new file mode 100644 index 0000000..f4f2b89 --- /dev/null +++ b/ppgan/faceutils/face_detection/detection/blazeface/detect.py @@ -0,0 +1,89 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn.functional as F + +import cv2 +import numpy as np + +from .utils import * + + +def detect(net, img, device): + H, W, C = img.shape + orig_size = min(H, W) + img, (xshift, yshift) = resize_and_crop_image(img, 128) + preds = net.predict_on_image(img.astype('float32')).numpy() + + if 0 == len(preds): + return [[]] + + shift = np.array([xshift, yshift] * 2) + scores = preds[:, -1:] + + locs = np.concatenate( + (preds[:, 1:2], preds[:, 0:1], preds[:, 3:4], preds[:, 2:3]), axis=1) + return [np.concatenate((locs * orig_size + shift, scores), axis=1)] + + +def batch_detect(net, img_batch): + """ + Inputs: + - img_batch: a numpy array or tensor of shape (Batch size, Channels, Height, Width) + Outputs: + - list of 2-dim numpy arrays with shape (faces_on_this_image, 5): x1, y1, x2, y2, confidence + (x1, y1) - top left corner, (x2, y2) - bottom right corner + """ + B, H, W, C = img_batch.shape + orig_size = min(H, W) + + if isinstance(img_batch, paddle.Tensor): + img_batch = img_batch.numpy() + + imgs, (xshift, yshift) = resize_and_crop_batch(img_batch, 128) + preds = net.predict_on_batch(imgs.astype('float32')) + bboxlists = [] + for pred in preds: + pred = pred.numpy() + shift = np.array([xshift, yshift] * 2) + scores = pred[:, -1:] + xmin = pred[:, 1:2] + ymin = pred[:, 0:1] + xmax = pred[:, 3:4] + ymax = pred[:, 2:3] + locs = np.concatenate((xmin, ymin, xmax, ymax), axis=1) + bboxlists.append( + np.concatenate((locs * orig_size + shift, scores), axis=1)) + + return bboxlists + + +def flip_detect(net, img): + img = cv2.flip(img, 1) + b = detect(net, img) + + bboxlist = np.zeros(b.shape) + bboxlist[:, 0] = img.shape[1] - b[:, 2] + bboxlist[:, 1] = b[:, 1] + bboxlist[:, 2] = img.shape[1] - b[:, 0] + bboxlist[:, 3] = b[:, 3] + bboxlist[:, 4] = b[:, 4] + return bboxlist + + +def pts_to_bb(pts): + min_x, min_y = np.min(pts, axis=0) + max_x, max_y = np.max(pts, axis=0) + return np.array([min_x, min_y, max_x, max_y]) diff --git a/ppgan/faceutils/face_detection/detection/blazeface/net_blazeface.py b/ppgan/faceutils/face_detection/detection/blazeface/net_blazeface.py new file mode 100644 index 0000000..ca066be --- /dev/null +++ b/ppgan/faceutils/face_detection/detection/blazeface/net_blazeface.py @@ -0,0 +1,380 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class BlazeBlock(nn.Layer): + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1): + super(BlazeBlock, self).__init__() + + self.stride = stride + self.channel_pad = out_channels - in_channels + + if stride == 2: + self.max_pool = nn.MaxPool2D(kernel_size=stride, stride=stride) + padding = 0 + else: + padding = (kernel_size - 1) // 2 + + self.convs = nn.Sequential( + nn.Conv2D(in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=in_channels), + nn.Conv2D(in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0), + ) + + self.act = nn.ReLU() + + def forward(self, x): + if self.stride == 2: + h = F.pad(x, [0, 2, 0, 2], "constant", 0) + x = self.max_pool(x) + else: + h = x + if self.channel_pad > 0: + x = F.pad(x, [0, 0, 0, self.channel_pad, 0, 0, 0, 0], "constant", 0) + + return self.act(self.convs(h) + x) + + +class BlazeFace(nn.Layer): + """The BlazeFace face detection model. + """ + def __init__(self): + super(BlazeFace, self).__init__() + + self.num_classes = 1 + self.num_anchors = 896 + self.num_coords = 16 + self.score_clipping_thresh = 100.0 + self.x_scale = 128.0 + self.y_scale = 128.0 + self.h_scale = 128.0 + self.w_scale = 128.0 + self.min_score_thresh = 0.75 + self.min_suppression_threshold = 0.3 + + self._define_layers() + + def _define_layers(self): + self.backbone1 = nn.Sequential( + nn.Conv2D(in_channels=3, + out_channels=24, + kernel_size=5, + stride=2, + padding=0), + nn.ReLU(), + BlazeBlock(24, 24), + BlazeBlock(24, 28), + BlazeBlock(28, 32, stride=2), + BlazeBlock(32, 36), + BlazeBlock(36, 42), + BlazeBlock(42, 48, stride=2), + BlazeBlock(48, 56), + BlazeBlock(56, 64), + BlazeBlock(64, 72), + BlazeBlock(72, 80), + BlazeBlock(80, 88), + ) + + self.backbone2 = nn.Sequential( + BlazeBlock(88, 96, stride=2), + BlazeBlock(96, 96), + BlazeBlock(96, 96), + BlazeBlock(96, 96), + BlazeBlock(96, 96), + ) + + self.classifier_8 = nn.Conv2D(88, 2, 1) + self.classifier_16 = nn.Conv2D(96, 6, 1) + + self.regressor_8 = nn.Conv2D(88, 32, 1) + self.regressor_16 = nn.Conv2D(96, 96, 1) + + def forward(self, x): + x = F.pad(x, [1, 2, 1, 2], "constant", 0) + + b = x.shape[0] + + x = self.backbone1(x) # (b, 88, 16, 16) + h = self.backbone2(x) # (b, 96, 8, 8) + + c1 = self.classifier_8(x) # (b, 2, 16, 16) + c1 = c1.transpose([0, 2, 3, 1]) # (b, 16, 16, 2) + c1 = c1.reshape([b, -1, 1]) # (b, 512, 1) + + c2 = self.classifier_16(h) # (b, 6, 8, 8) + c2 = c2.transpose([0, 2, 3, 1]) # (b, 8, 8, 6) + c2 = c2.reshape([b, -1, 1]) # (b, 384, 1) + + c = paddle.concat((c1, c2), axis=1) # (b, 896, 1) + + r1 = self.regressor_8(x) # (b, 32, 16, 16) + r1 = r1.transpose([0, 2, 3, 1]) # (b, 16, 16, 32) + r1 = r1.reshape([b, -1, 16]) # (b, 512, 16) + + r2 = self.regressor_16(h) # (b, 96, 8, 8) + r2 = r2.transpose([0, 2, 3, 1]) # (b, 8, 8, 96) + r2 = r2.reshape([b, -1, 16]) # (b, 384, 16) + + r = paddle.concat((r1, r2), axis=1) # (b, 896, 16) + return [r, c] + + def load_weights(self, path): + paddle.load_dict(paddle.load(path)) + self.eval() + + def load_anchors(self, path): + self.anchors = paddle.to_tensor(np.load(path), dtype='float32') + assert (self.anchors.shape == 2) + assert (self.anchors.shape[0] == self.num_anchors) + assert (self.anchors.shape[1] == 4) + + def load_anchors_from_npy(self, arr): + self.anchors = paddle.to_tensor(arr, dtype='float32') + assert (len(self.anchors.shape) == 2) + assert (self.anchors.shape[0] == self.num_anchors) + assert (self.anchors.shape[1] == 4) + + def _preprocess(self, x): + """Converts the image pixels to the range [-1, 1].""" + return x.astype('float32') / 127.5 - 1.0 + + def predict_on_image(self, img): + """Makes a prediction on a single image. + + Arguments: + img: a NumPy array of shape (H, W, 3) or a Paddle tensor of + shape (3, H, W). The image's height and width should be + 128 pixels. + + Returns: + A tensor with face detections. + """ + if isinstance(img, np.ndarray): + img = paddle.to_tensor(img).transpose((2, 0, 1)) + + return self.predict_on_batch(img.unsqueeze(0))[0] + + def predict_on_batch(self, x): + """Makes a prediction on a batch of images. + + Arguments: + x: a NumPy array of shape (b, H, W, 3) or a Paddle tensor of + shape (b, 3, H, W). The height and width should be 128 pixels. + + Returns: + A list containing a tensor of face detections for each image in + the batch. If no faces are found for an image, returns a tensor + of shape (0, 17). + + Each face detection is a Paddle tensor consisting of 17 numbers: + - ymin, xmin, ymax, xmax + - x,y-coordinates for the 6 keypoints + - confidence score + """ + if isinstance(x, np.ndarray): + x = paddle.to_tensor(x).transpose((0, 3, 1, 2)) + + assert x.shape[1] == 3 + assert x.shape[2] == 128 + assert x.shape[3] == 128 + + x = self._preprocess(x) + + with paddle.no_grad(): + out = self.__call__(x) + + detections = self._tensors_to_detections(out[0], out[1], self.anchors) + + filtered_detections = [] + for i in range(len(detections)): + faces = self._weighted_non_max_suppression(detections[i]) + faces = paddle.stack(faces) if len(faces) > 0 else paddle.zeros( + (0, 17)) + filtered_detections.append(faces) + + return filtered_detections + + def _tensors_to_detections(self, raw_box_tensor, raw_score_tensor, anchors): + """The output of the neural network is a tensor of shape (b, 896, 16) + containing the bounding box regressor predictions, as well as a tensor + of shape (b, 896, 1) with the classification confidences. + + Returns a list of (num_detections, 17) tensors, one for each image in + the batch. + """ + assert len(raw_box_tensor.shape) == 3 + assert raw_box_tensor.shape[1] == self.num_anchors + assert raw_box_tensor.shape[2] == self.num_coords + + assert len(raw_score_tensor.shape) == 3 + assert raw_score_tensor.shape[1] == self.num_anchors + assert raw_score_tensor.shape[2] == self.num_classes + + assert raw_box_tensor.shape[0] == raw_score_tensor.shape[0] + + detection_boxes = self._decode_boxes(raw_box_tensor, anchors) + + thresh = self.score_clipping_thresh + raw_score_tensor = raw_score_tensor.clip(-thresh, thresh) + detection_scores = F.sigmoid(raw_score_tensor).squeeze(axis=-1) + + mask = detection_scores >= self.min_score_thresh + mask = mask.numpy() + detection_boxes = detection_boxes.numpy() + detection_scores = detection_scores.numpy() + + output_detections = [] + for i in range(raw_box_tensor.shape[0]): + boxes = paddle.to_tensor(detection_boxes[i, mask[i]]) + scores = paddle.to_tensor( + detection_scores[i, mask[i]]).unsqueeze(axis=-1) + output_detections.append(paddle.concat((boxes, scores), axis=-1)) + + return output_detections + + def _decode_boxes(self, raw_boxes, anchors): + """Converts the predictions into actual coordinates using + the anchor boxes. Processes the entire batch at once. + """ + boxes = paddle.zeros_like(raw_boxes) + + x_center = raw_boxes[:,:, 0] / self.x_scale * \ + anchors[:, 2] + anchors[:, 0] + y_center = raw_boxes[:,:, 1] / self.y_scale * \ + anchors[:, 3] + anchors[:, 1] + + w = raw_boxes[:, :, 2] / self.w_scale * anchors[:, 2] + h = raw_boxes[:, :, 3] / self.h_scale * anchors[:, 3] + + boxes[:, :, 0] = y_center - h / 2. # ymin + boxes[:, :, 1] = x_center - w / 2. # xmin + boxes[:, :, 2] = y_center + h / 2. # ymax + boxes[:, :, 3] = x_center + w / 2. # xmax + + for k in range(6): + offset = 4 + k * 2 + keypoint_x = raw_boxes[:,:, offset] / \ + self.x_scale * anchors[:, 2] + anchors[:, 0] + keypoint_y = raw_boxes[:,:, offset + 1] / \ + self.y_scale * anchors[:, 3] + anchors[:, 1] + boxes[:, :, offset] = keypoint_x + boxes[:, :, offset + 1] = keypoint_y + + return boxes + + def _weighted_non_max_suppression(self, detections): + """The alternative NMS method as mentioned in the BlazeFace paper: + The input detections should be a Tensor of shape (count, 17). + Returns a list of Paddle tensors, one for each detected face. + + """ + if len(detections) == 0: + return [] + + output_detections = [] + + # Sort the detections from highest to lowest score. + remaining = paddle.argsort(detections[:, 16], descending=True).numpy() + detections = detections.numpy() + + while len(remaining) > 0: + detection = detections[remaining[0]] + + first_box = detection[:4] + other_boxes = detections[remaining, :4] + ious = overlap_similarity(paddle.to_tensor(first_box), + paddle.to_tensor(other_boxes)) + + mask = ious > self.min_suppression_threshold + mask = mask.numpy() + + overlapping = remaining[mask] + remaining = remaining[~mask] + + weighted_detection = detection.copy() + if len(overlapping) > 1: + coordinates = detections[overlapping, :16] + scores = detections[overlapping, 16:17] + total_score = scores.sum() + weighted = (coordinates * scores).sum(axis=0) / total_score + weighted_detection[:16] = weighted + weighted_detection[16] = total_score / len(overlapping) + + output_detections.append(paddle.to_tensor(weighted_detection)) + + return output_detections + + +def intersect(box_a, box_b): + """Compute the area of intersect between box_a and box_b. + Args: + box_a: (tensor) bounding boxes, Shape: [A,4]. + box_b: (tensor) bounding boxes, Shape: [B,4]. + Return: + (tensor) intersection area, Shape: [A,B]. + """ + A = box_a.shape[0] + B = box_b.shape[0] + max_xy = paddle.minimum(box_a[:, 2:].unsqueeze(1).expand((A, B, 2)), + box_b[:, 2:].unsqueeze(0).expand((A, B, 2))) + min_xy = paddle.maximum(box_a[:, :2].unsqueeze(1).expand((A, B, 2)), + box_b[:, :2].unsqueeze(0).expand((A, B, 2))) + inter = paddle.clip((max_xy - min_xy), min=0) + return inter[:, :, 0] * inter[:, :, 1] + + +def jaccard(box_a, box_b): + """Compute the jaccard overlap of two sets of boxes. + Args: + box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] + box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] + Return: + jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] + """ + inter = intersect(box_a, box_b) + area_a = ((box_a[:, 2] - box_a[:, 0]) * + (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) + area_b = ((box_b[:, 2] - box_b[:, 0]) * + (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) + union = area_a + area_b - inter + return inter / union + + +def overlap_similarity(box, other_boxes): + """Computes the IOU between a bounding box and set of other boxes.""" + return jaccard(box.unsqueeze(0), other_boxes).squeeze(0) + + +def init_model(): + net = BlazeFace() + net.load_weights("blazeface.pdparams") + net.load_anchors("anchors.npy") + + net.min_score_thresh = 0.75 + net.min_suppression_threshold = 0.3 + + return net diff --git a/ppgan/faceutils/face_detection/detection/blazeface/utils.py b/ppgan/faceutils/face_detection/detection/blazeface/utils.py new file mode 100644 index 0000000..a5691ce --- /dev/null +++ b/ppgan/faceutils/face_detection/detection/blazeface/utils.py @@ -0,0 +1,60 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np + + +def image_resize(image, width=None, height=None, inter=cv2.INTER_AREA): + dim = None + (h, w) = image.shape[:2] + + if width is None and height is None: + return image + + if width is None: + r = height / float(h) + dim = (int(w * r), height) + else: + r = width / float(w) + dim = (width, int(h * r)) + + resized = cv2.resize(image, dim, interpolation=inter) + + return resized + + +def resize_and_crop_image(image, dim): + if image.shape[0] > image.shape[1]: + img = image_resize(image, width=dim) + yshift, xshift = (image.shape[0] - image.shape[1]) // 2, 0 + y_start = (img.shape[0] - img.shape[1]) // 2 + y_end = y_start + dim + return img[y_start:y_end, :, :], (xshift, yshift) + else: + img = image_resize(image, height=dim) + yshift, xshift = 0, (image.shape[1] - image.shape[0]) // 2 + x_start = (img.shape[1] - img.shape[0]) // 2 + x_end = x_start + dim + return img[:, x_start:x_end, :], (xshift, yshift) + + +def resize_and_crop_batch(frames, dim): + smframes = [] + xshift, yshift = 0, 0 + for i in range(len(frames)): + smframe, (xshift, yshift) = resize_and_crop_image(frames[i], dim) + smframes.append(smframe) + smframes = np.stack(smframes) + return smframes, (xshift, yshift) -- GitLab