Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
b306aa73
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
1 年多 前同步成功
通知
98
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b306aa73
编写于
9月 03, 2020
作者:
L
LielinJiang
提交者:
GitHub
9月 03, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #16 from LielinJiang/adapt-to-2.0-api
Adapt to api 2.0 again
上级
4a3ba224
abd3250d
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
466 addition
and
397 deletion
+466
-397
applications/DAIN/predict.py
applications/DAIN/predict.py
+3
-3
applications/DeOldify/predict.py
applications/DeOldify/predict.py
+27
-34
applications/DeepRemaster/predict.py
applications/DeepRemaster/predict.py
+222
-182
applications/EDVR/predict.py
applications/EDVR/predict.py
+41
-46
configs/cyclegan_cityscapes.yaml
configs/cyclegan_cityscapes.yaml
+6
-6
configs/cyclegan_horse2zebra.yaml
configs/cyclegan_horse2zebra.yaml
+6
-6
configs/pix2pix_cityscapes.yaml
configs/pix2pix_cityscapes.yaml
+7
-7
configs/pix2pix_cityscapes_2gpus.yaml
configs/pix2pix_cityscapes_2gpus.yaml
+6
-6
configs/pix2pix_facades.yaml
configs/pix2pix_facades.yaml
+6
-5
ppgan/datasets/base_dataset.py
ppgan/datasets/base_dataset.py
+10
-6
ppgan/datasets/builder.py
ppgan/datasets/builder.py
+18
-20
ppgan/engine/trainer.py
ppgan/engine/trainer.py
+27
-21
ppgan/models/base_model.py
ppgan/models/base_model.py
+10
-5
ppgan/models/cycle_gan_model.py
ppgan/models/cycle_gan_model.py
+33
-21
ppgan/models/pix2pix_model.py
ppgan/models/pix2pix_model.py
+17
-12
ppgan/solver/lr_scheduler.py
ppgan/solver/lr_scheduler.py
+16
-5
ppgan/solver/optimizer.py
ppgan/solver/optimizer.py
+4
-6
ppgan/utils/logger.py
ppgan/utils/logger.py
+4
-4
ppgan/utils/setup.py
ppgan/utils/setup.py
+3
-2
未找到文件。
applications/DAIN/predict.py
浏览文件 @
b306aa73
...
@@ -11,7 +11,7 @@ from imageio import imread, imsave
...
@@ -11,7 +11,7 @@ from imageio import imread, imsave
import
cv2
import
cv2
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.
incubate.hapi
.download
import
get_path_from_url
from
paddle.
utils
.download
import
get_path_from_url
import
networks
import
networks
from
util
import
*
from
util
import
*
...
@@ -19,6 +19,7 @@ from my_args import parser
...
@@ -19,6 +19,7 @@ from my_args import parser
DAIN_WEIGHT_URL
=
'https://paddlegan.bj.bcebos.com/applications/DAIN_weight.tar'
DAIN_WEIGHT_URL
=
'https://paddlegan.bj.bcebos.com/applications/DAIN_weight.tar'
def
infer_engine
(
model_dir
,
def
infer_engine
(
model_dir
,
run_mode
=
'fluid'
,
run_mode
=
'fluid'
,
batch_size
=
1
,
batch_size
=
1
,
...
@@ -91,7 +92,6 @@ class VideoFrameInterp(object):
...
@@ -91,7 +92,6 @@ class VideoFrameInterp(object):
self
.
exe
,
self
.
program
,
self
.
fetch_targets
=
executor
(
model_path
,
self
.
exe
,
self
.
program
,
self
.
fetch_targets
=
executor
(
model_path
,
use_gpu
=
use_gpu
)
use_gpu
=
use_gpu
)
def
run
(
self
):
def
run
(
self
):
frame_path_input
=
os
.
path
.
join
(
self
.
output_path
,
'frames-input'
)
frame_path_input
=
os
.
path
.
join
(
self
.
output_path
,
'frames-input'
)
frame_path_interpolated
=
os
.
path
.
join
(
self
.
output_path
,
frame_path_interpolated
=
os
.
path
.
join
(
self
.
output_path
,
...
@@ -272,7 +272,7 @@ class VideoFrameInterp(object):
...
@@ -272,7 +272,7 @@ class VideoFrameInterp(object):
os
.
remove
(
video_pattern_output
)
os
.
remove
(
video_pattern_output
)
frames_to_video_ffmpeg
(
frame_pattern_combined
,
video_pattern_output
,
frames_to_video_ffmpeg
(
frame_pattern_combined
,
video_pattern_output
,
r2
)
r2
)
return
frame_pattern_combined
,
video_pattern_output
return
frame_pattern_combined
,
video_pattern_output
...
...
applications/DeOldify/predict.py
浏览文件 @
b306aa73
...
@@ -15,15 +15,19 @@ from PIL import Image
...
@@ -15,15 +15,19 @@ from PIL import Image
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
paddle
import
fluid
from
paddle
import
fluid
from
model
import
build_model
from
model
import
build_model
from
paddle.
incubate.hapi
.download
import
get_path_from_url
from
paddle.
utils
.download
import
get_path_from_url
parser
=
argparse
.
ArgumentParser
(
description
=
'DeOldify'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'DeOldify'
)
parser
.
add_argument
(
'--input'
,
type
=
str
,
default
=
'none'
,
help
=
'Input video'
)
parser
.
add_argument
(
'--input'
,
type
=
str
,
default
=
'none'
,
help
=
'Input video'
)
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
'output'
,
help
=
'output dir'
)
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
'output'
,
help
=
'output dir'
)
parser
.
add_argument
(
'--weight_path'
,
type
=
str
,
default
=
'none'
,
help
=
'Path to the reference image directory'
)
parser
.
add_argument
(
'--weight_path'
,
type
=
str
,
default
=
'none'
,
help
=
'Path to the reference image directory'
)
DeOldify_weight_url
=
'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams'
DeOldify_weight_url
=
'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams'
def
frames_to_video_ffmpeg
(
framepath
,
videopath
,
r
):
def
frames_to_video_ffmpeg
(
framepath
,
videopath
,
r
):
ffmpeg
=
[
'ffmpeg '
,
' -loglevel '
,
' error '
]
ffmpeg
=
[
'ffmpeg '
,
' -loglevel '
,
' error '
]
cmd
=
ffmpeg
+
[
cmd
=
ffmpeg
+
[
...
@@ -56,9 +60,9 @@ class DeOldifyPredictor():
...
@@ -56,9 +60,9 @@ class DeOldifyPredictor():
def
norm
(
self
,
img
,
render_factor
=
32
,
render_base
=
16
):
def
norm
(
self
,
img
,
render_factor
=
32
,
render_base
=
16
):
target_size
=
render_factor
*
render_base
target_size
=
render_factor
*
render_base
img
=
img
.
resize
((
target_size
,
target_size
),
resample
=
Image
.
BILINEAR
)
img
=
img
.
resize
((
target_size
,
target_size
),
resample
=
Image
.
BILINEAR
)
img
=
np
.
array
(
img
).
transpose
([
2
,
0
,
1
]).
astype
(
'float32'
)
/
255.0
img
=
np
.
array
(
img
).
transpose
([
2
,
0
,
1
]).
astype
(
'float32'
)
/
255.0
img_mean
=
np
.
array
([
0.485
,
0.456
,
0.406
]).
reshape
((
3
,
1
,
1
))
img_mean
=
np
.
array
([
0.485
,
0.456
,
0.406
]).
reshape
((
3
,
1
,
1
))
img_std
=
np
.
array
([
0.229
,
0.224
,
0.225
]).
reshape
((
3
,
1
,
1
))
img_std
=
np
.
array
([
0.229
,
0.224
,
0.225
]).
reshape
((
3
,
1
,
1
))
...
@@ -69,13 +73,13 @@ class DeOldifyPredictor():
...
@@ -69,13 +73,13 @@ class DeOldifyPredictor():
def
denorm
(
self
,
img
):
def
denorm
(
self
,
img
):
img_mean
=
np
.
array
([
0.485
,
0.456
,
0.406
]).
reshape
((
3
,
1
,
1
))
img_mean
=
np
.
array
([
0.485
,
0.456
,
0.406
]).
reshape
((
3
,
1
,
1
))
img_std
=
np
.
array
([
0.229
,
0.224
,
0.225
]).
reshape
((
3
,
1
,
1
))
img_std
=
np
.
array
([
0.229
,
0.224
,
0.225
]).
reshape
((
3
,
1
,
1
))
img
*=
img_std
img
*=
img_std
img
+=
img_mean
img
+=
img_mean
img
=
img
.
transpose
((
1
,
2
,
0
))
img
=
img
.
transpose
((
1
,
2
,
0
))
return
(
img
*
255
).
clip
(
0
,
255
).
astype
(
'uint8'
)
return
(
img
*
255
).
clip
(
0
,
255
).
astype
(
'uint8'
)
def
post_process
(
self
,
raw_color
,
orig
):
def
post_process
(
self
,
raw_color
,
orig
):
color_np
=
np
.
asarray
(
raw_color
)
color_np
=
np
.
asarray
(
raw_color
)
orig_np
=
np
.
asarray
(
orig
)
orig_np
=
np
.
asarray
(
orig
)
...
@@ -86,11 +90,11 @@ class DeOldifyPredictor():
...
@@ -86,11 +90,11 @@ class DeOldifyPredictor():
final
=
cv2
.
cvtColor
(
hires
,
cv2
.
COLOR_YUV2BGR
)
final
=
cv2
.
cvtColor
(
hires
,
cv2
.
COLOR_YUV2BGR
)
final
=
Image
.
fromarray
(
final
)
final
=
Image
.
fromarray
(
final
)
return
final
return
final
def
run_single
(
self
,
img_path
):
def
run_single
(
self
,
img_path
):
ori_img
=
Image
.
open
(
img_path
).
convert
(
'LA'
).
convert
(
'RGB'
)
ori_img
=
Image
.
open
(
img_path
).
convert
(
'LA'
).
convert
(
'RGB'
)
img
=
self
.
norm
(
ori_img
)
img
=
self
.
norm
(
ori_img
)
x
=
paddle
.
to_tensor
(
img
[
np
.
newaxis
,...])
x
=
paddle
.
to_tensor
(
img
[
np
.
newaxis
,
...])
out
=
self
.
model
(
x
)
out
=
self
.
model
(
x
)
pred_img
=
self
.
denorm
(
out
.
numpy
()[
0
])
pred_img
=
self
.
denorm
(
out
.
numpy
()[
0
])
...
@@ -118,20 +122,20 @@ class DeOldifyPredictor():
...
@@ -118,20 +122,20 @@ class DeOldifyPredictor():
frames
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
out_path
,
'*.png'
)))
frames
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
out_path
,
'*.png'
)))
for
frame
in
tqdm
(
frames
):
for
frame
in
tqdm
(
frames
):
pred_img
=
self
.
run_single
(
frame
)
pred_img
=
self
.
run_single
(
frame
)
frame_name
=
os
.
path
.
basename
(
frame
)
frame_name
=
os
.
path
.
basename
(
frame
)
pred_img
.
save
(
os
.
path
.
join
(
pred_frame_path
,
frame_name
))
pred_img
.
save
(
os
.
path
.
join
(
pred_frame_path
,
frame_name
))
frame_pattern_combined
=
os
.
path
.
join
(
pred_frame_path
,
'%08d.png'
)
frame_pattern_combined
=
os
.
path
.
join
(
pred_frame_path
,
'%08d.png'
)
vid_out_path
=
os
.
path
.
join
(
output_path
,
'{}_deoldify_out.mp4'
.
format
(
base_name
))
frames_to_video_ffmpeg
(
frame_pattern_combined
,
vid_out_path
,
str
(
int
(
fps
)))
return
frame_pattern_combined
,
vid_out_path
vid_out_path
=
os
.
path
.
join
(
output_path
,
'{}_deoldify_out.mp4'
.
format
(
base_name
))
frames_to_video_ffmpeg
(
frame_pattern_combined
,
vid_out_path
,
str
(
int
(
fps
)))
return
frame_pattern_combined
,
vid_out_path
def
dump_frames_ffmpeg
(
vid_path
,
outpath
,
r
=
None
,
ss
=
None
,
t
=
None
):
def
dump_frames_ffmpeg
(
vid_path
,
outpath
,
r
=
None
,
ss
=
None
,
t
=
None
):
...
@@ -147,21 +151,8 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
...
@@ -147,21 +151,8 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
if
ss
is
not
None
and
t
is
not
None
and
r
is
not
None
:
if
ss
is
not
None
and
t
is
not
None
and
r
is
not
None
:
cmd
=
ffmpeg
+
[
cmd
=
ffmpeg
+
[
' -ss '
,
' -ss '
,
ss
,
' -t '
,
t
,
' -i '
,
vid_path
,
' -r '
,
r
,
' -qscale:v '
,
ss
,
' 0.1 '
,
' -start_number '
,
' 0 '
,
outformat
' -t '
,
t
,
' -i '
,
vid_path
,
' -r '
,
r
,
' -qscale:v '
,
' 0.1 '
,
' -start_number '
,
' 0 '
,
outformat
]
]
else
:
else
:
cmd
=
ffmpeg
+
[
' -i '
,
vid_path
,
' -start_number '
,
' 0 '
,
outformat
]
cmd
=
ffmpeg
+
[
' -i '
,
vid_path
,
' -start_number '
,
' 0 '
,
outformat
]
...
@@ -177,11 +168,13 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
...
@@ -177,11 +168,13 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
return
out_full_path
return
out_full_path
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
paddle
.
enable_imperative
()
paddle
.
enable_imperative
()
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
predictor
=
DeOldifyPredictor
(
args
.
input
,
args
.
output
,
weight_path
=
args
.
weight_path
)
predictor
=
DeOldifyPredictor
(
args
.
input
,
args
.
output
,
weight_path
=
args
.
weight_path
)
frames_path
,
temp_video_path
=
predictor
.
run
()
frames_path
,
temp_video_path
=
predictor
.
run
()
print
(
'output video path:'
,
temp_video_path
)
print
(
'output video path:'
,
temp_video_path
)
\ No newline at end of file
applications/DeepRemaster/predict.py
浏览文件 @
b306aa73
...
@@ -15,195 +15,235 @@ import argparse
...
@@ -15,195 +15,235 @@ import argparse
import
subprocess
import
subprocess
import
utils
import
utils
from
remasternet
import
NetworkR
,
NetworkC
from
remasternet
import
NetworkR
,
NetworkC
from
paddle.
incubate.hapi
.download
import
get_path_from_url
from
paddle.
utils
.download
import
get_path_from_url
DeepRemaster_weight_url
=
'https://paddlegan.bj.bcebos.com/applications/deep_remaster.pdparams'
DeepRemaster_weight_url
=
'https://paddlegan.bj.bcebos.com/applications/deep_remaster.pdparams'
parser
=
argparse
.
ArgumentParser
(
description
=
'Remastering'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Remastering'
)
parser
.
add_argument
(
'--input'
,
type
=
str
,
default
=
None
,
help
=
'Input video'
)
parser
.
add_argument
(
'--input'
,
type
=
str
,
default
=
None
,
help
=
'Input video'
)
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
'output'
,
help
=
'output dir'
)
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
'output'
,
help
=
'output dir'
)
parser
.
add_argument
(
'--reference_dir'
,
type
=
str
,
default
=
None
,
help
=
'Path to the reference image directory'
)
parser
.
add_argument
(
'--reference_dir'
,
parser
.
add_argument
(
'--colorization'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Remaster without colorization'
)
type
=
str
,
parser
.
add_argument
(
'--mindim'
,
type
=
int
,
default
=
'360'
,
help
=
'Length of minimum image edges'
)
default
=
None
,
help
=
'Path to the reference image directory'
)
parser
.
add_argument
(
'--colorization'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Remaster without colorization'
)
parser
.
add_argument
(
'--mindim'
,
type
=
int
,
default
=
'360'
,
help
=
'Length of minimum image edges'
)
class
DeepReasterPredictor
:
class
DeepReasterPredictor
:
def
__init__
(
self
,
input
,
output
,
weight_path
=
None
,
colorization
=
False
,
reference_dir
=
None
,
mindim
=
360
):
def
__init__
(
self
,
self
.
input
=
input
input
,
self
.
output
=
os
.
path
.
join
(
output
,
'DeepRemaster'
)
output
,
self
.
colorization
=
colorization
weight_path
=
None
,
self
.
reference_dir
=
reference_dir
colorization
=
False
,
self
.
mindim
=
mindim
reference_dir
=
None
,
mindim
=
360
):
if
weight_path
is
None
:
self
.
input
=
input
weight_path
=
get_path_from_url
(
DeepRemaster_weight_url
,
cur_path
)
self
.
output
=
os
.
path
.
join
(
output
,
'DeepRemaster'
)
self
.
colorization
=
colorization
state_dict
,
_
=
paddle
.
load
(
weight_path
)
self
.
reference_dir
=
reference_dir
self
.
mindim
=
mindim
self
.
modelR
=
NetworkR
()
self
.
modelR
.
load_dict
(
state_dict
[
'modelR'
])
if
weight_path
is
None
:
self
.
modelR
.
eval
()
weight_path
=
get_path_from_url
(
DeepRemaster_weight_url
,
cur_path
)
if
colorization
:
self
.
modelC
=
NetworkC
()
state_dict
,
_
=
paddle
.
load
(
weight_path
)
self
.
modelC
.
load_dict
(
state_dict
[
'modelC'
])
self
.
modelC
.
eval
()
self
.
modelR
=
NetworkR
()
self
.
modelR
.
load_dict
(
state_dict
[
'modelR'
])
self
.
modelR
.
eval
()
def
run
(
self
):
if
colorization
:
outputdir
=
self
.
output
self
.
modelC
=
NetworkC
()
outputdir_in
=
os
.
path
.
join
(
outputdir
,
'input/'
)
self
.
modelC
.
load_dict
(
state_dict
[
'modelC'
])
os
.
makedirs
(
outputdir_in
,
exist_ok
=
True
)
self
.
modelC
.
eval
()
outputdir_out
=
os
.
path
.
join
(
outputdir
,
'output/'
)
os
.
makedirs
(
outputdir_out
,
exist_ok
=
True
)
def
run
(
self
):
outputdir
=
self
.
output
# Prepare reference images
outputdir_in
=
os
.
path
.
join
(
outputdir
,
'input/'
)
if
self
.
colorization
:
os
.
makedirs
(
outputdir_in
,
exist_ok
=
True
)
if
self
.
reference_dir
is
not
None
:
outputdir_out
=
os
.
path
.
join
(
outputdir
,
'output/'
)
import
glob
os
.
makedirs
(
outputdir_out
,
exist_ok
=
True
)
ext_list
=
[
'png'
,
'jpg'
,
'bmp'
]
reference_files
=
[]
# Prepare reference images
for
ext
in
ext_list
:
if
self
.
colorization
:
reference_files
+=
glob
.
glob
(
self
.
reference_dir
+
'/*.'
+
ext
,
recursive
=
True
)
if
self
.
reference_dir
is
not
None
:
aspect_mean
=
0
import
glob
minedge_dim
=
256
ext_list
=
[
'png'
,
'jpg'
,
'bmp'
]
refs
=
[]
reference_files
=
[]
for
v
in
reference_files
:
for
ext
in
ext_list
:
refimg
=
Image
.
open
(
v
).
convert
(
'RGB'
)
reference_files
+=
glob
.
glob
(
self
.
reference_dir
+
'/*.'
+
w
,
h
=
refimg
.
size
ext
,
aspect_mean
+=
w
/
h
recursive
=
True
)
refs
.
append
(
refimg
)
aspect_mean
=
0
aspect_mean
/=
len
(
reference_files
)
minedge_dim
=
256
target_w
=
int
(
256
*
aspect_mean
)
if
aspect_mean
>
1
else
256
refs
=
[]
target_h
=
256
if
aspect_mean
>=
1
else
int
(
256
/
aspect_mean
)
for
v
in
reference_files
:
refimg
=
Image
.
open
(
v
).
convert
(
'RGB'
)
refimgs
=
[]
w
,
h
=
refimg
.
size
for
i
,
v
in
enumerate
(
refs
):
aspect_mean
+=
w
/
h
refimg
=
utils
.
addMergin
(
v
,
target_w
=
target_w
,
target_h
=
target_h
)
refs
.
append
(
refimg
)
refimg
=
np
.
array
(
refimg
).
astype
(
'float32'
).
transpose
(
2
,
0
,
1
)
/
255.0
aspect_mean
/=
len
(
reference_files
)
refimgs
.
append
(
refimg
)
target_w
=
int
(
256
*
aspect_mean
)
if
aspect_mean
>
1
else
256
refimgs
=
paddle
.
to_tensor
(
np
.
array
(
refimgs
).
astype
(
'float32'
))
target_h
=
256
if
aspect_mean
>=
1
else
int
(
256
/
aspect_mean
)
refimgs
=
paddle
.
unsqueeze
(
refimgs
,
0
)
refimgs
=
[]
for
i
,
v
in
enumerate
(
refs
):
# Load video
refimg
=
utils
.
addMergin
(
v
,
cap
=
cv2
.
VideoCapture
(
self
.
input
)
target_w
=
target_w
,
nframes
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_COUNT
))
target_h
=
target_h
)
v_w
=
cap
.
get
(
cv2
.
CAP_PROP_FRAME_WIDTH
)
refimg
=
np
.
array
(
refimg
).
astype
(
'float32'
).
transpose
(
v_h
=
cap
.
get
(
cv2
.
CAP_PROP_FRAME_HEIGHT
)
2
,
0
,
1
)
/
255.0
minwh
=
min
(
v_w
,
v_h
)
refimgs
.
append
(
refimg
)
scale
=
1
refimgs
=
paddle
.
to_tensor
(
np
.
array
(
refimgs
).
astype
(
'float32'
))
if
minwh
!=
self
.
mindim
:
scale
=
self
.
mindim
/
minwh
refimgs
=
paddle
.
unsqueeze
(
refimgs
,
0
)
t_w
=
round
(
v_w
*
scale
/
16.
)
*
16
# Load video
t_h
=
round
(
v_h
*
scale
/
16.
)
*
16
cap
=
cv2
.
VideoCapture
(
self
.
input
)
fps
=
cap
.
get
(
cv2
.
CAP_PROP_FPS
)
nframes
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_COUNT
))
pbar
=
tqdm
(
total
=
nframes
)
v_w
=
cap
.
get
(
cv2
.
CAP_PROP_FRAME_WIDTH
)
block
=
5
v_h
=
cap
.
get
(
cv2
.
CAP_PROP_FRAME_HEIGHT
)
minwh
=
min
(
v_w
,
v_h
)
# Process
scale
=
1
with
paddle
.
no_grad
():
if
minwh
!=
self
.
mindim
:
it
=
0
scale
=
self
.
mindim
/
minwh
while
True
:
frame_pos
=
it
*
block
t_w
=
round
(
v_w
*
scale
/
16.
)
*
16
if
frame_pos
>=
nframes
:
t_h
=
round
(
v_h
*
scale
/
16.
)
*
16
break
fps
=
cap
.
get
(
cv2
.
CAP_PROP_FPS
)
cap
.
set
(
cv2
.
CAP_PROP_POS_FRAMES
,
frame_pos
)
pbar
=
tqdm
(
total
=
nframes
)
if
block
>=
nframes
-
frame_pos
:
block
=
5
proc_g
=
nframes
-
frame_pos
else
:
# Process
proc_g
=
block
with
paddle
.
no_grad
():
it
=
0
input
=
None
while
True
:
gtC
=
None
frame_pos
=
it
*
block
for
i
in
range
(
proc_g
):
if
frame_pos
>=
nframes
:
index
=
frame_pos
+
i
break
_
,
frame
=
cap
.
read
()
cap
.
set
(
cv2
.
CAP_PROP_POS_FRAMES
,
frame_pos
)
frame
=
cv2
.
resize
(
frame
,
(
t_w
,
t_h
))
if
block
>=
nframes
-
frame_pos
:
nchannels
=
frame
.
shape
[
2
]
proc_g
=
nframes
-
frame_pos
if
nchannels
==
1
or
self
.
colorization
:
else
:
frame_l
=
cv2
.
cvtColor
(
frame
,
cv2
.
COLOR_RGB2GRAY
)
proc_g
=
block
cv2
.
imwrite
(
outputdir_in
+
'%07d.png'
%
index
,
frame_l
)
frame_l
=
paddle
.
to_tensor
(
frame_l
.
astype
(
'float32'
))
input
=
None
frame_l
=
paddle
.
reshape
(
frame_l
,
[
frame_l
.
shape
[
0
],
frame_l
.
shape
[
1
],
1
])
gtC
=
None
frame_l
=
paddle
.
transpose
(
frame_l
,
[
2
,
0
,
1
])
for
i
in
range
(
proc_g
):
frame_l
/=
255.
index
=
frame_pos
+
i
_
,
frame
=
cap
.
read
()
frame_l
=
paddle
.
reshape
(
frame_l
,
[
1
,
frame_l
.
shape
[
0
],
1
,
frame_l
.
shape
[
1
],
frame_l
.
shape
[
2
]])
frame
=
cv2
.
resize
(
frame
,
(
t_w
,
t_h
))
elif
nchannels
==
3
:
nchannels
=
frame
.
shape
[
2
]
cv2
.
imwrite
(
outputdir_in
+
'%07d.png'
%
index
,
frame
)
if
nchannels
==
1
or
self
.
colorization
:
frame
=
frame
[:,:,::
-
1
]
## BGR -> RGB
frame_l
=
cv2
.
cvtColor
(
frame
,
cv2
.
COLOR_RGB2GRAY
)
frame_l
,
frame_ab
=
utils
.
convertRGB2LABTensor
(
frame
)
cv2
.
imwrite
(
outputdir_in
+
'%07d.png'
%
index
,
frame_l
)
frame_l
=
frame_l
.
transpose
([
2
,
0
,
1
])
frame_l
=
paddle
.
to_tensor
(
frame_l
.
astype
(
'float32'
))
frame_ab
=
frame_ab
.
transpose
([
2
,
0
,
1
])
frame_l
=
paddle
.
reshape
(
frame_l
=
frame_l
.
reshape
([
1
,
frame_l
.
shape
[
0
],
1
,
frame_l
.
shape
[
1
],
frame_l
.
shape
[
2
]])
frame_l
,
[
frame_l
.
shape
[
0
],
frame_l
.
shape
[
1
],
1
])
frame_ab
=
frame_ab
.
reshape
([
1
,
frame_ab
.
shape
[
0
],
1
,
frame_ab
.
shape
[
1
],
frame_ab
.
shape
[
2
]])
frame_l
=
paddle
.
transpose
(
frame_l
,
[
2
,
0
,
1
])
frame_l
/=
255.
if
input
is
not
None
:
paddle
.
concat
(
(
input
,
frame_l
),
2
)
frame_l
=
paddle
.
reshape
(
frame_l
,
[
1
,
frame_l
.
shape
[
0
],
1
,
frame_l
.
shape
[
1
],
input
=
frame_l
if
i
==
0
else
paddle
.
concat
(
(
input
,
frame_l
),
2
)
frame_l
.
shape
[
2
]
if
nchannels
==
3
and
not
self
.
colorization
:
])
gtC
=
frame_ab
if
i
==
0
else
paddle
.
concat
(
(
gtC
,
frame_ab
),
2
)
elif
nchannels
==
3
:
cv2
.
imwrite
(
outputdir_in
+
'%07d.png'
%
index
,
frame
)
input
=
paddle
.
to_tensor
(
input
)
frame
=
frame
[:,
:,
::
-
1
]
## BGR -> RGB
frame_l
,
frame_ab
=
utils
.
convertRGB2LABTensor
(
frame
)
frame_l
=
frame_l
.
transpose
([
2
,
0
,
1
])
output_l
=
self
.
modelR
(
input
)
# [B, C, T, H, W]
frame_ab
=
frame_ab
.
transpose
([
2
,
0
,
1
])
frame_l
=
frame_l
.
reshape
([
# Save restoration output without colorization when using the option [--disable_colorization]
1
,
frame_l
.
shape
[
0
],
1
,
frame_l
.
shape
[
1
],
if
not
self
.
colorization
:
frame_l
.
shape
[
2
]
for
i
in
range
(
proc_g
):
])
index
=
frame_pos
+
i
frame_ab
=
frame_ab
.
reshape
([
if
nchannels
==
3
:
1
,
frame_ab
.
shape
[
0
],
1
,
frame_ab
.
shape
[
1
],
out_l
=
output_l
.
detach
()[
0
,:,
i
]
frame_ab
.
shape
[
2
]
out_ab
=
gtC
[
0
,:,
i
]
])
out
=
paddle
.
concat
((
out_l
,
out_ab
),
axis
=
0
).
detach
().
numpy
().
transpose
((
1
,
2
,
0
))
if
input
is
not
None
:
out
=
Image
.
fromarray
(
np
.
uint8
(
utils
.
convertLAB2RGB
(
out
)
*
255
)
)
paddle
.
concat
((
input
,
frame_l
),
2
)
out
.
save
(
outputdir_out
+
'%07d.png'
%
(
index
)
)
else
:
input
=
frame_l
if
i
==
0
else
paddle
.
concat
(
raise
ValueError
(
'channels of imag3 must be 3!'
)
(
input
,
frame_l
),
2
)
if
nchannels
==
3
and
not
self
.
colorization
:
# Perform colorization
gtC
=
frame_ab
if
i
==
0
else
paddle
.
concat
(
else
:
(
gtC
,
frame_ab
),
2
)
if
self
.
reference_dir
is
None
:
output_ab
=
self
.
modelC
(
output_l
)
input
=
paddle
.
to_tensor
(
input
)
else
:
output_ab
=
self
.
modelC
(
output_l
,
refimgs
)
output_l
=
self
.
modelR
(
input
)
# [B, C, T, H, W]
output_l
=
output_l
.
detach
()
output_ab
=
output_ab
.
detach
()
# Save restoration output without colorization when using the option [--disable_colorization]
if
not
self
.
colorization
:
for
i
in
range
(
proc_g
):
for
i
in
range
(
proc_g
):
index
=
frame_pos
+
i
index
=
frame_pos
+
i
if
nchannels
==
3
:
out_l
=
output_l
[
0
,:,
i
,:,:]
out_l
=
output_l
.
detach
()[
0
,
:,
i
]
out_c
=
output_ab
[
0
,:,
i
,:,:]
out_ab
=
gtC
[
0
,
:,
i
]
output
=
paddle
.
concat
((
out_l
,
out_c
),
axis
=
0
).
numpy
().
transpose
((
1
,
2
,
0
))
output
=
Image
.
fromarray
(
np
.
uint8
(
utils
.
convertLAB2RGB
(
output
)
*
255
)
)
out
=
paddle
.
concat
(
output
.
save
(
outputdir_out
+
'%07d.png'
%
index
)
(
out_l
,
out_ab
),
axis
=
0
).
detach
().
numpy
().
transpose
((
1
,
2
,
0
))
it
=
it
+
1
out
=
Image
.
fromarray
(
pbar
.
update
(
proc_g
)
np
.
uint8
(
utils
.
convertLAB2RGB
(
out
)
*
255
))
out
.
save
(
outputdir_out
+
'%07d.png'
%
(
index
))
# Save result videos
else
:
outfile
=
os
.
path
.
join
(
outputdir
,
self
.
input
.
split
(
'/'
)[
-
1
].
split
(
'.'
)[
0
])
raise
ValueError
(
'channels of imag3 must be 3!'
)
cmd
=
'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_in.mp4'
%
(
fps
,
outputdir_in
,
fps
,
outfile
)
subprocess
.
call
(
cmd
,
shell
=
True
)
# Perform colorization
cmd
=
'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_out.mp4'
%
(
fps
,
outputdir_out
,
fps
,
outfile
)
else
:
subprocess
.
call
(
cmd
,
shell
=
True
)
if
self
.
reference_dir
is
None
:
cmd
=
'ffmpeg -y -i %s_in.mp4 -vf "[in] pad=2.01*iw:ih [left];movie=%s_out.mp4[right];[left][right] overlay=main_w/2:0,scale=2*iw/2:2*ih/2[out]" %s_comp.mp4'
%
(
outfile
,
outfile
,
outfile
)
output_ab
=
self
.
modelC
(
output_l
)
subprocess
.
call
(
cmd
,
shell
=
True
)
else
:
output_ab
=
self
.
modelC
(
output_l
,
refimgs
)
cap
.
release
()
output_l
=
output_l
.
detach
()
pbar
.
close
()
output_ab
=
output_ab
.
detach
()
return
outputdir_out
,
'%s_out.mp4'
%
outfile
for
i
in
range
(
proc_g
):
index
=
frame_pos
+
i
out_l
=
output_l
[
0
,
:,
i
,
:,
:]
out_c
=
output_ab
[
0
,
:,
i
,
:,
:]
output
=
paddle
.
concat
(
(
out_l
,
out_c
),
axis
=
0
).
numpy
().
transpose
((
1
,
2
,
0
))
output
=
Image
.
fromarray
(
np
.
uint8
(
utils
.
convertLAB2RGB
(
output
)
*
255
))
output
.
save
(
outputdir_out
+
'%07d.png'
%
index
)
it
=
it
+
1
pbar
.
update
(
proc_g
)
# Save result videos
outfile
=
os
.
path
.
join
(
outputdir
,
self
.
input
.
split
(
'/'
)[
-
1
].
split
(
'.'
)[
0
])
cmd
=
'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_in.mp4'
%
(
fps
,
outputdir_in
,
fps
,
outfile
)
subprocess
.
call
(
cmd
,
shell
=
True
)
cmd
=
'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_out.mp4'
%
(
fps
,
outputdir_out
,
fps
,
outfile
)
subprocess
.
call
(
cmd
,
shell
=
True
)
cmd
=
'ffmpeg -y -i %s_in.mp4 -vf "[in] pad=2.01*iw:ih [left];movie=%s_out.mp4[right];[left][right] overlay=main_w/2:0,scale=2*iw/2:2*ih/2[out]" %s_comp.mp4'
%
(
outfile
,
outfile
,
outfile
)
subprocess
.
call
(
cmd
,
shell
=
True
)
cap
.
release
()
pbar
.
close
()
return
outputdir_out
,
'%s_out.mp4'
%
outfile
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
paddle
.
disable_static
()
paddle
.
disable_static
()
predictor
=
DeepReasterPredictor
(
args
.
input
,
args
.
output
,
colorization
=
args
.
colorization
,
predictor
=
DeepReasterPredictor
(
args
.
input
,
reference_dir
=
args
.
reference_dir
,
mindim
=
args
.
mindim
)
args
.
output
,
predictor
.
run
()
colorization
=
args
.
colorization
,
reference_dir
=
args
.
reference_dir
,
\ No newline at end of file
mindim
=
args
.
mindim
)
predictor
.
run
()
applications/EDVR/predict.py
浏览文件 @
b306aa73
...
@@ -28,30 +28,29 @@ import paddle.fluid as fluid
...
@@ -28,30 +28,29 @@ import paddle.fluid as fluid
import
cv2
import
cv2
from
data
import
EDVRDataset
from
data
import
EDVRDataset
from
paddle.
incubate.hapi
.download
import
get_path_from_url
from
paddle.
utils
.download
import
get_path_from_url
EDVR_weight_url
=
'https://paddlegan.bj.bcebos.com/applications/edvr_infer_model.tar'
EDVR_weight_url
=
'https://paddlegan.bj.bcebos.com/applications/edvr_infer_model.tar'
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
parser
.
add_argument
(
'--input'
,
'--input'
,
type
=
str
,
type
=
str
,
default
=
None
,
default
=
None
,
help
=
'input video path'
)
help
=
'input video path'
)
parser
.
add_argument
(
'--output'
,
parser
.
add_argument
(
type
=
str
,
'--output'
,
default
=
'output'
,
type
=
str
,
help
=
'output path'
)
default
=
'output'
,
parser
.
add_argument
(
'--weight_path'
,
help
=
'output path'
)
type
=
str
,
parser
.
add_argument
(
default
=
None
,
'--weight_path'
,
help
=
'weight path'
)
type
=
str
,
default
=
None
,
help
=
'weight path'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
args
return
args
def
get_img
(
pred
):
def
get_img
(
pred
):
print
(
'pred shape'
,
pred
.
shape
)
print
(
'pred shape'
,
pred
.
shape
)
pred
=
pred
.
squeeze
()
pred
=
pred
.
squeeze
()
...
@@ -59,10 +58,11 @@ def get_img(pred):
...
@@ -59,10 +58,11 @@ def get_img(pred):
pred
=
pred
*
255
pred
=
pred
*
255
pred
=
pred
.
round
()
pred
=
pred
.
round
()
pred
=
pred
.
astype
(
'uint8'
)
pred
=
pred
.
astype
(
'uint8'
)
pred
=
np
.
transpose
(
pred
,
(
1
,
2
,
0
))
# chw -> hwc
pred
=
np
.
transpose
(
pred
,
(
1
,
2
,
0
))
# chw -> hwc
pred
=
pred
[:,
:,
::
-
1
]
# rgb -> bgr
pred
=
pred
[:,
:,
::
-
1
]
# rgb -> bgr
return
pred
return
pred
def
save_img
(
img
,
framename
):
def
save_img
(
img
,
framename
):
dirname
=
os
.
path
.
dirname
(
framename
)
dirname
=
os
.
path
.
dirname
(
framename
)
if
not
os
.
path
.
exists
(
dirname
):
if
not
os
.
path
.
exists
(
dirname
):
...
@@ -84,19 +84,8 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
...
@@ -84,19 +84,8 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
if
ss
is
not
None
and
t
is
not
None
and
r
is
not
None
:
if
ss
is
not
None
and
t
is
not
None
and
r
is
not
None
:
cmd
=
ffmpeg
+
[
cmd
=
ffmpeg
+
[
' -ss '
,
' -ss '
,
ss
,
' -t '
,
t
,
' -i '
,
vid_path
,
' -r '
,
r
,
' -qscale:v '
,
ss
,
' 0.1 '
,
' -start_number '
,
' 0 '
,
outformat
' -t '
,
t
,
' -i '
,
vid_path
,
' -r '
,
r
,
' -qscale:v '
,
' 0.1 '
,
' -start_number '
,
' 0 '
,
outformat
]
]
else
:
else
:
cmd
=
ffmpeg
+
[
' -i '
,
vid_path
,
' -start_number '
,
' 0 '
,
outformat
]
cmd
=
ffmpeg
+
[
' -i '
,
vid_path
,
' -start_number '
,
' 0 '
,
outformat
]
...
@@ -134,20 +123,21 @@ class EDVRPredictor:
...
@@ -134,20 +123,21 @@ class EDVRPredictor:
self
.
input
=
input
self
.
input
=
input
self
.
output
=
os
.
path
.
join
(
output
,
'EDVR'
)
self
.
output
=
os
.
path
.
join
(
output
,
'EDVR'
)
place
=
fluid
.
CUDAPlace
(
0
)
if
fluid
.
is_compiled_with_cuda
()
else
fluid
.
CPUPlace
()
place
=
fluid
.
CUDAPlace
(
0
)
if
fluid
.
is_compiled_with_cuda
()
else
fluid
.
CPUPlace
()
self
.
exe
=
fluid
.
Executor
(
place
)
self
.
exe
=
fluid
.
Executor
(
place
)
if
weight_path
is
None
:
if
weight_path
is
None
:
weight_path
=
get_path_from_url
(
EDVR_weight_url
,
cur_path
)
weight_path
=
get_path_from_url
(
EDVR_weight_url
,
cur_path
)
print
(
weight_path
)
print
(
weight_path
)
model_filename
=
'EDVR_model.pdmodel'
model_filename
=
'EDVR_model.pdmodel'
params_filename
=
'EDVR_params.pdparams'
params_filename
=
'EDVR_params.pdparams'
out
=
fluid
.
io
.
load_inference_model
(
dirname
=
weight_path
,
out
=
fluid
.
io
.
load_inference_model
(
dirname
=
weight_path
,
model_filename
=
model_filename
,
model_filename
=
model_filename
,
params_filename
=
params_filename
,
params_filename
=
params_filename
,
executor
=
self
.
exe
)
executor
=
self
.
exe
)
self
.
infer_prog
,
self
.
feed_list
,
self
.
fetch_list
=
out
self
.
infer_prog
,
self
.
feed_list
,
self
.
fetch_list
=
out
...
@@ -176,16 +166,19 @@ class EDVRPredictor:
...
@@ -176,16 +166,19 @@ class EDVRPredictor:
cur_time
=
time
.
time
()
cur_time
=
time
.
time
()
for
infer_iter
,
data
in
enumerate
(
dataset
):
for
infer_iter
,
data
in
enumerate
(
dataset
):
data_feed_in
=
[
data
[
0
]]
data_feed_in
=
[
data
[
0
]]
infer_outs
=
self
.
exe
.
run
(
self
.
infer_prog
,
infer_outs
=
self
.
exe
.
run
(
fetch_list
=
self
.
fetch_list
,
self
.
infer_prog
,
feed
=
{
self
.
feed_list
[
0
]:
np
.
array
(
data_feed_in
)})
fetch_list
=
self
.
fetch_list
,
feed
=
{
self
.
feed_list
[
0
]:
np
.
array
(
data_feed_in
)})
infer_result_list
=
[
item
for
item
in
infer_outs
]
infer_result_list
=
[
item
for
item
in
infer_outs
]
frame_path
=
data
[
1
]
frame_path
=
data
[
1
]
img_i
=
get_img
(
infer_result_list
[
0
])
img_i
=
get_img
(
infer_result_list
[
0
])
save_img
(
img_i
,
os
.
path
.
join
(
pred_frame_path
,
os
.
path
.
basename
(
frame_path
)))
save_img
(
img_i
,
os
.
path
.
join
(
pred_frame_path
,
os
.
path
.
basename
(
frame_path
)))
prev_time
=
cur_time
prev_time
=
cur_time
cur_time
=
time
.
time
()
cur_time
=
time
.
time
()
...
@@ -194,13 +187,15 @@ class EDVRPredictor:
...
@@ -194,13 +187,15 @@ class EDVRPredictor:
print
(
'Processed {} samples'
.
format
(
infer_iter
+
1
))
print
(
'Processed {} samples'
.
format
(
infer_iter
+
1
))
frame_pattern_combined
=
os
.
path
.
join
(
pred_frame_path
,
'%08d.png'
)
frame_pattern_combined
=
os
.
path
.
join
(
pred_frame_path
,
'%08d.png'
)
vid_out_path
=
os
.
path
.
join
(
self
.
output
,
'{}_edvr_out.mp4'
.
format
(
base_name
))
vid_out_path
=
os
.
path
.
join
(
self
.
output
,
frames_to_video_ffmpeg
(
frame_pattern_combined
,
vid_out_path
,
str
(
int
(
fps
)))
'{}_edvr_out.mp4'
.
format
(
base_name
))
frames_to_video_ffmpeg
(
frame_pattern_combined
,
vid_out_path
,
str
(
int
(
fps
)))
return
frame_pattern_combined
,
vid_out_path
return
frame_pattern_combined
,
vid_out_path
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
parse_args
()
predictor
=
EDVRPredictor
(
args
.
input
,
args
.
output
,
args
.
weight_path
)
predictor
=
EDVRPredictor
(
args
.
input
,
args
.
output
,
args
.
weight_path
)
predictor
.
run
()
predictor
.
run
()
configs/cyclegan_cityscapes.yaml
浏览文件 @
b306aa73
...
@@ -60,11 +60,12 @@ dataset:
...
@@ -60,11 +60,12 @@ dataset:
optimizer
:
optimizer
:
name
:
Adam
name
:
Adam
beta1
:
0.5
beta1
:
0.5
lr_scheduler
:
name
:
linear
lr_scheduler
:
learning_rate
:
0.0002
name
:
linear
start_epoch
:
100
learning_rate
:
0.0002
decay_epochs
:
100
start_epoch
:
100
decay_epochs
:
100
log_config
:
log_config
:
interval
:
100
interval
:
100
...
@@ -72,4 +73,3 @@ log_config:
...
@@ -72,4 +73,3 @@ log_config:
snapshot_config
:
snapshot_config
:
interval
:
5
interval
:
5
configs/cyclegan_horse2zebra.yaml
浏览文件 @
b306aa73
...
@@ -59,11 +59,12 @@ dataset:
...
@@ -59,11 +59,12 @@ dataset:
optimizer
:
optimizer
:
name
:
Adam
name
:
Adam
beta1
:
0.5
beta1
:
0.5
lr_scheduler
:
name
:
linear
lr_scheduler
:
learning_rate
:
0.0002
name
:
linear
start_epoch
:
100
learning_rate
:
0.0002
decay_epochs
:
100
start_epoch
:
100
decay_epochs
:
100
log_config
:
log_config
:
interval
:
100
interval
:
100
...
@@ -71,4 +72,3 @@ log_config:
...
@@ -71,4 +72,3 @@ log_config:
snapshot_config
:
snapshot_config
:
interval
:
5
interval
:
5
configs/pix2pix_cityscapes.yaml
浏览文件 @
b306aa73
...
@@ -25,7 +25,7 @@ dataset:
...
@@ -25,7 +25,7 @@ dataset:
train
:
train
:
name
:
PairedDataset
name
:
PairedDataset
dataroot
:
data/cityscapes
dataroot
:
data/cityscapes
num_workers
:
0
num_workers
:
4
phase
:
train
phase
:
train
max_dataset_size
:
inf
max_dataset_size
:
inf
direction
:
BtoA
direction
:
BtoA
...
@@ -57,11 +57,12 @@ dataset:
...
@@ -57,11 +57,12 @@ dataset:
optimizer
:
optimizer
:
name
:
Adam
name
:
Adam
beta1
:
0.5
beta1
:
0.5
lr_scheduler
:
name
:
linear
lr_scheduler
:
learning_rate
:
0.0002
name
:
linear
start_epoch
:
100
learning_rate
:
0.0002
decay_epochs
:
100
start_epoch
:
100
decay_epochs
:
100
log_config
:
log_config
:
interval
:
100
interval
:
100
...
@@ -69,4 +70,3 @@ log_config:
...
@@ -69,4 +70,3 @@ log_config:
snapshot_config
:
snapshot_config
:
interval
:
5
interval
:
5
configs/pix2pix_cityscapes_2gpus.yaml
浏览文件 @
b306aa73
...
@@ -56,11 +56,12 @@ dataset:
...
@@ -56,11 +56,12 @@ dataset:
optimizer
:
optimizer
:
name
:
Adam
name
:
Adam
beta1
:
0.5
beta1
:
0.5
lr_scheduler
:
name
:
linear
lr_scheduler
:
learning_rate
:
0.0004
name
:
linear
start_epoch
:
100
learning_rate
:
0.0004
decay_epochs
:
100
start_epoch
:
100
decay_epochs
:
100
log_config
:
log_config
:
interval
:
100
interval
:
100
...
@@ -68,4 +69,3 @@ log_config:
...
@@ -68,4 +69,3 @@ log_config:
snapshot_config
:
snapshot_config
:
interval
:
5
interval
:
5
configs/pix2pix_facades.yaml
浏览文件 @
b306aa73
...
@@ -56,11 +56,12 @@ dataset:
...
@@ -56,11 +56,12 @@ dataset:
optimizer
:
optimizer
:
name
:
Adam
name
:
Adam
beta1
:
0.5
beta1
:
0.5
lr_scheduler
:
name
:
linear
lr_scheduler
:
learning_rate
:
0.0002
name
:
linear
start_epoch
:
100
learning_rate
:
0.0002
decay_epochs
:
100
start_epoch
:
100
decay_epochs
:
100
log_config
:
log_config
:
interval
:
100
interval
:
100
...
...
ppgan/datasets/base_dataset.py
浏览文件 @
b306aa73
...
@@ -6,7 +6,7 @@ from paddle.io import Dataset
...
@@ -6,7 +6,7 @@ from paddle.io import Dataset
from
PIL
import
Image
from
PIL
import
Image
import
cv2
import
cv2
import
paddle.
incubate.hapi.
vision.transforms
as
transforms
import
paddle.vision.transforms
as
transforms
from
.transforms
import
transforms
as
T
from
.transforms
import
transforms
as
T
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
...
@@ -14,7 +14,6 @@ from abc import ABC, abstractmethod
...
@@ -14,7 +14,6 @@ from abc import ABC, abstractmethod
class
BaseDataset
(
Dataset
,
ABC
):
class
BaseDataset
(
Dataset
,
ABC
):
"""This class is an abstract base class (ABC) for datasets.
"""This class is an abstract base class (ABC) for datasets.
"""
"""
def
__init__
(
self
,
cfg
):
def
__init__
(
self
,
cfg
):
"""Initialize the class; save the options in the class
"""Initialize the class; save the options in the class
...
@@ -60,8 +59,11 @@ def get_params(cfg, size):
...
@@ -60,8 +59,11 @@ def get_params(cfg, size):
return
{
'crop_pos'
:
(
x
,
y
),
'flip'
:
flip
}
return
{
'crop_pos'
:
(
x
,
y
),
'flip'
:
flip
}
def
get_transform
(
cfg
,
def
get_transform
(
cfg
,
params
=
None
,
grayscale
=
False
,
method
=
cv2
.
INTER_CUBIC
,
convert
=
True
):
params
=
None
,
grayscale
=
False
,
method
=
cv2
.
INTER_CUBIC
,
convert
=
True
):
transform_list
=
[]
transform_list
=
[]
if
grayscale
:
if
grayscale
:
print
(
'grayscale not support for now!!!'
)
print
(
'grayscale not support for now!!!'
)
...
@@ -89,8 +91,10 @@ def get_transform(cfg, params=None, grayscale=False, method=cv2.INTER_CUBIC, con
...
@@ -89,8 +91,10 @@ def get_transform(cfg, params=None, grayscale=False, method=cv2.INTER_CUBIC, con
transform_list
.
append
(
transforms
.
RandomHorizontalFlip
())
transform_list
.
append
(
transforms
.
RandomHorizontalFlip
())
elif
params
[
'flip'
]:
elif
params
[
'flip'
]:
transform_list
.
append
(
transforms
.
RandomHorizontalFlip
(
1.0
))
transform_list
.
append
(
transforms
.
RandomHorizontalFlip
(
1.0
))
if
convert
:
if
convert
:
transform_list
+=
[
transforms
.
Permute
(
to_rgb
=
True
)]
transform_list
+=
[
transforms
.
Permute
(
to_rgb
=
True
)]
transform_list
+=
[
transforms
.
Normalize
((
127.5
,
127.5
,
127.5
),
(
127.5
,
127.5
,
127.5
))]
transform_list
+=
[
transforms
.
Normalize
((
127.5
,
127.5
,
127.5
),
(
127.5
,
127.5
,
127.5
))
]
return
transforms
.
Compose
(
transform_list
)
return
transforms
.
Compose
(
transform_list
)
ppgan/datasets/builder.py
浏览文件 @
b306aa73
...
@@ -3,12 +3,11 @@ import paddle
...
@@ -3,12 +3,11 @@ import paddle
import
numbers
import
numbers
import
numpy
as
np
import
numpy
as
np
from
multiprocessing
import
Manager
from
multiprocessing
import
Manager
from
paddle
import
ParallelEnv
from
paddle
.distributed
import
ParallelEnv
from
paddle.i
ncubate.hapi.distributed
import
DistributedBatchSampler
from
paddle.i
o
import
DistributedBatchSampler
from
..utils.registry
import
Registry
from
..utils.registry
import
Registry
DATASETS
=
Registry
(
"DATASETS"
)
DATASETS
=
Registry
(
"DATASETS"
)
...
@@ -21,7 +20,7 @@ class DictDataset(paddle.io.Dataset):
...
@@ -21,7 +20,7 @@ class DictDataset(paddle.io.Dataset):
single_item
=
dataset
[
0
]
single_item
=
dataset
[
0
]
self
.
keys
=
single_item
.
keys
()
self
.
keys
=
single_item
.
keys
()
for
k
,
v
in
single_item
.
items
():
for
k
,
v
in
single_item
.
items
():
if
not
isinstance
(
v
,
(
numbers
.
Number
,
np
.
ndarray
)):
if
not
isinstance
(
v
,
(
numbers
.
Number
,
np
.
ndarray
)):
setattr
(
self
,
k
,
Manager
().
dict
())
setattr
(
self
,
k
,
Manager
().
dict
())
...
@@ -32,9 +31,9 @@ class DictDataset(paddle.io.Dataset):
...
@@ -32,9 +31,9 @@ class DictDataset(paddle.io.Dataset):
def
__getitem__
(
self
,
index
):
def
__getitem__
(
self
,
index
):
ori_map
=
self
.
dataset
[
index
]
ori_map
=
self
.
dataset
[
index
]
tmp_list
=
[]
tmp_list
=
[]
for
k
,
v
in
ori_map
.
items
():
for
k
,
v
in
ori_map
.
items
():
if
isinstance
(
v
,
(
numbers
.
Number
,
np
.
ndarray
)):
if
isinstance
(
v
,
(
numbers
.
Number
,
np
.
ndarray
)):
tmp_list
.
append
(
v
)
tmp_list
.
append
(
v
)
...
@@ -60,17 +59,15 @@ class DictDataLoader():
...
@@ -60,17 +59,15 @@ class DictDataLoader():
place
=
paddle
.
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
\
place
=
paddle
.
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
\
if
ParallelEnv
().
nranks
>
1
else
paddle
.
fluid
.
CUDAPlace
(
0
)
if
ParallelEnv
().
nranks
>
1
else
paddle
.
fluid
.
CUDAPlace
(
0
)
sampler
=
DistributedBatchSampler
(
sampler
=
DistributedBatchSampler
(
self
.
dataset
,
self
.
dataset
,
batch_size
=
batch_size
,
batch_size
=
batch_size
,
shuffle
=
True
if
is_train
else
False
,
shuffle
=
True
if
is_train
else
False
,
drop_last
=
True
if
is_train
else
False
)
drop_last
=
True
if
is_train
else
False
)
self
.
dataloader
=
paddle
.
io
.
DataLoader
(
self
.
dataloader
=
paddle
.
io
.
DataLoader
(
self
.
dataset
,
self
.
dataset
,
batch_sampler
=
sampler
,
batch_sampler
=
sampler
,
places
=
place
,
places
=
place
,
num_workers
=
num_workers
)
num_workers
=
num_workers
)
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
...
@@ -83,7 +80,9 @@ class DictDataLoader():
...
@@ -83,7 +80,9 @@ class DictDataLoader():
j
=
0
j
=
0
for
k
in
self
.
dataset
.
keys
:
for
k
in
self
.
dataset
.
keys
:
if
k
in
self
.
dataset
.
tensor_keys_set
:
if
k
in
self
.
dataset
.
tensor_keys_set
:
return_dict
[
k
]
=
data
[
j
]
if
isinstance
(
data
,
(
list
,
tuple
))
else
data
return_dict
[
k
]
=
data
[
j
]
if
isinstance
(
data
,
(
list
,
tuple
))
else
data
j
+=
1
j
+=
1
else
:
else
:
return_dict
[
k
]
=
self
.
get_items_by_indexs
(
k
,
data
[
-
1
])
return_dict
[
k
]
=
self
.
get_items_by_indexs
(
k
,
data
[
-
1
])
...
@@ -104,13 +103,12 @@ class DictDataLoader():
...
@@ -104,13 +103,12 @@ class DictDataLoader():
return
current_items
return
current_items
def
build_dataloader
(
cfg
,
is_train
=
True
):
def
build_dataloader
(
cfg
,
is_train
=
True
):
dataset
=
DATASETS
.
get
(
cfg
.
name
)(
cfg
)
dataset
=
DATASETS
.
get
(
cfg
.
name
)(
cfg
)
batch_size
=
cfg
.
get
(
'batch_size'
,
1
)
batch_size
=
cfg
.
get
(
'batch_size'
,
1
)
num_workers
=
cfg
.
get
(
'num_workers'
,
0
)
num_workers
=
cfg
.
get
(
'num_workers'
,
0
)
dataloader
=
DictDataLoader
(
dataset
,
batch_size
,
is_train
,
num_workers
)
dataloader
=
DictDataLoader
(
dataset
,
batch_size
,
is_train
,
num_workers
)
return
dataloader
return
dataloader
\ No newline at end of file
ppgan/engine/trainer.py
浏览文件 @
b306aa73
...
@@ -4,7 +4,7 @@ import time
...
@@ -4,7 +4,7 @@ import time
import
logging
import
logging
import
paddle
import
paddle
from
paddle
import
ParallelEnv
,
DataParallel
from
paddle
.distributed
import
ParallelEnv
from
..datasets.builder
import
build_dataloader
from
..datasets.builder
import
build_dataloader
from
..models.builder
import
build_model
from
..models.builder
import
build_model
...
@@ -17,10 +17,11 @@ class Trainer:
...
@@ -17,10 +17,11 @@ class Trainer:
# build train dataloader
# build train dataloader
self
.
train_dataloader
=
build_dataloader
(
cfg
.
dataset
.
train
)
self
.
train_dataloader
=
build_dataloader
(
cfg
.
dataset
.
train
)
if
'lr_scheduler'
in
cfg
.
optimizer
:
if
'lr_scheduler'
in
cfg
.
optimizer
:
cfg
.
optimizer
.
lr_scheduler
.
step_per_epoch
=
len
(
self
.
train_dataloader
)
cfg
.
optimizer
.
lr_scheduler
.
step_per_epoch
=
len
(
self
.
train_dataloader
)
# build model
# build model
self
.
model
=
build_model
(
cfg
)
self
.
model
=
build_model
(
cfg
)
# multiple gpus prepare
# multiple gpus prepare
...
@@ -44,16 +45,17 @@ class Trainer:
...
@@ -44,16 +45,17 @@ class Trainer:
# time count
# time count
self
.
time_count
=
{}
self
.
time_count
=
{}
def
distributed_data_parallel
(
self
):
def
distributed_data_parallel
(
self
):
strategy
=
paddle
.
prepare_context
()
strategy
=
paddle
.
prepare_context
()
for
name
in
self
.
model
.
model_names
:
for
name
in
self
.
model
.
model_names
:
if
isinstance
(
name
,
str
):
if
isinstance
(
name
,
str
):
net
=
getattr
(
self
.
model
,
'net'
+
name
)
net
=
getattr
(
self
.
model
,
'net'
+
name
)
setattr
(
self
.
model
,
'net'
+
name
,
DataParallel
(
net
,
strategy
))
setattr
(
self
.
model
,
'net'
+
name
,
paddle
.
DataParallel
(
net
,
strategy
))
def
train
(
self
):
def
train
(
self
):
for
epoch
in
range
(
self
.
start_epoch
,
self
.
epochs
):
for
epoch
in
range
(
self
.
start_epoch
,
self
.
epochs
):
self
.
current_epoch
=
epoch
self
.
current_epoch
=
epoch
start_time
=
step_start_time
=
time
.
time
()
start_time
=
step_start_time
=
time
.
time
()
...
@@ -64,24 +66,27 @@ class Trainer:
...
@@ -64,24 +66,27 @@ class Trainer:
# data input should be dict
# data input should be dict
self
.
model
.
set_input
(
data
)
self
.
model
.
set_input
(
data
)
self
.
model
.
optimize_parameters
()
self
.
model
.
optimize_parameters
()
self
.
data_time
=
data_time
-
step_start_time
self
.
data_time
=
data_time
-
step_start_time
self
.
step_time
=
time
.
time
()
-
step_start_time
self
.
step_time
=
time
.
time
()
-
step_start_time
if
i
%
self
.
log_interval
==
0
:
if
i
%
self
.
log_interval
==
0
:
self
.
print_log
()
self
.
print_log
()
if
i
%
self
.
visual_interval
==
0
:
if
i
%
self
.
visual_interval
==
0
:
self
.
visual
(
'visual_train'
)
self
.
visual
(
'visual_train'
)
step_start_time
=
time
.
time
()
step_start_time
=
time
.
time
()
self
.
logger
.
info
(
'train one epoch time: {}'
.
format
(
time
.
time
()
-
start_time
))
self
.
logger
.
info
(
'train one epoch time: {}'
.
format
(
time
.
time
()
-
start_time
))
self
.
model
.
lr_scheduler
.
step
()
if
epoch
%
self
.
weight_interval
==
0
:
if
epoch
%
self
.
weight_interval
==
0
:
self
.
save
(
epoch
,
'weight'
,
keep
=-
1
)
self
.
save
(
epoch
,
'weight'
,
keep
=-
1
)
self
.
save
(
epoch
)
self
.
save
(
epoch
)
def
test
(
self
):
def
test
(
self
):
if
not
hasattr
(
self
,
'test_dataloader'
):
if
not
hasattr
(
self
,
'test_dataloader'
):
self
.
test_dataloader
=
build_dataloader
(
self
.
cfg
.
dataset
.
test
,
is_train
=
False
)
self
.
test_dataloader
=
build_dataloader
(
self
.
cfg
.
dataset
.
test
,
is_train
=
False
)
# data[0]: img, data[1]: img path index
# data[0]: img, data[1]: img path index
# test batch size must be 1
# test batch size must be 1
...
@@ -103,14 +108,15 @@ class Trainer:
...
@@ -103,14 +108,15 @@ class Trainer:
visual_results
.
update
({
name
:
img_tensor
[
j
]})
visual_results
.
update
({
name
:
img_tensor
[
j
]})
self
.
visual
(
'visual_test'
,
visual_results
=
visual_results
)
self
.
visual
(
'visual_test'
,
visual_results
=
visual_results
)
if
i
%
self
.
log_interval
==
0
:
if
i
%
self
.
log_interval
==
0
:
self
.
logger
.
info
(
'Test iter: [%d/%d]'
%
(
i
,
len
(
self
.
test_dataloader
)))
self
.
logger
.
info
(
'Test iter: [%d/%d]'
%
(
i
,
len
(
self
.
test_dataloader
)))
def
print_log
(
self
):
def
print_log
(
self
):
losses
=
self
.
model
.
get_current_losses
()
losses
=
self
.
model
.
get_current_losses
()
message
=
'Epoch: %d, iters: %d '
%
(
self
.
current_epoch
,
self
.
batch_id
)
message
=
'Epoch: %d, iters: %d '
%
(
self
.
current_epoch
,
self
.
batch_id
)
message
+=
'%s: %.6f '
%
(
'lr'
,
self
.
current_learning_rate
)
message
+=
'%s: %.6f '
%
(
'lr'
,
self
.
current_learning_rate
)
for
k
,
v
in
losses
.
items
():
for
k
,
v
in
losses
.
items
():
...
@@ -143,13 +149,14 @@ class Trainer:
...
@@ -143,13 +149,14 @@ class Trainer:
makedirs
(
os
.
path
.
join
(
self
.
output_dir
,
results_dir
))
makedirs
(
os
.
path
.
join
(
self
.
output_dir
,
results_dir
))
for
label
,
image
in
visual_results
.
items
():
for
label
,
image
in
visual_results
.
items
():
image_numpy
=
tensor2img
(
image
)
image_numpy
=
tensor2img
(
image
)
img_path
=
os
.
path
.
join
(
self
.
output_dir
,
results_dir
,
msg
+
'%s.png'
%
(
label
))
img_path
=
os
.
path
.
join
(
self
.
output_dir
,
results_dir
,
msg
+
'%s.png'
%
(
label
))
save_image
(
image_numpy
,
img_path
)
save_image
(
image_numpy
,
img_path
)
def
save
(
self
,
epoch
,
name
=
'checkpoint'
,
keep
=
1
):
def
save
(
self
,
epoch
,
name
=
'checkpoint'
,
keep
=
1
):
if
self
.
local_rank
!=
0
:
if
self
.
local_rank
!=
0
:
return
return
assert
name
in
[
'checkpoint'
,
'weight'
]
assert
name
in
[
'checkpoint'
,
'weight'
]
state_dicts
=
{}
state_dicts
=
{}
...
@@ -175,8 +182,8 @@ class Trainer:
...
@@ -175,8 +182,8 @@ class Trainer:
if
keep
>
0
:
if
keep
>
0
:
try
:
try
:
checkpoint_name_to_be_removed
=
os
.
path
.
join
(
self
.
output_dir
,
checkpoint_name_to_be_removed
=
os
.
path
.
join
(
'epoch_%s_%s.pkl'
%
(
epoch
-
keep
,
name
))
self
.
output_dir
,
'epoch_%s_%s.pkl'
%
(
epoch
-
keep
,
name
))
if
os
.
path
.
exists
(
checkpoint_name_to_be_removed
):
if
os
.
path
.
exists
(
checkpoint_name_to_be_removed
):
os
.
remove
(
checkpoint_name_to_be_removed
)
os
.
remove
(
checkpoint_name_to_be_removed
)
...
@@ -187,7 +194,7 @@ class Trainer:
...
@@ -187,7 +194,7 @@ class Trainer:
state_dicts
=
load
(
checkpoint_path
)
state_dicts
=
load
(
checkpoint_path
)
if
state_dicts
.
get
(
'epoch'
,
None
)
is
not
None
:
if
state_dicts
.
get
(
'epoch'
,
None
)
is
not
None
:
self
.
start_epoch
=
state_dicts
[
'epoch'
]
+
1
self
.
start_epoch
=
state_dicts
[
'epoch'
]
+
1
for
name
in
self
.
model
.
model_names
:
for
name
in
self
.
model
.
model_names
:
if
isinstance
(
name
,
str
):
if
isinstance
(
name
,
str
):
net
=
getattr
(
self
.
model
,
'net'
+
name
)
net
=
getattr
(
self
.
model
,
'net'
+
name
)
...
@@ -200,9 +207,8 @@ class Trainer:
...
@@ -200,9 +207,8 @@ class Trainer:
def
load
(
self
,
weight_path
):
def
load
(
self
,
weight_path
):
state_dicts
=
load
(
weight_path
)
state_dicts
=
load
(
weight_path
)
for
name
in
self
.
model
.
model_names
:
for
name
in
self
.
model
.
model_names
:
if
isinstance
(
name
,
str
):
if
isinstance
(
name
,
str
):
net
=
getattr
(
self
.
model
,
'net'
+
name
)
net
=
getattr
(
self
.
model
,
'net'
+
name
)
net
.
set_dict
(
state_dicts
[
'net'
+
name
])
net
.
set_dict
(
state_dicts
[
'net'
+
name
])
\ No newline at end of file
ppgan/models/base_model.py
浏览文件 @
b306aa73
...
@@ -5,6 +5,7 @@ import numpy as np
...
@@ -5,6 +5,7 @@ import numpy as np
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
..solver.lr_scheduler
import
build_lr_scheduler
class
BaseModel
(
ABC
):
class
BaseModel
(
ABC
):
...
@@ -16,7 +17,6 @@ class BaseModel(ABC):
...
@@ -16,7 +17,6 @@ class BaseModel(ABC):
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
"""
"""
def
__init__
(
self
,
opt
):
def
__init__
(
self
,
opt
):
"""Initialize the BaseModel class.
"""Initialize the BaseModel class.
...
@@ -33,8 +33,10 @@ class BaseModel(ABC):
...
@@ -33,8 +33,10 @@ class BaseModel(ABC):
"""
"""
self
.
opt
=
opt
self
.
opt
=
opt
self
.
isTrain
=
opt
.
isTrain
self
.
isTrain
=
opt
.
isTrain
self
.
save_dir
=
os
.
path
.
join
(
opt
.
output_dir
,
opt
.
model
.
name
)
# save all the checkpoints to save_dir
self
.
save_dir
=
os
.
path
.
join
(
opt
.
output_dir
,
opt
.
model
.
name
)
# save all the checkpoints to save_dir
self
.
loss_names
=
[]
self
.
loss_names
=
[]
self
.
model_names
=
[]
self
.
model_names
=
[]
self
.
visual_names
=
[]
self
.
visual_names
=
[]
...
@@ -75,6 +77,8 @@ class BaseModel(ABC):
...
@@ -75,6 +77,8 @@ class BaseModel(ABC):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
pass
pass
def
build_lr_scheduler
(
self
):
self
.
lr_scheduler
=
build_lr_scheduler
(
self
.
opt
.
lr_scheduler
)
def
eval
(
self
):
def
eval
(
self
):
"""Make models eval mode during test time"""
"""Make models eval mode during test time"""
...
@@ -114,10 +118,11 @@ class BaseModel(ABC):
...
@@ -114,10 +118,11 @@ class BaseModel(ABC):
errors_ret
=
OrderedDict
()
errors_ret
=
OrderedDict
()
for
name
in
self
.
loss_names
:
for
name
in
self
.
loss_names
:
if
isinstance
(
name
,
str
):
if
isinstance
(
name
,
str
):
errors_ret
[
name
]
=
float
(
getattr
(
self
,
'loss_'
+
name
))
# float(...) works for both scalar tensor and float number
errors_ret
[
name
]
=
float
(
getattr
(
self
,
'loss_'
+
name
)
)
# float(...) works for both scalar tensor and float number
return
errors_ret
return
errors_ret
def
set_requires_grad
(
self
,
nets
,
requires_grad
=
False
):
def
set_requires_grad
(
self
,
nets
,
requires_grad
=
False
):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
Parameters:
...
...
ppgan/models/cycle_gan_model.py
浏览文件 @
b306aa73
import
paddle
import
paddle
from
paddle
import
ParallelEnv
from
paddle
.distributed
import
ParallelEnv
from
.base_model
import
BaseModel
from
.base_model
import
BaseModel
from
.builder
import
MODELS
from
.builder
import
MODELS
...
@@ -23,7 +23,6 @@ class CycleGANModel(BaseModel):
...
@@ -23,7 +23,6 @@ class CycleGANModel(BaseModel):
CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
"""
"""
def
__init__
(
self
,
opt
):
def
__init__
(
self
,
opt
):
"""Initialize the CycleGAN class.
"""Initialize the CycleGAN class.
...
@@ -32,12 +31,14 @@ class CycleGANModel(BaseModel):
...
@@ -32,12 +31,14 @@ class CycleGANModel(BaseModel):
"""
"""
BaseModel
.
__init__
(
self
,
opt
)
BaseModel
.
__init__
(
self
,
opt
)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self
.
loss_names
=
[
'D_A'
,
'G_A'
,
'cycle_A'
,
'idt_A'
,
'D_B'
,
'G_B'
,
'cycle_B'
,
'idt_B'
]
self
.
loss_names
=
[
'D_A'
,
'G_A'
,
'cycle_A'
,
'idt_A'
,
'D_B'
,
'G_B'
,
'cycle_B'
,
'idt_B'
]
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
visual_names_A
=
[
'real_A'
,
'fake_B'
,
'rec_A'
]
visual_names_A
=
[
'real_A'
,
'fake_B'
,
'rec_A'
]
visual_names_B
=
[
'real_B'
,
'fake_A'
,
'rec_B'
]
visual_names_B
=
[
'real_B'
,
'fake_A'
,
'rec_B'
]
# if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
# if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
if
self
.
isTrain
and
self
.
opt
.
lambda_identity
>
0.0
:
if
self
.
isTrain
and
self
.
opt
.
lambda_identity
>
0.0
:
visual_names_A
.
append
(
'idt_B'
)
visual_names_A
.
append
(
'idt_B'
)
visual_names_B
.
append
(
'idt_A'
)
visual_names_B
.
append
(
'idt_A'
)
...
@@ -62,18 +63,28 @@ class CycleGANModel(BaseModel):
...
@@ -62,18 +63,28 @@ class CycleGANModel(BaseModel):
if
self
.
isTrain
:
if
self
.
isTrain
:
if
opt
.
lambda_identity
>
0.0
:
# only works when input and output images have the same number of channels
if
opt
.
lambda_identity
>
0.0
:
# only works when input and output images have the same number of channels
assert
(
opt
.
dataset
.
train
.
input_nc
==
opt
.
dataset
.
train
.
output_nc
)
assert
(
opt
.
dataset
.
train
.
input_nc
==
opt
.
dataset
.
train
.
output_nc
)
# create image buffer to store previously generated images
# create image buffer to store previously generated images
self
.
fake_A_pool
=
ImagePool
(
opt
.
dataset
.
train
.
pool_size
)
self
.
fake_A_pool
=
ImagePool
(
opt
.
dataset
.
train
.
pool_size
)
# create image buffer to store previously generated images
# create image buffer to store previously generated images
self
.
fake_B_pool
=
ImagePool
(
opt
.
dataset
.
train
.
pool_size
)
self
.
fake_B_pool
=
ImagePool
(
opt
.
dataset
.
train
.
pool_size
)
# define loss functions
# define loss functions
self
.
criterionGAN
=
GANLoss
(
opt
.
model
.
gan_mode
)
self
.
criterionGAN
=
GANLoss
(
opt
.
model
.
gan_mode
)
self
.
criterionCycle
=
paddle
.
nn
.
L1Loss
()
self
.
criterionCycle
=
paddle
.
nn
.
L1Loss
()
self
.
criterionIdt
=
paddle
.
nn
.
L1Loss
()
self
.
criterionIdt
=
paddle
.
nn
.
L1Loss
()
self
.
optimizer_G
=
build_optimizer
(
opt
.
optimizer
,
parameter_list
=
self
.
netG_A
.
parameters
()
+
self
.
netG_B
.
parameters
())
self
.
build_lr_scheduler
()
self
.
optimizer_D
=
build_optimizer
(
opt
.
optimizer
,
parameter_list
=
self
.
netD_A
.
parameters
()
+
self
.
netD_B
.
parameters
())
self
.
optimizer_G
=
build_optimizer
(
opt
.
optimizer
,
self
.
lr_scheduler
,
parameter_list
=
self
.
netG_A
.
parameters
()
+
self
.
netG_B
.
parameters
())
self
.
optimizer_D
=
build_optimizer
(
opt
.
optimizer
,
self
.
lr_scheduler
,
parameter_list
=
self
.
netD_A
.
parameters
()
+
self
.
netD_B
.
parameters
())
self
.
optimizers
.
append
(
self
.
optimizer_G
)
self
.
optimizers
.
append
(
self
.
optimizer_G
)
self
.
optimizers
.
append
(
self
.
optimizer_D
)
self
.
optimizers
.
append
(
self
.
optimizer_D
)
...
@@ -90,7 +101,7 @@ class CycleGANModel(BaseModel):
...
@@ -90,7 +101,7 @@ class CycleGANModel(BaseModel):
"""
"""
mode
=
'train'
if
self
.
isTrain
else
'test'
mode
=
'train'
if
self
.
isTrain
else
'test'
AtoB
=
self
.
opt
.
dataset
[
mode
].
direction
==
'AtoB'
AtoB
=
self
.
opt
.
dataset
[
mode
].
direction
==
'AtoB'
if
AtoB
:
if
AtoB
:
if
'A'
in
input
:
if
'A'
in
input
:
self
.
real_A
=
paddle
.
to_tensor
(
input
[
'A'
])
self
.
real_A
=
paddle
.
to_tensor
(
input
[
'A'
])
...
@@ -107,17 +118,15 @@ class CycleGANModel(BaseModel):
...
@@ -107,17 +118,15 @@ class CycleGANModel(BaseModel):
elif
'B_paths'
in
input
:
elif
'B_paths'
in
input
:
self
.
image_paths
=
input
[
'B_paths'
]
self
.
image_paths
=
input
[
'B_paths'
]
def
forward
(
self
):
def
forward
(
self
):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
if
hasattr
(
self
,
'real_A'
):
if
hasattr
(
self
,
'real_A'
):
self
.
fake_B
=
self
.
netG_A
(
self
.
real_A
)
# G_A(A)
self
.
fake_B
=
self
.
netG_A
(
self
.
real_A
)
# G_A(A)
self
.
rec_A
=
self
.
netG_B
(
self
.
fake_B
)
# G_B(G_A(A))
self
.
rec_A
=
self
.
netG_B
(
self
.
fake_B
)
# G_B(G_A(A))
if
hasattr
(
self
,
'real_B'
):
if
hasattr
(
self
,
'real_B'
):
self
.
fake_A
=
self
.
netG_B
(
self
.
real_B
)
# G_B(B)
self
.
fake_A
=
self
.
netG_B
(
self
.
real_B
)
# G_B(B)
self
.
rec_B
=
self
.
netG_A
(
self
.
fake_A
)
# G_A(G_B(B))
self
.
rec_B
=
self
.
netG_A
(
self
.
fake_A
)
# G_A(G_B(B))
def
backward_D_basic
(
self
,
netD
,
real
,
fake
):
def
backward_D_basic
(
self
,
netD
,
real
,
fake
):
"""Calculate GAN loss for the discriminator
"""Calculate GAN loss for the discriminator
...
@@ -166,10 +175,12 @@ class CycleGANModel(BaseModel):
...
@@ -166,10 +175,12 @@ class CycleGANModel(BaseModel):
if
lambda_idt
>
0
:
if
lambda_idt
>
0
:
# G_A should be identity if real_B is fed: ||G_A(B) - B||
# G_A should be identity if real_B is fed: ||G_A(B) - B||
self
.
idt_A
=
self
.
netG_A
(
self
.
real_B
)
self
.
idt_A
=
self
.
netG_A
(
self
.
real_B
)
self
.
loss_idt_A
=
self
.
criterionIdt
(
self
.
idt_A
,
self
.
real_B
)
*
lambda_B
*
lambda_idt
self
.
loss_idt_A
=
self
.
criterionIdt
(
self
.
idt_A
,
self
.
real_B
)
*
lambda_B
*
lambda_idt
# G_B should be identity if real_A is fed: ||G_B(A) - A||
# G_B should be identity if real_A is fed: ||G_B(A) - A||
self
.
idt_B
=
self
.
netG_B
(
self
.
real_A
)
self
.
idt_B
=
self
.
netG_B
(
self
.
real_A
)
self
.
loss_idt_B
=
self
.
criterionIdt
(
self
.
idt_B
,
self
.
real_A
)
*
lambda_A
*
lambda_idt
self
.
loss_idt_B
=
self
.
criterionIdt
(
self
.
idt_B
,
self
.
real_A
)
*
lambda_A
*
lambda_idt
else
:
else
:
self
.
loss_idt_A
=
0
self
.
loss_idt_A
=
0
self
.
loss_idt_B
=
0
self
.
loss_idt_B
=
0
...
@@ -179,12 +190,14 @@ class CycleGANModel(BaseModel):
...
@@ -179,12 +190,14 @@ class CycleGANModel(BaseModel):
# GAN loss D_B(G_B(B))
# GAN loss D_B(G_B(B))
self
.
loss_G_B
=
self
.
criterionGAN
(
self
.
netD_B
(
self
.
fake_A
),
True
)
self
.
loss_G_B
=
self
.
criterionGAN
(
self
.
netD_B
(
self
.
fake_A
),
True
)
# Forward cycle loss || G_B(G_A(A)) - A||
# Forward cycle loss || G_B(G_A(A)) - A||
self
.
loss_cycle_A
=
self
.
criterionCycle
(
self
.
rec_A
,
self
.
real_A
)
*
lambda_A
self
.
loss_cycle_A
=
self
.
criterionCycle
(
self
.
rec_A
,
self
.
real_A
)
*
lambda_A
# Backward cycle loss || G_A(G_B(B)) - B||
# Backward cycle loss || G_A(G_B(B)) - B||
self
.
loss_cycle_B
=
self
.
criterionCycle
(
self
.
rec_B
,
self
.
real_B
)
*
lambda_B
self
.
loss_cycle_B
=
self
.
criterionCycle
(
self
.
rec_B
,
self
.
real_B
)
*
lambda_B
# combined loss and calculate gradients
# combined loss and calculate gradients
self
.
loss_G
=
self
.
loss_G_A
+
self
.
loss_G_B
+
self
.
loss_cycle_A
+
self
.
loss_cycle_B
+
self
.
loss_idt_A
+
self
.
loss_idt_B
self
.
loss_G
=
self
.
loss_G_A
+
self
.
loss_G_B
+
self
.
loss_cycle_A
+
self
.
loss_cycle_B
+
self
.
loss_idt_A
+
self
.
loss_idt_B
if
ParallelEnv
().
nranks
>
1
:
if
ParallelEnv
().
nranks
>
1
:
self
.
loss_G
=
self
.
netG_A
.
scale_loss
(
self
.
loss_G
)
self
.
loss_G
=
self
.
netG_A
.
scale_loss
(
self
.
loss_G
)
self
.
loss_G
.
backward
()
self
.
loss_G
.
backward
()
...
@@ -216,6 +229,5 @@ class CycleGANModel(BaseModel):
...
@@ -216,6 +229,5 @@ class CycleGANModel(BaseModel):
self
.
backward_D_A
()
self
.
backward_D_A
()
# calculate graidents for D_B
# calculate graidents for D_B
self
.
backward_D_B
()
self
.
backward_D_B
()
# update D_A and D_B's weights
# update D_A and D_B's weights
self
.
optimizer_D
.
minimize
(
self
.
loss_D_A
+
self
.
loss_D_B
)
self
.
optimizer_D
.
minimize
(
self
.
loss_D_A
+
self
.
loss_D_B
)
ppgan/models/pix2pix_model.py
浏览文件 @
b306aa73
import
paddle
import
paddle
from
paddle
import
ParallelEnv
from
paddle
.distributed
import
ParallelEnv
from
.base_model
import
BaseModel
from
.base_model
import
BaseModel
from
.builder
import
MODELS
from
.builder
import
MODELS
...
@@ -22,7 +22,6 @@ class Pix2PixModel(BaseModel):
...
@@ -22,7 +22,6 @@ class Pix2PixModel(BaseModel):
pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
"""
"""
def
__init__
(
self
,
opt
):
def
__init__
(
self
,
opt
):
"""Initialize the pix2pix class.
"""Initialize the pix2pix class.
...
@@ -48,15 +47,21 @@ class Pix2PixModel(BaseModel):
...
@@ -48,15 +47,21 @@ class Pix2PixModel(BaseModel):
if
self
.
isTrain
:
if
self
.
isTrain
:
self
.
netD
=
build_discriminator
(
opt
.
model
.
discriminator
)
self
.
netD
=
build_discriminator
(
opt
.
model
.
discriminator
)
if
self
.
isTrain
:
if
self
.
isTrain
:
# define loss functions
# define loss functions
self
.
criterionGAN
=
GANLoss
(
opt
.
model
.
gan_mode
)
self
.
criterionGAN
=
GANLoss
(
opt
.
model
.
gan_mode
)
self
.
criterionL1
=
paddle
.
nn
.
L1Loss
()
self
.
criterionL1
=
paddle
.
nn
.
L1Loss
()
# build optimizers
# build optimizers
self
.
optimizer_G
=
build_optimizer
(
opt
.
optimizer
,
parameter_list
=
self
.
netG
.
parameters
())
self
.
build_lr_scheduler
()
self
.
optimizer_D
=
build_optimizer
(
opt
.
optimizer
,
parameter_list
=
self
.
netD
.
parameters
())
self
.
optimizer_G
=
build_optimizer
(
opt
.
optimizer
,
self
.
lr_scheduler
,
parameter_list
=
self
.
netG
.
parameters
())
self
.
optimizer_D
=
build_optimizer
(
opt
.
optimizer
,
self
.
lr_scheduler
,
parameter_list
=
self
.
netD
.
parameters
())
self
.
optimizers
.
append
(
self
.
optimizer_G
)
self
.
optimizers
.
append
(
self
.
optimizer_G
)
self
.
optimizers
.
append
(
self
.
optimizer_D
)
self
.
optimizers
.
append
(
self
.
optimizer_D
)
...
@@ -75,7 +80,6 @@ class Pix2PixModel(BaseModel):
...
@@ -75,7 +80,6 @@ class Pix2PixModel(BaseModel):
self
.
real_A
=
paddle
.
to_tensor
(
input
[
'A'
if
AtoB
else
'B'
])
self
.
real_A
=
paddle
.
to_tensor
(
input
[
'A'
if
AtoB
else
'B'
])
self
.
real_B
=
paddle
.
to_tensor
(
input
[
'B'
if
AtoB
else
'A'
])
self
.
real_B
=
paddle
.
to_tensor
(
input
[
'B'
if
AtoB
else
'A'
])
self
.
image_paths
=
input
[
'A_paths'
if
AtoB
else
'B_paths'
]
self
.
image_paths
=
input
[
'A_paths'
if
AtoB
else
'B_paths'
]
def
forward
(
self
):
def
forward
(
self
):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
...
@@ -84,7 +88,7 @@ class Pix2PixModel(BaseModel):
...
@@ -84,7 +88,7 @@ class Pix2PixModel(BaseModel):
def
forward_test
(
self
,
input
):
def
forward_test
(
self
,
input
):
input
=
paddle
.
imperative
.
to_variable
(
input
)
input
=
paddle
.
imperative
.
to_variable
(
input
)
return
self
.
netG
(
input
)
return
self
.
netG
(
input
)
def
backward_D
(
self
):
def
backward_D
(
self
):
"""Calculate GAN loss for the discriminator"""
"""Calculate GAN loss for the discriminator"""
# Fake; stop backprop to the generator by detaching fake_B
# Fake; stop backprop to the generator by detaching fake_B
...
@@ -112,7 +116,8 @@ class Pix2PixModel(BaseModel):
...
@@ -112,7 +116,8 @@ class Pix2PixModel(BaseModel):
pred_fake
=
self
.
netD
(
fake_AB
)
pred_fake
=
self
.
netD
(
fake_AB
)
self
.
loss_G_GAN
=
self
.
criterionGAN
(
pred_fake
,
True
)
self
.
loss_G_GAN
=
self
.
criterionGAN
(
pred_fake
,
True
)
# Second, G(A) = B
# Second, G(A) = B
self
.
loss_G_L1
=
self
.
criterionL1
(
self
.
fake_B
,
self
.
real_B
)
*
self
.
opt
.
lambda_L1
self
.
loss_G_L1
=
self
.
criterionL1
(
self
.
fake_B
,
self
.
real_B
)
*
self
.
opt
.
lambda_L1
# combine loss and calculate gradients
# combine loss and calculate gradients
self
.
loss_G
=
self
.
loss_G_GAN
+
self
.
loss_G_L1
self
.
loss_G
=
self
.
loss_G_GAN
+
self
.
loss_G_L1
...
@@ -129,12 +134,12 @@ class Pix2PixModel(BaseModel):
...
@@ -129,12 +134,12 @@ class Pix2PixModel(BaseModel):
# update D
# update D
self
.
set_requires_grad
(
self
.
netD
,
True
)
self
.
set_requires_grad
(
self
.
netD
,
True
)
self
.
optimizer_D
.
clear_gradients
()
self
.
optimizer_D
.
clear_gradients
()
self
.
backward_D
()
self
.
backward_D
()
self
.
optimizer_D
.
minimize
(
self
.
loss_D
)
self
.
optimizer_D
.
minimize
(
self
.
loss_D
)
# update G
# update G
self
.
set_requires_grad
(
self
.
netD
,
False
)
self
.
set_requires_grad
(
self
.
netD
,
False
)
self
.
optimizer_G
.
clear_gradients
()
self
.
optimizer_G
.
clear_gradients
()
self
.
backward_G
()
self
.
backward_G
()
self
.
optimizer_G
.
minimize
(
self
.
loss_G
)
self
.
optimizer_G
.
minimize
(
self
.
loss_G
)
ppgan/solver/lr_scheduler.py
浏览文件 @
b306aa73
...
@@ -6,13 +6,23 @@ def build_lr_scheduler(cfg):
...
@@ -6,13 +6,23 @@ def build_lr_scheduler(cfg):
# TODO: add more learning rate scheduler
# TODO: add more learning rate scheduler
if
name
==
'linear'
:
if
name
==
'linear'
:
return
LinearDecay
(
**
cfg
)
def
lambda_rule
(
epoch
):
lr_l
=
1.0
-
max
(
0
,
epoch
+
1
-
cfg
.
start_epoch
)
/
float
(
cfg
.
decay_epochs
+
1
)
return
lr_l
scheduler
=
paddle
.
optimizer
.
lr_scheduler
.
LambdaLR
(
cfg
.
learning_rate
,
lr_lambda
=
lambda_rule
)
return
scheduler
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
class
LinearDecay
(
paddle
.
fluid
.
dygraph
.
learning_rate_scheduler
.
LearningRateDecay
):
# paddle.optimizer.lr_scheduler
def
__init__
(
self
,
learning_rate
,
step_per_epoch
,
start_epoch
,
decay_epochs
):
class
LinearDecay
(
paddle
.
optimizer
.
lr_scheduler
.
_LRScheduler
):
def
__init__
(
self
,
learning_rate
,
step_per_epoch
,
start_epoch
,
decay_epochs
):
super
(
LinearDecay
,
self
).
__init__
()
super
(
LinearDecay
,
self
).
__init__
()
self
.
learning_rate
=
learning_rate
self
.
learning_rate
=
learning_rate
self
.
start_epoch
=
start_epoch
self
.
start_epoch
=
start_epoch
...
@@ -21,5 +31,6 @@ class LinearDecay(paddle.fluid.dygraph.learning_rate_scheduler.LearningRateDecay
...
@@ -21,5 +31,6 @@ class LinearDecay(paddle.fluid.dygraph.learning_rate_scheduler.LearningRateDecay
def
step
(
self
):
def
step
(
self
):
cur_epoch
=
int
(
self
.
step_num
//
self
.
step_per_epoch
)
cur_epoch
=
int
(
self
.
step_num
//
self
.
step_per_epoch
)
decay_rate
=
1.0
-
max
(
0
,
cur_epoch
+
1
-
self
.
start_epoch
)
/
float
(
self
.
decay_epochs
+
1
)
decay_rate
=
1.0
-
max
(
return
self
.
create_lr_var
(
decay_rate
*
self
.
learning_rate
)
0
,
cur_epoch
+
1
-
self
.
start_epoch
)
/
float
(
self
.
decay_epochs
+
1
)
\ No newline at end of file
return
self
.
create_lr_var
(
decay_rate
*
self
.
learning_rate
)
ppgan/solver/optimizer.py
浏览文件 @
b306aa73
...
@@ -4,13 +4,11 @@ import paddle
...
@@ -4,13 +4,11 @@ import paddle
from
.lr_scheduler
import
build_lr_scheduler
from
.lr_scheduler
import
build_lr_scheduler
def
build_optimizer
(
cfg
,
parameter_list
=
None
):
def
build_optimizer
(
cfg
,
lr_scheduler
,
parameter_list
=
None
):
cfg_copy
=
copy
.
deepcopy
(
cfg
)
cfg_copy
=
copy
.
deepcopy
(
cfg
)
lr_scheduler_cfg
=
cfg_copy
.
pop
(
'lr_scheduler'
,
None
)
lr_scheduler
=
build_lr_scheduler
(
lr_scheduler_cfg
)
opt_name
=
cfg_copy
.
pop
(
'name'
)
opt_name
=
cfg_copy
.
pop
(
'name'
)
return
getattr
(
paddle
.
optimizer
,
opt_name
)(
lr_scheduler
,
parameters
=
parameter_list
,
**
cfg_copy
)
return
getattr
(
paddle
.
optimizer
,
opt_name
)(
lr_scheduler
,
parameters
=
parameter_list
,
**
cfg_copy
)
ppgan/utils/logger.py
浏览文件 @
b306aa73
...
@@ -2,7 +2,7 @@ import logging
...
@@ -2,7 +2,7 @@ import logging
import
os
import
os
import
sys
import
sys
from
paddle
import
ParallelEnv
from
paddle
.distributed
import
ParallelEnv
def
setup_logger
(
output
=
None
,
name
=
"ppgan"
):
def
setup_logger
(
output
=
None
,
name
=
"ppgan"
):
...
@@ -23,8 +23,8 @@ def setup_logger(output=None, name="ppgan"):
...
@@ -23,8 +23,8 @@ def setup_logger(output=None, name="ppgan"):
logger
.
propagate
=
False
logger
.
propagate
=
False
plain_formatter
=
logging
.
Formatter
(
plain_formatter
=
logging
.
Formatter
(
"[%(asctime)s] %(name)s %(levelname)s: %(message)s"
,
datefmt
=
"%m/%d %H:%M:%S"
"[%(asctime)s] %(name)s %(levelname)s: %(message)s"
,
)
datefmt
=
"%m/%d %H:%M:%S"
)
# stdout logging: master only
# stdout logging: master only
local_rank
=
ParallelEnv
().
local_rank
local_rank
=
ParallelEnv
().
local_rank
if
local_rank
==
0
:
if
local_rank
==
0
:
...
@@ -52,4 +52,4 @@ def setup_logger(output=None, name="ppgan"):
...
@@ -52,4 +52,4 @@ def setup_logger(output=None, name="ppgan"):
fh
.
setFormatter
(
plain_formatter
)
fh
.
setFormatter
(
plain_formatter
)
logger
.
addHandler
(
fh
)
logger
.
addHandler
(
fh
)
return
logger
return
logger
\ No newline at end of file
ppgan/utils/setup.py
浏览文件 @
b306aa73
...
@@ -2,7 +2,7 @@ import os
...
@@ -2,7 +2,7 @@ import os
import
time
import
time
import
paddle
import
paddle
from
paddle
import
ParallelEnv
from
paddle
.distributed
import
ParallelEnv
from
.logger
import
setup_logger
from
.logger
import
setup_logger
...
@@ -12,7 +12,8 @@ def setup(args, cfg):
...
@@ -12,7 +12,8 @@ def setup(args, cfg):
cfg
.
isTrain
=
False
cfg
.
isTrain
=
False
cfg
.
timestamp
=
time
.
strftime
(
'-%Y-%m-%d-%H-%M'
,
time
.
localtime
())
cfg
.
timestamp
=
time
.
strftime
(
'-%Y-%m-%d-%H-%M'
,
time
.
localtime
())
cfg
.
output_dir
=
os
.
path
.
join
(
cfg
.
output_dir
,
str
(
cfg
.
model
.
name
)
+
cfg
.
timestamp
)
cfg
.
output_dir
=
os
.
path
.
join
(
cfg
.
output_dir
,
str
(
cfg
.
model
.
name
)
+
cfg
.
timestamp
)
logger
=
setup_logger
(
cfg
.
output_dir
)
logger
=
setup_logger
(
cfg
.
output_dir
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录