Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
3128fb0d
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
大约 1 年 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3128fb0d
编写于
9月 11, 2020
作者:
L
LielinJiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine print log
上级
bb71d1a4
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
98 addition
and
56 deletion
+98
-56
applications/DAIN/util.py
applications/DAIN/util.py
+7
-8
applications/DeOldify/hook.py
applications/DeOldify/hook.py
+48
-20
applications/DeOldify/predict.py
applications/DeOldify/predict.py
+21
-11
applications/EDVR/predict.py
applications/EDVR/predict.py
+7
-8
applications/RealSR/predict.py
applications/RealSR/predict.py
+7
-8
applications/tools/video-enhance.py
applications/tools/video-enhance.py
+7
-0
ppgan/engine/trainer.py
ppgan/engine/trainer.py
+1
-1
未找到文件。
applications/DAIN/util.py
浏览文件 @
3128fb0d
...
@@ -55,12 +55,12 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
...
@@ -55,12 +55,12 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
cmd
=
ffmpeg
+
[
' -i '
,
vid_path
,
' -start_number '
,
' 0 '
,
outformat
]
cmd
=
ffmpeg
+
[
' -i '
,
vid_path
,
' -start_number '
,
' 0 '
,
outformat
]
cmd
=
''
.
join
(
cmd
)
cmd
=
''
.
join
(
cmd
)
print
(
cmd
)
if
os
.
system
(
cmd
)
==
0
:
if
os
.
system
(
cmd
)
==
0
:
p
rint
(
'Video: {} done'
.
format
(
vid_name
))
p
ass
else
:
else
:
print
(
'
V
ideo: {} error'
.
format
(
vid_name
))
print
(
'
ffmpeg process v
ideo: {} error'
.
format
(
vid_name
))
print
(
''
)
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
return
out_full_path
return
out_full_path
...
@@ -72,13 +72,12 @@ def frames_to_video_ffmpeg(framepath, videopath, r):
...
@@ -72,13 +72,12 @@ def frames_to_video_ffmpeg(framepath, videopath, r):
' libx264 '
,
' -pix_fmt '
,
' yuv420p '
,
' -crf '
,
' 16 '
,
videopath
' libx264 '
,
' -pix_fmt '
,
' yuv420p '
,
' -crf '
,
' 16 '
,
videopath
]
]
cmd
=
''
.
join
(
cmd
)
cmd
=
''
.
join
(
cmd
)
print
(
cmd
)
if
os
.
system
(
cmd
)
==
0
:
if
os
.
system
(
cmd
)
==
0
:
p
rint
(
'Video: {} done'
.
format
(
videopath
))
p
ass
else
:
else
:
print
(
'
V
ideo: {} error'
.
format
(
videopath
))
print
(
'
ffmpeg process v
ideo: {} error'
.
format
(
videopath
))
print
(
''
)
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
...
...
applications/DeOldify/hook.py
浏览文件 @
3128fb0d
...
@@ -3,14 +3,16 @@ import numpy as np
...
@@ -3,14 +3,16 @@ import numpy as np
import
paddle
import
paddle
import
paddle.nn
as
nn
import
paddle.nn
as
nn
def
is_listy
(
x
):
def
is_listy
(
x
):
return
isinstance
(
x
,
(
tuple
,
list
))
return
isinstance
(
x
,
(
tuple
,
list
))
class
Hook
():
class
Hook
():
"Create a hook on `m` with `hook_func`."
"Create a hook on `m` with `hook_func`."
def
__init__
(
self
,
m
,
hook_func
,
is_forward
=
True
,
detach
=
True
):
def
__init__
(
self
,
m
,
hook_func
,
is_forward
=
True
,
detach
=
True
):
self
.
hook_func
,
self
.
detach
,
self
.
stored
=
hook_func
,
detach
,
None
self
.
hook_func
,
self
.
detach
,
self
.
stored
=
hook_func
,
detach
,
None
f
=
m
.
register_forward_post_hook
if
is_forward
else
m
.
register_backward_hook
f
=
m
.
register_forward_post_hook
if
is_forward
else
m
.
register_backward_hook
self
.
hook
=
f
(
self
.
hook_fn
)
self
.
hook
=
f
(
self
.
hook_fn
)
self
.
removed
=
False
self
.
removed
=
False
...
@@ -18,64 +20,90 @@ class Hook():
...
@@ -18,64 +20,90 @@ class Hook():
def
hook_fn
(
self
,
module
,
input
,
output
):
def
hook_fn
(
self
,
module
,
input
,
output
):
"Applies `hook_func` to `module`, `input`, `output`."
"Applies `hook_func` to `module`, `input`, `output`."
if
self
.
detach
:
if
self
.
detach
:
input
=
(
o
.
detach
()
for
o
in
input
)
if
is_listy
(
input
)
else
input
.
detach
()
input
=
(
o
.
detach
()
output
=
(
o
.
detach
()
for
o
in
output
)
if
is_listy
(
output
)
else
output
.
detach
()
for
o
in
input
)
if
is_listy
(
input
)
else
input
.
detach
()
output
=
(
o
.
detach
()
for
o
in
output
)
if
is_listy
(
output
)
else
output
.
detach
()
self
.
stored
=
self
.
hook_func
(
module
,
input
,
output
)
self
.
stored
=
self
.
hook_func
(
module
,
input
,
output
)
def
remove
(
self
):
def
remove
(
self
):
"Remove the hook from the model."
"Remove the hook from the model."
if
not
self
.
removed
:
if
not
self
.
removed
:
self
.
hook
.
remove
()
self
.
hook
.
remove
()
self
.
removed
=
True
self
.
removed
=
True
def
__enter__
(
self
,
*
args
):
return
self
def
__exit__
(
self
,
*
args
):
self
.
remove
()
def
__enter__
(
self
,
*
args
):
return
self
def
__exit__
(
self
,
*
args
):
self
.
remove
()
class
Hooks
():
class
Hooks
():
"Create several hooks on the modules in `ms` with `hook_func`."
"Create several hooks on the modules in `ms` with `hook_func`."
def
__init__
(
self
,
ms
,
hook_func
,
is_forward
=
True
,
detach
=
True
):
def
__init__
(
self
,
ms
,
hook_func
,
is_forward
=
True
,
detach
=
True
):
self
.
hooks
=
[]
self
.
hooks
=
[]
try
:
try
:
for
m
in
ms
:
for
m
in
ms
:
self
.
hooks
.
append
(
Hook
(
m
,
hook_func
,
is_forward
,
detach
))
self
.
hooks
.
append
(
Hook
(
m
,
hook_func
,
is_forward
,
detach
))
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
pass
def
__getitem__
(
self
,
i
:
int
)
->
Hook
:
return
self
.
hooks
[
i
]
def
__len__
(
self
)
->
int
:
return
len
(
self
.
hooks
)
def
__iter__
(
self
):
return
iter
(
self
.
hooks
)
def
__getitem__
(
self
,
i
:
int
)
->
Hook
:
return
self
.
hooks
[
i
]
def
__len__
(
self
)
->
int
:
return
len
(
self
.
hooks
)
def
__iter__
(
self
):
return
iter
(
self
.
hooks
)
@
property
@
property
def
stored
(
self
):
return
[
o
.
stored
for
o
in
self
]
def
stored
(
self
):
return
[
o
.
stored
for
o
in
self
]
def
remove
(
self
):
def
remove
(
self
):
"Remove the hooks from the model."
"Remove the hooks from the model."
for
h
in
self
.
hooks
:
h
.
remove
()
for
h
in
self
.
hooks
:
h
.
remove
()
def
__enter__
(
self
,
*
args
):
return
self
def
__enter__
(
self
,
*
args
):
def
__exit__
(
self
,
*
args
):
self
.
remove
()
return
self
def
_hook_inner
(
m
,
i
,
o
):
return
o
if
isinstance
(
o
,
paddle
.
framework
.
Variable
)
else
o
if
is_listy
(
o
)
else
list
(
o
)
def
__exit__
(
self
,
*
args
):
self
.
remove
()
def
hook_output
(
module
,
detach
=
True
,
grad
=
False
):
def
_hook_inner
(
m
,
i
,
o
):
return
o
if
isinstance
(
o
,
paddle
.
framework
.
Variable
)
else
o
if
is_listy
(
o
)
else
list
(
o
)
def
hook_output
(
module
,
detach
=
True
,
grad
=
False
):
"Return a `Hook` that stores activations of `module` in `self.stored`"
"Return a `Hook` that stores activations of `module` in `self.stored`"
return
Hook
(
module
,
_hook_inner
,
detach
=
detach
,
is_forward
=
not
grad
)
return
Hook
(
module
,
_hook_inner
,
detach
=
detach
,
is_forward
=
not
grad
)
def
hook_outputs
(
modules
,
detach
=
True
,
grad
=
False
):
def
hook_outputs
(
modules
,
detach
=
True
,
grad
=
False
):
"Return `Hooks` that store activations of all `modules` in `self.stored`"
"Return `Hooks` that store activations of all `modules` in `self.stored`"
return
Hooks
(
modules
,
_hook_inner
,
detach
=
detach
,
is_forward
=
not
grad
)
return
Hooks
(
modules
,
_hook_inner
,
detach
=
detach
,
is_forward
=
not
grad
)
def
model_sizes
(
m
,
size
=
(
64
,
64
)):
def
model_sizes
(
m
,
size
=
(
64
,
64
)):
"Pass a dummy input through the model `m` to get the various sizes of activations."
"Pass a dummy input through the model `m` to get the various sizes of activations."
with
hook_outputs
(
m
)
as
hooks
:
with
hook_outputs
(
m
)
as
hooks
:
x
=
dummy_eval
(
m
,
size
)
x
=
dummy_eval
(
m
,
size
)
return
[
o
.
stored
.
shape
for
o
in
hooks
]
return
[
o
.
stored
.
shape
for
o
in
hooks
]
def
dummy_eval
(
m
,
size
=
(
64
,
64
)):
def
dummy_eval
(
m
,
size
=
(
64
,
64
)):
"Pass a `dummy_batch` in evaluation mode in `m` with `size`."
"Pass a `dummy_batch` in evaluation mode in `m` with `size`."
m
.
eval
()
m
.
eval
()
return
m
(
dummy_batch
(
size
))
return
m
(
dummy_batch
(
size
))
def
dummy_batch
(
size
=
(
64
,
64
),
ch_in
=
3
):
def
dummy_batch
(
size
=
(
64
,
64
),
ch_in
=
3
):
"Create a dummy batch to go through `m` with `size`."
"Create a dummy batch to go through `m` with `size`."
arr
=
np
.
random
.
rand
(
1
,
ch_in
,
*
size
).
astype
(
'float32'
)
*
2
-
1
arr
=
np
.
random
.
rand
(
1
,
ch_in
,
*
size
).
astype
(
'float32'
)
*
2
-
1
return
paddle
.
to_tensor
(
arr
)
return
paddle
.
to_tensor
(
arr
)
applications/DeOldify/predict.py
浏览文件 @
3128fb0d
...
@@ -20,6 +20,10 @@ from paddle.utils.download import get_path_from_url
...
@@ -20,6 +20,10 @@ from paddle.utils.download import get_path_from_url
parser
=
argparse
.
ArgumentParser
(
description
=
'DeOldify'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'DeOldify'
)
parser
.
add_argument
(
'--input'
,
type
=
str
,
default
=
'none'
,
help
=
'Input video'
)
parser
.
add_argument
(
'--input'
,
type
=
str
,
default
=
'none'
,
help
=
'Input video'
)
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
'output'
,
help
=
'output dir'
)
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
'output'
,
help
=
'output dir'
)
parser
.
add_argument
(
'--render_factor'
,
type
=
int
,
default
=
32
,
help
=
'model inputsize=render_factor*16'
)
parser
.
add_argument
(
'--weight_path'
,
parser
.
add_argument
(
'--weight_path'
,
type
=
str
,
type
=
str
,
default
=
None
,
default
=
None
,
...
@@ -35,20 +39,25 @@ def frames_to_video_ffmpeg(framepath, videopath, r):
...
@@ -35,20 +39,25 @@ def frames_to_video_ffmpeg(framepath, videopath, r):
' libx264 '
,
' -pix_fmt '
,
' yuv420p '
,
' -crf '
,
' 16 '
,
videopath
' libx264 '
,
' -pix_fmt '
,
' yuv420p '
,
' -crf '
,
' 16 '
,
videopath
]
]
cmd
=
''
.
join
(
cmd
)
cmd
=
''
.
join
(
cmd
)
print
(
cmd
)
if
os
.
system
(
cmd
)
==
0
:
if
os
.
system
(
cmd
)
==
0
:
p
rint
(
'Video: {} done'
.
format
(
videopath
))
p
ass
else
:
else
:
print
(
'
V
ideo: {} error'
.
format
(
videopath
))
print
(
'
ffmpeg process v
ideo: {} error'
.
format
(
videopath
))
print
(
''
)
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
class
DeOldifyPredictor
():
class
DeOldifyPredictor
():
def
__init__
(
self
,
input
,
output
,
batch_size
=
1
,
weight_path
=
None
):
def
__init__
(
self
,
input
,
output
,
batch_size
=
1
,
weight_path
=
None
,
render_factor
=
32
):
self
.
input
=
input
self
.
input
=
input
self
.
output
=
os
.
path
.
join
(
output
,
'DeOldify'
)
self
.
output
=
os
.
path
.
join
(
output
,
'DeOldify'
)
self
.
render_factor
=
render_factor
self
.
model
=
build_model
()
self
.
model
=
build_model
()
if
weight_path
is
None
:
if
weight_path
is
None
:
weight_path
=
get_path_from_url
(
DeOldify_weight_url
,
cur_path
)
weight_path
=
get_path_from_url
(
DeOldify_weight_url
,
cur_path
)
...
@@ -93,7 +102,7 @@ class DeOldifyPredictor():
...
@@ -93,7 +102,7 @@ class DeOldifyPredictor():
def
run_single
(
self
,
img_path
):
def
run_single
(
self
,
img_path
):
ori_img
=
Image
.
open
(
img_path
).
convert
(
'LA'
).
convert
(
'RGB'
)
ori_img
=
Image
.
open
(
img_path
).
convert
(
'LA'
).
convert
(
'RGB'
)
img
=
self
.
norm
(
ori_img
)
img
=
self
.
norm
(
ori_img
,
self
.
render_factor
)
x
=
paddle
.
to_tensor
(
img
[
np
.
newaxis
,
...])
x
=
paddle
.
to_tensor
(
img
[
np
.
newaxis
,
...])
out
=
self
.
model
(
x
)
out
=
self
.
model
(
x
)
...
@@ -158,12 +167,12 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
...
@@ -158,12 +167,12 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
cmd
=
ffmpeg
+
[
' -i '
,
vid_path
,
' -start_number '
,
' 0 '
,
outformat
]
cmd
=
ffmpeg
+
[
' -i '
,
vid_path
,
' -start_number '
,
' 0 '
,
outformat
]
cmd
=
''
.
join
(
cmd
)
cmd
=
''
.
join
(
cmd
)
print
(
cmd
)
if
os
.
system
(
cmd
)
==
0
:
if
os
.
system
(
cmd
)
==
0
:
p
rint
(
'Video: {} done'
.
format
(
vid_name
))
p
ass
else
:
else
:
print
(
'
V
ideo: {} error'
.
format
(
vid_name
))
print
(
'
ffmpeg process v
ideo: {} error'
.
format
(
vid_name
))
print
(
''
)
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
return
out_full_path
return
out_full_path
...
@@ -174,7 +183,8 @@ if __name__ == '__main__':
...
@@ -174,7 +183,8 @@ if __name__ == '__main__':
predictor
=
DeOldifyPredictor
(
args
.
input
,
predictor
=
DeOldifyPredictor
(
args
.
input
,
args
.
output
,
args
.
output
,
weight_path
=
args
.
weight_path
)
weight_path
=
args
.
weight_path
,
render_factor
=
args
.
render_factor
)
frames_path
,
temp_video_path
=
predictor
.
run
()
frames_path
,
temp_video_path
=
predictor
.
run
()
print
(
'output video path:'
,
temp_video_path
)
print
(
'output video path:'
,
temp_video_path
)
applications/EDVR/predict.py
浏览文件 @
3128fb0d
...
@@ -91,12 +91,12 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
...
@@ -91,12 +91,12 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
cmd
=
ffmpeg
+
[
' -i '
,
vid_path
,
' -start_number '
,
' 0 '
,
outformat
]
cmd
=
ffmpeg
+
[
' -i '
,
vid_path
,
' -start_number '
,
' 0 '
,
outformat
]
cmd
=
''
.
join
(
cmd
)
cmd
=
''
.
join
(
cmd
)
print
(
cmd
)
if
os
.
system
(
cmd
)
==
0
:
if
os
.
system
(
cmd
)
==
0
:
p
rint
(
'Video: {} done'
.
format
(
vid_name
))
p
ass
else
:
else
:
print
(
'
V
ideo: {} error'
.
format
(
vid_name
))
print
(
'
ffmpeg process v
ideo: {} error'
.
format
(
vid_name
))
print
(
''
)
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
return
out_full_path
return
out_full_path
...
@@ -108,13 +108,12 @@ def frames_to_video_ffmpeg(framepath, videopath, r):
...
@@ -108,13 +108,12 @@ def frames_to_video_ffmpeg(framepath, videopath, r):
' libx264 '
,
' -pix_fmt '
,
' yuv420p '
,
' -crf '
,
' 16 '
,
videopath
' libx264 '
,
' -pix_fmt '
,
' yuv420p '
,
' -crf '
,
' 16 '
,
videopath
]
]
cmd
=
''
.
join
(
cmd
)
cmd
=
''
.
join
(
cmd
)
print
(
cmd
)
if
os
.
system
(
cmd
)
==
0
:
if
os
.
system
(
cmd
)
==
0
:
p
rint
(
'Video: {} done'
.
format
(
videopath
))
p
ass
else
:
else
:
print
(
'
V
ideo: {} error'
.
format
(
videopath
))
print
(
'
ffmpeg process v
ideo: {} error'
.
format
(
videopath
))
print
(
''
)
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
...
...
applications/RealSR/predict.py
浏览文件 @
3128fb0d
...
@@ -34,13 +34,12 @@ def frames_to_video_ffmpeg(framepath, videopath, r):
...
@@ -34,13 +34,12 @@ def frames_to_video_ffmpeg(framepath, videopath, r):
' libx264 '
,
' -pix_fmt '
,
' yuv420p '
,
' -crf '
,
' 16 '
,
videopath
' libx264 '
,
' -pix_fmt '
,
' yuv420p '
,
' -crf '
,
' 16 '
,
videopath
]
]
cmd
=
''
.
join
(
cmd
)
cmd
=
''
.
join
(
cmd
)
print
(
cmd
)
if
os
.
system
(
cmd
)
==
0
:
if
os
.
system
(
cmd
)
==
0
:
p
rint
(
'Video: {} done'
.
format
(
videopath
))
p
ass
else
:
else
:
print
(
'
V
ideo: {} error'
.
format
(
videopath
))
print
(
'
ffmpeg process v
ideo: {} error'
.
format
(
videopath
))
print
(
''
)
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
...
@@ -129,12 +128,12 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
...
@@ -129,12 +128,12 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
cmd
=
ffmpeg
+
[
' -i '
,
vid_path
,
' -start_number '
,
' 0 '
,
outformat
]
cmd
=
ffmpeg
+
[
' -i '
,
vid_path
,
' -start_number '
,
' 0 '
,
outformat
]
cmd
=
''
.
join
(
cmd
)
cmd
=
''
.
join
(
cmd
)
print
(
cmd
)
if
os
.
system
(
cmd
)
==
0
:
if
os
.
system
(
cmd
)
==
0
:
p
rint
(
'Video: {} done'
.
format
(
vid_name
))
p
ass
else
:
else
:
print
(
'
V
ideo: {} error'
.
format
(
vid_name
))
print
(
'
ffmpeg process v
ideo: {} error'
.
format
(
vid_name
))
print
(
''
)
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
return
out_full_path
return
out_full_path
...
...
applications/tools/video-enhance.py
浏览文件 @
3128fb0d
...
@@ -51,6 +51,11 @@ parser.add_argument('--mindim',
...
@@ -51,6 +51,11 @@ parser.add_argument('--mindim',
type
=
int
,
type
=
int
,
default
=
360
,
default
=
360
,
help
=
'Length of minimum image edges'
)
help
=
'Length of minimum image edges'
)
# DeOldify args
parser
.
add_argument
(
'--render_factor'
,
type
=
int
,
default
=
32
,
help
=
'model inputsize=render_factor*16'
)
#process order support model name:[DAIN, DeepRemaster, DeOldify, RealSR, EDVR]
#process order support model name:[DAIN, DeepRemaster, DeOldify, RealSR, EDVR]
parser
.
add_argument
(
'--proccess_order'
,
parser
.
add_argument
(
'--proccess_order'
,
type
=
str
,
type
=
str
,
...
@@ -65,6 +70,7 @@ if __name__ == "__main__":
...
@@ -65,6 +70,7 @@ if __name__ == "__main__":
temp_video_path
=
None
temp_video_path
=
None
for
order
in
orders
:
for
order
in
orders
:
print
(
'Model {} proccess start..'
.
format
(
order
))
if
temp_video_path
is
None
:
if
temp_video_path
is
None
:
temp_video_path
=
args
.
input
temp_video_path
=
args
.
input
if
order
==
'DAIN'
:
if
order
==
'DAIN'
:
...
@@ -106,3 +112,4 @@ if __name__ == "__main__":
...
@@ -106,3 +112,4 @@ if __name__ == "__main__":
print
(
'Model {} output frames path:'
.
format
(
order
),
frames_path
)
print
(
'Model {} output frames path:'
.
format
(
order
),
frames_path
)
print
(
'Model {} output video path:'
.
format
(
order
),
temp_video_path
)
print
(
'Model {} output video path:'
.
format
(
order
),
temp_video_path
)
print
(
'Model {} proccess done!'
.
format
(
order
))
ppgan/engine/trainer.py
浏览文件 @
3128fb0d
...
@@ -47,7 +47,7 @@ class Trainer:
...
@@ -47,7 +47,7 @@ class Trainer:
self
.
time_count
=
{}
self
.
time_count
=
{}
def
distributed_data_parallel
(
self
):
def
distributed_data_parallel
(
self
):
strategy
=
paddle
.
prepare_context
()
strategy
=
paddle
.
distributed
.
prepare_context
()
for
name
in
self
.
model
.
model_names
:
for
name
in
self
.
model
.
model_names
:
if
isinstance
(
name
,
str
):
if
isinstance
(
name
,
str
):
net
=
getattr
(
self
.
model
,
'net'
+
name
)
net
=
getattr
(
self
.
model
,
'net'
+
name
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录