未验证 提交 43ffc949 编写于 作者: L lijianshe02 提交者: GitHub

add automatic weight download (#146)

* add automatic weight download
上级 edd62113
...@@ -10,17 +10,16 @@ parser = argparse.ArgumentParser( ...@@ -10,17 +10,16 @@ parser = argparse.ArgumentParser(
parser.add_argument('--checkpoint_path', parser.add_argument('--checkpoint_path',
type=str, type=str,
help='Name of saved checkpoint to load weights from', help='Name of saved checkpoint to load weights from',
required=True) default=None)
parser.add_argument('--face',
type=str,
help='Filepath of video/image that contains faces to use',
required=True)
parser.add_argument( parser.add_argument(
'--audio', '--audio',
type=str, type=str,
help='Filepath of video/audio file to use as raw audio source', help='Filepath of video/audio file to use as raw audio source',
required=True) required=True)
parser.add_argument('--face',
type=str,
help='Filepath of video/image that contains faces to use',
required=True)
parser.add_argument('--outfile', parser.add_argument('--outfile',
type=str, type=str,
help='Video path to save result. See default for an e.g.', help='Video path to save result. See default for an e.g.',
......
...@@ -6,11 +6,13 @@ import json, subprocess, random, string ...@@ -6,11 +6,13 @@ import json, subprocess, random, string
from tqdm import tqdm from tqdm import tqdm
from glob import glob from glob import glob
import paddle import paddle
from paddle.utils.download import get_weights_path_from_url
from ppgan.faceutils import face_detection from ppgan.faceutils import face_detection
from ppgan.utils import audio from ppgan.utils import audio
from ppgan.models.generators.wav2lip import Wav2Lip from ppgan.models.generators.wav2lip import Wav2Lip
from .base_predictor import BasePredictor from .base_predictor import BasePredictor
WAV2LIP_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/wav2lip_hq.pdparams'
mel_step_size = 16 mel_step_size = 16
...@@ -216,7 +218,11 @@ class Wav2LipPredictor(BasePredictor): ...@@ -216,7 +218,11 @@ class Wav2LipPredictor(BasePredictor):
gen = self.datagen(full_frames.copy(), mel_chunks) gen = self.datagen(full_frames.copy(), mel_chunks)
model = Wav2Lip() model = Wav2Lip()
weights = paddle.load(self.args.checkpoint_path) if self.args.checkpoint_path is None:
model_weights_path = get_weights_path_from_url(WAV2LIP_WEIGHT_URL)
weights = paddle.load(model_weights_path)
else:
weights = paddle.load(self.args.checkpoint_path)
model.load_dict(weights) model.load_dict(weights)
model.eval() model.eval()
print("Model loaded") print("Model loaded")
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.utils.download import get_weights_path_from_url
from .base_model import BaseModel from .base_model import BaseModel
from .builder import MODELS from .builder import MODELS
...@@ -25,7 +26,7 @@ from .wav2lip_model import cosine_loss, get_sync_loss ...@@ -25,7 +26,7 @@ from .wav2lip_model import cosine_loss, get_sync_loss
from ..solver import build_optimizer from ..solver import build_optimizer
from ..modules.init import init_weights from ..modules.init import init_weights
lipsync_weight_path = '/workspace/PaddleGAN/lipsync_expert.pdparams' SYNCNET_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/syncnet.pdparams'
@MODELS.register() @MODELS.register()
...@@ -65,7 +66,8 @@ class Wav2LipModelHq(BaseModel): ...@@ -65,7 +66,8 @@ class Wav2LipModelHq(BaseModel):
distribution='uniform') distribution='uniform')
if self.is_train: if self.is_train:
self.nets['netDS'] = build_discriminator(discriminator_sync) self.nets['netDS'] = build_discriminator(discriminator_sync)
params = paddle.load(lipsync_weight_path) weights_path = get_weights_path_from_url(SYNCNET_WEIGHT_URL)
params = paddle.load(weights_path)
self.nets['netDS'].load_dict(params) self.nets['netDS'].load_dict(params)
self.nets['netDH'] = build_discriminator(discriminator_hq) self.nets['netDH'] = build_discriminator(discriminator_hq)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
from paddle.utils.download import get_weights_path_from_url
from .base_model import BaseModel from .base_model import BaseModel
from .builder import MODELS from .builder import MODELS
...@@ -22,6 +23,7 @@ from .discriminators.builder import build_discriminator ...@@ -22,6 +23,7 @@ from .discriminators.builder import build_discriminator
from ..solver import build_optimizer from ..solver import build_optimizer
from ..modules.init import init_weights from ..modules.init import init_weights
SYNCNET_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/syncnet.pdparams'
syncnet_T = 5 syncnet_T = 5
syncnet_mel_step_size = 16 syncnet_mel_step_size = 16
...@@ -74,7 +76,8 @@ class Wav2LipModel(BaseModel): ...@@ -74,7 +76,8 @@ class Wav2LipModel(BaseModel):
init_weights(self.nets['netG'], distribution='uniform') init_weights(self.nets['netG'], distribution='uniform')
if self.is_train: if self.is_train:
self.nets['netD'] = build_discriminator(discriminator) self.nets['netD'] = build_discriminator(discriminator)
params = paddle.load(lipsync_weight_path) weights_path = get_weights_path_from_url(SYNCNET_WEIGHT_URL)
params = paddle.load(weights_path)
self.nets['netD'].load_dict(params) self.nets['netD'].load_dict(params)
if self.is_train: if self.is_train:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册