未验证 提交 2328f9fb 编写于 作者: L LielinJiang 提交者: GitHub

polish code (#457)

上级 7f628db5
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# code was heavily based on https://github.com/Rudrabha/Wav2Lip
# Users should be careful about adopting these functions in any commercial matters.
# https://github.com/Rudrabha/Wav2Lip#license-and-citation
import cv2
import random
......
# Copyright (c) MMEditing Authors.
import paddle
import paddle.nn as nn
import paddle.vision.models.vgg as vgg
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) MMEditing Authors.
import paddle
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# MIT License
# Copyright (c) 2018 Jason Antic
import numpy as np
import paddle
......
......@@ -48,8 +48,18 @@ class UGATITModel(BaseModel):
cam_weight=1000.0):
"""Initialize the CycleGAN class.
Parameters:
opt (config)-- stores all the experiment flags; needs to be a subclass of Dict
Args:
generator (dict): config of generator.
discriminator_g (dict): config of discriminator_g.
discriminator_l (dict): config of discriminator_l.
l1_criterion (dict): config of l1_criterion.
mse_criterion (dict): config of mse_criterion.
bce_criterion (dict): config of bce_criterion.
direction (str): direction of dataset, default: 'a2b'.
adv_weight (float): adversial loss weight, default: 1.0.
cycle_weight (float): cycle loss weight, default: 10.0.
identity_weight (float): identity loss weight, default: 10.0.
cam_weight (float): cam loss weight, default: 1000.0.
"""
super(UGATITModel, self).__init__()
self.adv_weight = adv_weight
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# code was heavily based on https://github.com/AliaksandrSiarohin/first-order-model
import numpy as np
from scipy.spatial import ConvexHull
......
# code was heavily based on https://github.com/Rudrabha/Wav2Lip
# Users should be careful about adopting these functions in any commercial matters.
# https://github.com/Rudrabha/Wav2Lip#license-and-citation
import numpy as np
from scipy import signal
from scipy.io import wavfile
......
from easydict import EasyDict as edict
from .config import AttrDict
_C = edict()
_audio_cfg = AttrDict()
_C.num_mels = 80
_C.rescale = True
_C.rescaling_max = 0.9
_C.use_lws = False
_C.n_fft = 800
_C.hop_size = 200
_C.win_size = 800
_C.sample_rate = 16000
_C.frame_shift_ms = None
_C.signal_normalization = True
_C.allow_clipping_in_normalization = True
_C.symmetric_mels = True
_C.max_abs_value = 4.
_C.preemphasize = True
_C.preemphasis = 0.97
_C.min_level_db = -100
_C.ref_level_db = 20
_C.fmin = 55
_C.fmax = 7600
_C.fps = 25
_audio_cfg.num_mels = 80
_audio_cfg.rescale = True
_audio_cfg.rescaling_max = 0.9
_audio_cfg.use_lws = False
_audio_cfg.n_fft = 800
_audio_cfg.hop_size = 200
_audio_cfg.win_size = 800
_audio_cfg.sample_rate = 16000
_audio_cfg.frame_shift_ms = None
_audio_cfg.signal_normalization = True
_audio_cfg.allow_clipping_in_normalization = True
_audio_cfg.symmetric_mels = True
_audio_cfg.max_abs_value = 4.
_audio_cfg.preemphasize = True
_audio_cfg.preemphasis = 0.97
_audio_cfg.min_level_db = -100
_audio_cfg.ref_level_db = 20
_audio_cfg.fmin = 55
_audio_cfg.fmax = 7600
_audio_cfg.fps = 25
def get_audio_config():
return _C
return _audio_cfg
......@@ -19,7 +19,6 @@ __all__ = ['get_config']
class AttrDict(dict):
def __getattr__(self, key):
# return self[key]
try:
return self[key]
except KeyError:
......
......@@ -21,51 +21,48 @@ class ImagePool():
This buffer enables us to update discriminators using a history of generated images
rather than the ones produced by the latest generators.
"""
def __init__(self, pool_size):
"""Initialize the ImagePool class
Parameters:
pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
"""
Args:
pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
"""
def __init__(self, pool_size, prob=0.5):
self.pool_size = pool_size
if self.pool_size > 0: # create an empty pool
self.prob = prob
if self.pool_size > 0:
self.num_imgs = 0
self.images = []
def query(self, images):
"""Return an image from the pool.
Parameters:
images: the latest generated images from the generator
Args:
images (paddle.Tensor): the latest generated images from the generator
Returns images from the buffer.
By 50/100, the buffer will return input images.
By 50/100, the buffer will return images previously stored in the buffer,
and insert the current images to the buffer.
"""
if self.pool_size == 0: # if the buffer size is 0, do nothing
# if the buffer size is 0, do nothing
if self.pool_size == 0:
return images
return_images = []
for image in images:
image = paddle.unsqueeze(image, 0)
if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
# if the buffer is not full; keep inserting current images to the buffer
if self.num_imgs < self.pool_size:
self.num_imgs = self.num_imgs + 1
self.images.append(image)
return_images.append(image)
else:
p = random.uniform(0, 1)
if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
random_id = random.randint(0, self.pool_size -
1) # randint is inclusive
# FIXME: clone
# tmp = (self.images[random_id]).detach() #.clone()
tmp = self.images[random_id] #.clone()
# by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
if p > self.prob:
random_id = random.randint(0, self.pool_size - 1)
tmp = self.images[random_id].clone()
self.images[random_id] = image
return_images.append(tmp)
else: # by another 50% chance, the buffer will return the current image
else:
# by another 50% chance, the buffer will return the current image
return_images.append(image)
return_images = paddle.concat(return_images,
0) # collect all the images and return
# collect all the images and return
return_images = paddle.concat(return_images, 0)
return return_images
......@@ -61,10 +61,9 @@ def setup_logger(output=None, name="ppgan"):
if local_rank > 0:
filename = filename + ".rank{}".format(local_rank)
# PathManager.mkdirs(os.path.dirname(filename))
# make dir if path not exist
os.makedirs(os.path.dirname(filename), exist_ok=True)
# fh = logging.StreamHandler(_cached_log_stream(filename)
fh = logging.FileHandler(filename, mode='a')
fh.setLevel(logging.DEBUG)
fh.setFormatter(plain_formatter)
......
......@@ -77,6 +77,8 @@ class Registry(object):
return ret
# code was based on mmcv
# Copyright (c) Copyright (c) OpenMMLab.
def build_from_config(cfg, registry, default_args=None):
"""Build a class from config dict.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册