Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
43ffc949
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
大约 1 年 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
43ffc949
编写于
1月 18, 2021
作者:
L
lijianshe02
提交者:
GitHub
1月 18, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add automatic weight download (#146)
* add automatic weight download
上级
edd62113
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
20 addition
and
10 deletion
+20
-10
applications/tools/wav2lip.py
applications/tools/wav2lip.py
+5
-6
ppgan/apps/wav2lip_predictor.py
ppgan/apps/wav2lip_predictor.py
+7
-1
ppgan/models/wav2lip_hq_model.py
ppgan/models/wav2lip_hq_model.py
+4
-2
ppgan/models/wav2lip_model.py
ppgan/models/wav2lip_model.py
+4
-1
未找到文件。
applications/tools/wav2lip.py
浏览文件 @
43ffc949
...
@@ -10,17 +10,16 @@ parser = argparse.ArgumentParser(
...
@@ -10,17 +10,16 @@ parser = argparse.ArgumentParser(
parser
.
add_argument
(
'--checkpoint_path'
,
parser
.
add_argument
(
'--checkpoint_path'
,
type
=
str
,
type
=
str
,
help
=
'Name of saved checkpoint to load weights from'
,
help
=
'Name of saved checkpoint to load weights from'
,
required
=
True
)
default
=
None
)
parser
.
add_argument
(
'--face'
,
type
=
str
,
help
=
'Filepath of video/image that contains faces to use'
,
required
=
True
)
parser
.
add_argument
(
parser
.
add_argument
(
'--audio'
,
'--audio'
,
type
=
str
,
type
=
str
,
help
=
'Filepath of video/audio file to use as raw audio source'
,
help
=
'Filepath of video/audio file to use as raw audio source'
,
required
=
True
)
required
=
True
)
parser
.
add_argument
(
'--face'
,
type
=
str
,
help
=
'Filepath of video/image that contains faces to use'
,
required
=
True
)
parser
.
add_argument
(
'--outfile'
,
parser
.
add_argument
(
'--outfile'
,
type
=
str
,
type
=
str
,
help
=
'Video path to save result. See default for an e.g.'
,
help
=
'Video path to save result. See default for an e.g.'
,
...
...
ppgan/apps/wav2lip_predictor.py
浏览文件 @
43ffc949
...
@@ -6,11 +6,13 @@ import json, subprocess, random, string
...
@@ -6,11 +6,13 @@ import json, subprocess, random, string
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
glob
import
glob
from
glob
import
glob
import
paddle
import
paddle
from
paddle.utils.download
import
get_weights_path_from_url
from
ppgan.faceutils
import
face_detection
from
ppgan.faceutils
import
face_detection
from
ppgan.utils
import
audio
from
ppgan.utils
import
audio
from
ppgan.models.generators.wav2lip
import
Wav2Lip
from
ppgan.models.generators.wav2lip
import
Wav2Lip
from
.base_predictor
import
BasePredictor
from
.base_predictor
import
BasePredictor
WAV2LIP_WEIGHT_URL
=
'https://paddlegan.bj.bcebos.com/models/wav2lip_hq.pdparams'
mel_step_size
=
16
mel_step_size
=
16
...
@@ -216,7 +218,11 @@ class Wav2LipPredictor(BasePredictor):
...
@@ -216,7 +218,11 @@ class Wav2LipPredictor(BasePredictor):
gen
=
self
.
datagen
(
full_frames
.
copy
(),
mel_chunks
)
gen
=
self
.
datagen
(
full_frames
.
copy
(),
mel_chunks
)
model
=
Wav2Lip
()
model
=
Wav2Lip
()
weights
=
paddle
.
load
(
self
.
args
.
checkpoint_path
)
if
self
.
args
.
checkpoint_path
is
None
:
model_weights_path
=
get_weights_path_from_url
(
WAV2LIP_WEIGHT_URL
)
weights
=
paddle
.
load
(
model_weights_path
)
else
:
weights
=
paddle
.
load
(
self
.
args
.
checkpoint_path
)
model
.
load_dict
(
weights
)
model
.
load_dict
(
weights
)
model
.
eval
()
model
.
eval
()
print
(
"Model loaded"
)
print
(
"Model loaded"
)
...
...
ppgan/models/wav2lip_hq_model.py
浏览文件 @
43ffc949
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
import
paddle
import
paddle
import
paddle.nn.functional
as
F
import
paddle.nn.functional
as
F
from
paddle.utils.download
import
get_weights_path_from_url
from
.base_model
import
BaseModel
from
.base_model
import
BaseModel
from
.builder
import
MODELS
from
.builder
import
MODELS
...
@@ -25,7 +26,7 @@ from .wav2lip_model import cosine_loss, get_sync_loss
...
@@ -25,7 +26,7 @@ from .wav2lip_model import cosine_loss, get_sync_loss
from
..solver
import
build_optimizer
from
..solver
import
build_optimizer
from
..modules.init
import
init_weights
from
..modules.init
import
init_weights
lipsync_weight_path
=
'/workspace/PaddleGAN/lipsync_exper
t.pdparams'
SYNCNET_WEIGHT_URL
=
'https://paddlegan.bj.bcebos.com/models/syncne
t.pdparams'
@
MODELS
.
register
()
@
MODELS
.
register
()
...
@@ -65,7 +66,8 @@ class Wav2LipModelHq(BaseModel):
...
@@ -65,7 +66,8 @@ class Wav2LipModelHq(BaseModel):
distribution
=
'uniform'
)
distribution
=
'uniform'
)
if
self
.
is_train
:
if
self
.
is_train
:
self
.
nets
[
'netDS'
]
=
build_discriminator
(
discriminator_sync
)
self
.
nets
[
'netDS'
]
=
build_discriminator
(
discriminator_sync
)
params
=
paddle
.
load
(
lipsync_weight_path
)
weights_path
=
get_weights_path_from_url
(
SYNCNET_WEIGHT_URL
)
params
=
paddle
.
load
(
weights_path
)
self
.
nets
[
'netDS'
].
load_dict
(
params
)
self
.
nets
[
'netDS'
].
load_dict
(
params
)
self
.
nets
[
'netDH'
]
=
build_discriminator
(
discriminator_hq
)
self
.
nets
[
'netDH'
]
=
build_discriminator
(
discriminator_hq
)
...
...
ppgan/models/wav2lip_model.py
浏览文件 @
43ffc949
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
paddle
import
paddle
from
paddle.utils.download
import
get_weights_path_from_url
from
.base_model
import
BaseModel
from
.base_model
import
BaseModel
from
.builder
import
MODELS
from
.builder
import
MODELS
...
@@ -22,6 +23,7 @@ from .discriminators.builder import build_discriminator
...
@@ -22,6 +23,7 @@ from .discriminators.builder import build_discriminator
from
..solver
import
build_optimizer
from
..solver
import
build_optimizer
from
..modules.init
import
init_weights
from
..modules.init
import
init_weights
SYNCNET_WEIGHT_URL
=
'https://paddlegan.bj.bcebos.com/models/syncnet.pdparams'
syncnet_T
=
5
syncnet_T
=
5
syncnet_mel_step_size
=
16
syncnet_mel_step_size
=
16
...
@@ -74,7 +76,8 @@ class Wav2LipModel(BaseModel):
...
@@ -74,7 +76,8 @@ class Wav2LipModel(BaseModel):
init_weights
(
self
.
nets
[
'netG'
],
distribution
=
'uniform'
)
init_weights
(
self
.
nets
[
'netG'
],
distribution
=
'uniform'
)
if
self
.
is_train
:
if
self
.
is_train
:
self
.
nets
[
'netD'
]
=
build_discriminator
(
discriminator
)
self
.
nets
[
'netD'
]
=
build_discriminator
(
discriminator
)
params
=
paddle
.
load
(
lipsync_weight_path
)
weights_path
=
get_weights_path_from_url
(
SYNCNET_WEIGHT_URL
)
params
=
paddle
.
load
(
weights_path
)
self
.
nets
[
'netD'
].
load_dict
(
params
)
self
.
nets
[
'netD'
].
load_dict
(
params
)
if
self
.
is_train
:
if
self
.
is_train
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录