未验证 提交 d76d727e 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

Add SkyAR module (#1218)

上级 12267027
## 模型概述
* SkyAR 是一种用于视频中天空置换与协调的视觉方法,该方法能够在风格可控的视频中自动生成逼真的天空背景。
* 该算法是一种完全基于视觉的解决方案,它的好处就是可以处理非静态图像,同时不受拍摄设备的限制,也不需要用户交互,可以处理在线或离线视频。
* 算法主要由三个核心组成:
* 天空抠图网络(Sky Matting Network):就是一种 Matting 图像分隔,用于检测视频帧中天空区域的视频,可以精确地获得天空蒙版。
* 运动估计(Motion Estimation):恢复天空运动的运动估计器,使生成的天空与摄像机的运动同步。
* 图像融合(Image Blending):将用户指定的天空模板混合到视频帧中。除此之外,还用于重置和着色,使混合结果在其颜色和动态范围内更具视觉逼真感。
* 整体框架图如下:
![](http://p4.itc.cn/q_70/images03/20201114/42eaf00af8dd4aa4ae3c0cdc6e50b793.jpeg)
* 参考论文:Zhengxia Zou. [Castle in the Sky: Dynamic Sky Replacement and Harmonization in Videos](https://arxiv.org/abs/2010.11800). CoRR, abs/2010.118003, 2020.
* 官方开源项目: [jiupinjia/SkyAR](https://github.com/jiupinjia/SkyAR)
## 模型安装
```shell
$hub install SkyAR
```
## 效果展示
* 原始视频:
![原始视频](https://img-blog.csdnimg.cn/20210126142046572.gif)
* 木星:
![木星](https://img-blog.csdnimg.cn/20210125211435619.gif)
* 雨天:
![雨天](https://img-blog.csdnimg.cn/2021012521152492.gif)
* 银河:
![银河](https://img-blog.csdnimg.cn/20210125211523491.gif)
* 第九区飞船:
![第九区飞船](https://img-blog.csdnimg.cn/20210125211520955.gif)
* 原始视频:
![原始视频](https://img-blog.csdnimg.cn/20210126142038716.gif)
* 漂浮城堡:
![漂浮城堡](https://img-blog.csdnimg.cn/20210125211514997.gif)
* 电闪雷鸣:
![电闪雷鸣](https://img-blog.csdnimg.cn/20210125211433591.gif)
* 超级月亮:
![超级月亮](https://img-blog.csdnimg.cn/20210125211417524.gif)
## API 说明
```python
def MagicSky(
video_path, save_path, config='jupiter',
is_rainy=False, preview_frames_num=0, is_video_sky=False, is_show=False,
skybox_img=None, skybox_video=None, rain_cap_path=None,
halo_effect=True, auto_light_matching=False,
relighting_factor=0.8, recoloring_factor=0.5, skybox_center_crop=0.5
)
```
深度估计API
**参数**
* video_path(str):输入视频路径
* save_path(str):视频保存路径
* config(str): 预设 SkyBox 配置,所有预设配置如下,如果使用自定义 SkyBox,请设置为 None:
```
[
'cloudy', 'district9ship', 'floatingcastle', 'galaxy', 'jupiter',
'rainy', 'sunny', 'sunset', 'supermoon', 'thunderstorm'
]
```
* skybox_img(str):自定义的 SkyBox 图像路径
* skybox_video(str):自定义的 SkyBox 视频路径
* is_video_sky(bool):自定义 SkyBox 是否为视频
* rain_cap_path(str):自定义下雨效果视频路径
* is_rainy(bool): 天空是否下雨
* halo_effect(bool):是否开启 halo effect
* auto_light_matching(bool):是否开启自动亮度匹配
* relighting_factor(float): Relighting factor
* recoloring_factor(float): Recoloring factor
* skybox_center_crop(float):SkyBox center crop factor
* preview_frames_num(int):设置预览帧数量,即只处理开头这几帧,设为 0,则为全部处理
* is_show(bool):是否图形化预览
## 预测代码示例
```python
import paddlehub as hub
model = hub.Module(name='SkyAR')
model.MagicSky(
video_path=[path to input video path],
save_path=[path to save video path]
)
```
## 模型相关信息
### 模型代码
https://github.com/jm12138/SkyAR_Paddle_GUI
### 依赖
paddlepaddle >= 2.0.0rc0
paddlehub >= 2.0.0rc0
import os
import paddle.nn as nn
from .skyfilter import SkyFilter
from paddlehub.module.module import moduleinfo
@moduleinfo(
name="SkyAR",
type="CV/Video_editing",
author="jm12138",
author_email="",
summary="SkyAR",
version="1.0.0"
)
class SkyAR(nn.Layer):
def __init__(self, model_path=None):
super(SkyAR, self).__init__()
self.imgs = ['cloudy', 'district9ship', 'floatingcastle',
'galaxy', 'jupiter', 'rainy', 'sunny', 'sunset', 'supermoon']
self.videos = ['thunderstorm']
if model_path:
self.model_path = model_path
else:
self.model_path = os.path.join(self.directory, './ResNet50FCN')
def MagicSky(
self, video_path, save_path, config='jupiter',
is_rainy=False, preview_frames_num=0, is_video_sky=False, is_show=False,
skybox_img=None, skybox_video=None, rain_cap_path=None,
halo_effect=True, auto_light_matching=False,
relighting_factor=0.8, recoloring_factor=0.5, skybox_center_crop=0.5
):
if config in self.imgs:
skybox_img = os.path.join(
self.directory, 'skybox', '%s.jpg' % config)
skybox_video = None
is_video_sky = False
elif config in self.videos:
skybox_img = None
skybox_video = os.path.join(
self.directory, 'skybox', '%s.mp4' % config)
is_video_sky = True
elif skybox_img:
is_video_sky = False
skybox_video = None
elif is_video_sky and skybox_video:
skybox_img = None
else:
raise 'please check your configs'
if not rain_cap_path:
rain_cap_path = os.path.join(
self.directory, 'rain_streaks', 'videoplayback.mp4')
skyfilter = SkyFilter(
model_path=self.model_path,
video_path=video_path,
save_path=save_path,
in_size=(384, 384),
halo_effect=halo_effect,
auto_light_matching=auto_light_matching,
relighting_factor=relighting_factor,
recoloring_factor=recoloring_factor,
skybox_center_crop=skybox_center_crop,
rain_cap_path=rain_cap_path,
skybox_img=skybox_img,
skybox_video=skybox_video,
is_video=is_video_sky,
is_rainy=is_rainy,
is_show=is_show
)
skyfilter.run(preview_frames_num)
import cv2
import numpy as np
__all__ = ['Rain']
class Rain():
def __init__(self, rain_cap_path, rain_intensity=1.0, haze_intensity=4.0, gamma=2.0, light_correction=0.9):
self.rain_intensity = rain_intensity
self.haze_intensity = haze_intensity
self.gamma = gamma
self.light_correction = light_correction
self.frame_id = 1
self.cap = cv2.VideoCapture(rain_cap_path)
def _get_rain_layer(self):
ret, frame = self.cap.read()
if ret:
rain_layer = frame
else: # if reach the last frame, read from the begining
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
ret, frame = self.cap.read()
rain_layer = frame
rain_layer = cv2.cvtColor(rain_layer, cv2.COLOR_BGR2RGB) / 255.0
rain_layer = np.array(rain_layer, dtype=np.float32)
return rain_layer
def _create_haze_layer(self, rain_layer):
return 0.1*np.ones_like(rain_layer)
def forward(self, img):
# get input image size
h, w, c = img.shape
# create a rain layer
rain_layer = self._get_rain_layer()
rain_layer = cv2.resize(rain_layer, (w, h))
rain_layer = cv2.blur(rain_layer, (3, 3))
rain_layer = rain_layer * \
(1 - cv2.boxFilter(img, -1, (int(w/10), int(h/10))))
# create a haze layer
haze_layer = self._create_haze_layer(rain_layer)
# combine the rain layer and haze layer together
rain_layer = self.rain_intensity*rain_layer + \
self.haze_intensity*haze_layer
# synthesize an output image (screen blend)
img_out = 1 - (1 - rain_layer) * (1 - img)
# gamma and light correction
img_out = self.light_correction*(img_out**self.gamma)
# check boundary
img_out = np.clip(img_out, a_min=0, a_max=1.)
return img_out
import cv2
import numpy as np
from .rain import Rain
from .utils import build_transformation_matrix, update_transformation_matrix, estimate_partial_transform, removeOutliers, guidedfilter
class SkyBox():
def __init__(
self, out_size, skybox_img, skybox_video, halo_effect,
auto_light_matching, relighting_factor, recoloring_factor,
skybox_center_crop, rain_cap_path, is_video, is_rainy):
self.out_size_w, self.out_size_h = out_size
self.skybox_img = skybox_img
self.skybox_video = skybox_video
self.is_rainy = is_rainy
self.is_video = is_video
self.halo_effect = halo_effect
self.auto_light_matching = auto_light_matching
self.relighting_factor = relighting_factor
self.recoloring_factor = recoloring_factor
self.skybox_center_crop = skybox_center_crop
self.load_skybox()
self.rainmodel = Rain(
rain_cap_path=rain_cap_path,
rain_intensity=0.8,
haze_intensity=0.0,
gamma=1.0,
light_correction=1.0
)
# motion parameters
self.M = np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32)
self.frame_id = 0
def tile_skybox_img(self, imgtile):
screen_y1 = int(imgtile.shape[0] / 2 - self.out_size_h / 2)
screen_x1 = int(imgtile.shape[1] / 2 - self.out_size_w / 2)
imgtile = np.concatenate(
[imgtile[screen_y1:, :, :], imgtile[0:screen_y1, :, :]], axis=0)
imgtile = np.concatenate(
[imgtile[:, screen_x1:, :], imgtile[:, 0:screen_x1, :]], axis=1)
return imgtile
def load_skybox(self):
print('initialize skybox...')
if not self.is_video:
# static backgroud
skybox_img = cv2.imread(self.skybox_img, cv2.IMREAD_COLOR)
skybox_img = cv2.cvtColor(skybox_img, cv2.COLOR_BGR2RGB)
self.skybox_img = cv2.resize(
skybox_img, (self.out_size_w, self.out_size_h))
cc = 1. / self.skybox_center_crop
imgtile = cv2.resize(
skybox_img, (int(cc * self.out_size_w),
int(cc*self.out_size_h)))
self.skybox_imgx2 = self.tile_skybox_img(imgtile)
self.skybox_imgx2 = np.expand_dims(self.skybox_imgx2, axis=0)
else:
# video backgroud
cap = cv2.VideoCapture(self.skybox_video)
m_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cc = 1. / self.skybox_center_crop
self.skybox_imgx2 = np.zeros(
[m_frames, int(cc*self.out_size_h),
int(cc*self.out_size_w), 3], np.uint8)
for i in range(m_frames):
_, skybox_img = cap.read()
skybox_img = cv2.cvtColor(skybox_img, cv2.COLOR_BGR2RGB)
imgtile = cv2.resize(
skybox_img, (int(cc * self.out_size_w),
int(cc * self.out_size_h)))
skybox_imgx2 = self.tile_skybox_img(imgtile)
self.skybox_imgx2[i, :] = skybox_imgx2
def skymask_refinement(self, G_pred, img):
r, eps = 20, 0.01
refined_skymask = guidedfilter(img[:, :, 2], G_pred[:, :, 0], r, eps)
refined_skymask = np.stack(
[refined_skymask, refined_skymask, refined_skymask], axis=-1)
return np.clip(refined_skymask, a_min=0, a_max=1)
def get_skybg_from_box(self, m):
self.M = update_transformation_matrix(self.M, m)
nbgs, bgh, bgw, c = self.skybox_imgx2.shape
fetch_id = self.frame_id % nbgs
skybg_warp = cv2.warpAffine(
self.skybox_imgx2[fetch_id, :, :, :], self.M,
(bgw, bgh), borderMode=cv2.BORDER_WRAP)
skybg = skybg_warp[0:self.out_size_h, 0:self.out_size_w, :]
self.frame_id += 1
return np.array(skybg, np.float32)/255.
def skybox_tracking(self, frame, frame_prev, skymask):
if np.mean(skymask) < 0.05:
print('sky area is too small')
return np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32)
prev_gray = cv2.cvtColor(frame_prev, cv2.COLOR_RGB2GRAY)
prev_gray = np.array(255*prev_gray, dtype=np.uint8)
curr_gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
curr_gray = np.array(255*curr_gray, dtype=np.uint8)
mask = np.array(skymask[:, :, 0] > 0.99, dtype=np.uint8)
template_size = int(0.05*mask.shape[0])
mask = cv2.erode(mask, np.ones([template_size, template_size]))
# ShiTomasi corner detection
prev_pts = cv2.goodFeaturesToTrack(
prev_gray, mask=mask, maxCorners=200,
qualityLevel=0.01, minDistance=30, blockSize=3)
if prev_pts is None:
print('no feature point detected')
return np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32)
# Calculate optical flow (i.e. track feature points)
curr_pts, status, err = cv2.calcOpticalFlowPyrLK(
prev_gray, curr_gray, prev_pts, None)
# Filter only valid points
idx = np.where(status == 1)[0]
if idx.size == 0:
print('no good point matched')
return np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32)
prev_pts, curr_pts = removeOutliers(prev_pts, curr_pts)
if curr_pts.shape[0] < 10:
print('no good point matched')
return np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32)
# limit the motion to translation + rotation
dxdyda = estimate_partial_transform((
np.array(prev_pts), np.array(curr_pts)))
m = build_transformation_matrix(dxdyda)
return m
def relighting(self, img, skybg, skymask):
# color matching, reference: skybox_img
step = int(img.shape[0]/20)
skybg_thumb = skybg[::step, ::step, :]
img_thumb = img[::step, ::step, :]
skymask_thumb = skymask[::step, ::step, :]
skybg_mean = np.mean(skybg_thumb, axis=(0, 1), keepdims=True)
img_mean = np.sum(img_thumb * (1-skymask_thumb), axis=(0, 1), keepdims=True) \
/ ((1-skymask_thumb).sum(axis=(0, 1), keepdims=True) + 1e-9)
diff = skybg_mean - img_mean
img_colortune = img + self.recoloring_factor*diff
if self.auto_light_matching:
img = img_colortune
else:
# keep foreground ambient_light and maunally adjust lighting
img = self.relighting_factor * \
(img_colortune + (img.mean() - img_colortune.mean()))
return img
def halo(self, syneth, skybg, skymask):
# reflection
halo = 0.5*cv2.blur(
skybg*skymask, (int(self.out_size_w/5),
int(self.out_size_w/5)))
# screen blend 1 - (1-a)(1-b)
syneth_with_halo = 1 - (1-syneth) * (1-halo)
return syneth_with_halo
def skyblend(self, img, img_prev, skymask):
m = self.skybox_tracking(img, img_prev, skymask)
skybg = self.get_skybg_from_box(m)
img = self.relighting(img, skybg, skymask)
syneth = img * (1 - skymask) + skybg * skymask
if self.halo_effect:
# halo effect brings better visual realism but will slow down the speed
syneth = self.halo(syneth, skybg, skymask)
if self.is_rainy:
syneth = self.rainmodel.forward(syneth)
return np.clip(syneth, a_min=0, a_max=1)
import os
import cv2
import paddle
import numpy as np
from .skybox import SkyBox
__all__ = ['SkyFilter']
class SkyFilter():
def __init__(self, model_path, video_path, save_path, in_size,
halo_effect, auto_light_matching, relighting_factor,
recoloring_factor, skybox_center_crop, rain_cap_path,
skybox_img, skybox_video, is_video, is_rainy, is_show
):
self.in_size = in_size
self.is_show = is_show
self.cap = cv2.VideoCapture(video_path)
self.m_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
self.fps = self.cap.get(cv2.CAP_PROP_FPS)
self.out_size = int(self.cap.get(
cv2.CAP_PROP_FRAME_WIDTH)), int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
self.model = paddle.jit.load(
model_path, model_filename='__model__', params_filename='__params__')
self.model.eval()
self.skyboxengine = SkyBox(
out_size=self.out_size, skybox_img=skybox_img, skybox_video=skybox_video,
halo_effect=halo_effect, auto_light_matching=auto_light_matching,
relighting_factor=relighting_factor, recoloring_factor=recoloring_factor,
skybox_center_crop=skybox_center_crop, rain_cap_path=rain_cap_path,
is_video=is_video, is_rainy=is_rainy
)
path, _ = os.path.split(save_path)
if path == '':
path = '.'
if not os.path.exists(path):
os.mkdir(path)
self.video_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'MP4V'),
self.fps, self.out_size)
def synthesize(self, img_HD, img_HD_prev):
h, w, _ = img_HD.shape
img = cv2.resize(img_HD, self.in_size)
img = np.array(img, dtype=np.float32)
img = img.transpose(2, 0, 1)
img = img[np.newaxis, ...]
img = paddle.to_tensor(img)
G_pred = self.model(img)
G_pred = paddle.nn.functional.interpolate(
G_pred, (h, w), mode='bicubic', align_corners=False)
G_pred = G_pred[0, :].transpose([1, 2, 0])
G_pred = paddle.concat([G_pred, G_pred, G_pred], axis=-1)
G_pred = G_pred.detach().numpy()
G_pred = np.clip(G_pred, a_max=1.0, a_min=0.0)
skymask = self.skyboxengine.skymask_refinement(G_pred, img_HD)
syneth = self.skyboxengine.skyblend(img_HD, img_HD_prev, skymask)
return syneth, G_pred, skymask
def run(self, preview_frames_num=0):
img_HD_prev = None
frames_num = preview_frames_num if 0 < preview_frames_num < self.m_frames else self.m_frames
print('frames_num: %d, running evaluation...' % frames_num)
for idx in range(1, frames_num+1):
ret, frame = self.cap.read()
if ret:
frame = cv2.resize(frame, self.out_size)
img_HD = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img_HD = np.array(img_HD / 255., dtype=np.float32)
if img_HD_prev is None:
img_HD_prev = img_HD
syneth, _, _ = self.synthesize(img_HD, img_HD_prev)
result = np.array(255.0 * syneth[:, :, ::-1], dtype=np.uint8)
self.video_writer.write(result)
if self.is_show:
show_img = np.concatenate([frame, result], 1)
h, w = show_img.shape[:2]
show_img = cv2.resize(show_img, (720, int(720/w*h)))
cv2.imshow('preview', show_img)
k = cv2.waitKey(1)
if (k == 27) or (cv2.getWindowProperty('preview', 0) == -1):
self.video_writer.release()
cv2.destroyAllWindows()
break
print('processing: %d / %d ...' % (idx, frames_num))
img_HD_prev = img_HD
else:
self.video_writer.release()
cv2.destroyAllWindows()
break
import cv2
import numpy as np
from sklearn.neighbors import KernelDensity
__all__ = [
'build_transformation_matrix',
'update_transformation_matrix',
'estimate_partial_transform',
'removeOutliers',
'guidedfilter'
]
def build_transformation_matrix(transform):
"""Convert transform list to transformation matrix
:param transform: transform list as [dx, dy, da]
:return: transform matrix as 2d (2, 3) numpy array
"""
transform_matrix = np.zeros((2, 3))
transform_matrix[0, 0] = np.cos(transform[2])
transform_matrix[0, 1] = -np.sin(transform[2])
transform_matrix[1, 0] = np.sin(transform[2])
transform_matrix[1, 1] = np.cos(transform[2])
transform_matrix[0, 2] = transform[0]
transform_matrix[1, 2] = transform[1]
return transform_matrix
def update_transformation_matrix(M, m):
# extend M and m to 3x3 by adding an [0,0,1] to their 3rd row
M_ = np.concatenate([M, np.zeros([1, 3])], axis=0)
M_[-1, -1] = 1
m_ = np.concatenate([m, np.zeros([1, 3])], axis=0)
m_[-1, -1] = 1
M_new = np.matmul(m_, M_)
return M_new[0:2, :]
def estimate_partial_transform(matched_keypoints):
"""Wrapper of cv2.estimateRigidTransform for convenience in vidstab process
:param matched_keypoints: output of match_keypoints util function; tuple of (cur_matched_kp, prev_matched_kp)
:return: transform as list of [dx, dy, da]
"""
prev_matched_kp, cur_matched_kp = matched_keypoints
transform = cv2.estimateAffinePartial2D(np.array(prev_matched_kp),
np.array(cur_matched_kp))[0]
if transform is not None:
# translation x
dx = transform[0, 2]
# translation y
dy = transform[1, 2]
# rotation
da = np.arctan2(transform[1, 0], transform[0, 0])
else:
dx = dy = da = 0
return [dx, dy, da]
def removeOutliers(prev_pts, curr_pts):
d = np.sum((prev_pts - curr_pts)**2, axis=-1)**0.5
d_ = np.array(d).reshape(-1, 1)
kde = KernelDensity(kernel='gaussian', bandwidth=0.5).fit(d_)
density = np.exp(kde.score_samples(d_))
prev_pts = prev_pts[np.where((density >= 0.1))]
curr_pts = curr_pts[np.where((density >= 0.1))]
return prev_pts, curr_pts
def boxfilter(img, r):
(rows, cols) = img.shape
imDst = np.zeros_like(img)
imCum = np.cumsum(img, 0)
imDst[0: r+1, :] = imCum[r: 2*r+1, :]
imDst[r+1: rows-r, :] = imCum[2*r+1: rows, :] - imCum[0: rows-2*r-1, :]
imDst[rows-r: rows, :] = np.tile(imCum[rows-1, :],
[r, 1]) - imCum[rows-2*r-1: rows-r-1, :]
imCum = np.cumsum(imDst, 1)
imDst[:, 0: r+1] = imCum[:, r: 2*r+1]
imDst[:, r+1: cols-r] = imCum[:, 2*r+1: cols] - imCum[:, 0: cols-2*r-1]
imDst[:, cols-r: cols] = np.tile(imCum[:, cols-1],
[r, 1]).T - imCum[:, cols-2*r-1: cols-r-1]
return imDst
def guidedfilter(img, p, r, eps):
(rows, cols) = img.shape
N = boxfilter(np.ones([rows, cols]), r)
meanI = boxfilter(img, r) / N
meanP = boxfilter(p, r) / N
meanIp = boxfilter(img * p, r) / N
covIp = meanIp - meanI * meanP
meanII = boxfilter(img * img, r) / N
varI = meanII - meanI * meanI
a = covIp / (varI + eps)
b = meanP - a * meanI
meanA = boxfilter(a, r) / N
meanB = boxfilter(b, r) / N
q = meanA * img + meanB
return q
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册