Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
b306aa73
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
大约 2 年 前同步成功
通知
100
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b306aa73
编写于
9月 03, 2020
作者:
L
LielinJiang
提交者:
GitHub
9月 03, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #16 from LielinJiang/adapt-to-2.0-api
Adapt to api 2.0 again
上级
4a3ba224
abd3250d
变更
19
显示空白变更内容
内联
并排
Showing
19 changed file
with
466 addition
and
397 deletion
+466
-397
applications/DAIN/predict.py
applications/DAIN/predict.py
+3
-3
applications/DeOldify/predict.py
applications/DeOldify/predict.py
+27
-34
applications/DeepRemaster/predict.py
applications/DeepRemaster/predict.py
+222
-182
applications/EDVR/predict.py
applications/EDVR/predict.py
+41
-46
configs/cyclegan_cityscapes.yaml
configs/cyclegan_cityscapes.yaml
+6
-6
configs/cyclegan_horse2zebra.yaml
configs/cyclegan_horse2zebra.yaml
+6
-6
configs/pix2pix_cityscapes.yaml
configs/pix2pix_cityscapes.yaml
+7
-7
configs/pix2pix_cityscapes_2gpus.yaml
configs/pix2pix_cityscapes_2gpus.yaml
+6
-6
configs/pix2pix_facades.yaml
configs/pix2pix_facades.yaml
+6
-5
ppgan/datasets/base_dataset.py
ppgan/datasets/base_dataset.py
+10
-6
ppgan/datasets/builder.py
ppgan/datasets/builder.py
+18
-20
ppgan/engine/trainer.py
ppgan/engine/trainer.py
+27
-21
ppgan/models/base_model.py
ppgan/models/base_model.py
+10
-5
ppgan/models/cycle_gan_model.py
ppgan/models/cycle_gan_model.py
+33
-21
ppgan/models/pix2pix_model.py
ppgan/models/pix2pix_model.py
+17
-12
ppgan/solver/lr_scheduler.py
ppgan/solver/lr_scheduler.py
+16
-5
ppgan/solver/optimizer.py
ppgan/solver/optimizer.py
+4
-6
ppgan/utils/logger.py
ppgan/utils/logger.py
+4
-4
ppgan/utils/setup.py
ppgan/utils/setup.py
+3
-2
未找到文件。
applications/DAIN/predict.py
浏览文件 @
b306aa73
...
@@ -11,7 +11,7 @@ from imageio import imread, imsave
...
@@ -11,7 +11,7 @@ from imageio import imread, imsave
import
cv2
import
cv2
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.
incubate.hapi
.download
import
get_path_from_url
from
paddle.
utils
.download
import
get_path_from_url
import
networks
import
networks
from
util
import
*
from
util
import
*
...
@@ -19,6 +19,7 @@ from my_args import parser
...
@@ -19,6 +19,7 @@ from my_args import parser
DAIN_WEIGHT_URL
=
'https://paddlegan.bj.bcebos.com/applications/DAIN_weight.tar'
DAIN_WEIGHT_URL
=
'https://paddlegan.bj.bcebos.com/applications/DAIN_weight.tar'
def
infer_engine
(
model_dir
,
def
infer_engine
(
model_dir
,
run_mode
=
'fluid'
,
run_mode
=
'fluid'
,
batch_size
=
1
,
batch_size
=
1
,
...
@@ -91,7 +92,6 @@ class VideoFrameInterp(object):
...
@@ -91,7 +92,6 @@ class VideoFrameInterp(object):
self
.
exe
,
self
.
program
,
self
.
fetch_targets
=
executor
(
model_path
,
self
.
exe
,
self
.
program
,
self
.
fetch_targets
=
executor
(
model_path
,
use_gpu
=
use_gpu
)
use_gpu
=
use_gpu
)
def
run
(
self
):
def
run
(
self
):
frame_path_input
=
os
.
path
.
join
(
self
.
output_path
,
'frames-input'
)
frame_path_input
=
os
.
path
.
join
(
self
.
output_path
,
'frames-input'
)
frame_path_interpolated
=
os
.
path
.
join
(
self
.
output_path
,
frame_path_interpolated
=
os
.
path
.
join
(
self
.
output_path
,
...
...
applications/DeOldify/predict.py
浏览文件 @
b306aa73
...
@@ -15,15 +15,19 @@ from PIL import Image
...
@@ -15,15 +15,19 @@ from PIL import Image
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
paddle
import
fluid
from
paddle
import
fluid
from
model
import
build_model
from
model
import
build_model
from
paddle.
incubate.hapi
.download
import
get_path_from_url
from
paddle.
utils
.download
import
get_path_from_url
parser
=
argparse
.
ArgumentParser
(
description
=
'DeOldify'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'DeOldify'
)
parser
.
add_argument
(
'--input'
,
type
=
str
,
default
=
'none'
,
help
=
'Input video'
)
parser
.
add_argument
(
'--input'
,
type
=
str
,
default
=
'none'
,
help
=
'Input video'
)
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
'output'
,
help
=
'output dir'
)
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
'output'
,
help
=
'output dir'
)
parser
.
add_argument
(
'--weight_path'
,
type
=
str
,
default
=
'none'
,
help
=
'Path to the reference image directory'
)
parser
.
add_argument
(
'--weight_path'
,
type
=
str
,
default
=
'none'
,
help
=
'Path to the reference image directory'
)
DeOldify_weight_url
=
'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams'
DeOldify_weight_url
=
'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams'
def
frames_to_video_ffmpeg
(
framepath
,
videopath
,
r
):
def
frames_to_video_ffmpeg
(
framepath
,
videopath
,
r
):
ffmpeg
=
[
'ffmpeg '
,
' -loglevel '
,
' error '
]
ffmpeg
=
[
'ffmpeg '
,
' -loglevel '
,
' error '
]
cmd
=
ffmpeg
+
[
cmd
=
ffmpeg
+
[
...
@@ -90,7 +94,7 @@ class DeOldifyPredictor():
...
@@ -90,7 +94,7 @@ class DeOldifyPredictor():
def
run_single
(
self
,
img_path
):
def
run_single
(
self
,
img_path
):
ori_img
=
Image
.
open
(
img_path
).
convert
(
'LA'
).
convert
(
'RGB'
)
ori_img
=
Image
.
open
(
img_path
).
convert
(
'LA'
).
convert
(
'RGB'
)
img
=
self
.
norm
(
ori_img
)
img
=
self
.
norm
(
ori_img
)
x
=
paddle
.
to_tensor
(
img
[
np
.
newaxis
,...])
x
=
paddle
.
to_tensor
(
img
[
np
.
newaxis
,
...])
out
=
self
.
model
(
x
)
out
=
self
.
model
(
x
)
pred_img
=
self
.
denorm
(
out
.
numpy
()[
0
])
pred_img
=
self
.
denorm
(
out
.
numpy
()[
0
])
...
@@ -118,7 +122,6 @@ class DeOldifyPredictor():
...
@@ -118,7 +122,6 @@ class DeOldifyPredictor():
frames
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
out_path
,
'*.png'
)))
frames
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
out_path
,
'*.png'
)))
for
frame
in
tqdm
(
frames
):
for
frame
in
tqdm
(
frames
):
pred_img
=
self
.
run_single
(
frame
)
pred_img
=
self
.
run_single
(
frame
)
...
@@ -127,13 +130,14 @@ class DeOldifyPredictor():
...
@@ -127,13 +130,14 @@ class DeOldifyPredictor():
frame_pattern_combined
=
os
.
path
.
join
(
pred_frame_path
,
'%08d.png'
)
frame_pattern_combined
=
os
.
path
.
join
(
pred_frame_path
,
'%08d.png'
)
vid_out_path
=
os
.
path
.
join
(
output_path
,
'{}_deoldify_out.mp4'
.
format
(
base_name
))
vid_out_path
=
os
.
path
.
join
(
output_path
,
frames_to_video_ffmpeg
(
frame_pattern_combined
,
vid_out_path
,
str
(
int
(
fps
)))
'{}_deoldify_out.mp4'
.
format
(
base_name
))
frames_to_video_ffmpeg
(
frame_pattern_combined
,
vid_out_path
,
str
(
int
(
fps
)))
return
frame_pattern_combined
,
vid_out_path
return
frame_pattern_combined
,
vid_out_path
def
dump_frames_ffmpeg
(
vid_path
,
outpath
,
r
=
None
,
ss
=
None
,
t
=
None
):
def
dump_frames_ffmpeg
(
vid_path
,
outpath
,
r
=
None
,
ss
=
None
,
t
=
None
):
ffmpeg
=
[
'ffmpeg '
,
' -loglevel '
,
' error '
]
ffmpeg
=
[
'ffmpeg '
,
' -loglevel '
,
' error '
]
vid_name
=
vid_path
.
split
(
'/'
)[
-
1
].
split
(
'.'
)[
0
]
vid_name
=
vid_path
.
split
(
'/'
)[
-
1
].
split
(
'.'
)[
0
]
...
@@ -147,21 +151,8 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
...
@@ -147,21 +151,8 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
if
ss
is
not
None
and
t
is
not
None
and
r
is
not
None
:
if
ss
is
not
None
and
t
is
not
None
and
r
is
not
None
:
cmd
=
ffmpeg
+
[
cmd
=
ffmpeg
+
[
' -ss '
,
' -ss '
,
ss
,
' -t '
,
t
,
' -i '
,
vid_path
,
' -r '
,
r
,
' -qscale:v '
,
ss
,
' 0.1 '
,
' -start_number '
,
' 0 '
,
outformat
' -t '
,
t
,
' -i '
,
vid_path
,
' -r '
,
r
,
' -qscale:v '
,
' 0.1 '
,
' -start_number '
,
' 0 '
,
outformat
]
]
else
:
else
:
cmd
=
ffmpeg
+
[
' -i '
,
vid_path
,
' -start_number '
,
' 0 '
,
outformat
]
cmd
=
ffmpeg
+
[
' -i '
,
vid_path
,
' -start_number '
,
' 0 '
,
outformat
]
...
@@ -177,11 +168,13 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
...
@@ -177,11 +168,13 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
return
out_full_path
return
out_full_path
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
paddle
.
enable_imperative
()
paddle
.
enable_imperative
()
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
predictor
=
DeOldifyPredictor
(
args
.
input
,
args
.
output
,
weight_path
=
args
.
weight_path
)
predictor
=
DeOldifyPredictor
(
args
.
input
,
args
.
output
,
weight_path
=
args
.
weight_path
)
frames_path
,
temp_video_path
=
predictor
.
run
()
frames_path
,
temp_video_path
=
predictor
.
run
()
print
(
'output video path:'
,
temp_video_path
)
print
(
'output video path:'
,
temp_video_path
)
applications/DeepRemaster/predict.py
浏览文件 @
b306aa73
...
@@ -15,20 +15,35 @@ import argparse
...
@@ -15,20 +15,35 @@ import argparse
import
subprocess
import
subprocess
import
utils
import
utils
from
remasternet
import
NetworkR
,
NetworkC
from
remasternet
import
NetworkR
,
NetworkC
from
paddle.
incubate.hapi
.download
import
get_path_from_url
from
paddle.
utils
.download
import
get_path_from_url
DeepRemaster_weight_url
=
'https://paddlegan.bj.bcebos.com/applications/deep_remaster.pdparams'
DeepRemaster_weight_url
=
'https://paddlegan.bj.bcebos.com/applications/deep_remaster.pdparams'
parser
=
argparse
.
ArgumentParser
(
description
=
'Remastering'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Remastering'
)
parser
.
add_argument
(
'--input'
,
type
=
str
,
default
=
None
,
help
=
'Input video'
)
parser
.
add_argument
(
'--input'
,
type
=
str
,
default
=
None
,
help
=
'Input video'
)
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
'output'
,
help
=
'output dir'
)
parser
.
add_argument
(
'--output'
,
type
=
str
,
default
=
'output'
,
help
=
'output dir'
)
parser
.
add_argument
(
'--reference_dir'
,
type
=
str
,
default
=
None
,
help
=
'Path to the reference image directory'
)
parser
.
add_argument
(
'--reference_dir'
,
parser
.
add_argument
(
'--colorization'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Remaster without colorization'
)
type
=
str
,
parser
.
add_argument
(
'--mindim'
,
type
=
int
,
default
=
'360'
,
help
=
'Length of minimum image edges'
)
default
=
None
,
help
=
'Path to the reference image directory'
)
parser
.
add_argument
(
'--colorization'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Remaster without colorization'
)
parser
.
add_argument
(
'--mindim'
,
type
=
int
,
default
=
'360'
,
help
=
'Length of minimum image edges'
)
class
DeepReasterPredictor
:
class
DeepReasterPredictor
:
def
__init__
(
self
,
input
,
output
,
weight_path
=
None
,
colorization
=
False
,
reference_dir
=
None
,
mindim
=
360
):
def
__init__
(
self
,
input
,
output
,
weight_path
=
None
,
colorization
=
False
,
reference_dir
=
None
,
mindim
=
360
):
self
.
input
=
input
self
.
input
=
input
self
.
output
=
os
.
path
.
join
(
output
,
'DeepRemaster'
)
self
.
output
=
os
.
path
.
join
(
output
,
'DeepRemaster'
)
self
.
colorization
=
colorization
self
.
colorization
=
colorization
...
@@ -48,55 +63,59 @@ class DeepReasterPredictor:
...
@@ -48,55 +63,59 @@ class DeepReasterPredictor:
self
.
modelC
.
load_dict
(
state_dict
[
'modelC'
])
self
.
modelC
.
load_dict
(
state_dict
[
'modelC'
])
self
.
modelC
.
eval
()
self
.
modelC
.
eval
()
def
run
(
self
):
def
run
(
self
):
outputdir
=
self
.
output
outputdir
=
self
.
output
outputdir_in
=
os
.
path
.
join
(
outputdir
,
'input/'
)
outputdir_in
=
os
.
path
.
join
(
outputdir
,
'input/'
)
os
.
makedirs
(
outputdir_in
,
exist_ok
=
True
)
os
.
makedirs
(
outputdir_in
,
exist_ok
=
True
)
outputdir_out
=
os
.
path
.
join
(
outputdir
,
'output/'
)
outputdir_out
=
os
.
path
.
join
(
outputdir
,
'output/'
)
os
.
makedirs
(
outputdir_out
,
exist_ok
=
True
)
os
.
makedirs
(
outputdir_out
,
exist_ok
=
True
)
# Prepare reference images
# Prepare reference images
if
self
.
colorization
:
if
self
.
colorization
:
if
self
.
reference_dir
is
not
None
:
if
self
.
reference_dir
is
not
None
:
import
glob
import
glob
ext_list
=
[
'png'
,
'jpg'
,
'bmp'
]
ext_list
=
[
'png'
,
'jpg'
,
'bmp'
]
reference_files
=
[]
reference_files
=
[]
for
ext
in
ext_list
:
for
ext
in
ext_list
:
reference_files
+=
glob
.
glob
(
self
.
reference_dir
+
'/*.'
+
ext
,
recursive
=
True
)
reference_files
+=
glob
.
glob
(
self
.
reference_dir
+
'/*.'
+
ext
,
recursive
=
True
)
aspect_mean
=
0
aspect_mean
=
0
minedge_dim
=
256
minedge_dim
=
256
refs
=
[]
refs
=
[]
for
v
in
reference_files
:
for
v
in
reference_files
:
refimg
=
Image
.
open
(
v
).
convert
(
'RGB'
)
refimg
=
Image
.
open
(
v
).
convert
(
'RGB'
)
w
,
h
=
refimg
.
size
w
,
h
=
refimg
.
size
aspect_mean
+=
w
/
h
aspect_mean
+=
w
/
h
refs
.
append
(
refimg
)
refs
.
append
(
refimg
)
aspect_mean
/=
len
(
reference_files
)
aspect_mean
/=
len
(
reference_files
)
target_w
=
int
(
256
*
aspect_mean
)
if
aspect_mean
>
1
else
256
target_w
=
int
(
256
*
aspect_mean
)
if
aspect_mean
>
1
else
256
target_h
=
256
if
aspect_mean
>=
1
else
int
(
256
/
aspect_mean
)
target_h
=
256
if
aspect_mean
>=
1
else
int
(
256
/
aspect_mean
)
refimgs
=
[]
refimgs
=
[]
for
i
,
v
in
enumerate
(
refs
):
for
i
,
v
in
enumerate
(
refs
):
refimg
=
utils
.
addMergin
(
v
,
target_w
=
target_w
,
target_h
=
target_h
)
refimg
=
utils
.
addMergin
(
v
,
refimg
=
np
.
array
(
refimg
).
astype
(
'float32'
).
transpose
(
2
,
0
,
1
)
/
255.0
target_w
=
target_w
,
target_h
=
target_h
)
refimg
=
np
.
array
(
refimg
).
astype
(
'float32'
).
transpose
(
2
,
0
,
1
)
/
255.0
refimgs
.
append
(
refimg
)
refimgs
.
append
(
refimg
)
refimgs
=
paddle
.
to_tensor
(
np
.
array
(
refimgs
).
astype
(
'float32'
))
refimgs
=
paddle
.
to_tensor
(
np
.
array
(
refimgs
).
astype
(
'float32'
))
refimgs
=
paddle
.
unsqueeze
(
refimgs
,
0
)
refimgs
=
paddle
.
unsqueeze
(
refimgs
,
0
)
# Load video
# Load video
cap
=
cv2
.
VideoCapture
(
self
.
input
)
cap
=
cv2
.
VideoCapture
(
self
.
input
)
nframes
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_COUNT
))
nframes
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_COUNT
))
v_w
=
cap
.
get
(
cv2
.
CAP_PROP_FRAME_WIDTH
)
v_w
=
cap
.
get
(
cv2
.
CAP_PROP_FRAME_WIDTH
)
v_h
=
cap
.
get
(
cv2
.
CAP_PROP_FRAME_HEIGHT
)
v_h
=
cap
.
get
(
cv2
.
CAP_PROP_FRAME_HEIGHT
)
minwh
=
min
(
v_w
,
v_h
)
minwh
=
min
(
v_w
,
v_h
)
scale
=
1
scale
=
1
if
minwh
!=
self
.
mindim
:
if
minwh
!=
self
.
mindim
:
scale
=
self
.
mindim
/
minwh
scale
=
self
.
mindim
/
minwh
t_w
=
round
(
v_w
*
scale
/
16.
)
*
16
t_w
=
round
(
v_w
*
scale
/
16.
)
*
16
t_h
=
round
(
v_h
*
scale
/
16.
)
*
16
t_h
=
round
(
v_h
*
scale
/
16.
)
*
16
fps
=
cap
.
get
(
cv2
.
CAP_PROP_FPS
)
fps
=
cap
.
get
(
cv2
.
CAP_PROP_FPS
)
pbar
=
tqdm
(
total
=
nframes
)
pbar
=
tqdm
(
total
=
nframes
)
block
=
5
block
=
5
...
@@ -105,12 +124,12 @@ class DeepReasterPredictor:
...
@@ -105,12 +124,12 @@ class DeepReasterPredictor:
with
paddle
.
no_grad
():
with
paddle
.
no_grad
():
it
=
0
it
=
0
while
True
:
while
True
:
frame_pos
=
it
*
block
frame_pos
=
it
*
block
if
frame_pos
>=
nframes
:
if
frame_pos
>=
nframes
:
break
break
cap
.
set
(
cv2
.
CAP_PROP_POS_FRAMES
,
frame_pos
)
cap
.
set
(
cv2
.
CAP_PROP_POS_FRAMES
,
frame_pos
)
if
block
>=
nframes
-
frame_pos
:
if
block
>=
nframes
-
frame_pos
:
proc_g
=
nframes
-
frame_pos
proc_g
=
nframes
-
frame_pos
else
:
else
:
proc_g
=
block
proc_g
=
block
...
@@ -123,77 +142,96 @@ class DeepReasterPredictor:
...
@@ -123,77 +142,96 @@ class DeepReasterPredictor:
nchannels
=
frame
.
shape
[
2
]
nchannels
=
frame
.
shape
[
2
]
if
nchannels
==
1
or
self
.
colorization
:
if
nchannels
==
1
or
self
.
colorization
:
frame_l
=
cv2
.
cvtColor
(
frame
,
cv2
.
COLOR_RGB2GRAY
)
frame_l
=
cv2
.
cvtColor
(
frame
,
cv2
.
COLOR_RGB2GRAY
)
cv2
.
imwrite
(
outputdir_in
+
'%07d.png'
%
index
,
frame_l
)
cv2
.
imwrite
(
outputdir_in
+
'%07d.png'
%
index
,
frame_l
)
frame_l
=
paddle
.
to_tensor
(
frame_l
.
astype
(
'float32'
))
frame_l
=
paddle
.
to_tensor
(
frame_l
.
astype
(
'float32'
))
frame_l
=
paddle
.
reshape
(
frame_l
,
[
frame_l
.
shape
[
0
],
frame_l
.
shape
[
1
],
1
])
frame_l
=
paddle
.
reshape
(
frame_l
,
[
frame_l
.
shape
[
0
],
frame_l
.
shape
[
1
],
1
])
frame_l
=
paddle
.
transpose
(
frame_l
,
[
2
,
0
,
1
])
frame_l
=
paddle
.
transpose
(
frame_l
,
[
2
,
0
,
1
])
frame_l
/=
255.
frame_l
/=
255.
frame_l
=
paddle
.
reshape
(
frame_l
,
[
1
,
frame_l
.
shape
[
0
],
1
,
frame_l
.
shape
[
1
],
frame_l
.
shape
[
2
]])
frame_l
=
paddle
.
reshape
(
frame_l
,
[
1
,
frame_l
.
shape
[
0
],
1
,
frame_l
.
shape
[
1
],
frame_l
.
shape
[
2
]
])
elif
nchannels
==
3
:
elif
nchannels
==
3
:
cv2
.
imwrite
(
outputdir_in
+
'%07d.png'
%
index
,
frame
)
cv2
.
imwrite
(
outputdir_in
+
'%07d.png'
%
index
,
frame
)
frame
=
frame
[:,:,::
-
1
]
## BGR -> RGB
frame
=
frame
[:,
:,
::
-
1
]
## BGR -> RGB
frame_l
,
frame_ab
=
utils
.
convertRGB2LABTensor
(
frame
)
frame_l
,
frame_ab
=
utils
.
convertRGB2LABTensor
(
frame
)
frame_l
=
frame_l
.
transpose
([
2
,
0
,
1
])
frame_l
=
frame_l
.
transpose
([
2
,
0
,
1
])
frame_ab
=
frame_ab
.
transpose
([
2
,
0
,
1
])
frame_ab
=
frame_ab
.
transpose
([
2
,
0
,
1
])
frame_l
=
frame_l
.
reshape
([
1
,
frame_l
.
shape
[
0
],
1
,
frame_l
.
shape
[
1
],
frame_l
.
shape
[
2
]])
frame_l
=
frame_l
.
reshape
([
frame_ab
=
frame_ab
.
reshape
([
1
,
frame_ab
.
shape
[
0
],
1
,
frame_ab
.
shape
[
1
],
frame_ab
.
shape
[
2
]])
1
,
frame_l
.
shape
[
0
],
1
,
frame_l
.
shape
[
1
],
frame_l
.
shape
[
2
]
])
frame_ab
=
frame_ab
.
reshape
([
1
,
frame_ab
.
shape
[
0
],
1
,
frame_ab
.
shape
[
1
],
frame_ab
.
shape
[
2
]
])
if
input
is
not
None
:
if
input
is
not
None
:
paddle
.
concat
(
(
input
,
frame_l
),
2
)
paddle
.
concat
((
input
,
frame_l
),
2
)
input
=
frame_l
if
i
==
0
else
paddle
.
concat
(
(
input
,
frame_l
),
2
)
input
=
frame_l
if
i
==
0
else
paddle
.
concat
(
if
nchannels
==
3
and
not
self
.
colorization
:
(
input
,
frame_l
),
2
)
gtC
=
frame_ab
if
i
==
0
else
paddle
.
concat
(
(
gtC
,
frame_ab
),
2
)
if
nchannels
==
3
and
not
self
.
colorization
:
gtC
=
frame_ab
if
i
==
0
else
paddle
.
concat
(
(
gtC
,
frame_ab
),
2
)
input
=
paddle
.
to_tensor
(
input
)
input
=
paddle
.
to_tensor
(
input
)
output_l
=
self
.
modelR
(
input
)
# [B, C, T, H, W]
output_l
=
self
.
modelR
(
input
)
# [B, C, T, H, W]
# Save restoration output without colorization when using the option [--disable_colorization]
# Save restoration output without colorization when using the option [--disable_colorization]
if
not
self
.
colorization
:
if
not
self
.
colorization
:
for
i
in
range
(
proc_g
):
for
i
in
range
(
proc_g
):
index
=
frame_pos
+
i
index
=
frame_pos
+
i
if
nchannels
==
3
:
if
nchannels
==
3
:
out_l
=
output_l
.
detach
()[
0
,:,
i
]
out_l
=
output_l
.
detach
()[
0
,
:,
i
]
out_ab
=
gtC
[
0
,:,
i
]
out_ab
=
gtC
[
0
,
:,
i
]
out
=
paddle
.
concat
((
out_l
,
out_ab
),
axis
=
0
).
detach
().
numpy
().
transpose
((
1
,
2
,
0
))
out
=
paddle
.
concat
(
out
=
Image
.
fromarray
(
np
.
uint8
(
utils
.
convertLAB2RGB
(
out
)
*
255
)
)
(
out_l
,
out_ab
),
out
.
save
(
outputdir_out
+
'%07d.png'
%
(
index
)
)
axis
=
0
).
detach
().
numpy
().
transpose
((
1
,
2
,
0
))
out
=
Image
.
fromarray
(
np
.
uint8
(
utils
.
convertLAB2RGB
(
out
)
*
255
))
out
.
save
(
outputdir_out
+
'%07d.png'
%
(
index
))
else
:
else
:
raise
ValueError
(
'channels of imag3 must be 3!'
)
raise
ValueError
(
'channels of imag3 must be 3!'
)
# Perform colorization
# Perform colorization
else
:
else
:
if
self
.
reference_dir
is
None
:
if
self
.
reference_dir
is
None
:
output_ab
=
self
.
modelC
(
output_l
)
output_ab
=
self
.
modelC
(
output_l
)
else
:
else
:
output_ab
=
self
.
modelC
(
output_l
,
refimgs
)
output_ab
=
self
.
modelC
(
output_l
,
refimgs
)
output_l
=
output_l
.
detach
()
output_l
=
output_l
.
detach
()
output_ab
=
output_ab
.
detach
()
output_ab
=
output_ab
.
detach
()
for
i
in
range
(
proc_g
):
for
i
in
range
(
proc_g
):
index
=
frame_pos
+
i
index
=
frame_pos
+
i
out_l
=
output_l
[
0
,:,
i
,:,:]
out_l
=
output_l
[
0
,
:,
i
,
:,
:]
out_c
=
output_ab
[
0
,:,
i
,:,:]
out_c
=
output_ab
[
0
,
:,
i
,
:,
:]
output
=
paddle
.
concat
((
out_l
,
out_c
),
axis
=
0
).
numpy
().
transpose
((
1
,
2
,
0
))
output
=
paddle
.
concat
(
output
=
Image
.
fromarray
(
np
.
uint8
(
utils
.
convertLAB2RGB
(
output
)
*
255
)
)
(
out_l
,
out_c
),
axis
=
0
).
numpy
().
transpose
((
1
,
2
,
0
))
output
.
save
(
outputdir_out
+
'%07d.png'
%
index
)
output
=
Image
.
fromarray
(
np
.
uint8
(
utils
.
convertLAB2RGB
(
output
)
*
255
))
output
.
save
(
outputdir_out
+
'%07d.png'
%
index
)
it
=
it
+
1
it
=
it
+
1
pbar
.
update
(
proc_g
)
pbar
.
update
(
proc_g
)
# Save result videos
# Save result videos
outfile
=
os
.
path
.
join
(
outputdir
,
self
.
input
.
split
(
'/'
)[
-
1
].
split
(
'.'
)[
0
])
outfile
=
os
.
path
.
join
(
outputdir
,
cmd
=
'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_in.mp4'
%
(
fps
,
outputdir_in
,
fps
,
outfile
)
self
.
input
.
split
(
'/'
)[
-
1
].
split
(
'.'
)[
0
])
subprocess
.
call
(
cmd
,
shell
=
True
)
cmd
=
'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_in.mp4'
%
(
cmd
=
'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_out.mp4'
%
(
fps
,
outputdir_out
,
fps
,
outfile
)
fps
,
outputdir_in
,
fps
,
outfile
)
subprocess
.
call
(
cmd
,
shell
=
True
)
subprocess
.
call
(
cmd
,
shell
=
True
)
cmd
=
'ffmpeg -y -i %s_in.mp4 -vf "[in] pad=2.01*iw:ih [left];movie=%s_out.mp4[right];[left][right] overlay=main_w/2:0,scale=2*iw/2:2*ih/2[out]" %s_comp.mp4'
%
(
outfile
,
outfile
,
outfile
)
cmd
=
'ffmpeg -y -r %d -i %s%%07d.png -vcodec libx264 -pix_fmt yuv420p -r %d %s_out.mp4'
%
(
subprocess
.
call
(
cmd
,
shell
=
True
)
fps
,
outputdir_out
,
fps
,
outfile
)
subprocess
.
call
(
cmd
,
shell
=
True
)
cmd
=
'ffmpeg -y -i %s_in.mp4 -vf "[in] pad=2.01*iw:ih [left];movie=%s_out.mp4[right];[left][right] overlay=main_w/2:0,scale=2*iw/2:2*ih/2[out]" %s_comp.mp4'
%
(
outfile
,
outfile
,
outfile
)
subprocess
.
call
(
cmd
,
shell
=
True
)
cap
.
release
()
cap
.
release
()
pbar
.
close
()
pbar
.
close
()
...
@@ -203,7 +241,9 @@ class DeepReasterPredictor:
...
@@ -203,7 +241,9 @@ class DeepReasterPredictor:
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
paddle
.
disable_static
()
paddle
.
disable_static
()
predictor
=
DeepReasterPredictor
(
args
.
input
,
args
.
output
,
colorization
=
args
.
colorization
,
predictor
=
DeepReasterPredictor
(
args
.
input
,
reference_dir
=
args
.
reference_dir
,
mindim
=
args
.
mindim
)
args
.
output
,
colorization
=
args
.
colorization
,
reference_dir
=
args
.
reference_dir
,
mindim
=
args
.
mindim
)
predictor
.
run
()
predictor
.
run
()
\ No newline at end of file
applications/EDVR/predict.py
浏览文件 @
b306aa73
...
@@ -28,30 +28,29 @@ import paddle.fluid as fluid
...
@@ -28,30 +28,29 @@ import paddle.fluid as fluid
import
cv2
import
cv2
from
data
import
EDVRDataset
from
data
import
EDVRDataset
from
paddle.
incubate.hapi
.download
import
get_path_from_url
from
paddle.
utils
.download
import
get_path_from_url
EDVR_weight_url
=
'https://paddlegan.bj.bcebos.com/applications/edvr_infer_model.tar'
EDVR_weight_url
=
'https://paddlegan.bj.bcebos.com/applications/edvr_infer_model.tar'
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
parser
.
add_argument
(
'--input'
,
'--input'
,
type
=
str
,
type
=
str
,
default
=
None
,
default
=
None
,
help
=
'input video path'
)
help
=
'input video path'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--output'
,
'--output'
,
type
=
str
,
type
=
str
,
default
=
'output'
,
default
=
'output'
,
help
=
'output path'
)
help
=
'output path'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--weight_path'
,
'--weight_path'
,
type
=
str
,
type
=
str
,
default
=
None
,
default
=
None
,
help
=
'weight path'
)
help
=
'weight path'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
args
return
args
def
get_img
(
pred
):
def
get_img
(
pred
):
print
(
'pred shape'
,
pred
.
shape
)
print
(
'pred shape'
,
pred
.
shape
)
pred
=
pred
.
squeeze
()
pred
=
pred
.
squeeze
()
...
@@ -63,6 +62,7 @@ def get_img(pred):
...
@@ -63,6 +62,7 @@ def get_img(pred):
pred
=
pred
[:,
:,
::
-
1
]
# rgb -> bgr
pred
=
pred
[:,
:,
::
-
1
]
# rgb -> bgr
return
pred
return
pred
def
save_img
(
img
,
framename
):
def
save_img
(
img
,
framename
):
dirname
=
os
.
path
.
dirname
(
framename
)
dirname
=
os
.
path
.
dirname
(
framename
)
if
not
os
.
path
.
exists
(
dirname
):
if
not
os
.
path
.
exists
(
dirname
):
...
@@ -84,19 +84,8 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
...
@@ -84,19 +84,8 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
if
ss
is
not
None
and
t
is
not
None
and
r
is
not
None
:
if
ss
is
not
None
and
t
is
not
None
and
r
is
not
None
:
cmd
=
ffmpeg
+
[
cmd
=
ffmpeg
+
[
' -ss '
,
' -ss '
,
ss
,
' -t '
,
t
,
' -i '
,
vid_path
,
' -r '
,
r
,
' -qscale:v '
,
ss
,
' 0.1 '
,
' -start_number '
,
' 0 '
,
outformat
' -t '
,
t
,
' -i '
,
vid_path
,
' -r '
,
r
,
' -qscale:v '
,
' 0.1 '
,
' -start_number '
,
' 0 '
,
outformat
]
]
else
:
else
:
cmd
=
ffmpeg
+
[
' -i '
,
vid_path
,
' -start_number '
,
' 0 '
,
outformat
]
cmd
=
ffmpeg
+
[
' -i '
,
vid_path
,
' -start_number '
,
' 0 '
,
outformat
]
...
@@ -134,7 +123,8 @@ class EDVRPredictor:
...
@@ -134,7 +123,8 @@ class EDVRPredictor:
self
.
input
=
input
self
.
input
=
input
self
.
output
=
os
.
path
.
join
(
output
,
'EDVR'
)
self
.
output
=
os
.
path
.
join
(
output
,
'EDVR'
)
place
=
fluid
.
CUDAPlace
(
0
)
if
fluid
.
is_compiled_with_cuda
()
else
fluid
.
CPUPlace
()
place
=
fluid
.
CUDAPlace
(
0
)
if
fluid
.
is_compiled_with_cuda
()
else
fluid
.
CPUPlace
()
self
.
exe
=
fluid
.
Executor
(
place
)
self
.
exe
=
fluid
.
Executor
(
place
)
if
weight_path
is
None
:
if
weight_path
is
None
:
...
@@ -177,15 +167,18 @@ class EDVRPredictor:
...
@@ -177,15 +167,18 @@ class EDVRPredictor:
for
infer_iter
,
data
in
enumerate
(
dataset
):
for
infer_iter
,
data
in
enumerate
(
dataset
):
data_feed_in
=
[
data
[
0
]]
data_feed_in
=
[
data
[
0
]]
infer_outs
=
self
.
exe
.
run
(
self
.
infer_prog
,
infer_outs
=
self
.
exe
.
run
(
self
.
infer_prog
,
fetch_list
=
self
.
fetch_list
,
fetch_list
=
self
.
fetch_list
,
feed
=
{
self
.
feed_list
[
0
]:
np
.
array
(
data_feed_in
)})
feed
=
{
self
.
feed_list
[
0
]:
np
.
array
(
data_feed_in
)})
infer_result_list
=
[
item
for
item
in
infer_outs
]
infer_result_list
=
[
item
for
item
in
infer_outs
]
frame_path
=
data
[
1
]
frame_path
=
data
[
1
]
img_i
=
get_img
(
infer_result_list
[
0
])
img_i
=
get_img
(
infer_result_list
[
0
])
save_img
(
img_i
,
os
.
path
.
join
(
pred_frame_path
,
os
.
path
.
basename
(
frame_path
)))
save_img
(
img_i
,
os
.
path
.
join
(
pred_frame_path
,
os
.
path
.
basename
(
frame_path
)))
prev_time
=
cur_time
prev_time
=
cur_time
cur_time
=
time
.
time
()
cur_time
=
time
.
time
()
...
@@ -194,13 +187,15 @@ class EDVRPredictor:
...
@@ -194,13 +187,15 @@ class EDVRPredictor:
print
(
'Processed {} samples'
.
format
(
infer_iter
+
1
))
print
(
'Processed {} samples'
.
format
(
infer_iter
+
1
))
frame_pattern_combined
=
os
.
path
.
join
(
pred_frame_path
,
'%08d.png'
)
frame_pattern_combined
=
os
.
path
.
join
(
pred_frame_path
,
'%08d.png'
)
vid_out_path
=
os
.
path
.
join
(
self
.
output
,
'{}_edvr_out.mp4'
.
format
(
base_name
))
vid_out_path
=
os
.
path
.
join
(
self
.
output
,
frames_to_video_ffmpeg
(
frame_pattern_combined
,
vid_out_path
,
str
(
int
(
fps
)))
'{}_edvr_out.mp4'
.
format
(
base_name
))
frames_to_video_ffmpeg
(
frame_pattern_combined
,
vid_out_path
,
str
(
int
(
fps
)))
return
frame_pattern_combined
,
vid_out_path
return
frame_pattern_combined
,
vid_out_path
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
parse_args
()
predictor
=
EDVRPredictor
(
args
.
input
,
args
.
output
,
args
.
weight_path
)
predictor
=
EDVRPredictor
(
args
.
input
,
args
.
output
,
args
.
weight_path
)
predictor
.
run
()
predictor
.
run
()
configs/cyclegan_cityscapes.yaml
浏览文件 @
b306aa73
...
@@ -60,7 +60,8 @@ dataset:
...
@@ -60,7 +60,8 @@ dataset:
optimizer
:
optimizer
:
name
:
Adam
name
:
Adam
beta1
:
0.5
beta1
:
0.5
lr_scheduler
:
lr_scheduler
:
name
:
linear
name
:
linear
learning_rate
:
0.0002
learning_rate
:
0.0002
start_epoch
:
100
start_epoch
:
100
...
@@ -72,4 +73,3 @@ log_config:
...
@@ -72,4 +73,3 @@ log_config:
snapshot_config
:
snapshot_config
:
interval
:
5
interval
:
5
configs/cyclegan_horse2zebra.yaml
浏览文件 @
b306aa73
...
@@ -59,7 +59,8 @@ dataset:
...
@@ -59,7 +59,8 @@ dataset:
optimizer
:
optimizer
:
name
:
Adam
name
:
Adam
beta1
:
0.5
beta1
:
0.5
lr_scheduler
:
lr_scheduler
:
name
:
linear
name
:
linear
learning_rate
:
0.0002
learning_rate
:
0.0002
start_epoch
:
100
start_epoch
:
100
...
@@ -71,4 +72,3 @@ log_config:
...
@@ -71,4 +72,3 @@ log_config:
snapshot_config
:
snapshot_config
:
interval
:
5
interval
:
5
configs/pix2pix_cityscapes.yaml
浏览文件 @
b306aa73
...
@@ -25,7 +25,7 @@ dataset:
...
@@ -25,7 +25,7 @@ dataset:
train
:
train
:
name
:
PairedDataset
name
:
PairedDataset
dataroot
:
data/cityscapes
dataroot
:
data/cityscapes
num_workers
:
0
num_workers
:
4
phase
:
train
phase
:
train
max_dataset_size
:
inf
max_dataset_size
:
inf
direction
:
BtoA
direction
:
BtoA
...
@@ -57,7 +57,8 @@ dataset:
...
@@ -57,7 +57,8 @@ dataset:
optimizer
:
optimizer
:
name
:
Adam
name
:
Adam
beta1
:
0.5
beta1
:
0.5
lr_scheduler
:
lr_scheduler
:
name
:
linear
name
:
linear
learning_rate
:
0.0002
learning_rate
:
0.0002
start_epoch
:
100
start_epoch
:
100
...
@@ -69,4 +70,3 @@ log_config:
...
@@ -69,4 +70,3 @@ log_config:
snapshot_config
:
snapshot_config
:
interval
:
5
interval
:
5
configs/pix2pix_cityscapes_2gpus.yaml
浏览文件 @
b306aa73
...
@@ -56,7 +56,8 @@ dataset:
...
@@ -56,7 +56,8 @@ dataset:
optimizer
:
optimizer
:
name
:
Adam
name
:
Adam
beta1
:
0.5
beta1
:
0.5
lr_scheduler
:
lr_scheduler
:
name
:
linear
name
:
linear
learning_rate
:
0.0004
learning_rate
:
0.0004
start_epoch
:
100
start_epoch
:
100
...
@@ -68,4 +69,3 @@ log_config:
...
@@ -68,4 +69,3 @@ log_config:
snapshot_config
:
snapshot_config
:
interval
:
5
interval
:
5
configs/pix2pix_facades.yaml
浏览文件 @
b306aa73
...
@@ -56,7 +56,8 @@ dataset:
...
@@ -56,7 +56,8 @@ dataset:
optimizer
:
optimizer
:
name
:
Adam
name
:
Adam
beta1
:
0.5
beta1
:
0.5
lr_scheduler
:
lr_scheduler
:
name
:
linear
name
:
linear
learning_rate
:
0.0002
learning_rate
:
0.0002
start_epoch
:
100
start_epoch
:
100
...
...
ppgan/datasets/base_dataset.py
浏览文件 @
b306aa73
...
@@ -6,7 +6,7 @@ from paddle.io import Dataset
...
@@ -6,7 +6,7 @@ from paddle.io import Dataset
from
PIL
import
Image
from
PIL
import
Image
import
cv2
import
cv2
import
paddle.
incubate.hapi.
vision.transforms
as
transforms
import
paddle.vision.transforms
as
transforms
from
.transforms
import
transforms
as
T
from
.transforms
import
transforms
as
T
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
...
@@ -14,7 +14,6 @@ from abc import ABC, abstractmethod
...
@@ -14,7 +14,6 @@ from abc import ABC, abstractmethod
class
BaseDataset
(
Dataset
,
ABC
):
class
BaseDataset
(
Dataset
,
ABC
):
"""This class is an abstract base class (ABC) for datasets.
"""This class is an abstract base class (ABC) for datasets.
"""
"""
def
__init__
(
self
,
cfg
):
def
__init__
(
self
,
cfg
):
"""Initialize the class; save the options in the class
"""Initialize the class; save the options in the class
...
@@ -60,8 +59,11 @@ def get_params(cfg, size):
...
@@ -60,8 +59,11 @@ def get_params(cfg, size):
return
{
'crop_pos'
:
(
x
,
y
),
'flip'
:
flip
}
return
{
'crop_pos'
:
(
x
,
y
),
'flip'
:
flip
}
def
get_transform
(
cfg
,
def
get_transform
(
cfg
,
params
=
None
,
grayscale
=
False
,
method
=
cv2
.
INTER_CUBIC
,
convert
=
True
):
params
=
None
,
grayscale
=
False
,
method
=
cv2
.
INTER_CUBIC
,
convert
=
True
):
transform_list
=
[]
transform_list
=
[]
if
grayscale
:
if
grayscale
:
print
(
'grayscale not support for now!!!'
)
print
(
'grayscale not support for now!!!'
)
...
@@ -92,5 +94,7 @@ def get_transform(cfg, params=None, grayscale=False, method=cv2.INTER_CUBIC, con
...
@@ -92,5 +94,7 @@ def get_transform(cfg, params=None, grayscale=False, method=cv2.INTER_CUBIC, con
if
convert
:
if
convert
:
transform_list
+=
[
transforms
.
Permute
(
to_rgb
=
True
)]
transform_list
+=
[
transforms
.
Permute
(
to_rgb
=
True
)]
transform_list
+=
[
transforms
.
Normalize
((
127.5
,
127.5
,
127.5
),
(
127.5
,
127.5
,
127.5
))]
transform_list
+=
[
transforms
.
Normalize
((
127.5
,
127.5
,
127.5
),
(
127.5
,
127.5
,
127.5
))
]
return
transforms
.
Compose
(
transform_list
)
return
transforms
.
Compose
(
transform_list
)
ppgan/datasets/builder.py
浏览文件 @
b306aa73
...
@@ -3,12 +3,11 @@ import paddle
...
@@ -3,12 +3,11 @@ import paddle
import
numbers
import
numbers
import
numpy
as
np
import
numpy
as
np
from
multiprocessing
import
Manager
from
multiprocessing
import
Manager
from
paddle
import
ParallelEnv
from
paddle
.distributed
import
ParallelEnv
from
paddle.i
ncubate.hapi.distributed
import
DistributedBatchSampler
from
paddle.i
o
import
DistributedBatchSampler
from
..utils.registry
import
Registry
from
..utils.registry
import
Registry
DATASETS
=
Registry
(
"DATASETS"
)
DATASETS
=
Registry
(
"DATASETS"
)
...
@@ -60,14 +59,12 @@ class DictDataLoader():
...
@@ -60,14 +59,12 @@ class DictDataLoader():
place
=
paddle
.
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
\
place
=
paddle
.
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
\
if
ParallelEnv
().
nranks
>
1
else
paddle
.
fluid
.
CUDAPlace
(
0
)
if
ParallelEnv
().
nranks
>
1
else
paddle
.
fluid
.
CUDAPlace
(
0
)
sampler
=
DistributedBatchSampler
(
sampler
=
DistributedBatchSampler
(
self
.
dataset
,
self
.
dataset
,
batch_size
=
batch_size
,
batch_size
=
batch_size
,
shuffle
=
True
if
is_train
else
False
,
shuffle
=
True
if
is_train
else
False
,
drop_last
=
True
if
is_train
else
False
)
drop_last
=
True
if
is_train
else
False
)
self
.
dataloader
=
paddle
.
io
.
DataLoader
(
self
.
dataloader
=
paddle
.
io
.
DataLoader
(
self
.
dataset
,
self
.
dataset
,
batch_sampler
=
sampler
,
batch_sampler
=
sampler
,
places
=
place
,
places
=
place
,
num_workers
=
num_workers
)
num_workers
=
num_workers
)
...
@@ -83,7 +80,9 @@ class DictDataLoader():
...
@@ -83,7 +80,9 @@ class DictDataLoader():
j
=
0
j
=
0
for
k
in
self
.
dataset
.
keys
:
for
k
in
self
.
dataset
.
keys
:
if
k
in
self
.
dataset
.
tensor_keys_set
:
if
k
in
self
.
dataset
.
tensor_keys_set
:
return_dict
[
k
]
=
data
[
j
]
if
isinstance
(
data
,
(
list
,
tuple
))
else
data
return_dict
[
k
]
=
data
[
j
]
if
isinstance
(
data
,
(
list
,
tuple
))
else
data
j
+=
1
j
+=
1
else
:
else
:
return_dict
[
k
]
=
self
.
get_items_by_indexs
(
k
,
data
[
-
1
])
return_dict
[
k
]
=
self
.
get_items_by_indexs
(
k
,
data
[
-
1
])
...
@@ -104,7 +103,6 @@ class DictDataLoader():
...
@@ -104,7 +103,6 @@ class DictDataLoader():
return
current_items
return
current_items
def
build_dataloader
(
cfg
,
is_train
=
True
):
def
build_dataloader
(
cfg
,
is_train
=
True
):
dataset
=
DATASETS
.
get
(
cfg
.
name
)(
cfg
)
dataset
=
DATASETS
.
get
(
cfg
.
name
)(
cfg
)
...
...
ppgan/engine/trainer.py
浏览文件 @
b306aa73
...
@@ -4,7 +4,7 @@ import time
...
@@ -4,7 +4,7 @@ import time
import
logging
import
logging
import
paddle
import
paddle
from
paddle
import
ParallelEnv
,
DataParallel
from
paddle
.distributed
import
ParallelEnv
from
..datasets.builder
import
build_dataloader
from
..datasets.builder
import
build_dataloader
from
..models.builder
import
build_model
from
..models.builder
import
build_model
...
@@ -19,7 +19,8 @@ class Trainer:
...
@@ -19,7 +19,8 @@ class Trainer:
self
.
train_dataloader
=
build_dataloader
(
cfg
.
dataset
.
train
)
self
.
train_dataloader
=
build_dataloader
(
cfg
.
dataset
.
train
)
if
'lr_scheduler'
in
cfg
.
optimizer
:
if
'lr_scheduler'
in
cfg
.
optimizer
:
cfg
.
optimizer
.
lr_scheduler
.
step_per_epoch
=
len
(
self
.
train_dataloader
)
cfg
.
optimizer
.
lr_scheduler
.
step_per_epoch
=
len
(
self
.
train_dataloader
)
# build model
# build model
self
.
model
=
build_model
(
cfg
)
self
.
model
=
build_model
(
cfg
)
...
@@ -50,7 +51,8 @@ class Trainer:
...
@@ -50,7 +51,8 @@ class Trainer:
for
name
in
self
.
model
.
model_names
:
for
name
in
self
.
model
.
model_names
:
if
isinstance
(
name
,
str
):
if
isinstance
(
name
,
str
):
net
=
getattr
(
self
.
model
,
'net'
+
name
)
net
=
getattr
(
self
.
model
,
'net'
+
name
)
setattr
(
self
.
model
,
'net'
+
name
,
DataParallel
(
net
,
strategy
))
setattr
(
self
.
model
,
'net'
+
name
,
paddle
.
DataParallel
(
net
,
strategy
))
def
train
(
self
):
def
train
(
self
):
...
@@ -74,14 +76,17 @@ class Trainer:
...
@@ -74,14 +76,17 @@ class Trainer:
self
.
visual
(
'visual_train'
)
self
.
visual
(
'visual_train'
)
step_start_time
=
time
.
time
()
step_start_time
=
time
.
time
()
self
.
logger
.
info
(
'train one epoch time: {}'
.
format
(
time
.
time
()
-
start_time
))
self
.
logger
.
info
(
'train one epoch time: {}'
.
format
(
time
.
time
()
-
start_time
))
self
.
model
.
lr_scheduler
.
step
()
if
epoch
%
self
.
weight_interval
==
0
:
if
epoch
%
self
.
weight_interval
==
0
:
self
.
save
(
epoch
,
'weight'
,
keep
=-
1
)
self
.
save
(
epoch
,
'weight'
,
keep
=-
1
)
self
.
save
(
epoch
)
self
.
save
(
epoch
)
def
test
(
self
):
def
test
(
self
):
if
not
hasattr
(
self
,
'test_dataloader'
):
if
not
hasattr
(
self
,
'test_dataloader'
):
self
.
test_dataloader
=
build_dataloader
(
self
.
cfg
.
dataset
.
test
,
is_train
=
False
)
self
.
test_dataloader
=
build_dataloader
(
self
.
cfg
.
dataset
.
test
,
is_train
=
False
)
# data[0]: img, data[1]: img path index
# data[0]: img, data[1]: img path index
# test batch size must be 1
# test batch size must be 1
...
@@ -105,7 +110,8 @@ class Trainer:
...
@@ -105,7 +110,8 @@ class Trainer:
self
.
visual
(
'visual_test'
,
visual_results
=
visual_results
)
self
.
visual
(
'visual_test'
,
visual_results
=
visual_results
)
if
i
%
self
.
log_interval
==
0
:
if
i
%
self
.
log_interval
==
0
:
self
.
logger
.
info
(
'Test iter: [%d/%d]'
%
(
i
,
len
(
self
.
test_dataloader
)))
self
.
logger
.
info
(
'Test iter: [%d/%d]'
%
(
i
,
len
(
self
.
test_dataloader
)))
def
print_log
(
self
):
def
print_log
(
self
):
losses
=
self
.
model
.
get_current_losses
()
losses
=
self
.
model
.
get_current_losses
()
...
@@ -143,7 +149,8 @@ class Trainer:
...
@@ -143,7 +149,8 @@ class Trainer:
makedirs
(
os
.
path
.
join
(
self
.
output_dir
,
results_dir
))
makedirs
(
os
.
path
.
join
(
self
.
output_dir
,
results_dir
))
for
label
,
image
in
visual_results
.
items
():
for
label
,
image
in
visual_results
.
items
():
image_numpy
=
tensor2img
(
image
)
image_numpy
=
tensor2img
(
image
)
img_path
=
os
.
path
.
join
(
self
.
output_dir
,
results_dir
,
msg
+
'%s.png'
%
(
label
))
img_path
=
os
.
path
.
join
(
self
.
output_dir
,
results_dir
,
msg
+
'%s.png'
%
(
label
))
save_image
(
image_numpy
,
img_path
)
save_image
(
image_numpy
,
img_path
)
def
save
(
self
,
epoch
,
name
=
'checkpoint'
,
keep
=
1
):
def
save
(
self
,
epoch
,
name
=
'checkpoint'
,
keep
=
1
):
...
@@ -175,8 +182,8 @@ class Trainer:
...
@@ -175,8 +182,8 @@ class Trainer:
if
keep
>
0
:
if
keep
>
0
:
try
:
try
:
checkpoint_name_to_be_removed
=
os
.
path
.
join
(
self
.
output_dir
,
checkpoint_name_to_be_removed
=
os
.
path
.
join
(
'epoch_%s_%s.pkl'
%
(
epoch
-
keep
,
name
))
self
.
output_dir
,
'epoch_%s_%s.pkl'
%
(
epoch
-
keep
,
name
))
if
os
.
path
.
exists
(
checkpoint_name_to_be_removed
):
if
os
.
path
.
exists
(
checkpoint_name_to_be_removed
):
os
.
remove
(
checkpoint_name_to_be_removed
)
os
.
remove
(
checkpoint_name_to_be_removed
)
...
@@ -205,4 +212,3 @@ class Trainer:
...
@@ -205,4 +212,3 @@ class Trainer:
if
isinstance
(
name
,
str
):
if
isinstance
(
name
,
str
):
net
=
getattr
(
self
.
model
,
'net'
+
name
)
net
=
getattr
(
self
.
model
,
'net'
+
name
)
net
.
set_dict
(
state_dicts
[
'net'
+
name
])
net
.
set_dict
(
state_dicts
[
'net'
+
name
])
\ No newline at end of file
ppgan/models/base_model.py
浏览文件 @
b306aa73
...
@@ -5,6 +5,7 @@ import numpy as np
...
@@ -5,6 +5,7 @@ import numpy as np
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
..solver.lr_scheduler
import
build_lr_scheduler
class
BaseModel
(
ABC
):
class
BaseModel
(
ABC
):
...
@@ -16,7 +17,6 @@ class BaseModel(ABC):
...
@@ -16,7 +17,6 @@ class BaseModel(ABC):
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
"""
"""
def
__init__
(
self
,
opt
):
def
__init__
(
self
,
opt
):
"""Initialize the BaseModel class.
"""Initialize the BaseModel class.
...
@@ -33,7 +33,9 @@ class BaseModel(ABC):
...
@@ -33,7 +33,9 @@ class BaseModel(ABC):
"""
"""
self
.
opt
=
opt
self
.
opt
=
opt
self
.
isTrain
=
opt
.
isTrain
self
.
isTrain
=
opt
.
isTrain
self
.
save_dir
=
os
.
path
.
join
(
opt
.
output_dir
,
opt
.
model
.
name
)
# save all the checkpoints to save_dir
self
.
save_dir
=
os
.
path
.
join
(
opt
.
output_dir
,
opt
.
model
.
name
)
# save all the checkpoints to save_dir
self
.
loss_names
=
[]
self
.
loss_names
=
[]
self
.
model_names
=
[]
self
.
model_names
=
[]
...
@@ -75,6 +77,8 @@ class BaseModel(ABC):
...
@@ -75,6 +77,8 @@ class BaseModel(ABC):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
pass
pass
def
build_lr_scheduler
(
self
):
self
.
lr_scheduler
=
build_lr_scheduler
(
self
.
opt
.
lr_scheduler
)
def
eval
(
self
):
def
eval
(
self
):
"""Make models eval mode during test time"""
"""Make models eval mode during test time"""
...
@@ -114,10 +118,11 @@ class BaseModel(ABC):
...
@@ -114,10 +118,11 @@ class BaseModel(ABC):
errors_ret
=
OrderedDict
()
errors_ret
=
OrderedDict
()
for
name
in
self
.
loss_names
:
for
name
in
self
.
loss_names
:
if
isinstance
(
name
,
str
):
if
isinstance
(
name
,
str
):
errors_ret
[
name
]
=
float
(
getattr
(
self
,
'loss_'
+
name
))
# float(...) works for both scalar tensor and float number
errors_ret
[
name
]
=
float
(
getattr
(
self
,
'loss_'
+
name
)
)
# float(...) works for both scalar tensor and float number
return
errors_ret
return
errors_ret
def
set_requires_grad
(
self
,
nets
,
requires_grad
=
False
):
def
set_requires_grad
(
self
,
nets
,
requires_grad
=
False
):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
Parameters:
...
...
ppgan/models/cycle_gan_model.py
浏览文件 @
b306aa73
import
paddle
import
paddle
from
paddle
import
ParallelEnv
from
paddle
.distributed
import
ParallelEnv
from
.base_model
import
BaseModel
from
.base_model
import
BaseModel
from
.builder
import
MODELS
from
.builder
import
MODELS
...
@@ -23,7 +23,6 @@ class CycleGANModel(BaseModel):
...
@@ -23,7 +23,6 @@ class CycleGANModel(BaseModel):
CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
"""
"""
def
__init__
(
self
,
opt
):
def
__init__
(
self
,
opt
):
"""Initialize the CycleGAN class.
"""Initialize the CycleGAN class.
...
@@ -32,7 +31,9 @@ class CycleGANModel(BaseModel):
...
@@ -32,7 +31,9 @@ class CycleGANModel(BaseModel):
"""
"""
BaseModel
.
__init__
(
self
,
opt
)
BaseModel
.
__init__
(
self
,
opt
)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self
.
loss_names
=
[
'D_A'
,
'G_A'
,
'cycle_A'
,
'idt_A'
,
'D_B'
,
'G_B'
,
'cycle_B'
,
'idt_B'
]
self
.
loss_names
=
[
'D_A'
,
'G_A'
,
'cycle_A'
,
'idt_A'
,
'D_B'
,
'G_B'
,
'cycle_B'
,
'idt_B'
]
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
visual_names_A
=
[
'real_A'
,
'fake_B'
,
'rec_A'
]
visual_names_A
=
[
'real_A'
,
'fake_B'
,
'rec_A'
]
visual_names_B
=
[
'real_B'
,
'fake_A'
,
'rec_B'
]
visual_names_B
=
[
'real_B'
,
'fake_A'
,
'rec_B'
]
...
@@ -62,7 +63,8 @@ class CycleGANModel(BaseModel):
...
@@ -62,7 +63,8 @@ class CycleGANModel(BaseModel):
if
self
.
isTrain
:
if
self
.
isTrain
:
if
opt
.
lambda_identity
>
0.0
:
# only works when input and output images have the same number of channels
if
opt
.
lambda_identity
>
0.0
:
# only works when input and output images have the same number of channels
assert
(
opt
.
dataset
.
train
.
input_nc
==
opt
.
dataset
.
train
.
output_nc
)
assert
(
opt
.
dataset
.
train
.
input_nc
==
opt
.
dataset
.
train
.
output_nc
)
# create image buffer to store previously generated images
# create image buffer to store previously generated images
self
.
fake_A_pool
=
ImagePool
(
opt
.
dataset
.
train
.
pool_size
)
self
.
fake_A_pool
=
ImagePool
(
opt
.
dataset
.
train
.
pool_size
)
# create image buffer to store previously generated images
# create image buffer to store previously generated images
...
@@ -72,8 +74,17 @@ class CycleGANModel(BaseModel):
...
@@ -72,8 +74,17 @@ class CycleGANModel(BaseModel):
self
.
criterionCycle
=
paddle
.
nn
.
L1Loss
()
self
.
criterionCycle
=
paddle
.
nn
.
L1Loss
()
self
.
criterionIdt
=
paddle
.
nn
.
L1Loss
()
self
.
criterionIdt
=
paddle
.
nn
.
L1Loss
()
self
.
optimizer_G
=
build_optimizer
(
opt
.
optimizer
,
parameter_list
=
self
.
netG_A
.
parameters
()
+
self
.
netG_B
.
parameters
())
self
.
build_lr_scheduler
()
self
.
optimizer_D
=
build_optimizer
(
opt
.
optimizer
,
parameter_list
=
self
.
netD_A
.
parameters
()
+
self
.
netD_B
.
parameters
())
self
.
optimizer_G
=
build_optimizer
(
opt
.
optimizer
,
self
.
lr_scheduler
,
parameter_list
=
self
.
netG_A
.
parameters
()
+
self
.
netG_B
.
parameters
())
self
.
optimizer_D
=
build_optimizer
(
opt
.
optimizer
,
self
.
lr_scheduler
,
parameter_list
=
self
.
netD_A
.
parameters
()
+
self
.
netD_B
.
parameters
())
self
.
optimizers
.
append
(
self
.
optimizer_G
)
self
.
optimizers
.
append
(
self
.
optimizer_G
)
self
.
optimizers
.
append
(
self
.
optimizer_D
)
self
.
optimizers
.
append
(
self
.
optimizer_D
)
...
@@ -107,7 +118,6 @@ class CycleGANModel(BaseModel):
...
@@ -107,7 +118,6 @@ class CycleGANModel(BaseModel):
elif
'B_paths'
in
input
:
elif
'B_paths'
in
input
:
self
.
image_paths
=
input
[
'B_paths'
]
self
.
image_paths
=
input
[
'B_paths'
]
def
forward
(
self
):
def
forward
(
self
):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
if
hasattr
(
self
,
'real_A'
):
if
hasattr
(
self
,
'real_A'
):
...
@@ -118,7 +128,6 @@ class CycleGANModel(BaseModel):
...
@@ -118,7 +128,6 @@ class CycleGANModel(BaseModel):
self
.
fake_A
=
self
.
netG_B
(
self
.
real_B
)
# G_B(B)
self
.
fake_A
=
self
.
netG_B
(
self
.
real_B
)
# G_B(B)
self
.
rec_B
=
self
.
netG_A
(
self
.
fake_A
)
# G_A(G_B(B))
self
.
rec_B
=
self
.
netG_A
(
self
.
fake_A
)
# G_A(G_B(B))
def
backward_D_basic
(
self
,
netD
,
real
,
fake
):
def
backward_D_basic
(
self
,
netD
,
real
,
fake
):
"""Calculate GAN loss for the discriminator
"""Calculate GAN loss for the discriminator
...
@@ -166,10 +175,12 @@ class CycleGANModel(BaseModel):
...
@@ -166,10 +175,12 @@ class CycleGANModel(BaseModel):
if
lambda_idt
>
0
:
if
lambda_idt
>
0
:
# G_A should be identity if real_B is fed: ||G_A(B) - B||
# G_A should be identity if real_B is fed: ||G_A(B) - B||
self
.
idt_A
=
self
.
netG_A
(
self
.
real_B
)
self
.
idt_A
=
self
.
netG_A
(
self
.
real_B
)
self
.
loss_idt_A
=
self
.
criterionIdt
(
self
.
idt_A
,
self
.
real_B
)
*
lambda_B
*
lambda_idt
self
.
loss_idt_A
=
self
.
criterionIdt
(
self
.
idt_A
,
self
.
real_B
)
*
lambda_B
*
lambda_idt
# G_B should be identity if real_A is fed: ||G_B(A) - A||
# G_B should be identity if real_A is fed: ||G_B(A) - A||
self
.
idt_B
=
self
.
netG_B
(
self
.
real_A
)
self
.
idt_B
=
self
.
netG_B
(
self
.
real_A
)
self
.
loss_idt_B
=
self
.
criterionIdt
(
self
.
idt_B
,
self
.
real_A
)
*
lambda_A
*
lambda_idt
self
.
loss_idt_B
=
self
.
criterionIdt
(
self
.
idt_B
,
self
.
real_A
)
*
lambda_A
*
lambda_idt
else
:
else
:
self
.
loss_idt_A
=
0
self
.
loss_idt_A
=
0
self
.
loss_idt_B
=
0
self
.
loss_idt_B
=
0
...
@@ -179,9 +190,11 @@ class CycleGANModel(BaseModel):
...
@@ -179,9 +190,11 @@ class CycleGANModel(BaseModel):
# GAN loss D_B(G_B(B))
# GAN loss D_B(G_B(B))
self
.
loss_G_B
=
self
.
criterionGAN
(
self
.
netD_B
(
self
.
fake_A
),
True
)
self
.
loss_G_B
=
self
.
criterionGAN
(
self
.
netD_B
(
self
.
fake_A
),
True
)
# Forward cycle loss || G_B(G_A(A)) - A||
# Forward cycle loss || G_B(G_A(A)) - A||
self
.
loss_cycle_A
=
self
.
criterionCycle
(
self
.
rec_A
,
self
.
real_A
)
*
lambda_A
self
.
loss_cycle_A
=
self
.
criterionCycle
(
self
.
rec_A
,
self
.
real_A
)
*
lambda_A
# Backward cycle loss || G_A(G_B(B)) - B||
# Backward cycle loss || G_A(G_B(B)) - B||
self
.
loss_cycle_B
=
self
.
criterionCycle
(
self
.
rec_B
,
self
.
real_B
)
*
lambda_B
self
.
loss_cycle_B
=
self
.
criterionCycle
(
self
.
rec_B
,
self
.
real_B
)
*
lambda_B
# combined loss and calculate gradients
# combined loss and calculate gradients
self
.
loss_G
=
self
.
loss_G_A
+
self
.
loss_G_B
+
self
.
loss_cycle_A
+
self
.
loss_cycle_B
+
self
.
loss_idt_A
+
self
.
loss_idt_B
self
.
loss_G
=
self
.
loss_G_A
+
self
.
loss_G_B
+
self
.
loss_cycle_A
+
self
.
loss_cycle_B
+
self
.
loss_idt_A
+
self
.
loss_idt_B
...
@@ -218,4 +231,3 @@ class CycleGANModel(BaseModel):
...
@@ -218,4 +231,3 @@ class CycleGANModel(BaseModel):
self
.
backward_D_B
()
self
.
backward_D_B
()
# update D_A and D_B's weights
# update D_A and D_B's weights
self
.
optimizer_D
.
minimize
(
self
.
loss_D_A
+
self
.
loss_D_B
)
self
.
optimizer_D
.
minimize
(
self
.
loss_D_A
+
self
.
loss_D_B
)
ppgan/models/pix2pix_model.py
浏览文件 @
b306aa73
import
paddle
import
paddle
from
paddle
import
ParallelEnv
from
paddle
.distributed
import
ParallelEnv
from
.base_model
import
BaseModel
from
.base_model
import
BaseModel
from
.builder
import
MODELS
from
.builder
import
MODELS
...
@@ -22,7 +22,6 @@ class Pix2PixModel(BaseModel):
...
@@ -22,7 +22,6 @@ class Pix2PixModel(BaseModel):
pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
"""
"""
def
__init__
(
self
,
opt
):
def
__init__
(
self
,
opt
):
"""Initialize the pix2pix class.
"""Initialize the pix2pix class.
...
@@ -48,15 +47,21 @@ class Pix2PixModel(BaseModel):
...
@@ -48,15 +47,21 @@ class Pix2PixModel(BaseModel):
if
self
.
isTrain
:
if
self
.
isTrain
:
self
.
netD
=
build_discriminator
(
opt
.
model
.
discriminator
)
self
.
netD
=
build_discriminator
(
opt
.
model
.
discriminator
)
if
self
.
isTrain
:
if
self
.
isTrain
:
# define loss functions
# define loss functions
self
.
criterionGAN
=
GANLoss
(
opt
.
model
.
gan_mode
)
self
.
criterionGAN
=
GANLoss
(
opt
.
model
.
gan_mode
)
self
.
criterionL1
=
paddle
.
nn
.
L1Loss
()
self
.
criterionL1
=
paddle
.
nn
.
L1Loss
()
# build optimizers
# build optimizers
self
.
optimizer_G
=
build_optimizer
(
opt
.
optimizer
,
parameter_list
=
self
.
netG
.
parameters
())
self
.
build_lr_scheduler
()
self
.
optimizer_D
=
build_optimizer
(
opt
.
optimizer
,
parameter_list
=
self
.
netD
.
parameters
())
self
.
optimizer_G
=
build_optimizer
(
opt
.
optimizer
,
self
.
lr_scheduler
,
parameter_list
=
self
.
netG
.
parameters
())
self
.
optimizer_D
=
build_optimizer
(
opt
.
optimizer
,
self
.
lr_scheduler
,
parameter_list
=
self
.
netD
.
parameters
())
self
.
optimizers
.
append
(
self
.
optimizer_G
)
self
.
optimizers
.
append
(
self
.
optimizer_G
)
self
.
optimizers
.
append
(
self
.
optimizer_D
)
self
.
optimizers
.
append
(
self
.
optimizer_D
)
...
@@ -76,7 +81,6 @@ class Pix2PixModel(BaseModel):
...
@@ -76,7 +81,6 @@ class Pix2PixModel(BaseModel):
self
.
real_B
=
paddle
.
to_tensor
(
input
[
'B'
if
AtoB
else
'A'
])
self
.
real_B
=
paddle
.
to_tensor
(
input
[
'B'
if
AtoB
else
'A'
])
self
.
image_paths
=
input
[
'A_paths'
if
AtoB
else
'B_paths'
]
self
.
image_paths
=
input
[
'A_paths'
if
AtoB
else
'B_paths'
]
def
forward
(
self
):
def
forward
(
self
):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self
.
fake_B
=
self
.
netG
(
self
.
real_A
)
# G(A)
self
.
fake_B
=
self
.
netG
(
self
.
real_A
)
# G(A)
...
@@ -112,7 +116,8 @@ class Pix2PixModel(BaseModel):
...
@@ -112,7 +116,8 @@ class Pix2PixModel(BaseModel):
pred_fake
=
self
.
netD
(
fake_AB
)
pred_fake
=
self
.
netD
(
fake_AB
)
self
.
loss_G_GAN
=
self
.
criterionGAN
(
pred_fake
,
True
)
self
.
loss_G_GAN
=
self
.
criterionGAN
(
pred_fake
,
True
)
# Second, G(A) = B
# Second, G(A) = B
self
.
loss_G_L1
=
self
.
criterionL1
(
self
.
fake_B
,
self
.
real_B
)
*
self
.
opt
.
lambda_L1
self
.
loss_G_L1
=
self
.
criterionL1
(
self
.
fake_B
,
self
.
real_B
)
*
self
.
opt
.
lambda_L1
# combine loss and calculate gradients
# combine loss and calculate gradients
self
.
loss_G
=
self
.
loss_G_GAN
+
self
.
loss_G_L1
self
.
loss_G
=
self
.
loss_G_GAN
+
self
.
loss_G_L1
...
...
ppgan/solver/lr_scheduler.py
浏览文件 @
b306aa73
...
@@ -6,13 +6,23 @@ def build_lr_scheduler(cfg):
...
@@ -6,13 +6,23 @@ def build_lr_scheduler(cfg):
# TODO: add more learning rate scheduler
# TODO: add more learning rate scheduler
if
name
==
'linear'
:
if
name
==
'linear'
:
return
LinearDecay
(
**
cfg
)
def
lambda_rule
(
epoch
):
lr_l
=
1.0
-
max
(
0
,
epoch
+
1
-
cfg
.
start_epoch
)
/
float
(
cfg
.
decay_epochs
+
1
)
return
lr_l
scheduler
=
paddle
.
optimizer
.
lr_scheduler
.
LambdaLR
(
cfg
.
learning_rate
,
lr_lambda
=
lambda_rule
)
return
scheduler
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
class
LinearDecay
(
paddle
.
fluid
.
dygraph
.
learning_rate_scheduler
.
LearningRateDecay
):
# paddle.optimizer.lr_scheduler
def
__init__
(
self
,
learning_rate
,
step_per_epoch
,
start_epoch
,
decay_epochs
):
class
LinearDecay
(
paddle
.
optimizer
.
lr_scheduler
.
_LRScheduler
):
def
__init__
(
self
,
learning_rate
,
step_per_epoch
,
start_epoch
,
decay_epochs
):
super
(
LinearDecay
,
self
).
__init__
()
super
(
LinearDecay
,
self
).
__init__
()
self
.
learning_rate
=
learning_rate
self
.
learning_rate
=
learning_rate
self
.
start_epoch
=
start_epoch
self
.
start_epoch
=
start_epoch
...
@@ -21,5 +31,6 @@ class LinearDecay(paddle.fluid.dygraph.learning_rate_scheduler.LearningRateDecay
...
@@ -21,5 +31,6 @@ class LinearDecay(paddle.fluid.dygraph.learning_rate_scheduler.LearningRateDecay
def
step
(
self
):
def
step
(
self
):
cur_epoch
=
int
(
self
.
step_num
//
self
.
step_per_epoch
)
cur_epoch
=
int
(
self
.
step_num
//
self
.
step_per_epoch
)
decay_rate
=
1.0
-
max
(
0
,
cur_epoch
+
1
-
self
.
start_epoch
)
/
float
(
self
.
decay_epochs
+
1
)
decay_rate
=
1.0
-
max
(
0
,
cur_epoch
+
1
-
self
.
start_epoch
)
/
float
(
self
.
decay_epochs
+
1
)
return
self
.
create_lr_var
(
decay_rate
*
self
.
learning_rate
)
return
self
.
create_lr_var
(
decay_rate
*
self
.
learning_rate
)
ppgan/solver/optimizer.py
浏览文件 @
b306aa73
...
@@ -4,13 +4,11 @@ import paddle
...
@@ -4,13 +4,11 @@ import paddle
from
.lr_scheduler
import
build_lr_scheduler
from
.lr_scheduler
import
build_lr_scheduler
def
build_optimizer
(
cfg
,
parameter_list
=
None
):
def
build_optimizer
(
cfg
,
lr_scheduler
,
parameter_list
=
None
):
cfg_copy
=
copy
.
deepcopy
(
cfg
)
cfg_copy
=
copy
.
deepcopy
(
cfg
)
lr_scheduler_cfg
=
cfg_copy
.
pop
(
'lr_scheduler'
,
None
)
lr_scheduler
=
build_lr_scheduler
(
lr_scheduler_cfg
)
opt_name
=
cfg_copy
.
pop
(
'name'
)
opt_name
=
cfg_copy
.
pop
(
'name'
)
return
getattr
(
paddle
.
optimizer
,
opt_name
)(
lr_scheduler
,
parameters
=
parameter_list
,
**
cfg_copy
)
return
getattr
(
paddle
.
optimizer
,
opt_name
)(
lr_scheduler
,
parameters
=
parameter_list
,
**
cfg_copy
)
ppgan/utils/logger.py
浏览文件 @
b306aa73
...
@@ -2,7 +2,7 @@ import logging
...
@@ -2,7 +2,7 @@ import logging
import
os
import
os
import
sys
import
sys
from
paddle
import
ParallelEnv
from
paddle
.distributed
import
ParallelEnv
def
setup_logger
(
output
=
None
,
name
=
"ppgan"
):
def
setup_logger
(
output
=
None
,
name
=
"ppgan"
):
...
@@ -23,8 +23,8 @@ def setup_logger(output=None, name="ppgan"):
...
@@ -23,8 +23,8 @@ def setup_logger(output=None, name="ppgan"):
logger
.
propagate
=
False
logger
.
propagate
=
False
plain_formatter
=
logging
.
Formatter
(
plain_formatter
=
logging
.
Formatter
(
"[%(asctime)s] %(name)s %(levelname)s: %(message)s"
,
datefmt
=
"%m/%d %H:%M:%S"
"[%(asctime)s] %(name)s %(levelname)s: %(message)s"
,
)
datefmt
=
"%m/%d %H:%M:%S"
)
# stdout logging: master only
# stdout logging: master only
local_rank
=
ParallelEnv
().
local_rank
local_rank
=
ParallelEnv
().
local_rank
if
local_rank
==
0
:
if
local_rank
==
0
:
...
...
ppgan/utils/setup.py
浏览文件 @
b306aa73
...
@@ -2,7 +2,7 @@ import os
...
@@ -2,7 +2,7 @@ import os
import
time
import
time
import
paddle
import
paddle
from
paddle
import
ParallelEnv
from
paddle
.distributed
import
ParallelEnv
from
.logger
import
setup_logger
from
.logger
import
setup_logger
...
@@ -12,7 +12,8 @@ def setup(args, cfg):
...
@@ -12,7 +12,8 @@ def setup(args, cfg):
cfg
.
isTrain
=
False
cfg
.
isTrain
=
False
cfg
.
timestamp
=
time
.
strftime
(
'-%Y-%m-%d-%H-%M'
,
time
.
localtime
())
cfg
.
timestamp
=
time
.
strftime
(
'-%Y-%m-%d-%H-%M'
,
time
.
localtime
())
cfg
.
output_dir
=
os
.
path
.
join
(
cfg
.
output_dir
,
str
(
cfg
.
model
.
name
)
+
cfg
.
timestamp
)
cfg
.
output_dir
=
os
.
path
.
join
(
cfg
.
output_dir
,
str
(
cfg
.
model
.
name
)
+
cfg
.
timestamp
)
logger
=
setup_logger
(
cfg
.
output_dir
)
logger
=
setup_logger
(
cfg
.
output_dir
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录