未验证 提交 f5be3a9d 编写于 作者: L lijianshe02 提交者: GitHub

remove duplicate frames and keep timestamp, fix psgan docs (#74)

* remove duplicate frames and keep timestamp, fix psgan docs
上级 83c1b7ef
......@@ -35,14 +35,14 @@ python tools/psgan_infer.py \
```
mv landmarks/makeup MT-Dataset/landmarks/makeup
mv landmarks/non-makeup MT-Dataset/landmarks/non-makeup
mv landmarks/train_makeup.txt MT-Dataset/makeup.txt
mv tlandmarks/train_non-makeup.txt MT-Dataset/non-makeup.txt
cp landmarks/train_makeup.txt MT-Dataset/train_makeup.txt
cp landmarks/train_non-makeup.txt MT-Dataset/train_non-makeup.txt
```
The final data directory should be looked like:
```
data
data/MT-Dataset
├── images
│ ├── makeup
│ └── non-makeup
......
......@@ -2,7 +2,7 @@
## 1. PSGAN原理
[PSGAN](https://arxiv.org/abs/1909.06956)模型的任务是妆容迁移, 即将任意参照图像上的妆容迁移到不带妆容的源图像上。很多人像美化应用都需要这种技术。近来的一些妆容迁移方法大都基于生成对抗网络(GAN)。它们通常采用 CycleGAN 的框架,并在两个数据集上进行训练,即无妆容图像和有妆容图像。但是,现有的方法存在一个局限性:只在正面人脸图像上表现良好,没有为处理源图像和参照图像之间的姿态和表情差异专门设计模块。PSGAN是一种全新的姿态稳健可感知空间的生成对抗网络。PSGAN 主要分为三部分:妆容提炼网络(MDNet)、注意式妆容变形(AMM)模块和卸妆-再化妆网络(DRNet)。这三种新提出的模块能让 PSGAN 具备上述的完美妆容迁移模型所应具备的能力。
[PSGAN](https://arxiv.org/abs/1909.06956)模型的任务是妆容迁移, 即将任意参照图像上的妆容迁移到不带妆容的源图像上。很多人像美化应用都需要这种技术。近来的一些妆容迁移方法大都基于生成对抗网络(GAN)。它们通常采用 CycleGAN 的框架,并在两个数据集上进行训练,即无妆容图像和有妆容图像。但是,现有的方法存在一个局限性:只在正面人脸图像上表现良好,没有为处理源图像和参照图像之间的姿态和表情差异专门设计模块。PSGAN是一种全新的姿态稳健可感知空间的生成对抗网络。PSGAN 主要分为三部分:妆容提炼网络(MDNet)、注意式妆容变形(AMM)模块和卸妆-再化妆网络(DRNet)。这三种新提出的模块能让 PSGAN 具备上述的完美妆容迁移模型所应具备的能力。
<div align="center">
<img src="../../imgs/psgan_arc.png" width="800"/>
......@@ -35,13 +35,13 @@ python tools/psgan_infer.py \
```
mv landmarks/makeup MT-Dataset/landmarks/makeup
mv landmarks/non-makeup MT-Dataset/landmarks/non-makeup
mv landmarks/train_makeup.txt MT-Dataset/makeup.txt
mv tlandmarks/train_non-makeup.txt MT-Dataset/non-makeup.txt
cp landmarks/train_makeup.txt MT-Dataset/train_makeup.txt
cp landmarks/train_non-makeup.txt MT-Dataset/train_non-makeup.txt
```
最后数据集目录如下所示:
```
data
data/MT-Dataset
├── images
│   ├── makeup
│   └── non-makeup
......
......@@ -82,14 +82,9 @@ class DAINPredictor(BasePredictor):
vidname = video_path.split('/')[-1].split('.')[0]
frames = sorted(glob.glob(os.path.join(out_path, '*.png')))
orig_frames = len(frames)
need_frames = orig_frames * times_interp
if self.remove_duplicates:
frames = self.remove_duplicate_frames(out_path)
left_frames = len(frames)
timestep = left_frames / need_frames
num_frames = int(1.0 / timestep) - 1
img = imread(frames[0])
......@@ -125,9 +120,11 @@ class DAINPredictor(BasePredictor):
if not os.path.exists(os.path.join(frame_path_combined, vidname)):
os.makedirs(os.path.join(frame_path_combined, vidname))
for i in tqdm(range(frame_num - 1)):
for i in range(frame_num - 1):
first = frames[i]
second = frames[i + 1]
first_index = int(first.split('/')[-1].split('.')[-2])
second_index = int(second.split('/')[-1].split('.')[-2])
img_first = imread(first)
img_second = imread(second)
......@@ -173,22 +170,43 @@ class DAINPredictor(BasePredictor):
padding_left:padding_left + int_width],
(1, 2, 0)) for item in y_
]
time_offsets = [kk * timestep for kk in range(1, 1 + num_frames, 1)]
count = 1
for item, time_offset in zip(y_, time_offsets):
out_dir = os.path.join(frame_path_interpolated, vidname,
"{:0>6d}_{:0>4d}.png".format(i, count))
count = count + 1
imsave(out_dir, np.round(item).astype(np.uint8))
num_frames = int(1.0 / timestep) - 1
if self.remove_duplicates:
num_frames = times_interp * (second_index - first_index) - 1
time_offsets = [
kk * timestep for kk in range(1, 1 + num_frames, 1)
]
start = times_interp * first_index + 1
for item, time_offset in zip(y_, time_offsets):
out_dir = os.path.join(frame_path_interpolated, vidname,
"{:08d}.png".format(start))
imsave(out_dir, np.round(item).astype(np.uint8))
start = start + 1
else:
time_offsets = [
kk * timestep for kk in range(1, 1 + num_frames, 1)
]
count = 1
for item, time_offset in zip(y_, time_offsets):
out_dir = os.path.join(
frame_path_interpolated, vidname,
"{:0>6d}_{:0>4d}.png".format(i, count))
count = count + 1
imsave(out_dir, np.round(item).astype(np.uint8))
input_dir = os.path.join(frame_path_input, vidname)
interpolated_dir = os.path.join(frame_path_interpolated, vidname)
combined_dir = os.path.join(frame_path_combined, vidname)
self.combine_frames(input_dir, interpolated_dir, combined_dir,
num_frames)
if self.remove_duplicates:
self.combine_frames_with_rm(input_dir, interpolated_dir,
combined_dir, times_interp)
else:
num_frames = int(1.0 / timestep) - 1
self.combine_frames(input_dir, interpolated_dir, combined_dir,
num_frames)
frame_pattern_combined = os.path.join(frame_path_combined, vidname,
'%08d.png')
......@@ -223,6 +241,26 @@ class DAINPredictor(BasePredictor):
except Exception as e:
print(e)
def combine_frames_with_rm(self, input, interpolated, combined,
times_interp):
frames1 = sorted(glob.glob(os.path.join(input, '*.png')))
frames2 = sorted(glob.glob(os.path.join(interpolated, '*.png')))
num1 = len(frames1)
num2 = len(frames2)
for i in range(num1):
src = frames1[i]
index = int(src.split('/')[-1].split('.')[-2])
dst = os.path.join(combined,
'{:08d}.png'.format(times_interp * index))
shutil.copy2(src, dst)
for i in range(num2):
src = frames2[i]
imgname = src.split('/')[-1]
dst = os.path.join(combined, imgname)
shutil.copy2(src, dst)
def remove_duplicate_frames(self, paths):
def dhash(image, hash_size=8):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
......@@ -241,14 +279,19 @@ class DAINPredictor(BasePredictor):
for (h, hashed_paths) in hashes.items():
if len(hashed_paths) > 1:
for p in hashed_paths[1:]:
os.remove(p)
frames = sorted(glob.glob(os.path.join(paths, '*.png')))
for fid, frame in enumerate(frames):
new_name = '{:08d}'.format(fid) + '.png'
new_name = os.path.join(paths, new_name)
os.rename(frame, new_name)
first_index = int(hashed_paths[0].split('/')[-1].split('.')[-2])
last_index = int(
hashed_paths[-1].split('/')[-1].split('.')[-2]) + 1
gap = 2 * (last_index - first_index) - 1
if gap > 9:
mid = len(hashed_paths) // 2
for p in hashed_paths[1:mid - 1]:
os.remove(p)
for p in hashed_paths[mid + 1:]:
os.remove(p)
else:
for p in hashed_paths[1:]:
os.remove(p)
frames = sorted(glob.glob(os.path.join(paths, '*.png')))
return frames
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册