未验证 提交 2328f9fb 编写于 作者: L LielinJiang 提交者: GitHub

polish code (#457)

上级 7f628db5
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # code was heavily based on https://github.com/Rudrabha/Wav2Lip
# # Users should be careful about adopting these functions in any commercial matters.
# Licensed under the Apache License, Version 2.0 (the "License"); # https://github.com/Rudrabha/Wav2Lip#license-and-citation
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2 import cv2
import random import random
......
# Copyright (c) MMEditing Authors.
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.vision.models.vgg as vgg import paddle.vision.models.vgg as vgg
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) MMEditing Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle import paddle
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # MIT License
# # Copyright (c) 2018 Jason Antic
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
import paddle import paddle
......
...@@ -48,8 +48,18 @@ class UGATITModel(BaseModel): ...@@ -48,8 +48,18 @@ class UGATITModel(BaseModel):
cam_weight=1000.0): cam_weight=1000.0):
"""Initialize the CycleGAN class. """Initialize the CycleGAN class.
Parameters: Args:
opt (config)-- stores all the experiment flags; needs to be a subclass of Dict generator (dict): config of generator.
discriminator_g (dict): config of discriminator_g.
discriminator_l (dict): config of discriminator_l.
l1_criterion (dict): config of l1_criterion.
mse_criterion (dict): config of mse_criterion.
bce_criterion (dict): config of bce_criterion.
direction (str): direction of dataset, default: 'a2b'.
adv_weight (float): adversial loss weight, default: 1.0.
cycle_weight (float): cycle loss weight, default: 10.0.
identity_weight (float): identity loss weight, default: 10.0.
cam_weight (float): cam loss weight, default: 1000.0.
""" """
super(UGATITModel, self).__init__() super(UGATITModel, self).__init__()
self.adv_weight = adv_weight self.adv_weight = adv_weight
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# code was heavily based on https://github.com/AliaksandrSiarohin/first-order-model
import numpy as np import numpy as np
from scipy.spatial import ConvexHull from scipy.spatial import ConvexHull
......
# code was heavily based on https://github.com/Rudrabha/Wav2Lip
# Users should be careful about adopting these functions in any commercial matters.
# https://github.com/Rudrabha/Wav2Lip#license-and-citation
import numpy as np import numpy as np
from scipy import signal from scipy import signal
from scipy.io import wavfile from scipy.io import wavfile
......
from easydict import EasyDict as edict from .config import AttrDict
_C = edict() _audio_cfg = AttrDict()
_C.num_mels = 80 _audio_cfg.num_mels = 80
_C.rescale = True _audio_cfg.rescale = True
_C.rescaling_max = 0.9 _audio_cfg.rescaling_max = 0.9
_C.use_lws = False _audio_cfg.use_lws = False
_C.n_fft = 800 _audio_cfg.n_fft = 800
_C.hop_size = 200 _audio_cfg.hop_size = 200
_C.win_size = 800 _audio_cfg.win_size = 800
_C.sample_rate = 16000 _audio_cfg.sample_rate = 16000
_C.frame_shift_ms = None _audio_cfg.frame_shift_ms = None
_C.signal_normalization = True _audio_cfg.signal_normalization = True
_C.allow_clipping_in_normalization = True _audio_cfg.allow_clipping_in_normalization = True
_C.symmetric_mels = True _audio_cfg.symmetric_mels = True
_C.max_abs_value = 4. _audio_cfg.max_abs_value = 4.
_C.preemphasize = True _audio_cfg.preemphasize = True
_C.preemphasis = 0.97 _audio_cfg.preemphasis = 0.97
_C.min_level_db = -100 _audio_cfg.min_level_db = -100
_C.ref_level_db = 20 _audio_cfg.ref_level_db = 20
_C.fmin = 55 _audio_cfg.fmin = 55
_C.fmax = 7600 _audio_cfg.fmax = 7600
_C.fps = 25 _audio_cfg.fps = 25
def get_audio_config(): def get_audio_config():
return _C return _audio_cfg
...@@ -19,7 +19,6 @@ __all__ = ['get_config'] ...@@ -19,7 +19,6 @@ __all__ = ['get_config']
class AttrDict(dict): class AttrDict(dict):
def __getattr__(self, key): def __getattr__(self, key):
# return self[key]
try: try:
return self[key] return self[key]
except KeyError: except KeyError:
......
...@@ -21,51 +21,48 @@ class ImagePool(): ...@@ -21,51 +21,48 @@ class ImagePool():
This buffer enables us to update discriminators using a history of generated images This buffer enables us to update discriminators using a history of generated images
rather than the ones produced by the latest generators. rather than the ones produced by the latest generators.
"""
def __init__(self, pool_size):
"""Initialize the ImagePool class
Parameters: Args:
pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
""" """
def __init__(self, pool_size, prob=0.5):
self.pool_size = pool_size self.pool_size = pool_size
if self.pool_size > 0: # create an empty pool self.prob = prob
if self.pool_size > 0:
self.num_imgs = 0 self.num_imgs = 0
self.images = [] self.images = []
def query(self, images): def query(self, images):
"""Return an image from the pool. """Return an image from the pool.
Parameters: Args:
images: the latest generated images from the generator images (paddle.Tensor): the latest generated images from the generator
Returns images from the buffer. Returns images from the buffer.
By 50/100, the buffer will return input images.
By 50/100, the buffer will return images previously stored in the buffer,
and insert the current images to the buffer.
""" """
if self.pool_size == 0: # if the buffer size is 0, do nothing # if the buffer size is 0, do nothing
if self.pool_size == 0:
return images return images
return_images = [] return_images = []
for image in images: for image in images:
image = paddle.unsqueeze(image, 0) image = paddle.unsqueeze(image, 0)
if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer # if the buffer is not full; keep inserting current images to the buffer
if self.num_imgs < self.pool_size:
self.num_imgs = self.num_imgs + 1 self.num_imgs = self.num_imgs + 1
self.images.append(image) self.images.append(image)
return_images.append(image) return_images.append(image)
else: else:
p = random.uniform(0, 1) p = random.uniform(0, 1)
if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
random_id = random.randint(0, self.pool_size - if p > self.prob:
1) # randint is inclusive random_id = random.randint(0, self.pool_size - 1)
# FIXME: clone tmp = self.images[random_id].clone()
# tmp = (self.images[random_id]).detach() #.clone()
tmp = self.images[random_id] #.clone()
self.images[random_id] = image self.images[random_id] = image
return_images.append(tmp) return_images.append(tmp)
else: # by another 50% chance, the buffer will return the current image else:
# by another 50% chance, the buffer will return the current image
return_images.append(image) return_images.append(image)
return_images = paddle.concat(return_images, # collect all the images and return
0) # collect all the images and return return_images = paddle.concat(return_images, 0)
return return_images return return_images
...@@ -61,10 +61,9 @@ def setup_logger(output=None, name="ppgan"): ...@@ -61,10 +61,9 @@ def setup_logger(output=None, name="ppgan"):
if local_rank > 0: if local_rank > 0:
filename = filename + ".rank{}".format(local_rank) filename = filename + ".rank{}".format(local_rank)
# PathManager.mkdirs(os.path.dirname(filename)) # make dir if path not exist
os.makedirs(os.path.dirname(filename), exist_ok=True) os.makedirs(os.path.dirname(filename), exist_ok=True)
# fh = logging.StreamHandler(_cached_log_stream(filename)
fh = logging.FileHandler(filename, mode='a') fh = logging.FileHandler(filename, mode='a')
fh.setLevel(logging.DEBUG) fh.setLevel(logging.DEBUG)
fh.setFormatter(plain_formatter) fh.setFormatter(plain_formatter)
......
...@@ -77,6 +77,8 @@ class Registry(object): ...@@ -77,6 +77,8 @@ class Registry(object):
return ret return ret
# code was based on mmcv
# Copyright (c) Copyright (c) OpenMMLab.
def build_from_config(cfg, registry, default_args=None): def build_from_config(cfg, registry, default_args=None):
"""Build a class from config dict. """Build a class from config dict.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册