未验证 提交 07599eed 编写于 作者: C chenjian 提交者: GitHub

add firstordermotion module (#1748)

* add firstordermotion module

* fix a bug

* fix according to review
上级 94e2a74e
# first_order_motion
|模型名称|first_order_motion|
| :--- | :---: |
|类别|图像 - 图像生成|
|网络|S3FD|
|数据集|-|
|是否支持Fine-tuning|否|
|模型大小|343MB|
|最新更新日期|2021-12-24|
|数据指标|-|
## 一、模型基本信息
- ### 应用效果展示
- 样例结果示例:
<p align="center">
<img src="https://user-images.githubusercontent.com/22424850/147347145-1a7e84b6-2853-4490-8eaf-caf9cfdca79b.png" width = "40%" hspace='10'/>
<br />
输入图像
<br />
<img src="https://user-images.githubusercontent.com/22424850/147347151-d6c5690b-00cd-433f-b82b-3f8bb90bc7bd.gif" width = "40%" hspace='10'/>
<br />
输入视频
<br />
<img src="https://user-images.githubusercontent.com/22424850/147348127-52eb3f26-9b2c-49d5-a4a2-20a31f159802.gif" width = "40%" hspace='10'/>
<br />
输出视频
<br />
</p>
- ### 模型介绍
- First Order Motion的任务是图像动画/Image Animation,即输入为一张源图片和一个驱动视频,源图片中的人物则会做出驱动视频中的动作。
## 二、安装
- ### 1、环境依赖
- paddlepaddle >= 2.1.0
- paddlehub >= 2.1.0 | [如何安装PaddleHub](../../../../docs/docs_ch/get_start/installation.rst)
- ### 2、安装
- ```shell
$ hub install first_order_motion
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
## 三、模型API预测
- ### 1、命令行预测
- ```shell
$ hub run first_order_motion --source_image "/PATH/TO/IMAGE" --driving_video "/PATH/TO/VIDEO" --use_gpu
```
- 通过命令行方式实现视频驱动生成模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
- ### 2、预测代码示例
- ```python
import paddlehub as hub
module = hub.Module(name="first_order_motion")
module.generate(source_image="/PATH/TO/IMAGE", driving_video="/PATH/TO/VIDEO", ratio=0.4, image_size=256, output_dir='./motion_driving_result/', filename='result.mp4', use_gpu=False)
```
- ### 3、API
- ```python
generate(self, source_image=None, driving_video=None, ratio=0.4, image_size=256, output_dir='./motion_driving_result/', filename='result.mp4', use_gpu=False)
```
- 视频驱动生成API。
- **参数**
- source_image (str): 原始图片,支持单人图片和多人图片,视频中人物的表情动作将迁移到该原始图片中的人物上。
- driving_video (str): 驱动视频,视频中人物的表情动作作为待迁移的对象。
- ratio (float): 贴回驱动生成的人脸区域占原图的比例, 用户需要根据生成的效果调整该参数,尤其对于多人脸距离比较近的情况下需要调整改参数, 默认为0.4,调整范围是[0.4, 0.5]。
- image_size (int): 图片人脸大小,默认为256,可设置为512。
- output\_dir (str): 结果保存的文件夹名; <br/>
- filename (str): 结果保存的文件名。
- use\_gpu (bool): 是否使用 GPU;<br/>
## 四、更新历史
* 1.0.0
初始发布
- ```shell
$ hub install first_order_motion==1.0.0
```
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import math
import pickle
import yaml
import imageio
import numpy as np
from tqdm import tqdm
from scipy.spatial import ConvexHull
import cv2
import paddle
from ppgan.utils.download import get_path_from_url
from ppgan.utils.animate import normalize_kp
from ppgan.modules.keypoint_detector import KPDetector
from ppgan.models.generators.occlusion_aware import OcclusionAwareGenerator
from ppgan.faceutils import face_detection
class FirstOrderPredictor:
def __init__(self,
weight_path=None,
config=None,
image_size=256,
relative=True,
adapt_scale=False,
find_best_frame=False,
best_frame=None,
face_detector='sfd',
multi_person=False,
face_enhancement=True,
batch_size=1,
mobile_net=False):
if config is not None and isinstance(config, str):
with open(config) as f:
self.cfg = yaml.load(f, Loader=yaml.SafeLoader)
elif isinstance(config, dict):
self.cfg = config
elif config is None:
self.cfg = {
'model': {
'common_params': {
'num_kp': 10,
'num_channels': 3,
'estimate_jacobian': True
},
'generator': {
'kp_detector_cfg': {
'temperature': 0.1,
'block_expansion': 32,
'max_features': 1024,
'scale_factor': 0.25,
'num_blocks': 5
},
'generator_cfg': {
'block_expansion': 64,
'max_features': 512,
'num_down_blocks': 2,
'num_bottleneck_blocks': 6,
'estimate_occlusion_map': True,
'dense_motion_params': {
'block_expansion': 64,
'max_features': 1024,
'num_blocks': 5,
'scale_factor': 0.25
}
}
}
}
}
self.image_size = image_size
if weight_path is None:
if mobile_net:
vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-mobile.pdparams'
else:
if self.image_size == 512:
vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-cpk-512.pdparams'
else:
vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-cpk.pdparams'
weight_path = get_path_from_url(vox_cpk_weight_url)
self.weight_path = weight_path
self.relative = relative
self.adapt_scale = adapt_scale
self.find_best_frame = find_best_frame
self.best_frame = best_frame
self.face_detector = face_detector
self.generator, self.kp_detector = self.load_checkpoints(self.cfg, self.weight_path)
self.multi_person = multi_person
self.face_enhancement = face_enhancement
self.batch_size = batch_size
if face_enhancement:
from ppgan.faceutils.face_enhancement import FaceEnhancement
self.faceenhancer = FaceEnhancement(batch_size=batch_size)
def read_img(self, path):
img = imageio.imread(path)
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# som images have 4 channels
if img.shape[2] > 3:
img = img[:, :, :3]
return img
def run(self, source_image, driving_video, ratio, image_size, output_dir, filename):
self.ratio = ratio
self.image_size = image_size
self.output = output_dir
self.filename = filename
if not os.path.exists(output_dir):
os.makedirs(output_dir)
def get_prediction(face_image):
if self.find_best_frame or self.best_frame is not None:
i = self.best_frame if self.best_frame is not None else self.find_best_frame_func(
source_image, driving_video)
print("Best frame: " + str(i))
driving_forward = driving_video[i:]
driving_backward = driving_video[:(i + 1)][::-1]
predictions_forward = self.make_animation(
face_image,
driving_forward,
self.generator,
self.kp_detector,
relative=self.relative,
adapt_movement_scale=self.adapt_scale)
predictions_backward = self.make_animation(
face_image,
driving_backward,
self.generator,
self.kp_detector,
relative=self.relative,
adapt_movement_scale=self.adapt_scale)
predictions = predictions_backward[::-1] + predictions_forward[1:]
else:
predictions = self.make_animation(
face_image,
driving_video,
self.generator,
self.kp_detector,
relative=self.relative,
adapt_movement_scale=self.adapt_scale)
return predictions
source_image = self.read_img(source_image)
reader = imageio.get_reader(driving_video)
fps = reader.get_meta_data()['fps']
driving_video = []
try:
for im in reader:
driving_video.append(im)
except RuntimeError:
print("Read driving video error!")
pass
reader.close()
driving_video = [cv2.resize(frame, (self.image_size, self.image_size)) / 255.0 for frame in driving_video]
results = []
bboxes = self.extract_bbox(source_image.copy())
print(str(len(bboxes)) + " persons have been detected")
# for multi person
for rec in bboxes:
face_image = source_image.copy()[rec[1]:rec[3], rec[0]:rec[2]]
face_image = cv2.resize(face_image, (self.image_size, self.image_size)) / 255.0
predictions = get_prediction(face_image)
results.append({'rec': rec, 'predict': [predictions[i] for i in range(predictions.shape[0])]})
if len(bboxes) == 1 or not self.multi_person:
break
out_frame = []
for i in range(len(driving_video)):
frame = source_image.copy()
for result in results:
x1, y1, x2, y2, _ = result['rec']
h = y2 - y1
w = x2 - x1
out = result['predict'][i]
out = cv2.resize(out.astype(np.uint8), (x2 - x1, y2 - y1))
if len(results) == 1:
frame[y1:y2, x1:x2] = out
break
else:
patch = np.zeros(frame.shape).astype('uint8')
patch[y1:y2, x1:x2] = out
mask = np.zeros(frame.shape[:2]).astype('uint8')
cx = int((x1 + x2) / 2)
cy = int((y1 + y2) / 2)
cv2.circle(mask, (cx, cy), math.ceil(h * self.ratio), (255, 255, 255), -1, 8, 0)
frame = cv2.copyTo(patch, mask, frame)
out_frame.append(frame)
imageio.mimsave(os.path.join(self.output, self.filename), [frame for frame in out_frame], fps=fps)
def load_checkpoints(self, config, checkpoint_path):
generator = OcclusionAwareGenerator(
**config['model']['generator']['generator_cfg'], **config['model']['common_params'], inference=True)
kp_detector = KPDetector(**config['model']['generator']['kp_detector_cfg'], **config['model']['common_params'])
checkpoint = paddle.load(self.weight_path)
generator.set_state_dict(checkpoint['generator'])
kp_detector.set_state_dict(checkpoint['kp_detector'])
generator.eval()
kp_detector.eval()
return generator, kp_detector
def make_animation(self,
source_image,
driving_video,
generator,
kp_detector,
relative=True,
adapt_movement_scale=True):
with paddle.no_grad():
predictions = []
source = paddle.to_tensor(source_image[np.newaxis].astype(np.float32)).transpose([0, 3, 1, 2])
driving = paddle.to_tensor(np.array(driving_video).astype(np.float32)).transpose([0, 3, 1, 2])
kp_source = kp_detector(source)
kp_driving_initial = kp_detector(driving[0:1])
kp_source_batch = {}
kp_source_batch["value"] = paddle.tile(kp_source["value"], repeat_times=[self.batch_size, 1, 1])
kp_source_batch["jacobian"] = paddle.tile(kp_source["jacobian"], repeat_times=[self.batch_size, 1, 1, 1])
source = paddle.tile(source, repeat_times=[self.batch_size, 1, 1, 1])
begin_idx = 0
for frame_idx in tqdm(range(int(np.ceil(float(driving.shape[0]) / self.batch_size)))):
frame_num = min(self.batch_size, driving.shape[0] - begin_idx)
driving_frame = driving[begin_idx:begin_idx + frame_num]
kp_driving = kp_detector(driving_frame)
kp_source_img = {}
kp_source_img["value"] = kp_source_batch["value"][0:frame_num]
kp_source_img["jacobian"] = kp_source_batch["jacobian"][0:frame_num]
kp_norm = normalize_kp(
kp_source=kp_source,
kp_driving=kp_driving,
kp_driving_initial=kp_driving_initial,
use_relative_movement=relative,
use_relative_jacobian=relative,
adapt_movement_scale=adapt_movement_scale)
out = generator(source[0:frame_num], kp_source=kp_source_img, kp_driving=kp_norm)
img = np.transpose(out['prediction'].numpy(), [0, 2, 3, 1]) * 255.0
if self.face_enhancement:
img = self.faceenhancer.enhance_from_batch(img)
predictions.append(img)
begin_idx += frame_num
return np.concatenate(predictions)
def find_best_frame_func(self, source, driving):
import face_alignment
def normalize_kp(kp):
kp = kp - kp.mean(axis=0, keepdims=True)
area = ConvexHull(kp[:, :2]).volume
area = np.sqrt(area)
kp[:, :2] = kp[:, :2] / area
return kp
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True)
kp_source = fa.get_landmarks(255 * source)[0]
kp_source = normalize_kp(kp_source)
norm = float('inf')
frame_num = 0
for i, image in tqdm(enumerate(driving)):
kp_driving = fa.get_landmarks(255 * image)[0]
kp_driving = normalize_kp(kp_driving)
new_norm = (np.abs(kp_source - kp_driving)**2).sum()
if new_norm < norm:
norm = new_norm
frame_num = i
return frame_num
def extract_bbox(self, image):
detector = face_detection.FaceAlignment(
face_detection.LandmarksType._2D, flip_input=False, face_detector=self.face_detector)
frame = [image]
predictions = detector.get_detections_for_image(np.array(frame))
person_num = len(predictions)
if person_num == 0:
return np.array([])
results = []
face_boxs = []
h, w, _ = image.shape
for rect in predictions:
bh = rect[3] - rect[1]
bw = rect[2] - rect[0]
cy = rect[1] + int(bh / 2)
cx = rect[0] + int(bw / 2)
margin = max(bh, bw)
y1 = max(0, cy - margin)
x1 = max(0, cx - int(0.8 * margin))
y2 = min(h, cy + margin)
x2 = min(w, cx + int(0.8 * margin))
area = (y2 - y1) * (x2 - x1)
results.append([x1, y1, x2, y2, area])
# if a person has more than one bbox, keep the largest one
# maybe greedy will be better?
sorted(results, key=lambda area: area[4], reverse=True)
results_box = [results[0]]
for i in range(1, person_num):
num = len(results_box)
add_person = True
for j in range(num):
pre_person = results_box[j]
iou = self.IOU(pre_person[0], pre_person[1], pre_person[2], pre_person[3], pre_person[4], results[i][0],
results[i][1], results[i][2], results[i][3], results[i][4])
if iou > 0.5:
add_person = False
break
if add_person:
results_box.append(results[i])
boxes = np.array(results_box)
return boxes
def IOU(self, ax1, ay1, ax2, ay2, sa, bx1, by1, bx2, by2, sb):
#sa = abs((ax2 - ax1) * (ay2 - ay1))
#sb = abs((bx2 - bx1) * (by2 - by1))
x1, y1 = max(ax1, bx1), max(ay1, by1)
x2, y2 = min(ax2, bx2), min(ay2, by2)
w = x2 - x1
h = y2 - y1
if w < 0 or h < 0:
return 0.0
else:
return 1.0 * w * h / (sa + sb - w * h)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import copy
import paddle
import paddlehub as hub
from paddlehub.module.module import moduleinfo, runnable, serving
import numpy as np
import cv2
from skimage.io import imread
from skimage.transform import rescale, resize
from .model import FirstOrderPredictor
@moduleinfo(
name="first_order_motion", type="CV/gan", author="paddlepaddle", author_email="", summary="", version="1.0.0")
class FirstOrderMotion:
def __init__(self):
self.pretrained_model = os.path.join(self.directory, "vox-cpk.pdparams")
self.network = FirstOrderPredictor(weight_path=self.pretrained_model, face_enhancement=True)
def generate(self,
source_image=None,
driving_video=None,
ratio=0.4,
image_size=256,
output_dir='./motion_driving_result/',
filename='result.mp4',
use_gpu=False):
'''
source_image (str): path to image<br/>
driving_video (str) : path to driving_video<br/>
ratio: margin ratio
image_size: size of image
output_dir: the dir to save the results
filename: filename to save the results
use_gpu: if True, use gpu to perform the computation, otherwise cpu.
'''
paddle.disable_static()
place = 'gpu:0' if use_gpu else 'cpu'
place = paddle.set_device(place)
if source_image == None or driving_video == None:
print('No image or driving video provided. Please input an image and a driving video.')
return
self.network.run(source_image, driving_video, ratio, image_size, output_dir, filename)
@runnable
def run_cmd(self, argvs: list):
"""
Run as a command.
"""
self.parser = argparse.ArgumentParser(
description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
title="Config options", description="Run configuration for controlling module behavior, not required.")
self.add_module_config_arg()
self.add_module_input_arg()
self.args = self.parser.parse_args(argvs)
self.generate(
source_image=self.args.source_image,
driving_video=self.args.driving_video,
ratio=self.args.ratio,
image_size=self.args.image_size,
output_dir=self.args.output_dir,
use_gpu=self.args.use_gpu)
return
def add_module_config_arg(self):
"""
Add the command config options.
"""
self.arg_config_group.add_argument('--use_gpu', action='store_true', help="use GPU or not")
self.arg_config_group.add_argument(
'--output_dir', type=str, default='motion_driving_result', help='output directory for saving result.')
self.arg_config_group.add_argument("--filename", default='result.mp4', help="filename to output")
def add_module_input_arg(self):
"""
Add the command input options.
"""
self.arg_input_group.add_argument("--source_image", type=str, help="path to source image")
self.arg_input_group.add_argument("--driving_video", type=str, help="path to driving video")
self.arg_input_group.add_argument("--ratio", dest="ratio", type=float, default=0.4, help="margin ratio")
self.arg_input_group.add_argument(
"--image_size", dest="image_size", type=int, default=256, help="size of image")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册