From 2a09260714e4ee6d87624b34bcee8bd580545d52 Mon Sep 17 00:00:00 2001 From: lijianshe02 <48898730+lijianshe02@users.noreply.github.com> Date: Tue, 1 Dec 2020 18:53:11 +0800 Subject: [PATCH] fix auto downloading problem of dlib weight and add s3fd face detector (#98) * fix auto downloading problem of dlib weight and add s3fd face detector --- ppgan/faceutils/dlibutils/dlib_utils.py | 16 +- ppgan/faceutils/face_detection/__init__.py | 1 + ppgan/faceutils/face_detection/api.py | 86 ++++++++ .../face_detection/detection/__init__.py | 1 + .../face_detection/detection/core.py | 145 +++++++++++++ .../face_detection/detection/sfd/__init__.py | 1 + .../face_detection/detection/sfd/bbox.py | 147 +++++++++++++ .../face_detection/detection/sfd/detect.py | 104 ++++++++++ .../face_detection/detection/sfd/net_s3fd.py | 195 ++++++++++++++++++ .../detection/sfd/sfd_detector.py | 77 +++++++ ppgan/faceutils/face_detection/utils.py | 78 +++++++ ppgan/models/discriminators/syncnet.py | 145 +++++++++++++ ppgan/modules/conv.py | 68 ++++++ requirments.txt => requirements.txt | 0 14 files changed, 1054 insertions(+), 10 deletions(-) create mode 100644 ppgan/faceutils/face_detection/__init__.py create mode 100644 ppgan/faceutils/face_detection/api.py create mode 100644 ppgan/faceutils/face_detection/detection/__init__.py create mode 100644 ppgan/faceutils/face_detection/detection/core.py create mode 100644 ppgan/faceutils/face_detection/detection/sfd/__init__.py create mode 100644 ppgan/faceutils/face_detection/detection/sfd/bbox.py create mode 100644 ppgan/faceutils/face_detection/detection/sfd/detect.py create mode 100644 ppgan/faceutils/face_detection/detection/sfd/net_s3fd.py create mode 100644 ppgan/faceutils/face_detection/detection/sfd/sfd_detector.py create mode 100644 ppgan/faceutils/face_detection/utils.py create mode 100644 ppgan/models/discriminators/syncnet.py create mode 100644 ppgan/modules/conv.py rename requirments.txt => requirements.txt (100%) diff --git a/ppgan/faceutils/dlibutils/dlib_utils.py b/ppgan/faceutils/dlibutils/dlib_utils.py index 0554198..eca5ba6 100644 --- a/ppgan/faceutils/dlibutils/dlib_utils.py +++ b/ppgan/faceutils/dlibutils/dlib_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import os.path as osp import numpy as np @@ -19,17 +20,9 @@ from PIL import Image import dlib import cv2 from ..image import resize_by_max -from ppgan.utils.logger import get_logger -logger = get_logger() +from paddle.utils.download import get_weights_path_from_url -detector = dlib.get_frontal_face_detector() - -try: - predictor = dlib.shape_predictor( - osp.split(osp.realpath(__file__))[0] + '/lms.dat') -except Exception as e: - predictor = None - logger.warning(e) +LANDMARKS_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lms.dat' def detect(image: Image): @@ -37,6 +30,7 @@ def detect(image: Image): h, w = image.shape[:2] image = resize_by_max(image, 361) actual_h, actual_w = image.shape[:2] + detector = dlib.get_frontal_face_detector() faces_on_small = detector(image, 1) faces = dlib.rectangles() for face in faces_on_small: @@ -129,6 +123,8 @@ def crop_by_image_size(image: Image, face): def landmarks(image: Image, face): + weight_path = get_weights_path_from_url(LANDMARKS_WEIGHT_URL) + predictor = dlib.shape_predictor(weight_path) shape = predictor(np.asarray(image), face).parts() return np.array([[p.y, p.x] for p in shape]) diff --git a/ppgan/faceutils/face_detection/__init__.py b/ppgan/faceutils/face_detection/__init__.py new file mode 100644 index 0000000..7af5231 --- /dev/null +++ b/ppgan/faceutils/face_detection/__init__.py @@ -0,0 +1 @@ +from .api import FaceAlignment, LandmarksType, NetworkSize diff --git a/ppgan/faceutils/face_detection/api.py b/ppgan/faceutils/face_detection/api.py new file mode 100644 index 0000000..7b61606 --- /dev/null +++ b/ppgan/faceutils/face_detection/api.py @@ -0,0 +1,86 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from enum import Enum +import numpy as np +import cv2 + +from .utils import * +import sys + + +class LandmarksType(Enum): + """Enum class defining the type of landmarks to detect. + + ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face + ``_2halfD`` - this points represent the projection of the 3D points into 3D + ``_3D`` - detect the points ``(x,y,z)``` in a 3D space + + """ + _2D = 1 + _2halfD = 2 + _3D = 3 + + +class NetworkSize(Enum): + # TINY = 1 + # SMALL = 2 + # MEDIUM = 3 + LARGE = 4 + + def __new__(cls, value): + member = object.__new__(cls) + member._value_ = value + return member + + def __int__(self): + return self.value + + +class FaceAlignment: + def __init__(self, + landmarks_type, + network_size=NetworkSize.LARGE, + flip_input=False, + face_detector='sfd', + verbose=False): + self.flip_input = flip_input + self.landmarks_type = landmarks_type + self.verbose = verbose + + network_size = int(network_size) + + # Get the face detector + face_detector_module = __import__( + 'face_detection.detection.' + face_detector, globals(), locals(), + [face_detector], 0) + self.face_detector = face_detector_module.FaceDetector(verbose=verbose) + + def get_detections_for_batch(self, images): + images = images[..., ::-1] + detected_faces = self.face_detector.detect_from_batch(images.copy()) + results = [] + + for i, d in enumerate(detected_faces): + if len(d) == 0: + results.append(None) + continue + d = d[0] + d = np.clip(d, 0, None) + + x1, y1, x2, y2 = map(int, d[:-1]) + results.append((x1, y1, x2, y2)) + + return results diff --git a/ppgan/faceutils/face_detection/detection/__init__.py b/ppgan/faceutils/face_detection/detection/__init__.py new file mode 100644 index 0000000..9dbd01b --- /dev/null +++ b/ppgan/faceutils/face_detection/detection/__init__.py @@ -0,0 +1 @@ +from .core import FaceDetector diff --git a/ppgan/faceutils/face_detection/detection/core.py b/ppgan/faceutils/face_detection/detection/core.py new file mode 100644 index 0000000..b9988f8 --- /dev/null +++ b/ppgan/faceutils/face_detection/detection/core.py @@ -0,0 +1,145 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import glob +from tqdm import tqdm +import numpy as np +import paddle +import cv2 + + +class FaceDetector(object): + """An abstract class representing a face detector. + + Any other face detection implementation must subclass it. All subclasses + must implement ``detect_from_image``, that return a list of detected + bounding boxes. Optionally, for speed considerations detect from path is + recommended. + """ + def __init__(self, verbose): + self.verbose = verbose + + def detect_from_image(self, tensor_or_path): + """Detects faces in a given image. + + This function detects the faces present in a provided BGR(usually) + image. The input can be either the image itself or the path to it. + + Args: + tensor_or_path {numpy.ndarray, paddle.tensor or string} -- the path + to an image or the image itself. + + Example:: + + >>> path_to_image = 'data/image_01.jpg' + ... detected_faces = detect_from_image(path_to_image) + [A list of bounding boxes (x1, y1, x2, y2)] + >>> image = cv2.imread(path_to_image) + ... detected_faces = detect_from_image(image) + [A list of bounding boxes (x1, y1, x2, y2)] + + """ + raise NotImplementedError + + def detect_from_directory(self, + path, + extensions=['.jpg', '.png'], + recursive=False, + show_progress_bar=True): + """Detects faces from all the images present in a given directory. + + Ars: + path {string} -- a string containing a path that points to the folder containing the images + extensions {list} -- list of string containing the extensions to be + consider in the following format: ``.extension_name`` (default: + {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the + folder recursively (default: {False}) show_progress_bar {bool} -- + display a progressbar (default: {True}) + + Example: + >>> directory = 'data' + ... detected_faces = detect_from_directory(directory) + {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]} + + """ + if self.verbose: + logger = logging.getLogger(__name__) + + if len(extensions) == 0: + if self.verbose: + logger.error( + "Expected at list one extension, but none was received.") + raise ValueError + + if self.verbose: + logger.info("Constructing the list of images.") + additional_pattern = '/**/*' if recursive else '/*' + files = [] + for extension in extensions: + files.extend( + glob.glob(path + additional_pattern + extension, + recursive=recursive)) + + if self.verbose: + logger.info("Finished searching for images. %s images found", + len(files)) + logger.info("Preparing to run the detection.") + + predictions = {} + for image_path in tqdm(files, disable=not show_progress_bar): + if self.verbose: + logger.info("Running the face detector on image: %s", + image_path) + predictions[image_path] = self.detect_from_image(image_path) + + if self.verbose: + logger.info("The detector was successfully run on all %s images", + len(files)) + + return predictions + + @property + def reference_scale(self): + raise NotImplementedError + + @property + def reference_x_shift(self): + raise NotImplementedError + + @property + def reference_y_shift(self): + raise NotImplementedError + + @staticmethod + def tensor_or_path_to_ndarray(tensor_or_path, rgb=True): + """Convert path (represented as a string) or paddle.tensor to a numpy.ndarray + + Args: + tensor_or_path {numpy.ndarray, paddle.tensor or string} -- path to the image, or the image itself + """ + if isinstance(tensor_or_path, str): + return cv2.imread(tensor_or_path) if not rgb else cv2.imread( + tensor_or_path)[..., ::-1] + elif isinstance( + tensor_or_path, + (paddle.fluid.framework.Variable, paddle.fluid.core.VarBase)): + # Call cpu in case its coming from cuda + return tensor_or_path.numpy()[ + ..., ::-1].copy() if not rgb else tensor_or_path.numpy() + elif isinstance(tensor_or_path, np.ndarray): + return tensor_or_path[ + ..., ::-1].copy() if not rgb else tensor_or_path + else: + raise TypeError diff --git a/ppgan/faceutils/face_detection/detection/sfd/__init__.py b/ppgan/faceutils/face_detection/detection/sfd/__init__.py new file mode 100644 index 0000000..fef5689 --- /dev/null +++ b/ppgan/faceutils/face_detection/detection/sfd/__init__.py @@ -0,0 +1 @@ +from .sfd_detector import SFDDetector as FaceDetector diff --git a/ppgan/faceutils/face_detection/detection/sfd/bbox.py b/ppgan/faceutils/face_detection/detection/sfd/bbox.py new file mode 100644 index 0000000..02b21a4 --- /dev/null +++ b/ppgan/faceutils/face_detection/detection/sfd/bbox.py @@ -0,0 +1,147 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import os +import sys +import cv2 +import random +import datetime +import time +import math +import argparse +import numpy as np +import paddle + +try: + from iou import IOU +except BaseException: + # IOU cython speedup 10x + def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2): + sa = abs((ax2 - ax1) * (ay2 - ay1)) + sb = abs((bx2 - bx1) * (by2 - by1)) + x1, y1 = max(ax1, bx1), max(ay1, by1) + x2, y2 = min(ax2, bx2), min(ay2, by2) + w = x2 - x1 + h = y2 - y1 + if w < 0 or h < 0: + return 0.0 + else: + return 1.0 * w * h / (sa + sb - w * h) + + +def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh): + xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1 + dx, dy = (xc - axc) / aww, (yc - ayc) / ahh + dw, dh = math.log(ww / aww), math.log(hh / ahh) + return dx, dy, dw, dh + + +def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh): + xc, yc = dx * aww + axc, dy * ahh + ayc + ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh + x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2 + return x1, y1, x2, y2 + + +def nms(dets, thresh): + if 0 == len(dets): + return [] + x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:,2], \ + dets[:, 3], dets[:,4] + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1, yy1 = np.maximum(x1[i], + x1[order[1:]]), np.maximum(y1[i], y1[order[1:]]) + xx2, yy2 = np.minimum(x2[i], + x2[order[1:]]), np.minimum(y2[i], y2[order[1:]]) + + w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1) + ovr = w * h / (areas[i] + areas[order[1:]] - w * h) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + + +def encode(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 4]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded boxes (tensor), Shape: [num_priors, 4] + """ + + # dist b/t match center and prior's center + g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, 2:]) + # match wh / prior wh + g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] + g_wh = paddle.log(g_wh) / variances[1] + # return target for smooth_l1_loss + return paddle.concat([g_cxcy, g_wh], 1) # [num_priors,4] + + +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = paddle.concat( + (priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * paddle.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + +def batch_decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = paddle.concat( + (priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:], + priors[:, :, 2:] * paddle.exp(loc[:, :, 2:] * variances[1])), 2) + boxes[:, :, :2] -= boxes[:, :, 2:] / 2 + boxes[:, :, 2:] += boxes[:, :, :2] + return boxes diff --git a/ppgan/faceutils/face_detection/detection/sfd/detect.py b/ppgan/faceutils/face_detection/detection/sfd/detect.py new file mode 100644 index 0000000..b5493cc --- /dev/null +++ b/ppgan/faceutils/face_detection/detection/sfd/detect.py @@ -0,0 +1,104 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn.functional as F + +import os +import sys +import cv2 +import random +import datetime +import math +import argparse +import numpy as np + +import scipy.io as sio +import zipfile +from .net_s3fd import s3fd +from .bbox import * + + +def detect(net, img): + img = img - np.array([104, 117, 123]) + img = img.transpose(2, 0, 1) + img = img.reshape((1, ) + img.shape) + + img = paddle.to_tensor(img).astype('float32') + BB, CC, HH, WW = img.shape + with paddle.no_grad(): + olist = net(img) + + bboxlist = [] + for i in range(len(olist) // 2): + olist[i * 2] = F.softmax(olist[i * 2], axis=1) + for i in range(len(olist) // 2): + ocls, oreg = olist[i * 2], olist[i * 2 + 1] + FB, FC, FH, FW = ocls.shape # feature map size + stride = 2**(i + 2) # 4,8,16,32,64,128 + anchor = stride * 4 + poss = zip(*np.where(ocls.numpy()[:, 1, :, :] > 0.05)) + for Iindex, hindex, windex in poss: + axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride + score = ocls.numpy()[0, 1, hindex, windex] + loc = oreg.numpy()[0, :, hindex, windex].reshape(1, 4) + priors = paddle.to_tensor( + [[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]) + variances = [0.1, 0.2] + box = decode(paddle.to_tensor(loc), priors, variances) + x1, y1, x2, y2 = box[0] * 1.0 + bboxlist.append([x1, y1, x2, y2, score]) + bboxlist = np.array(bboxlist) + if 0 == len(bboxlist): + bboxlist = np.zeros((1, 5)) + + return bboxlist + + +def batch_detect(net, imgs): + imgs = imgs - np.array([104, 117, 123]) + imgs = imgs.transpose(0, 3, 1, 2) + + imgs = paddle.to_tensor(imgs).astype('float32') + BB, CC, HH, WW = imgs.shape + with paddle.no_grad(): + olist = net(imgs) + + bboxlist = [] + for i in range(len(olist) // 2): + olist[i * 2] = F.softmax(olist[i * 2], axis=1) + for i in range(len(olist) // 2): + ocls, oreg = olist[i * 2], olist[i * 2 + 1] + FB, FC, FH, FW = ocls.shape # feature map size + stride = 2**(i + 2) # 4,8,16,32,64,128 + anchor = stride * 4 + poss = zip(*np.where(ocls.numpy()[:, 1, :, :] > 0.05)) + for Iindex, hindex, windex in poss: + axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride + score = ocls.numpy()[:, 1, hindex, windex] + loc = oreg.numpy()[:, :, hindex, windex].reshape(BB, 1, 4) + priors = paddle.to_tensor( + [[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, + stride * 4 / 1.0]]).reshape([1, 1, 4]) + variances = [0.1, 0.2] + box = batch_decode(paddle.to_tensor(loc), priors, variances) + box = box[:, 0] * 1.0 + bboxlist.append( + paddle.concat([box, paddle.to_tensor(score).unsqueeze(1)], + 1).numpy()) + bboxlist = np.array(bboxlist) + if 0 == len(bboxlist): + bboxlist = np.zeros((1, BB, 5)) + + return bboxlist diff --git a/ppgan/faceutils/face_detection/detection/sfd/net_s3fd.py b/ppgan/faceutils/face_detection/detection/sfd/net_s3fd.py new file mode 100644 index 0000000..aa5a7db --- /dev/null +++ b/ppgan/faceutils/face_detection/detection/sfd/net_s3fd.py @@ -0,0 +1,195 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class L2Norm(nn.Layer): + def __init__(self, n_channels, scale=1.0): + super(L2Norm, self).__init__() + self.n_channels = n_channels + self.scale = scale + self.eps = 1e-10 + self.weight = paddle.create_parameter(shape=[self.n_channels], + dtype='float32') + self.weight.set_value(paddle.zeros([self.n_channels]) + self.scale) + + def forward(self, x): + norm = x.pow(2).sum(axis=1, keepdim=True).sqrt() + self.eps + x = x / norm * self.weight.reshape([1, -1, 1, 1]) + return x + + +class s3fd(nn.Layer): + def __init__(self): + super(s3fd, self).__init__() + self.conv1_1 = nn.Conv2D(3, 64, kernel_size=3, stride=1, padding=1) + self.conv1_2 = nn.Conv2D(64, 64, kernel_size=3, stride=1, padding=1) + + self.conv2_1 = nn.Conv2D(64, 128, kernel_size=3, stride=1, padding=1) + self.conv2_2 = nn.Conv2D(128, 128, kernel_size=3, stride=1, padding=1) + + self.conv3_1 = nn.Conv2D(128, 256, kernel_size=3, stride=1, padding=1) + self.conv3_2 = nn.Conv2D(256, 256, kernel_size=3, stride=1, padding=1) + self.conv3_3 = nn.Conv2D(256, 256, kernel_size=3, stride=1, padding=1) + + self.conv4_1 = nn.Conv2D(256, 512, kernel_size=3, stride=1, padding=1) + self.conv4_2 = nn.Conv2D(512, 512, kernel_size=3, stride=1, padding=1) + self.conv4_3 = nn.Conv2D(512, 512, kernel_size=3, stride=1, padding=1) + + self.conv5_1 = nn.Conv2D(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_2 = nn.Conv2D(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_3 = nn.Conv2D(512, 512, kernel_size=3, stride=1, padding=1) + + self.fc6 = nn.Conv2D(512, 1024, kernel_size=3, stride=1, padding=3) + self.fc7 = nn.Conv2D(1024, 1024, kernel_size=1, stride=1, padding=0) + + self.conv6_1 = nn.Conv2D(1024, 256, kernel_size=1, stride=1, padding=0) + self.conv6_2 = nn.Conv2D(256, 512, kernel_size=3, stride=2, padding=1) + + self.conv7_1 = nn.Conv2D(512, 128, kernel_size=1, stride=1, padding=0) + self.conv7_2 = nn.Conv2D(128, 256, kernel_size=3, stride=2, padding=1) + + self.conv3_3_norm = L2Norm(256, scale=10) + self.conv4_3_norm = L2Norm(512, scale=8) + self.conv5_3_norm = L2Norm(512, scale=5) + + self.conv3_3_norm_mbox_conf = nn.Conv2D(256, + 4, + kernel_size=3, + stride=1, + padding=1) + self.conv3_3_norm_mbox_loc = nn.Conv2D(256, + 4, + kernel_size=3, + stride=1, + padding=1) + self.conv4_3_norm_mbox_conf = nn.Conv2D(512, + 2, + kernel_size=3, + stride=1, + padding=1) + self.conv4_3_norm_mbox_loc = nn.Conv2D(512, + 4, + kernel_size=3, + stride=1, + padding=1) + self.conv5_3_norm_mbox_conf = nn.Conv2D(512, + 2, + kernel_size=3, + stride=1, + padding=1) + self.conv5_3_norm_mbox_loc = nn.Conv2D(512, + 4, + kernel_size=3, + stride=1, + padding=1) + + self.fc7_mbox_conf = nn.Conv2D(1024, + 2, + kernel_size=3, + stride=1, + padding=1) + self.fc7_mbox_loc = nn.Conv2D(1024, + 4, + kernel_size=3, + stride=1, + padding=1) + self.conv6_2_mbox_conf = nn.Conv2D(512, + 2, + kernel_size=3, + stride=1, + padding=1) + self.conv6_2_mbox_loc = nn.Conv2D(512, + 4, + kernel_size=3, + stride=1, + padding=1) + self.conv7_2_mbox_conf = nn.Conv2D(256, + 2, + kernel_size=3, + stride=1, + padding=1) + self.conv7_2_mbox_loc = nn.Conv2D(256, + 4, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + h = F.relu(self.conv1_1(x)) + h = F.relu(self.conv1_2(h)) + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv2_1(h)) + h = F.relu(self.conv2_2(h)) + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv3_1(h)) + h = F.relu(self.conv3_2(h)) + h = F.relu(self.conv3_3(h)) + f3_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv4_1(h)) + h = F.relu(self.conv4_2(h)) + h = F.relu(self.conv4_3(h)) + f4_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv5_1(h)) + h = F.relu(self.conv5_2(h)) + h = F.relu(self.conv5_3(h)) + f5_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.fc6(h)) + h = F.relu(self.fc7(h)) + ffc7 = h + h = F.relu(self.conv6_1(h)) + h = F.relu(self.conv6_2(h)) + f6_2 = h + h = F.relu(self.conv7_1(h)) + h = F.relu(self.conv7_2(h)) + f7_2 = h + + f3_3 = self.conv3_3_norm(f3_3) + f4_3 = self.conv4_3_norm(f4_3) + f5_3 = self.conv5_3_norm(f5_3) + + cls1 = self.conv3_3_norm_mbox_conf(f3_3) + reg1 = self.conv3_3_norm_mbox_loc(f3_3) + cls2 = self.conv4_3_norm_mbox_conf(f4_3) + reg2 = self.conv4_3_norm_mbox_loc(f4_3) + cls3 = self.conv5_3_norm_mbox_conf(f5_3) + reg3 = self.conv5_3_norm_mbox_loc(f5_3) + cls4 = self.fc7_mbox_conf(ffc7) + reg4 = self.fc7_mbox_loc(ffc7) + cls5 = self.conv6_2_mbox_conf(f6_2) + reg5 = self.conv6_2_mbox_loc(f6_2) + cls6 = self.conv7_2_mbox_conf(f7_2) + reg6 = self.conv7_2_mbox_loc(f7_2) + + # max-out background label + chunk = paddle.chunk(cls1, 4, 1) + tmp_max = paddle.where(chunk[0] > chunk[1], chunk[0], chunk[1]) + bmax = paddle.where(tmp_max > chunk[2], tmp_max, chunk[2]) + cls1 = paddle.concat([bmax, chunk[3]], axis=1) + + return [ + cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, + reg6 + ] diff --git a/ppgan/faceutils/face_detection/detection/sfd/sfd_detector.py b/ppgan/faceutils/face_detection/detection/sfd/sfd_detector.py new file mode 100644 index 0000000..4db2d28 --- /dev/null +++ b/ppgan/faceutils/face_detection/detection/sfd/sfd_detector.py @@ -0,0 +1,77 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +from paddle.utils.download import get_weights_path_from_url + +from ..core import FaceDetector + +from .net_s3fd import s3fd +from .bbox import * +from .detect import * + +models_urls = { + 's3fd': 'https://paddlegan.bj.bcebos.com/models/s3fd_paddle.pdparams', +} + + +class SFDDetector(FaceDetector): + def __init__(self, path_to_detector=None, verbose=False): + super(SFDDetector, self).__init__(verbose) + + # Initialise the face detector + if path_to_detector is None: + model_weights_path = get_weights_path_from_url( + models_urls['s3fd'], cur_path) + model_weights = paddle.load(model_weights_path) + else: + model_weights = paddle.load(path_to_detector) + + self.face_detector = s3fd() + self.face_detector.load_dict(model_weights) + self.face_detector.eval() + + def detect_from_image(self, tensor_or_path): + image = self.tensor_or_path_to_ndarray(tensor_or_path) + + bboxlist = detect(self.face_detector, image) + keep = nms(bboxlist, 0.3) + bboxlist = bboxlist[keep, :] + bboxlist = [x for x in bboxlist if x[-1] > 0.5] + + return bboxlist + + def detect_from_batch(self, images): + bboxlists = batch_detect(self.face_detector, images) + keeps = [ + nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1]) + ] + bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)] + bboxlists = [[x for x in bboxlist if x[-1] > 0.5] + for bboxlist in bboxlists] + + return bboxlists + + @property + def reference_scale(self): + return 195 + + @property + def reference_x_shift(self): + return 0 + + @property + def reference_y_shift(self): + return 0 diff --git a/ppgan/faceutils/face_detection/utils.py b/ppgan/faceutils/face_detection/utils.py new file mode 100644 index 0000000..6590f96 --- /dev/null +++ b/ppgan/faceutils/face_detection/utils.py @@ -0,0 +1,78 @@ +import time +import paddle +import math +import numpy as np +import cv2 + + +def transform(point, center, scale, resolution, invert=False): + """Generate and affine transformation matrix. + + Given a set of points, a center, a scale and a targer resolution, the + function generates and affine transformation matrix. If invert is ``True`` + it will produce the inverse transformation. + + Args: + point {paddle.tensor} -- the input 2D point + center {paddle.tensor or numpy.array} -- the center around which to perform the transformations + scale {float} -- the scale of the face/object + resolution {float} -- the output resolution + invert {bool} -- define wherever the function should produce the direct or the + inverse transformation matrix (default: {False}) + """ + _pt = paddle.ones([3]) + _pt[0] = point[0] + _pt[1] = point[1] + + h = 200.0 * scale + t = paddle.eye(3) + t[0, 0] = resolution / h + t[1, 1] = resolution / h + t[0, 2] = resolution * (-center[0] / h + 0.5) + t[1, 2] = resolution * (-center[1] / h + 0.5) + + if invert: + t = paddle.inverse(t) + + new_point = (paddle.matmul(t, _pt))[0:2] + + return new_point.astype('int32') + + +def crop(image, center, scale, resolution=256.0): + """Center crops an image or set of heatmaps + + Args: + image {numpy.array} -- an rgb image + center {numpy.array} -- the center of the object, usually the same as of the bounding box + scale {float} -- scale of the face + resolution {float} -- the size of the output cropped image (default: {256.0}) + + """ + """ Crops the image around the center. Input is expected to be an np.ndarray """ + ul = transform([1, 1], center, scale, resolution, True) + ul = ul.numpy() + br = transform([resolution, resolution], center, scale, resolution, True) + br = br.numpy() + if image.ndim > 2: + newDim = np.array([br[1] - ul[1], br[0] - ul[0], image.shape[2]], + dtype=np.int32) + newImg = np.zeros(newDim, dtype=np.uint8) + else: + newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int) + newImg = np.zeros(newDim, dtype=np.uint8) + ht = image.shape[0] + wd = image.shape[1] + newX = np.array( + [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32) + newY = np.array( + [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32) + oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32) + oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32) + newImg[newY[0] - 1:newY[1], + newX[0] - 1:newX[1]] = image[oldY[0] - 1:oldY[1], + oldX[0] - 1:oldX[1], :] + newImg = cv2.resize(newImg, + dsize=(int(resolution), int(resolution)), + interpolation=cv2.INTER_LINEAR) + return newImg diff --git a/ppgan/models/discriminators/syncnet.py b/ppgan/models/discriminators/syncnet.py new file mode 100644 index 0000000..d07b4fd --- /dev/null +++ b/ppgan/models/discriminators/syncnet.py @@ -0,0 +1,145 @@ +import paddle +from paddle import nn +from paddle.nn import functional as F +from ...modules.conv import ConvBNRelu + + +class SyncNetColor(nn.Layer): + def __init__(self): + super(SyncNetColor, self).__init__() + + self.face_encoder = nn.Sequential( + ConvBNRelu(15, 32, kernel_size=(7, 7), stride=1, padding=3), + ConvBNRelu(32, 64, kernel_size=5, stride=(1, 2), padding=1), + ConvBNRelu(64, + 64, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(64, + 64, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(64, 128, kernel_size=3, stride=2, padding=1), + ConvBNRelu(128, + 128, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(128, + 128, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(128, + 128, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(128, 256, kernel_size=3, stride=2, padding=1), + ConvBNRelu(256, + 256, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(256, + 256, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(256, 512, kernel_size=3, stride=2, padding=1), + ConvBNRelu(512, + 512, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(512, + 512, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(512, 512, kernel_size=3, stride=2, padding=1), + ConvBNRelu(512, 512, kernel_size=3, stride=1, padding=0), + ConvBNRelu(512, 512, kernel_size=1, stride=1, padding=0), + ) + + self.audio_encoder = nn.Sequential( + ConvBNRelu(1, 32, kernel_size=3, stride=1, padding=1), + ConvBNRelu(32, + 32, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(32, + 32, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(32, 64, kernel_size=3, stride=(3, 1), padding=1), + ConvBNRelu(64, + 64, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(64, + 64, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(64, 128, kernel_size=3, stride=3, padding=1), + ConvBNRelu(128, + 128, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(128, + 128, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(128, 256, kernel_size=3, stride=(3, 2), padding=1), + ConvBNRelu(256, + 256, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(256, + 256, + kernel_size=3, + stride=1, + padding=1, + residual=True), + ConvBNRelu(256, 512, kernel_size=3, stride=1, padding=0), + ConvBNRelu(512, 512, kernel_size=1, stride=1, padding=0), + ) + + def forward(self, audio_sequences, + face_sequences): # audio_sequences := (B, dim, T) + face_embedding = self.face_encoder(face_sequences) + audio_embedding = self.audio_encoder(audio_sequences) + + audio_embedding = audio_embedding.reshape( + [audio_embedding.shape[0], -1]) + face_embedding = face_embedding.reshape([face_embedding.shape[0], -1]) + + audio_embedding = F.normalize(audio_embedding, p=2, axis=1) + face_embedding = F.normalize(face_embedding, p=2, axis=1) + + return audio_embedding, face_embedding diff --git a/ppgan/modules/conv.py b/ppgan/modules/conv.py new file mode 100644 index 0000000..a7f5b2c --- /dev/null +++ b/ppgan/modules/conv.py @@ -0,0 +1,68 @@ +import paddle +from paddle import nn +from paddle.nn import functional as F + + +class ConvBNRelu(nn.Layer): + def __init__(self, + cin, + cout, + kernel_size, + stride, + padding, + residual=False, + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv2D(cin, cout, kernel_size, stride, padding), + nn.BatchNorm2D(cout)) + self.act = nn.ReLU() + self.residual = residual + + def forward(self, x): + out = self.conv_block(x) + if self.residual: + out += x + return self.act(out) + + +class NonNormConv2d(nn.Layer): + def __init__(self, + cin, + cout, + kernel_size, + stride, + padding, + residual=False, + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv2D(cin, cout, kernel_size, stride, padding), ) + self.act = nn.LeakyReLU(0.01, inplace=True) + + def forward(self, x): + out = self.conv_block(x) + return self.act(out) + + +class Conv2dTranspseRelu(nn.Layer): + def __init__(self, + cin, + cout, + kernel_size, + stride, + padding, + output_padding=0, + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.ConvTranspose2D(cin, cout, kernel_size, stride, padding, + output_padding), nn.BatchNorm2D(cout)) + self.act = nn.ReLU() + + def forward(self, x): + out = self.conv_block(x) + return self.act(out) diff --git a/requirments.txt b/requirements.txt similarity index 100% rename from requirments.txt rename to requirements.txt -- GitLab