未验证 提交 8dc86476 编写于 作者: F FNRE 提交者: GitHub

modify review code (#297)

* test

* add training od first order motion

* Modify codes according to reviews

* Modify codes according to reviews

* (1)modify single person and multi person process of fom model. (2)git rid of sklearn, skimage

* modify documents of fom

* modify review code

* modify review code

* modify review code

* 1.add vox training; 2.fix attributed error for DataParallel. 3.fix fom fps
上级 7b592f5c
epochs: 100
output_dir: output_dir
dataset:
train:
name: FirstOrderDataset
batch_size: 8
num_workers: 4
use_shared_memory: False
phase: train
dataroot: data/first_order/Voxceleb/
frame_shape: [256, 256, 3]
id_sampling: True
pairs_list: None
time_flip: True
num_repeats: 75
create_frames_folder: False
transforms:
- name: PairedRandomHorizontalFlip
prob: 0.5
keys: [image, image]
- name: PairedColorJitter
brightness: 0.1
contrast: 0.1
saturation: 0.1
hue: 0.1
keys: [image, image]
test:
name: FirstOrderDataset
dataroot: data/first_order/Voxceleb/
phase: test
batch_size: 1
num_workers: 1
time_flip: False
id_sampling: False
create_frames_folder: False
frame_shape: [ 256, 256, 3 ]
model:
name: FirstOrderModel
common_params:
num_kp: 10
num_channels: 3
estimate_jacobian: True
generator:
name: FirstOrderGenerator
kp_detector_cfg:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25
num_blocks: 5
generator_cfg:
block_expansion: 64
max_features: 512
num_down_blocks: 2
num_bottleneck_blocks: 6
estimate_occlusion_map: True
dense_motion_params:
block_expansion: 64
max_features: 1024
num_blocks: 5
scale_factor: 0.25
discriminator:
name: FirstOrderDiscriminator
discriminator_cfg:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
sn: True
train_params:
num_epochs: 100
scales: [1, 0.5, 0.25, 0.125]
checkpoint_freq: 50
transform_params:
sigma_affine: 0.05
sigma_tps: 0.005
points_tps: 5
loss_weights:
generator_gan: 0
discriminator_gan: 1
feature_matching: [10, 10, 10, 10]
perceptual: [10, 10, 10, 10, 10]
equivariance_value: 10
equivariance_jacobian: 10
lr_scheduler:
name: MultiStepDecay
epoch_milestones: [237360, 356040]
lr_generator: 2.0e-4
lr_discriminator: 2.0e-4
lr_kp_detector: 2.0e-4
reconstruction_params:
num_videos: 1000
format: '.mp4'
animate_params:
num_pairs: 50
format: '.mp4'
normalization_params:
adapt_movement_scale: False
use_relative_movement: True
use_relative_jacobian: True
visualizer_params:
kp_size: 5
draw_border: True
colormap: 'gist_rainbow'
log_config:
interval: 10
visiual_interval: 10
validate:
interval: 10
save_img: false
snapshot_config:
interval: 10
optimizer:
name: Adam
...@@ -34,7 +34,9 @@ python -u tools/first-order-demo.py \ ...@@ -34,7 +34,9 @@ python -u tools/first-order-demo.py \
--ratio 0.4 \ --ratio 0.4 \
--relative --adapt_scale --relative --adapt_scale
``` ```
- multi face
- multi face:
```
cd applications/ cd applications/
python -u tools/first-order-demo.py \ python -u tools/first-order-demo.py \
--driving_video ../docs/imgs/fom_dv.mp4 \ --driving_video ../docs/imgs/fom_dv.mp4 \
...@@ -43,6 +45,7 @@ python -u tools/first-order-demo.py \ ...@@ -43,6 +45,7 @@ python -u tools/first-order-demo.py \
--relative --adapt_scale \ --relative --adapt_scale \
--multi_person --multi_person
**params:** **params:**
- driving_video: driving video, the motion of the driving video is to be migrated. - driving_video: driving video, the motion of the driving video is to be migrated.
- source_image: source_image, support single people and multi-person in the image, the image will be animated according to the motion of the driving video. - source_image: source_image, support single people and multi-person in the image, the image will be animated according to the motion of the driving video.
......
...@@ -41,7 +41,8 @@ python -u tools/first-order-demo.py \ ...@@ -41,7 +41,8 @@ python -u tools/first-order-demo.py \
--ratio 0.4 \ --ratio 0.4 \
--relative --adapt_scale --relative --adapt_scale
``` ```
- 多人脸 - 多人脸:
```
cd applications/ cd applications/
python -u tools/first-order-demo.py \ python -u tools/first-order-demo.py \
--driving_video ../docs/imgs/fom_dv.mp4 \ --driving_video ../docs/imgs/fom_dv.mp4 \
......
...@@ -163,7 +163,8 @@ class FirstOrderPredictor(BasePredictor): ...@@ -163,7 +163,8 @@ class FirstOrderPredictor(BasePredictor):
imageio.mimsave(os.path.join(self.output, self.filename), [ imageio.mimsave(os.path.join(self.output, self.filename), [
cv2.resize((frame * 255.0).astype('uint8'), (h, w)) cv2.resize((frame * 255.0).astype('uint8'), (h, w))
for frame in predictions for frame in predictions
]) ],
fps=fps)
return return
bboxes = self.extract_bbox(source_image.copy()) bboxes = self.extract_bbox(source_image.copy())
...@@ -175,7 +176,8 @@ class FirstOrderPredictor(BasePredictor): ...@@ -175,7 +176,8 @@ class FirstOrderPredictor(BasePredictor):
imageio.mimsave(os.path.join(self.output, self.filename), [ imageio.mimsave(os.path.join(self.output, self.filename), [
cv2.resize((frame * 255.0).astype('uint8'), (h, w)) cv2.resize((frame * 255.0).astype('uint8'), (h, w))
for frame in predictions for frame in predictions
]) ],
fps=fps)
return return
# for multi person # for multi person
......
...@@ -104,7 +104,9 @@ def read_video(name: Path, frame_shape=tuple([256, 256, 3]), saveto='folder'): ...@@ -104,7 +104,9 @@ def read_video(name: Path, frame_shape=tuple([256, 256, 3]), saveto='folder'):
if name.is_dir(): if name.is_dir():
frames = sorted(name.iterdir(), frames = sorted(name.iterdir(),
key=lambda x: int(x.with_suffix('').name)) key=lambda x: int(x.with_suffix('').name))
video_array = np.array([imread(path) for path in frames]) video_array = np.array([imread(path) for path in frames],
dtype='float32')
return video_array
elif name.suffix.lower() in ['.gif', '.mp4', '.mov']: elif name.suffix.lower() in ['.gif', '.mp4', '.mov']:
try: try:
video = mimread(name, memtest=False) video = mimread(name, memtest=False)
...@@ -135,9 +137,9 @@ def read_video(name: Path, frame_shape=tuple([256, 256, 3]), saveto='folder'): ...@@ -135,9 +137,9 @@ def read_video(name: Path, frame_shape=tuple([256, 256, 3]), saveto='folder'):
for idx, img in enumerate(video_array_reshape): for idx, img in enumerate(video_array_reshape):
cv2.imwrite(sub_dir.joinpath('%i.png' % idx), img) cv2.imwrite(sub_dir.joinpath('%i.png' % idx), img)
name.unlink() name.unlink()
return video_array_reshape
else: else:
raise Exception("Unknown dataset file extensions %s" % name) raise Exception("Unknown dataset file extensions %s" % name)
return video_array_reshape
class FramesDataset(Dataset): class FramesDataset(Dataset):
......
...@@ -64,27 +64,13 @@ class FirstOrderModel(BaseModel): ...@@ -64,27 +64,13 @@ class FirstOrderModel(BaseModel):
generator_cfg.update({'train_params': train_params}) generator_cfg.update({'train_params': train_params})
generator_cfg.update( generator_cfg.update(
{'dis_scales': discriminator.discriminator_cfg.scales}) {'dis_scales': discriminator.discriminator_cfg.scales})
self.Gen_Full = build_generator(generator_cfg) self.nets['Gen_Full'] = build_generator(generator_cfg)
discriminator_cfg = discriminator discriminator_cfg = discriminator
discriminator_cfg.update({'common_params': common_params}) discriminator_cfg.update({'common_params': common_params})
discriminator_cfg.update({'train_params': train_params}) discriminator_cfg.update({'train_params': train_params})
self.Dis = build_discriminator(discriminator_cfg) self.nets['Dis'] = build_discriminator(discriminator_cfg)
self.visualizer = Visualizer() self.visualizer = Visualizer()
if isinstance(self.Gen_Full, paddle.DataParallel):
self.nets['kp_detector'] = self.Gen_Full._layers.kp_extractor
self.nets['generator'] = self.Gen_Full._layers.generator
self.nets['discriminator'] = self.Dis._layers.discriminator
else:
self.nets['kp_detector'] = self.Gen_Full.kp_extractor
self.nets['generator'] = self.Gen_Full.generator
self.nets['discriminator'] = self.Dis.discriminator
# init params
init_weight(self.nets['kp_detector'])
init_weight(self.nets['generator'])
init_weight(self.nets['discriminator'])
def setup_lr_schedulers(self, lr_cfg): def setup_lr_schedulers(self, lr_cfg):
self.kp_lr = MultiStepDecay(learning_rate=lr_cfg['lr_kp_detector'], self.kp_lr = MultiStepDecay(learning_rate=lr_cfg['lr_kp_detector'],
milestones=lr_cfg['epoch_milestones'], milestones=lr_cfg['epoch_milestones'],
...@@ -102,6 +88,21 @@ class FirstOrderModel(BaseModel): ...@@ -102,6 +88,21 @@ class FirstOrderModel(BaseModel):
} }
def setup_optimizers(self, lr_cfg, optimizer): def setup_optimizers(self, lr_cfg, optimizer):
if isinstance(self.nets['Gen_Full'], paddle.DataParallel):
self.nets['kp_detector'] = self.nets[
'Gen_Full']._layers.kp_extractor
self.nets['generator'] = self.nets['Gen_Full']._layers.generator
self.nets['discriminator'] = self.nets['Dis']._layers.discriminator
else:
self.nets['kp_detector'] = self.nets['Gen_Full'].kp_extractor
self.nets['generator'] = self.nets['Gen_Full'].generator
self.nets['discriminator'] = self.nets['Dis'].discriminator
# init params
init_weight(self.nets['kp_detector'])
init_weight(self.nets['generator'])
init_weight(self.nets['discriminator'])
# define loss functions # define loss functions
self.losses = {} self.losses = {}
...@@ -124,7 +125,7 @@ class FirstOrderModel(BaseModel): ...@@ -124,7 +125,7 @@ class FirstOrderModel(BaseModel):
def forward(self): def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" """Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.losses_generator, self.generated = \ self.losses_generator, self.generated = \
self.Gen_Full(self.input_data.copy(), self.nets['discriminator']) self.nets['Gen_Full'](self.input_data.copy(), self.nets['discriminator'])
self.visual_items['driving_source_gen'] = self.visualizer.visualize( self.visual_items['driving_source_gen'] = self.visualizer.visualize(
self.input_data['driving'].detach(), self.input_data['driving'].detach(),
self.input_data['source'].detach(), self.generated) self.input_data['source'].detach(), self.generated)
...@@ -136,7 +137,8 @@ class FirstOrderModel(BaseModel): ...@@ -136,7 +137,8 @@ class FirstOrderModel(BaseModel):
loss.backward() loss.backward()
def backward_D(self): def backward_D(self):
losses_discriminator = self.Dis(self.input_data.copy(), self.generated) losses_discriminator = self.nets['Dis'](self.input_data.copy(),
self.generated)
loss_values = [val.mean() for val in losses_discriminator.values()] loss_values = [val.mean() for val in losses_discriminator.values()]
loss = paddle.add_n(loss_values) loss = paddle.add_n(loss_values)
loss.backward() loss.backward()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册